<table style="background-color:#FFFFFF">   
  <tr>     
  <td><img src="https://upload.wikimedia.org/wikipedia/commons/9/95/Logo_EPFL_2019.svg" width="150x"/>
  </td>     
  <td>
  <h1> <b>CS-461: Foundation Models and Generative AI</b> </h1>
  Prof. Charlotte Bunne  
  </td>   
  </tr>
</table>

# ðŸ“š  Exercise Session 3 - Code Demonstration: VAEs

In this notebook, we demonstrate a simple implementation of a Variational Autoencoder (VAE) using the MNIST dataset.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Hyperparameters & Data Preparation

In [None]:
# Hyperparameters
batch_size = 128
learning_rate = 1e-3
num_epochs = 15
latent_dim = 2  # 2D latent space for visualization
kl_factor = 0.0

# MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### VAE Model Architecture

We model the decoder as a Gaussian distribution with mean and diagonal covariance matrix, predicted by a convolutional neural network, i.e.,
$$q_\theta(z|x) = \mathcal{N}(\mu_\theta(x), \mathrm{diag}(\sigma^2_\theta(x))).$$
In practice, we predict the log-variance, instead of the variance to avoid negative values.

In [None]:
class ConvVAE(nn.Module):
    def __init__(self):
        super(ConvVAE, self).__init__()
        # Encoder
        self.enc_conv1 = nn.Conv2d(1, 32, 3, 2, 1)
        self.enc_conv2 = nn.Conv2d(32, 64, 3, 2, 1)
        self.enc_fc1 = nn.Linear(64*7*7, 256)
        self.fc_mu = nn.Linear(256, latent_dim)
        self.fc_logvar = nn.Linear(256, latent_dim)
        # Decoder
        self.dec_fc = nn.Linear(latent_dim, 256)
        self.dec_fc2 = nn.Linear(256, 64*7*7)
        self.dec_conv1 = nn.ConvTranspose2d(64, 32, 3, 2, 1, output_padding=1)
        self.dec_conv2 = nn.ConvTranspose2d(32, 1, 3, 2, 1, output_padding=1)

    def encode(self, x):
        h = F.relu(self.enc_conv1(x))
        h = F.relu(self.enc_conv2(h))
        h = h.view(h.size(0), -1)
        h = F.relu(self.enc_fc1(h))
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = F.relu(self.dec_fc(z))
        h = F.relu(self.dec_fc2(h))
        h = h.view(-1, 64, 7, 7)
        h = F.relu(self.dec_conv1(h))
        return torch.sigmoid(self.dec_conv2(h))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

### VAE Loss

Since MNIST images are in the range of [0,1], let us assume that our probabilistic decoder is a Bernoulli distributio, i.e., $p_\theta(x|z) = \mathrm{Ber}(\hat{x}(z))$ where $\hat{x}(z)$ is the output of our decoder network.\
Further, we assume a Gaussian prior $p(z) = \mathcal{N}(0,I)$

In this case, the ELBO loss becomes:\
$$
\mathcal{L}(x; \theta, \phi) 
= - \mathbb{E}_{q_\phi(z|x)} \Bigg[ \underbrace{\sum_{i=1}^D \Big( x_i \log \hat{x}_{\theta,i}(z) + (1 - x_i) \log (1 - \hat{x}_{\theta,i}(z)) \Big) \Bigg]}_{\text{BCE}(\hat{x}_\theta(z), x)}
+ \frac{1}{2} \sum_{j=1}^d \Big( \mu_{\phi,j}(x)^2 + \sigma_{\phi,j}(x)^2 - \log \sigma_{\phi,j}(x)^2 - 1 \Big)
$$


In [None]:
# Loss
def vae_loss(x_recon, x, mu, logvar):
    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_factor * kl_loss

In [None]:
# Model and optimizer
model = ConvVAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training
model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for x, _ in train_loader:
        x = x.to(device)
        optimizer.zero_grad()
        x_recon, mu, logvar = model(x)
        loss = vae_loss(x_recon, x, mu, logvar)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss/len(train_loader.dataset):.4f}')

In [None]:
# Visualize reconstructed images
model.eval()
with torch.no_grad():
    for x, _ in test_loader:
        x = x.to(device)
        x_recon, _, _ = model(x)
        break

# Plot original and reconstructed images
x = x.cpu()
x_recon = x_recon.cpu()

fig, axes = plt.subplots(2, 10, figsize=(12, 3))
for i in range(10):
    axes[0, i].imshow(x[i].view(28, 28), cmap='gray')
    axes[0, i].axis('off')
    axes[1, i].imshow(x_recon[i].view(28, 28), cmap='gray')
    axes[1, i].axis('off')
plt.suptitle("Top: Original, Bottom: Reconstructed")
plt.show()

# Generate new digits from random latent vectors
with torch.no_grad():
    z = torch.randn(10, latent_dim).to(device)
    samples = model.decode(z).cpu()

fig, axes = plt.subplots(1, 10, figsize=(12, 2))
for i in range(10):
    axes[i].imshow(samples[i].view(28, 28), cmap='gray')
    axes[i].axis('off')
plt.suptitle("Generated digits from random latent vectors")
plt.show()

In [None]:
# Visualize latent space
model.eval()
all_mu = []
all_labels = []
with torch.no_grad():
    for x, labels in test_loader:
        x = x.to(device)
        mu, _ = model.encode(x)
        all_mu.append(mu.cpu())
        all_labels.append(labels)
all_mu = torch.cat(all_mu, dim=0)
all_labels = torch.cat(all_labels, dim=0)

plt.figure(figsize=(8, 6))
scatter = plt.scatter(all_mu[:,0], all_mu[:,1], c=all_labels, cmap='tab10', s=15)
plt.colorbar(scatter, ticks=range(10))
plt.xlabel("z1")
plt.ylabel("z2")
plt.title("2D Latent Space of MNIST")
plt.show()

### Task
Try out different values for `latent_dim` and `kl_factor`. In particular, test what happens for the combinations:\
`latent_dim` $\in \{2,20\}$\
`kl_factor` $\in \{0.0,1.0\}$