# Kintsugi Demo

Generate kintsugi overlays for OOD perturbations.

In [None]:
from pathlib import Path

import torch
import numpy as np
from torch import optim
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from PIL import Image

from src.models.vae import KintsugiVAE, vae_loss
torch.manual_seed(42)
np.random.seed(42)

from src.viz.kintsugi import (
    add_gaussian_noise,
    add_occlusion,
    create_kintsugi_grid,
    mix_digits,
    rotate_image,
)

SMOKE_TEST = True

results_dir = Path('results')
results_dir.mkdir(exist_ok=True)
checkpoint_path = results_dir / 'vae_mnist.pt'

device = torch.device('cpu')
model = KintsugiVAE(z_dim=20).to(device)

if checkpoint_path.exists():
    model.load_state_dict(torch.load(checkpoint_path, map_location=device))
else:
    transform = transforms.Compose([transforms.ToTensor()])
    train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
    if SMOKE_TEST:
        train_dataset = Subset(train_dataset, list(range(1000)))
        epochs = 2
    else:
        epochs = 30
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(epochs):
        model.train()
        for batch, _ in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            recon, mu, log_var = model(batch)
            loss = vae_loss(recon, batch, mu, log_var)
            loss.backward()
            optimizer.step()
    torch.save(model.state_dict(), checkpoint_path)

transform = transforms.Compose([transforms.ToTensor()])
test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=True)
batch, _ = next(iter(test_loader))
base_images = batch[:4]

perturbations = {}
perturbations['occlusion'] = torch.stack([add_occlusion(img) for img in base_images])
perturbations['noise'] = torch.stack([add_gaussian_noise(img) for img in base_images])
perturbations['rotation'] = torch.stack([rotate_image(img, angle=45) for img in base_images])
mixed = []
for i in range(4):
    mixed.append(mix_digits(base_images[i], base_images[(i + 1) % 4]))
perturbations['mix'] = torch.stack(mixed)

rows = []
for name, images in perturbations.items():
    grid = create_kintsugi_grid(model, images, n_mc_samples=20, cols=4)
    rows.append(grid)

grid_width = rows[0].width
grid_height = sum(row.height for row in rows)
gallery = Image.new('RGB', (grid_width, grid_height))
y_offset = 0
for row in rows:
    gallery.paste(row, (0, y_offset))
    y_offset += row.height

output_path = results_dir / 'kintsugi_gallery.png'
gallery.save(output_path)
print(f'Saved gallery to {output_path}')
