<a href="https://colab.research.google.com/github/5w7Tch/GM-final/blob/main/notebooks/train_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# VAE Training on CIFAR-10
**Variational Autoencoder**


## 1. Setup

In [1]:
!git clone https://github.com/5w7Tch/GM-final
%cd GM-final

Cloning into 'GM-final'...
remote: Enumerating objects: 111, done.[K
remote: Counting objects: 100% (111/111), done.[K
remote: Compressing objects: 100% (98/98), done.[K
remote: Total 111 (delta 51), reused 43 (delta 11), pack-reused 0 (from 0)[K
Receiving objects: 100% (111/111), 5.68 MiB | 4.22 MiB/s, done.
Resolving deltas: 100% (51/51), done.
/content/GM-final


In [2]:
!pip install wandb tqdm -q

In [3]:
import torch
import torch.optim as optim
import os
from tqdm.auto import tqdm
from torchvision.utils import make_grid

from src.models import VAE
from src.losses import vae_loss
from src.data import get_dataloader
from src.utils import show_samples, save_samples

import wandb
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: [wandb.login()] Using explicit session credentials for https://api.wandb.ai.
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33meghib22[0m ([33mnurch22-free-university-of-tbilisi-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

## 2. Configuration

In [4]:
config = {
    'latent_dim': 64,
    'beta': 0.1,  # KL weight (beta-VAE)

    'epochs': 60,
    'batch_size': 128,
    'lr': 1e-4,

    'sample_every': 10,
    'save_every': 25,

    'seed': 42
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

wandb.init(
    project="ML2-NCSN-CIFAR10",
    config=config,
    name=f"VAE_latent{config['latent_dim']}_beta{config['beta']}"
)

## 3. Initialize

In [5]:
torch.manual_seed(config['seed'])
torch.cuda.manual_seed_all(config['seed'])
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('samples', exist_ok=True)

train_loader = get_dataloader(batch_size=config['batch_size'])

# TODO: Initialize model after implementing VAE
model = VAE(latent_dim=config['latent_dim']).to(device)
optimizer = optim.Adam(model.parameters(), lr=config['lr'])

100%|██████████| 170M/170M [00:13<00:00, 12.9MB/s]


## 4. Training


In [None]:
model.train()

for epoch in range(config['epochs']):
    pbar = tqdm(train_loader)
    for images, _ in pbar:
        images = images.to(device)

        x_recon, mu, log_var = model(images)
        loss, recon_loss, kl_loss = vae_loss(
            images, x_recon, mu, log_var, config['beta']
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        wandb.log({
            "loss/total": loss.item(),
            "loss/recon": recon_loss.item(),
            "loss/kl": kl_loss.item(),
            "epoch": epoch + 1
        })


        pbar.set_description(
            f"Epoch {epoch+1} | Loss {loss.item():.4f}"
        )

    if (epoch + 1) % config['sample_every'] == 0:
        samples = model.sample(64, device)

        # Normalize samples to [0, 1] for display and logging
        normalized_samples = (samples + 1) / 2

        # Create a grid of samples for logging to wandb as a single image
        # make_grid expects [N, C, H, W] and returns [C, H_grid, W_grid]
        grid = make_grid(normalized_samples.cpu(), nrow=8, padding=2, normalize=False)

        # Convert to HWC format numpy array for wandb.Image
        # .permute(1, 2, 0) converts [C, H, W] to [H, W, C]
        grid_np = grid.permute(1, 2, 0).cpu().numpy()

        wandb.log(
            {
                "samples": wandb.Image(grid_np, caption=f"Epoch {epoch+1}")
            }
        )
        show_samples(samples)

## 5. Evaluation

FID calculation, reconstruction visualization

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision.models import inception_v3
from torchvision.transforms import Resize
from scipy.linalg import sqrtm

# -------------------------
# Reconstruction visualization
# -------------------------
def show_reconstructions(model, dataloader, device, n=8):
    model.eval()
    images, _ = next(iter(dataloader))
    images = images[:n].to(device)

    with torch.no_grad():
        x_recon, _, _ = model(images)

    images = (images + 1) / 2
    x_recon = (x_recon + 1) / 2

    fig, axes = plt.subplots(2, n, figsize=(2*n, 4))

    for i in range(n):
        axes[0, i].imshow(images[i].permute(1, 2, 0).cpu())
        axes[0, i].axis("off")

        axes[1, i].imshow(x_recon[i].permute(1, 2, 0).cpu())
        axes[1, i].axis("off")

    axes[0, 0].set_ylabel("Original")
    axes[1, 0].set_ylabel("Reconstruction")

    plt.tight_layout()
    plt.show()
    return fig


# -------------------------
# Sample generation
# -------------------------
def generate_samples(model, n_samples, device, batch_size=64):
    model.eval()
    samples = []

    with torch.no_grad():
        for _ in range(n_samples // batch_size):
            z = torch.randn(batch_size, model.latent_dim).to(device)
            x = model.decoder(z)
            samples.append(x)

    samples = torch.cat(samples, dim=0)
    samples = (samples + 1) / 2  # [-1,1] → [0,1]
    return samples


# -------------------------
# Inception feature extraction
# -------------------------
def get_inception_features(images, device):
    model = inception_v3(pretrained=True, transform_input=False)
    model.fc = torch.nn.Identity()
    model.to(device)
    model.eval()

    resize = Resize((299, 299))
    features = []

    with torch.no_grad():
        for i in range(0, len(images), 32):
            batch = images[i:i+32].to(device)
            batch = resize(batch)
            feat = model(batch)
            features.append(feat.cpu())

    return torch.cat(features, dim=0).numpy()


# -------------------------
# FID computation
# -------------------------
def calculate_fid(real_feats, fake_feats):
    mu_r, sigma_r = real_feats.mean(0), np.cov(real_feats, rowvar=False)
    mu_f, sigma_f = fake_feats.mean(0), np.cov(fake_feats, rowvar=False)

    diff = mu_r - mu_f
    covmean = sqrtm(sigma_r @ sigma_f)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_r + sigma_f - 2 * covmean)
    return fid


# -------------------------
# Run evaluation
# -------------------------
fig = show_reconstructions(model, train_loader, device)

real_images, _ = next(iter(train_loader))
real_images = (real_images + 1) / 2

fake_images = generate_samples(model, 1024, device)

real_feats = get_inception_features(real_images, device)
fake_feats = get_inception_features(fake_images, device)

fid = calculate_fid(real_feats, fake_feats)
print("FID:", fid)
wandb.log({"FID": fid})
wandb.log({
    "reconstructions": wandb.Image(fig)
})
wandb.finish()