In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.distributions import Normal, MultivariateNormal
from torch.utils.data import Dataset, TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.dataset import random_split
from torch.optim import Adam
from torch.nn import NLLLoss, CrossEntropyLoss
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import plotly.offline as py
import plotly.graph_objs as go
import plotly.express as px
import plotly.figure_factory as ff

import sys
import shutil
import warnings
from zipfile import ZipFile
from pathlib import Path
from typing import Optional, Tuple

In [31]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [78]:
x = np.linspace(0, 160, 1000)
y = np.exp(((x + 160) / 64) - 5)
trace = [go.Scatter(x=x, y=y, mode='lines')]
figure = go.Figure(data=trace)
figure.show()

In [34]:
class Flatten(nn.Module):
    def __init__(self, c, h, w):
        super().__init__()
        self.c = c
        self.h = h
        self.w = w
    
    def forward(self, x):
        return x.view(-1, self.c * self.h * self.w)


class Unflatten(nn.Module):
    def __init__(self, c, h, w):
        super().__init__()
        self.c = c
        self.h = h
        self.w = w
    
    def forward(self, x):
        return x.view(-1, self.c, self.h, self.w)


class Encoder(nn.Module):
    def __init__(self, image_channels: int):
        super().__init__()

        self.encoder1 = nn.Sequential(
            nn.Conv2d(image_channels, 32, kernel_size=5, stride=1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(32, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32)
        )
        self.encoder2 = nn.Sequential(
            nn.LeakyReLU(),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64)
        )
        self.encoder3 = nn.Sequential(
            nn.LeakyReLU(),
            nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(128, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128)
        )
        self.encoder4 = nn.Sequential(
            nn.LeakyReLU(),
            nn.Conv2d(128, 256, kernel_size=5, stride=1, padding=2),
            nn.LeakyReLU(),
            nn.Conv2d(256, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256)
        )
        self.flatten = Flatten(256, 8, 8)

    def forward(self, x):
        z = self.encoder1(x)
        z = self.encoder2(z)
        z = self.encoder3(z)
        z = self.encoder4(z)
        z = self.flatten(z)

        n, d = z.shape
        mu, logvar = torch.split(z, split_size_or_sections=d//2, dim=1)
        return mu, logvar


class Decoder(nn.Module):
    def __init__(self, image_channels: int):
        super().__init__()

        self.decoder1 = nn.Sequential(
            Unflatten(128, 8, 8),
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256)
        )
        self.decoder2 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128)
        )
        self.decoder3 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64)
        )
        self.decoder4 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32)
        )
        self.decoder5 = nn.Sequential(
            nn.ReLU(),
            nn.ConvTranspose2d(32, image_channels, kernel_size=5, stride=1, padding=2),
            nn.Sigmoid(),
        )

    def forward(self, z):
        y = self.decoder1(z)
        y = self.decoder2(y)
        y = self.decoder3(y)
        y = self.decoder4(y)
        y = self.decoder5(y)
        return y


class ConvolutionalVAE(nn.Module):
    def __init__(self, channels=3, device=torch.device('cpu')):
        super().__init__()

        self.device = device
        self.hidden_size = 128 * 8 * 8

        self.encoder = Encoder(image_channels=channels)
        self.decoder = Decoder(image_channels=channels)

    def forward(self, x):
        n, c, h, w = x.shape
        
        mu_z, logvar_z = self.encoder(x)
        z = mu_z + torch.exp(0.5 * logvar_z) * torch.randn_like(mu_z).to(self.device)
        y = self.decoder(z)

        return y.clamp(0, 1), mu_z, logvar_z

    def generate(self, n):
        z = torch.randn(n, self.hidden_size).to(self.device)
        y = self.decoder(z)
        return y.clamp(0, 1)
    

def vae_mse_loss(x, y, mu_z, logvar_z):
    kldiv_loss = 0.5 * (logvar_z.exp() + mu_z**2 - 1 - logvar_z).sum()
    mse_loss = F.mse_loss(y, x, reduction='sum')
    return kldiv_loss, mse_loss

In [35]:
class Classifier(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(input_size, 1024),
            nn.Tanh(),
            nn.Linear(1024, 512),
            nn.Tanh(),
            nn.Linear(512, 16),
            nn.Tanh(),
            nn.Linear(16, 1)
        )
    
    def forward(self, x):
        return self.layers(x)

In [20]:
class DEAMDataset(Dataset):
    DATASET_PATH = Path('dataset')

    def __init__(self,
                sample_rate: int = 20480,
                n_mfcc: int = 40,
                frames_per_sec: int = 40,
                sample_duration: int = 4):
        super().__init__()
        self.sample_rate = sample_rate
        self.n_mfcc = n_mfcc
        self.frames_per_sec = frames_per_sec
        self.sample_duration = sample_duration
        self.hop_length = sample_rate // self.frames_per_sec

        self.random_state = None
        self.samples = None
        self.targets = None

    def _load_dataset(self, torch_samples_path: Path, torch_targets_path: Path):
        self.samples = torch.load(torch_samples_path)
        self.targets = torch.load(torch_targets_path)

    def load(self):
        torch_samples_path = DEAMDataset.DATASET_PATH / 'audio_samples.pt'
        torch_targets_path = DEAMDataset.DATASET_PATH / 'audio_targets.pt'
        self.random_state = np.random.RandomState(seed=13)

        self._load_dataset(torch_samples_path, torch_targets_path)

        self.samples = self.samples.reshape(-1, 1, self.n_mfcc, self.frames_per_sec * self.sample_duration).float()
        self.targets = self.targets.float()

    def train_test_split(self, test_size=0.2, random_seed=0) -> Tuple[Dataset, Dataset]:
        total_size = len(self)
        train_size = int((1 - test_size) * total_size)

        ids = np.arange(total_size)
        random = np.random.RandomState(random_seed)
        random.shuffle(ids)

        train_ids, test_ids = ids[:train_size], ids[train_size:]

        train_dataset = DEAMDataset(self.sample_rate, self.n_mfcc, self.frames_per_sec, self.sample_duration)
        train_dataset.samples, train_dataset.targets = self[train_ids]

        test_dataset = DEAMDataset(self.sample_rate, self.n_mfcc, self.frames_per_sec, self.sample_duration)
        test_dataset.samples, test_dataset.targets = self[test_ids]

        return train_dataset, test_dataset

    def __getitem__(self, index):
        return self.samples[index], self.targets[index]

    def __iter__(self):
        for sample, target in zip(self.samples, self.targets):
            yield sample, target

    def __len__(self):
        return self.samples.shape[0]


In [107]:
class Flatten(nn.Module):
    def forward(self, x: torch.Tensor):
        return x.view(x.shape[0], -1)


class AudioLSTMEncoder(nn.Module):
    def __init__(self, n_mfcc: int, n_len: int, n_hidden: int, n_out: int, device=torch.device('cpu')):
        super().__init__()

        self.n_mfcc = n_mfcc
        self.n_len = n_len
        self.n_hidden = n_hidden
        self.n_out = n_out
            
        self.weights = torch.ones(n_len).to(device)
        
        self.transform = nn.Sequential(
            nn.BatchNorm1d(1),
            nn.Conv1d(in_channels=1, out_channels=8, kernel_size=5, stride=1, padding=2),
            nn.LeakyReLU(),
            nn.BatchNorm1d(8),
            nn.Conv1d(in_channels=8, out_channels=16, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(),
            nn.BatchNorm1d(16),
            Flatten()
        )
        self.lstm = nn.LSTM(input_size=16 * self.n_mfcc // 2, hidden_size=self.n_hidden, num_layers=2, dropout=0.2, batch_first=True)
        self.out = nn.Sequential(
            nn.Linear(in_features=self.n_hidden, out_features=self.n_out),
            nn.BatchNorm1d(self.n_out),
            nn.Tanh()
        )

    def forward(self, x):
        n, c, h, w = x.shape

        x = x * self.weights
        x = x.transpose(2, 3).transpose(1, 2)

        frames = x.reshape(n * w, c, h)
        frames = self.transform(frames)
        
        frames = frames.reshape(n, w, -1)
        lstm_out, _ = self.lstm(frames)

        h1 = self.out(lstm_out[:,-2])
        h2 = self.out(lstm_out[:,-1])
        reg = (h2 - h1).pow(2).mean()
        return h2, -reg if reg < 0.2 else 0

In [93]:
dataset = DEAMDataset()
dataset.load()

In [94]:
train_dataset, test_dataset = dataset.train_test_split(test_size=0.2)
len(train_dataset), len(test_dataset)

(45924, 11482)

In [82]:
batch_size = 128

train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [83]:
classifier = Classifier(100)
classifier_checkpoint = torch.load(str(Path('models', 'classifier_gan.pt')))
classifier.load_state_dict(classifier_checkpoint['state_dict'])

<All keys matched successfully>

In [84]:
classifier = classifier.to(device)
classifier = classifier.eval()

In [110]:
encoder = AudioLSTMEncoder(n_mfcc=40, n_len=160, n_hidden=512, n_out=100, device=device).to(device)
optimizer = Adam(encoder.parameters(), lr=0.0005, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5)

In [None]:
epochs = 50

best_loss = 1e9
C = 1

for epoch in tqdm(range(epochs)):
    train_loss, test_loss = 0, 0
    
    encoder.cuda()
    encoder.train()
    for i, batch in enumerate(train_dataloader):
        batch = [tensor.to(device) for tensor in batch]
        samples, targets = batch
        targets = targets[:,0].view(-1,1)  # arousal target
        
        optimizer.zero_grad()

        h, reg = encoder(samples)
        pred = classifier(h)

        loss = F.mse_loss(pred, targets, reduction='mean') + reg * C
        train_loss += loss.item()
        
        loss.backward()
        optimizer.step()

    encoder.eval()
    for i, batch in enumerate(test_dataloader):
        batch = [tensor.to(device) for tensor in batch]
        samples, targets = batch
        targets = targets[:,0].view(-1,1) # arousal target
        
        with torch.no_grad():
            h, reg = encoder(samples)
            pred = classifier(h)

            loss = F.mse_loss(pred, targets, reduction='mean') + reg * C
            test_loss += loss.item()
    
    print('epoch={},\t'.format(epoch+1),
          'train_loss={},\t'.format(train_loss / len(train_dataloader)),
          'test_loss={}'.format(test_loss / len(test_dataloader)))
    
    if test_loss < best_loss:
        best_loss = test_loss
        torch.save({'state_dict': encoder.cpu().state_dict()}, str(Path('models', 'encoder_rnn_gan_v3.pt')))

HBox(children=(FloatProgress(value=0.0, max=50.0), HTML(value='')))

epoch=1,	 train_loss=0.04330220032991805,	 test_loss=0.04663829964896043
epoch=2,	 train_loss=0.032888638431564345,	 test_loss=0.024296878164427146
epoch=3,	 train_loss=0.02880753344314012,	 test_loss=0.02026207922026515
epoch=4,	 train_loss=0.023003170500226672,	 test_loss=0.017925267563098007
epoch=5,	 train_loss=0.01766419018979358,	 test_loss=0.011262295565878351
epoch=6,	 train_loss=0.01351494583091115,	 test_loss=0.007377280874384774
epoch=7,	 train_loss=0.010531749831860195,	 test_loss=0.004501763027575281
epoch=8,	 train_loss=0.008785131163013098,	 test_loss=0.0028551633831941418
epoch=9,	 train_loss=0.006953997645153557,	 test_loss=0.00328555617791911
epoch=10,	 train_loss=0.006070678749492698,	 test_loss=0.0016897401875919766
epoch=11,	 train_loss=0.005031183507931639,	 test_loss=0.0009434926582293378
epoch=12,	 train_loss=0.004833253570067019,	 test_loss=0.00199113504236771
epoch=13,	 train_loss=0.004229194814764731,	 test_loss=0.0045065494590542385
epoch=14,	 train_loss=0.0