In [None]:
import torch
import torchvision
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
import torchvision.transforms as transforms

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f'Using device: {device}')

In [None]:
from AE import ConvAE, train_AE
from diffusion import DiffusionNet, train_diffusion, sample, compute_latent_stats

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

dataset = torchvision.datasets.MNIST(
    root="mnist/",
    train=True,
    download=True,
    transform=transform
)

In [None]:
train_dataloader = DataLoader(dataset, batch_size=8)

In [None]:
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

In [None]:
train_dataloader = DataLoader(dataset, batch_size=2048, shuffle=True, pin_memory=True,
                              num_workers=8, persistent_workers=True)

In [None]:
# hyperparameters
latent_channels = 64  # token_dim — must match AE
T = 100  # diffusion timesteps
n_layers = 4  # transformer layers
n_heads = 4  # attention heads (128 / 4 = 32 per head)
mlp_size = 512  # FFN intermediate size
dropout_rate = 0.05
lr_ae = 1e-3
lr_diff = 5e-4
ae_epochs = 20
diff_epochs = 200

In [None]:
model_AE = ConvAE(latent_channels=latent_channels)
print(f"ConvAE: {model_AE.n_tokens} tokens x {model_AE.token_dim} dim")

In [None]:
train_AE(model_AE, ae_epochs, train_dataloader, lr=lr_ae, device=device)

In [None]:
# visualize AE reconstructions
model_AE.eval()
with torch.no_grad():
    x_sample, _ = next(iter(train_dataloader))
    x_sample = x_sample[:8].to(device)
    recon, _ = model_AE(x_sample)
    recon = torch.sigmoid(recon)

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(8):
    axes[0, i].imshow(x_sample[i, 0].detach().cpu(), cmap='Greys')
    axes[0, i].axis('off')
    axes[1, i].imshow(recon[i, 0].detach().cpu(), cmap='Greys')
    axes[1, i].axis('off')
axes[0, 0].set_title('Original')
axes[1, 0].set_title('Reconstructed')
plt.tight_layout()

In [None]:
# compute latent stats for normalization
latent_mean, latent_std = compute_latent_stats(model_AE, train_dataloader, device)
print(f"Latent mean range: [{latent_mean.min():.3f}, {latent_mean.max():.3f}]")
print(f"Latent std range:  [{latent_std.min():.3f}, {latent_std.max():.3f}]")

In [None]:
diff_model = DiffusionNet(
    token_dim=latent_channels,    # 64 — matches AE
    hidden_size=128,              # internal transformer width
    n_layers=n_layers,
    n_heads=n_heads,
    dropout_rate=dropout_rate,
    mlp_size=mlp_size,
    T=T,
)
print(f"DiffusionNet params: {sum(p.numel() for p in diff_model.parameters()):,}")

In [None]:
train_diffusion(diff_model, model_AE, diff_epochs, train_dataloader, T, lr=lr_diff,
                latent_mean=latent_mean, latent_std=latent_std, device=device)

In [None]:
generated = sample(
    diff_model, model_AE,
    n_samples=16, T=T,
    n_tokens=model_AE.n_tokens,
    token_dim=model_AE.token_dim,
    latent_mean=latent_mean,
    latent_std=latent_std,
    device=device,
)

fig, axes = plt.subplots(2, 8, figsize=(16, 4))
for i in range(16):
    axes[i // 8, i % 8].imshow(generated[i, 0].detach().cpu(), cmap='Greys')
    axes[i // 8, i % 8].axis('off')
plt.suptitle('Generated MNIST Digits')
plt.tight_layout()