
# Classroom DDPM (PyTorch) — CUDA / Apple Silicon (MPS) / CPU

This notebook implements a **minimal, classroom-friendly DDPM** (Denoising Diffusion Probabilistic Model) that runs on **NVIDIA CUDA**, **Apple Silicon (MPS/Metal)**, or **CPU**.

**What you’ll get:**
- A tiny U-Net-based DDPM that trains on **CIFAR-10 (32×32, RGB)**
- **Auto device** selection: MPS → CUDA → CPU
- Clear, heavily commented cells you can teach from
- Quick **training**, **sampling**, and **visualization** cells

> Tip: On Macs with M1/M2/M3, PyTorch will automatically use the MPS backend if available.


## 0) Setup

## 1) Reproducibility & Device

In [1]:

import os, math, random
from dataclasses import dataclass
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils

def seed_everything(seed: int = 42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_device() -> torch.device:
    # Prefer MPS (Apple Silicon), else CUDA, else CPU
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")

seed_everything(42)
device = get_device()
print(f"Using device: {device}")
try:
    torch.set_float32_matmul_precision("high")
except Exception as e:
    print("Note:", e)


Using device: mps


## 2) Diffusion Schedule Helpers

In [2]:

@dataclass
class DiffusionSchedule:
    betas: torch.Tensor
    alphas: torch.Tensor
    alpha_cumprod: torch.Tensor
    alpha_cumprod_prev: torch.Tensor
    sqrt_alpha_cumprod: torch.Tensor
    sqrt_one_minus_alpha_cumprod: torch.Tensor
    sqrt_recip_alpha: torch.Tensor
    posterior_variance: torch.Tensor

def make_beta_schedule(T: int, beta_start: float = 1e-4, beta_end: float = 2e-2) -> torch.Tensor:
    # Linear schedule from the original DDPM paper
    return torch.linspace(beta_start, beta_end, T)

def build_schedule(T: int) -> DiffusionSchedule:
    betas = make_beta_schedule(T)
    alphas = 1.0 - betas
    alpha_cumprod = torch.cumprod(alphas, dim=0)
    alpha_cumprod_prev = torch.cat([torch.tensor([1.0]), alpha_cumprod[:-1]])

    sqrt_alpha_cumprod = torch.sqrt(alpha_cumprod)
    sqrt_one_minus_alpha_cumprod = torch.sqrt(1.0 - alpha_cumprod)
    sqrt_recip_alpha = torch.sqrt(1.0 / alphas)

    # From Ho et al. (posterior variance)
    posterior_variance = betas * (1.0 - alpha_cumprod_prev) / (1.0 - alpha_cumprod)

    return DiffusionSchedule(
        betas=betas,
        alphas=alphas,
        alpha_cumprod=alpha_cumprod,
        alpha_cumprod_prev=alpha_cumprod_prev,
        sqrt_alpha_cumprod=sqrt_alpha_cumprod,
        sqrt_one_minus_alpha_cumprod=sqrt_one_minus_alpha_cumprod,
        sqrt_recip_alpha=sqrt_recip_alpha,
        posterior_variance=posterior_variance,
    )


## 3) Time Embeddings (Sinusoidal)

In [3]:

import math

class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim

    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half_dim = self.dim // 2
        device = t.device
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
        emb = t.float().unsqueeze(1) * emb.unsqueeze(0)
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
        if self.dim % 2 == 1:
            emb = F.pad(emb, (0, 1))
        return emb


## 4) Tiny U-Net (Classroom-Friendly)

In [4]:

class ResBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, time_dim: int, groups: int = 8):
        super().__init__()
        self.norm1 = nn.GroupNorm(groups, in_ch)
        self.act1 = nn.SiLU()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)

        self.norm2 = nn.GroupNorm(groups, out_ch)
        self.act2 = nn.SiLU()
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(time_dim, out_ch)
        )

        self.res_conv = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor) -> torch.Tensor:
        h = self.conv1(self.act1(self.norm1(x)))
        t = self.time_mlp(t_emb).type_as(h)
        h = h + t[:, :, None, None]
        h = self.conv2(self.act2(self.norm2(h)))
        return h + self.res_conv(x)

class Down(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, time_dim: int):
        super().__init__()
        self.block1 = ResBlock(in_ch, out_ch, time_dim)
        self.block2 = ResBlock(out_ch, out_ch, time_dim)
        self.down = nn.Conv2d(out_ch, out_ch, 4, stride=2, padding=1)

    def forward(self, x: torch.Tensor, t_emb: torch.Tensor):
        x = self.block1(x, t_emb)
        x = self.block2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip

class Up(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, time_dim: int):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 4, stride=2, padding=1)
        self.block1 = ResBlock(out_ch * 2, out_ch, time_dim)
        self.block2 = ResBlock(out_ch, out_ch, time_dim)

    def forward(self, x: torch.Tensor, skip: torch.Tensor, t_emb: torch.Tensor):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        x = self.block1(x, t_emb)
        x = self.block2(x, t_emb)
        return x

class TinyUNet(nn.Module):
    """A very small U-Net that works for 32x32 images (CIFAR-10)."""
    def __init__(self, base_ch: int = 64, time_dim: int = 256):
        super().__init__()
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim * 4),
            nn.SiLU(),
            nn.Linear(time_dim * 4, time_dim)
        )

        self.in_conv = nn.Conv2d(3, base_ch, 3, padding=1)

        # Down: 32 -> 16 -> 8
        self.down1 = Down(base_ch, base_ch, time_dim)
        self.down2 = Down(base_ch, base_ch * 2, time_dim)

        # Bottleneck at 8x8
        self.bot1 = ResBlock(base_ch * 2, base_ch * 4, time_dim)
        self.bot2 = ResBlock(base_ch * 4, base_ch * 2, time_dim)

        # Up: 8 -> 16 -> 32
        self.up1 = Up(base_ch * 2, base_ch, time_dim)
        self.up2 = Up(base_ch, base_ch, time_dim)

        self.out_norm = nn.GroupNorm(8, base_ch)
        self.out_act = nn.SiLU()
        self.out_conv = nn.Conv2d(base_ch, 3, 3, padding=1)

        self._init_weights()

    def _init_weights(self):
        def init(m):
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
        self.apply(init)

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        t_emb = self.time_mlp(t)

        x = self.in_conv(x)
        x, s1 = self.down1(x, t_emb)
        x, s2 = self.down2(x, t_emb)

        x = self.bot1(x, t_emb)
        x = self.bot2(x, t_emb)

        x = self.up1(x, s2, t_emb)
        x = self.up2(x, s1, t_emb)

        x = self.out_conv(self.out_act(self.out_norm(x)))
        return x


## 5) DDPM Core (Forward/Reverse Diffusion)

In [5]:

class DDPM(nn.Module):
    def __init__(self, model: nn.Module, T: int = 1000):
        super().__init__()
        self.model = model
        self.T = T
        sched = build_schedule(T)
        for k, v in sched.__dict__.items():
            self.register_buffer(k, v)

    @torch.no_grad()
    def q_sample(self, x_start: torch.Tensor, t: torch.Tensor, noise: torch.Tensor = None) -> torch.Tensor:
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_ac = self.sqrt_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_om = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_ac * x_start + sqrt_om * noise

    def p_losses(self, x_start: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        noise = torch.randn_like(x_start)
        x_noisy = self.q_sample(x_start, t, noise)
        noise_pred = self.model(x_noisy, t)  # predict epsilon
        return F.mse_loss(noise_pred, noise)

    @torch.no_grad()
    def p_sample(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        betas_t = self.betas[t].view(-1, 1, 1, 1)
        sqrt_one_minus_ac_t = self.sqrt_one_minus_alpha_cumprod[t].view(-1, 1, 1, 1)
        sqrt_recip_alpha_t = self.sqrt_recip_alpha[t].view(-1, 1, 1, 1)

        eps_theta = self.model(x, t)
        model_mean = sqrt_recip_alpha_t * (x - betas_t * eps_theta / sqrt_one_minus_ac_t)

        if (t == 0).all():
            return model_mean
        else:
            posterior_var = self.posterior_variance[t].view(-1, 1, 1, 1)
            noise = torch.randn_like(x)
            return model_mean + torch.sqrt(posterior_var) * noise

    @torch.no_grad()
    def sample(self, batch_size: int, device: torch.device, channels: int = 3, size: int = 32) -> torch.Tensor:
        x = torch.randn(batch_size, channels, size, size, device=device)
        for i in reversed(range(self.T)):
            t = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(x, t)
        return x


## 6) Data: CIFAR-10 Loader (Normalized to [-1, 1])

In [6]:

def get_cifar10_loader(data_root: str, batch_size: int, num_workers: int = 2) -> DataLoader:
    transform = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),  # map to [-1, 1]
    ])
    train_set = datasets.CIFAR10(root=data_root, train=True, download=True, transform=transform)
    return DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)


## 7) Training Loop

In [7]:

def train_ddpm(epochs=2, batch_size=128, lr=2e-4, timesteps=1000, base_channels=64, time_dim=256, 
               data_root="data", num_workers=2, run_dir="runs", log_every=50, sample_every=1000, seed=42):
    seed_everything(seed)
    dev = get_device()
    print(f"Training on: {dev}")
    os.makedirs(run_dir, exist_ok=True)

    model = TinyUNet(base_ch=base_channels, time_dim=time_dim).to(dev)
    ddpm = DDPM(model, T=timesteps).to(dev)
    loader = get_cifar10_loader(data_root, batch_size, num_workers)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)

    global_step = 0
    for epoch in range(epochs):
        model.train()
        for x, _ in loader:
            x = x.to(dev)
            t = torch.randint(0, ddpm.T, (x.size(0),), device=dev).long()
            loss = ddpm.p_losses(x, t)

            opt.zero_grad(set_to_none=True)
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            opt.step()

            if global_step % log_every == 0:
                print(f"epoch {epoch+1}/{epochs} | step {global_step:06d} | loss {loss.item():.4f}")

            if global_step % sample_every == 0 and global_step > 0:
                model.eval()
                with torch.no_grad():
                    samples = ddpm.sample(batch_size=36, device=dev)
                os.makedirs(run_dir, exist_ok=True)
                grid_path = os.path.join(run_dir, f"sample_step_{global_step}.png")
                grid = vutils.make_grid(samples, nrow=6, normalize=True, value_range=(-1, 1))
                vutils.save_image(grid, grid_path)
                print(f"Saved {grid_path}")
                model.train()

            global_step += 1

        ckpt_path = os.path.join(run_dir, "ddpm.pt")
        torch.save({"model": model.state_dict(), "ddpm_T": ddpm.T}, ckpt_path)
        print(f"Saved checkpoint to {ckpt_path}")


## 8) Sampling

In [8]:

@torch.no_grad()
def sample_images(n_samples=36, outfile="samples/grid.png", timesteps=1000, base_channels=64, time_dim=256, checkpoint="runs/ddpm.pt"):
    dev = get_device()
    print(f"Sampling on: {dev}")
    os.makedirs(os.path.dirname(outfile), exist_ok=True)

    model = TinyUNet(base_ch=base_channels, time_dim=time_dim).to(dev)
    T = timesteps
    if os.path.isfile(checkpoint):
        ckpt = torch.load(checkpoint, map_location=dev)
        model.load_state_dict(ckpt["model"])
        T = ckpt.get("ddpm_T", timesteps)
        print(f"Loaded checkpoint with T={T}")
    else:
        print("No checkpoint found; sampling with randomly initialized model (results will be noise).")

    ddpm = DDPM(model, T=T).to(dev)
    samples = ddpm.sample(batch_size=n_samples, device=dev, channels=3, size=32)
    grid = vutils.make_grid(samples, nrow=int(n_samples**0.5), normalize=True, value_range=(-1, 1))
    vutils.save_image(grid, outfile)
    print(f"Saved {outfile}")


## 9) Visualize the Beta Schedule

In [9]:

import matplotlib.pyplot as plt

def visualize_beta_schedule(timesteps=1000, out="runs/beta_schedule.png"):
    os.makedirs("runs", exist_ok=True)
    sched = build_schedule(timesteps)
    plt.figure()
    plt.plot(sched.betas.cpu().numpy())
    plt.title("Beta schedule (linear)")
    plt.xlabel("t")
    plt.ylabel("beta_t")
    plt.savefig(out, bbox_inches="tight")
    print(f"Saved {out}")


## 10) Quickstart: Train → Sample

In [10]:
# TRAIN (recommended demo settings for a quick classroom run)
train_ddpm(epochs=2, batch_size=128, lr=2e-4, timesteps=1000, base_channels=64, time_dim=256,
          data_root="data", num_workers=2, run_dir="runs", log_every=50, sample_every=1000, seed=42)

Training on: mps


100%|██████████| 170M/170M [06:29<00:00, 437kB/s]  



RuntimeError: Expected weight to be a vector of size equal to the number of channels in input, but got weight of shape [128] and input of shape [128, 192, 16, 16]

In [None]:

# SAMPLE (after training). Will use runs/ddpm.pt by default
# sample_images(n_samples=36, outfile="samples/grid.png", timesteps=1000, base_channels=64, time_dim=256, checkpoint="runs/ddpm.pt")

print("Uncomment the line above to sample.")


In [None]:

# VIZ (optional): visualize beta schedule
# visualize_beta_schedule(timesteps=1000, out="runs/beta_schedule.png")

print("Uncomment the line above to visualize the beta schedule.")


In [5]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

# Check PyTorch MPS (Metal Performance Shaders) availability
print("PyTorch version:", torch.__version__)
print("MPS (Metal) device available:", torch.backends.mps.is_available())
print("MPS (Metal) device built:", torch.backends.mps.is_built())

# Set up device
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS device")
else:
    device = torch.device("cpu")
    print("Using CPU device")

# Test device with a simple operation
x = torch.randn(5, 3).to(device)
print("\nTest tensor device:", x.device)
y = x @ x.t()
print("Matrix multiplication successful on device:", y.device)

PyTorch version: 2.6.0
MPS (Metal) device available: True
MPS (Metal) device built: True
Using MPS device

Test tensor device: mps:0
Matrix multiplication successful on device: mps:0
