# First we import our CIFAR-10 dataset

In [None]:
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader

DATA_DIR = "data"
BATCH_SIZE = 64

train_dataset = CIFAR10(root=DATA_DIR, train=True, transform=transforms.Compose([transforms.ToTensor()]), download=True)
test_dataset = CIFAR10(root=DATA_DIR, train=False, transform=transforms.Compose([transforms.ToTensor()]), download=True)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

num_train_images = len(train_dataset)
num_test_images = len(test_dataset)

print(f"Loaded batches of size {BATCH_SIZE}:")
print(f" - {len(train_dataloader)} batches for training ({num_train_images} images)")
print(f" - {len(test_dataloader)} batches for validation ({num_test_images} images)")
print(f"for a total of {num_train_images + num_test_images} images (Shape: {train_dataset[0][0].shape}).")

In [None]:
import matplotlib.pyplot as plt


def show_some_images(num_images, num_rows, dataloader, classes):
    num_cols = num_images // num_rows

    images, labels = next(iter(dataloader))
    plt.figure(figsize=(num_cols, num_rows))
    for i in range(num_images):
        img = images[i].permute(1, 2, 0).numpy()
        plt.subplot(num_rows, num_cols, i + 1)
        plt.imshow(img)
        plt.title(classes[labels[i].item()])
        plt.axis("off")
    plt.tight_layout()
    plt.show()


show_some_images(16, 2, train_dataloader, train_dataset.classes)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from pathlib import Path

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Model Implementation - Core Components

In [None]:
# Helper functions
def zero_init(module: nn.Module) -> nn.Module:
    """Initialize module parameters to zero."""
    for p in module.parameters():
        nn.init.zeros_(p.data)
    return module


def get_timestep_embedding(timesteps, embedding_dim: int, dtype=torch.float32, max_timescale=10_000, min_timescale=1):
    """Create sinusoidal timestep embeddings."""
    assert timesteps.ndim == 1
    assert embedding_dim % 2 == 0
    timesteps *= 1000.0
    num_timescales = embedding_dim // 2
    inv_timescales = torch.logspace(
        -np.log10(min_timescale),
        -np.log10(max_timescale),
        num_timescales,
        device=timesteps.device,
    )
    emb = timesteps.to(dtype)[:, None] * inv_timescales[None, :]
    return torch.cat([emb.sin(), emb.cos()], dim=1)


def fourier_encode(x: torch.Tensor, num_frequencies: int = 7) -> torch.Tensor:
    """Apply Fourier feature encoding to input."""
    B, C, H, W = x.shape
    device = x.device
    dtype = x.dtype
    n = torch.arange(num_frequencies, device=device, dtype=dtype)
    freqs = (2.0**n) * (2.0 * math.pi)
    angles = x.unsqueeze(2) * freqs.view(1, 1, -1, 1, 1)
    sin_feats = torch.sin(angles).reshape(B, C * num_frequencies, H, W)
    cos_feats = torch.cos(angles).reshape(B, C * num_frequencies, H, W)
    return torch.cat([x, sin_feats, cos_feats], dim=1)

In [None]:
# ResNet Block
class ResnetBlock(nn.Module):
    def __init__(self, in_channels=128, out_channels=None, condition_dim=None, norm_num_groups=32):
        super().__init__()
        out_channels = out_channels or in_channels
        self.out_channels = out_channels
        self.condition_dim = condition_dim

        self.net1 = nn.Sequential(
            nn.GroupNorm(norm_num_groups, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
        )
        self.net2 = nn.Sequential(
            nn.GroupNorm(norm_num_groups, out_channels),
            nn.SiLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
        )

        if in_channels != out_channels:
            self.shortcut_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        if condition_dim is not None:
            self.cond_proj = zero_init(nn.Linear(condition_dim, out_channels))
        else:
            self.cond_proj = None

    def forward(self, x, condition):
        h = self.net1(x)
        if self.cond_proj is not None:
            condition = self.cond_proj(condition)
            condition = condition[:, :, None, None]
            h = h + condition

        h = self.net2(h)
        if x.shape[1] != self.out_channels:
            x = self.shortcut_conv(x)
        return x + h

In [None]:
# Attention Block
def attention_inner_heads(qkv, num_heads):
    """Computes attention with heads inside of qkv in the channel dimension."""
    bs, width, length = qkv.shape
    ch = width // (3 * num_heads)
    q, k, v = qkv.chunk(3, dim=1)
    scale = ch ** (-1 / 4)
    q = q * scale
    k = k * scale
    new_shape = (bs * num_heads, ch, length)
    q = q.view(*new_shape)
    k = k.view(*new_shape)
    v = v.reshape(*new_shape)
    weight = torch.einsum("bct,bcs->bts", q, k)
    weight = F.softmax(weight.float(), dim=-1).to(weight.dtype)
    out = torch.einsum("bts,bcs->bct", weight, v)
    return out.reshape(bs, num_heads * ch, length)


class Attention(nn.Module):
    def __init__(self, n_heads):
        super().__init__()
        self.n_heads = n_heads

    def forward(self, qkv):
        spatial_dims = qkv.shape[2:]
        qkv = qkv.view(*qkv.shape[:2], -1)
        out = attention_inner_heads(qkv, self.n_heads)
        return out.view(*out.shape[:2], *spatial_dims)


class AttentionBlock(nn.Module):
    def __init__(self, n_heads, n_channels, norm_groups):
        super().__init__()
        self.layers = nn.Sequential(
            nn.GroupNorm(num_groups=norm_groups, num_channels=n_channels),
            nn.Conv2d(n_channels, 3 * n_channels, kernel_size=1),
            Attention(n_heads),
            zero_init(nn.Conv2d(n_channels, n_channels, kernel_size=1)),
        )

    def forward(self, x):
        return self.layers(x) + x

In [None]:
# Up/Down Block
class UpDownBlock(nn.Module):
    def __init__(self, resnet_block, attention_block=None):
        super().__init__()
        self.resnet_block = resnet_block
        self.attention_block = attention_block

    def forward(self, x, cond):
        x = self.resnet_block(x, cond)
        if self.attention_block is not None:
            x = self.attention_block(x)
        return x

In [None]:
# U-Net Model
class UNet(nn.Module):
    def __init__(self, embedding_dim=128, n_blocks=32, input_channels=3, use_fourier=True, num_fourier_features=7):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.n_blocks = n_blocks
        self.use_fourier = use_fourier
        self.num_fourier_features = num_fourier_features
        self.gamma_min = -13.3
        self.gamma_max = 5.0

        attention_params = dict(n_heads=1, n_channels=embedding_dim, norm_groups=32)
        resnet_params = dict(
            in_channels=embedding_dim, out_channels=embedding_dim, condition_dim=4 * embedding_dim, norm_num_groups=32
        )

        self.embed_conditioning = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 4),
            nn.SiLU(),
            nn.Linear(embedding_dim * 4, embedding_dim * 4),
            nn.SiLU(),
        )

        total_input_ch = input_channels * (1 + 2 * num_fourier_features) if use_fourier else input_channels
        self.input_conv = nn.Conv2d(total_input_ch, embedding_dim, kernel_size=3, padding=1)

        self.down_blocks = nn.ModuleList([UpDownBlock(ResnetBlock(**resnet_params)) for _ in range(n_blocks)])

        self.mid_resnet_block_1 = ResnetBlock(**resnet_params)
        self.mid_attn_block = AttentionBlock(**attention_params)
        self.mid_resnet_block_2 = ResnetBlock(**resnet_params)

        resnet_params["in_channels"] *= 2
        self.up_blocks = nn.ModuleList([UpDownBlock(ResnetBlock(**resnet_params)) for _ in range(n_blocks + 1)])

        self.output_conv = nn.Sequential(
            nn.GroupNorm(32, embedding_dim),
            nn.SiLU(),
            zero_init(nn.Conv2d(embedding_dim, input_channels, kernel_size=3, padding=1)),
        )

    def forward(self, z, g_t):
        g_t = g_t.expand(z.shape[0])
        t = (g_t - self.gamma_min) / (self.gamma_max - self.gamma_min)
        t_embedding = get_timestep_embedding(t, self.embedding_dim)
        condition = self.embed_conditioning(t_embedding)

        z_in = fourier_encode(z, self.num_fourier_features) if self.use_fourier else z
        h = self.input_conv(z_in)

        skip_connections = []
        for down_block in self.down_blocks:
            skip_connections.append(h)
            h = down_block(h, condition)

        skip_connections.append(h)
        h = self.mid_resnet_block_1(h, condition)
        h = self.mid_attn_block(h)
        h = self.mid_resnet_block_2(h, condition)

        for up_block in self.up_blocks:
            h = torch.cat([h, skip_connections.pop()], dim=1)
            h = up_block(h, condition)

        prediction = self.output_conv(h)
        return prediction + z


print("UNet model defined")

In [None]:
# Variational Diffusion Model
class VDM(nn.Module):
    def __init__(self, model, image_shape, gamma_min=-13.3, gamma_max=5.0, device=None):
        super().__init__()
        self.model = model
        self.image_shape = image_shape
        self.vocab_size = 256
        self.gamma_min = gamma_min
        self.gamma_max = gamma_max
        self.device = device

    def gamma_schedule(self, t):
        """Linear gamma schedule."""
        return self.gamma_min + (self.gamma_max - self.gamma_min) * t

    def encode(self, x):
        """Encode discrete image values to continuous latent space."""
        x_discrete = (x * 255).round()
        return 2 * ((x_discrete + 0.5) / self.vocab_size) - 1

    def decode(self, z, g_0):
        """Compute log probabilities for reconstruction."""
        x_vals = torch.arange(0, self.vocab_size, device=z.device, dtype=z.dtype)
        x_vals = x_vals.view(1, 1, 1, 1, self.vocab_size)
        x_vals_encoded = 2 * ((x_vals + 0.5) / self.vocab_size) - 1

        if g_0.dim() == 0:
            g_0 = g_0.expand(z.shape[0])
        inv_stdev = torch.exp(-0.5 * g_0).view(-1, 1, 1, 1, 1)
        z_expanded = z.unsqueeze(-1)
        logits = -0.5 * torch.square((z_expanded - x_vals_encoded) * inv_stdev)
        return F.log_softmax(logits, dim=-1)

    def q_sample(self, x0, t, noise=None):
        """Sample z_t ~ q(z_t | x0, t)."""
        gamma_t = self.gamma_schedule(t)
        gamma_t_padded = gamma_t.view(-1, 1, 1, 1)

        mean = x0 * torch.sqrt(torch.sigmoid(-gamma_t_padded))
        scale = torch.sqrt(torch.sigmoid(gamma_t_padded))

        if noise is None:
            noise = torch.randn_like(x0)

        z_t = mean + noise * scale
        return z_t, noise, gamma_t

    def forward(self, batch, noise=None):
        """Compute VDM loss for training."""
        x = batch
        batch_size = x.shape[0]

        # Encode to latent space
        f = self.encode(x)

        # Get gamma values
        g_0 = self.gamma_schedule(torch.tensor(0.0, device=x.device))
        g_1 = self.gamma_schedule(torch.tensor(1.0, device=x.device))
        var_0 = torch.sigmoid(g_0)
        var_1 = torch.sigmoid(g_1)

        # 1. Reconstruction loss
        if noise is None:
            eps_0 = torch.randn_like(f)
        else:
            eps_0 = noise
        z_0 = torch.sqrt(1.0 - var_0) * f + torch.sqrt(var_0) * eps_0
        z_0_rescaled = f + torch.exp(0.5 * g_0) * eps_0

        x_discrete = (x * 255).round().long()
        log_probs = self.decode(z_0_rescaled, g_0)
        x_onehot = F.one_hot(x_discrete, num_classes=self.vocab_size).float()
        log_prob = torch.sum(x_onehot * log_probs, dim=[1, 2, 3, 4])
        loss_recon = -log_prob

        # 2. KL loss
        mean1_sqr = (1.0 - var_1) * torch.square(f)
        loss_klz = 0.5 * torch.sum(mean1_sqr + var_1 - torch.log(var_1) - 1.0, dim=[1, 2, 3])

        # 3. Diffusion loss
        t0 = np.random.uniform(0, 1 / batch_size)
        t = torch.arange(t0, 1.0, 1.0 / batch_size, device=self.device)[:batch_size]
        z_t, eps, gamma_t = self.q_sample(f, t, noise=None)
        eps_pred = self.model(z_t, gamma_t)
        loss_diff_mse = torch.sum(torch.square(eps - eps_pred), dim=[1, 2, 3])
        g_t_grad = self.gamma_max - self.gamma_min
        loss_diff = 0.5 * g_t_grad * loss_diff_mse

        # Total loss
        total_loss_per_sample = loss_recon + loss_klz + loss_diff
        total_loss = torch.mean(total_loss_per_sample)

        # Convert to BPD
        num_dims = np.prod(self.image_shape)
        rescale_to_bpd = 1.0 / (num_dims * np.log(2.0))
        bpd_total = torch.mean(total_loss_per_sample) * rescale_to_bpd

        bpd_components = {
            "bpd_recon": (torch.mean(loss_recon) * rescale_to_bpd).item(),
            "bpd_klz": (torch.mean(loss_klz) * rescale_to_bpd).item(),
            "bpd_diff": (torch.mean(loss_diff) * rescale_to_bpd).item(),
        }

        return total_loss, bpd_total, bpd_components

    @torch.no_grad()
    def sample_p_s_t(self, z, t, s, clip_samples=True):
        """Sample from p(z_s | z_t)."""
        gamma_t = self.gamma_schedule(t)
        gamma_s = self.gamma_schedule(s)
        c = -torch.expm1(gamma_s - gamma_t)

        alpha_t = torch.sqrt(torch.sigmoid(-gamma_t))
        alpha_s = torch.sqrt(torch.sigmoid(-gamma_s))
        sigma_t = torch.sqrt(torch.sigmoid(gamma_t))
        sigma_s = torch.sqrt(torch.sigmoid(gamma_s))

        pred_noise = self.model(z, gamma_t)

        if clip_samples:
            x_start = (z - sigma_t * pred_noise) / alpha_t
            x_start.clamp_(-1.0, 1.0)
            mean = alpha_s * (z * (1 - c) / alpha_t + c * x_start)
        else:
            mean = alpha_s / alpha_t * (z - c * sigma_t * pred_noise)

        scale = sigma_s * torch.sqrt(c)
        return mean + scale * torch.randn_like(z)

    @torch.no_grad()
    def sample(self, batch_size, n_sample_steps=100, clip_samples=True):
        """Generate samples from the model."""
        z = torch.randn((batch_size, *self.image_shape), device=self.device)
        steps = torch.linspace(1.0, 0.0, n_sample_steps + 1, device=self.device)

        for i in range(n_sample_steps):
            z = self.sample_p_s_t(z, steps[i], steps[i + 1], clip_samples)

        # Decode final z_0 to image
        g_0 = self.gamma_schedule(torch.tensor(0.0, device=self.device))
        z_0_rescaled = z / torch.sqrt(torch.sigmoid(-g_0))
        logprobs = self.decode(z_0_rescaled, g_0)
        x = torch.argmax(logprobs, dim=-1)
        return x.float() / (self.vocab_size - 1)


print("VDM model defined")

# Initialize Model

In [None]:
# Create model instance
image_shape = (3, 32, 32)  # CIFAR-10 image shape
unet = UNet(embedding_dim=128, n_blocks=32, input_channels=3, use_fourier=True, num_fourier_features=7)
vdm = VDM(unet, image_shape=image_shape, device=device)
vdm = vdm.to(device)

# Count parameters
num_params = sum(p.numel() for p in vdm.parameters())
print(f"Total parameters: {num_params:,}")

In [None]:
# Load checkpoint if available
checkpoint_paths = list(Path("outputs").glob("**/model.pt"))
if checkpoint_paths:
    checkpoint_path = checkpoint_paths[-1]  # Use most recent
    print(f"Loading checkpoint from: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    # Load model state
    if "model" in checkpoint:
        vdm.load_state_dict(checkpoint["model"], strict=False)
    elif "ema" in checkpoint and checkpoint["ema"] is not None:
        # Try loading EMA model if available
        vdm.load_state_dict(checkpoint["ema"]["ema_model"], strict=False)

    print(f"Loaded checkpoint from epoch {checkpoint.get('step', 'unknown')}")
else:
    print("No checkpoint found - model has random initialization")

# Training Results Visualization

In [None]:
# Training configuration
TRAIN_MODEL = True
NUM_EPOCHS = 5
LEARNING_RATE = 1e-4
TRAIN_BATCH_SIZE = 32

if TRAIN_MODEL:
    from tqdm import tqdm
    import torch.optim as optim

    # Create fresh dataloaders for training
    train_loader = DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        drop_last=True,
        num_workers=0,  # Use 0 for notebook to avoid multiprocessing issues
    )

    # Reinitialize model for training
    unet = UNet(embedding_dim=128, n_blocks=32, input_channels=3, use_fourier=True, num_fourier_features=7)
    vdm = VDM(unet, image_shape=image_shape, device=device)
    vdm = vdm.to(device)

    # Setup optimizer
    optimizer = optim.AdamW(vdm.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.99), weight_decay=0.01, eps=1e-8)

    print(f"Starting training for {NUM_EPOCHS} epochs...")
    print(f"Batch size: {TRAIN_BATCH_SIZE}, Batches per epoch: {len(train_loader)}")

    # Training loop
    training_losses = []
    training_bpds = []

    for epoch in range(NUM_EPOCHS):
        vdm.train()
        epoch_loss = 0.0
        epoch_bpd = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
        for batch_idx, (images, _) in enumerate(pbar):
            images = images.to(device)

            # Forward pass
            optimizer.zero_grad()
            loss, bpd, bpd_components = vdm(images)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(vdm.parameters(), 1.0)
            optimizer.step()

            # Track metrics
            epoch_loss += loss.item()
            epoch_bpd += bpd.item()

            # Update progress bar
            pbar.set_postfix(
                {
                    "loss": f"{loss.item():.4f}",
                    "bpd": f"{bpd.item():.4f}",
                    "recon": f'{bpd_components["bpd_recon"]:.3f}',
                    "kl": f'{bpd_components["bpd_klz"]:.3f}',
                    "diff": f'{bpd_components["bpd_diff"]:.3f}',
                }
            )

        # Calculate epoch averages
        avg_loss = epoch_loss / len(train_loader)
        avg_bpd = epoch_bpd / len(train_loader)
        training_losses.append(avg_loss)
        training_bpds.append(avg_bpd)

        print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}, BPD: {avg_bpd:.4f}")

    print("\nTraining completed!")

    # Plot training progress
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].plot(range(1, NUM_EPOCHS + 1), training_losses, marker="o", linewidth=2)
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Loss")
    axes[0].set_title("Training Loss")
    axes[0].grid(True, alpha=0.3)

    axes[1].plot(range(1, NUM_EPOCHS + 1), training_bpds, marker="o", linewidth=2, color="orange")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("BPD")
    axes[1].set_title("Training BPD")
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()
else:
    print("Training skipped - using loaded model")

# Generate Samples from Model

In [None]:
# Set model to evaluation mode
vdm.eval()

# Generate samples
print("Generating samples from the model...")
print("This may take a few minutes...")

num_samples = 16
n_sample_steps = 100  # Number of denoising steps

with torch.no_grad():
    samples = vdm.sample(batch_size=num_samples, n_sample_steps=n_sample_steps, clip_samples=True)
    samples = samples.cpu()

print(f"Generated {num_samples} samples with {n_sample_steps} denoising steps")

In [None]:
# Visualize generated samples
fig, axes = plt.subplots(4, 4, figsize=(8, 8))
axes = axes.flatten()

for i in range(num_samples):
    img = samples[i].permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)  # Ensure values are in [0, 1]
    axes[i].imshow(img)
    axes[i].axis("off")

plt.suptitle("Generated Samples from VDM", fontsize=14, y=0.98)
plt.tight_layout()
plt.show()

In [None]:
# Show real images alongside generated ones
fig, axes = plt.subplots(2, 8, figsize=(12, 3.5))

# Real images (top row)
real_images, real_labels = next(iter(test_dataloader))
for i in range(8):
    img = real_images[i].permute(1, 2, 0).numpy()
    axes[0, i].imshow(img)
    axes[0, i].axis("off")
    if i == 0:
        axes[0, i].set_title("Real", loc="left", fontsize=10)

# Generated images (bottom row)
for i in range(8):
    img = samples[i].permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    axes[1, i].imshow(img)
    axes[1, i].axis("off")
    if i == 0:
        axes[1, i].set_title("Generated", loc="left", fontsize=10)

plt.tight_layout()
plt.show()

In [None]:
# Generate samples and save intermediate steps to visualize denoising
vdm.eval()
num_steps = 100
steps_to_show = [0, 10, 25, 50, 75, 90, 99]

z = torch.randn((1, *image_shape), device=device)
steps = torch.linspace(1.0, 0.0, num_steps + 1, device=device)

intermediate_steps = []
with torch.no_grad():
    for i in range(num_steps):
        if i in steps_to_show:
            # Decode current z to image space
            g_t = vdm.gamma_schedule(steps[i])
            z_rescaled = z / torch.sqrt(torch.sigmoid(-g_t))
            intermediate_steps.append((i, z_rescaled.cpu().clone()))

        z = vdm.sample_p_s_t(z, steps[i], steps[i + 1], clip_samples=True)

    # Add final result
    g_0 = vdm.gamma_schedule(torch.tensor(0.0, device=device))
    z_0_rescaled = z / torch.sqrt(torch.sigmoid(-g_0))
    logprobs = vdm.decode(z_0_rescaled, g_0)
    final_img = torch.argmax(logprobs, dim=-1).float() / (vdm.vocab_size - 1)
    intermediate_steps.append((num_steps, final_img.cpu()))

# Visualize the denoising process
fig, axes = plt.subplots(1, len(intermediate_steps), figsize=(16, 2.5))

for idx, (step_num, img_tensor) in enumerate(intermediate_steps):
    img = img_tensor[0].permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    axes[idx].imshow(img)
    axes[idx].set_title(f"Step {step_num}", fontsize=9)
    axes[idx].axis("off")

plt.suptitle("Denoising Process: From Noise to Image", fontsize=12, y=1.02)
plt.tight_layout()
plt.show()