**Table of contents**<a id='toc0_'></a>    
- [Variational Autoencoders With Alternative Bottlenecks](#toc1_)    
- [Project Overview](#toc2_)    
- [Hydra Configuration System](#toc3_)    
      - [Main training configuration:](#toc3_1_1_1_)    
      - [Hyperparameter sweep setup:](#toc3_1_1_2_)    
- [Data Pipeline](#toc4_)    
- [Model Implementations](#toc5_)    
  - [Shared blocks](#toc5_1_)    
  - [VAE classes](#toc5_2_)    
  - [Loss functions](#toc5_3_)    
- [Training Pipeline](#toc6_)    
- [Evaluation Script](#toc7_)    
  - [Multi-model Comparison](#toc7_1_)    
  - [Visualization Utilities](#toc7_2_)    
- [Final Findings](#toc8_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Variational Autoencoders With Alternative Bottlenecks](#toc0_)
Authors:
| Name | Student ID |
|------|------------|
| Olivia Sommer Droob | S214696 |
| Karoline Klan Hansen | S214638 |
| Martha Falkesgaard Nybroe | S214692 |
| Signe Djernis Olsen | S206759 |
| Bella Strandfort | S214205 |

# <a id='toc2_'></a>[Project Overview](#toc0_)

This repository contains the implementation for our Deep Learning project at DTU (course 02456), where we investigate how different latent bottlenecks affect the behaviour of a Variational Autoencoder (VAE).
Specifically, we compare:

- Gaussian VAE (standard)
- Dirichlet VAE (simplex-constrained latent space)
- Continuous Categorical (CC) VAE (a newer exponential-family simplex distribution)


![image](project_images/VAE_figure.png)

# <a id='toc3_'></a>[Hydra Configuration System](#toc0_)
Hydra is used to keep all experiment settings clean, centralized, and easy to override from the command line.  
In this project, two configuration files control how training runs: [`base_config.yaml`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/configs/base_config.yaml) and [`wandb_sweep_config.yaml`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/configs/wandb_sweep_config.yaml).

#### <a id='toc3_1_1_1_'></a>[Main training configuration:](#toc0_)
[`configs/base_config.yaml`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/configs/base_config.yaml)

This is the default configuration used when running:

```bash
python -m src.deep_proj.train
```

It defines:

- Weights and Biases (WandB) settings
    - Whether logging is enabled, project name, entity, grouping of runs.

- Dataset setup
    - Choose mnist or medmnist, data path, batch size, validation split, and optional class filtering

- Model configuration
    - Select bottleneck type (gaussian, dirichlet, cc), latent dimension, and model-specific parameters.

- Training hyperparameters
    - Learning rate, number of epochs, GPU/CPU selection, early stopping.

- Output structure

#### <a id='toc3_1_1_2_'></a>[Hyperparameter sweep setup:](#toc0_)
[`configs/wandb_sweep_config.yaml`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/configs/wandb_sweep_config.yaml)

This file defines a WandB grid search sweep, allowing us to automatically try multiple combinations of:
- datasets
- model types (gaussian, diriclet and cc)
- learning rates
- latent dimensions

To launch the sweep run following from the terminal:
```bash
wandb sweep configs/wandb_sweep_config.yaml
wandb agent <sweep_id>
```

# <a id='toc4_'></a>[Data Pipeline](#toc0_)

In the datascript found in [`src/deep_proj/data.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/data.py), we make sure to handle both **MNIST** and **MedMNIST** datasets in a unified way.  

The script provides:
- **Dataset builders**:
  - `_build_mnist` - Loads MNIST with standard normalization ()  and optionally filters the dataset to only include user-specified classes loaded from the config file (`mnist_classes=[0,1,...]`).
  - `_build_medmnist`  
    Loads MedMNIST subsets using metadata from the `medmnist` package and applies appropriate normalization.

- **Dataloader construction** via `get_dataloaders`:
  - Selects which dataset loader to use (`mnist` or `medmnist`)
  - Splits the training set deterministically into **train/val** subsets based on `val_split`
  - Wraps everything into PyTorch `DataLoader` objects for training, validation, and testing




In [9]:
from omegaconf import OmegaConf
from src.deep_proj.data import _build_mnist, _build_medmnist

# Load base config (we will override the dataset field manually)
cfg = OmegaConf.load("configs/base_config.yaml")

# -------------------------
# MNIST
# -------------------------
cfg.dataset = "mnist"
cfg.mnist_classes = [0,1,4]   # only include the 3 subclasses 

print("=== MNIST ===")
mnist_train, mnist_test = _build_mnist(cfg)

print()

# -------------------------
# MedMNIST
# -------------------------
print("=== MedMNIST (subset:", cfg.medmnist_subset, ") ===")
cfg.dataset = "medmnist"
cfg.medmnist_subset = "organcmnist"   # change if needed
cfg.medmnist_classes = [0,1,2]  # only include the 3 organ classes

med_train, med_test = _build_medmnist(cfg)


print("Train samples:", len(med_train))
print("Test samples:", len(med_test))



=== MNIST ===
[MNIST FILTER] Selected classes: {0, 1, 4}
[MNIST FILTER] Train samples: 18507
[MNIST FILTER] Test samples:  3097

=== MedMNIST (subset: organcmnist ) ===
Using downloaded and verified file: data/organcmnist.npz
Using downloaded and verified file: data/organcmnist.npz
Train samples: 12975
Test samples: 8216


# <a id='toc5_'></a>[Model Implementations](#toc0_)
The [`src/deep_proj/model.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/model.py) collects all neural network models and ELBO loss functions used in the project. The key idea is that all three VAEs share the same encoder/decoder architecture but differ in the latent bottleneck distribution.

## <a id='toc5_1_'></a>[Shared blocks](#toc0_)

- `MLPEncoder`  
  A fully connected network that maps a flattened image to latent parameters.  
  It behaves differently depending on `bottle`:
  - `"gaus"` – outputs mean `μ` and log-variance `logvar` for a Gaussian latent.
  - `"dir"` – outputs positive concentration parameters `α̂` for a Dirichlet latent (clamped to a safe range).
  - `"cc"` – outputs positive parameters `λ̂` for the Continuous Categorical latent.

- `BernoulliDecoder`  
  A mirrored MLP that maps a latent vector back to image logits.  
  The last layer outputs raw logits which are used with `BCEWithLogitsLoss` to model Bernoulli pixels.

## <a id='toc5_2_'></a>[VAE classes](#toc0_)

- `GaussianVAE`  
  Standard VAE:
  1. Encoder produces `μ` and `logvar`.
  2. `reparameterize` draws `z = μ + σ ⊙ ε` with `ε ~ N(0, I)`.
  3. Decoder maps `z` to reconstruction logits.

- `DirVAE`  
  Uses a Dirichlet latent:
  1. Encoder outputs `α̂` (Dirichlet parameters).
  2. Samples are generated via an approximate inverse Gamma CDF, then normalized to the simplex.
  3. The prior is another Dirichlet with concentration `prior_alpha`.
  4. `dirvae_elbo_loss` combines a Bernoulli reconstruction term with a KL term based on the MultiGamma representation of the Dirichlet.

- `CCVAE`  
  Uses the Continuous Categorical distribution on the simplex:
  1. Encoder outputs `λ̂`, which is normalized to a mean vector `λ` on the simplex.
  2. `sample_cc_ordered_reparam` implements a differentiable sampler that returns `z` on the simplex.
  3. A uniform prior over the simplex is encoded as `prior_lambda` which is just flat (`torch.ones(latent_dim)`).
  4. `ccvae_elbo_loss` uses a Bernoulli reconstruction term and a CC-specific KL term (`cc_kl`) based on natural parameters `η` and the log-normalizing constant `log C(η)`.

## <a id='toc5_3_'></a>[Loss functions](#toc0_)

Each model has a corresponding ELBO helper:

- `gaussian_vae_elbo_loss(model, x)`  
  Standard VAE ELBO with analytic Gaussian KL.

- `dirvae_elbo_loss(model, x)`  
  ELBO for the Dirichlet VAE using the MultiGamma KL.

- `ccvae_elbo_loss(model, x)`  
  ELBO for the CC-VAE using the CC KL based on natural parameters and the normalizing constant.

Together, these components let us isolate the effect of changing only the latent distribution while keeping the encoder/decoder architecture and reconstruction loss fixed.

In [15]:
import torch
from src.deep_proj.model import (
    GaussianVAE, DirVAE, CCVAE,
    gaussian_vae_elbo_loss, dirvae_elbo_loss, ccvae_elbo_loss
)

# Setup
input_dim = 28 * 28
enc_hidden = [500, 500]
dec_hidden = [500]
latent_dim = 10

x = torch.rand(8, input_dim)

gaus = GaussianVAE(input_dim, enc_hidden, dec_hidden, latent_dim)
dirv = DirVAE(input_dim, enc_hidden, dec_hidden, latent_dim)
ccv  = CCVAE(input_dim, enc_hidden, dec_hidden, latent_dim)


def pretty(title):
    print("\n" + "="*70)
    print(f"{title}".center(70))
    print("="*70)


# ------------------------------------------------------------------------
# Gaussian VAE
# ------------------------------------------------------------------------
pretty("Gaussian VAE (Standard Normal Bottleneck)")

with torch.no_grad():
    logits_g, mu, logvar, z_g = gaus(x)
    _, _, kl_g = gaussian_vae_elbo_loss(gaus, x)

print("μ (mean) for first sample:")
print(" ", mu[0].tolist())
print("\nσ² (variance) for first sample:")
print(" ", logvar[0].exp().tolist())

print(f"\nKL( q(z|x) || N(0, I) ):   {kl_g.item():.4f}")


# ------------------------------------------------------------------------
# Dirichlet VAE
# ------------------------------------------------------------------------
pretty("Dirichlet VAE (Simplex Constrained)")

with torch.no_grad():
    logits_d, z_d, alpha_hat, v = dirv(x)
    _, _, kl_d = dirvae_elbo_loss(dirv, x)

a0 = alpha_hat[0]
z0 = z_d[0]

print("α̂ (concentration) for first sample:")
print(" ", a0.tolist())

print("\nSample z on simplex:")
print(f"  sum(z) = {z0.sum().item():.4f},   min(z) = {z0.min().item():.4f}")

print(f"\nKL( Dirichlet posterior || prior ):   {kl_d.item():.4f}")


# ------------------------------------------------------------------------
# CC VAE
# ------------------------------------------------------------------------
pretty("Continuous Categorical VAE (CC Bottleneck)")

with torch.no_grad():
    logits_c, z_c, lam = ccv(x)
    _, _, kl_c = ccvae_elbo_loss(ccv, x)

lam0 = lam[0]
z0c = z_c[0]

print("λ (mean vector) for first sample:")
print(" ", lam0.tolist())
print(f"  sum(λ) = {lam0.sum().item():.4f}")

print("\nSample z on simplex:")
print(f"  sum(z) = {z0c.sum().item():.4f},   min(z) = {z0c.min().item():.4f}")

print(f"\nKL( CC posterior || prior ):          {kl_c.item():.4f}")



              Gaussian VAE (Standard Normal Bottleneck)               
μ (mean) for first sample:
  [0.006777741014957428, -0.018862128257751465, 0.0059876032173633575, 0.008925219997763634, 0.05221676081418991, -0.009353311732411385, -0.04312464967370033, 0.002020535059273243, 0.07870528101921082, -0.02894168347120285]

σ² (variance) for first sample:
  [1.050595998764038, 0.9927116632461548, 0.977631151676178, 0.9883387684822083, 1.0319002866744995, 0.9219326376914978, 1.057778239250183, 1.007879614830017, 1.019608974456787, 1.0233640670776367]

KL( q(z|x) || N(0, I) ):   0.0185

                 Dirichlet VAE (Simplex Constrained)                  
α̂ (concentration) for first sample:
  [0.6840500831604004, 0.7139037251472473, 0.7407388091087341, 0.7028700709342957, 0.6711534857749939, 0.6853014826774597, 0.7448384165763855, 0.6747437715530396, 0.6862039566040039, 0.7255652546882629]

Sample z on simplex:
  sum(z) = 1.0000,   min(z) = 0.0001

KL( Dirichlet posterior || prior ):   0

# <a id='toc6_'></a>[Training Pipeline](#toc0_)

The training script is found in [`src/deep_proj/train.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/train.py).

It does the following:
1. Loads config and picks device  
   Hydra loads `base_config.yaml`, and `get_device(cfg)` chooses GPU if available (or CPU otherwise).

2. Builds dataloaders and model  
   Using `get_dataloaders(cfg)` and the `model_name` in the config, the script creates one of:
   - `GaussianVAE`
   - `DirVAE`
   - `CCVAE`  
   together with the matching ELBO loss function.

3. Sets up optimizer, early stopping, and logging 
   - Adam optimizer  
   - `EarlyStopping` based on validation loss  
   - Optional Weights & Biases logging and checkpoint directory under `models/`.

The core of the script is the epoch loop:
```python
for epoch in range(1, epochs + 1):
    model.train()
    tot_loss = tot_recon = tot_kl = 0.0
    n = 0

    for xb, _ in train_loader:
        xb = xb.to(device)

        # undo normalization - pixels back in [0,1]
        if cfg.dataset.lower() == "mnist":
            xb = xb * 0.3081 + 0.1307
        elif cfg.dataset.lower() == "medmnist":
            xb = xb * 0.5 + 0.5

        xb = xb.view(xb.size(0), -1)

        optimizer.zero_grad()
        loss, recon, kl = loss_fn(model, xb, reduction="mean")
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()

        bs = xb.size(0)
        n += bs
        tot_loss += loss.item() * bs
        tot_recon += recon.item() * bs
        tot_kl += kl.item() * bs
```

Key points:
- Inputs are unnormalized back to [0,1] before the ELBO, so the Bernoulli likelihood is consistent.
- The chosen ELBO loss returns total loss, reconstruction term, and KL term.
- Gradients are clipped (max_norm=5) for stability.

After each epoch, the script:
- Runs a validation pass with evaluate_split(...).
- Prints train/val losses and updates history.
- Saves a “best” checkpoint if the validation loss improves.
- Optionally logs metrics and visualizations to WandB.
- Checks early stopping; if there is no improvement for 10 epochs, training stops.

For full details see the script in [`src/deep_proj/train.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/train.py).

# <a id='toc7_'></a>[Evaluation Script](#toc0_)
We evaluate the trained models using the [`src/deep_proj/evaluate.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/evaluate.py). It takes a single checkpoint file and rebuilds everything needed to evaluate it.

It also saves:
- A simplex projection of the latent space (plot_latent_simplex)
- A t-SNE latent plot (plot_latent)
- A reconstruction grid (plot_recons)

All plots are stored under:
`reports/figures/<run_id>/evaluation/`

In [None]:
import torch
from omegaconf import OmegaConf, DictConfig

from src.deep_proj.evaluate import build_model_from_config
from src.deep_proj.data import get_dataloaders
from src.deep_proj.train import evaluate_split

# Path to one of your saved checkpoints (adjust this)
ckpt_path = "models/mnist_gaussian_z10_lr0.0004_best.pt" #EDIT WHEN WE HAVE THE FINAL MODELS

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load checkpoint and restore config
ckpt = torch.load(ckpt_path, map_location=device)
cfg = DictConfig(OmegaConf.create(ckpt["config"]))

# Rebuild model and loss function
model, loss_fn = build_model_from_config(cfg, device)
model.load_state_dict(ckpt["model_state_dict"])
model.to(device)
model.eval()

# Rebuild dataloaders
loaders = get_dataloaders(cfg)
val_loader = loaders["val"]
test_loader = loaders["test"]

print("Validation size:", len(val_loader.dataset))
print("Test size:      ", len(test_loader.dataset))

# Evaluate
val_loss, val_recon, val_kl = evaluate_split(model, val_loader, loss_fn, cfg, device)
test_loss, test_recon, test_kl = evaluate_split(model, test_loader, loss_fn, cfg, device)

print("\n=== Final model evaluation ===")
print(f"Validation | Loss {val_loss:.4f} | Recon {val_recon:.4f} | KL {val_kl:.4f}")
print(f"Test       | Loss {test_loss:.4f} | Recon {test_recon:.4f} | KL {test_kl:.4f}")

  from .autonotebook import tqdm as notebook_tqdm


[MNIST FILTER] Selected classes: {0, 1}
[MNIST FILTER] Train samples: 12665
[MNIST FILTER] Test samples:  2115
Validation size: 1266
Test size:       2115

=== Final model evaluation ===
Validation | Loss 108.5586 | Recon 96.2061 | KL 12.3525
Test       | Loss 106.6115 | Recon 94.1676 | KL 12.4439


## <a id='toc7_1_'></a>[Multi-model Comparison](#toc0_)

We also create a [`src/deep_proj/evaluate_multiple.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/evaluate_multiple.py) script that 
loads three trained checkpoints (Gaussian, Dirichlet, CC model), rebuilds each model from its saved config, and evaluates them on the same MNIST test set. For each model, it:

- Extracts latent representations and runs **t-SNE** to create a side-by-side latent space comparison.
- Uses a fixed batch of test images to plot a 4 × N reconstruction grid:
  1. Original images  
  2. Gaussian-VAE reconstructions  
  3. Dirichlet-VAE reconstructions  
  4. CC-VAE reconstructions

The resulting comparison figures are saved under:

`reports/figures/multi_eval/latent_comparison.png` 
and  
`reports/figures/multi_eval/reconstruction_comparison.png`.

## <a id='toc7_2_'></a>[Visualization Utilities](#toc0_)

All plotting code for the project lives in two helper modules:

- **[`simplex.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/simplex.py)**  
  Contains the tools for plotting latent simplex projections, where simplex-valued latents (Dirichlet / CC) are projected onto a polygon in 2D and overlaid with:
  - Colored class clusters
  - Example MNIST images placed at the simplex corners  
  The main function is:
  - `plot_latent_simplex(...)`

- **[`visualize.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/visualize.py)**  
  Collects the general-purpose visualization functions used during training and evaluation, including:
  - `plot_latent(...)` – t-SNE latent space plots  
  - `plot_recons(...)` – original vs. reconstructed image grids  
  - `plot_side_by_side(...)` – combine two images/plots into a single figure  
  - `plot_training_progress(...)` – side-by-side reconstructions + latent t-SNE over epochs  
  - `plot_training_loss(...)` – training and validation loss/ KL curves  

These utilities are called from `train.py`, `evaluate.py`, and `evaluate_multiple.py` to produce the figures shown in the report and in this notebook.

# <a id='toc8_'></a>[Final Findings](#toc0_)