# Variational Autoencoders With Alternative Bottlenecks  
Authors:
| Name | Student ID |
|------|------------|
| Olivia Sommer Droob | S214696 |
| Karoline Klan Hansen | S214638 |
| Martha Falkesgaard Nybroe | S214692 |
| Signe Djernis Olsen | S206759 |
| Bella Strandfort | S214205 |

# Table of Contents

1. [Project Overview](#1.-project-overview)
2. [Hydra Configuration System](#2.-hydra-configuration-system)
    - [base_config.yaml](#base_configyaml)
    - [wandb_sweep_config.yaml](#wandb_sweep-configyaml)
3. [Data Pipeline](#3.-data-pipeline)
    - [data.py](#datapy)
    - [Dataset filtering and class selection](#dataset-filtering)
    - [Dataloaders and splitting](#dataloaders-and-splitting)
4. [Model Implementations](#4.-model-implementations)
    - [Gaussian Bottleneck](#gaussian-bottleneck)
    - [Dirichlet Bottleneck](#dirichlet-bottleneck)
    - [Continuous Categorical (CC) Bottleneck](#cc-bottleneck)
5. [Training Pipeline](#5.-training-pipeline)
    - [train.py](#trainpy)
    - [Training loop](#training-loop)
    - [Early stopping](#early-stopping)
    - [Checkpoint saving and loading](#checkpointing)
    - [WandB integration](#wandb-integration)
6. [Evaluation Scripts](#6.-evaluation-scripts)
    - [evaluate.py](#evaluatepy)
    - [evaluate_multiple.py](#evaluate_multiplepy)
    - [Latent space visualization (t-SNE)](#latent-space-visualization)
    - [Reconstruction grids](#reconstruction-grids)
7. [Visualization Utilities](#visualization-utilities)
    - [visualize.py](#visualizepy)
8. [Final Findings](#final-findings)
    - [Gaussian vs. Dirichlet vs. CC](#gauss-vs-dir-vs-cc)
    - [Latent structure comparison](#latent-structure)
    - [Reconstruction comparison](#reconstruction-comparison)
    - [Overall conclusions](#overall-conclusions)

# 1. Project Overview

This repository contains the implementation for our Deep Learning project at DTU (course 02456), where we investigate how different latent bottlenecks affect the behaviour of a Variational Autoencoder (VAE).
Specifically, we compare:

- Gaussian VAE (standard)
- Dirichlet VAE (simplex-constrained latent space)
- Continuous Categorical (CC) VAE (a newer exponential-family simplex distribution)


![image](https://github.com/KarolineKlan/deep_project_group_38/blob/main/project_images/VAE%20figure.png)s

# 2. Hydra Configuration System
Explain configuration setup

In [1]:
#insert code


inf

# 3. Data Pipeline

In the datascript found in [`src/deep_proj/data.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/data.py), we make sure to handle both **MNIST** and **MedMNIST** datasets in a unified way.  

The script provides:
- **Dataset builders**:
  - `_build_mnist`  
    Loads MNIST with standard normalization  and optionally filters the dataset to only include user-specified classes loaded from the config file (`mnist_classes=[0,1,...]`).
  - `_build_medmnist`  
    Loads MedMNIST subsets using metadata from the `medmnist` package and applies appropriate normalization.

- **Dataloader construction** via `get_dataloaders`:
  - Automatically selects which dataset loader to use (`mnist` or `medmnist`)
  - Splits the training set deterministically into **train/val** subsets based on `val_split`
  - Wraps everything into PyTorch `DataLoader` objects for training, validation, and testing




In [None]:
from omegaconf import OmegaConf
from src.deep_proj.data import _build_mnist, _build_medmnist

# Load base config (we will override the dataset field manually)
cfg = OmegaConf.load("configs/base_config.yaml")

# -------------------------
# MNIST
# -------------------------
cfg.dataset = "mnist"
cfg.mnist_classes = [0,1,4]   # only include the 3 subclasses 

mnist_train, mnist_test = _build_mnist(cfg)

print("=== MNIST ===")
print("Train samples:", len(mnist_train))
print("Test samples:", len(mnist_test))
print()
print()

# -------------------------
# MedMNIST
# -------------------------
# ensure correct subset is set (organcmnist by default)
cfg.dataset = "medmnist"
cfg.medmnist_subset = "organcmnist"   # change if needed

med_train, med_test = _build_medmnist(cfg)

print("=== MedMNIST (subset:", cfg.medmnist_subset, ") ===")
print("Train samples:", len(med_train))
print("Test samples:", len(med_test))



=== MNIST ===
Train samples: 60000
Test samples: 10000

Downloading https://zenodo.org/records/10519652/files/organcmnist.npz?download=1 to data/organcmnist.npz


100%|██████████| 15526411/15526411 [00:03<00:00, 3943141.48it/s] 


Using downloaded and verified file: data/organcmnist.npz
=== MedMNIST (subset: organcmnist ) ===
Train samples: 12975
Test samples: 8216


# 4. Model Implementations
This is the most important section ... explain the models

## Gaussian Bottleneck


In [None]:
# insert code for gauss

## Dirichlet Bottleneck


In [None]:
# insert code for dir

## CC Bottleneck


In [None]:
# insert code for cc

# 5. Training Pipeline


In [None]:
# insert explanations and code