**Table of contents**<a id='toc0_'></a>    
- [Variational Autoencoders With Alternative Bottlenecks](#toc1_)    
- [Project Overview](#toc2_)    
- [Hydra Configuration System](#toc3_)    
      - [Main training configuration:](#toc3_1_1_1_)    
      - [Hyperparameter sweep setup:](#toc3_1_1_2_)    
- [Data Pipeline](#toc4_)    
- [Model Implementations](#toc5_)    
  - [Shared blocks](#toc5_1_)    
  - [VAE classes](#toc5_2_)    
- [Training Pipeline](#toc6_)    
- [Evaluation Script](#toc7_)    
  - [Multi-model Comparison](#toc7_1_)    
  - [Visualization Utilities](#toc7_2_)    
- [Final Findings](#toc8_)    
    - [The final 3 best models](#toc8_1_1_)    
      - [Reconstruction Comparison](#toc8_1_1_1_)    
      - [Latent space representation](#toc8_1_1_2_)    
  - [MedMNIST Experiment](#toc8_2_)    

<!-- vscode-jupyter-toc-config
	numbering=false
	anchor=true
	flat=false
	minLevel=1
	maxLevel=6
	/vscode-jupyter-toc-config -->
<!-- THIS CELL WILL BE REPLACED ON TOC UPDATE. DO NOT WRITE YOUR TEXT IN THIS CELL -->

# <a id='toc1_'></a>[Variational Autoencoders With Alternative Bottlenecks](#toc0_)
Authors:
| Name | Student ID |
|------|------------|
| Olivia Sommer Droob | S214696 |
| Karoline Klan Hansen | S214638 |
| Martha Falkesgaard Nybroe | S214692 |
| Signe Djernis Olsen | S206759 |
| Bella Strandfort | S214205 |

# <a id='toc2_'></a>[Project Overview](#toc0_)

This repository contains the implementation for our Deep Learning project at DTU (course 02456), where we investigate how different latent bottlenecks affect the behaviour of a Variational Autoencoder (VAE).
Specifically, we compare:

- Gaussian VAE (standard)
- Dirichlet VAE (simplex-constrained latent space)
- Continuous Categorical (CC) VAE (a newer exponential-family simplex distribution)


![image](project_images/VAE_figure_FINAL.png)

This Explainer Notebook is to provide documentation of the code setup and the experiments run. We refer to the report for more elaborate discussion and conslusion on the findings.

# <a id='toc3_'></a>[Hydra Configuration System](#toc0_)
Hydra is used to keep all experiment settings clean, centralized, and easy to override from the command line.  
In this project, two configuration files control how training runs: [`base_config.yaml`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/configs/base_config.yaml) and [`wandb_sweep_config.yaml`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/configs/wandb_sweep_config.yaml).

#### <a id='toc3_1_1_1_'></a>[Main training configuration:](#toc0_)
[`configs/base_config.yaml`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/configs/base_config.yaml)

This is the default configuration used when running:

```bash
python -m src.deep_proj.train
```

It defines:

- Weights and Biases (WandB) settings
    - Whether logging is enabled, project name, entity, grouping of runs.

- Dataset setup
    - Choose mnist or medmnist, data path, batch size, validation split, and optional class filtering

- Model configuration
    - Select bottleneck type (gaussian, dirichlet, cc), latent dimension, and model-specific parameters.

- Training hyperparameters
    - Learning rate, number of epochs, GPU/CPU selection, early stopping.

- Output structure

#### <a id='toc3_1_1_2_'></a>[Hyperparameter sweep setup:](#toc0_)
[`configs/wandb_sweep_config.yaml`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/configs/wandb_sweep_config.yaml)

This file defines a WandB grid search sweep, allowing us to automatically try multiple combinations of:
- datasets
- model types (gaussian, diriclet and cc)
- learning rates
- latent dimensions

To launch the sweep run following from the terminal:
```bash
wandb sweep configs/wandb_sweep_config.yaml
wandb agent <sweep_id>
```

# <a id='toc4_'></a>[Data Pipeline](#toc0_)

In the datascript found in [`src/deep_proj/data.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/data.py), we make sure to handle both **MNIST** and **MedMNIST** datasets in a unified way.  

The script provides:
- **Dataset builders**:
  - `_build_mnist` - Loads MNIST with standard normalization ()  and optionally filters the dataset to only include user-specified classes loaded from the config file (mnist_classes=[0,1,...]). 
  - `_build_medmnist`  
    Loads MedMNIST subsets using metadata from the `medmnist` package and applies appropriate normalization.

- **Dataloader construction** via `get_dataloaders`:
  - Selects which dataset loader to use (`mnist` or `medmnist`)
  - Splits the training set deterministically into **train/val** subsets based on `val_split`
  - Wraps everything into PyTorch `DataLoader` objects for training, validation, and testing

! Note on normalization: This was first introduced as a standard thing to add stability in the optimization, but is later removed in the training to keep the same simple setup as in the paper we tried to follow for the report. 


In [2]:
from omegaconf import OmegaConf
from src.deep_proj.data import _build_mnist, _build_medmnist

# we will override the base_config manually
cfg = OmegaConf.load("configs/base_config.yaml")

# -------------------------
# MNIST
# -------------------------
cfg.dataset = "mnist"
cfg.mnist_classes = [0,1,4]   # only include the 3 subclasses 

print("=== MNIST ===")
mnist_train, mnist_test = _build_mnist(cfg)
print("Train samples:", len(mnist_train))
print("Test samples:", len(mnist_test))


print()

# -------------------------
# MedMNIST
# -------------------------
print("=== MedMNIST (subset:", cfg.medmnist_subset, ") ===")
cfg.dataset = "medmnist"
cfg.medmnist_subset = "organcmnist"   
cfg.medmnist_classes = [3,5,8]        # only include the 3 organ classes

med_train, med_test = _build_medmnist(cfg)


print("Train samples:", len(med_train))
print("Test samples:", len(med_test))



=== MNIST ===
Train samples: 18507
Test samples: 3097

=== MedMNIST (subset: organcmnist ) ===
Using downloaded and verified file: data/organcmnist.npz
Using downloaded and verified file: data/organcmnist.npz
Train samples: 2792
Test samples: 1713


# <a id='toc5_'></a>[Model Implementations](#toc0_)
The [`src/deep_proj/model.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/model.py) collects all neural network models and ELBO loss functions used in the project. The key idea is that all three VAEs share the same encoder/decoder architecture but differ in the latent bottleneck distribution.

## <a id='toc5_1_'></a>[Shared blocks](#toc0_)

- `MLPEncoder`  
  A fully connected network that maps a flattened image to latent parameters.  
  It behaves differently depending on `bottle`:
  - `"gaus"` – outputs mean `μ` and log-variance `logvar` for a Gaussian latent.
  - `"dir"` – outputs positive concentration parameters `α̂` for a Dirichlet latent (clamped to a safe range).
  - `"cc"` – outputs positive parameters `λ̂` for the Continuous Categorical latent.

- `BernoulliDecoder`  
  A mirrored MLP that maps a latent vector back to image logits.  
  The last layer outputs raw logits which are used with `BCEWithLogitsLoss` to model Bernoulli pixels.

## <a id='toc5_2_'></a>[VAE classes](#toc0_)

- `GaussianVAE`  
  Standard VAE:
  1. Encoder produces `μ` and `logvar`.
  2. `reparameterize` draws `z = μ + σ ⊙ ε` with `ε ~ N(0, I)`.
  3. Decoder maps `z` to reconstruction logits.

- `DirVAE`  
  Uses a Dirichlet latent:
  1. Encoder outputs `α̂` (Dirichlet parameters).
  2. Samples are generated via an approximate inverse Gamma CDF, then normalized to the simplex.
  3. The prior is another Dirichlet with concentration `prior_alpha`.
  4. `dirvae_elbo_loss` combines a Bernoulli reconstruction term with a KL term based on the MultiGamma representation of the Dirichlet.

- `CCVAE`  
  Uses the Continuous Categorical distribution on the simplex:
  1. Encoder outputs `λ̂` (CC parameters) which has been normalized to 1.
  2. `sample_cc_ordered_reparam` implements a differentiable sampler using the inverse of the continous Bernoulli distribution and returns samples `z` on the simplex.
  3. A uniform prior over the simplex is encoded as `prior_lambda` which is just flat (`torch.ones(latent_dim)/K`).
  4. `ccvae_elbo_loss` uses a reconstruction term and a MC estimate of CC-specific KL term (`cc_kl`) based on the log-normalizing constant `log C(η)` and the sample z.

See  [`src/deep_proj/model.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/model.py) for all details.


In [3]:
# Showing model summaries
import torch
from src.deep_proj.model import (
    GaussianVAE,
    DirVAE,
    CCVAE,
    gaussian_vae_elbo_loss,
    dirvae_elbo_loss,
    ccvae_elbo_loss
)

input_dim   = 28 * 28
enc_hidden  = [500, 500]
dec_hidden  = [500]
latent_dim  = 3  # or 5 / 8

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


# Instantiate one copy of each model
gauss = GaussianVAE(input_dim, enc_hidden, dec_hidden, latent_dim)
dirichlet = DirVAE(input_dim, enc_hidden, dec_hidden, latent_dim)
cc = CCVAE(input_dim, enc_hidden, dec_hidden, latent_dim)

# Build rows *directly from the model objects*
rows = [
    {
        "Model": "Gauss-VAE",
        "Latent dim": gauss.latent_dim,
        "# trainable params": count_params(gauss),
        "Prior": "N(0, I)",  # GaussianVAE uses a fixed standard normal prior
        "ELBO loss fn": gaussian_vae_elbo_loss.__name__,
    },
    {
        "Model": "Dir-VAE",
        "Latent dim": dirichlet.latent_dim,
        "# trainable params": count_params(dirichlet),
        # shows the actual prior vector stored in the model
        "Prior": f" α={dirichlet.prior_alpha.cpu().numpy()} (same as in paper)",
        "ELBO loss fn": dirvae_elbo_loss.__name__, #if i were to print the inputs to this function i would write: #dirvae_elbo_loss(recon_x, x, mu, logvar, dirichlet.prior_alpha)
    },
    {
        "Model": "CC-VAE",
        "Latent dim": cc.latent_dim,
        "# trainable params": count_params(cc),
        "Prior": f"λ={cc.prior_lambda.cpu().numpy()} (flat on simplex)",
        "ELBO loss fn": ccvae_elbo_loss.__name__,
    },
]

# Pretty print table
header = rows[0].keys()
print("{:<12} {:<11} {:<20} {:<60} {:<25}".format(*header))
print("-" * 125)
for r in rows:
    print("{Model:<12} {Latent dim:<11} {# trainable params:<20} {Prior:<60} {ELBO loss fn:<25}".format(**r))


Model        Latent dim  # trainable params   Prior                                                        ELBO loss fn             
-----------------------------------------------------------------------------------------------------------------------------
Gauss-VAE    3           1043796              N(0, I)                                                      gaussian_vae_elbo_loss   
Dir-VAE      3           1043796               α=[0.98 0.98 0.98] (same as in paper)                       dirvae_elbo_loss         
CC-VAE       3           1043796              λ=[0.33333334 0.33333334 0.33333334] (flat on simplex)       ccvae_elbo_loss          


# <a id='toc6_'></a>[Training Pipeline](#toc0_)

The training script is found in [`src/deep_proj/train.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/train.py).

It does the following:
1. Loads config and picks device  
   Hydra loads `base_config.yaml`, and `get_device(cfg)` chooses GPU if available (or CPU otherwise).

2. Builds dataloaders and model  
   Using `get_dataloaders(cfg)` and the `model_name` in the config, the script creates one of:
   - `GaussianVAE`
   - `DirVAE`
   - `CCVAE`  
   together with the matching ELBO loss function.

3. Sets up optimizer, early stopping, and logging 
   - Adam optimizer  
   - `EarlyStopping` based on validation loss  
   - Optional Weights & Biases logging and checkpoint directory under `models/`.

The core of the script is the epoch loop:
```python
for epoch in range(1, epochs + 1):
    model.train()
    tot_loss = tot_recon = tot_kl = 0.0
    n = 0

    for xb, _ in train_loader:
        xb = xb.to(device)

        # undo normalization - pixels back in [0,1] because paper we are trying to mimic does not normalize
        if cfg.dataset.lower() == "mnist":
            xb = xb * 0.3081 + 0.1307
        elif cfg.dataset.lower() == "medmnist":
            xb = xb * 0.5 + 0.5

        xb = xb.view(xb.size(0), -1)                                # Flatten images to match the MLP architecture adopted from [7]

        optimizer.zero_grad()
        loss, recon, kl = loss_fn(model, xb, reduction="mean")      # loss_fn is one of {gaussian_vae_elbo_loss, dirvae_elbo_loss, ccvae_elbo_loss}
        loss.backward()                                             # Backpropagate stochastic gradient of the ELBO
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)     #apply gradient clipping with max norm = 5, to stabilize optimization
        optimizer.step()                                            # Single Adam update step (same optimizer for all models)

        # Accumulate sums so we can later compute averages over the full epoch
        bs = xb.size(0)
        n += bs
        tot_loss += loss.item() * bs
        tot_recon += recon.item() * bs
        tot_kl += kl.item() * bs
```

Key points:
- Inputs are unnormalized back to [0,1] before the ELBO, so the Bernoulli likelihood is consistent.
- The chosen ELBO loss returns total loss, reconstruction term, and KL term.
- Gradients are clipped (max_norm=5) for stability.

After each epoch, the script:
- Runs a validation pass with evaluate_split(...).
- Prints train/val losses and updates history.
- Saves a “best” checkpoint if the validation loss improves.
- Optionally logs metrics and visualizations to WandB.
- Checks early stopping; if there is no improvement for 10 epochs, training stops.

For full details see the script in [`src/deep_proj/train.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/train.py).

# <a id='toc7_'></a>[Evaluation Script](#toc0_)
We evaluate the trained models using the [`src/deep_proj/evaluate.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/evaluate.py). It takes a single checkpoint file and rebuilds everything needed to evaluate it.

It also saves:
- A simplex projection of the latent space (plot_latent_simplex)
- A t-SNE latent plot (plot_latent)
- A reconstruction grid (plot_recons)

All plots are stored under:
`reports/figures/<run_id>/evaluation/`

## <a id='toc7_1_'></a>[Multi-model Comparison](#toc0_)

We also create a [`src/deep_proj/evaluate_multiple.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/evaluate_multiple.py) script that 
loads three trained checkpoints (Gaussian, Dirichlet, CC model), rebuilds each model from its saved config, and evaluates them on the same MNIST test set. For each model, it:

- Extracts latent representations and runs **t-SNE** to create a side-by-side latent space comparison.
- Uses a fixed batch of test images to plot a 4 × N reconstruction grid:
  1. Original images  
  2. Gaussian-VAE reconstructions  
  3. Dirichlet-VAE reconstructions  
  4. CC-VAE reconstructions

The resulting comparison figures are saved under:

`reports/figures/multi_eval/latent_comparison.png` 
and  
`reports/figures/multi_eval/reconstruction_comparison.png`.

## <a id='toc7_2_'></a>[Visualization Utilities](#toc0_)

All plotting code for the project lives in two helper modules:

- **[`simplex.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/simplex.py)**  
  Contains the tools for plotting latent simplex projections, where simplex-valued latents (Dirichlet / CC) are projected onto a polygon in 2D and overlaid with:
  - Colored class clusters
  - Example MNIST images placed at the simplex corners  
  The main function is:
  - `plot_latent_simplex(...)`

- **[`visualize.py`](https://github.com/KarolineKlan/deep_project_group_38/blob/main/src/deep_proj/visualize.py)**  
  Collects the general-purpose visualization functions used during training and evaluation, including:
  - `plot_latent(...)` – t-SNE latent space plots  
  - `plot_recons(...)` – original vs. reconstructed image grids  
  - `plot_side_by_side(...)` – combine two images/plots into a single figure (used for training logs mostly) 
  - `plot_training_loss(...)` – training and validation loss/ KL curves  

These utilities are called from `train.py`, `evaluate.py`, and `evaluate_multiple.py` to produce the figures shown in the report and in this notebook.

# <a id='toc8_'></a>[Final Findings](#toc0_)

This project explored how three different latent bottlenecks - Gaussian, Dirichlet, and CC - shape the behaviour of a Variational Autoencoder, both in terms of representation geometry and reconstruction quality. By keeping the encoder/decoder architecture and training pipeline identical across models, differences can be directly attributed to the choice of latent distribution.

All training logs can be found in the [Weights and Biases report](https://api.wandb.ai/links/spicy-mlops/92cfrf8p).

### <a id='toc8_1_1_'></a>[The final 3 best models](#toc0_)

We found the 3 best models for gauss and dir to be the one with the highest latent space M=3 with learning rate 0.0007, and for CC the one with M=3 and learning rate 0.0003

In [7]:
GAUSS_FINAL_MODEL = "models/final_sweep/mnist_gaussian_z8_lr0.0007_best.pt"
DIR_FINAL_MODEL   = "models/final_sweep/mnist_dirichlet_z8_lr0.0007_best.pt"
CC_FINAL_MODEL    = "models/final_sweep/mnist_cc_z3_lr0.0003_best.pt"

In [None]:
from src.deep_proj.evaluate_multiple import build_model_from_config
from src.deep_proj.data import get_dataloaders
from src.deep_proj.train import evaluate_split
from src.deep_proj.model import (
    gaussian_vae_elbo_loss, dirvae_elbo_loss, ccvae_elbo_loss
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def run_eval(ckpt_path, name):
    # Load checkpoint + config
    ckpt = torch.load(ckpt_path, map_location=device)
    cfg = DictConfig(OmegaConf.create(ckpt["config"]))

    # Build model
    model = build_model_from_config(cfg, device)
    model.load_state_dict(ckpt["model_state_dict"])
    model.to(device)
    model.eval()

    # Get loss function automatically
    if cfg.model_name.lower() in ("gaussian", "gaus", "gauss"):
        loss_fn = gaussian_vae_elbo_loss
    elif cfg.model_name.lower() in ("dirichlet", "dir"):
        loss_fn = dirvae_elbo_loss
    else:
        loss_fn = ccvae_elbo_loss

    # Dataloaders
    loaders = get_dataloaders(cfg)
    val_loader = loaders["val"]
    test_loader = loaders["test"]

    # Eval
    val_loss, val_recon, val_kl = evaluate_split(model, val_loader, loss_fn, cfg, device)
    test_loss, test_recon, test_kl = evaluate_split(model, test_loader, loss_fn, cfg, device)

    print(f"\n=== {name} (latent dim = {cfg.latent_dim}) ===")
    print(f"Validation | Loss {val_loss:.4f} | Recon {val_recon:.4f} | KL {val_kl:.4f}")
    print(f"Test       | Loss {test_loss:.4f} | Recon {test_recon:.4f} | KL {test_kl:.4f}")

run_eval(GAUSS_FINAL_MODEL, "Gaussian-VAE")
run_eval(DIR_FINAL_MODEL,   "Dirichlet-VAE")
run_eval(CC_FINAL_MODEL,    "CC-VAE")



=== Gaussian-VAE (latent dim = 8) ===
Validation | Loss 83.4912 | Recon 68.1760 | KL 15.3152
Test       | Loss 84.2690 | Recon 69.1138 | KL 15.1553

=== Dirichlet-VAE (latent dim = 8) ===
Validation | Loss 104.3500 | Recon 90.9932 | KL 13.3568
Test       | Loss 101.9431 | Recon 88.7045 | KL 13.2386

=== CC-VAE (latent dim = 3) ===
Validation | Loss 158.1684 | Recon 150.3114 | KL 7.8570
Test       | Loss 158.9817 | Recon 151.1213 | KL 7.8604



Based on our experiments, the Gaussian bottleneck consistently achieved the lowest reconstruction errors, while the CC-bottleneck produced the highest, with the Dirichlet in between. This trend is also visible in the reconstruction examples:

#### <a id='toc8_1_1_1_'></a>[Reconstruction Comparison](#toc0_)

<img src="project_images/reconstruction_comparison_final_mnist.png" width="60%">


#### <a id='toc8_1_1_2_'></a>[Latent space representation](#toc0_)

**Latent space of the 3 models**
For the latent spaces, the Gaussian-VAE and Dirichlet-VAE form clearly separated clusters, whereas the CC-VAE clusters remain a bit connected, indicating that the CC model might not fully discriminate between classes:

<img src="project_images/latent_comparison_FINAL.png" width="60%">


**Dirichlet versus CC on the Simplex**
When projecting the models onto the simplex, we observe
- the CC-VAE uses the extremes of the simplex more strongly than the Dirichlet-VAE for both M=3 and M=5.
- The Dirichlet-VAE tends to cluster toward the interior, while the CC-VAE spreads toward the corners

<img src="project_images/dir_vs_cc_simplex_grid.png" width="40%">

## <a id='toc8_2_'></a>[MedMNIST Experiment](#toc0_)

To test the robustness of the three bottlenecks beyond MNIST, we repeated the experiments on OrganCMNIST with three classes as well but due to time constraints no larger sweep was run. Only models with learning rate 0.0005 and M=5.

Training logs can be found in the [MedMNIST WandB Report](https://api.wandb.ai/links/spicy-mlops/4fqt4npo)

In [20]:
print("=== MedMNIST (subset:", cfg.medmnist_subset, ") ===")
cfg.dataset = "medmnist"
cfg.medmnist_subset = "organcmnist"   
cfg.medmnist_classes = [3,5,8]        # only include the 3 organ classes

med_train, med_test = _build_medmnist(cfg)


print("Train samples:", len(med_train))
print("Test samples:", len(med_test))


=== MedMNIST (subset: organcmnist ) ===
Using downloaded and verified file: data/organcmnist.npz
Using downloaded and verified file: data/organcmnist.npz
Train samples: 2792
Test samples: 1713


Evaluation of the 3 models show:

In [15]:
GAUSS_MED_z5 = "models/medmnist_gaussian_z5_lr0.0005_best.pt"
DIR_MED_z5   = "models/medmnist_dirichlet_z5_lr0.0005_best.pt"
CC_MED_z5    = "models/medmnist_cc_z5_lr0.0005_best.pt"

run_eval(GAUSS_MED_z5, "Gaussian-VAE (MedMNIST, M=5)")
run_eval(DIR_MED_z5,   "Dirichlet-VAE (MedMNIST, M=5)")
run_eval(CC_MED_z5,    "CC-VAE (MedMNIST, M=5)")

Using downloaded and verified file: data/organcmnist.npz
Using downloaded and verified file: data/organcmnist.npz

=== Gaussian-VAE (MedMNIST, M=5) (latent dim = 5) ===
Validation | Loss 421.1960 | Recon 410.7639 | KL 10.4321
Test       | Loss 443.6786 | Recon 434.0534 | KL 9.6253
Using downloaded and verified file: data/organcmnist.npz
Using downloaded and verified file: data/organcmnist.npz

=== Dirichlet-VAE (MedMNIST, M=5) (latent dim = 5) ===
Validation | Loss 448.4669 | Recon 439.9898 | KL 8.4770
Test       | Loss 459.2871 | Recon 451.5473 | KL 7.7398
Using downloaded and verified file: data/organcmnist.npz
Using downloaded and verified file: data/organcmnist.npz

=== CC-VAE (MedMNIST, M=5) (latent dim = 5) ===
Validation | Loss 490.4989 | Recon 460.0301 | KL 30.4688
Test       | Loss 508.5288 | Recon 477.6432 | KL 30.8856


As described in the report, the images in OrganCMNIST is much more complex than MNIST: the images contain textures, anatomy, shading, and organ variation that a small MLP-based VAE cannot model well. Also the Train-set is much smaller than regular mnist when only selecting a subset of classes. As a result all models produce much worse reconstructions, which drives the reconstruction term (and therefore total ELBO) extremely high. Gaussian still performs best.

Accordingly, all models produced less impressive reconstructions, as they all appear very blurry:

<img src="project_images/gauss_dir_cc_recon.png" width="40%">

Gauss forms reasonably separated clusters, but the class boundaries are not as clean as they were on MNIST. CC has some pointy clusters almost looking like a star, something we also witnessed from training logs when training on the mnist for M=5 classes for CC.

<img src="project_images/medmnist_latent_tsne.png" width="40%">


we see similar results when projecting the medmnist onto the simplex, that classes for CC spread towards the corners, but not as well as for the mnist data. We see some overlapping classes.

<img src="project_images/medmnist_dir_vs_cc_simplex_grid.png" width="30%">

We refer to the report for discussion and conslusion on the findings.