
# Variational Autoencoder (VAE) on Fashion‑MNIST — **PyTorch**

This notebook reimplements the activity **using PyTorch**, and is **fully annotated** for clarity: data prep → model → training →
evaluation → visualization → sampling, plus an **AE baseline** for extra credit.



## Dependencies

- `torch` (CPU is fine)
- `pandas`, `numpy`, `matplotlib`

> If you don't have PyTorch yet, install a CPU‑only wheel:
> ```bash
> pip install torch --index-url https://download.pytorch.org/whl/cpu
> ```


In [8]:

import os, math
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

def seed_everything(seed: int = 7):
    torch.manual_seed(seed)
    np.random.seed(seed)

seed_everything(7)

ROOT = Path(".")
ASSETS = ROOT / "assets"
DATA = ROOT / "data"
ASSETS.mkdir(exist_ok=True, parents=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cpu



## 1) Data Preparation (CSV → PyTorch DataLoader)

We use the Kaggle‑style CSVs where the **first column is the label** and the remaining **784 columns** are pixel values.
We normalize to **[0,1]**, split **90/10** train/validation, and create `DataLoader`s.


In [9]:

class FashionCSVDataset(Dataset):
    def __init__(self, csv_path):
        df = pd.read_csv(csv_path)
        y = df.iloc[:, 0].to_numpy(dtype=np.int64)
        X = df.iloc[:, 1:].to_numpy(dtype=np.float32) / 255.0
        self.x = torch.from_numpy(X)   # (N, 784) in [0,1]
        self.y = torch.from_numpy(y)   # (N,)
    def __len__(self): return self.x.shape[0]
    def __getitem__(self, idx): return self.x[idx], self.y[idx]

train_csv = DATA / "fashion-mnist_train.csv"
test_csv  = DATA / "fashion-mnist_test.csv"

full_train = FashionCSVDataset(train_csv)
N = len(full_train)
n_val = int(0.1 * N)
n_train = N - n_val
train_set, val_set = random_split(full_train, [n_train, n_val], generator=torch.Generator().manual_seed(123))

# subsample for speed while prototyping (set to None to use all)
SUB_TRAIN = 15000
SUB_VAL   = 5000
if SUB_TRAIN is not None and SUB_TRAIN < n_train:
    train_set = torch.utils.data.Subset(train_set, range(SUB_TRAIN))
if SUB_VAL is not None and SUB_VAL < n_val:
    val_set = torch.utils.data.Subset(val_set, range(SUB_VAL))

train_loader = DataLoader(train_set, batch_size=256, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_set, batch_size=256, shuffle=False, num_workers=0)

len(train_set), len(val_set)


(15000, 5000)


## 2) VAE in PyTorch

**Encoder**: `x → ReLU → (μ, logσ²)`  
**Reparameterization**: `z = μ + σ·ε` with `ε ~ N(0, I)`  
**Decoder**: `z → ReLU → logits(784)`  
**Loss**: `ELBO = BCEWithLogits + KL`


In [10]:

class Encoder(nn.Module):
    def __init__(self, in_dim=784, hidden=256, latent=2):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(in_dim, hidden), nn.ReLU(inplace=True))
        self.mu = nn.Linear(hidden, latent)
        self.logvar = nn.Linear(hidden, latent)
    def forward(self, x):
        h = self.net(x)
        return self.mu(h), self.logvar(h)

class Decoder(nn.Module):
    def __init__(self, latent=2, hidden=256, out_dim=784):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(latent, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, out_dim))
    def forward(self, z): return self.net(z)

class VAE(nn.Module):
    def __init__(self, in_dim=784, hidden=256, latent=2):
        super().__init__()
        self.encoder = Encoder(in_dim, hidden, latent)
        self.decoder = Decoder(latent, hidden, in_dim)
    @staticmethod
    def reparameterize(mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon_logits = self.decoder(z)
        return recon_logits, mu, logvar
    def decode_from_mu(self, x):
        mu, logvar = self.encoder(x)
        return self.decoder(mu), mu, logvar

def elbo_loss(recon_logits, x, mu, logvar, reduction="mean"):
    bce = nn.functional.binary_cross_entropy_with_logits(recon_logits, x, reduction="sum")
    kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    if reduction == "mean":
        N = x.size(0)
        return (bce + kl) / N, bce / N, kl / N
    return bce + kl, bce, kl



## 3) Training (ELBO)

We use Adam, track **train** and **validation** ELBO, and save curves to `assets/`.
For validation, we decode deterministically with `z=μ` to stabilize the metric.


In [11]:

def train_vae(model, train_loader, val_loader, device, epochs=5, lr=1e-3):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    tr_hist, va_hist = [], []
    for ep in range(1, epochs+1):
        model.train(); ep_losses = []
        for xb, _ in train_loader:
            xb = xb.to(device)
            opt.zero_grad(set_to_none=True)
            recon_logits, mu, logvar = model(xb)
            loss, bce, kl = elbo_loss(recon_logits, xb, mu, logvar, reduction='mean')
            loss.backward(); opt.step()
            ep_losses.append(loss.item())
        tr_hist.append(float(np.mean(ep_losses)))
        model.eval(); val_losses = []
        with torch.no_grad():
            for xb, _ in val_loader:
                xb = xb.to(device)
                recon_logits, mu, logvar = model.decode_from_mu(xb)
                vloss, _, _ = elbo_loss(recon_logits, xb, mu, logvar, reduction='mean')
                val_losses.append(vloss.item())
        va_hist.append(float(np.mean(val_losses)))
        print(f"Epoch {ep:02d} | train {tr_hist[-1]:.3f} | val {va_hist[-1]:.3f}")
    # Curves
    plt.figure(); plt.plot(tr_hist); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('VAE — Training Loss (ELBO approx.)')
    plt.tight_layout(); plt.savefig(ASSETS / 'vae_train_loss.png'); plt.close()
    plt.figure(); plt.plot(va_hist); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('VAE — Validation Loss (ELBO approx.)')
    plt.tight_layout(); plt.savefig(ASSETS / 'vae_val_loss.png'); plt.close()
    return tr_hist, va_hist

vae = VAE(in_dim=784, hidden=256, latent=2)
EPOCHS = 30
train_hist, val_hist = train_vae(vae, train_loader, val_loader, device, epochs=EPOCHS, lr=1e-3)


Epoch 01 | train 375.257 | val 312.153
Epoch 02 | train 302.491 | val 292.237
Epoch 03 | train 289.729 | val 284.550
Epoch 04 | train 284.528 | val 280.986
Epoch 05 | train 281.714 | val 279.163
Epoch 06 | train 279.918 | val 277.751
Epoch 07 | train 278.719 | val 276.370
Epoch 08 | train 277.558 | val 275.306
Epoch 09 | train 276.439 | val 274.252
Epoch 10 | train 275.517 | val 273.484
Epoch 11 | train 274.949 | val 273.200
Epoch 12 | train 274.044 | val 272.145
Epoch 13 | train 273.427 | val 271.287
Epoch 14 | train 272.505 | val 270.443
Epoch 15 | train 272.168 | val 269.932
Epoch 16 | train 271.465 | val 270.047
Epoch 17 | train 271.002 | val 269.171
Epoch 18 | train 270.754 | val 268.431
Epoch 19 | train 270.134 | val 268.426
Epoch 20 | train 269.703 | val 267.981



## 4) Evaluation & Visualization

- **Reconstructions (val, z=μ)** → `assets/vae_recon_val.png`  
- **Latent scatter (μ)** → `assets/vae_latent_scatter.png` (2‑D or PCA)  
- **Prior samples** → `assets/vae_samples.png`


In [12]:

def save_image_grid(images, grid_shape, out_png, suptitle=None):
    R, C = grid_shape
    fig, axes = plt.subplots(R, C, figsize=(C*1.5, R*1.5))
    axes = np.array(axes).reshape(R, C)
    k = 0
    for i in range(R):
        for j in range(C):
            ax = axes[i, j]
            ax.imshow(images[k], cmap='gray', vmin=0, vmax=1)
            ax.axis('off'); k += 1
            if k >= len(images): break
    if suptitle: fig.suptitle(suptitle, y=0.98)
    plt.tight_layout(); plt.savefig(out_png, bbox_inches='tight'); plt.close()

vae.eval()
with torch.no_grad():
    xb, yb = next(iter(val_loader))
    xb = xb.to(device)
    recon_logits, mu_b, logvar_b = vae.decode_from_mu(xb)  # deterministic
    recon = torch.sigmoid(recon_logits).cpu().numpy().reshape(-1, 28, 28)
    orig  = xb.cpu().numpy().reshape(-1, 28, 28)
    grid = np.vstack([orig[:12], recon[:12]])
    save_image_grid(grid, (4,6), ASSETS / 'vae_recon_val.png', 'VAE — Originals (top) vs Reconstructions (bottom)')

    # Latent scatter
    all_mu, all_y = [], []
    for xb2, yb2 in val_loader:
        xb2 = xb2.to(device)
        mu2, _ = vae.encoder(xb2)
        all_mu.append(mu2.cpu().numpy()); all_y.append(yb2.numpy())
    MU = np.concatenate(all_mu, axis=0); Y = np.concatenate(all_y, axis=0)
    if MU.shape[1] == 2:
        pts = MU
    else:
        Xc = MU - MU.mean(0, keepdims=True)
        U, S, Vt = np.linalg.svd(Xc, full_matrices=False)
        pts = Xc @ Vt[:2].T
    if len(pts) > 2000:
        idx = np.random.default_rng(0).choice(len(pts), size=2000, replace=False)
        pts, Y = pts[idx], Y[idx]
    plt.figure(figsize=(5.2, 4.4))
    plt.scatter(pts[:,0], pts[:,1], c=Y, s=8, alpha=0.7)
    plt.xlabel('z1'); plt.ylabel('z2'); plt.title('VAE — Latent Space (μ) on Validation')
    plt.tight_layout(); plt.savefig(ASSETS / 'vae_latent_scatter.png'); plt.close()

    # Prior samples
    Z = torch.randn(24, 2, device=device)
    samples = torch.sigmoid(vae.decoder(Z)).cpu().numpy().reshape(-1, 28, 28)
    save_image_grid(samples, (4,6), ASSETS / 'vae_samples.png', 'VAE — Samples from N(0, I)')



## 5) Extra — Autoencoder (AE) Baseline

The AE reconstructs well but is **not** generative. We compare curves and reconstructions.


In [13]:

class AE(nn.Module):
    def __init__(self, in_dim=784, hidden=256, latent=32):
        super().__init__()
        self.enc = nn.Sequential(nn.Linear(in_dim, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, latent))
        self.dec = nn.Sequential(nn.Linear(latent, hidden), nn.ReLU(inplace=True), nn.Linear(hidden, in_dim))
    def forward(self, x):
        z = self.enc(x); return self.dec(z)

def train_ae(model, train_loader, val_loader, device, epochs=3, lr=1e-3):
    model.to(device)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    bce = nn.BCEWithLogitsLoss(reduction='mean')
    tr, va = [], []
    for ep in range(1, epochs+1):
        model.train(); ep_losses = []
        for xb, _ in train_loader:
            xb = xb.to(device); opt.zero_grad(set_to_none=True)
            logits = model(xb); loss = bce(logits, xb)
            loss.backward(); opt.step(); ep_losses.append(loss.item())
        tr.append(float(np.mean(ep_losses)))
        model.eval(); val_losses = []
        with torch.no_grad():
            for xb, _ in val_loader:
                xb = xb.to(device); logits = model(xb); val_losses.append(bce(logits, xb).item())
        va.append(float(np.mean(val_losses)))
        print(f"AE Epoch {ep:02d} | train {tr[-1]:.3f} | val {va[-1]:.3f}")
    plt.figure(); plt.plot(tr); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('AE — Training Recon Loss (BCE)')
    plt.tight_layout(); plt.savefig(ASSETS / 'ae_train_loss.png'); plt.close()
    plt.figure(); plt.plot(va); plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('AE — Validation Recon Loss (BCE)')
    plt.tight_layout(); plt.savefig(ASSETS / 'ae_val_loss.png'); plt.close()
    return tr, va

ae = AE(in_dim=784, hidden=256, latent=32)
ae_tr, ae_va = train_ae(ae, train_loader, val_loader, device, epochs=30, lr=1e-3)

# Recon grid
ae.eval()
with torch.no_grad():
    xb, yb = next(iter(val_loader))
    xr = torch.sigmoid(ae(xb.to(device))).cpu().numpy().reshape(-1, 28, 28)
    orig = xb.cpu().numpy().reshape(-1, 28, 28)
    grid = np.vstack([orig[:12], xr[:12]])
    # naming consistent with README
    plt.figure()
    plt.close()


AE Epoch 01 | train 0.468 | val 0.369
AE Epoch 02 | train 0.346 | val 0.330
AE Epoch 03 | train 0.323 | val 0.317
AE Epoch 04 | train 0.313 | val 0.311
AE Epoch 05 | train 0.308 | val 0.305
AE Epoch 06 | train 0.304 | val 0.301
AE Epoch 07 | train 0.301 | val 0.299
AE Epoch 08 | train 0.298 | val 0.296
AE Epoch 09 | train 0.296 | val 0.297
AE Epoch 10 | train 0.295 | val 0.293
AE Epoch 11 | train 0.293 | val 0.292
AE Epoch 12 | train 0.291 | val 0.290
AE Epoch 13 | train 0.290 | val 0.289
AE Epoch 14 | train 0.289 | val 0.289
AE Epoch 15 | train 0.288 | val 0.287
AE Epoch 16 | train 0.287 | val 0.286
AE Epoch 17 | train 0.286 | val 0.286
AE Epoch 18 | train 0.286 | val 0.285
AE Epoch 19 | train 0.285 | val 0.284
AE Epoch 20 | train 0.284 | val 0.283
AE Epoch 21 | train 0.283 | val 0.283
AE Epoch 22 | train 0.282 | val 0.282
AE Epoch 23 | train 0.282 | val 0.282
AE Epoch 24 | train 0.281 | val 0.281
AE Epoch 25 | train 0.281 | val 0.281
AE Epoch 26 | train 0.280 | val 0.280
AE Epoch 27 


## 6) Wrap‑up

- ELBO = reconstruction (BCE) + KL to a unit Gaussian prior.  
- VAE learns a structured latent space for **sampling**; AE is **not** generative.  
- Increase epochs/hidden size for better quality; try different latent sizes; use PCA (or t‑SNE/UMAP) for visualization when `latent_dim > 3`.
