**My Implementation of the Variational Auto Encoder**

[Link To the Paper - Kingma & Welling (2013)](https://arxiv.org/pdf/1312.6114)

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [11]:
import numpy as mp
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

latent_dim = 64
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 128
EPOCHS = 60

transform = transforms.Compose([transforms.ToTensor(),])

trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)

train_loader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

In [2]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()

        # encoder

        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
        )

        self.fc = nn.Linear(128*4*4, 256)
        self.mu = nn.Linear(256, latent_dim)
        self.logvar = nn.Linear(256, latent_dim)

        # decoder

        self.fcinv = nn.Linear(latent_dim, 256)
        self.fcinv2 = nn.Linear(256, 128*4*4)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid(),
        )

    def forward_training(self, x):
        x = self.encoder(x)
        x = x.view(x.size(0), -1)
        h = self.fc(x)
        mu, logvar = self.mu(h), self.logvar(h)

        z = mu + torch.exp(logvar/2) * torch.randn_like(mu)

        z = self.fcinv(z)
        z = self.fcinv2(z)
        z = z.view(z.size(0), 128, 4, 4)

        x = self.decoder(z)

        return x, mu, logvar

In [12]:
model = VAE().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)

def loss_fn(x_hat, x, mu, logvar):
    recon = F.mse_loss(x_hat, x, reduction="sum")
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + kl

def train():
    for epoch in range(EPOCHS):
        model.train() #TRAINING MODE for model (enables dropout and batch norm) --- model.eval() for inference
        total_loss = 0

        for x, _ in train_loader:
            x = x.to(DEVICE)

            opt.zero_grad()

            x_hat, mu, logvar = model.forward_training(x)
            loss = loss_fn(x_hat, x, mu, logvar)

            loss.backward()
            opt.step()

            total_loss += loss.item()

        avg = total_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg:.4f}")

    print("Training complete.")
    torch.save(model.state_dict(), "/content/drive/MyDrive/vae_cifar10_2.pth")


In [14]:
train()

Epoch 1/60 | Loss: 73.9501
Epoch 2/60 | Loss: 73.9006
Epoch 3/60 | Loss: 73.8839
Epoch 4/60 | Loss: 73.8239
Epoch 5/60 | Loss: 73.8484
Epoch 6/60 | Loss: 73.8927
Epoch 7/60 | Loss: 73.7871
Epoch 8/60 | Loss: 73.8504
Epoch 9/60 | Loss: 73.7696
Epoch 10/60 | Loss: 73.7463
Epoch 11/60 | Loss: 73.7254
Epoch 12/60 | Loss: 73.7532
Epoch 13/60 | Loss: 73.7823
Epoch 14/60 | Loss: 73.7002
Epoch 15/60 | Loss: 73.7247
Epoch 16/60 | Loss: 73.6939
Epoch 17/60 | Loss: 73.7070
Epoch 18/60 | Loss: 73.7142
Epoch 19/60 | Loss: 73.6935
Epoch 20/60 | Loss: 73.6101
Epoch 21/60 | Loss: 73.6808
Epoch 22/60 | Loss: 73.6247
Epoch 23/60 | Loss: 73.7283
Epoch 24/60 | Loss: 73.6193
Epoch 25/60 | Loss: 73.5811
Epoch 26/60 | Loss: 73.5708
Epoch 27/60 | Loss: 73.5548
Epoch 28/60 | Loss: 73.5522
Epoch 29/60 | Loss: 73.5548
Epoch 30/60 | Loss: 73.5700
Epoch 31/60 | Loss: 73.5603
Epoch 32/60 | Loss: 73.5571
Epoch 33/60 | Loss: 73.5403
Epoch 34/60 | Loss: 73.4641
Epoch 35/60 | Loss: 73.4843
Epoch 36/60 | Loss: 73.5010
E

In [15]:
import torchvision.utils as vutils
model.eval()

x, _ = next(iter(train_loader))
x = x.to(DEVICE)

with torch.no_grad():
    x_hat, _, _ = model.forward_training(x)

#save recons (top row = original, bottom row = recon)
comparison = torch.cat([x[:8], x_hat[:8]])
vutils.save_image(comparison, "reconstructions.png", nrow=8)

#sampling from z
with torch.no_grad():
    z = torch.randn(32, latent_dim).to(DEVICE)

    z = model.fcinv(z)
    z = model.fcinv2(z)
    z = z.view(z.size(0), 128, 4, 4)

    samples = model.decoder(z)

vutils.save_image(samples, "samples.png", nrow=8)

print("Saved reconstructions.png and samples.png")

Saved reconstructions.png and samples.png
