# Variational Autoencoder

### Imports

In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

### Data

In [2]:
image_size = 784
hidden_dim = 400
latent_dim = 20
batch_size = 128
epochs = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_dataset = torchvision.datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root='data', train=False, transform=transforms.ToTensor())

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Create directory for reconstructed images
sample_dir = 'vae_images'
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

### Model

In [3]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(image_size, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc2_logvar = nn.Linear(hidden_dim, latent_dim)
        self.fc3 = nn.Linear(latent_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, image_size)
    
    def encode(self, x):
        h = F.relu(self.fc1(x))
        return self.fc2_mean(h), self.fc2_logvar(h)
    
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std
    
    def decode(self, z):
        h = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h))

    def forward(self, x):
        mean, logvar = self.encode(x.view(-1, image_size))
        z = self.reparameterize(mean, logvar)
        return self.decode(z), mean, logvar

### Training

In [4]:
model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [5]:
def loss_function(recon_image, original_image, mean, logvar):
    bce = F.binary_cross_entropy(recon_image, original_image.view(-1, image_size), reduction='sum')
    kld = 0.5 * torch.sum(mean.pow(2) + logvar.exp() - logvar - 1)
    # If dim=1, then the return ternsor will be of shape (batch_size, 1)
    # So another sum is needed to get the total loss
    return bce + kld

In [6]:
def train(epoch):
    model.train()
    train_loss = 0
    for i, (images, _) in enumerate(train_loader):
        images = images.to(device)
        recon_images, mean, logvar = model(images)
        loss = loss_function(recon_images, images, mean, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        if (i+1) % 100 == 0:
            print(f'Epoch [{epoch+1}], Batch [{i+1}/{len(train_loader)}], Loss: {loss.item() / len(train_loader):.4f}')
    print(f'Epoch [{epoch+1}], Average Loss: {train_loss/len(train_loader.dataset):.4f}')

In [7]:
def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (images, _) in enumerate(test_loader):
            images = images.to(device)
            recon_images, mean, logvar = model(images)
            loss = loss_function(recon_images, images, mean, logvar)
            test_loss += loss.item()
            if i == 0:
                comparison = torch.cat([images[:5], recon_images.view(batch_size, 1, 28, 28)[:5]])
                torchvision.utils.save_image(comparison.cpu(), os.path.join(sample_dir, f'reconstruction_{epoch+1}.png'), nrow=5)
    print(f'Test Loss: {test_loss / len(test_loader.dataset):.4f}')

In [8]:
for epoch in range(0, epochs):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        z = torch.randn(64, latent_dim).to(device)
        sample = model.decode(z).view(-1, 1, 28, 28)
        torchvision.utils.save_image(sample.cpu(), os.path.join(sample_dir, f'sample_{epoch+1}.png'))

Epoch [1], Batch [100/469], Loss: 50.0405
Epoch [1], Batch [200/469], Loss: 42.5522
Epoch [1], Batch [300/469], Loss: 38.0322
Epoch [1], Batch [400/469], Loss: 34.3826
Epoch [1], Average Loss: 164.5556
Test Loss: 127.1950
Epoch [2], Batch [100/469], Loss: 34.6890
Epoch [2], Batch [200/469], Loss: 33.2544
Epoch [2], Batch [300/469], Loss: 32.2783
Epoch [2], Batch [400/469], Loss: 31.7993
Epoch [2], Average Loss: 121.2632
Test Loss: 115.4512
Epoch [3], Batch [100/469], Loss: 31.7116
Epoch [3], Batch [200/469], Loss: 32.0206
Epoch [3], Batch [300/469], Loss: 31.4955
Epoch [3], Batch [400/469], Loss: 30.2868
Epoch [3], Average Loss: 114.6442
Test Loss: 111.7778
Epoch [4], Batch [100/469], Loss: 30.2835
Epoch [4], Batch [200/469], Loss: 31.0493
Epoch [4], Batch [300/469], Loss: 29.3748
Epoch [4], Batch [400/469], Loss: 31.8540
Epoch [4], Average Loss: 111.7311
Test Loss: 109.7383
Epoch [5], Batch [100/469], Loss: 30.6597
Epoch [5], Batch [200/469], Loss: 28.9398
Epoch [5], Batch [300/469], 