In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
import torch
from PIL import Image
from torch.utils.data import TensorDataset, DataLoader
import torchvision
import torch.nn as nn
from torch.optim import Adam

In [None]:
%matplotlib inline

In [None]:
from tqdm.notebook import tqdm

In [None]:
trainset = torchvision.datasets.MNIST(root='./data/', train=True, download=True)

In [None]:
trainset[0]

In [None]:
display(trainset[0][0])

In [None]:
transform = torchvision.transforms.PILToTensor()

In [None]:
trainset = torchvision.datasets.MNIST(root='./data/', train=True, download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True, num_workers=2)

In [None]:
for X, label in trainloader:
    print(X.shape)
    print(X.dtype)
    break

![title](auto_encoder.png)

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.lin1 = nn.Linear(32*7*7, 512)
        self.lin2 = nn.Linear(512, 1)

    def forward(self, inp):
        out = self.conv1(inp)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.conv2(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = out.reshape(inp.shape[0], -1)
        out = self.lin1(out)
        return self.lin2(self.relu(out))

In [None]:
encoder = Encoder()

In [None]:
for X, label in trainloader:
    X = X.to(torch.float32)
    out = encoder(X)
    break

In [None]:
out.shape

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.lin1 = nn.Linear(1, 128)
        self.lin2 = nn.Linear(128, 512)
        self.lin3 = nn.Linear(512, 784)
        self.relu = nn.ReLU()

    def forward(self, z):
        out = self.lin1(z)
        out = self.relu(out)
        out = self.lin2(out)
        out = self.relu(out)
        out = self.lin3(out)
        out = out.reshape(z.shape[0], 28, 28)
        return out

In [None]:
decoder = Decoder()

In [None]:
for X, label in trainloader:
    X = X.to(torch.float32)
    print(X.shape)
    out = encoder(X)
    X_hat = decoder(out).unsqueeze(1)
    print(X_hat.shape)
    break

In [None]:
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, inp):
        out = self.encoder(inp)
        return self.decoder(out)

In [None]:
auto_encoder = Autoencoder()
criterion = nn.MSELoss()
optimizer = Adam(auto_encoder.parameters(), lr=0.001)

In [None]:
epochs = 30
for e in tqdm(range(epochs)):
    for X, label in trainloader:
        X = X.to(torch.float32)
        X_hat = auto_encoder(X).unsqueeze(1)
        loss = criterion(X_hat, X)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [None]:
torch.save(auto_encoder.state_dict(), 'mnist_auto_encoder.pt')

In [None]:
auto_encoder = Autoencoder()
auto_encoder.load_state_dict(torch.load('mnist_auto_encoder.pt'))

In [None]:
def sample_image(cls): # sample image of a particular class from the dataset
    for i, data in enumerate(trainset):
        if data[1]==cls:
            return data

In [None]:
sample = sample_image(2)

In [None]:
sample[0].shape

In [None]:
plt.imshow(sample[0].squeeze(0).numpy(), cmap='gray')

In [None]:
auto_encoder.eval()

In [None]:
out = auto_encoder(sample[0].to(torch.float32))

In [None]:
plt.imshow(out.squeeze(0).detach().numpy(), cmap='gray')

In [None]:
auto_encoder.encoder(sample[0].to(torch.float32))

In [None]:
# Sample 10 images for the 10 digits and compare look into the value of z

In [None]:
for i in range(10):
    sample = sample_image(i)
    z = auto_encoder.encoder(sample[0].to(torch.float32))
    print(i, z)

In [None]:
inp = torch.tensor([[205.0]])

In [None]:
out = auto_encoder.decoder(inp)

In [None]:
plt.imshow(out.squeeze(0).detach().numpy(), cmap='gray')

# Variational auto-encoder

![title](vae.png)

$$
\Large D_{KL}\left[N(\mu, \sigma) \parallel N(0, 1)\right] = -\frac{1}{2}\left(\log\sigma^2 + 1 - \sigma^2 - \mu^2 \right)
$$

In [None]:
class Encoder(nn.Module):
    def __init__(self, z=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.lin1 = nn.Linear(32*7*7, 512)
        self.lin_m = nn.Linear(512, 1) # we will predict \mu and \sigma
        self.lin_s = nn.Linear(512, 1)

    def forward(self, inp):
        out = self.conv1(inp)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.conv2(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = out.reshape(inp.shape[0], -1)
        out = self.lin1(out)
        return self.lin_m(self.relu(out)), self.lin_s(self.relu(out))

In [None]:
class Variationaautoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder() # same decoder coud be used

    def forward(self, X):
        m, var = self.encoder(X)
        # sample z = \mu + \epsilon*\sigma
        z_norm = torch.randn(m.shape[0],1) # generate epsilon
        z = m + z_norm*var
        return self.decoder(z.unsqueeze(1)), m, var # we need \mu and \sigma to compute loss

In [None]:
def kl_div(m, s):
    return 0.5*(m**2 + s**2 - 1 - torch.log(s**2))

In [None]:
vae = Variationaautoencoder()
optimizer = Adam(vae.parameters(), lr=0.001)
criterion = nn.MSELoss()

In [None]:
epochs = 15
for e in tqdm(range(epochs)):
    running_loss = 0
    batches = 0
    for X, label in trainloader:
        batches+=1
        X = X.to(torch.float32)
        X_hat, m, s = vae(X)
        #print(m,s)
        loss = criterion(X_hat.unsqueeze(1), X) + torch.mean(kl_div(m, s))
        running_loss+=loss.detach().item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Loss at the end of epoch {e}: {running_loss/batches}")

In [None]:
torch.save(vae.state_dict(), 'mnist_vae.pt')

In [None]:
vae.load_state_dict(torch.load('mnist_vae.pt'))

In [None]:
vae.eval()

In [None]:
sample = sample_image(9)

In [None]:
plt.imshow(sample[0].squeeze(0).numpy(), cmap='gray')

In [None]:
out = vae(sample[0].to(torch.float32))

In [None]:
plt.imshow(out[0].squeeze(0).detach().numpy(), cmap='gray')

In [None]:
for i in range(10):
    sample = sample_image(i)
    mu, sigma = vae.encoder(sample[0].to(torch.float32))
    print(i, mu, sigma)

In [None]:
import matplotlib as mpl

In [None]:
mpl.rcParams['figure.figsize'] =  10,10

In [None]:
x = np.linspace(-5, 10, num=250)

In [None]:
fig = plt.figure()
i = 1
for m in x:
    mu = torch.FloatTensor([[m]])
    std = torch.zeros_like(mu)
    epsilon = torch.rand_like(std)
    z = mu + epsilon*std
    out = vae.decoder(z)
    plt.subplot(10, 25, i)
    plt.imshow(out[0].squeeze(0).detach().numpy(), cmap='gray')
    i+=1

# Task: Multiple latent dimensions

Consider the same dataset and design a vae consisting of two latent dimensions. So you will have two means and two standard deviations

In [None]:
# 1. adopt the encoder architecture to predict mu and sigma corresponding to each of the two latent dimensions
# 2. adopt the kl-div loss to consider multiple doimensions. Note that it is just the mean of the score across 
# each individual dimension
# 3. the decoder will now take as input a vector of size 2


Sample images from each class and obtain the corresponding latent vector for each of them and create a scatter plot. Can you identify the clusters?

Change the number of dimensions to see how it impacts the results..