# Introduction to Generative AI and Hands-On with Variational Autoencoders (VAEs)

**Learning Objectives:**
- Understand what Generative AI is and its real-world applications.
- Compare generative models to traditional deep learning approaches.
- Explain autoencoders, latent spaces, and the concept behind VAEs.
- Gain practical experience by exploring latent space interpolation using a VAE on CelebA.

## 🔧 Prerequisites

Before running this notebook locally, make sure to set up your environment:

```bash
    # Create a virtual environment
    python -m venv genai-env
    source genai-env/bin/activate  # Use `genai-env\\Scripts\\activate` on Windows

    # Install required packages
    pip install torch torchvision matplotlib numpy
```

## Generative AI vs Traditional approaches


**Traditional DL** Focus on classification, regression, etc.

For example, in image classification, given an image $x$ classify if it is a cat or dog.

<img src="./images/classification-object-detection.png" width="600" style="display: block; margin: auto;">

*Image Source: [ambolt.io](https://ambolt.io/en/image-classification-and-object-detection/)*  

**Generative AI** models on the other hand learn the underlying data distribution  $p^{*}(x)$ so that we can generate new, synthetic data similar to the training examples by sampling from $ x \sim p^{*}(x) $.

Example of an AI generated image of a kitten.

<img src="./images/AI_generated_kitten.webp" width="300" style="display: block; margin: auto;">

*Image Source: [Dall-E](https://openai.com/dall-e-3)*

## Latent Spaces & Autoencoders


**Latent space** is a simplified, compressed representation of data where similar items are grouped closer together, capturing hidden patterns or features that define their structure. It's like a map that organizes complex data into meaningful, lower-dimensional coordinates.

Example of the latent space of the mnist dataset
<img src="./images/aae_latent.png" width="500" style="display: block; margin: auto;">
*Image Source: [https://github.com/greentfrapp/keras-aae](https://github.com/greentfrapp/keras-aae)*

**Autoencoders** capture the underlying latent space.
- **Autoencoder Architecture:**  
  - **Encoder:** Compresses the input data into a lower-dimensional latent representation.
  - **Decoder:** Reconstructs the input from the latent representation.
- **Loss Function:** Typically, a reconstruction loss (e.g., Mean Squared Error) is used to train the autoencoder.
- **Latent Space:**  
  - It represents the compressed, learned features.
  - Enables operations like interpolation between data points and semantic modifications.

<img src="./images/variational-autoencoder-neural-network.png" width="900" style="display: block; margin: auto;">

*Image Source: [ibm.com](https://www.ibm.com/de-de/think/topics/variational-autoencoder)*

# MNIST Autoencoder in PyTorch

In this following, we’ll build and train a simple autoencoder on the MNIST dataset using PyTorch.


In [1]:
# Import necessary libraries
import torch
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch.nn as nn

import matplotlib.pyplot as plt

In [2]:
import wandb
wandb.login()  # only needed once per machine/session

config = {
    "epochs": 40,
    "batch_size": 64,
    "learning_rate": 1e-3,
    "architecture": "Autoencoder",
    "dataset": "MNIST",
    "latent_dim": 10,
    "train_split": 0.8
}

wandb.init(
    project="mnist-autoencoder",
    config=config,
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhussam-alafandi[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


## Load the MNIST Dataset

We use the torchvision `MNIST` dataset class and apply a basic transformation to convert images to tensors.


In [3]:
# Load MNIST dataset
transform = ToTensor()

train_data = MNIST(root='./data', train=True, transform=transform, download=True)
val_data = MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)

## Define the Autoencoder Model

The autoencoder has two main parts:  
- **Encoder:** Compresses the 28x28 image into a smaller latent vector  
- **Decoder:** Reconstructs the image from the latent vector  
We use linear layers with ReLU activations and end the decoder with a sigmoid activation.


In [4]:
# Define the Autoencoder model
class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, wandb.config.latent_dim),
        )
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(wandb.config.latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 28 * 28),
            nn.Sigmoid()
        )

    def forward(self, x):
        x_flat = x.flatten(start_dim=1)

        encoded = self.encoder(x_flat)
        decoded = self.decoder(encoded)

        return decoded.reshape_as(x)


## Initialize the Model, Loss Function, and Optimizer

We’ll use:
- **MSELoss** to measure reconstruction error.
- **Adam** optimizer for training.
Make sure to define the device (CPU or GPU) before using `.to(device)`.


In [5]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model, loss function, and optimizer
autoencoder = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=wandb.config.learning_rate)


## Train the Autoencoder

Loop through the dataset for multiple epochs.  
For each batch:
- Forward pass through the model  
- Compute the loss  
- Backpropagate and update weights  
- Track the loss for monitoring


In [6]:
def evaluate_model(model, dataloader, criterion):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for images, _ in dataloader:
            images = images.to(device)
            outputs = model(images)
            loss = criterion(outputs, images)
            val_loss += loss.item()
    return val_loss / len(dataloader)

In [7]:
import matplotlib.pyplot as plt

def log_reconstruction_images(model, dataloader, device, n=6, tag="Reconstruction"):
    model.eval()
    images, _ = next(iter(dataloader))
    images = images.to(device)
    with torch.no_grad():
        outputs = model(images)

    # Plot original and reconstructed
    fig, axs = plt.subplots(2, n, figsize=(n * 2, 4))
    for i in range(n):
        axs[0, i].imshow(images[i].squeeze().cpu(), cmap="gray")
        axs[0, i].set_title("Original")
        axs[0, i].axis("off")
        axs[1, i].imshow(outputs[i].squeeze().cpu(), cmap="gray")
        axs[1, i].set_title("Reconstructed")
        axs[1, i].axis("off")
    plt.tight_layout()

    # Log to wandb
    wandb.log({tag: wandb.Image(fig)})
    plt.close()

In [8]:
for epoch in range(wandb.config.epochs):
    autoencoder.train()
    total_loss = 0

    for images, _ in train_loader:
        images = images.to(device)
        optimizer.zero_grad()
        outputs = autoencoder(images)
        loss = criterion(outputs, images)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    avg_val_loss = evaluate_model(autoencoder, val_loader, criterion)

    wandb.log({
        "epoch": epoch + 1,
        "train_loss": avg_train_loss,
        "val_loss": avg_val_loss,
    })

    print(f"Epoch {epoch+1}/{wandb.config.epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

    # Every 5 epochs: log images
    if (epoch + 1) % 5 == 0:
        log_reconstruction_images(autoencoder, val_loader, device)

wandb.finish()



Epoch 1/40, Train Loss: 0.0499, Val Loss: 0.0322
Epoch 2/40, Train Loss: 0.0279, Val Loss: 0.0242
Epoch 3/40, Train Loss: 0.0231, Val Loss: 0.0215
Epoch 4/40, Train Loss: 0.0212, Val Loss: 0.0203
Epoch 5/40, Train Loss: 0.0202, Val Loss: 0.0193
Epoch 6/40, Train Loss: 0.0194, Val Loss: 0.0189
Epoch 7/40, Train Loss: 0.0188, Val Loss: 0.0184
Epoch 8/40, Train Loss: 0.0184, Val Loss: 0.0179
Epoch 9/40, Train Loss: 0.0180, Val Loss: 0.0176
Epoch 10/40, Train Loss: 0.0177, Val Loss: 0.0175
Epoch 11/40, Train Loss: 0.0174, Val Loss: 0.0172
Epoch 12/40, Train Loss: 0.0172, Val Loss: 0.0169
Epoch 13/40, Train Loss: 0.0170, Val Loss: 0.0167
Epoch 14/40, Train Loss: 0.0168, Val Loss: 0.0167
Epoch 15/40, Train Loss: 0.0166, Val Loss: 0.0168
Epoch 16/40, Train Loss: 0.0165, Val Loss: 0.0164
Epoch 17/40, Train Loss: 0.0163, Val Loss: 0.0162
Epoch 18/40, Train Loss: 0.0162, Val Loss: 0.0161
Epoch 19/40, Train Loss: 0.0161, Val Loss: 0.0160
Epoch 20/40, Train Loss: 0.0160, Val Loss: 0.0160
Epoch 21/

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▄▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,█▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,40.0
train_loss,0.01475
val_loss,0.015


## Save the Trained Model

save the trained model for later use.


In [10]:
from pathlib import Path

# Save the trained model (optional)
model_dir = "autoencoder-mnist"

# Create the directory if it doesn't exist
Path(model_dir).mkdir(parents=True, exist_ok=True)

# Save the model
torch.save(autoencoder.state_dict(), Path(model_dir) / "autoencoder_mnist.pth")


## Variational Autoencoders (VAEs)

**Motivation for VAEs:**
- **Regularization:** VAEs enforce a continuous, smooth latent space by modeling the encoder output as a probability distribution.
- **Key Components:**  
  - **KL Divergence:** A regularization term that encourages the latent distribution to be close to a prior (usually a standard Gaussian).
  - **Reparameterization Trick:** Allows backpropagation through stochastic nodes.
- **Benefits:**  
  - Smooth interpolation between data points.
  - More meaningful manipulations in the latent space.

*Concept Check:* Think about how changing parts of a latent vector might modify attributes (e.g., facial expression, pose) in generated images.


## Section 4: Hands-On Exercise – VAE with CelebA

In this section, we’ll work with a VAE on the CelebA dataset. The goals are:
- Load and preprocess the CelebA dataset.
- Define and (optionally) load a pretrained VAE model.
- Encode images to obtain their latent representations.
- Interpolate between latent vectors and decode the results to see how the generated images change.

**Note:** Training a VAE on CelebA from scratch is computationally intensive. For this demonstration, you can either:
- Use a pretrained model checkpoint (if available), or
- Train a simplified model on a subset of the dataset for a few epochs.


In [None]:
# Cell 1: Setup and Library Imports

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
# Cell 2: Data Preparation - Load CelebA (subset)

# Define transformations for the CelebA dataset
transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize(64),  # Resize to a smaller size for quick experimentation
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Download CelebA dataset (this may take a few minutes; you can also set download=False if already downloaded)
celeba_data = datasets.CelebA(root='./data', split='train', transform=transform, download=True)

# Create a DataLoader with a small subset (e.g., first 500 images)
subset_size = 500
subset_indices = list(range(subset_size))
celeba_subset = torch.utils.data.Subset(celeba_data, subset_indices)
data_loader = DataLoader(celeba_subset, batch_size=64, shuffle=True)

print("Number of images in subset:", len(celeba_subset))


## Section 5: Defining a Simple VAE Model

Below is a simplified convolutional VAE model. In a real-world course, you might provide a more refined architecture or a pretrained checkpoint. For this exercise, we define the model and show how to perform encoding and decoding.


In [None]:
# Cell 3: Define the VAE Model

class VAE(nn.Module):
    def __init__(self, latent_dim=32):
        super(VAE, self).__init__()
        self.latent_dim = latent_dim
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),  # 64 -> 32
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),  # 32 -> 16
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1), # 16 -> 8
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1), # 8 -> 4
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(256*4*4, latent_dim)
        self.fc_logvar = nn.Linear(256*4*4, latent_dim)
        
        # Decoder
        self.decoder_input = nn.Linear(latent_dim, 256*4*4)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1), # 4 -> 8
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 8 -> 16
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),   # 16 -> 32
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=4, stride=2, padding=1),    # 32 -> 64
            nn.Tanh()  # Output values in [-1, 1] due to normalization
        )
        
    def encode(self, x):
        x_enc = self.encoder(x)
        x_enc = x_enc.view(x_enc.size(0), -1)
        mu = self.fc_mu(x_enc)
        logvar = self.fc_logvar(x_enc)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std
    
    def decode(self, z):
        x_dec = self.decoder_input(z)
        x_dec = x_dec.view(-1, 256, 4, 4)
        x_recon = self.decoder(x_dec)
        return x_recon
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decode(z)
        return x_recon, mu, logvar

# Initialize the VAE model and move it to the device
latent_dim = 32
model = VAE(latent_dim=latent_dim).to(device)
print(model)


## Section 6: (Optional) Training the VAE

Training a VAE can be time-consuming. For demonstration, you can either load a pretrained model or train for a few epochs on the subset. Below is a simple training loop that you can run if you wish to see the model learn.


In [None]:
# Cell 4: Training Setup (Optional)

# Loss function components: Reconstruction Loss and KL Divergence
def loss_function(recon_x, x, mu, logvar):
    # Reconstruction loss (MSE)
    recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    # KL Divergence
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_loss

optimizer = optim.Adam(model.parameters(), lr=1e-3)
num_epochs = 3  # For demonstration; increase for better results

model.train()
for epoch in range(num_epochs):
    train_loss = 0
    for batch_idx, (data, _) in enumerate(data_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    
    print(f"Epoch {epoch+1} Average Loss: {train_loss/len(celeba_subset):.4f}")

# After training, you can save the model if desired:
# torch.save(model.state_dict(), "vae_celeba.pth")


## Section 7: Latent Space Manipulation & Interpolation

Now that our VAE is (optionally) trained, let’s demonstrate latent space interpolation. We will:
1. Select two images from the dataset.
2. Encode them to obtain their latent vectors.
3. Linearly interpolate between the two latent vectors.
4. Decode each interpolated vector to visualize the transition between the two images.


In [None]:
# Cell 5: Latent Space Interpolation Demo

def interpolate(z1, z2, num_steps=8):
    """Generate a series of interpolated latent vectors between z1 and z2."""
    ratios = np.linspace(0, 1, num_steps)
    interpolated = [ (1 - r) * z1 + r * z2 for r in ratios ]
    return torch.stack(interpolated)

# Get two images from the dataset
data_iter = iter(data_loader)
images, _ = next(data_iter)
img1 = images[0].unsqueeze(0).to(device)
img2 = images[1].unsqueeze(0).to(device)

# Encode images
model.eval()
with torch.no_grad():
    mu1, _ = model.encode(img1)
    mu2, _ = model.encode(img2)

# Interpolate in the latent space
num_steps = 8
interpolated_z = interpolate(mu1, mu2, num_steps=num_steps)

# Decode the interpolated latent vectors
with torch.no_grad():
    decoded_imgs = model.decode(interpolated_z).cpu()

# Plot the interpolation results
fig, axes = plt.subplots(1, num_steps, figsize=(20, 3))
for i, ax in enumerate(axes):
    # Denormalize image from [-1,1] to [0,1]
    img = (decoded_imgs[i].permute(1, 2, 0).numpy() + 1) / 2.0
    ax.imshow(np.clip(img, 0, 1))
    ax.axis('off')
plt.suptitle("Latent Space Interpolation between Two CelebA Images", fontsize=16)
plt.show()


## Section 8: Discussion & Next Steps

- **Discussion Points:**  
  - What do you observe in the interpolation results?  
  - How does the latent space capture semantic features of faces?
  - How might you extend this approach (e.g., by conditioning on attributes)?

- **Further Experiments:**  
  - Try random sampling from the latent space.
  - Modify specific dimensions in the latent vector to observe changes in output attributes.
  - Explore different architectures or pretrained models for improved image quality.

---

## Homework / Reflection

- **Write a short summary:** Explain in your own words what a VAE is and how the latent space allows for creative manipulation of images.
- **Experiment:** If you have extra time, modify the interpolation function to explore non-linear paths in latent space.

---

**Recommended Reading & Resources:**
- [A Beginner’s Guide to VAEs (arXiv)](https://arxiv.org/abs/1906.02691)
- [PyTorch VAE Tutorial](https://github.com/pytorch/examples/tree/main/vae)
- [Understanding Latent Space](https://towardsdatascience.com/understanding-latent-space-in-machine-learning-4de5473c44f)

---

*End of Day 1 Notebook*

Feel free to ask if you need any modifications or further details on any section!
