In [1]:

from typing import *
import importlib
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Image, display, clear_output
import numpy as np
import seaborn as sns
import pandas as pd
from torch import nn, Tensor
from torch.distributions import Distribution, Dirichlet as TorchDirichlet
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from functools import reduce
from torch.distributions import Bernoulli
from plotting import make_vae_plots
import plotting
importlib.reload(plotting)
import math 
import torch
from torch.nn.functional import softplus
from collections import defaultdict
import torch.nn.functional as F

In [2]:
# Define the train and test sets
dset_train = MNIST("./", train=True,  transform=ToTensor(), download=True)
dset_test  = MNIST("./", train=False, transform=ToTensor())

# The digit classes to use
classes = [3, 7]

def stratified_sampler(labels):
    """Sampler that only picks datapoints corresponding to the specified classes"""
    (indices,) = np.where(reduce(lambda x, y: x | y, [labels.numpy() == i for i in classes]))
    indices = torch.from_numpy(indices)
    return SubsetRandomSampler(indices)


In [3]:

batch_size = 64
eval_batch_size = 100
# The loaders perform the actual work
train_loader = DataLoader(dset_train, batch_size=batch_size,
                          sampler=stratified_sampler(dset_train.train_labels))
test_loader  = DataLoader(dset_test, batch_size=eval_batch_size, 
                          sampler=stratified_sampler(dset_test.test_labels))
sns.set_style("whitegrid")

images, labels = next(iter(train_loader))

def reduce(x: Tensor) -> Tensor:
    """Reduce only if tensor has more than one dimension."""
    return x if x.ndim == 1 else x.view(x.size(0), -1).sum(dim=-1)



In [4]:

class VariationalInference(nn.Module):
    def __init__(self, beta: float = 1.):
        super().__init__()
        self.beta = beta
        
    def forward(self, model: nn.Module, x: Tensor):
        # forward pass through the model
        outputs = model(x)
        
        # unpack outputs
        px, pz, qz, z = [outputs[k] for k in ["px", "pz", "qz", "z"]]
        
        # evaluate log probabilities
        log_px = reduce(px.log_prob(x))
        log_pz = reduce(pz.log_prob(z))
        log_qz = reduce(qz.log_prob(z))
        
        # compute the KL divergence term
        kl = log_qz - log_pz
        
        # ELBO: log p(x|z) - KL(q||p)
        elbo = log_px - kl
        
        # β-ELBO: log p(x|z) - β * KL(q||p)
        beta_elbo = log_px - self.beta * kl
        
        # loss = -E_q[Lβ]
        loss = -beta_elbo.mean()
        
        # diagnostics for monitoring
        with torch.no_grad():
            diagnostics = {'elbo': elbo, 'log_px': log_px, 'kl': kl}
            
        return loss, diagnostics, outputs


In [5]:

import torch
import torch.nn.functional as F
from torchmetrics.functional.regression import mean_squared_error, mean_absolute_error
from torchmetrics.functional.image.ssim import structural_similarity_index_measure
from models.convolutional_vae import VariationalAutoencoder
def compute_reconstruction_errors(x_true: torch.Tensor, x_recon: torch.Tensor):
    """
    Compute standard reconstruction error metrics using torchmetrics.

    Args:
        x_true: Ground-truth tensor of shape [B, C, H, W].
        x_recon: Reconstructed tensor of same shape.

    Returns:
        dict with keys: 'MSE', 'RMSE', 'MAE', 'PSNR', 'SSIM'
    """

    # Ensure same dtype and device
    x_true = x_true.to(x_recon.device, dtype=x_recon.dtype)

    # Ensure same shape
    assert x_true.shape == x_recon.shape, "x_true and x_recon must have the same shape"

    # Compute metrics
    mse = mean_squared_error(x_recon, x_true).item()
    rmse = mse ** 0.5
    mae = mean_absolute_error(x_recon, x_true).item()
    ssim = structural_similarity_index_measure(x_recon, x_true, data_range=1.0).item()

    return {
        "MSE": mse,
        "RMSE": rmse,
        "MAE": mae,
        "SSIM": ssim,
    }

def plot_reconstructions_with_metrics(vae: VariationalAutoencoder, test_loader: DataLoader, device: torch.device):
    with torch.no_grad():
        vae.eval()
        x, y = next(iter(test_loader))
        x = x.to(device)
        outputs = vae(x)

        # Plot reconstructions with their metrics
        num_images = 8
        fig, axes = plt.subplots(3, num_images, figsize=(num_images * 2, 6))
        plt.subplots_adjust(hspace=0.4)

        for i in range(num_images):
            # Compute metrics for each image
            diagnostics = compute_reconstruction_errors(
                x[i].unsqueeze(0), outputs['px'].mean[i].unsqueeze(0)
            )

            # --- Original ---
            axes[0, i].imshow(x[i].cpu().squeeze(), cmap='gray')
            axes[0, i].set_title(f"Image {i+1}", fontsize=10)
            axes[0, i].axis('off')

            # --- Reconstruction ---
            axes[1, i].imshow(outputs['px'].mean[i].cpu().squeeze(), cmap='gray')
            axes[1, i].axis('off')

            # --- Metrics ---
            metrics_text = "\n".join([f"{k}: {v:.4f}" for k, v in diagnostics.items()])
            axes[2, i].text(
                0.5, 0.5, metrics_text,
                color='black',
                fontsize=16,
                ha='center', va='center',
                family='monospace'
            )
            axes[2, i].axis('off')

        axes[0, 0].set_ylabel("Original", fontsize=12)
        axes[1, 0].set_ylabel("Reconstruction", fontsize=12)
        axes[2, 0].set_ylabel("Metrics", fontsize=12)
        plt.tight_layout()
        plt.show()



In [6]:
from collections import defaultdict
from utils import LatentType
from models.convolutional_vae import VariationalAutoencoder

# define the models, evaluator and optimizer
# VAE
latent_features = 2  # can always be changed
vae = VariationalAutoencoder(images[0].shape, latent_features, latent_type = LatentType.DIRICHLET)

# Evaluator: Variational Inference
beta = 1
vi = VariationalInference(beta=beta)

# The Adam optimizer works really well with VAEs.
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

# define dictionary to store the training curves
training_data = defaultdict(list)
validation_data = defaultdict(list)

epoch = 0



TODO:
- Pull and update code (create utils class)
- Update T-SNE diagram to work with x dimensional latent spaces
- Create one-hot-encoding reconstructions with latent variables like the dirichlet paper
- Save and loading of models
- Explore UMAP and MDS more
- Look more at greyscale transformations

In [None]:
num_epochs = 25
from plotting import make_vae_plots

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f">> Using device: {device}")

vae = vae.to(device)


# annealing parameters
tau_min = 0.5
anneal_rate = 0.0003
kl_warmup_epochs = 20  # number of epochs to reach full β
max_beta = 0.3
num_batches = len(train_loader)

while epoch < num_epochs:
    epoch += 1
    training_epoch_data = defaultdict(list)
    vae.train()

    # KL warmup: gradually increase β
    vi.beta = min(max_beta, epoch / kl_warmup_epochs)
    
    # training loop
    for batch_idx, (x, y) in  enumerate(train_loader, start=1):
        if batch_idx == num_batches:
              print("epoch:",epoch)
              print("elbo:",round(diagnostics["elbo"].mean().item()),2)
              print("log px:",round(diagnostics["log_px"].mean().item()),2)
              print("kl-divergence:",round(diagnostics["kl"].mean().item()),2)

        x = x.to(device)
        loss, diagnostics, outputs = vi(vae, x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        for k, v in diagnostics.items():
            training_epoch_data[k] += [v.mean().item()]

    for k, v in training_epoch_data.items():
        training_data[k] += [np.mean(training_epoch_data[k])]

    # validation
    with torch.no_grad():
        vae.eval()
        x, y = next(iter(test_loader))
        x = x.to(device)
        loss, diagnostics, outputs = vi(vae, x)
        for k, v in diagnostics.items():
            validation_data[k] += [v.mean().item()]

    # visualize
    # make_vae_plots(vae, x, y, outputs, training_data, validation_data)


>> Using device: cpu
epoch: 1
elbo: -161.8885040283203
log px: -159.9197998046875
kl-divergence: 1.9687070846557617
epoch: 2
elbo: -173.84388732910156
log px: -170.15965270996094
kl-divergence: 3.6842355728149414
epoch: 3
elbo: -160.43516540527344
log px: -158.02325439453125
kl-divergence: 2.411924362182617
epoch: 4
elbo: -168.97825622558594
log px: -166.19334411621094
kl-divergence: 2.784891128540039
epoch: 5
elbo: -166.98187255859375
log px: -164.47889709472656
kl-divergence: 2.5029830932617188
epoch: 6
elbo: -152.03614807128906
log px: -149.53038024902344
kl-divergence: 2.505770683288574


In [None]:
from plotting import latent_morphing
plot_reconstructions_with_metrics(vae, train_loader, device=device)
#latent_morphing(vae)