In [2]:
import os
os.makedirs("LikelihoodTrainer")

In [3]:
%%writefile LikelihoodTrainer/liklihood.py
"""
Likelihood Ledger — RealNVP Flow + Exact Log-Likelihood (PyTorch)
=================================================================
An end-to-end teaching project for **likelihood-based evaluation** of generative
models using an explicit-density **normalizing flow (RealNVP)**. We train on
MNIST (28×28 grayscale → 784-D) and report **exact log-likelihood** and
**bits-per-dimension (bpd)** across epochs. Rich visuals and artifacts included.

• Dataset: MNIST (default) — optionally Fashion-MNIST
• Model: RealNVP with affine coupling, alternating binary masks, random perms
• Preprocessing: uniform dequantization + logit transform (with λ-smoothing)
• Evaluation: exact log p(x), NLL, bits/dim; validation curves, histograms
• Visuals: loss & bpd curves, per-epoch sample grids, NLL histogram, t-SNE of latents
• Artifacts: CSV logs, animated GIF timeline of sample grids
• Epochs: configurable (5/10/20) for pedagogy

Run
----
python likelihood_flow.py \
  --dataset mnist \
  --data_root ./data \
  --outdir ./outputs_ll \
  --epochs 10 \
  --batch_size 128 \
  --layers 8 \
  --hidden 512 \
  --lambda_logit 0.05

Optional flags: --dataset {mnist,fashion}, --epochs {5,10,20}, --lr 1e-3, --device cuda

Notes
-----
• **Exact likelihood** comes from change of variables: log p(x) = log p(z) + log|det ∂z/∂x|.
• RealNVP’s triangular Jacobian makes log-determinant **tractable**.
• The **logit** preprocessing stabilizes flow training on image pixels in [0,1].
"""

import os
import math
import argparse
import random
from dataclasses import dataclass
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
import matplotlib.pyplot as plt
import imageio

try:
    from sklearn.manifold import TSNE
    _HAVE_SK = True
except Exception:
    _HAVE_SK = False

# ------------------------------
# Reproducibility
# ------------------------------

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ------------------------------
# Data & Transforms (Dequantize + Logit)
# ------------------------------

class LogitTransform(nn.Module):
    """Applies x ∈ [0,1]  →  x' = λ + (1-2λ)x  →  y = logit(x'), with
    log|det J| = Σ[ log(1-2λ) - log(x') - log(1-x') ]."""
    def __init__(self, lam: float = 0.05, eps: float = 1e-6):
        super().__init__()
        self.lam = lam
        self.eps = eps
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # x in [0,1]
        x_ = self.lam + (1 - 2 * self.lam) * x
        x_ = x_.clamp(self.eps, 1 - self.eps)
        y = torch.log(x_) - torch.log(1 - x_)
        logdet = torch.sum(
            torch.log(torch.tensor(1 - 2 * self.lam, device=x.device))
            - torch.log(x_) - torch.log(1 - x_), dim=list(range(1, x.dim()))
        )
        return y, logdet
    def inverse(self, y: torch.Tensor) -> torch.Tensor:
        x_ = torch.sigmoid(y)
        x = (x_ - self.lam) / (1 - 2 * self.lam)
        return x.clamp(0.0, 1.0)


def get_dataloaders(dataset: str, data_root: str, batch_size: int, workers: int):
    tf = transforms.ToTensor()  # yields [0,1]
    if dataset == 'mnist':
        train_set = datasets.MNIST(root=data_root, train=True, download=True, transform=tf)
        test_set  = datasets.MNIST(root=data_root, train=False, download=True, transform=tf)
        channels = 1; H = W = 28
    elif dataset == 'fashion':
        train_set = datasets.FashionMNIST(root=data_root, train=True, download=True, transform=tf)
        test_set  = datasets.FashionMNIST(root=data_root, train=False, download=True, transform=tf)
        channels = 1; H = W = 28
    else:
        raise ValueError('dataset must be mnist or fashion')
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=workers, drop_last=True)
    test_loader  = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=workers)
    return train_loader, test_loader, channels, H, W

# ------------------------------
# RealNVP Building Blocks
# ------------------------------

class Permute(nn.Module):
    """Fixed random permutation layer for 1D vectors."""
    def __init__(self, D: int, seed: int = 0):
        super().__init__()
        rng = np.random.default_rng(seed)
        perm = torch.tensor(rng.permutation(D), dtype=torch.long)
        inv = torch.empty_like(perm)
        inv[perm] = torch.arange(D)
        self.register_buffer('perm', perm)
        self.register_buffer('inv', inv)
    def forward(self, x):
        return x[:, self.perm]
    def inverse(self, y):
        return y[:, self.inv]

class CouplingMLP(nn.Module):
    def __init__(self, D_in: int, D_out: int, hidden: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(D_in, hidden), nn.ReLU(),
            nn.Linear(hidden, hidden), nn.ReLU(),
            nn.Linear(hidden, D_out)
        )
    def forward(self, x):
        return self.net(x)

class AffineCoupling(nn.Module):
    """RealNVP affine coupling layer with a binary mask m ∈ {0,1}^D.
    y_A = x_A
    y_B = x_B * exp(s(x_A)) + t(x_A)
    log|det J| = Σ s(x_A) over B dims.
    """
    def __init__(self, mask: torch.Tensor, hidden: int):
        super().__init__()
        self.register_buffer('m', mask.float())
        D = mask.numel()
        D_A = int(self.m.sum().item())
        D_B = D - D_A
        self.s_net = CouplingMLP(D_A, D_B, hidden)
        self.t_net = CouplingMLP(D_A, D_B, hidden)
        self.scale = nn.Parameter(torch.zeros(1))  # global scale for stability
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        x_A = x * self.m                       # keep A
        x_A_flat = x_A[:, self.m.bool()]        # (N, D_A)
        s = self.s_net(x_A_flat)
        t = self.t_net(x_A_flat)
        s = torch.tanh(s) * torch.exp(self.scale)  # stabilize
        # apply to B
        x_B = x[:, (~self.m.bool())]
        y_B = x_B * torch.exp(s) + t
        y = torch.empty_like(x)
        y[:, self.m.bool()] = x_A_flat
        y[:, (~self.m.bool())] = y_B
        logdet = torch.sum(s, dim=1)
        return y, logdet
    def inverse(self, y: torch.Tensor) -> torch.Tensor:
        y_A = y[:, self.m.bool()]
        s = self.s_net(y_A)
        t = self.t_net(y_A)
        s = torch.tanh(s) * torch.exp(self.scale)
        y_B = y[:, (~self.m.bool())]
        x_B = (y_B - t) * torch.exp(-s)
        x = torch.empty_like(y)
        x[:, self.m.bool()] = y_A
        x[:, (~self.m.bool())] = x_B
        return x

class RealNVP(nn.Module):
    def __init__(self, D: int, layers: int = 8, hidden: int = 512, seed: int = 0):
        super().__init__()
        ms = []
        perms = []
        rng = np.random.default_rng(seed)
        # alternating 0/1 masks with random perms in between
        base_mask = torch.zeros(D)
        base_mask[::2] = 1.0  # start with even dims as A
        for k in range(layers):
            # permute dims to increase mixing
            perm = Permute(D, seed=int(rng.integers(0, 10_000)))
            perms.append(perm)
            # flip mask pattern each layer
            mask = base_mask.clone() if (k % 2 == 0) else (1 - base_mask)
            ms.append(AffineCoupling(mask, hidden))
        self.perms = nn.ModuleList(perms)
        self.couplings = nn.ModuleList(ms)
        self.D = D
    def f(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        # x -> z, accumulate logdet
        logdet_total = torch.zeros(x.size(0), device=x.device)
        h = x
        for perm, coup in zip(self.perms, self.couplings):
            h = perm(h)
            h, logdet = coup(h)
            logdet_total = logdet_total + logdet
        return h, logdet_total
    def inv(self, z: torch.Tensor) -> torch.Tensor:
        h = z
        for perm, coup in reversed(list(zip(self.perms, self.couplings))):
            h = coup.inverse(h)
            h = perm.inverse(h)
        return h

# ------------------------------
# Utilities: sampling, logging, plots
# ------------------------------

def save_image_grid(tensor: torch.Tensor, path: str, nrow: int = 8):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    grid = vutils.make_grid(tensor, nrow=nrow, padding=2)
    vutils.save_image(grid, path)

def save_csv(path: str, header: List[str], rows: List[List]):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w') as f:
        f.write(','.join(header) + '\n')
        for r in rows:
            f.write(','.join(map(str, r)) + '\n')

def make_gif_from_folder(folder: str, out_path: str, fps: int = 4):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    frames = []
    if not os.path.isdir(folder):
        return
    for name in sorted(os.listdir(folder)):
        if name.lower().endswith('.png'):
            frames.append(imageio.v2.imread(os.path.join(folder, name)))
    if frames:
        imageio.mimsave(out_path, frames, fps=fps)

def plot_curves(curves: dict, path: str, title: str, ylabel: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    plt.figure(figsize=(7, 4))
    for k, v in curves.items():
        xs = list(range(1, len(v)+1))
        plt.plot(xs, v, label=k)
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel(ylabel)
    plt.legend()
    plt.tight_layout()
    plt.savefig(path)
    plt.close()

# ------------------------------
# Training & Evaluation
# ------------------------------

@dataclass
class TrainConfig:
    dataset: str = 'mnist'            # mnist or fashion
    data_root: str = './data'
    outdir: str = './outputs_ll'
    batch_size: int = 128
    workers: int = 2
    epochs: int = 10
    lr: float = 1e-3
    layers: int = 8
    hidden: int = 512
    lambda_logit: float = 0.05
    seed: int = 42
    device: str = 'cuda' if torch.cuda.is_available() else 'cpu'

class LikelihoodTrainer:
    def __init__(self, cfg: TrainConfig):
        self.cfg = cfg
        self.device = torch.device(cfg.device)
        set_seed(cfg.seed)
        os.makedirs(cfg.outdir, exist_ok=True)

        # Data
        self.train_loader, self.test_loader, C, H, W = get_dataloaders(cfg.dataset, cfg.data_root, cfg.batch_size, cfg.workers)
        self.D = C * H * W
        self.logit = LogitTransform(cfg.lambda_logit)

        # Model
        self.flow = RealNVP(D=self.D, layers=cfg.layers, hidden=cfg.hidden, seed=cfg.seed).to(self.device)
        self.opt = optim.Adam(self.flow.parameters(), lr=cfg.lr)

        # Logs
        self.hist_train_nll: List[float] = []
        self.hist_val_nll: List[float] = []
        self.hist_val_bpd: List[float] = []

        # Fixed noise for generation from base
        self.fixed_z = torch.randn(64, self.D, device=self.device)
        self.C, self.H, self.W = C, H, W

    def _preprocess_batch(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Dequantize, logit, flatten. Returns (y, logdet_logit)."""
        # x in [0,1]; uniform dequantization
        u = torch.rand_like(x)
        x_deq = (x * 255.0 + u) / 256.0
        y, logdet = self.logit(x_deq)
        y = y.view(y.size(0), -1)  # flatten to (N,D)
        return y, logdet

    def _nll_batch(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute per-sample negative log-likelihood and bits/dim."""
        y, logdet_logit = self._preprocess_batch(x)
        z, logdet_flow = self.flow.f(y)
        # Standard normal log-prob
        log_pz = -0.5 * (z**2 + math.log(2 * math.pi)).sum(dim=1)
        log_px = log_pz + logdet_flow + logdet_logit
        nll = -log_px  # per-sample
        bpd = nll / (self.D * math.log(2))
        return nll, bpd

    @torch.no_grad()
    def evaluate(self):
        self.flow.eval()
        nlls, bpds = [], []
        for x, _ in self.test_loader:
            x = x.to(self.device)
            nll, bpd = self._nll_batch(x)
            nlls.append(nll.cpu()); bpds.append(bpd.cpu())
        nlls = torch.cat(nlls)
        bpds = torch.cat(bpds)
        return float(nlls.mean().item()), float(bpds.mean().item()), nlls, bpds

    @torch.no_grad()
    def sample(self, n: int = 64) -> torch.Tensor:
        self.flow.eval()
        z = torch.randn(n, self.D, device=self.device)
        y = self.flow.inv(z)
        x = self.logit.inverse(y.view(n, self.C, self.H, self.W))
        return x

    def train_epoch(self, epoch: int):
        self.flow.train()
        running = 0.0
        for i, (x, _) in enumerate(self.train_loader):
            x = x.to(self.device)
            nll, _ = self._nll_batch(x)
            loss = nll.mean()
            self.opt.zero_grad(); loss.backward(); self.opt.step()
            running += loss.item()
            if (i+1) % 100 == 0:
                print(f"Epoch {epoch:03d} [{i+1:04d}/{len(self.train_loader)}]  NLL: {loss.item():.2f}")
        self.hist_train_nll.append(running / len(self.train_loader))

        # Save sample grid for the epoch
        with torch.no_grad():
            sample = self.sample(64)
            save_image_grid(sample, os.path.join(self.cfg.outdir, f"samples/epoch_{epoch:03d}.png"), nrow=8)

    def post_plots(self):
        # Curves
        plot_curves({"Train NLL": self.hist_train_nll, "Val NLL": self.hist_val_nll},
                    os.path.join(self.cfg.outdir, "curves/nll_curve.png"),
                    title="Negative Log-Likelihood over Epochs", ylabel="NLL (lower is better)")
        plot_curves({"Val bpd": self.hist_val_bpd},
                    os.path.join(self.cfg.outdir, "curves/bpd_curve.png"),
                    title="Bits per Dimension over Epochs", ylabel="bits/dim (lower is better)")
        # Animated GIF of samples
        make_gif_from_folder(os.path.join(self.cfg.outdir, "samples"),
                             os.path.join(self.cfg.outdir, "samples/timeline.gif"), fps=3)

    def run(self):
        for epoch in range(1, self.cfg.epochs + 1):
            self.train_epoch(epoch)
            val_nll, val_bpd, nlls, bpds = self.evaluate()
            self.hist_val_nll.append(val_nll); self.hist_val_bpd.append(val_bpd)
            print(f"[VAL] Epoch {epoch:03d}  NLL = {val_nll:.2f}  |  bpd = {val_bpd:.4f}")

            # NLL histogram
            os.makedirs(os.path.join(self.cfg.outdir, 'curves'), exist_ok=True)
            plt.figure(figsize=(6,4))
            plt.hist(nlls.numpy(), bins=40, alpha=0.8)
            plt.title(f"Per-sample NLL (epoch {epoch})")
            plt.xlabel("NLL"); plt.ylabel("Count"); plt.tight_layout()
            plt.savefig(os.path.join(self.cfg.outdir, f"curves/nll_hist_epoch_{epoch:03d}.png")); plt.close()

            # (Optional) t-SNE of latents for a single batch
            if _HAVE_SK:
                x_vis, _ = next(iter(self.test_loader))
                x_vis = x_vis.to(self.device)
                y_vis, _ = self._preprocess_batch(x_vis)
                z_vis, _ = self.flow.f(y_vis)
                emb = TSNE(n_components=2, init='random', random_state=42, perplexity=30).fit_transform(z_vis.detach().cpu().numpy())
                plt.figure(figsize=(5,5))
                plt.scatter(emb[:,0], emb[:,1], s=6, alpha=0.6)
                plt.title(f"t-SNE of Latent z (epoch {epoch})")
                plt.tight_layout(); plt.savefig(os.path.join(self.cfg.outdir, f"curves/tsne_latent_{epoch:03d}.png")); plt.close()

        # Save logs
        save_csv(os.path.join(self.cfg.outdir, 'logs/nll.csv'), ['epoch','train_nll','val_nll'],
                 [[i+1, self.hist_train_nll[i], self.hist_val_nll[i]] for i in range(len(self.hist_val_nll))])
        save_csv(os.path.join(self.cfg.outdir, 'logs/bpd.csv'), ['epoch','val_bpd'],
                 [[i+1, self.hist_val_bpd[i]] for i in range(len(self.hist_val_bpd))])

        self.post_plots()

        print("\nArtifacts saved:")
        print(" - Sample grids:", os.path.join(self.cfg.outdir, "samples/epoch_###.png"))
        print(" - Sample animation:", os.path.join(self.cfg.outdir, "samples/timeline.gif"))
        print(" - Curves:", os.path.join(self.cfg.outdir, "curves/"))
        print(" - Logs (CSV):", os.path.join(self.cfg.outdir, "logs/"))
        print("Training complete. Outputs saved to:", self.cfg.outdir)

# ------------------------------
# Main
# ------------------------------

def parse_args() -> 'TrainConfig':
    p = argparse.ArgumentParser()
    p.add_argument('--dataset', type=str, default='mnist')
    p.add_argument('--data_root', type=str, default='./data')
    p.add_argument('--outdir', type=str, default='./outputs_ll')
    p.add_argument('--batch_size', type=int, default=128)
    p.add_argument('--workers', type=int, default=2)
    p.add_argument('--epochs', type=int, default=10)
    p.add_argument('--lr', type=float, default=1e-3)
    p.add_argument('--layers', type=int, default=8)
    p.add_argument('--hidden', type=int, default=512)
    p.add_argument('--lambda_logit', type=float, default=0.05)
    p.add_argument('--seed', type=int, default=42)
    p.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
    args = p.parse_args(args=None if __name__ == '__main__' else [])
    return TrainConfig(**vars(args))


def main():
    cfg = parse_args()
    print("\nLikelihood Ledger — RealNVP Flow + Exact Log-Likelihood (PyTorch)")
    print("Config:", cfg)
    trainer = LikelihoodTrainer(cfg)
    trainer.run()

if __name__ == '__main__':
    main()


Writing LikelihoodTrainer/liklihood.py


In [5]:
!python LikelihoodTrainer/liklihood.py \
  --dataset mnist \
  --epochs 10 \
  --layers 8 \
  --hidden 512 \
  --lambda_logit 0.05 \
  --outdir ./outputs_ll



Likelihood Ledger — RealNVP Flow + Exact Log-Likelihood (PyTorch)
Config: TrainConfig(dataset='mnist', data_root='./data', outdir='./outputs_ll', batch_size=128, workers=2, epochs=10, lr=0.001, layers=8, hidden=512, lambda_logit=0.05, seed=42, device='cuda')
100% 9.91M/9.91M [00:01<00:00, 5.01MB/s]
100% 28.9k/28.9k [00:00<00:00, 132kB/s]
100% 1.65M/1.65M [00:01<00:00, 1.27MB/s]
100% 4.54k/4.54k [00:00<00:00, 16.1MB/s]
Epoch 001 [0100/468]  NLL: -1462.94
Epoch 001 [0200/468]  NLL: -2013.84
Epoch 001 [0300/468]  NLL: -1608.11
Epoch 001 [0400/468]  NLL: -2325.74
[VAL] Epoch 001  NLL = -2435.83  |  bpd = -4.4823
Epoch 002 [0100/468]  NLL: -2428.37
Epoch 002 [0200/468]  NLL: -2571.56
Epoch 002 [0300/468]  NLL: -2573.00
Epoch 002 [0400/468]  NLL: -2098.82
[VAL] Epoch 002  NLL = -2482.48  |  bpd = -4.5682
Epoch 003 [0100/468]  NLL: -2641.11
Epoch 003 [0200/468]  NLL: -2665.64
Epoch 003 [0300/468]  NLL: -2712.14
Epoch 003 [0400/468]  NLL: -2682.07
[VAL] Epoch 003  NLL = -2705.81  |  bpd = -4.