<a href="https://colab.research.google.com/github/AnanyaTyagi/VAE-GAN-Diffusion-Benchmark/blob/main/VAE_Image_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# New setup

In [None]:
# ┌─────────────────────────────────────────────────────────────┐
# │ EVALUATION APPROACH                                         │
# │ Challenge: Generated images won’t match specific data items. │
# │ Solution: Use FID & IS → compare DISTRIBUTIONS, not pixels.  │
# │ Standard in generative research (DALL·E, SD, StyleGAN).      │
# └─────────────────────────────────────────────────────────────┘

# --- Colab / Environment ---
# --- Check GPU and mount Google Drive ---
!nvidia-smi -L || true

from google.colab import drive
try:
    drive.mount('/content/drive', force_remount=True)
    OUT_DIR = "/content/drive/MyDrive/vae_cifar10_runs"   # change if you prefer
    DRIVE_OK = True
    print("✅ Drive mounted, results will be saved to:", OUT_DIR)
except Exception as e:
    print("⚠️ Drive mount failed:", e)
    OUT_DIR = "/content/vae_cifar10_runs"
    DRIVE_OK = False
    print("Saving locally to:", OUT_DIR)

import os
os.makedirs(OUT_DIR, exist_ok=True)


import os, json, csv, time, math, random
import numpy as np
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
import torchvision.utils as vutils


# --- Reproducibility ---
def set_seed(seed=42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

set_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.benchmark = True

# --- Helpers ---
def denorm(x):  # map [-1,1] -> [0,1]
    return (x.clamp(-1,1) + 1) / 2

def save_grid_pair(orig, recon, path, nrow=8):
    grid_o = vutils.make_grid(denorm(orig), nrow=nrow)
    grid_r = vutils.make_grid(denorm(recon), nrow=nrow)
    combo  = torch.cat([grid_o, grid_r], dim=1)  # stack vertically
    vutils.save_image(combo, path)

def beta_schedule(epoch, warmup_epochs=10, beta_start=0.0, beta_end=1.0):
    if epoch <= warmup_epochs:
        t = epoch / max(1, warmup_epochs)
        return beta_start + t * (beta_end - beta_start)
    return beta_end

# --- Config ---
@dataclass
class CFG:
    data_root: str = "./data"
    out_dir: str = OUT_DIR
    batch_size: int = 256
    epochs: int = 100        # bump to 100 for full run
    lr: float = 1e-3
    z_dim: int = 256
    num_workers: int = 2
    log_interval: int = 100
    beta_warmup_epochs: int = 10
    beta_start: float = 0.0
    beta_end: float = 1.0
    grid_n: int = 64
    samples_export_total: int = 10_000
    samples_export_bs: int = 100
    smoke_subset: int = 0   # set to 1000 for a quick smoke test
    use_amp: bool = True

cfg = CFG()
print("Saving outputs to:", cfg.out_dir)
with open(os.path.join(cfg.out_dir, "config.json"), "w") as f:
    json.dump(cfg.__dict__, f, indent=2)

print("PyTorch:", torch.__version__, "| CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


GPU 0: Tesla T4 (UUID: GPU-d5ef296a-956d-c984-8619-772b012f77d8)
Mounted at /content/drive
✅ Drive mounted, results will be saved to: /content/drive/MyDrive/vae_cifar10_runs
Saving outputs to: /content/drive/MyDrive/vae_cifar10_runs
PyTorch: 2.9.0+cu126 | CUDA: True
GPU: Tesla T4


In [None]:
# --- Reparameterization ---
def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + std * eps

# --- Encoder ---
class Encoder(nn.Module):
    """
    (B,3,32,32) → (B,256,2,2) via 4× Conv2d(k=4,s=2,p=1)
    flatten → 1024 → heads: mu/logvar (B, z_dim)
    """
    def __init__(self, z_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1, bias=False),  # 32x16x16
            nn.BatchNorm2d(32), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1, bias=False), # 64x8x8
            nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64,128, 4, 2, 1, bias=False), # 128x4x4
            nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128,256,4, 2, 1, bias=False), # 256x2x2
            nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
        )
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(256*2*2, z_dim)
        self.fc_logvar = nn.Linear(256*2*2, z_dim)

    def forward(self, x):
        h = self.net(x)
        h = self.flatten(h)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

# --- Decoder ---
class Decoder(nn.Module):
    """
    z (B,z) → FC 1024 → view (B,256,2,2) → 4× ConvTranspose2d → (B,3,32,32)
    """
    def __init__(self, z_dim=128):
        super().__init__()
        self.fc = nn.Linear(z_dim, 256*2*2)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(256,128,4,2,1,bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128,64,4,2,1,bias=False),  nn.BatchNorm2d(64),  nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64,32,4,2,1,bias=False),   nn.BatchNorm2d(32),  nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32,3,4,2,1,bias=False),
            nn.Tanh(),  # output in [-1,1]
        )

    def forward(self, z):
        h = self.fc(z)
        h = h.view(h.size(0), 256, 2, 2)
        return self.net(h)

# --- VAE wrapper ---
class VAE(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.enc = Encoder(z_dim)
        self.dec = Decoder(z_dim)

    def forward(self, x):
        mu, logvar = self.enc(x)
        z = reparameterize(mu, logvar)
        xr = self.dec(z)
        return xr, mu, logvar, z

# --- Loss ---
def vae_loss(x, xr, mu, logvar, recon_type="mse"):
    if recon_type == "mse":
        recon = F.mse_loss(xr, x, reduction="mean")
    else:  # BCE expects [0,1]
        recon = F.binary_cross_entropy(denorm(xr), denorm(x), reduction="mean")
    kl = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon, kl

# --- DataLoaders (CIFAR-10) ---
def get_loaders(cfg):
    tfm = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),  # [-1,1]
    ])
    train = datasets.CIFAR10(cfg.data_root, train=True, download=True, transform=tfm)
    test  = datasets.CIFAR10(cfg.data_root, train=False, download=True, transform=tfm)

    if cfg.smoke_subset and cfg.smoke_subset > 0:
        train = Subset(train, list(range(cfg.smoke_subset)))
        test  = Subset(test,  list(range(min(len(test), cfg.smoke_subset//5))))

    trainloader = DataLoader(train, batch_size=cfg.batch_size, shuffle=True,
                             num_workers=cfg.num_workers, pin_memory=True)
    testloader  = DataLoader(test,  batch_size=cfg.batch_size, shuffle=False,
                             num_workers=cfg.num_workers, pin_memory=True)
    return trainloader, testloader

trainloader, testloader = get_loaders(cfg)
print("Train batches:", len(trainloader), "| Test batches:", len(testloader))


100%|██████████| 170M/170M [00:16<00:00, 10.1MB/s]


Train batches: 196 | Test batches: 40


In [None]:
# Train the VAE and save artifacts to Drive
vae = VAE(z_dim=cfg.z_dim).to(device)
opt = torch.optim.Adam(vae.parameters(), lr=cfg.lr)
scaler = torch.cuda.amp.GradScaler(enabled=(cfg.use_amp and device=="cuda"))

# fixed test batch for recon snapshots
vae.eval()
with torch.no_grad():
    fixed_imgs, _ = next(iter(testloader))
fixed_imgs = fixed_imgs[:cfg.grid_n].to(device)
vae.train()

csv_path = os.path.join(cfg.out_dir, "train_log.csv")
if not os.path.exists(csv_path):
    with open(csv_path, "w", newline="") as f:
        csv.writer(f).writerow(["epoch","beta","avg_total","avg_recon","avg_kl","time_sec"])

for epoch in range(1, cfg.epochs + 1):
    t0 = time.time()
    vae.train()
    ep_total = ep_recon = ep_kl = 0.0
    beta = beta_schedule(epoch, cfg.beta_warmup_epochs, cfg.beta_start, cfg.beta_end)

    for i, (x, _) in enumerate(trainloader, start=1):
        x = x.to(device)
        opt.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast(enabled=(cfg.use_amp and device=="cuda")):
            xr, mu, logvar, _ = vae(x)
            recon, kl = vae_loss(x, xr, mu, logvar, recon_type="mse")
            total = recon + beta * kl
        scaler.scale(total).backward()
        scaler.step(opt)
        scaler.update()

        ep_total += total.item(); ep_recon += recon.item(); ep_kl += kl.item()
        if i % cfg.log_interval == 0:
            print(f"[E{epoch:03d} {i:04d}/{len(trainloader)}] β={beta:.3f} total={total.item():.4f} recon={recon.item():.4f} kl={kl.item():.4f}")

    # epoch summary
    n = len(trainloader); dt = time.time() - t0
    avg_total, avg_recon, avg_kl = ep_total/n, ep_recon/n, ep_kl/n
    print(f"==> Epoch {epoch} | β={beta:.3f} | {dt:.1f}s | total={avg_total:.4f} recon={avg_recon:.4f} kl={avg_kl:.4f}")
    with open(csv_path, "a", newline="") as f:
        csv.writer(f).writerow([epoch, beta, avg_total, avg_recon, avg_kl, round(dt,2)])

    # save recon grid (originals top, recon bottom)
    vae.eval()
    with torch.no_grad():
        xr, _, _, _ = vae(fixed_imgs)
    save_grid_pair(fixed_imgs, xr, os.path.join(cfg.out_dir, f"recon_epoch_{epoch:03d}.png"), nrow=8)

    # save random samples grid
    with torch.no_grad():
        z = torch.randn(cfg.grid_n, cfg.z_dim, device=device)
        xs = vae.dec(z)
    vutils.save_image(vutils.make_grid(denorm(xs), nrow=8),
                      os.path.join(cfg.out_dir, f"samples_epoch_{epoch:03d}.png"))

    # checkpoint
    ckpt = {"epoch": epoch, "model": vae.state_dict(), "optimizer": opt.state_dict(),
            "cfg": cfg.__dict__}
    torch.save(ckpt, os.path.join(cfg.out_dir, f"vae_epoch_{epoch:03d}.pth"))

print("Training complete. Outputs in:", cfg.out_dir)


  scaler = torch.cuda.amp.GradScaler(enabled=(cfg.use_amp and device=="cuda"))
  with torch.cuda.amp.autocast(enabled=(cfg.use_amp and device=="cuda")):


[E001 0100/196] β=0.100 total=0.1191 recon=0.1048 kl=0.1428
==> Epoch 1 | β=0.100 | 14.5s | total=0.1346 recon=0.1206 kl=0.1406
[E002 0100/196] β=0.200 total=0.0994 recon=0.0784 kl=0.1049
==> Epoch 2 | β=0.200 | 13.8s | total=0.1024 recon=0.0810 kl=0.1070
[E003 0100/196] β=0.300 total=0.1074 recon=0.0812 kl=0.0872
==> Epoch 3 | β=0.300 | 12.2s | total=0.1037 recon=0.0776 kl=0.0867
[E004 0100/196] β=0.400 total=0.1093 recon=0.0791 kl=0.0755
==> Epoch 4 | β=0.400 | 11.3s | total=0.1096 recon=0.0804 kl=0.0731
[E005 0100/196] β=0.500 total=0.1149 recon=0.0842 kl=0.0613
==> Epoch 5 | β=0.500 | 11.4s | total=0.1149 recon=0.0836 kl=0.0627
[E006 0100/196] β=0.600 total=0.1177 recon=0.0851 kl=0.0543
==> Epoch 6 | β=0.600 | 12.1s | total=0.1200 recon=0.0870 kl=0.0549
[E007 0100/196] β=0.700 total=0.1242 recon=0.0894 kl=0.0497
==> Epoch 7 | β=0.700 | 12.0s | total=0.1248 recon=0.0907 kl=0.0488
[E008 0100/196] β=0.800 total=0.1300 recon=0.0949 kl=0.0439
==> Epoch 8 | β=0.800 | 12.1s | total=0.1288

In [None]:
export_dir = os.path.join(cfg.out_dir, "samples_10k")
os.makedirs(export_dir, exist_ok=True)
vae.eval()

saved = 0
with torch.no_grad():
    while saved < cfg.samples_export_total:
        cur = min(cfg.samples_export_bs, cfg.samples_export_total - saved)
        z = torch.randn(cur, cfg.z_dim, device=device)
        xs = denorm(vae.dec(z))
        for i in range(cur):
            vutils.save_image(xs[i], os.path.join(export_dir, f"sample_{saved+i:05d}.png"))
        saved += cur

print(f"Saved {saved} samples to {export_dir}")


Saved 10000 samples to /content/drive/MyDrive/vae_cifar10_runs/samples_10k


In [None]:
# Vary one latent dimension while keeping others fixed
tfm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
])
test_only = datasets.CIFAR10(cfg.data_root, train=False, download=True, transform=tfm)
x0, _ = test_only[0]
x0 = x0.unsqueeze(0).to(device)

vae.eval()
with torch.no_grad():
    mu, logvar = vae.enc(x0)
    base = mu[0].clone()
    vals = torch.linspace(-3, 3, 13, device=device)
    imgs = []
    dim = 0  # change the latent dimension to traverse
    for v in vals:
        z = base.clone(); z[dim] = v
        img = vae.dec(z.unsqueeze(0))
        imgs.append(denorm(img))
    grid = vutils.make_grid(torch.cat(imgs, dim=0), nrow=len(vals))
    out_path = os.path.join(cfg.out_dir, "latent_traversal_dim0.png")
    vutils.save_image(grid, out_path)
out_path


'/content/drive/MyDrive/vae_cifar10_runs/latent_traversal_dim0.png'

In [None]:
# === FID & IS in one go ===
# - Computes FID between CIFAR-10 test set and your generated samples (10k recommended)
# - Computes Inception Score on your generated samples
# - Saves results to metrics.txt in your Drive run folder

# 1) Install deps
!pip -q install pytorch-fid torchmetrics torch-fidelity

import os, math
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision import datasets
from pytorch_fid import fid_score
from torchmetrics.image.inception import InceptionScore
import torch, os
from torch_fidelity import calculate_metrics


device = "cuda" if torch.cuda.is_available() else "cpu"

# 2) Paths (edit GEN_DIR if you used a different folder)
RUN_DIR = "/content/drive/MyDrive/vae_cifar10_runs"
GEN_DIR = os.path.join(RUN_DIR, "samples_10k")   # generated images you exported
REAL_DIR = "/content/cifar10_real_test"          # will be created once for CIFAR-10 test imgs
os.makedirs(RUN_DIR, exist_ok=True)
os.makedirs(REAL_DIR, exist_ok=True)

# 3) Prepare REAL CIFAR-10 (test set) as PNGs if not already
#    We save raw test images with no normalization (PIL) for a fair FID reference.
if len([f for f in os.listdir(REAL_DIR) if f.endswith(".png")]) < 10000:
    print("Preparing CIFAR-10 test images (first time only)...")
    real_ds = datasets.CIFAR10(root="./data", train=False, download=True, transform=None)
    for i in range(len(real_ds)):
        img, _ = real_ds[i]            # PIL Image
        img.save(f"{REAL_DIR}/{i:05d}.png")
    print("Saved CIFAR-10 test images to:", REAL_DIR)
else:
    print("Found existing CIFAR-10 test images in:", REAL_DIR)

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/983.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m42.0 MB/s[0m eta [36m0:00:00[0m
[?25hPreparing CIFAR-10 test images (first time only)...
Saved CIFAR-10 test images to: /content/cifar10_real_test


In [None]:
# 4) Compute FID (lower is better)
assert os.path.isdir(GEN_DIR) and len(os.listdir(GEN_DIR)) > 0, f"No generated images found at: {GEN_DIR}"
fid = fid_score.calculate_fid_given_paths(
    [REAL_DIR, GEN_DIR],
    batch_size=64,
    device=device,
    dims=2048
)
print(f"\nFID: {fid:.2f}")

Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth


100%|██████████| 91.2M/91.2M [00:00<00:00, 285MB/s]
100%|██████████| 157/157 [00:41<00:00,  3.80it/s]
100%|██████████| 157/157 [02:14<00:00,  1.17it/s]



FID: 240.41


In [None]:
# 5) Compute Inception Score (higher is better)
#    Torchmetrics expects tensors in [0,1] with shape (B,3,H,W). We’ll just load your PNGs.
from torchvision.transforms import InterpolationMode

class FolderDataset(Dataset):
    def __init__(self, path):
        self.files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".png")]
        self.files.sort()
        # Keep uint8 [0,255] tensors
        self.tf = T.Compose([
            T.Resize(299, interpolation=InterpolationMode.BILINEAR),
            T.CenterCrop(299),
            T.PILToTensor(),     # <- uint8, shape (C,H,W)
        ])
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.tf(img)


gen_ds = FolderDataset(GEN_DIR)
gen_loader = DataLoader(gen_ds, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

is_metric = InceptionScore(splits=10).to(device)

with torch.no_grad():
    for batch in gen_loader:
        batch = batch.to(device)          # uint8 [0,255]
        is_metric.update(batch)

is_mean, is_std = is_metric.compute()
print(f"Inception Score: {is_mean:.2f} ± {is_std:.2f}")

Downloading: "https://github.com/toshas/torch-fidelity/releases/download/v0.2.0/weights-inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/weights-inception-2015-12-05-6726825d.pth
100%|██████████| 91.2M/91.2M [00:02<00:00, 40.4MB/s]


Inception Score: 1.98 ± 0.04


In [None]:
from torchvision.transforms import InterpolationMode
from torchmetrics.image.kid import KernelInceptionDistance

class FolderDatasetUint8(Dataset):
    def __init__(self, path):
        self.files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".png")]
        self.files.sort()
        self.tf = T.Compose([
            T.Resize(299, interpolation=InterpolationMode.BILINEAR),
            T.CenterCrop(299),
            T.PILToTensor(),        # -> uint8 [0,255], CxHxW
        ])
    def __len__(self): return len(self.files)
    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        return self.tf(img)

real_ds = FolderDatasetUint8(REAL_DIR)
gen_ds  = FolderDatasetUint8(GEN_DIR)

real_loader = DataLoader(real_ds, batch_size=64, shuffle=False, num_workers=2, pin_memory=True)
gen_loader  = DataLoader(gen_ds,  batch_size=64, shuffle=False, num_workers=2, pin_memory=True)

kid_metric = KernelInceptionDistance(subset_size=1000).to(device)  # subset_size<=num_images

with torch.no_grad():
    # real images
    for batch in real_loader:
        batch = batch.to(device)
        kid_metric.update(batch, real=True)
    # generated images
    for batch in gen_loader:
        batch = batch.to(device)
        kid_metric.update(batch, real=False)

kid_mean, kid_std = kid_metric.compute()
print(f"KID: {kid_mean.item():.4f} ± {kid_std.item():.4f}")



KID: 0.2471 ± 0.0039


In [None]:
# 6) Save results to Drive
with open(os.path.join(RUN_DIR, "metrics.txt"), "w") as f:
    f.write(f"FID: {fid:.2f}\n")
    f.write(f"Inception Score: {is_mean:.2f} ± {is_std:.2f}\n")
    f.write(f"KID: {kid_mean:.6f} ± {kid_std:.6f}\n")

print("\nSaved metrics to:", os.path.join(RUN_DIR, "metrics.txt"))


Saved metrics to: /content/drive/MyDrive/vae_cifar10_runs/metrics.txt


In [None]:
# --- Generate Later: load checkpoint and sample images ---
import os, torch
from torchvision.utils import save_image
from google.colab import drive

drive.mount('/content/drive', force_remount=True)
CKPT_DIR = "/content/drive/MyDrive/vae_cifar10_runs"     # folder you used before
CKPT_FILE = "vae_epoch_050.pth"                           # choose a saved checkpoint

# --- Rebuild model definitions (same as before) ---
import torch.nn as nn
import torch.nn.functional as F
import torchvision.utils as vutils

def denorm(x): return (x.clamp(-1,1) + 1) / 2

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar); eps = torch.randn_like(std)
    return mu + std * eps

class Encoder(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 4, 2, 1, bias=False), nn.BatchNorm2d(32), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(32, 64, 4, 2, 1, bias=False), nn.BatchNorm2d(64), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64,128, 4, 2, 1, bias=False), nn.BatchNorm2d(128), nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128,256,4, 2, 1, bias=False), nn.BatchNorm2d(256), nn.LeakyReLU(0.2, inplace=True),
        )
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(256*2*2, z_dim)
        self.fc_logvar = nn.Linear(256*2*2, z_dim)
    def forward(self, x):
        h = self.net(x); h = self.flatten(h)
        return self.fc_mu(h), self.fc_logvar(h)

class Decoder(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.fc = nn.Linear(z_dim, 256*2*2)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(256,128,4,2,1,bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128,64,4,2,1,bias=False),  nn.BatchNorm2d(64),  nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64,32,4,2,1,bias=False),   nn.BatchNorm2d(32),  nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32,3,4,2,1,bias=False),    nn.Tanh(),
        )
    def forward(self, z):
        h = self.fc(z); h = h.view(h.size(0),256,2,2)
        return self.net(h)

class VAE(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__(); self.enc = Encoder(z_dim); self.dec = Decoder(z_dim)
    def forward(self, x):
        mu, logvar = self.enc(x); z = reparameterize(mu, logvar); xr = self.dec(z)
        return xr, mu, logvar, z

device = "cuda" if torch.cuda.is_available() else "cpu"
ckpt = torch.load(os.path.join(CKPT_DIR, CKPT_FILE), map_location=device)

z_dim = ckpt["cfg"]["z_dim"]
vae = VAE(z_dim=z_dim).to(device)
vae.load_state_dict(ckpt["model"])
vae.eval()

# Generate 64 samples and save a grid to Drive
with torch.no_grad():
    z = torch.randn(64, z_dim, device=device)
    xs = vae.dec(z)
grid_path = os.path.join(CKPT_DIR, "samples_from_loaded.png")
vutils.save_image(vutils.make_grid(denorm(xs), nrow=8), grid_path)
print("Saved:", grid_path)


Mounted at /content/drive
Saved: /content/drive/MyDrive/vae_cifar10_runs/samples_from_loaded.png


In [None]:
# ============================================================
#  Plotting module for VAE training
#  Generates: recon loss, KL loss, total loss, beta curve,
#  and optional latent histograms + recon grids
# ============================================================

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from torchvision.utils import make_grid, save_image

log_path = os.path.join(cfg.out_dir, "train_log.csv")
plot_dir = os.path.join(cfg.out_dir, "plots")
os.makedirs(plot_dir, exist_ok=True)

def save_plt(path):
    plt.savefig(path, bbox_inches='tight')
    plt.close()

# Load CSV
df = pd.read_csv(log_path)

print("Loaded training log from:", log_path)
df.head()

plt.figure(figsize=(8,5))
plt.plot(df["epoch"], df["avg_recon"], label="Reconstruction Loss", linewidth=2)
plt.plot(df["epoch"], df["avg_kl"], label="KL Loss", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("VAE Reconstruction & KL Loss")
plt.grid(True)
# plt.legend()
# plt.show()
save_plt(os.path.join(plot_dir, "VAE_Reconstruction_KL_Loss.png"))


plt.figure(figsize=(8,5))
plt.plot(df["epoch"], df["avg_total"], label="Total Loss", color="black", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Total Loss Over Training")
plt.grid(True)
# plt.legend()
# plt.show()
save_plt(os.path.join(plot_dir, "total_loss.png"))

plt.figure(figsize=(8,4))
plt.plot(df["epoch"], df["beta"], label="KL Weight β", color="purple", linewidth=2)
plt.xlabel("Epoch")
plt.ylabel("β")
plt.title("KL Weight Schedule")
plt.grid(True)
# plt.legend()
# plt.show()
save_plt(os.path.join(plot_dir, "KL_weight_schedule.png"))

# Load a batch and extract latent means
vae.eval()
all_mu = []

with torch.no_grad():
    for x, _ in trainloader:
        x = x.to(device)
        mu, logvar = vae.enc(x)
        all_mu.append(mu.cpu())
        if len(all_mu) > 30:   # about 30 batches (~2000 images)
            break

all_mu = torch.cat(all_mu, dim=0)
z0 = all_mu[:, 0].numpy()    # pick latent dimension 0

plt.figure(figsize=(8,5))
sns.histplot(z0, bins=50, kde=True)
plt.title("Histogram of Latent Dimension z₀")
plt.xlabel("Value")
plt.ylabel("Frequency")
# plt.show()
save_plt(os.path.join(plot_dir, "latent_dimension_histogram.png"))

vae.eval()

x, _ = next(iter(testloader))
x = x[:8].to(device)

with torch.no_grad():
    xr, _, _, _ = vae(x)

# Denormalize
x_dn = denorm(x)
xr_dn = denorm(xr)

# Stack original + recon under each other
comparison = torch.cat([x_dn, xr_dn], dim=0)

grid = make_grid(comparison, nrow=8)
plt.figure(figsize=(12,4))
plt.title("Original (top) vs Reconstruction (bottom)")
plt.imshow(grid.permute(1,2,0).cpu().numpy())
plt.axis('off')
# plt.show()
save_plt(os.path.join(plot_dir, "original_vs_reconstruction.png"))


Loaded training log from: /content/drive/MyDrive/vae_cifar10_runs/train_log.csv
