# VAE Training (MNIST)

Set `SMOKE_TEST=True` for a lightweight run.

In [None]:
from pathlib import Path

import torch
import numpy as np
from torch import optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms

from src.models.vae import KintsugiVAE, vae_loss
torch.manual_seed(42)
np.random.seed(42)


SMOKE_TEST = True

results_dir = Path('results')
results_dir.mkdir(exist_ok=True)

transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)

if SMOKE_TEST:
    subset_indices = list(range(1000))
    train_dataset = Subset(train_dataset, subset_indices)
    epochs = 2
else:
    epochs = 30

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

device = torch.device('cpu')
model = KintsugiVAE(z_dim=20).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(epochs):
    model.train()
    total_loss = 0.0
    for batch, _ in train_loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        recon, mu, log_var = model(batch)
        loss = vae_loss(recon, batch, mu, log_var)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(train_loader)
    print(f'Epoch {epoch + 1}/{epochs} - Loss: {avg_loss:.4f}')

checkpoint_path = results_dir / 'vae_mnist.pt'
torch.save(model.state_dict(), checkpoint_path)
print(f'Saved checkpoint to {checkpoint_path}')
