# Variational Auto - Encoder using CNNs

## Importing required packages

In [1]:
import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt

# Defining the model

<img src="Architecture_cnn.png" width="800" />

In [2]:
class VAE(nn.Module):
  
    def __init__(self):
        super(VAE, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, 3)
        self.conv2 = nn.Conv2d(32, 16, 3)
        self.max_pool = nn.MaxPool2d(kernel_size = 2, stride = 2)
        self.fc1_mu = nn.Linear(12 * 12 * 16, 20)
        self.fc1_sig = nn.Linear(12 * 12 * 16, 20)
        self.fc2 = nn.Linear(20, 12 * 12 * 16)
        self.up_sample = nn.UpsamplingNearest2d(scale_factor=2)
        self.conv3 = nn.ConvTranspose2d(16, 32, 3)
        self.conv4 = nn.ConvTranspose2d(32, 1, 3)
  
    def encode(self,x):
        a1 = F.relu(self.conv1(x))
        a2 = F.relu(self.conv2(a1))
        mx_poold = self.max_pool(a2)
        a_reshaped = mx_poold.reshape(-1 , 12 * 12 * 16)
        a_mu = self.fc1_mu(a_reshaped)
        a_logvar = self.fc1_sig(a_reshaped)
        return a_mu, a_logvar
  
    def decode(self,z):
        a3 = F.relu(self.fc2(z))
        a3 = a3.reshape(-1, 16, 12, 12)
        a3_upsample = self.up_sample(a3)
        a4 = F.relu(self.conv3(a3_upsample))
        a5 = torch.sigmoid(self.conv4(a4))
        return a5
  
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return eps.mul(std).add_(mu)
  
    def forward(self,x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar
    

# Getting the dataset

In [3]:
train_dataset = torchvision.datasets.MNIST(root='./data',train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='./data',train=False, transform=transforms.ToTensor(),download = True)
batch_size = 100
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,  batch_size=100, shuffle=False)

## Instantiating the model and optimizer

In [4]:
model = VAE()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

## Defining loss

In [None]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 784), reduction='sum')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD

# Training

In [None]:
num_epochs = 3
print_per = 100
model.train()
for epoch in range(num_epochs):
    train_loss = 0
    print_loss = 0
    loss_record = []
    for i, (images, _) in enumerate(train_loader):
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(images)
        loss = loss_function(recon_batch, images, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        print_loss += loss.item()
        optimizer.step()
        if (i%print_per == 0):
            print("Epoch : {} , Minibatch : {} Loss = {:.4f}".format(epoch+1, i, print_loss))
            loss_record.append(print_loss)
            print_loss = 0
    print("Epoch {} : Loss = ({:.4f}) ".format(epoch+1, train_loss))

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch : 1 , Minibatch : 0 Loss = 54326.6719
Epoch : 1 , Minibatch : 100 Loss = 2963207.3535
Epoch : 1 , Minibatch : 200 Loss = 1878940.3311
Epoch : 1 , Minibatch : 300 Loss = 1469711.3965
Epoch : 1 , Minibatch : 400 Loss = 1313537.5762
Epoch : 1 , Minibatch : 500 Loss = 1249602.0713
Epoch 1 : Loss = (10128098.4697) 
Epoch : 2 , Minibatch : 0 Loss = 12200.3232
Epoch : 2 , Minibatch : 100 Loss = 1188733.4014
Epoch : 2 , Minibatch : 200 Loss = 1168597.8184
Epoch : 2 , Minibatch : 300 Loss = 1164925.1621
Epoch : 2 , Minibatch : 400 Loss = 1150479.5059
Epoch : 2 , Minibatch : 500 Loss = 1135445.5137
Epoch 2 : Loss = (6940009.0449) 
Epoch : 3 , Minibatch : 0 Loss = 11324.7793
Epoch : 3 , Minibatch : 100 Loss = 1120590.1504


In [None]:
print(model(images)[0].data[0].numpy().shape)
image1 = images[8].reshape(1, 1, 28, 28)
print(image1.shape)
plt.imshow(model(images)[0].data[8].numpy().reshape(28, 28), cmap='gray')
plt.show(block=True)

# Testing

In [None]:
test_loss = 0
print_per = 10
with torch.no_grad():
    for i, (images, _) in enumerate(test_loader):
        recon_batch, mu, logvar = model(images)
        test_loss += loss_function(recon_batch, images, mu, logvar).item()
        if (i%print_per == 0):
            plt.imshow(model(images)[0].data[0].numpy().reshape(28, 28), cmap='gray')
            plt.show(block=True)

In [None]:
image2 = images[1].reshape(1, 1, 28, 28)
print(image2.shape)

In [None]:
plt.imshow(model(images)[0].data[1].numpy().reshape(28, 28), cmap='gray')
plt.show(block=True)

In [None]:
plt.imshow(images[1].numpy().reshape(28, 28), cmap='gray')
plt.show(block=True)

# Experiments

All of these are similar to those given in this [notebook](https://github.com/ac-alpha/VAEs-using-Pytorch/blob/master/VAE.ipynb)

In [None]:
with torch.no_grad():
    mu1, logvar1 = model.encode(image1)
    std1 = torch.exp(0.5*logvar1)
    mu2, logvar2 = model.encode(image2)
    std2 = torch.exp(0.5*logvar2)
    print(mu1.shape)

In [None]:
with torch.no_grad():
    recon_images1 = []
    for ctr in range(0, 100, 5):
        eps_val = torch.full_like(mu1, fill_value = ctr * 0.01 )
        z_val1 = eps_val.mul(std1).add_(mu1)
        recon_image1 = model.decode(z_val1)
        recon_images1.append(recon_image1)

In [None]:
print(recon_images1[0] - recon_images1[1])

In [None]:
fig=plt.figure(figsize=(28, 28))
columns = 4
rows = 5
for i in range(1, columns*rows +1):
    img = recon_images1[i-1].detach().numpy().reshape(28, 28)
    fig.add_subplot(rows, columns, i)
    plt.imshow(img, cmap="gray")
plt.show()

In [None]:
with torch.no_grad():
    recon_images1 = []
    eps_val = torch.randn_like(mu1)
    for ctr in range(0, 100, 5):
        eps_val[:, 7] = ctr * 0.05 * std1[:, 7] + mu1[:, 7]
        z_val1 = eps_val.mul(std1).add_(mu1)
        recon_image1 = model.decode(z_val1)
        recon_images1.append(recon_image1)

In [None]:
fig=plt.figure(figsize=(28, 28))
columns = 4
rows = 5
for i in range(1, columns*rows +1):
    img = recon_images1[i-1].detach().numpy().reshape(28, 28)
    fig.add_subplot(rows, columns, i)
    plt.imshow(img, cmap = "gray")
plt.show()

In [None]:
eps_any = torch.randn_like(mu1)
z1 = eps_any.mul(std1).add_(mu1)
z2 = eps_any.mul(std2).add_(mu2)
all_recons = []
for i in range(20):
    z_bet = z1 + torch.full_like(mu1, fill_value = 0.05*i).mul(z2 - z1)
    recon_image = model.decode(z_bet)
    all_recons.append(recon_image)

In [None]:
fig=plt.figure(figsize=(28, 28))
columns = 4
rows = 5
for i in range(1, columns*rows +1):
    img = all_recons[i-1].detach().numpy().reshape(28, 28)
    fig.add_subplot(rows, columns, i)
    plt.imshow(img)
plt.show()