# VAE Training on CIFAR-10
**Variational Autoencoder**


## 1. Setup

In [None]:
!git clone https://github.com/YOUR_USERNAME/ML2_final.git
%cd ML2_final

In [None]:
!pip install wandb tqdm -q

In [None]:
import torch
import torch.optim as optim
import os
from tqdm.auto import tqdm

from src.models import VAE
from src.losses import vae_loss
from src.data import get_dataloader, denormalize
from src.utils import show_samples, save_samples

## 2. Configuration

In [None]:
config = {
    'latent_dim': 128,
    'beta': 1.0,  # KL weight (beta-VAE)

    'epochs': 200,
    'batch_size': 128,
    'lr': 1e-4,

    'sample_every': 10,
    'save_every': 25,

    'seed': 42
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## 3. Initialize

In [None]:
torch.manual_seed(config['seed'])
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('samples', exist_ok=True)

train_loader = get_dataloader(batch_size=config['batch_size'])

# TODO: Initialize model after implementing VAE
model = VAE(latent_dim=config['latent_dim']).to(device)
optimizer = optim.Adam(model.parameters(), lr=config['lr'])

## 4. Training


In [None]:
model.train()

for epoch in range(config['epochs']):
    pbar = tqdm(train_loader)
    for images, _ in pbar:
        images = images.to(device)

        x_recon, mu, log_var = model(images)
        loss, recon_loss, kl_loss = vae_loss(
            images, x_recon, mu, log_var, config['beta']
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_description(
            f"Epoch {epoch+1} | Loss {loss.item():.4f}"
        )

    if (epoch + 1) % config['sample_every'] == 0:
        samples = model.sample(64, device)
        show_samples(samples)


## 5. Evaluation

FID calculation, reconstruction visualization

In [None]:
def generate_samples(model, n_samples, device, batch_size=64):
    model.eval()
    samples = []

    with torch.no_grad():
        for _ in range(n_samples // batch_size):
            z = torch.randn(batch_size, model.latent_dim).to(device)
            x = model.decoder(z)
            samples.append(x)

    samples = torch.cat(samples, dim=0)
    samples = (samples + 1) / 2  # [-1,1] â†’ [0,1]
    return samples

from torchvision.models import inception_v3
from torchvision.transforms import Resize

def get_inception_features(images, device):
    model = inception_v3(pretrained=True, transform_input=False)
    model.fc = torch.nn.Identity()
    model.to(device)
    model.eval()

    resize = Resize((299, 299))
    features = []

    with torch.no_grad():
        for i in range(0, len(images), 32):
            batch = images[i:i+32].to(device)
            batch = resize(batch)
            feat = model(batch)
            features.append(feat.cpu())

    return torch.cat(features, dim=0)

import numpy as np
from scipy.linalg import sqrtm

def calculate_fid(real_feats, fake_feats):
    mu_r, sigma_r = real_feats.mean(0), np.cov(real_feats, rowvar=False)
    mu_f, sigma_f = fake_feats.mean(0), np.cov(fake_feats, rowvar=False)

    diff = mu_r - mu_f
    covmean = sqrtm(sigma_r @ sigma_f)

    if np.iscomplexobj(covmean):
        covmean = covmean.real

    fid = diff @ diff + np.trace(sigma_r + sigma_f - 2 * covmean)
    return fid

real_images, _ = next(iter(train_loader))
real_images = (real_images + 1) / 2

fake_images = generate_samples(model, 1024, device)

real_feats = get_inception_features(real_images, device).numpy()
fake_feats = get_inception_features(fake_images, device).numpy()

fid = calculate_fid(real_feats, fake_feats)
print("FID:", fid)
