# VAE from scartch on MNIST

## Importing Libraries

In [1]:
import torch
from torch import nn

## VAE Model

In [16]:
class VariationalAutoencoder(nn.Module):
    def  __init__(self, input_dim, h_dim =200, z_dim=20):
        super().__init__()
        # encoder
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)
        # decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)
        self.relu = nn.ReLU()

    def encode(self, x):
        # q_phi(z|x) = N(mu_x|z, sigma_x|z)
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        # p_teta(x|z) = N(mu_z|x, sigma_z|x)
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h)) # pixel value should be in 0 and 1

    def forward(self,x):
        mu, sigma = self.encode(x)
        z_reparametrized = mu +sigma*torch.randn_like(sigma)
        x_reconstructed = self.decode(z_reparametrized)
        return  x_reconstructed, mu, sigma

## Test Model with random

In [15]:
x = torch.randn(4,28*28) #  MNIST input image is of size (batch, channels=1, height=28, width=28)
vae = VariationalAutoencoder(input_dim=784)
x_reconstructed,mu,sigma = vae(x)
print(x_reconstructed.shape, mu.shape, sigma.shape)

torch.Size([4, 784]) torch.Size([4, 20]) torch.Size([4, 20])


## Training the model

In [None]:
import torch
import torchvision.datasets as datasets
from tqdm import tqdm # progress bar
from torch import nn , optim
from torchvision import transforms # image augmentation
from torchvision.utils import save_image
from torch.utils.data import DataLoader

## Config

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20
BATCH_SIZE = 100
NUM_EPOCHS = 20 
BATCH_SIZE = 64
LR_RATE = 3e-6  # karpathy constant

## Dataset loading

In [None]:
datasets =datasets.MNIST(root="dataser/", train=True, transform=transforms.ToTensor(),download=True)
train_loader = DataLoader(dataset=datasets, batch_size=BATCH_SIZE, shuffle=True)
model = VariationalAutoencoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum") ## binary cross entropy loss

## Start Trainning

In [54]:
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    for i, (x,_) in loop:
        # forward pass
        x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)

        # compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x) # push it to reconstruct the image
        kl_div = -torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2)) #push it to gaussian latant space
        
        # backprop 
        loss = reconstruction_loss + kl_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

1200it [00:33, 36.05it/s, loss=6.94e+10]
1200it [00:35, 34.13it/s, loss=4.38e+10]
1200it [00:46, 26.06it/s, loss=4.92e+10]
1200it [00:33, 35.78it/s, loss=3.4e+10] 
1200it [00:38, 30.78it/s, loss=2.53e+10]
1200it [00:33, 36.34it/s, loss=1.78e+10]
1200it [00:29, 40.11it/s, loss=1.11e+10]
1200it [00:30, 38.77it/s, loss=8.93e+9] 
1200it [00:31, 38.46it/s, loss=5.67e+9]
1200it [00:31, 38.53it/s, loss=3.98e+9]
1200it [00:27, 43.07it/s, loss=2.61e+9]
1200it [00:28, 42.79it/s, loss=1.51e+9]
1200it [00:27, 43.11it/s, loss=7.8e+8] 
1200it [00:27, 43.36it/s, loss=4.07e+8]
1200it [00:27, 43.34it/s, loss=1.55e+8]
1200it [00:28, 42.83it/s, loss=6.37e+7]
1200it [00:25, 47.17it/s, loss=2.16e+7]
1200it [00:24, 48.79it/s, loss=4.67e+6]
1200it [00:24, 48.82it/s, loss=1.3e+6] 
1200it [00:24, 49.56it/s, loss=2.76e+5]


## Inference

In [55]:
model = model.to("cpu")
def inference(digit, num_examples=1):
    """Generate and return `num_examples` images of digit from 0 to 9.
    Args: 
        digit (int): The digit to generate images for ranging from 0 to 9
        num_examples (int): Number of examples/images to generate
    
    Returns:
    A list of `PIL.Image` objects containing the generated images
    """
    images = []
    idx = 0
    for x, y in datasets:
        if y == idx:
            images.append(x)
            idx +=1
        if idx == 10:
            break

    encoding_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encode(images[d].view(1,784))
        encoding_digit.append((mu, sigma))

    mu, sigma = encoding_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon 
        out = model.decode(z)
        out = out.view(-1,1,28,28)
        save_image(out, f"test/generated_[Digit-{digit}]_{example}.png")

In [56]:
for idx in range(10):
    inference(idx, num_examples=1)