## Variational autoencoder in pytorch

The following cells contain the codes for building a variational autoencoder in Pytorch, and training it with MNIST dataset, along with some visualisations and analysis. It is recommended to run this in Google Colab as training without GPU will take a really long time.

In [None]:
# Get the required packages
!pip install -q torch torchvision altair matplotlib pandas
!git clone -q https://github.com/afspies/icl_dl_cw2_utils
from icl_dl_cw2_utils.utils.plotting import plot_tsne
%load_ext google.colab.data_table

In [None]:
# Mount google drive
from google.colab import drive
drive.mount('/content/drive') # Outputs will be saved in google drive

In [None]:
# Setting up
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, sampler
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import torch.nn.functional as F
import matplotlib.pyplot as plt

def show(img):
    npimg = img.cpu().numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))

if not os.path.exists('/content/drive/MyDrive/VAE/'):
    os.makedirs('/content/drive/MyDrive/VAE/')

# Set a random seed to ensure that the results are reproducible.
if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True
torch.manual_seed(0)

GPU = True # Choose whether to use GPU
if GPU:
    device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(f'Using {device}')

In [None]:
# Necessary Hyperparameters 
num_epochs = 20
learning_rate = 0.0005
batch_size = 64
latent_dim = 24     # Choose a value for the size of the latent space

# Additional Hyperparameters 
hidden_layer = 400

# (Optionally) Modify transformations on input
# transform = transforms.Compose([
#     transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
# ])

transform = transforms.Compose([
    transforms.ToTensor(),
])

# (Optionally) Modify the network's output for visualizing your images
# def denorm(x):
#     x = x*0.3081 + 0.1307
#     return x

def denorm(x):
    return x

In [None]:
# Load dataset
train_dat = datasets.MNIST(
    "data/", train=True, download=True, transform=transform
)
test_dat = datasets.MNIST("data/", train=False, transform=transform)

loader_train = DataLoader(train_dat, batch_size, shuffle=True)
loader_test = DataLoader(test_dat, batch_size, shuffle=False)

# Don't change 
sample_inputs, _ = next(iter(loader_test))
fixed_input = sample_inputs[:32, :, :, :]
save_image(fixed_input, '/content/drive/MyDrive/VAE/image_original.png')

In [None]:
# Define VAE

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()

        # self.encoder = nn.Sequential(
        #     nn.Conv2d(1, 16, kernel_size = 3),
        #     nn.BatchNorm2d(16),
        #     nn.LeakyReLU(),
        #     nn.Conv2d(16, 32, kernel_size = 3),
        #     nn.BatchNorm2d(32),
        #     nn.LeakyReLU(),
        #     nn.Conv2d(32, 64, kernel_size = 3, stride = 2),
        #     nn.BatchNorm2d(64),
        #     nn.LeakyReLU(),
        #     nn.Conv2d(64, 128, kernel_size = 3, stride = 2),
        #     nn.BatchNorm2d(128),
        #     nn.LeakyReLU()
        # )

        # self.fc_mean = nn.Linear(128*5*5, latent_dim)
        # self.fc_logvar = nn.Linear(128*5*5, latent_dim)
        # self.fc_grow = nn.Linear(latent_dim, 128*5*5)

        # self.decoder = nn.Sequential(
        #     nn.ConvTranspose2d(128, 64, kernel_size = 3, stride = 2),
        #     nn.BatchNorm2d(64),
        #     nn.LeakyReLU(),
        #     nn.ConvTranspose2d(64, 32, kernel_size = 4, stride = 2),
        #     nn.BatchNorm2d(32),
        #     nn.LeakyReLU(),
        #     nn.ConvTranspose2d(32, 16, kernel_size = 3),
        #     nn.BatchNorm2d(16),
        #     nn.LeakyReLU(),
        #     nn.ConvTranspose2d(16, 1, kernel_size = 3),
        #     nn.Sigmoid()
        # )
        ### I tried implementing convolutional VAE too, but found 
        ### fully-connected VAE to perform better, hence I chose to 
        ### stick with fully-connected VAE and commented out the codes
        ### for convolutional VAE.

        self.latent_dim = latent_dim

        self.encoder = nn.Sequential(
            nn.Linear(28*28, hidden_layer),
            nn.LeakyReLU(),
            nn.Linear(hidden_layer, 100),
            nn.LeakyReLU(),
        )

        self.fc_mean = nn.Linear(100, latent_dim)
        self.fc_logvar = nn.Linear(100, latent_dim)

        self.decoder = nn.Sequential(
            nn.Linear(self.latent_dim, 100),
            nn.LeakyReLU(),
            nn.Linear(100, hidden_layer),
            nn.LeakyReLU(),
            nn.Linear(hidden_layer, 28*28),
            nn.Sigmoid()
        )

        
    def encode(self, x):

        # x = self.encoder(x)
        # x = x.view(x.shape[0], -1)
        # mu = self.fc_mean(x)
        # logvar = self.fc_logvar(x)

        x = x.view(x.shape[0], -1)
        x = self.encoder(x)
        mu = self.fc_mean(x)
        logvar = self.fc_logvar(x)

        return mu, logvar
    
    
    def reparametrize(self, mu, logvar):

        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std


    def decode(self, z):

        z = self.decoder(z)
        z = z.view(z.shape[0], 1, 28, 28)
        return z

    
    def forward(self, x):

        mu, logvar = self.encode(x)
        z = self.reparametrize(mu, logvar)
        out = self.decode(z)
        return out, mu, logvar


In [None]:
# VAE model 
model = VAE(latent_dim).to(device)
params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("Total number of parameters is: {}".format(params))
print(model)
# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Define functions

def loss_function_VAE(recon_x, x, mu, logvar, beta = 1):
        
        # mse = nn.MSELoss()
        # recon_loss = F.mse_loss(recon_x, x)
        # kl_divergence = torch.mean(-0.5*torch.sum(1 + logvar - mu**2 - logvar.exp(), dim = 1), dim = 0)
        # loss = recon_loss + beta*kl_divergence

        bce = F.binary_cross_entropy(recon_x, x, reduction = 'sum')
        kl_divergence = -0.5*torch.sum(1 + logvar - mu**2 - logvar.exp())
        loss = bce + beta*kl_divergence
        return loss, bce, kl_divergence


def test_part(loader, model, beta = 1):
        # record the various losses
        recon_loss = 0
        kl_div = 0
        total_loss = 0
        count = 0
        model.eval()
        with torch.no_grad():
          for x,y in loader:
            x = x.to(device = device)
            y = y.to(device = device)
            # pass input through the model
            recon_x, mu, logvar = model(x)
            loss, rec_loss, kld = loss_function_VAE(recon_x, x, mu, logvar, beta = beta)
            # update the losses
            recon_loss += rec_loss
            kl_div += kld
            total_loss += loss
            count += 1
          print(f'test loss: {total_loss/count:.5f}')
          return (total_loss/count), (recon_loss/count), (kl_div/count)


In [None]:
# Training the model
# for plotting
print_every = 200
train_losses = []
train_recon_losses = []
train_klds = []
test_losses = []
test_recon_losses = []
test_klds = []

for epoch in range(num_epochs):     

        for t, (data, _) in enumerate(loader_train):
          model.train()
          data = data.to(device = device)
          # train the model
          recon_x, mu, logvar = model(data)
          loss, recon_loss, kld = loss_function_VAE(recon_x, data, mu, logvar, beta = 1)
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          if t%print_every == 0:
            print(f'Epoch:{epoch}, Iteration:{t}, Loss = {loss.item():.5f}')
            train_losses.append(loss)
            train_recon_losses.append(recon_loss)
            train_klds.append(kld)
            aa,bb,cc  = test_part(loader_test, model, beta = 1)
            test_losses.append(aa)
            test_recon_losses.append(bb)
            test_klds.append(cc)
        
        # save the model
        if epoch == num_epochs - 1:
            with torch.no_grad():
                torch.jit.save(torch.jit.trace(model, (data), check_trace=False),
                    '/content/drive/MyDrive/VAE/VAE_model.pth')


In [None]:
# experimenting with other values of beta
# codes are the same as the previous cell

betaa = 0.4
model2 = VAE(latent_dim).to(device)

train_losses2 = []
train_recon_losses2 = []
train_klds2 = []
test_losses2 = []
test_recon_losses2 = []
test_klds2 = []

for epoch in range(num_epochs):     
        for t, (data, _) in enumerate(loader_train):
          model2.train()
          data = data.to(device = device)

          recon_x, mu, logvar = model(data)
          loss, recon_loss, kld = loss_function_VAE(recon_x, data, mu, logvar, beta = betaa)
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          if t%print_every == 0:
            print(f'Epoch:{epoch}, Iteration:{t}, Loss = {loss.item():.5f}')
            train_losses2.append(loss)
            train_recon_losses2.append(recon_loss)
            train_klds2.append(kld)
            aa,bb,cc  = test_part(loader_test, model, beta = betaa)
            test_losses2.append(aa)
            test_recon_losses2.append(bb)
            test_klds2.append(cc)

In [None]:
# Plotting loss
# Loss curves with beta = 1
iterations = list(range(len(train_losses)))

fig = plt.figure()

ax1 = fig.add_subplot(1,2,1)
ax1.plot(iterations, train_losses, label = "total_loss")
ax1.plot(iterations, train_recon_losses, label = "recon_loss")
ax1.plot(iterations, train_klds, label = "kl_loss")
ax1.set_yscale('log')
ax1.set_title("train, Beta = 1")
ax1.legend(loc='best')

ax2 = fig.add_subplot(1,2,2)
ax2.plot(iterations, test_losses, label = "total_loss")
ax2.plot(iterations, test_recon_losses, label = "recon_loss")
ax2.plot(iterations, test_klds, label = "kl_loss")
ax2.set_yscale('log')
ax2.set_title("test, Beta = 1")
ax2.legend(loc='best')

plt.savefig("/content/drive/MyDrive/VAE/loss.png")

In [None]:
# loss curves with beta less than 1
# here is an example of a loss plot with beta = 0.4
iterations = list(range(len(train_losses2)))

fig = plt.figure()

ax1 = fig.add_subplot(1,2,1)
ax1.plot(iterations, train_losses2, label = "total_loss")
ax1.plot(iterations, train_recon_losses2, label = "recon_loss")
ax1.plot(iterations, train_klds2, label = "kl_loss")
ax1.set_yscale('log')
ax1.set_title(f"train, Beta = {betaa}")
ax1.legend(loc='best')

ax2 = fig.add_subplot(1,2,2)
ax2.plot(iterations, test_losses, label = "total_loss")
ax2.plot(iterations, test_recon_losses, label = "recon_loss")
ax2.plot(iterations, test_klds, label = "kl_loss")
ax2.set_yscale('log')
ax2.set_title(f"test, Beta = {betaa}")
ax2.legend(loc='best')

plt.savefig("/content/drive/MyDrive/VAE/loss_otherbeta4.png")

In [None]:
# Show samples and reconstruction

# Load the model
print('Input images')
print('-'*50)

sample_inputs, _ = next(iter(loader_test))
fixed_input = sample_inputs[0:32, :, :, :]

# Visualize the original images of the last batch of the test set
img = make_grid(denorm(fixed_input), nrow=8, padding=2, normalize=False,
                range=None, scale_each=False, pad_value=0)
plt.figure()
show(img)

print('Reconstructed images')
print('-'*50)
with torch.no_grad():
    # visualize the reconstructed images of the last batch of test set    
    fixed_input = fixed_input.to(device = device)
    recon_batch, mu, logvar = model(fixed_input)
    
    recon_batch = recon_batch.cpu()
    recon_batch = make_grid(denorm(recon_batch), nrow=8, padding=2, normalize=False,
                            range=None, scale_each=False, pad_value=0)
    plt.figure()
    show(recon_batch)

print('Generated Images')  
print('-'*50)
model.eval()
n_samples = 256
z = torch.randn(n_samples,latent_dim).to(device)
with torch.no_grad():
    
    # z = z.view(-1, 32, latent_dim)
    # out = mu + z * torch.exp(0.5*logvar)
    # out = out.view(256, -1)
    # samples = model.decode(out)

    samples = model.decode(z)
    
    samples = samples.cpu()
    samples = make_grid(denorm(samples), nrow=16, padding=2, normalize=False,
                            range=None, scale_each=False, pad_value=0)
    plt.figure(figsize = (8,8))
    show(samples)


In [None]:
# Visualisation with TSNE
from sklearn.manifold import TSNE
tsne_loader = DataLoader(test_dat, batch_size = 1, shuffle = False)

test_mu = torch.empty(0)
label = np.empty(10000)
test_mu = test_mu.to(device = device)
for i, (x, y) in enumerate(tsne_loader):
  x = x.to(device = device)
  label[i] = y
  # pass input through the model to obtain mu and logvar
  mu, logvar = model2.encode(x)
  # append mu values to the test_mu tensor
  test_mu = torch.cat((test_mu, mu), dim = 0)

test_mu = test_mu.detach().to(device = 'cpu').numpy()

# use TSNE to create z_embedded
z_embedded = TSNE(n_components=2).fit_transform(test_mu)
print(z_embedded.shape)


In [None]:
test_dataloader = DataLoader(test_dat, 10000, shuffle=False)
""" Inputs to the function are
        z_embedded - X, Y positions for every point in test_dataloader
        test_dataloader - dataloader with batchsize set to 10000
        num_points - number of points plotted (will slow down with >1k)
"""
plot_tsne(z_embedded, test_dataloader, num_points=1000, darkmode=False)


In [None]:
# Custom Visualizations
colors = ['red','purple','blue','orange','yellow','green','black','brown','pink','violet']
nums = ['red: 0','purple: 1','blue: 2','orange: 3','yellow: 4','green: 5','black: 6','brown: 7','pink: 8','violet: 9']

### the codes below are experimenting with different ways to plot the digits
# zz = np.hstack((z_embedded, label.reshape(10000,1)))
# zz = zz[zz[:,2].argsort()]
# fig = plt.figure(figsize = (8,8))
# ax1 = fig.add_subplot(1,1,1)
# for i in range(10):
#   ax1.scatter(zz[i*1000:i*1000+1000, 0], zz[i*1000:i*1000+1000, 1], c = colors[i], label = i)
# ax1.legend(loc='best')
# fig.show()


from matplotlib.colors import ListedColormap
fig = plt.figure(figsize = (8,8))
# plot the digits using different colours
plt.scatter(z_embedded[:,0], z_embedded[:,1], c = label, cmap = ListedColormap(colors))

# use colorbar for labelling
cb = plt.colorbar()
loc = np.arange(0, max(label), max(label)/float(len(colors)))
cb.set_ticks(loc)
cb.set_ticklabels(nums)


In [None]:
# Custom Visualizations
colors = ['red','purple','blue','orange','yellow','green','black','brown','pink','violet']
nums = ['red: 0','purple: 1','blue: 2','orange: 3','yellow: 4','green: 5','black: 6','brown: 7','pink: 8','violet: 9']


# zz = np.hstack((z_embedded, label.reshape(10000,1)))
# zz = zz[zz[:,2].argsort()]
# fig = plt.figure(figsize = (8,8))
# ax1 = fig.add_subplot(1,1,1)
# for i in range(10):
#   ax1.scatter(zz[i*1000:i*1000+1000, 0], zz[i*1000:i*1000+1000, 1], c = colors[i], label = i)
# ax1.legend(loc='best')
# fig.show()


from matplotlib.colors import ListedColormap
fig = plt.figure(figsize = (8,8))
plt.scatter(z_embedded[:,0], z_embedded[:,1], c = label, cmap = ListedColormap(colors))

cb = plt.colorbar()
loc = np.arange(0, max(label), max(label)/float(len(colors)))
cb.set_ticks(loc)
cb.set_ticklabels(nums)


In [None]:
# Interpolating in the latent space

# look through the data to decide what digits to use
data = []
numbers = []
count = 0
for i, (x,y) in enumerate(tsne_loader):
  count += 1
  data.append(x)
  numbers.append(y)
  if count == 10:
    break
print(numbers)

# position 2 and 3 contain the number 1 and 0, which I will use for this exercise.
mu_one, logvar_one = model.encode(data[2].to(device=device))
mu_zero, logvar_zero = model.encode(data[3].to(device=device))

# linear interpolation with 32 intervals
frequency = 32
diff = (mu_one - mu_zero) / (frequency - 1)
interpolation = torch.empty(0).to(device=device)
for i in range(frequency):
  interpolation = torch.cat((interpolation, mu_zero + i*diff), dim = 0)

out = model.decode(interpolation)

# plot the interpolation
samples = make_grid(out, nrow=8, padding=2, normalize=False,
                            range=None, scale_each=False, pad_value=0)
samples = samples.detach().cpu()
plt.figure(figsize = (8,8))
show(samples)
