In [1]:
# ===== 1) Imports & device ====================================================
import math, os, time, random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils
import matplotlib.pyplot as plt

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Using device: cpu


In [2]:
# ===== 2) Hyperparameters =====================================================
IMG_SIZE   = 28
CHANNELS   = 1
BATCH_SIZE = 128
LR         = 2e-4
EPOCHS     = 3          # for demo; increase to ~20-40 for nicer samples
T          = 300        # diffusion steps (200-400 is fine for MNIST)

beta_start, beta_end = 1e-4, 2e-2   # linear beta schedule

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(),                   # [0,1]
    transforms.Lambda(lambda x: x * 2 - 1),  # -> [-1,1]
])

train_set = datasets.MNIST(root="./data", train=True,  download=True, transform=transform)
test_set  = datasets.MNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

def plot_tensor_grid(x, nrow=8, title=None):
    """
    x: [B, C, H, W] in [-1,1]; display first nrow*nrow images
    """
    x = (x[: nrow*nrow].detach().cpu() + 1) / 2  # -> [0,1]
    grid = utils.make_grid(x, nrow=nrow, padding=2)
    plt.figure(figsize=(6,6))
    plt.imshow(grid.permute(1,2,0).squeeze(), interpolation="nearest")
    plt.axis("off")
    if title: plt.title(title)
    plt.show()

100%|██████████| 9.91M/9.91M [00:00<00:00, 62.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.75MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.8MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.34MB/s]


In [4]:
betas  = torch.linspace(beta_start, beta_end, T, device=DEVICE)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat([torch.tensor([1.0], device=DEVICE), alphas_cumprod[:-1]])

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
# DDPM posterior variance (Eq. 7)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)

# ===== 5) q_sample: forward/noising ==========================================
def q_sample(x0, t, noise=None):
    """
    Sample x_t ~ q(x_t | x_0) in closed form:
        x_t = sqrt(alpha_bar_t) * x0 + sqrt(1 - alpha_bar_t) * eps
    x0   : [B, C, H, W] in [-1,1]
    t    : [B] int64 timesteps
    noise: [B, C, H, W], standard normal; if None, sampled here
    """
    if noise is None:
        noise = torch.randn_like(x0)
    s1 = sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)  # underroot alpha
    s2 = sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1) # 1- underroot alpha
    return s1 * x0 + s2 * noise

# ===== 6) Sinusoidal timestep embedding ======================================
def timestep_embedding(timesteps, dim=64):
    """
    Create sinusoidal embeddings for integer timesteps.
    timesteps: [B] int64
    returns  : [B, dim]
    """
    device = timesteps.device
    half = dim // 2
    freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=device).float() / half)  # [half]
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)                                  # [B, half]
    emb = torch.cat([torch.cos(args), torch.sin(args)], dim=1)                                  # [B, 2*half]
    if dim % 2 == 1:
        emb = F.pad(emb, (0,1))
    return emb

In [5]:
class ResBlock(nn.Module):
    def __init__(self, c_in, c_out, t_emb_dim=64):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, c_in)
        self.conv1 = nn.Conv2d(c_in, c_out, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, c_out)
        self.conv2 = nn.Conv2d(c_out, c_out, 3, padding=1)
        self.time_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(t_emb_dim, c_out * 2)  # scale & shift
        )
        self.skip = nn.Conv2d(c_in, c_out, 1) if c_in != c_out else nn.Identity()

    def forward(self, x, t_emb):
        # t_emb -> (gamma, beta) for FiLM conditioning
        temb = self.time_mlp(t_emb)[:, :, None, None]
        gamma, beta = temb.chunk(2, dim=1)

        h = self.conv1(F.silu(self.norm1(x)))
        h = h * (1 + gamma) + beta
        h = self.conv2(F.silu(self.norm2(h)))
        return h + self.skip(x)

class Down(nn.Module):
    def __init__(self, c_in, c_out, t_emb_dim=64):
        super().__init__()
        self.b1 = ResBlock(c_in, c_out, t_emb_dim)
        self.b2 = ResBlock(c_out, c_out, t_emb_dim)
        self.down = nn.Conv2d(c_out, c_out, 4, stride=2, padding=1)  # 2x downsample

    def forward(self, x, t_emb):
        x = self.b1(x, t_emb)
        x = self.b2(x, t_emb)
        skip = x
        x = self.down(x)
        return x, skip

class Up(nn.Module):
    def __init__(self, c_in, c_skip, c_out, t_emb_dim=64):
        super().__init__()
        self.up = nn.ConvTranspose2d(c_in, c_out, 4, stride=2, padding=1)  # 2x upsample: c_in -> c_out
        self.b1 = ResBlock(c_out + c_skip, c_out, t_emb_dim)               # after concat
        self.b2 = ResBlock(c_out, c_out, t_emb_dim)

    def forward(self, x, skip, t_emb):
        x = self.up(x)                       # spatial: 7->14 or 14->28, channels: c_in->c_out
        x = torch.cat([x, skip], dim=1)      # concat with skip (same H,W now)
        x = self.b1(x, t_emb)
        x = self.b2(x, t_emb)
        return x

class UNetMNIST(nn.Module):
    def __init__(self, c=CHANNELS, base=64, t_emb_dim=64):
        super().__init__()
        self.time_mlp = nn.Sequential(
            nn.Linear(t_emb_dim, t_emb_dim * 4),
            nn.SiLU(),
            nn.Linear(t_emb_dim * 4, t_emb_dim),
        )

        self.in_conv = nn.Conv2d(c, base, 3, padding=1)

        self.down1 = Down(base, base, t_emb_dim)         # out: base,     skip s1 @ 28x28
        self.down2 = Down(base, base * 2, t_emb_dim)     # out: base*2,   skip s2 @ 14x14

        self.mid1 = ResBlock(base * 2, base * 2, t_emb_dim)
        self.mid2 = ResBlock(base * 2, base * 2, t_emb_dim)

        # Up: (c_in_from_prev, c_skip, c_out_after_up)
        self.up1 = Up(base * 2, base * 2, base, t_emb_dim)  # 7->14, concat with s2 (base*2), out base
        self.up2 = Up(base,     base,     base, t_emb_dim)  # 14->28, concat with s1 (base),   out base

        self.out_norm = nn.GroupNorm(8, base)
        self.out_conv = nn.Conv2d(base, c, 3, padding=1)

    def forward(self, x, t):
        t_emb = timestep_embedding(t, dim=64)
        t_emb = self.time_mlp(t_emb)

        x = self.in_conv(x)
        x, s1 = self.down1(x, t_emb)   # s1: [B, base,     28, 28]
        x, s2 = self.down2(x, t_emb)   # s2: [B, base*2,   14, 14]

        x = self.mid1(x, t_emb)        # x:  [B, base*2,    7,  7]
        x = self.mid2(x, t_emb)

        x = self.up1(x, s2, t_emb)     # -> [B, base,      14, 14]
        x = self.up2(x, s1, t_emb)     # -> [B, base,      28, 28]

        x = F.silu(self.out_norm(x))
        x = self.out_conv(x)           # [B, 1, 28, 28]
        return x

model = UNetMNIST().to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.0)

In [6]:
def train_epoch(ep):
    model.train()
    running = 0.0
    for i, (x0, _) in enumerate(train_loader):
        x0 = x0.to(DEVICE)                                  # [B,1,28,28] in [-1,1]
        t  = torch.randint(0, T, (x0.size(0),), device=DEVICE).long()
        eps = torch.randn_like(x0)                          # true noise
        x_t = q_sample(x0, t, eps)                          # forward/noised sample

        eps_pred = model(x_t, t)                            # predict noise
        loss = F.mse_loss(eps_pred, eps)                    # ε-pred loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running += loss.item()
        if (i + 1) % 100 == 0:
            print(f"Epoch {ep}  |  Step {i+1:4d}/{len(train_loader)}  |  loss={running/100:.4f}")
            running = 0.0

# reverse

In [7]:
@torch.no_grad()
def p_sample(x_t, t):
    """
    One reverse step: x_t -> x_{t-1}
    Uses Eq.(11) mean and posterior variance from DDPM.
    """
    eps_theta = model(x_t, t)
    alpha_t   = alphas[t].view(-1,1,1,1)
    beta_t    = betas[t].view(-1,1,1,1)
    sqrt_one_minus_ac = sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)
    sqrt_recip_alpha  = sqrt_recip_alphas[t].view(-1,1,1,1)

    # Model mean (Eq. 11)
    model_mean = sqrt_recip_alpha * (x_t - (beta_t / sqrt_one_minus_ac) * eps_theta)

    # At t=0, no further noise
    if (t == 0).all():
        return model_mean

    var = posterior_variance[t].view(-1,1,1,1)
    noise = torch.randn_like(x_t)
    return model_mean + torch.sqrt(var) * noise

@torch.no_grad()
def sample(n=64):
    """
    Draw n samples by ancestral sampling from x_T ~ N(0,I).
    Returns: [n,1,28,28] in [-1,1]
    """
    model.eval()
    x_t = torch.randn(n, CHANNELS, IMG_SIZE, IMG_SIZE, device=DEVICE)
    for step in reversed(range(T)):
        t = torch.full((n,), step, device=DEVICE, dtype=torch.long)
        x_t = p_sample(x_t, t)
    return x_t.clamp(-1, 1)

In [None]:
for ep in range(1, EPOCHS + 1):
    train_epoch(ep)
    # quick qualitative check each epoch
    with torch.no_grad():
        imgs = sample(36)  # 6x6 grid for speed
    plot_tensor_grid(imgs, nrow=6, title=f"Epoch {ep} samples")
