<a href="https://colab.research.google.com/github/FaisalAhmed0/variational-autoencoder/blob/main/VAEs_reimplementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
'''
This is a reimplementation of the varitional auto encoder based on the original paper "Auto-Encoding Variational Bayes". by Kingma et.al
'''

# Imports, Setup, and Data preperation 

In [None]:
import torch
import torch.nn as nn
import torch.optim as opt
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

from torchvision.datasets import MNIST
from torchvision import transforms
from torchvision.utils import make_grid

import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# set the batch size for pytorch data loader
batch_size = 100

In [None]:
# function to load MNIST dataset
def load_mnist(batch_size):
  mnist = MNIST("./", train=True, download=True, transform=transforms.Compose([
                                                                                transforms.ToTensor()]) )
  mnist_test = MNIST("./", train=False, download=True,  transform=transforms.Compose([
                                                                              transforms.ToTensor()]) )
  mnist_dataloader = DataLoader(mnist, batch_size=batch_size)
  mnist_test_dataloader = DataLoader(mnist_test, batch_size=batch_size)
  return mnist_dataloader, mnist_test_dataloader

In [None]:
# function to load frey face dataset
def load_frey_face(batch_size):
  # download the data
  ! wget https://cs.nyu.edu/~roweis/data/frey_rawface.mat

  fileName = "frey_rawface.mat"
  frey_face_mat = loadmat(fileName) # load the mat file
  frey_face_input = torch.tensor( frey_face_mat['ff'].T.reshape(-1, 1, 28, 20))
  dummy_targets = torch.zeros(frey_face_input.shape[0])
  # print(frey_face_input[0])
  size = frey_face_input.shape[0]
  train_size = int(0.9 * size)

  frey_face = TensorDataset((frey_face_input[: train_size]), dummy_targets[: train_size])
  frey_face_test = TensorDataset((frey_face_input[train_size: ]), dummy_targets[train_size:])

  frey_face_dataloader = DataLoader(frey_face, batch_size=batch_size)
  frey_face_test_dataloader = DataLoader(frey_face_test, batch_size=batch_size)
  return frey_face_dataloader, frey_face_test_dataloader

In [None]:
# plot a batch of images as a grid.
def plot_grid(dataloader):
  images, _ = next(iter(dataloader))
  grid = make_grid(images, )
  plt.figure(figsize=(10, 10))
  plt.imshow(grid.permute(1, 2, 0))

In [None]:
# load the data
mnist, mnist_test = load_mnist(batch_size)
freyface, freyface_test = load_frey_face(batch_size)

In [None]:
# plot a grid of images
plot_grid(mnist)

In [None]:
plot_grid(freyface)

# Model Architecture 

In [None]:
class Encoder(nn.Module):
  '''
  This class defines the encoder architecture
  '''
  def __init__(self, input_size, hidden_size, bottleneck):
    super().__init__()
    self.linear1 = nn.Linear(input_size, hidden_size)
    self.mean = nn.Linear(hidden_size, bottleneck)
    self.var = nn.Linear(hidden_size, bottleneck) 

    nn.init.normal_(self.linear1.weight, mean=0.0, std=0.01)
    nn.init.normal_(self.mean.weight, mean=0.0, std=0.01)
    nn.init.normal_(self.var.weight, mean=0.0, std=0.01)
    

  def forward(self, x):
    mean = self.mean(torch.tanh(self.linear1(x)))
    log_var =  self.var(torch.tanh(self.linear1(x)))
    return mean, log_var

In [None]:
class Decoder(nn.Module):
  '''
  This class defines the decoder architecture
  '''
  def __init__(self, bottleneck, hidden_size, input_size):
    super().__init__()
    self.linear1 = nn.Linear(bottleneck, hidden_size)
    self.mean = nn.Linear(hidden_size, input_size)

    nn.init.normal_(self.linear1.weight, mean=0.0, std=0.01)
    nn.init.normal_(self.mean.weight, mean=0.0, std=0.01)

  def forward(self, x, output_activation=None):
    mean = self.mean(torch.tanh(self.linear1(x)))
    if output_activation:
      return output_activation(mean)
    return mean

# Loss function and Training loop

In [None]:
def vae_loss(logvar_z, mean_z, output, target, size, batch_size, mse=True):
  # KL Divergence between the prior and the posterior
  # print(logvar_z.shape)
  # print(output.shape)
  # print(target.shape)
  kl_divergence = - 0.5 * (torch.sum(1 + logvar_z - mean_z.pow(2) - logvar_z.exp(), dim=1)).sum()
  # reconstruction loss
  if mse:
    reconstruction_loss = F.mse_loss(output, target, reduction="sum")
  else:
    reconstruction_loss = F.binary_cross_entropy(output, target, reduction="sum")
  loss = (1/batch_size) * (kl_divergence + reconstruction_loss)
  return loss

In [None]:
# simple function to implemenet the reparametrization trick
def reparametrization(mean, logv):
  eps = torch.randn_like(mean, device=device)
  z = mean + eps * logv.exp().pow(0.5)
  # print(z.shape)
  return z

In [None]:
def train(encoder, decoder, loss, optimizer, dataloader, epochs, dataset_size, testloader, channels=1, height=28, width=28, plot=True, mse=False,  activation=True, data="mnist", plot_freq=10):
  losses = []
  test_losses = []
  # Main training loop
  for epoch in range(epochs):
    for img, _ in dataloader:
      if data == "freyface":
        img_flattend = img.reshape(-1, (torch.tensor(img.shape[1:])).prod()).to(torch.float32)
      else:
        img_flattend = img.reshape(-1, (torch.tensor(img.shape[1:])).prod())
      mu, logv = encoder(img_flattend.to(device))
      z = reparametrization(mu, logv)
      if activation:
        output = decoder(z.to(device), torch.sigmoid)
      else:
        output = decoder(z.to(device))
      loss = vae_loss(logv.to(device), mu.to(device), output.to(device), img_flattend.to(device), dataset_size, len(img), mse=mse)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    losses.append(-loss)

    # plot some results every 10 epochs
    if (epoch+1) % plot_freq == 0 :
      targets = img[:10]
      output_reshaped = output.reshape(-1, channels, height, width)[:10]
      target_grid = make_grid(targets.cpu().detach(), nrow=10)
      if mse:
        output_grid = make_grid(output_reshaped.cpu().detach().to(torch.int32), nrow=10)
      else:
        output_grid = make_grid(output_reshaped.cpu().detach(), nrow=10)
      if plot:
        plt.figure(figsize=(15, 10))
        plt.imshow(target_grid.permute(1, 2, 0))
        plt.figure(figsize=(15, 10))
        plt.imshow(output_grid.permute(1, 2, 0))
        plt.show()

    # evaluate on the test set
    with torch.no_grad():
      for img, _ in testloader:
        if data == "freyface":
          img_flattend = img.reshape(-1, (torch.tensor(img.shape[1:])).prod()).to(torch.float32)
        else:
          img_flattend = img.reshape(-1, (torch.tensor(img.shape[1:])).prod())
        mu, logv = encoder(img_flattend.to(device))
        z = reparametrization(mu, logv)
        if activation:
          output = decoder(z.to(device), torch.sigmoid)
        else:
          output = decoder(z.to(device))
        test_loss = vae_loss(logv.to(device), mu.to(device), output.to(device), img_flattend.to(device), dataset_size, len(img), mse=mse)
        # test_loss = vae_loss(logv.to(device), mu.to(device), output.to(device), img_flattend.to(device), 60000, len(img), mse=False)
      test_losses.append(- test_loss)

      print(f"Epoch: {epoch+1}, train loss: {loss}, test loss: {test_loss}")

  return losses, test_losses,target_grid, output_grid

# Test for the implementation

In [None]:
hidden_size = 500
bottleneck = 5
input_size = 784
stepsize = 0.01
epochs = 100
# add the parameters for weight initlization

In [None]:
encoder = Encoder(input_size, hidden_size, bottleneck).to(device) # define the encoder
decoder = Decoder(bottleneck, hidden_size, input_size).to(device) # define the decoder
optimizer = opt.Adagrad(list(encoder.parameters()) + list(decoder.parameters()) , lr=stepsize, weight_decay=1) # define the optimizer

train(encoder, decoder, vae_loss, optimizer, mnist, epochs, dataset_size=60000, testloader=mnist_test)

# Experimental Setup

In [None]:
def experiment(epochs,input_size, hidden_size, bottleneck, height=None, width=None, plot_freq=10):
  # save image for comparison
  encoder = Encoder(input_size, hidden_size, bottleneck).to(device) # define the encoder
  decoder = Decoder(bottleneck, hidden_size, input_size).to(device) # define the decoder

  optimizer = opt.Adagrad(list(encoder.parameters()) + list(decoder.parameters()) , lr=stepsize) # define the optimizer
  if height != None and width != None:
    loss, test_loss, data, output = train(encoder, decoder, vae_loss, optimizer, freyface, epochs, dataset_size=dataset_size, plot=True, testloader=freyface_test, height=height, width=width, data="freyface", mse=True, plot_freq=plot_freq, activation=False)
  else:
    loss, test_loss, data, output = train(encoder, decoder, vae_loss, optimizer, mnist, epochs, dataset_size=dataset_size, plot=True, testloader=mnist_test, activation=True, mse=False, plot_freq=plot_freq)

  return loss, test_loss, data, output

In [None]:
def plot_loss(loss, loss_test, n, data="MNIST"):
  x_labels = [i*10**6 for i in range(1,len(loss)+1)]
  plt.plot(x_labels, loss, '-r', label="AEVB (train)")
  plt.plot(x_labels, loss_test, '--r', label="AEVB (test)")
  plt.xscale('log')
  plt.xlabel("# Training samples evaluated")
  plt.ylabel("Loss")
  plt.legend()
  plt.savefig(f"{data} N={n}")

# MNIST Experiment

In [None]:
# Experemints setup for MNIST
# Networks parameters
hidden_size = 500
input_size = 784
# different size of the latent space
N = [3, 5, 10, 20, 200]
epochs = 10
stepsize = 0.03
for bottleneck in N:
  loss, test_loss, data, output = experiment(epochs,input_size, hidden_size, bottleneck)
  if bottleneck == N[0]:
    print(f"Original images")
    plt.figure(figsize=(15, 10))
    plt.imshow(data.permute(1, 2, 0))
    plt.savefig("original_image")
    plt.show()
  print(f"MNIST Image Generated with latent space size of {bottleneck}")
  plt.figure(figsize=(15, 10))
  plt.imshow(output.permute(1, 2, 0))
  plt.savefig(f"MNIST Image Generated with latent space size of {bottleneck}")
  plt.show()
  print(f"Losses for N={bottleneck}")
  plot_loss(loss, test_loss, bottleneck)

# Experiments (Frey Face)

In [None]:
# Experemints setup for frey face
# Networks parameters
hidden_size = 100
input_size = 560
# different size of the latent space
N = [2, 5, 10, 20]
epochs =5000
dataset_size = 1950
stepsize = 0.1
for bottleneck in N:
  loss, test_loss, data, output = experiment(epochs,input_size, hidden_size, bottleneck, height=28, width=20, plot_freq=100)
  if bottleneck == N[0]:
    print(f"Original images fery face")
    plt.figure(figsize=(15, 10))
    plt.imshow(data.permute(1, 2, 0))
    plt.savefig("original_image")
    plt.show()
  print(f"frey face Image Generated with latent space size of {bottleneck}")
  plt.figure(figsize=(15, 10))
  plt.imshow(output.permute(1, 2, 0))
  plt.savefig(f"frey face Image Generated with latent space size of {bottleneck}")
  plt.show()
  print(f"frey face Losses for N={bottleneck}")
  plot_loss(loss, test_loss, bottleneck,  data="Frey face")