# Deep Learning Course - Spring 2025 - Sharif University of Technology
## Homework 4 - VAE (100 points)

*Instructor:  Dr. Soleymani*

---

*Full Name:*

*SID:*

---

### Variational AutoEncoder(VAE)
In this notebook we want to implement VAE and also get familiar with latent space and downstream tasks we can do with these latents.

First, lets downlaod the dataset(Fashion MNIST) and create train and test data-loaders.
These parts of code are implemented and you don't need to change them.

In [None]:
import argparse
import os
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

train_data = datasets.FashionMNIST('./data', train=True, download=True,
                            transform=transforms.ToTensor())
test_data = datasets.FashionMNIST('./data', train=False,
                           transform=transforms.ToTensor())

# Set the device to GPU if available, otherwise use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"

# pin memory provides improved transfer speed
kwargs = {'num_workers': 1, 'pin_memory': True} if device == 'cuda' else {}


train_loader = torch.utils.data.DataLoader(train_data,
                                           batch_size=128, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(test_data,
                                          batch_size=128, shuffle=True, **kwargs)

In [None]:
import matplotlib.pyplot as plt

def plot_image_with_label(image, label):

    # Map from numerical labels to real labels
    fashion_mnist_classes = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
    ]

    # Remove the channel dimension and convert to numpy for plotting
    image = image.squeeze().numpy()

    # Plot the image
    plt.imshow(image, cmap='gray')
    plt.title(f"Label: {fashion_mnist_classes[label]}")
    plt.axis('off')  # Hide the axes
    plt.show()

In [None]:
plot_image_with_label(train_data[0][0], train_data[0][1])

### AutoEncoder
In this cell, you will implement an AutoEncoder. AutoEncoder consists of an encoder and a decoder. Encoder is responsible for mapping image to latent dimension where the image gets compressed in a low dimensional vector.

Fashion-MNIST images have (28, 28) shape, which is a (784, ) vector of real numbers. We want to compress this vector and just keep the most important and informational part of data (like what we do in PCA).

Then in decoder, we need to reconstruct the latent vector and recreate the given image with minimum loss. So the decoder does the opposite of what encoder does.

In [None]:

class AutoEncoder(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.latent_dim = latent_dim
        # encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, latent_dim),
        )
        # decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 28*28),
            nn.Sigmoid(),
        )

    def encode(self, x):
        z = self.encoder(x)
        return z

    def decode(self, z):
        x_hat = self.decoder(z)
        return x_hat

    def forward(self, x):
        z = self.encode(x)
        x_hat = self.decode(z)
        return x_hat, z


Define model, optimizer, and loss function here.
Reconstruction loss forces model to recreate given image from z latent with minimum distance.

In [None]:

# model and optimizer
AE = AutoEncoder().to(device)
optimizer = optim.Adam(AE.parameters(), lr=1e-3)

# loss function (binary cross entropy)
def loss_function(recon_x, x):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    return BCE


Train the model with the reconstruction loss

In [None]:

def train(epochs, model):
    model.train()
    for epoch in range(epochs):
        train_loss = 0
        for batch_idx, (data, label) in enumerate(train_loader):
            # data: [batch size, 1, 28, 28]

            #call optimizer
            optimizer.zero_grad()

            data = data.to(device)
            data = data.view(-1, 28*28)

            # call model
            recon_batch, _ = model(data)
            loss = loss_function(recon_batch, data)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()

        print(f"Epoch {epoch+1}: avg loss {train_loss/len(train_loader.dataset):.4f}")


reconstruct the first image of test dataset and plot it.

In [None]:

image, label = test_data[0]
plot_image_with_label(image, label)

# encode and reconstruct model and plot
with torch.no_grad():
    x_decode, z = AE(image.unsqueeze(0).to(device).view(1, -1))
    plot_image_with_label(x_decode.view(28, 28).to("cpu"), label)


Try to sampe from AE. create random latent vectors, then decode them. you would possibly face with meaningless images.

In [None]:

def sample_from_AE():
    with torch.no_grad():
        # create random vector
        z = torch.randn(1, AE.latent_dim, device=device)

        #decode and plot it
        x = AE.decode(z).view(28, 28).cpu()
        plot_image_with_label(x, 1) # random label


In [None]:
for _ in range(5):
    sample_from_AE()




# Creating an instance of the MNIST Dataset with



batch_size = 128
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Create a Dataloader instance for loading data in batches
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)



In [None]:

class VariationalAE(AutoEncoder):
    def __init__(self):
        super().__init__()
        self.mu_layer = nn.Linear(128, self.latent_dim)
        self.logvar_layer = nn.Linear(128, self.latent_dim)

    def encode(self, x):
        h = self.encoder[:-1](x)  # up to 128-dim
        mu = self.mu_layer(h)
        logvar = self.logvar_layer(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


Define model, optimizer, and loss here and then train vae.

The loss function in a Variational Autoencoder (VAE) consists of two components:

1. **Reconstruction Loss** (`criterion(decoded, images)`)  
   - Measures how well the decoded output matches the original input.  
   - Typically, Mean Squared Error (MSE) or Binary Cross-Entropy (BCE) is used.  
   - Encourages the VAE to accurately reconstruct inputs.

2. **Kullback-Leibler Divergence (KLD)**  
   ```python
   KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())

**KLD** Measures how much the learned latent distribution deviates from a standard normal distribution (N(0, I)). This value is derived from the KL divergence formula between two Gaussians. It encourages the latent space to be continuous and structured, improving generative capabilities.

In [None]:

VAE = VariationalAE().to(device)
optimizer = optim.Adam(VAE.parameters(), lr=1e-3)

def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


In [None]:

def train(epochs, model):
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for batch_idx, (data, label) in enumerate(train_loader):
            # data: [batch size, 1, 28, 28]
            data = data.to(device).view(-1, 28*28)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            loss = loss_function(recon_batch, data, mu, logvar)
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
        print(f"Epoch {epoch+1}: avg loss {train_loss/len(train_loader.dataset):.4f}")

epochs = 15
train(epochs, VAE)


Sample from VAE. Create noraml(0,1) variables then decode them and plot them.

In [None]:

import numpy as np

def plot_multiple_images(N, images):
    images = images.numpy().squeeze(1)
    cols = int(np.ceil(np.sqrt(N)))
    rows = int(np.ceil(N / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(cols, rows))
    axes = axes.flatten()
    for i in range(rows * cols):
        if i < N:
            axes[i].imshow(images[i], cmap='gray')
            axes[i].axis('off')
        else:
            axes[i].remove()

    plt.tight_layout()
    plt.show()

def sample_from_model(N, model):
    with torch.no_grad():
        # p(z) = N(0,I), this distribution is used when calculating KLD. So we can sample z from N(0,I)
        sample = torch.randn(N, model.latent_dim, device=device)
        # decode samples
        images = model.decode(sample).view(N, 1, 28, 28).cpu()
        #plot them
        plot_multiple_images(N, images)


In [None]:
sample_from_model(9, VAE)

# Traversing Latent Dimensions in VAE

This function explores how individual dimensions of the **latent space** influence the generated output. By systematically varying a latent dimensions , we can understand how each dimension encodes different aspects of the data.

create an arange for each of dimensions. Then decode all of these combinations and plot them in a single grid.

In [None]:
def plot_along_axis(model):
    # create aranges
    z1 =
    z2 =
    num_z1 = z1.shape[0]
    num_z2 = z2.shape[0]
    num_z = num_z1 * num_z2

    sample = torch.zeros(num_z, 2).to(device)

    # create all possible combinations
    for i in range(num_z1):
        for j in range(num_z2):
            idx =
            sample[idx][0] =
            sample[idx][1] =
    # decode and plot them
    with torch.no_grad():
        sample =
    plot_multiple_images(num_z, sample)


In [None]:
plot_along_axis(VAE)

# Clustering the VAE Latent Space

This function analyzes the structure of the **latent space** in a Variational Autoencoder (VAE) by applying **K-Means clustering** and visualizing the results.

Extract the latent space of test dateset. Use KMeans to cluster this space and plot it in 2d space.

Plot one image from each cluster to see if they really represent labels from real dataset.

In [None]:

from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
import seaborn as sns

def cluster_latent_space(vae, dataloader, n_clusters, device):
    vae.eval()
    latents = []
    labels = []

    # iterate through images
    with torch.no_grad():
        for img, label in dataloader:
            img = img.to(device).view(-1, 28*28)
            mu, logvar = vae.encode(img)
            z = vae.reparameterize(mu, logvar)
            latents.append(z.cpu())
            labels.append(label)

    latents = torch.cat(latents, dim=0).numpy()
    labels = torch.cat(labels, dim=0).numpy()

    # KMeans clustering
    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(latents)

    # t-SNE for visualization
    tsne = TSNE(n_components=2, random_state=42)
    latents_2d = tsne.fit_transform(latents)

    plt.figure(figsize=(8, 6))
    sns.scatterplot(x=latents_2d[:, 0], y=latents_2d[:, 1], hue=clusters, palette='tab10', legend=False)
    plt.title('Latent space clustering')
    plt.show()

    return latents, clusters, labels


From each cluster, generate an image. Ideally, we expect to get 10 different images from 10 classes.

In [None]:

# Generate one sample per cluster center (approximate)
latents, clusters, labels = cluster_latent_space(VAE, test_loader, n_clusters=10, device=device)
cluster_images = []
for c in range(10):
    # take mean latent of cluster c
    z_c = torch.tensor(latents[clusters == c].mean(axis=0), device=device).unsqueeze(0)
    with torch.no_grad():
        img = VAE.decode(z_c).view(28, 28).cpu()
    cluster_images.append(img)

# Plot cluster representatives
for idx, img in enumerate(cluster_images):
    plot_image_with_label(img, idx)


# Adding a Classifier to VAE

This code extends the **Variational Autoencoder (VAE)** by adding a **classification head**. This allows the model to **predict class labels** from the latent space.

Add a linear layer that gets encoding of image as input and outputs the image's label. Change the forward function to also returns class logits. You can also add a classify function to returns just the class logits.

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

class VAEWithClassifier(VariationalAE):
    def __init__(self, num_classes):
        super().__init__()
        self.classifier = nn.Linear(self.latent_dim, num_classes)

    def forward(self, x):
        recon, mu, logvar = super().forward(x)
        logits = self.classifier(mu)
        return recon, mu, logvar, logits


Load the pretrained vae weights, then use CrossEntropyLoss to train the model for a few epochs.



In [None]:

def fine_tune_vae(vae_model, train_loader, test_loader, num_classes, num_epochs=4, learning_rate=1e-3, device=device):
    # Add the classification head to the VAE
    model = VAEWithClassifier(num_classes).to(device)

    # Load the pre-trained VAE weights
    model.load_state_dict(vae_model.state_dict(), strict=False)

    # Define the loss function
    recon_criterion = F.binary_cross_entropy
    cls_criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        for data, labels in train_loader:
            data = data.to(device).view(-1, 28*28)
            labels = labels.to(device)

            optimizer.zero_grad()
            recon, mu, logvar, logits = model(data)
            recon_loss = recon_criterion(recon, data, reduction='sum')
            kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            cls_loss = cls_criterion(logits, labels)
            loss = recon_loss + kl_loss + cls_loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}: loss {total_loss/len(train_loader.dataset):.4f}")

    return model

fine_tuned_model = fine_tune_vae(VAE, train_loader, test_loader, num_classes=10, num_epochs=4)


Evaluate fine tuned model on test data.

In [None]:

fine_tuned_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, labels in test_loader:
        data = data.to(device).view(-1, 28*28)
        labels = labels.to(device)
        _, mu, _, logits = fine_tuned_model(data)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")


### Adversarial Examples in Variational Autoencoders

In this experiment, we explore how a small, intentional perturbation to the input image can significantly alter the VAE's output, even though the input still looks visually similar.

We use the **Fast Gradient Sign Method (FGSM)** to generate an *adversarial image* that "fools" the VAE.

#### Steps:

1. Take a test image.
2. Compute the VAEâ€™s reconstruction loss and backpropagate to get gradients w.r.t. the input.
3. Generate a perturbed image
4. Clamp pixel values to keep them in the valid range \([0, 1]\).
5. Feed both original and adversarial images through the VAE.
6. Compare:
   - Original vs Reconstruction
   - Adversarial vs Reconstruction

In [None]:

# Get an input image from test_loader
x, _ = next(iter(test_loader))
x = x[0:1].to(device)
x.requires_grad = True

# Forward pass
recon, mu, logvar = VAE(x.view(-1, 28*28))

# Compute reconstruction loss
loss = F.binary_cross_entropy(recon, x.view(-1, 28*28), reduction='sum')
loss.backward()

# Generate adversarial example (FGSM)
eta = 0.1 * x.grad.sign()
x_adv = x + eta
x_adv = torch.clamp(x_adv, 0, 1)  # Ensure pixel range is valid

# Reconstruct adversarial
with torch.no_grad():
    recon_adv, _, _ = VAE(x_adv.view(-1, 28*28))

plot_image_with_label(x.squeeze(0).cpu(), 0)
plot_image_with_label(x_adv.squeeze(0).cpu(), 0)
plot_image_with_label(recon_adv.view(28, 28).cpu(), 0)


###Fine-Tuning VAE for MNIST Digit Classification

In this experiment, we explore **cross-domain representation learning** by fine-tuning a **VAE originally trained on Fashion-MNIST** to classify digits from the **MNIST** dataset.

Although the VAE never saw digits during training, it learned general-purpose visual features (edges, strokes, textures). We now test whether these features can be **repurposed** to classify a completely different kind of visual data.

#### Steps:
1. The **Fashion-MNIST-trained VAE encoder** is reused as a feature extractor.
2. A **classifier head** is added to the encoder.
3. The model is fine-tuned on MNIST digit labels.
4. few epochs (one) is used to test how quickly it can adapt.

Calculate test accuracy on MNIST after fine-tuning.

In [None]:
from torch.utils.data import DataLoader

transform = transforms.ToTensor()

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.ToTensor())
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

In [None]:
num_classes = 10
fine_tuned_model = fine_tune_vae(VAE, train_loader, test_loader, num_classes, num_epochs=1)

In [None]:
fine_tuned_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, labels in test_loader:
        data, labels = data.to(device).view(data.shape[0], -1), labels.to(device)
        class_logits = fine_tuned_model(data)
        _, predicted = torch.max(class_logits.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")