Task 2: Generative Models - Investigating VAE vs. GAN Biases


Mount the Drive

In [5]:
from google.colab import drive

# mount the google drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


Load the CIFAR-10 data:

In [6]:
!pip -q install torch torchvision

import os, math, itertools, random
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms, utils
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 128
IMG_SIZE = 32
NC = 3  # channels
DATA_ROOT = "/content/drive/MyDrive/datasets"

# VAE usually uses inputs in [0,1]; DCGAN prefers [-1,1] (tanh generator).
# We'll build two separate loaders so each model gets what it needs.
tfm_vae = transforms.Compose([
    transforms.ToTensor(),  # [0,1]
])

tfm_gan = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))  # [-1,1]
])

train_vae = torchvision.datasets.CIFAR10(DATA_ROOT, train=True, download=True, transform=tfm_vae)
train_gan = torchvision.datasets.CIFAR10(DATA_ROOT, train=True, download=True, transform=tfm_gan)

dl_vae = DataLoader(train_vae, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
dl_gan = DataLoader(train_gan, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

save_path = "/content/drive/MyDrive/AI Masters/ATML/PA1-Task2"
samples_vae = os.path.join(save_path, "samples_vae")
samples_gan = os.path.join(save_path, "samples_gan")

os.makedirs(samples_vae, exist_ok=True)
os.makedirs(samples_gan, exist_ok=True)

1) Variational Autoencoder (VAE)

Simple convolutional VAE with BCE reconstruction + KL term. Latent size = 128.

In [7]:
class ConvVAE(nn.Module):
    def __init__(self, z_dim=128):
        super().__init__()
        self.z_dim = z_dim
        # Encoder: 32x32 -> 16 -> 8 -> 4
        self.enc = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1), nn.ReLU(True),     # 16x16
            nn.Conv2d(64,128,4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),  # 8x8
            nn.Conv2d(128,256,4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True), # 4x4
        )
        self.enc_out = nn.Flatten()
        self.fc_mu = nn.Linear(256*4*4, z_dim)
        self.fc_logvar = nn.Linear(256*4*4, z_dim)

        # Decoder: 4 -> 8 -> 16 -> 32
        self.fc_dec = nn.Linear(z_dim, 256*4*4)
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(256,128,4,2,1), nn.BatchNorm2d(128), nn.ReLU(True), # 8x8
            nn.ConvTranspose2d(128,64, 4,2,1), nn.BatchNorm2d(64),  nn.ReLU(True), # 16x16
            nn.ConvTranspose2d(64, 3,   4,2,1), nn.Sigmoid() # 32x32, output in [0,1]
        )

    def encode(self, x):
        h = self.enc(x); h = self.enc_out(h)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = (0.5*logvar).exp()
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h = self.fc_dec(z).view(-1,256,4,4)
        return self.dec(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        x_hat = self.decode(z)
        return x_hat, mu, logvar

def vae_loss(x, x_hat, mu, logvar, beta=1.0):
    # recon with BCE (x, x_hat in [0,1]); reduction=sum for stable KL scaling
    bce = F.binary_cross_entropy(x_hat, x, reduction='sum')
    # KL(N(mu, sigma) || N(0, I))
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return (bce + beta*kl), bce, kl

@torch.no_grad()
def save_vae_samples(model, n=64, z_dim=128, step=0):
    model.eval()
    z = torch.randn(n, z_dim, device=DEVICE)
    imgs = model.decode(z).cpu()
    utils.save_image(imgs, f"{samples_vae}/vae_{step:06d}.png", nrow=8)

# Train VAE
vae = ConvVAE(z_dim=128).to(DEVICE)
opt_vae = torch.optim.Adam(vae.parameters(), lr=2e-4)

EPOCHS_VAE = 20 # 5 in the beginning for sanity check, 20 otherwise
global_step = 0
for epoch in range(1, EPOCHS_VAE+1):
    vae.train()
    pbar = tqdm(dl_vae, desc=f"VAE Epoch {epoch}/{EPOCHS_VAE}")
    running = {"loss":0., "bce":0., "kl":0., "n":0}
    for x,_ in pbar:
        x = x.to(DEVICE)
        x_hat, mu, logvar = vae(x)
        loss, bce, kl = vae_loss(x, x_hat, mu, logvar, beta=1.0)

        opt_vae.zero_grad(set_to_none=True)
        loss.backward()
        opt_vae.step()

        bs = x.size(0)
        running["loss"] += loss.item()
        running["bce"]  += bce.item()
        running["kl"]   += kl.item()
        running["n"]    += bs
        pbar.set_postfix(loss=running["loss"]/running["n"])

        if global_step % 500 == 0:
            save_vae_samples(vae, n=64, z_dim=vae.z_dim, step=global_step)
        global_step += 1

    # epoch-end samples
    save_vae_samples(vae, n=64, z_dim=vae.z_dim, step=global_step)
    torch.save(vae.state_dict(), f"{save_path}/snap_vae_epoch{epoch}.pt")


VAE Epoch 1/20: 100%|██████████| 391/391 [00:04<00:00, 96.61it/s, loss=1.97e+3]
VAE Epoch 2/20: 100%|██████████| 391/391 [00:03<00:00, 99.28it/s, loss=1.87e+3]
VAE Epoch 3/20: 100%|██████████| 391/391 [00:03<00:00, 98.14it/s, loss=1.84e+3]
VAE Epoch 4/20: 100%|██████████| 391/391 [00:04<00:00, 79.89it/s, loss=1.84e+3]
VAE Epoch 5/20: 100%|██████████| 391/391 [00:04<00:00, 92.90it/s, loss=1.83e+3] 
VAE Epoch 6/20: 100%|██████████| 391/391 [00:03<00:00, 97.89it/s, loss=1.83e+3]
VAE Epoch 7/20: 100%|██████████| 391/391 [00:03<00:00, 99.82it/s, loss=1.83e+3]
VAE Epoch 8/20: 100%|██████████| 391/391 [00:04<00:00, 96.94it/s, loss=1.83e+3]
VAE Epoch 9/20: 100%|██████████| 391/391 [00:04<00:00, 96.01it/s, loss=1.83e+3]
VAE Epoch 10/20: 100%|██████████| 391/391 [00:04<00:00, 93.03it/s, loss=1.82e+3] 
VAE Epoch 11/20: 100%|██████████| 391/391 [00:04<00:00, 96.47it/s, loss=1.82e+3] 
VAE Epoch 12/20: 100%|██████████| 391/391 [00:03<00:00, 99.63it/s, loss=1.82e+3] 
VAE Epoch 13/20: 100%|██████████|

VAE Reconstruction:

In [8]:
# --- 3) Restore your trained VAE (pick the checkpoint you want) ---
vae = ConvVAE(z_dim=128).to(DEVICE)
ckpt_path = os.path.join(save_path, "snap_vae_epoch20.pt")  # <-- change epoch if needed
vae.load_state_dict(torch.load(ckpt_path, map_location=DEVICE))
vae.eval()

# --- 4) Make a test loader (no normalization; VAE expects [0,1]) ---
tfm_vae = transforms.Compose([transforms.ToTensor()])  # [0,1]
testset = torchvision.datasets.CIFAR10(root=save_path, train=False, download=True, transform=tfm_vae)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)

# --- 5) Take a random batch, reconstruct, and save side-by-side ---
@torch.no_grad()
def save_vae_recons(model, loader, n=4, out_name="vae_recon.png"):
    x, _ = next(iter(loader))      # random batch due to shuffle=True
    x = x[:n].to(DEVICE)           # take first n images
    x_hat, _, _ = model(x)         # encode -> sample z -> decode

    # Stack originals over reconstructions for visual comparison
    grid = torch.cat([x.cpu(), x_hat.cpu()], dim=0)  # first n originals, then n recons
    utils.save_image(grid, os.path.join(samples_vae, out_name), nrow=n)
    print(f"Saved to: {os.path.join(samples_vae, out_name)}")

save_vae_recons(vae, testloader, n=4, out_name="vae_recon_test.png")


100%|██████████| 170M/170M [00:14<00:00, 11.8MB/s]


Saved to: /content/drive/MyDrive/AI Masters/ATML/PA1-Task2/samples_vae/vae_recon_test.png


DCGAN (GAN)

Classic DCGAN for 32×32 RGB. Uses tanh output and [-1,1] inputs.

In [9]:
NZ = 128   # latent size
NGF = 128  # generator feature maps
NDF = 128  # discriminator feature maps
LR = 2e-4
BETA1 = 0.5
BETA2 = 0.999

class Generator(nn.Module):
    def __init__(self, z_dim=NZ, ngf=NGF, nc=NC):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(z_dim, ngf*4, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf*4), nn.ReLU(True),   # 4x4
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf*2), nn.ReLU(True),   # 8x8
            nn.ConvTranspose2d(ngf*2, ngf,   4, 2, 1, bias=False), nn.BatchNorm2d(ngf),   nn.ReLU(True),   # 16x16
            nn.ConvTranspose2d(ngf,   nc,    4, 2, 1, bias=False), nn.Tanh()                                # 32x32
        )
    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, ndf=NDF, nc=NC):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(nc,  ndf,   4, 2, 1, bias=False), nn.LeakyReLU(0.2, True),       # 16x16
            nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, True), # 8x8
            nn.Conv2d(ndf*2, ndf*4,4, 2, 1, bias=False), nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, True), # 4x4
            nn.Conv2d(ndf*4, 1,    4, 1, 0, bias=False) # 1x1
        )
    def forward(self, x):
        return self.net(x).view(-1)

G = Generator().to(DEVICE)
D = Discriminator().to(DEVICE)

optG = torch.optim.Adam(G.parameters(), lr=LR, betas=(BETA1, BETA2))
optD = torch.optim.Adam(D.parameters(), lr=LR, betas=(BETA1, BETA2))

criterion = nn.BCEWithLogitsLoss()

@torch.no_grad()
def save_gan_samples(G, step=0, n=64):
    G.eval()
    z = torch.randn(n, NZ, 1, 1, device=DEVICE)
    fake = G(z).cpu()  # in [-1,1]
    # denormalize to [0,1] for saving
    imgs = (fake + 1)/2
    utils.save_image(imgs, f"{samples_gan}/gan_{step:06d}.png", nrow=8)

EPOCHS_GAN = 50
fixed_z = torch.randn(64, NZ, 1, 1, device=DEVICE)  # for consistent previews
step = 0

for epoch in range(1, EPOCHS_GAN+1):
    pbar = tqdm(dl_gan, desc=f"GAN Epoch {epoch}/{EPOCHS_GAN}")
    for x,_ in pbar:
        x = x.to(DEVICE)  # real in [-1,1]

        # 1) Update D: maximize log D(x) + log(1 - D(G(z)))
        D.train(); G.train()
        bs = x.size(0)
        real_labels = torch.ones(bs, device=DEVICE)
        fake_labels = torch.zeros(bs, device=DEVICE)

        # real
        d_real = D(x)
        loss_real = criterion(d_real, real_labels)

        # fake
        z = torch.randn(bs, NZ, 1, 1, device=DEVICE)
        with torch.no_grad():
            x_fake = G(z)
        d_fake = D(x_fake)
        loss_fake = criterion(d_fake, fake_labels)

        loss_D = loss_real + loss_fake
        optD.zero_grad(set_to_none=True)
        loss_D.backward()
        optD.step()

        # 2) Update G: maximize log D(G(z))  <=> minimize BCE(D(G(z)), 1)
        z = torch.randn(bs, NZ, 1, 1, device=DEVICE)
        x_fake = G(z)
        d_fake = D(x_fake)
        loss_G = criterion(d_fake, real_labels)

        optG.zero_grad(set_to_none=True)
        loss_G.backward()
        optG.step()

        if step % 500 == 0:
            with torch.no_grad():
                imgs = G(fixed_z).cpu()
                utils.save_image((imgs+1)/2, f"{samples_gan}/gan_fixed_{step:06d}.png", nrow=8)
        step += 1

    save_gan_samples(G, step=step)
    torch.save(G.state_dict(), f"{save_path}/snap_g_epoch{epoch}.pt")
    torch.save(D.state_dict(), f"{save_path}/snap_d_epoch{epoch}.pt")


GAN Epoch 1/50: 100%|██████████| 391/391 [00:06<00:00, 56.59it/s]
GAN Epoch 2/50: 100%|██████████| 391/391 [00:06<00:00, 62.32it/s]
GAN Epoch 3/50: 100%|██████████| 391/391 [00:06<00:00, 61.05it/s]
GAN Epoch 4/50: 100%|██████████| 391/391 [00:06<00:00, 61.26it/s]
GAN Epoch 5/50: 100%|██████████| 391/391 [00:06<00:00, 61.70it/s]
GAN Epoch 6/50: 100%|██████████| 391/391 [00:06<00:00, 61.57it/s]
GAN Epoch 7/50: 100%|██████████| 391/391 [00:06<00:00, 65.05it/s]
GAN Epoch 8/50: 100%|██████████| 391/391 [00:06<00:00, 59.81it/s]
GAN Epoch 9/50: 100%|██████████| 391/391 [00:06<00:00, 63.03it/s]
GAN Epoch 10/50: 100%|██████████| 391/391 [00:06<00:00, 63.29it/s]
GAN Epoch 11/50: 100%|██████████| 391/391 [00:06<00:00, 62.21it/s]
GAN Epoch 12/50: 100%|██████████| 391/391 [00:06<00:00, 62.72it/s]
GAN Epoch 13/50: 100%|██████████| 391/391 [00:06<00:00, 64.57it/s]
GAN Epoch 14/50: 100%|██████████| 391/391 [00:06<00:00, 58.92it/s]
GAN Epoch 15/50: 100%|██████████| 391/391 [00:06<00:00, 63.79it/s]
GAN 

Latent Space Structure - helper code:

In [10]:
import math
import random
from pathlib import Path
from typing import Tuple, Optional, Iterable
import torch
import torch.nn.functional as F
from torch import Tensor
from torchvision import datasets, transforms
from torchvision import utils as vutils

# ------------------------------------------------------------------
# Globals and small utilities
# ------------------------------------------------------------------

try:
    DEVICE  # type: ignore[name-defined]
except NameError:
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

_CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
_CIFAR10_STD  = (0.2470, 0.2435, 0.2616)

def get_cifar10_transforms(normalize: bool = True):
    t = [transforms.ToTensor()]
    if normalize:
        t.append(transforms.Normalize(_CIFAR10_MEAN, _CIFAR10_STD))
    return transforms.Compose(t)

def denormalize_cifar10(x: Tensor) -> Tensor:
    """Inverse of CIFAR-10 normalization. Clamps into [0,1]."""
    mean = torch.tensor(_CIFAR10_MEAN, device=x.device).view(1, -1, 1, 1)
    std  = torch.tensor(_CIFAR10_STD,  device=x.device).view(1, -1, 1, 1)
    x = x * std + mean
    return x.clamp(0, 1)

def ensure_dir(path: str):
    Path(path).parent.mkdir(parents=True, exist_ok=True)

def to_device(x: Tensor) -> Tensor:
    return x.to(DEVICE, non_blocking=True)

def reparameterize(mu: Tensor, logvar: Tensor, deterministic: bool = False) -> Tensor:
    if deterministic:
        return mu
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

# ------------------------------------------------------------------
# Picking two test images of different classes (CIFAR-10)
# ------------------------------------------------------------------

def pick_two_test_images_from_different_classes(
    root: str = DATA_ROOT,

    normalize: bool = True,
    seed: Optional[int] = 123
) -> Tuple[Tuple[Tensor, int], Tuple[Tensor, int]]:
    """
    Returns two (image, label) tuples from CIFAR-10 test set with different labels.
    Images are returned as CHW tensors, optionally normalized with CIFAR-10 stats.
    """
    if seed is not None:
        random.seed(seed)

    tfm = get_cifar10_transforms(normalize=normalize)
    testset = datasets.CIFAR10(root=root, train=False, download=True, transform=tfm)

    idx1 = random.randrange(len(testset))
    x1, y1 = testset[idx1]

    # Find another with a different class
    for _ in range(20000):
        idx2 = random.randrange(len(testset))
        x2, y2 = testset[idx2]
        if y2 != y1:
            return (x1, y1), (x2, y2)

    # fallback (shouldn't happen)
    return (x1, y1), (x2, y2)

# ------------------------------------------------------------------
# Interpolation math (LERP and SLERP)
# ------------------------------------------------------------------

def lerp(z1: Tensor, z2: Tensor, t: float) -> Tensor:
    return (1.0 - t) * z1 + t * z2

def slerp(z1: Tensor, z2: Tensor, t: float, eps: float = 1e-7) -> Tensor:
    """
    Spherical linear interpolation for smoother perceptual paths in GAN latent spaces.
    Works for z vectors (B, Z) or (Z,). Operates per-sample if batched.
    """
    # Flatten batch if needed
    squeeze = False
    if z1.dim() == 1:
        z1 = z1.unsqueeze(0)
        z2 = z2.unsqueeze(0)
        squeeze = True

    z1n = torch.nn.functional.normalize(z1, dim=1, eps=eps)
    z2n = torch.nn.functional.normalize(z2, dim=1, eps=eps)
    dot = (z1n * z2n).sum(dim=1, keepdim=True).clamp(-1 + eps, 1 - eps)
    omega = torch.acos(dot)
    so = torch.sin(omega)

    t = torch.as_tensor(t, device=z1.device, dtype=z1.dtype).view(1, 1)

    # When z1 ~= z2, fall back to lerp to avoid division by small numbers
    mask = (so < 1e-3).float()
    s1 = torch.sin((1.0 - t) * omega) / (so + eps)
    s2 = torch.sin(t * omega) / (so + eps)
    out = (s1 * z1 + s2 * z2) * (1.0 - mask) + lerp(z1, z2, t.item()) * mask

    if squeeze:
        out = out.squeeze(0)
    return out

def interp_1d(z1: Tensor, z2: Tensor, n_intermediate: int, mode: str = 'lerp') -> Tensor:
    """
    Returns a tensor of shape (n_intermediate+2, Z) including endpoints.
    """
    ts = torch.linspace(0.0, 1.0, steps=n_intermediate + 2, device=z1.device, dtype=z1.dtype)
    outs = []
    for t in ts.tolist():
        if mode == 'slerp':
            outs.append(slerp(z1, z2, t))
        else:
            outs.append(lerp(z1, z2, t))
    return torch.stack(outs, dim=0)

# ------------------------------------------------------------------
# VAE interpolation
# ------------------------------------------------------------------

@torch.no_grad()
def interpolate_vae_and_save_grid(
    model,
    x1: Tensor, x2: Tensor,
    n_intermediate: int = 10,
    out_path: str = f'{save_path}/samples_vae/interp/vae_interp_grid.png',
    use_mu: bool = True,
    detach: bool = True,
    denormalize_before_save: bool = True,
):
    """
    - Encodes x1, x2 with model.encode -> (mu, logvar)
    - Takes z = mu (default) for each endpoint, optionally samples via reparameterize
    - Linearly interpolates between z1 and z2
    - Decodes each z to an image and saves a single-row grid
    """
    ensure_dir(out_path)
    model.eval()

    x1 = x1.unsqueeze(0).to(DEVICE)
    x2 = x2.unsqueeze(0).to(DEVICE)

    # Some VAEs implement encode(x) -> (mu, logvar); support both possibilities.
    enc1 = model.encode(x1)
    enc2 = model.encode(x2)

    if isinstance(enc1, (tuple, list)) and len(enc1) >= 2:
        mu1, logvar1 = enc1[0], enc1[1]
    else:
        raise ValueError("model.encode(x) must return (mu, logvar) for this helper.")

    if isinstance(enc2, (tuple, list)) and len(enc2) >= 2:
        mu2, logvar2 = enc2[0], enc2[1]
    else:
        raise ValueError("model.encode(x) must return (mu, logvar) for this helper.")

    z1 = reparameterize(mu1, logvar1, deterministic=use_mu).squeeze(0)
    z2 = reparameterize(mu2, logvar2, deterministic=use_mu).squeeze(0)

    zs = interp_1d(z1, z2, n_intermediate=n_intermediate, mode='lerp')  # VAE: lerp is typical
    imgs = model.decode(zs.to(DEVICE))

    if imgs.dim() != 4:
        raise ValueError("Decoded images must be a 4D tensor (N,C,H,W).")

    # VAE usually outputs [0,1]; clamp to be safe
    grid = vutils.make_grid(imgs.clamp(0, 1), nrow=imgs.size(0), padding=2)
    vutils.save_image(grid, out_path)
    return out_path

# ------------------------------------------------------------------
# GAN interpolation
# ------------------------------------------------------------------

@torch.no_grad()
def interpolate_gan_and_save_grid(G, z1: Optional[Tensor]=None, z2: Optional[Tensor]=None, z_dim: int=128,
                                  n_intermediate: int=10, out_path: str='samples_gan/interp/gan_interp_grid.png',
                                  mode: str='slerp', tanh_output: bool=True, reshape_policy: str='auto'):
    ensure_dir(out_path); G.eval()
    if z1 is None: z1=torch.randn(z_dim, device=DEVICE)
    if z2 is None: z2=torch.randn(z_dim, device=DEVICE)
    zs = interp_1d(z1, z2, n_intermediate, mode=mode).to(DEVICE)  # (T,Z)

    def forward(zflat: Tensor):
        if reshape_policy=='always':
            return G(zflat.view(zflat.size(0), zflat.size(1), 1, 1))
        if reshape_policy=='never':
            return G(zflat)
        # auto
        try:
            return G(zflat)
        except RuntimeError:
            return G(zflat.view(zflat.size(0), zflat.size(1), 1, 1))

    imgs = forward(zs)
    if imgs.dim()!=4: raise ValueError("Generator must return (N,C,H,W).")
    if tanh_output: imgs = (imgs+1)/2.0
    imgs = imgs.clamp(0,1)
    vutils.save_image(vutils.make_grid(imgs, nrow=imgs.size(0), padding=2), out_path)
    return out_path

Actual Code for Interpolation:

In [11]:
# 1) Two CIFAR-10 test images from different classes (normalized for your encoder)
(x1, y1), (x2, y2) = pick_two_test_images_from_different_classes(root='./data', normalize=True)

# 2) VAE interpolation (μ-only for clean path). Saves a 12-frame row: endpoints + 10 steps
vae_grid_path = interpolate_vae_and_save_grid(
    model=vae,        # <-- your VAE instance
    x1=x1, x2=x2,
    n_intermediate=10,
    out_path=f'{save_path}/samples_vae/interp/vae_interp_grid.png',
    use_mu=True,            # set False to sample via reparameterization
)

# 3) GAN interpolation (SLERP recommended)
gan_grid_path = interpolate_gan_and_save_grid(
    G=G,                    # <-- your trained generator
    z_dim=128,              # match your training
    n_intermediate=10,
    out_path=f'{save_path}/samples_gan/interp/gan_interp_grid.png',
    mode='slerp',           # 'lerp' also available
    tanh_output=True        # set False if G outputs [0,1]
)

print("Saved:", vae_grid_path, gan_grid_path)


100%|██████████| 170M/170M [00:18<00:00, 9.05MB/s]


Saved: /content/drive/MyDrive/AI Masters/ATML/PA1-Task2/samples_vae/interp/vae_interp_grid.png /content/drive/MyDrive/AI Masters/ATML/PA1-Task2/samples_gan/interp/gan_interp_grid.png


Helper functions for 2D PCA.

In [12]:

# plot_vae_pca.py
import torch
from torch import Tensor
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from pathlib import Path

try:
    from sklearn.decomposition import PCA as SKPCA
    _HAVE_SKLEARN = True
except Exception:
    _HAVE_SKLEARN = False

try:
    DEVICE  # type: ignore[name-defined]
except NameError:
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

_CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
_CIFAR10_STD  = (0.2470, 0.2435, 0.2616)
_CIFAR10_CLASSES = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

def get_cifar10_transform(normalize: bool=True):
    tfms = [transforms.ToTensor()]
    if normalize:
        tfms.append(transforms.Normalize(_CIFAR10_MEAN, _CIFAR10_STD))
    return transforms.Compose(tfms)

def collect_10_per_class(root: str=DATA_ROOT, normalize: bool=True):
    ds = datasets.CIFAR10(root=root, train=False, download=True, transform=get_cifar10_transform(normalize))
    per_class = {k: [] for k in range(10)}
    for idx in range(len(ds)):
        x, y = ds[idx]
        if len(per_class[y]) < 10:
            per_class[y].append(x)
        if all(len(v) >= 10 for v in per_class.values()):
            break
    xs, ys = [], []
    for c in range(10):
        xs.extend(per_class[c])
        ys.extend([c]*len(per_class[c]))
    X = torch.stack(xs, dim=0)  # [100,3,32,32]
    y = torch.tensor(ys, dtype=torch.long)
    return X, y

@torch.no_grad()
def encode_vae_mu(model, X: Tensor, batch_size: int=64) -> Tensor:
    model.eval()
    Zs = []
    for i in range(0, X.size(0), batch_size):
        xb = X[i:i+batch_size].to(DEVICE)
        mu, logvar = model.encode(xb)[:2]
        Zs.append(mu.detach().cpu())
    return torch.cat(Zs, dim=0)

def pca_2d(Z: Tensor) -> Tensor:
    Z = Z.detach().cpu().float()
    if _HAVE_SKLEARN:
        pca = SKPCA(n_components=2, svd_solver='auto', random_state=42)
        import numpy as np
        Z2 = torch.from_numpy(pca.fit_transform(Z.numpy()))
    else:
        Zc = Z - Z.mean(dim=0, keepdim=True)
        U, S, Vh = torch.linalg.svd(Zc, full_matrices=False)
        comps = Vh[:2]  # [2, D]
        Z2 = Zc @ comps.t()
    return Z2

def plot_scatter_by_class(Z2: Tensor, y: Tensor, out_path: str) -> str:
    plt.figure(figsize=(7.5, 6))
    cmap = plt.get_cmap('tab10')
    for c in range(10):
        idx = (y==c).nonzero(as_tuple=False).squeeze(1)
        pts = Z2[idx]
        plt.scatter(pts[:,0].numpy(), pts[:,1].numpy(),
                    label=_CIFAR10_CLASSES[c], s=28, alpha=0.85, c=[cmap(c)])
    plt.legend(loc='best', fontsize=9, ncol=2, frameon=True)
    plt.title('VAE Latent Space (PCA to 2D) — 10 images per CIFAR-10 class')
    plt.xlabel('PC1'); plt.ylabel('PC2')
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(out_path, dpi=160)
    plt.close()
    return out_path

def plot_vae_latents_pca(model,
                         root: str='./data',
                         batch_size: int=64,
                         normalize: bool=True,
                         out_path: str='plots/vae_cifar10_pca.png') -> str:
    X, y = collect_10_per_class(root=DATA_ROOT, normalize=normalize)
    Z = encode_vae_mu(model, X, batch_size=batch_size)
    Z2 = pca_2d(Z)
    return plot_scatter_by_class(Z2, y, out_path)


Plotting 2D PCA:

In [13]:
fig_path = plot_vae_latents_pca(
    model=vae,
    root=DATA_ROOT,
    batch_size=64,
    normalize=True,
    out_path=f'{save_path}/plots/vae_cifar10_pca.png'
)
print("PCA plot saved:", fig_path)


PCA plot saved: /content/drive/MyDrive/AI Masters/ATML/PA1-Task2/plots/vae_cifar10_pca.png


Helper Functions for Semantic Factors:

In [14]:

import math, numpy as np
from pathlib import Path
from typing import Optional, Tuple, List

import torch
from torch import Tensor
import matplotlib.pyplot as plt

try:
    DEVICE  # type: ignore[name-defined]
except NameError:
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

_CIFAR10_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1,3,1,1)
_CIFAR10_STD  = torch.tensor([0.2470, 0.2435, 0.2616]).view(1,3,1,1)
_CIFAR10_NAMES = ["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
_CAT_LABEL = 3

def _to_chw_float(x: np.ndarray) -> torch.Tensor:
    if x.dtype != np.float32:
        x = x.astype(np.float32) / 255.0
    t = torch.from_numpy(x).permute(2,0,1).unsqueeze(0)  # 1x3x32x32
    return t

def _normalize(x: Tensor, normalize: bool=True) -> Tensor:
    if not normalize: return x
    return (x - _CIFAR10_MEAN) / _CIFAR10_STD

def _load_cifar10_cat_keras(normalize: bool=True) -> Tuple[Tensor, int]:
    from tensorflow.keras.datasets import cifar10
    (_, _), (x_test, y_test) = cifar10.load_data()
    y = y_test.flatten()
    idx = int(np.where(y == _CAT_LABEL)[0][0])
    img = x_test[idx]
    x = _to_chw_float(img)
    x = _normalize(x, normalize)
    return x, int(_CAT_LABEL)

def _ensure_dir(path: str):
    Path(path).parent.mkdir(parents=True, exist_ok=True)

def _make_single_row_grid(imgs: Tensor) -> np.ndarray:
    N, C, H, W = imgs.shape
    canvas = torch.zeros(C, H, W*N, dtype=imgs.dtype)
    for i in range(N):
        canvas[:, :, i*W:(i+1)*W] = imgs[i]
    arr = canvas.permute(1,2,0).detach().cpu().numpy()
    arr = np.clip(arr, 0.0, 1.0)
    return arr

@torch.no_grad()
def vary_single_latent_axis_for_cat(
    vae_model,
    axis: Optional[int] = None,
    steps: int = 11,
    sigma_span: float = 3.0,
    normalize_input: bool = True,
    decoder_output: str = "sigmoid",   # 'sigmoid' or 'tanh'
    out_path: str = f"{save_path}/plots/vae_cat_axis_sweep.png",
) -> Tuple[str, List[float], int]:
    vae_model.eval().to(DEVICE)

    x, y = _load_cifar10_cat_keras(normalize=normalize_input)  # [1,3,32,32]
    x = x.to(DEVICE)

    enc = vae_model.encode(x)
    if isinstance(enc, (tuple, list)) and len(enc) >= 2:
        mu, logvar = enc[0], enc[1]
    else:
        raise ValueError("Expected vae_model.encode(x) -> (mu, logvar).")

    mu = mu.squeeze(0); logvar = logvar.squeeze(0)
    sigma = torch.exp(0.5 * logvar)

    if axis is None:
        axis = int(torch.argmax(sigma).item())

    vals = torch.linspace(-sigma_span, sigma_span, steps=steps, device=DEVICE)
    values_used = (mu[axis] + vals * sigma[axis]).detach().cpu().tolist()

    zs = []
    for t in vals:
        z = mu.clone()
        z[axis] = mu[axis] + t * sigma[axis]
        zs.append(z)
    Z = torch.stack(zs, dim=0)

    imgs = vae_model.decode(Z)
    if decoder_output == "tanh":
        imgs = (imgs + 1.0) / 2.0
    imgs = imgs.clamp(0,1)

    grid = _make_single_row_grid(imgs)
    _ensure_dir(out_path)
    plt.imsave(out_path, grid)
    return out_path, values_used, axis


Semantic Factors Code: Changing one dimention while fixing all others.

In [15]:
# Run traversal (auto-picks axis with largest posterior sigma)
out_path, values_used, axis = vary_single_latent_axis_for_cat(
    vae,
    axis=None,            # or specify an int index, e.g., axis=7
    steps=11,             # number of images (endpoints included)
    sigma_span=3.0,       # sweep mu_k ± 3 * sigma_k
    normalize_input=True, # set False if your VAE expects raw [0,1]
    decoder_output="sigmoid",  # "tanh" if decoder outputs [-1,1]
    out_path=f"{save_path}/plots/vae_cat_axis_sweep.png"
)
print("Saved grid at:", out_path)
print("Axis varied:", axis)
print("Values used (z_k):", values_used)


Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m123s[0m 1us/step
Saved grid at: /content/drive/MyDrive/AI Masters/ATML/PA1-Task2/plots/vae_cat_axis_sweep.png
Axis varied: 56
Values used (z_k): [-6.496976852416992, -4.946293830871582, -3.3956100940704346, -1.8449268341064453, -0.29424333572387695, 1.2564395666122437, 2.8071231842041016, 4.35780668258667, 5.908490180969238, 7.459174156188965, 9.009857177734375]


Helper code for OOD reconstruction:

In [16]:

import numpy as np
from pathlib import Path
from typing import Tuple, Dict

import torch
import matplotlib.pyplot as plt

try:
    DEVICE  # type: ignore[name-defined]
except NameError:
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# CIFAR-10 normalization constants
_CIFAR10_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1,3,1,1)
_CIFAR10_STD  = torch.tensor([0.2470, 0.2435, 0.2616]).view(1,3,1,1)

# -------------------------------
# Data helpers (no torchvision)
# -------------------------------

def _to_chw_float(x: np.ndarray) -> torch.Tensor:
    # x: HxWxC uint8 or float in [0,1]; return 1x3x32x32 float32 in [0,1]
    if x.dtype != np.float32:
        x = x.astype(np.float32) / 255.0
    t = torch.from_numpy(x).permute(2,0,1).unsqueeze(0)  # 1x3x32x32
    return t

def _normalize(x: torch.Tensor, normalize: bool=True) -> torch.Tensor:
    if not normalize:
        return x
    return (x - _CIFAR10_MEAN) / _CIFAR10_STD

def load_cifar10_cat_keras(normalize: bool=True) -> torch.Tensor:
    """
    Load one 'cat' image from CIFAR-10 test set via Keras and return [1,3,32,32] tensor.
    """
    from tensorflow.keras.datasets import cifar10
    (_, _), (x_test, y_test) = cifar10.load_data()
    y = y_test.flatten()
    idx = int(np.where(y == 3)[0][0])  # label 3 = 'cat'
    x = _to_chw_float(x_test[idx])
    x = _normalize(x, normalize)
    return x

def make_synthetic_house(normalize: bool=True) -> torch.Tensor:
    """
    Create a simple 32x32 'house' icon (square + triangle roof), [1,3,32,32] in [0,1] then normalize.
    """
    H, W = 32, 32
    img = np.ones((H,W,3), dtype=np.float32)  # white background

    # House body (rectangle)
    body_top, body_bottom = 14, 28
    body_left, body_right = 8, 24
    img[body_top:body_bottom, body_left:body_right, :] = np.array([0.6, 0.6, 0.8], dtype=np.float32)  # light bluish

    # Door
    door_top, door_bottom = 20, 28
    door_left, door_right = 14, 18
    img[door_top:door_bottom, door_left:door_right, :] = np.array([0.4, 0.2, 0.1], dtype=np.float32)

    # Roof (triangle)
    for r in range(10, 15):
        span = r - 10  # increases with row
        c_center = 16
        c_left = c_center - span - 1
        c_right = c_center + span + 1
        img[r, max(0,c_left):min(W, c_right), :] = np.array([0.8, 0.2, 0.2], dtype=np.float32)

    x = _to_chw_float((img * 255).astype(np.uint8))
    x = _normalize(x, normalize)
    return x

# -------------------------------
# VAE reconstruction + metrics
# -------------------------------

@torch.no_grad()
def vae_reconstruct(vae_model, x: torch.Tensor, decoder_output: str="sigmoid"):
    """
    Given x [1,3,32,32], returns reconstruction x_hat [1,3,32,32] in [0,1] (mapped if decoder_output='tanh').
    Uses z = mu for deterministic reconstruction.
    """
    vae_model.eval().to(DEVICE)
    x = x.to(DEVICE)

    enc = vae_model.encode(x)
    if isinstance(enc, (tuple, list)) and len(enc) >= 2:
        mu, logvar = enc[0], enc[1]
    else:
        raise ValueError("Expected vae_model.encode(x) -> (mu, logvar).")

    z = mu  # deterministic
    x_hat = vae_model.decode(z)

    if decoder_output == "tanh":
        x_hat = (x_hat + 1.0) / 2.0

    return x_hat.clamp(0,1)

def mse(a: torch.Tensor, b: torch.Tensor) -> float:
    return float(torch.mean((a - b) ** 2).cpu())

def l1(a: torch.Tensor, b: torch.Tensor) -> float:
    return float(torch.mean(torch.abs(a - b)).cpu())

def psnr(a: torch.Tensor, b: torch.Tensor, data_range: float = 1.0) -> float:
    e = torch.mean((a - b) ** 2)
    e = float(e.cpu())
    if e == 0:
        return float("inf")
    import math
    return 20.0 * math.log10(data_range) - 10.0 * math.log10(e)

def save_pair(original: torch.Tensor, recon: torch.Tensor, path: str, title_top: str, title_bottom: str):
    """
    Save a 2-row image: original on top, reconstruction on bottom.
    """
    import matplotlib.pyplot as plt
    orig = original.squeeze(0).permute(1,2,0).detach().cpu().numpy()
    rec  = recon.squeeze(0).permute(1,2,0).detach().cpu().numpy()

    Path(path).parent.mkdir(parents=True, exist_ok=True)
    plt.figure(figsize=(3.2, 6.0))
    plt.subplot(2,1,1); plt.imshow(orig); plt.axis("off"); plt.title(title_top)
    plt.subplot(2,1,2); plt.imshow(rec);  plt.axis("off"); plt.title(title_bottom)
    plt.tight_layout()
    plt.savefig(path, dpi=140)
    plt.close()

def compare_cat_vs_house_reconstruction(
    vae_model,
    normalize_input: bool = True,
    decoder_output: str = "sigmoid",
    out_dir: str = "plots/vae_compare"
) -> Tuple[Dict[str, float], Dict[str, float]]:
    """
    Reconstructs a CIFAR-10 'cat' sample and a synthetic 'house' image with the VAE,
    computes MSE/L1/PSNR errors, and saves side-by-side images.
    Returns (cat_metrics, house_metrics).
    """
    Path(out_dir).mkdir(parents=True, exist_ok=True)

    # 1) Get images
    x_cat   = load_cifar10_cat_keras(normalize=normalize_input)     # [1,3,32,32]
    x_house = make_synthetic_house(normalize=normalize_input)       # [1,3,32,32]

    # 2) Reconstruct
    x_cat_hat   = vae_reconstruct(vae_model, x_cat,   decoder_output=decoder_output)
    x_house_hat = vae_reconstruct(vae_model, x_house, decoder_output=decoder_output)

    # 3) Metrics
    cat_metrics = {
        "MSE":  mse(x_cat,   x_cat_hat),
        "L1":   l1(x_cat,    x_cat_hat),
        "PSNR": psnr(x_cat,  x_cat_hat),
    }
    house_metrics = {
        "MSE":  mse(x_house,   x_house_hat),
        "L1":   l1(x_house,    x_house_hat),
        "PSNR": psnr(x_house,  x_house_hat),
    }

    # 4) Save visuals
    save_pair(x_cat,   x_cat_hat,   f"{out_dir}/cat_reconstruction.png",   "Original: Cat",   "Reconstruction: Cat")
    save_pair(x_house, x_house_hat, f"{out_dir}/house_reconstruction.png", "Original: House", "Reconstruction: House")

    # 5) CSV
    import csv
    with open(f"{out_dir}/reconstruction_metrics.csv", "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["image", "MSE", "L1", "PSNR"])
        writer.writerow(["cat",   cat_metrics["MSE"],   cat_metrics["L1"],   cat_metrics["PSNR"]])
        writer.writerow(["house", house_metrics["MSE"], house_metrics["L1"], house_metrics["PSNR"]])

    return cat_metrics, house_metrics


Actual code for OOD reconstruction:

In [18]:
# --- Patch: consistent device + consistent scale ([0,1]) for metrics and visuals ---

import os, torch, numpy as np, matplotlib.pyplot as plt

# CIFAR-10 normalization
_CIFAR10_MEAN = torch.tensor([0.4914, 0.4822, 0.4465]).view(1,3,1,1)
_CIFAR10_STD  = torch.tensor([0.2470, 0.2435, 0.2616]).view(1,3,1,1)

def _denorm_01(x):  # x in normalized space -> [0,1]
    return (x * _CIFAR10_STD + _CIFAR10_MEAN).clamp(0, 1)

def _mse_cpu(a, b):  # both assumed CPU, [0,1]
    return float(((a - b) ** 2).mean().item())

def _l1_cpu(a, b):
    return float((a - b).abs().mean().item())

def _psnr_cpu(a, b, data_range=1.0):
    e = ((a - b) ** 2).mean().item()
    if e == 0: return float('inf')
    import math
    return 20.0 * math.log10(data_range) - 10.0 * math.log10(e)

def _save_pair_cpu(original_01, recon_01, path, title_top, title_bottom):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    plt.figure(figsize=(3.2, 6.0))
    plt.subplot(2,1,1); plt.imshow(original_01.squeeze(0).permute(1,2,0).numpy()); plt.axis("off"); plt.title(title_top)
    plt.subplot(2,1,2); plt.imshow(recon_01.squeeze(0).permute(1,2,0).numpy());    plt.axis("off"); plt.title(title_bottom)
    plt.tight_layout(); plt.savefig(path, dpi=140); plt.close()


def compare_cat_vs_house_reconstruction_FIXED(
    vae_model,
    normalize_input: bool = True,
    decoder_output: str = "sigmoid",
    out_dir: str = "plots/vae_compare"
):
    os.makedirs(out_dir, exist_ok=True)

    # Load inputs (possibly normalized if normalize_input=True)
    x_cat   = load_cifar10_cat_keras(normalize=normalize_input)   # [1,3,32,32], CPU
    x_house = make_synthetic_house(normalize=normalize_input)     # [1,3,32,32], CPU

    # Reconstructions (on the model’s device); function returns [0,1]
    x_cat_hat   = vae_reconstruct(vae_model, x_cat,   decoder_output=decoder_output).detach().cpu()
    x_house_hat = vae_reconstruct(vae_model, x_house, decoder_output=decoder_output).detach().cpu()

    # Map originals to [0,1] before comparing, then move to CPU
    x_cat_img   = (_denorm_01(x_cat)   if normalize_input else x_cat).detach().cpu()
    x_house_img = (_denorm_01(x_house) if normalize_input else x_house).detach().cpu()

    # Metrics on CPU in [0,1]
    cat_metrics = {
        "MSE":  _mse_cpu(x_cat_img,   x_cat_hat),
        "L1":   _l1_cpu(x_cat_img,    x_cat_hat),
        "PSNR": _psnr_cpu(x_cat_img,  x_cat_hat),
    }
    house_metrics = {
        "MSE":  _mse_cpu(x_house_img,   x_house_hat),
        "L1":   _l1_cpu(x_house_img,    x_house_hat),
        "PSNR": _psnr_cpu(x_house_img,  x_house_hat),
    }

    # Visuals
    _save_pair_cpu(x_cat_img,   x_cat_hat,   f"{out_dir}/cat_reconstruction.png",   "Original: Cat",   "Reconstruction: Cat")
    _save_pair_cpu(x_house_img, x_house_hat, f"{out_dir}/house_reconstruction.png", "Original: House", "Reconstruction: House")

    # CSV
    import csv
    with open(f"{out_dir}/reconstruction_metrics.csv", "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["image", "MSE", "L1", "PSNR"])
        w.writerow(["cat",   cat_metrics["MSE"],   cat_metrics["L1"],   cat_metrics["PSNR"]])
        w.writerow(["house", house_metrics["MSE"], house_metrics["L1"], house_metrics["PSNR"]])

    return cat_metrics, house_metrics

# ---- RUN IT ----
cat_metrics, house_metrics = compare_cat_vs_house_reconstruction_FIXED(
    vae,                # <-- your VAE variable name
    normalize_input=True,     # True if you trained with CIFAR-10 normalization
    decoder_output="sigmoid", # "tanh" if your decoder outputs [-1,1]
    out_dir=f"{save_path}/plots/vae_compare"
)
print("Cat metrics:", cat_metrics)
print("House metrics:", house_metrics)


Cat metrics: {'MSE': 0.14413614571094513, 'L1': 0.3470536470413208, 'PSNR': 8.41227095422466}
House metrics: {'MSE': 0.030931850895285606, 'L1': 0.07367781549692154, 'PSNR': 15.095940919735813}
