<a href="https://colab.research.google.com/github/paulsubarna/FedWit/blob/main/vae_example_tutorial_curated_annotated.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Tutorial 3: Variational AutoEncoder

In this tutorial, we will learn how to model and train our first generative latent variable model: Variational Autoencoders. It essentially connects the pieces that we had covered in our previous lectures if you recall. Happy fun Generating !!


## Imports and Setup


In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.datasets import load_digits
from sklearn import datasets
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F

from pytorch_model_summary import summary

## Dataset


This cell defines a PyTorch Dataset class for the Digits dataset, which contains images $x \in \mathbb{R}^{8 \times 8}$ with 1500 samples and pixel values in $\{0, 1, ..., 16\}$. The dataset is partitioned into train/val/test splits. This setup prepares the data for easy batching while keeping the tutorial lightweight and runnable on CPUs.

In [None]:
class Digits(Dataset):
    """Scikit-Learn Digits dataset."""

    def __init__(self, mode='train', transforms=None):
        digits = load_digits()
        if mode == 'train':
            self.data = digits.data[:1000].astype(np.float32)
        elif mode == 'val':
            self.data = digits.data[1000:1350].astype(np.float32)
        else:
            self.data = digits.data[1350:].astype(np.float32)

        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        if self.transforms:
            sample = self.transforms(sample)
        return sample

**Probability Distributions**

This cell implements log-likelihood functions for categorical, Bernoulli, and diagonal normal distributions, used for the VAE's likelihood and prior terms:
- *Categorical log-likelihood*: $\log p(x) = \sum_i x_i \log p_i$
- *Bernoulli log-likelihood*: $\log p(x) = x \log p + (1-x)\log(1-p)$
- *Diagonal Normal*: $\log \mathcal{N}(x;\mu,\sigma^2) = -0.5 D \log(2\pi) - 0.5\sum_j \log \sigma_j^2 - 0.5 \sum_j \frac{(x_j-\mu_j)^2}{\sigma_j^2}$

The cell enables computation of log-probabilities essential for the VAE's loss.

In [None]:
PI = torch.from_numpy(np.asarray(np.pi))
EPS = 1.e-5

def log_categorical(x, p, num_classes=256, reduction=None, dim=None):

        return log_p

def log_bernoulli(x, p, reduction=None, dim=None):

        return log_p

def log_normal_diag(x, mu, log_var, reduction=None, dim=None):

        return log_p


def log_standard_normal(x, reduction=None, dim=None):

        return log_p

## VAE Encoder

**Encoder Module**

Defines the encoder module $q_\phi(z|x) = \mathcal{N}(z;\mu(x),\sigma^2(x)I)$, outputting parameters $\mu(x)$ and $\log \sigma^2(x)$ for each input. The *reparameterization trick* computes $z = \mu + \sigma \odot \epsilon$ for $\epsilon \sim \mathcal{N}(0,I)$, which makes VAE optimization via stochastic gradient descent possible.

In [None]:
class Encoder(nn.Module):
    def __init__(self, encoder_net):
        super(Encoder, self).__init__()

        self.encoder = encoder_net

    @staticmethod
    def reparameterization(mu, log_var):


    def encode(self, x):


        return

    def sample(self, x=None, mu_e=None, log_var_e=None):

        return z

    def log_prob(self, x=None, mu_e=None, log_var_e=None, z=None):

        return

    def forward(self, x, type='log_prob'):

        return

## Decoder Module

This cell sets up the VAE decoder $p_\theta(x|z)$:
- For categorical likelihood (image pixel classes), outputs are reshaped and a softmax produces pixel probabilities.
- For Bernoulli likelihood, applies a sigmoid.
Given a latent vector $z$, the decoder produces parameters for the likelihood $p_\theta(x|z)$, reconstructing images from codes.

In [None]:
class Decoder(nn.Module):
    def __init__(self, decoder_net, distribution='categorical', num_vals=None):
        super(Decoder, self).__init__()

        self.decoder = decoder_net
        self.distribution = distribution
        self.num_vals=num_vals

    def decode(self, z):

            return

        elif self.distribution == 'bernoulli':

            return

        else:
            raise ValueError('Either `categorical` or `bernoulli`')

    def sample(self, z):


        return x_new

    def log_prob(self, x, z):

        return log_p

    def forward(self, z, x=None, type='log_prob'):


**Prior Module**

Implements the VAE prior $p(z) = \mathcal{N}(0, I)$:
- Sampling: $z \sim \mathcal{N}(0, I)$
- Log-probability: Computes the log-density of $z$ under the standard normal.
Intuition: Regularizes the latent space, forcing it to match a simple, tractable prior distribution.

In [None]:
class Prior(nn.Module):
    def __init__(self, L):
        super(Prior, self).__init__()
        self.L = L

    def sample(self, batch_size):

        return z

    def log_prob(self, z):
        return

## Full VAE, Loss, and Sampling

Defines the complete VAE as a composition:

**Forward pass**:
- Encode $x$ to $ \mu(x), \log\sigma^2(x) $
- Sample $z$ using the reparameterization trick
- Decode $z$ to reconstruct $x$ (e.g., BCE/MSE loss)
- Compute the Evidence Lower Bound (ELBO):
$$
       \mathcal{L}_{\text{VAE}} = \mathbb{E}_{q(z|x)}[\log p(x|z)] - \mathrm{KL}(q(z|x)||p(z))
$$

The KL term is for two diagonal Gaussians and the reconstruction from the chosen likelihood.

### Sampling
This cell draws $z\sim\mathcal{N}(0,I)$ and decodes to generate new samples $\hat x\sim p_\theta(x\mid z)$.

In [None]:
class VAE(nn.Module):
    def __init__(self, encoder_net, decoder_net, num_vals=256, L=16, likelihood_type='categorical'):
        super(VAE, self).__init__()

        print('VAE by JT.')

        self.encoder = Encoder(encoder_net=encoder_net)
        self.decoder = Decoder(distribution=likelihood_type, decoder_net=decoder_net, num_vals=num_vals)
        self.prior = Prior(L=L)

        self.num_vals = num_vals

        self.likelihood_type = likelihood_type

    def forward(self, x, reduction='avg'):


    def sample(self, batch_size=64):

        return

## Auxiliary Functions: Training, Evaluation, Plotting

Implements utility functions for model training, validation/testing, and result visualization:
- Computes and tracks loss (negative ELBO)
- Saves models and generated samples
- Plots the training curve, outputs images
These tools facilitate model development and result inspection.

In [None]:
def evaluation(test_loader, name=None, model_best=None, epoch=None):
    # EVALUATION

    return loss


def samples_real(name, test_loader):
    # REAL-------
    num_x = 4
    num_y = 4
    x = next(iter(test_loader)).detach().numpy()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x[i], (8, 8))
        ax.imshow(plottable_image, cmap='gray')
        ax.axis('off')

    plt.savefig(name+'_real_images.pdf', bbox_inches='tight')
    plt.close()


def samples_generated(name, data_loader, extra_name=''):
    x = next(iter(data_loader)).detach().numpy()

    # GENERATIONS-------
    model_best = torch.load(name + '.model')
    model_best.eval()

    num_x = 4
    num_y = 4
    x = model_best.sample(num_x * num_y)
    x = x.detach().numpy()

    fig, ax = plt.subplots(num_x, num_y)
    for i, ax in enumerate(ax.flatten()):
        plottable_image = np.reshape(x[i], (8, 8))
        ax.imshow(plottable_image, cmap='gray')
        ax.axis('off')

    plt.savefig(name + '_generated_images' + extra_name + '.pdf', bbox_inches='tight')
    plt.close()


def plot_curve(name, nll_val):
    plt.plot(np.arange(len(nll_val)), nll_val, linewidth='3')
    plt.xlabel('epochs')
    plt.ylabel('nll')
    plt.savefig(name + '_nll_val_curve.pdf', bbox_inches='tight')
    plt.close()

### Optimization setup
This cell configures the optimizer and learning-rate schedule used to minimize $\mathcal{L}_{\text{VAE}}=\text{ReconLoss}+\mathrm{KL}$.

### Training loop
For each batch: encode $x\to(\mu,\sigma)$, sample $z$, decode to $\hat x$, compute losses, backprop, and update parameters.

In [None]:
def training(name, max_patience, num_epochs, model, optimizer, training_loader, val_loader):
    nll_val = []
    best_nll = 1000.
    patience = 0

    # Main loop


        # Validation


        if patience > max_patience:
            break

    nll_val = np.asarray(nll_val)

    return nll_val

## Initialize dataloaders

**Hyperparameters**

Specifies model hyperparameters:
- Input dim $D$, latent dim $L$, network size $M$
- Learning rate, epochs, early stopping
These control model structure, speed, and regularization.

In [None]:
train_data = Digits(mode='train')
val_data = Digits(mode='val')
test_data = Digits(mode='test')

training_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

result_dir = 'results/'
if not(os.path.exists(result_dir)):
    os.mkdir(result_dir)
name = 'vae'

### Hyperparams

**Instantiate Model & Summaries**

Chooses the likelihood type, configures encoder/decoder, and creates the VAE object. Model summaries help verify correct shapes and parameters.

In [None]:
D = 64   # input dimension
L = 16  # number of latents
M = 256  # the number of neurons in scale (s) and translation (t) nets

lr = 1e-3 # learning rate
num_epochs = 1000 # max. number of epochs
max_patience = 20 # an early stopping is used, if training doesn't improve for longer than 20 epochs, it is stopped

### Initialize VAE

**Optimizer Setup**

Configures the optimizer (Adamax) and sets the learning rate for updating model parameters.

In [None]:
likelihood_type = 'categorical'

if likelihood_type == 'categorical':
    num_vals = 17
elif likelihood_type == 'bernoulli':
    num_vals = 1

encoder =
decoder =

prior =
model = VAE(encoder_net=encoder, decoder_net=decoder, num_vals=num_vals, L=L, likelihood_type=likelihood_type)

# Print the summary (like in Keras)
print("ENCODER:\n", summary(encoder, torch.zeros(1, D), show_input=False, show_hierarchical=False))
print("\nDECODER:\n", summary(decoder, torch.zeros(1, L), show_input=False, show_hierarchical=False))

### Let's play! Training

### Optimization setup
This cell configures the optimizer and learning-rate schedule used to minimize $\mathcal{L}_{\text{VAE}}=\text{ReconLoss}+\mathrm{KL}$.

**How to verify it worked:** run without errors; if training, confirm the loss decreases across iterations/epochs.

**Training Procedure**

Runs the training loop with early stopping. Periodically saves weights if validation loss improves and visualizes progress.

In [None]:
# OPTIMIZER
optimizer = torch.optim.Adamax([p for p in model.parameters() if p.requires_grad == True], lr=lr)

**Testing and Visualization**

Evaluates the best model on the test set, saves loss and sample images, and learning curve plots.

In [None]:
# Training procedure
nll_val = training(name=result_dir + name, max_patience=max_patience, num_epochs=num_epochs, model=model, optimizer=optimizer,
                       training_loader=training_loader, val_loader=val_loader)


**How to verify it worked:** run without errors; if training, confirm the loss decreases across iterations/epochs.

In [None]:
test_loss = evaluation(name=result_dir + name, test_loader=test_loader)
f = open(result_dir + name + '_test_loss.txt', "w")
f.write(str(test_loss))
f.close()

samples_real(result_dir + name, test_loader)

plot_curve(result_dir + name, nll_val)

## Visualization TO-DO

### **Objective**
Analyze and interpret the geometry of the learned latent space.

### **Instructions**
1. Select a dataset sample (images or other input) and encode it into the latent space using the VAE encoder.
2. Apply a dimensionality reduction technique if needed (e.g., PCA or t-SNE) to project latent vectors to 2D.
3. Visualize the resulting 2D latent representations.
4. If labels are available, color points by class.
5. Train the model with different number of latent variables




