**1st Implementation, MNIST Dataset**

In [None]:
# -----------------------
# 0) Imports & Setup
# -----------------------
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import matplotlib.pyplot as plt



In [None]:
# TODO 0.1: Set the seed and device
torch.manual_seed(...)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ============================================================
# PART A — Minimal VAE on MNIST (MLP)
# ============================================================

# -----------------------
# 1) Data (MNIST)
# -----------------------
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 128
transform = transforms.ToTensor()
train_ds = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)




In [None]:
# -----------------------
# 2) Model (MLP VAE)
# -----------------------
class VAE(nn.Module):
    # TODO A2.0: Decide latent dimension
    def __init__(self, z_dim=20):
        super().__init__()
        # Encoder MLP
        # TODO A2.1: Fill input and hidden sizes
        self.fc1 = nn.Linear(28*28, ...)
        # TODO A2.2: Heads for mean and log-variance
        self.fc_mu     = nn.Linear(..., z_dim)
        self.fc_logvar = nn.Linear(..., z_dim)

        # Decoder MLP
        # TODO A2.3: First decoder layer (latent -> hidden)
        self.fc3 = nn.Linear(z_dim, ...)
        # TODO A2.4: Final decoder layer (hidden -> flattened image)
        self.fc4 = nn.Linear(400, ...)  #output size should equal ????

    def encode(self, x):
        """Map x -> (mu, logvar)."""
        # TODO A2.5: Fill the dots
        h = F.relu(self.fc1(...))
        # TODO A2.6: Return mu and logvar from h
        mu = self.fc_mu(...)
        logvar = self.fc_logvar(...)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        """z = mu + sigma * eps, with eps ~ N(0, I)."""
        # TODO A2.7: Compute std from logvar, sample eps, return z
        std = torch.exp(0.5 * ...)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        """Map z -> x_hat in [0,1]."""
        # TODO A2.8: Implement decoder forward pass
        h = F.relu(self.fc3(...))
        x_hat = torch.sigmoid(self.fc4(...))
        return x_hat

    def forward(self, x):
        """Full VAE pass."""
        # TODO A2.9: Flatten x to [B, 784]
        x = x.view(..., ...)
        mu, logvar = self.encode(...)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar

model = VAE(z_dim=20).to(device)


In [None]:

# -----------------------
# 3) Loss (negative ELBO)
# -----------------------
def vae_loss(recon, x, mu, logvar):
    """Use BCE reconstruction + analytic KL."""
    # TODO A3.1: Ensure shapes match for BCE
    b = x.size(0)
    x = x.view(b, -1)
    # TODO A3.2: Compute BCE with reduction='sum'
    bce = F.binary_cross_entropy(..., ..., reduction='sum')
    # TODO A3.3: Compute KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
    kld = -0.5 * torch.sum(1 + ... - ... - ...)
    # TODO A3.4: Return per-example averaged losses (total, bce, kld)
    return (... + ...) / b, ... / b, ... / b


In [None]:

# -----------------------
# 4) Train (MNIST)
# -----------------------
# TODO A4.1: Choose epochs and optimizer
epochs = ...
opt = torch.optim.Adam(model.parameters(), lr=...)

model.train()
for epoch in range(1, epochs+1):
    total, total_bce, total_kld, n = 0.0, 0.0, 0.0, 0
    for x, _ in train_loader:
        x = x.to(device)
        recon, mu, logvar = model(...)
        loss, bce, kld = vae_loss(recon, x, mu, logvar)

        opt.zero_grad()
        loss.backward()
        opt.step()

        bs = x.size(0)
        total     += loss.item() * bs
        total_bce += bce.item()  * bs
        total_kld += kld.item()  * bs
        n += bs
    print(f"[MNIST][Epoch {epoch:02d}] loss={total/n:.2f} recon={total_bce/n:.2f} kl={total_kld/n:.2f}")



In [None]:
# -----------------------
# 5) Visualize (MNIST)
# -----------------------
model.eval()
with torch.no_grad():
    # TODO A5.1: Sample z ~ N(0,I) and decode
    z = torch.randn(..., ..., device=device)
    gen = model.decode(z).view(-1, 1, 28, 28).cpu()
    gen_grid = make_grid(gen, nrow=...)

    # TODO A5.2: Take a batch, reconstruct
    x, _ = next(iter(train_loader))
    x = x.to(device)[:...]
    recon, _, _ = model(...)
    recon = recon.view(-1, 1, 28, 28).cpu()
    recon_grid = make_grid(recon, nrow=...)
    inp_grid = make_grid(x.view(-1, 1, 28, 28).cpu(), nrow=...)

def show(t, title=None):
    plt.figure(figsize=(6,6)); plt.axis('off')
    if title: plt.title(title)
    # TODO A5.3: Permute for HWC and choose cmap for grayscale
    plt.imshow(t.permute(..., ..., ...).squeeze(), cmap='gray')

show(gen_grid,  "MNIST Samples (z ~ N(0,I))")
show(inp_grid,  "MNIST Inputs")
show(recon_grid,"MNIST Reconstructions")





**# PART B — Convolutional VAE on Oxford-IIIT Pet (Cats subset)**






In [None]:

# -----------------------
# 6) Data (Cats subset)
# -----------------------
# TODO B1.1: Set batch size
bs = ...

# TODO B1.2: Define transform to 64x64 tensor in [0,1]
tx = transforms.Compose([
    transforms.Resize(...),  #we want it 64
    transforms.CenterCrop(64),
    transforms.ToTensor()
])

ds_full = datasets.OxfordIIITPet(root='./data', split='trainval', target_types='category', transform=tx, download=True)
cat_names = ['Abyssinian','Bengal','Birman','Bombay','Persian','Ragdoll','Siamese','Sphynx']
cat_idx = { ds_full.class_to_idx[n] for n in cat_names }
idxs = [ i for i in range(len(ds_full)) if ds_full[i][1] in cat_idx ]
ds = Subset(ds_full, idxs)
dl=DataLoader(ds,batch_size=bs,shuffle=True,num_workers=2,pin_memory=True)

In [None]:

# -----------------------
# 7) Conv VAE Model
# -----------------------
class ConvVAE(nn.Module):
    def __init__(self, z_dim=...):
        super().__init__()
        # Encoder: 64 -> 32 -> 16 -> 8 -> 4
        self.enc = nn.Sequential(
            nn.Conv2d(3,  ..., 4, 2, 1),         nn.ReLU(True),             # 64x64 -> 32x32
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),  # -> 16x16
            nn.Conv2d(...,256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True),  # -> 8x8
            nn.Conv2d(...,512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU(True)   # -> 4x4
        )
        # TODO B2.1: FC layers to mean/logvar from flattened 512*4*4
        self.fc_mu = nn.Linear(..., ...)
        self.fc_lv = nn.Linear(..., ...)

        # TODO B2.2: FC to expand from z_dim back to 512*4*4
        self.fc_dec = nn.Linear(..., ...)

        # Decoder: 4 -> 8 -> 16 -> 32 -> 64
        self.dec = nn.Sequential(
            nn.ConvTranspose2d(512,256,4,2,1), nn.BatchNorm2d(256), nn.ReLU(True),  # 4->8
            nn.ConvTranspose2d(256,128,4,2,1), nn.BatchNorm2d(128), nn.ReLU(True),  # 8->16
            nn.ConvTranspose2d(..., 64,4,2,1), nn.BatchNorm2d(64),  nn.ReLU(True),  # 16->32
            nn.ConvTranspose2d(64,   3,4,2,1), nn.Sigmoid()                          # 32->64, [0,1]
        )

    def encode(self, x):
        """x -> (mu, lv)"""
        # TODO B2.3: Flatten encoder features to [B, 512*4*4]
        h = self.enc(x).view(x.size(0), -1)
        return self.fc_mu(...), self.fc_lv(...)

    def reparameterize(self, mu, lv):
        """z = mu + sigma * eps, where sigma = exp(0.5 * lv)"""
        # TODO B2.4: Implement reparameterization
        std = (0.5 * ...).exp()
        eps = torch.randn_like(...)
        return mu + eps * std

    def decode(self, z):
        """z -> x_hat in [0,1]"""
        # TODO B2.5: Map z to [B,512,4,4], then deconvs
        h = self.fc_dec(...).view(..., 512, 4, 4)
        return self.dec(h)

    def forward(self, x):
        mu, lv = self.encode(x)
        z = self.reparameterize(mu, lv)
        recon = self.decode(z)
        return recon, mu, lv

# TODO B2.6: Instantiate ConvVAE and move to device
model = ConvVAE(z_dim=...).to(device)


In [None]:

# -----------------------
# 8) Loss (β-VAE ready)
# -----------------------
def vae_loss_conv(recon, x, mu, lv, beta=...):
    """Use BCE recon on [0,1] images + analytic KL. Return total,bce,kld per-example averages."""
    b = x.size(0)
    # TODO B3.1: BCE with sum, KL in closed form, scale KL by beta
    bce = F.binary_cross_entropy(..., ..., reduction='sum')
    kld = -0.5 * torch.sum(1 + ... - ... - ...)
    return (... + beta * ...) / b, ... / b, ... / b


In [None]:

# -----------------------
# 9) Train (Conv VAE)
# -----------------------
# TODO B4.1: Choose epochs and optimizer hyperparams
epochs = ...
opt = torch.optim.Adam(model.parameters(), lr=..., betas=(..., ...))

model.train()
for ep in range(1, epochs+1):
    total, total_bce, total_kld, n = 0.0, 0.0, 0.0, 0
    for x, _ in dl:
        x = x.to(device)
        recon, mu, lv = model(x)
        loss, bce, kld = vae_loss_conv(recon, x, mu, lv, beta=...)

        opt.zero_grad()
        loss.backward()
        opt.step()

        bs_ = x.size(0)
        total     += loss.item() * bs_
        total_bce += bce.item()  * bs_
        total_kld += kld.item()  * bs_
        n += bs_
    print(f"[CATS][Epoch {ep:03d}] loss={total/n:.3f} recon={total_bce/n:.3f} kl={total_kld/n:.3f}")



In [None]:
# -----------------------
# 10) Visualize (Conv VAE)   no todos here
# -----------------------
model.eval()
with torch.no_grad():
    z=torch.randn(64,128,device=device)
    gen=model.decode(z).cpu()
    x,_=next(iter(dl))
    x=x.to(device)[:64]
    r,_,_=model(x)
    gen_grid=make_grid(gen,8)
    inp_grid=make_grid(x.cpu(),8)
    rec_grid=make_grid(r.cpu(),8)

def show(t):
    plt.figure(figsize=(6,6))
    plt.axis('off')
    plt.imshow(t.permute(1,2,0))
show(inp_grid)
show(rec_grid)



In [None]:
# ============================================================
# OPTIONAL EXPLORATIONS
# ============================================================
# - TODO X1: β-VAE — try beta in {0.5, 1, 4, 10} and discuss recon vs. latent structure.
# - TODO X2: Latent dim sweep — try z_dim ∈ {2, 16, 64, 128, 256}.
# - TODO X3: Replace BCE with MSE; discuss differences in smoothness/sharpness.
# - TODO X4: Latent traversal — fix z, vary one dimension in [-3, 3], visualize effect.
# - TODO X5: Add simple augmentations (RandomHorizontalFlip) and observe robustness.