In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F
import matplotlib.pyplot as plt

def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))

if not os.path.exists('/content/drive/MyDrive/Thesis_final/'):
    os.makedirs('/content/drive/MyDrive/Thesis_final/')

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)

GPU = True
if GPU:
    device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(f'Using {device}')

In [None]:
num_epochs = 800
learning_rate = 0.0004
batch_size = 8
latent_dim = 5
beta = 0.5
transform = transforms.Compose([])

def denorm(x):
    return x

In [None]:
import pickle
from torch.utils.data import Dataset
import random

class SpeakersDataset(Dataset):
    def __init__(self, train = True, transform = None):

        self.transform = transform

        with open("drive/MyDrive/Thesis_final/VCTK-Corpus/autovc/speakers_norm.pkl",'rb') as f:
           speakers = pickle.load(f)

        data = [speaker for _, speaker in speakers.items()]

        random.shuffle(data)

        if train:
            self.data = data[:61]
        else:
            self.data = data[61:]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample.reshape((1, 16,16))

In [None]:
train_dat = SpeakersDataset(train=True, transform=transform)
test_dat = SpeakersDataset(train=False, transform=transform)

loader_train = DataLoader(train_dat, batch_size, shuffle=True)
loader_test = DataLoader(test_dat, batch_size, shuffle=False)

In [None]:
class Reshape1(nn.Module):
    def forward(self, x):
        N = x.shape[0]
        return x.view(N, -1)

class Reshape2(nn.Module):
    def __init__(self, shape):
        super(Reshape2, self).__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(in_channels = 1, out_channels = 16, kernel_size = 3, stride=1, padding=0),
            nn.ReLU(),
            nn.Conv2d(in_channels = 16, out_channels = 32, kernel_size = 3, stride=1, padding=0),
            nn.ReLU(),        
            nn.Conv2d(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 2, padding = 0),
            Reshape1(),
            nn.Linear(1600, 120),
            nn.ReLU(),
            nn.Linear(120, 60),
            nn.ReLU(),
            nn.Linear(60, latent_dim*2))

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 60),
            nn.ReLU(),
            nn.Linear(60, 120),
            nn.ReLU(),
            nn.Linear(120, 1600),
            Reshape2((-1, 64, 5, 5)),
            nn.ConvTranspose2d(in_channels = 64, out_channels = 32, stride = 2, kernel_size = 4, padding = 0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels = 32,out_channels = 16,kernel_size = 3, stride = 1, padding = 0),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels = 16, out_channels = 1, kernel_size = 3, stride = 1, padding = 0),
            nn.Sigmoid()
        )
        
    def encode(self, x):
        mean, logvar = torch.chunk(self.encoder(x), chunks = 2, dim = 1)

        return mean, logvar
    
    def reparametrize(self, mu, logvar):
        std = torch.exp(0.5*logvar) if self.training else 0
        eps = torch.randn_like(std) if self.training else 0
        
        return mu + eps*std

    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        recon = self.decode(z)

        return recon, mu, logvar 

model = VAE(latent_dim).to(device)
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))
print(model)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
def loss_function_VAE(recon_x, x, mu, logvar, beta):
        recon_loss = F.mse_loss(recon_x, x, reduction='sum')
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

        return recon_loss + beta*KLD, recon_loss, beta*KLD

train_losses = []
test_losses = []

model.train()
for epoch in range(num_epochs):     
        train_loss = 0
        train_RL = 0
        train_KLD = 0

        test_loss = 0
        test_RL = 0
        test_KLD = 0

        for batch_idx, data in enumerate(loader_train):
          data = data.to(device)
          optimizer.zero_grad()
          recon_x, mu, logvar = model(data)
          loss, recon_loss, KLD = loss_function_VAE(recon_x, data, mu, logvar, beta)
          train_loss += loss.item()
          train_RL += recon_loss.item()
          train_KLD += KLD.item()

          loss.backward()
          optimizer.step()

          if batch_idx % 250 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(loader_train.dataset),
                100. * batch_idx / len(loader_train),
                loss.data / len(data)))

        print('Epoch: {} Average loss: {:.4f}'.format(epoch, train_loss / len(loader_train.dataset)))
        print()
        print("="*50)
        print()
        
        train_losses.append((train_loss, train_RL, train_KLD))

        with torch.no_grad():
          for test_data in loader_test:
            test_data = test_data.to(device)
            recon_x, mu, logvar = model(data)
            loss, recon_loss, KLD = loss_function_VAE(recon_x, data, mu, logvar, beta)
            test_loss += loss.item()
            test_RL += recon_loss.item()
            test_KLD += KLD.item()
        
        test_losses.append((test_loss, test_RL, test_KLD))

        if epoch == num_epochs - 1:
            with torch.no_grad():
                torch.jit.save(torch.jit.trace(model, (data), check_trace=False),
                    '/content/drive/MyDrive/Thesis_final/VAE_model.pth')


In [None]:
legend = ["RL+β*KLD","RL","β*KLD"]

plt.plot(train_losses)
plt.legend(legend)
plt.xlabel("Epoch")
plt.ylabel("Train Loss")

plt.figure()

plt.plot(test_losses)
plt.legend(legend)
plt.xlabel("Epoch")
plt.ylabel("Test Loss")

In [None]:
print('Input images')
print('-'*50)

sample_inputs = next(iter(loader_test))
fixed_input = sample_inputs

img = make_grid(denorm(fixed_input))
plt.figure()
show(img)

print('Reconstructed images')
print('-'*50)
with torch.no_grad():
    recon_batch = recon_batch.cpu()
    recon_batch = make_grid(denorm(recon_batch), nrow=8, padding=2, normalize=False,
                            range=None, scale_each=False, pad_value=0)
    plt.figure()
    show(recon_batch)

print('Generated Images')  
print('-'*50)
model.eval()
n_samples = 256
z = torch.randn(n_samples,latent_dim).to(device)
with torch.no_grad():
    samples = model.decode(z)
    
    copy_over = samples.cpu()
    samples = samples.cpu()
    samples = make_grid(denorm(samples), nrow=16, padding=2, normalize=False,
                            range=None, scale_each=False, pad_value=0)
    plt.figure(figsize = (8,8))
    show(samples)