In [2]:
!pip install utils



# Imports

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torchvision
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from IPython.display import Image
from IPython.core.display import Image, display
from torch.utils.data.sampler import SubsetRandomSampler

In [4]:
CUDA_LAUNCH_BLOCKING=1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
training_batch_size = 128

# Data Loading

In [5]:
# Load Training Data
dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
num_train = len(dataset)
# Only train against 10000 images to reduce the training time
indices = list(range(num_train))
train_idx = indices[:10000]
train_sampler = SubsetRandomSampler(train_idx)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=training_batch_size, sampler=train_sampler)

In [6]:
#test loader
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
num_test = len(test_dataset)
# Only test against 200 images to reduce the testing time
indices = list(range(num_test))
test_idx = indices[:200]
test_sampler = SubsetRandomSampler(test_idx)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, sampler=test_sampler)

# Utilities

In [7]:
# Flatten 2D image tensor to 1D array before feeding to neural network
def convert2DTensorto1DTensor(img):
    x = img.view(img.size(0), -1)
    if torch.cuda.is_available():
      x = Variable(x.cuda())
    return x

# Regularization parameter
alpha = -0.2

# Loss function - Combination of reconstruction error(BCELoss) and the KL diverzence for the distribution of latent variable at 
# end of encoder
def loss_fn(x_bar, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(x_bar, x, size_average=False)
    KLD_loss = alpha * torch.sum(1 + logvar - mu**2 -  logvar.exp())
    return (recon_loss + KLD_loss)

# Network Implementation

In [8]:
# Network implementation of VAE with 3 encoder layers & 3 decoder layers
class VAE(nn.Module):
    def __init__(self, image_size=784, latent_dim=40):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(image_size, 400),
            nn.LeakyReLU(0.2),
            nn.Linear(400, 200),
            nn.LeakyReLU(0.2),
            nn.Linear(200, latent_dim*2)
        ) 
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 200),
            nn.ReLU(),
            nn.Linear(200, 400),
            nn.ReLU(),
            nn.Linear(400, image_size),
            nn.Sigmoid()
        )
    # Reparameterize for backpropagation and resample from the reparametrized distribution
    def reparameterize(self, mean, log_variance):
        std = log_variance.mul(0.5).exp_()
        random = torch.randn(*mean.size())
        if torch.cuda.is_available():
          random = random.cuda()
        z_sampled = Variable(random)
        z_reparam = mean + std * z_sampled
        return z_reparam
    # Implemtation of forward
    def forward(self, x):
        latent = self.encoder(x)
        mean, log_variance = torch.chunk(latent, 2, dim=1)
        sample = self.reparameterize(mean, log_variance)
        return self.decoder(sample), mean, log_variance

In [9]:
model= VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training

In [None]:
# We train for 1000 epochs
epochs = 1000
loss_values = []
for epoch in range(epochs):
    running_loss = 0.0
    flag = True
    for idx, (images, _) in enumerate(dataloader):
        images = convert2DTensorto1DTensor(images)
        images_bar, mu, logvar = model(images)
        loss = loss_fn(images_bar, images, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()/ (len(dataloader) * training_batch_size)
        if flag == True:
          plt.subplot(1, 2, 1)
          plt.imshow(torch.reshape(images.cpu()[0], (28,28)), cmap='gray')
          plt.subplot(1, 2, 2)
          plt.imshow(torch.reshape(images_bar.cpu().detach()[0], (28,28)), cmap='gray')
          plt.show()
          flag = False   
        #running_loss = running_loss/128
    loss_values.append(running_loss)
    print("Epoch {} Loss: {:.2f}".format(epoch+1, running_loss))
plt.plot(np.array(loss_values), 'r')

# Testing

In [None]:
# Run the model for test images and calculate loss
total_test_loss = 0.0
for idx, (images, _) in enumerate(test_dataloader):
  images = convert2DTensorto1DTensor(images)
  recon_images, mu, logvar = model(images)
  loss = loss_fn(recon_images, images, mu, logvar)
  total_test_loss += loss.item()
  for i in range(0, len(images)):
    plt.subplot(1, 2, 1)
    plt.imshow(np.reshape(images[i].cpu().detach().numpy(), (28, 28)), cmap='gray')
    plt.title("Original Image")
    plt.subplot(1, 2, 2)
    plt.imshow(np.reshape(recon_images[i].cpu().detach().numpy(), (28, 28)), cmap='gray')
    plt.title("Reconstructed image")
    plt.show()
avg_test_loss = total_test_loss/ len(test_dataloader)
print(avg_test_loss)