<a href="https://colab.research.google.com/github/KarolineKlan/deep_project_group_38/blob/main/ExperimentSetup.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# # MNIST — GVAE vs DirVAE (inverse Gamma CDF) vs CC-Placeholder
# One model is trained **at a time**, results are saved, then a final comparison is plotted.
# Output dir: /content/outputs
#
# DirVAE notes (paper equations referenced below):
# • Dirichlet = normalized composition of K independent Gamma(α_k, β) variables. (Sec. 2.2)
# • KL between MultiGamma posteriors/prior (Eq. (3)):  Σ_k[ log Γ(α_k) − log Γ(α̂_k) + (α̂_k−α_k) ψ(α̂_k) ].
# • ELBO (Eq. (7)): reconstruction (BCE) + the MultiGamma KL term above.
# • “Fair” prior vs. softmax–Gaussian Laplace note (Eq. (5)): α_k = 1 − 1/K when μ=0, Σ=I; we adopt that prior.
# • **Sampling path**: inverse Gamma CDF approximation (Knowles 2015) — NO Laplace reparam. (Sec. 3)
#     X ~ Gamma(α, β): F^{-1}(u; α, β) ≈ β^{-1} (u^α Γ(α))^{1/α}, with u ~ Uniform(0,1).
#     Implemented in log-space for stability:  log X = log u + lgamma(α)/α − log β.
#
# Citations: DirVAE paper and the inverse-CDF approach are from your uploaded PDF.  :contentReference[oaicite:1]{index=1}

import os, math, random
from dataclasses import dataclass, asdict
from typing import Dict, List, Tuple

import numpy as np
import torch
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, utils as vutils

import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import pandas as pd

# -----------------------------
# Config (all knobs in one place)
# -----------------------------
@dataclass
class Config:
    data_root: str = "/content/data"
    out_dir: str = "/content/outputs"
    # training
    batch_size: int = 256
    epochs: int = 10
    lr: float = 1e-3
    seed: int = 42
    # model
    latent_dim: int = 10         # K (match classes) for all models
    beta_gamma_rate: float = 1.0 # β for DirVAE Gamma rate (shared across dims)
    enc_ch: int = 32
    dec_ch: int = 32
    # eval
    tsne_samples: int = 5000
    # dataloader (keep simple & robust in Colab)
    num_workers: int = 0

cfg = Config()
os.makedirs(cfg.out_dir, exist_ok=True)

# Reproducibility & device
def set_seeds(s: int):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seeds(cfg.seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

# -----------------------------
# Data (45k/5k/10k split as requested)
# -----------------------------
transform = transforms.ToTensor()
mnist_full = datasets.MNIST(cfg.data_root, train=True, download=True, transform=transform)
test = datasets.MNIST(cfg.data_root, train=False, download=True, transform=transform)

# 45k train / 5k val / (drop 10k from train set); official 10k test used for test
train_len = 45_000
rest_len = len(mnist_full) - train_len
train, rest = random_split(mnist_full, [train_len, rest_len],
                           generator=torch.Generator().manual_seed(cfg.seed))
val_len, drop_len = 5_000, rest_len - 5_000
val, _ = random_split(rest, [val_len, drop_len],
                      generator=torch.Generator().manual_seed(cfg.seed + 1))

train_loader = DataLoader(train, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=False)
val_loader   = DataLoader(val,   batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=False)
test_loader  = DataLoader(test,  batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=False)

print(f"Split sizes → train: {len(train)}, val: {len(val)}, test: {len(test)}")
print("Device:", device)

# -----------------------------
# Shared CNN encoder/decoder
# -----------------------------
class EncoderCNN(nn.Module):
    """Small CNN encoder → 256-d hidden."""
    def __init__(self, ch: int = 32):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, ch, 4, 2, 1),  # 28→14
            nn.ReLU(True),
            nn.Conv2d(ch, ch*2, 4, 2, 1),  # 14→7
            nn.BatchNorm2d(ch*2),
            nn.ReLU(True),
            nn.Conv2d(ch*2, ch*4, 3, 1, 1),  # stay 7
            nn.BatchNorm2d(ch*4),
            nn.ReLU(True),
        )
        self.fc = nn.Sequential(nn.Flatten(), nn.Linear(ch*4*7*7, 256), nn.ReLU(True))

    def forward(self, x):  # (B,1,28,28) → (B,256)
        return self.fc(self.net(x))

class DecoderCNN(nn.Module):
    """Deterministically outputs (B,1,28,28) logits (use BCEWithLogitsLoss)."""
    def __init__(self, latent_dim: int, ch: int = 32):
        super().__init__()
        C = ch * 4
        self.fc = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(True),
            nn.Linear(256, C*7*7),
            nn.ReLU(True),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(C, C//2, 4, 2, 1),  # 7→14
            nn.BatchNorm2d(C//2),
            nn.ReLU(True),
            nn.ConvTranspose2d(C//2, 1, 4, 2, 1),  # 14→28
        )

    def forward(self, z):
        h = self.fc(z).view(z.size(0), -1, 7, 7)
        return self.deconv(h)  # logits

# -----------------------------
# Bottlenecks
# -----------------------------
class GaussianBottleneck(nn.Module):
    """Diagonal Gaussian posterior; prior N(0,I)."""
    def __init__(self, latent_dim: int):
        super().__init__()
        self.mu = nn.Linear(256, latent_dim)
        self.logvar = nn.Linear(256, latent_dim)

    def forward(self, h):
        mu, logvar = self.mu(h), self.logvar(h)
        std = torch.exp(0.5 * logvar)
        z = mu + std * torch.randn_like(std)
        return z, {"mu": mu, "logvar": logvar}

    def kl(self, aux):  # per-sample KL
        mu, logvar = aux["mu"], aux["logvar"]
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)

    def embed(self, aux):  # for t-SNE, use μ
        return aux["mu"]

class DirichletBottleneck(nn.Module):
    r"""
    DirVAE bottleneck using inverse Gamma CDF (Knowles 2015) with numerical stabilizations.
    - Sample per-dim Gamma via:  log X = log u + lgamma(α̂)/α̂ − log β, u~U(0,1).
    - Normalize in log-space: y = softmax(log_v) to avoid overflow.
    - Clamp α̂ into [α_min, α_max] so lgamma/ψ are well-behaved.
    - Prior α_k from Eq. (5) fairness note: α_k = 1 − 1/K, β = 1 by default.
    - KL between MultiGamma’s (Eq. (3)).
    """
    def __init__(self, latent_dim: int, beta_rate: float = 1.0, eps: float = 1e-6,
                 alpha_min: float = 1e-1, alpha_max: float = 50.0):
        super().__init__()
        self.latent_dim = latent_dim
        self.beta = beta_rate
        self.eps = eps
        self.alpha_min = alpha_min
        self.alpha_max = alpha_max

        self.alpha_raw = nn.Linear(256, latent_dim)  # α̂(x) > 0 via softplus
        alpha0 = 1.0 - 1.0 / float(latent_dim)       # prior α_k (Eq. (5))
        self.register_buffer("alpha_prior", torch.full((latent_dim,), alpha0))

    def _alpha_hat(self, h):
        # Positive + clamped α̂ for stability in sampler & KL
        ah = F.softplus(self.alpha_raw(h)) + 1e-6
        return ah.clamp(min=self.alpha_min, max=self.alpha_max)

    def _sample_multi_gamma_log(self, alpha_hat: torch.Tensor) -> torch.Tensor:
        """
        Sample log v for K independent Gamma(α̂_k, β) using the inverse CDF approximation.
        Return log_v to allow log-space normalization (stable).
        """
        B, K = alpha_hat.shape
        u = torch.clamp(torch.rand(B, K, device=alpha_hat.device), 1e-6, 1.0 - 1e-6)
        # log v = log u + lgamma(α̂)/α̂ − log β
        log_v = torch.log(u) + torch.lgamma(alpha_hat) / alpha_hat - math.log(self.beta)
        # Optional clamp to a safe range to avoid inf during backprop if something goes wrong
        log_v = torch.clamp(log_v, min=-60.0, max=60.0)
        return log_v

    def forward(self, h):
        alpha_hat = self._alpha_hat(h)           # (B,K) stabilized α̂
        log_v = self._sample_multi_gamma_log(alpha_hat)
        y = F.softmax(log_v, dim=1)              # normalize in log-space (stable)
        return y, {"alpha_hat": alpha_hat}

    def kl(self, aux):
        """
        KL(Q||P) between MultiGamma(α̂, β) and MultiGamma(α, β):
          Σ_k [ log Γ(α_k) − log Γ(α̂_k) + (α̂_k − α_k) ψ(α̂_k) ]  (Eq. (3))
        """
        alpha_hat = aux["alpha_hat"]  # already stabilized
        alpha = self.alpha_prior.view(1, -1).expand_as(alpha_hat)
        term = torch.lgamma(alpha) - torch.lgamma(alpha_hat) + (alpha_hat - alpha) * torch.digamma(alpha_hat)
        return torch.sum(term, dim=1)

    def embed(self, aux):
        # Dirichlet mean for visualization: α̂ / Σ α̂
        ah = aux["alpha_hat"]
        return ah / (ah.sum(dim=1, keepdim=True) + self.eps)


class CCPlaceholderBottleneck(nn.Module):
    r"""
    Continuous–Categorical placeholder:
      • Produce logits g(x); return softmax(g/τ) as simplex latent.
      • KL = 0.5 * ||g||^2 (small regularizer).  TODO: replace with true CC sampler & KL.
    """
    def __init__(self, latent_dim: int, temperature: float = 0.5):
        super().__init__()
        self.temperature = temperature
        self.logits = nn.Linear(256, latent_dim)

    def forward(self, h):
        g = self.logits(h)
        y = F.softmax(g / self.temperature, dim=1)
        return y, {"logits": g}

    def kl(self, aux):
        mu = aux["logits"]
        return 0.5 * torch.sum(mu.pow(2), dim=1)

    def embed(self, aux):  # for t-SNE
        return F.softmax(aux["logits"] / self.temperature, dim=1)

# -----------------------------
# VAE wrapper
# -----------------------------
class VAE(nn.Module):
    def __init__(self, bottleneck: nn.Module, latent_dim: int, enc_ch: int = 32, dec_ch: int = 32):
        super().__init__()
        self.encoder = EncoderCNN(enc_ch)
        self.bottleneck = bottleneck
        self.decoder = DecoderCNN(latent_dim, dec_ch)

    def forward(self, x):
        h = self.encoder(x)               # (B,256)
        z, aux = self.bottleneck(h)       # (B,K)
        x_logits = self.decoder(z)        # (B,1,28,28)
        return x_logits, z, aux

# -----------------------------
# Train / Eval
# -----------------------------
def bce_recon_loss(x_logits, x):
    if x_logits.shape[2:] != (28, 28):
        raise ValueError(f"Decoder produced {tuple(x_logits.shape)}; expected (B,1,28,28).")
    return F.binary_cross_entropy_with_logits(
        x_logits.view(x.size(0), -1),
        x.view(x.size(0), -1),
        reduction="sum"
    )

def run_epoch(model, loader, optimizer=None):
    train = optimizer is not None
    model.train(train)
    bce_total, kl_total, n = 0.0, 0.0, 0
    for x, _ in loader:
        x = x.to(device)
        if train: optimizer.zero_grad()
        x_logits, _, aux = model(x)
        bce = bce_recon_loss(x_logits, x)                 # sum over batch
        kl  = model.bottleneck.kl(aux).sum()              # sum over batch
        loss = bce + kl
        if train:
            loss.backward()
            optimizer.step()
        bsz = x.size(0)
        bce_total += float(bce.item()); kl_total += float(kl.item()); n += bsz
    return {"bce": bce_total / n, "kl": kl_total / n, "total": (bce_total + kl_total) / n}

@torch.no_grad()
def make_recon_grid(model, loader, path: str, n: int = 10):
    model.eval()
    x, _ = next(iter(loader))
    x = x.to(device)[:n]
    x_logits, _, _ = model(x)
    x_rec = torch.sigmoid(x_logits)
    grid = torch.cat([x, x_rec], dim=0)
    grid = vutils.make_grid(grid, nrow=n, padding=2)
    vutils.save_image(grid, path)

@torch.no_grad()
def collect_embeddings(model, loader, max_n: int = 5000):
    model.eval()
    xs, ys, total = [], [], 0
    for x, y in loader:
        x = x.to(device); b = x.size(0)
        h = model.encoder(x)
        # Compute bottleneck stats without sampling the decoder
        if isinstance(model.bottleneck, GaussianBottleneck):
            aux = {"mu": model.bottleneck.mu(h), "logvar": model.bottleneck.logvar(h)}
            emb = model.bottleneck.embed(aux)
        elif isinstance(model.bottleneck, DirichletBottleneck):
            ah = model.bottleneck._alpha_hat(h)  # use same stabilized α̂
            emb = ah / (ah.sum(dim=1, keepdim=True) + 1e-6)
        elif isinstance(model.bottleneck, CCPlaceholderBottleneck):
            logits = model.bottleneck.logits(h)
            emb = F.softmax(logits / model.bottleneck.temperature, dim=1)
        else:
            raise ValueError("Unknown bottleneck.")
        xs.append(emb.detach().cpu()); ys.append(y); total += b
        if total >= max_n: break
    X = torch.cat(xs, 0).numpy()[:max_n]
    Y = torch.cat(ys, 0).numpy()[:max_n]
    return X, Y

def tsne_plot(X, Y, title: str, path: str):
    Z = TSNE(n_components=2, init="pca", perplexity=30, n_iter=1000,
             learning_rate="auto", random_state=cfg.seed).fit_transform(X)
    plt.figure(figsize=(6,5))
    plt.scatter(Z[:,0], Z[:,1], c=Y, s=5, cmap="tab10", alpha=0.8)
    plt.title(title); plt.tight_layout(); plt.savefig(path, dpi=150); plt.close()

# -----------------------------
# Run ONE model (training + eval + save)
# -----------------------------
def run_single(model_name: str) -> Dict[str, List[float]]:
    if model_name == "gvae":
        bottleneck = GaussianBottleneck(cfg.latent_dim)
    elif model_name == "dirvae":
        bottleneck = DirichletBottleneck(cfg.latent_dim, beta_rate=cfg.beta_gamma_rate)
    elif model_name == "cc":
        bottleneck = CCPlaceholderBottleneck(cfg.latent_dim, temperature=0.5)
    else:
        raise ValueError("Unknown model name.")

    model = VAE(bottleneck, cfg.latent_dim, enc_ch=cfg.enc_ch, dec_ch=cfg.dec_ch).to(device)
    opt = optim.Adam(model.parameters(), lr=cfg.lr)
    history = {"train_bce": [], "train_kl": [], "train_total": [],
               "val_bce": [], "val_kl": [], "val_total": []}

    print(f"\n=== Training {model_name.upper()} ===")
    for epoch in range(1, cfg.epochs+1):
        tr = run_epoch(model, train_loader, optimizer=opt)
        va = run_epoch(model, val_loader, optimizer=None)
        history["train_bce"].append(tr["bce"]); history["train_kl"].append(tr["kl"]); history["train_total"].append(tr["total"])
        history["val_bce"].append(va["bce"]);   history["val_kl"].append(va["kl"]);   history["val_total"].append(va["total"])
        if epoch % 5 == 0 or epoch == cfg.epochs:
            print(f"Epoch {epoch:02d} | Train ELBO {tr['total']:.2f} (BCE {tr['bce']:.2f}, KL {tr['kl']:.2f})  "
                  f"| Val ELBO {va['total']:.2f} (BCE {va['bce']:.2f}, KL {va['kl']:.2f})")

    # Recon grid
    make_recon_grid(model, val_loader, os.path.join(cfg.out_dir, f"recon_{model_name}.png"), n=10)

    # t-SNE
    X, Y = collect_embeddings(model, test_loader, max_n=cfg.tsne_samples)
    tsne_plot(X, Y, f"t-SNE ({model_name.upper()})", os.path.join(cfg.out_dir, f"tsne_{model_name}.png"))

    # Test metrics
    te = run_epoch(model, test_loader, optimizer=None)
    history["test_bce"]   = [te["bce"]]
    history["test_kl"]    = [te["kl"]]
    history["test_total"] = [te["total"]]

    # Save checkpoint
    torch.save({"model_state": model.state_dict(),
                "cfg": asdict(cfg),
                "history": history},
               os.path.join(cfg.out_dir, f"ckpt_{model_name}.pt"))

    # tidy GPU mem between runs
    del model; torch.cuda.empty_cache()
    return history

# -----------------------------
# Run models sequentially (one at a time)
# -----------------------------
histories = {}
for name in ["gvae", "dirvae", "cc"]:
    histories[name] = run_single(name)

# -----------------------------
# Final comparison plots & table
# -----------------------------
epochs = range(1, cfg.epochs + 1)
plt.figure(figsize=(12,4))
# BCE
plt.subplot(1,3,1)
for k, lab in [("gvae","GVAE"), ("dirvae","DirVAE"), ("cc","CC")]:
    plt.plot(epochs, histories[k]["train_bce"], label=f"{lab} Train")
    plt.plot(epochs, histories[k]["val_bce"],   linestyle="--", label=f"{lab} Val")
plt.title("Reconstruction (BCE) ↓"); plt.xlabel("Epoch"); plt.ylabel("Avg per sample"); plt.legend(fontsize=8)

# KL
plt.subplot(1,3,2)
for k, lab in [("gvae","GVAE"), ("dirvae","DirVAE"), ("cc","CC")]:
    plt.plot(epochs, histories[k]["train_kl"], label=f"{lab} Train")
    plt.plot(epochs, histories[k]["val_kl"],   linestyle="--", label=f"{lab} Val")
plt.title("KL term ↓"); plt.xlabel("Epoch"); plt.ylabel("Avg per sample"); plt.legend(fontsize=8)

# Total
plt.subplot(1,3,3)
for k, lab in [("gvae","GVAE"), ("dirvae","DirVAE"), ("cc","CC")]:
    plt.plot(epochs, histories[k]["train_total"], label=f"{lab} Train")
    plt.plot(epochs, histories[k]["val_total"],   linestyle="--", label=f"{lab} Val")
plt.title("ELBO (BCE+KL) ↓"); plt.xlabel("Epoch"); plt.ylabel("Avg per sample"); plt.legend(fontsize=8)

plt.tight_layout(); plt.savefig(os.path.join(cfg.out_dir, "loss_curves.png"), dpi=150); plt.show()

# Tiny summary table (final test metrics)
def last(arr): return float(arr[-1])
summary = {
    "Model": ["GVAE", "DirVAE", "CC"],
    "Test BCE":   [last(histories["gvae"]["test_bce"]),   last(histories["dirvae"]["test_bce"]),   last(histories["cc"]["test_bce"])],
    "Test KL":    [last(histories["gvae"]["test_kl"]),    last(histories["dirvae"]["test_kl"]),    last(histories["cc"]["test_kl"])],
    "Test ELBO":  [last(histories["gvae"]["test_total"]), last(histories["dirvae"]["test_total"]), last(histories["cc"]["test_total"])],
}
df = pd.DataFrame(summary)
print("\nFinal test metrics (avg per sample):")
print(df.to_string(index=False))

print(f"\nSaved files in {cfg.out_dir}:")
for name in ["recon_gvae.png", "recon_dirvae.png", "recon_cc.png",
             "tsne_gvae.png", "tsne_dirvae.png", "tsne_cc.png", "loss_curves.png",
             "ckpt_gvae.pt", "ckpt_dirvae.pt", "ckpt_cc.pt"]:
    print(" -", os.path.join(cfg.out_dir, name))


100%|██████████| 9.91M/9.91M [00:00<00:00, 20.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 477kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.43MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.87MB/s]


Split sizes → train: 45000, val: 5000, test: 10000
Device: cuda

=== Training GVAE ===
Epoch 05 | Train ELBO 105.86 (BCE 87.24, KL 18.63)  | Val ELBO 106.22 (BCE 87.70, KL 18.52)
Epoch 10 | Train ELBO 101.29 (BCE 82.05, KL 19.24)  | Val ELBO 102.23 (BCE 83.08, KL 19.15)


