# Discrete Diffusion (D3PM) – Practical

Interactive, hands-on introduction to discrete diffusion models (D3PMs). Built as a sequel to the continuous diffusion session; assumes comfort with probability and the Gaussian version.

**You will:**
- Build the discrete forward process via categorical transition matrices.
- Derive and code the exact posterior $q(x_{t-1} \mid x_t, x_0)$ and sample with Gumbel-max.
- Implement the hybrid D3PM loss (VB term + CE on $x_0$ logits).
- Train a tiny D3PM on MNIST (quantized) and sample class-conditional images.
- Probe the process interactively (sliders for $t$, corruption visualization).

**Run order:** top-to-bottom. Coding exercises are clearly marked; downstream cells expect them to be filled. If time is short, focus on Sections 1–5 and use the provided small training loop.



### Why discrete diffusion?
- Data are categorical/quantized (tokens, color bins); Gaussian noise is a poor fit.
- Forward process: multiply by categorical transition matrices instead of adding Gaussian noise.
- Reverse process: model predicts $x_0$ logits; exact posterior $q(x_{t-1}\mid x_t, x_0)$ is closed-form.
- Sampling uses Gumbel-max instead of reparameterized Gaussians.

We will derive each component, implement it, and immediately probe it with toy checks.


## Setup
- Keep runtime light: we use small channel sizes and few diffusion steps for class time.
- If GPU is available, use it; CPU will work but slower for training/sampling.


In [None]:
import math
import random
from typing import Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
import matplotlib.pyplot as plt
import seaborn as sns

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', DEVICE)


In [None]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_seed(42)


## Section 1 – Discrete forward process

We corrupt categorical data with a transition matrix $Q^{(t)}$ each step. For simplicity we start with **uniform corruption** governed by a noise schedule $eta_t$.

- One-step matrix (uniform): $Q^{(t)} = (1 - eta_t) I + eta_t \cdot 	frac{1}{K}\mathbf{1}\mathbf{1}^T$.
- Cumulative: $Q_{1:t} = Q^{(1)} Q^{(2)} \dots Q^{(t)}$.

We will: (1) build these matrices, (2) visualize how the distribution flattens as $t$ grows.



**Theory:** For K categories and noise schedule $eta_t$, one-step corruption is
$$Q^{(t)} = (1-eta_t) I + eta_t 	frac{1}{K}\mathbf{1}\mathbf{1}^T,$$
so each step keeps the original label w.p. $1-eta_t$ and otherwise jumps uniformly. The cumulative
$$Q_{1:t} = Q^{(1)} Q^{(2)} \cdots Q^{(t)}$$
flattens the distribution as $t	o T$ (entropy increases). We will build these matrices and verify row-stochasticity/entropy growth.



### Exercise 1: One-step and cumulative transition matrices
Implement `make_uniform_q_onestep` and `cumulative_q` so rows sum to 1 and cumulative entropy grows with t.


In [None]:

# Check: shapes and entropy monotonicity
K, T = 3, 5
betas = torch.linspace(0.01, 0.2, T)
q1 = make_uniform_q_onestep(K, betas)
assert q1.shape == (T, K, K)
    assert torch.allclose(q1.sum(-1), torch.ones_like(q1.sum(-1)))
qcum = cumulative_q(q1)
x0 = torch.tensor([0])
probs_t = [qcum[t, x0] for t in range(T)]
ents = [-(p * p.log()).sum().item() for p in probs_t]
print("Entropies by t:", ents)
assert all(ents[i] <= ents[i+1] + 1e-6 for i in range(len(ents)-1))
print("Exercise 1 check passed")


In [None]:

def make_uniform_q_onestep(num_classes: int, betas: torch.Tensor) -> torch.Tensor:
    # Build T one-step transition matrices for uniform corruption.
    mats = []
    for beta in betas:
        # with prob (1 - beta) stay; otherwise jump uniformly
        mat = torch.ones(num_classes, num_classes, dtype=torch.float64) * (beta / num_classes)
        mat.fill_diagonal_(1 - (num_classes - 1) * beta / num_classes)
        mats.append(mat)
    return torch.stack(mats, dim=0)


def cumulative_q(q_onestep: torch.Tensor) -> torch.Tensor:
    # Multiply one-step matrices to get cumulative Q_{1:t} for all t.
    qs = []
    acc = q_onestep[0]
    qs.append(acc)
    for i in range(1, q_onestep.shape[0]):
        acc = acc @ q_onestep[i]
        qs.append(acc)
    return torch.stack(qs, dim=0)

# toy example (kept for exploration)
K = 4
T = 10
betas = torch.linspace(1e-3, 0.15, T)
q1 = make_uniform_q_onestep(K, betas)
qcum = cumulative_q(q1)

fig, axes = plt.subplots(1, 3, figsize=(9, 3))
for ax, t in zip(axes, [0, 4, 9]):
    sns.heatmap(qcum[t].numpy(), ax=ax, vmin=0, vmax=1, cmap='magma', cbar=False)
    ax.set_title(f'Q_1:{t+1}')
plt.tight_layout(); plt.show()


## Section 2 – Posterior $q(x_{t-1} \mid x_t, x_0)$

For the uniform case, the exact posterior is proportional to:
\[
q(x_{t-1}\mid x_t, x_0) \propto Q^{(t)}(x_t \mid x_{t-1}) \cdot Q_{1:(t-1)}(x_{t-1}\mid x_0)
\]

We work in logits, then sample via Gumbel-max. This mirrors Eq. (3) in D3PM.



**Theory:** The exact posterior (Eq. 3 in D3PM) decomposes into two factors:
$$q(x_{t-1}\mid x_t,x_0) \propto Q^{(t)}(x_t\mid x_{t-1}) \cdot Q_{1:t-1}(x_{t-1}\mid x_0).$$
We operate in logits, sum the log-factors, and sample via Gumbel-max. When $t=1$, the posterior collapses to $p(x_0)$.



### Exercise 2: Posterior logits and Gumbel sampling
Complete `gather_at` and `q_posterior_logits` (and use `gumbel_max_sample`).


In [None]:

    # Check: posterior probabilities match empirical samples
    K = 3
    betas = torch.linspace(0.01, 0.05, 4)
    q1 = make_uniform_q_onestep(K, betas)
    qcum = cumulative_q(q1)
    q1T = q1.transpose(1, 2)

    x0 = torch.tensor([0, 1])
    xt = torch.tensor([1, 2])
    t = torch.tensor([3, 3])

    logits = q_posterior_logits(x0, xt, t, q1T, qcum)
    probs = torch.softmax(logits, dim=-1)
    # empirical
    N = 2000
    counts = torch.zeros_like(probs)
    for _ in range(N):
        s = gumbel_max_sample(logits)
        for k in range(K):
            counts[..., k] += (s == k).float()
    emp = counts / N
    print("Analytic probs:
", probs)
    print("Empirical probs:
", emp)
    assert torch.allclose(probs, emp, atol=0.05)
    print("Exercise 2 check passed")


In [None]:

def gather_at(a: torch.Tensor, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
    # Index helper: select rows of a by time t and value x.
    t_shape = (x.shape[0],) + (1,) * (x.dim() - 1)
    t = t.view(t_shape).to(torch.long)
    return a[t - 1, x, :]


def q_posterior_logits(x0_logits: torch.Tensor, xt: torch.Tensor, t: torch.Tensor,
                       q_one_step_T: torch.Tensor, q_cum: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    # Compute logits for q(x_{t-1} | x_t, x_0) = log Q^{(t)}(x_t|x_{t-1}) + log Q_{1:t-1}(x_{t-1}|x_0)
    if x0_logits.dtype in (torch.int32, torch.int64):
        x0_logits = torch.log(F.one_hot(x0_logits, q_cum.shape[-1]).float() + eps)
    soft = torch.softmax(x0_logits, dim=-1)

    fact1 = gather_at(q_one_step_T, t, xt)  # from x_t
    qm = q_cum[t - 2].to(dtype=soft.dtype)   # from x_0
    fact2 = torch.einsum('b...k,bkd->b...d', soft, qm)

    logits = torch.log(fact1 + eps) + torch.log(fact2 + eps)
    t_b = t.view((t.shape[0],) + (1,) * xt.dim())
    return torch.where(t_b == 1, x0_logits, logits)


def gumbel_max_sample(logits: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
    # Categorical sampling via Gumbel-max
    noise = -torch.log(-torch.log(torch.clamp(torch.rand_like(logits), eps, 1.0 - eps)))
    return torch.argmax(logits + noise, dim=-1)

# toy check
B = 3
x0 = torch.tensor([0, 1, 2])
xt = torch.tensor([1, 2, 0])
t = torch.tensor([3, 3, 3])
q_one_step_T = q1.transpose(1, 2)
logits = q_posterior_logits(x0, xt, t, q_one_step_T, qcum)
samples = gumbel_max_sample(logits)
print('posterior logits shape', logits.shape)
print('samples', samples)


## Section 3 – Hybrid loss (VB + CE)
We combine a variational bound term matching the model posterior to the true posterior and a cross-entropy on predicted $x_0$ logits (stabilizes training).



**Theory:** Training uses a hybrid objective:
- Variational bound term $\mathrm{VB} = \mathbb{E}[\mathrm{KL}(q(x_{t-1}\mid x_t,x_0)\,||\,q_	heta(x_{t-1}\mid x_t))]$ over $t$.
- Cross-entropy on predicted $x_0$ logits to stabilize learning.
Total: $\mathcal{L} = \lambda\,\mathrm{VB} + \mathrm{CE}(\hat{x}_0, x_0)$.



### Exercise 3: Hybrid loss pieces
Implement `vb` and ensure `hybrid_loss` returns sensible values.


In [None]:

# Check: vb ~ 0 for identical logits; >0 when shifted
logits = torch.randn(4, 5)
assert vb(logits, logits) < 1e-6
logits2 = logits + 1.0
assert vb(logits, logits2) > 0

# Hybrid loss shape sanity
B = 2; K = 4
model_logits = torch.randn(B, 1, 1, 1, K)
x = torch.randint(0, K, (B, 1, 1, 1))
xt = torch.randint(0, K, (B, 1, 1, 1))
t = torch.tensor([2, 2])
q1 = make_uniform_q_onestep(K, torch.linspace(0.01, 0.1, 3))
qcum = cumulative_q(q1)
q1T = q1.transpose(1, 2)
loss, info = hybrid_loss(model_logits, x, xt, t, q1T, qcum)
print('Loss', loss.item(), 'info', info)
print("Exercise 3 check passed")


In [None]:

def vb(dist1_logits: torch.Tensor, dist2_logits: torch.Tensor, eps: float = 1e-9) -> torch.Tensor:
    # KL-like term between two categorical distributions given logits.
    dist1 = dist1_logits.reshape(-1, dist1_logits.shape[-1])
    dist2 = dist2_logits.reshape(-1, dist2_logits.shape[-1])
    p1 = torch.softmax(dist1 + eps, dim=-1)
    return torch.mean(torch.sum(p1 * (torch.log_softmax(dist1 + eps, dim=-1) - torch.log_softmax(dist2 + eps, dim=-1)), dim=-1))


def hybrid_loss(model_logits: torch.Tensor, x: torch.Tensor, xt: torch.Tensor, t: torch.Tensor,
                q_one_step_T: torch.Tensor, q_cum: torch.Tensor, hybrid_coeff: float = 0.001):
    # Hybrid objective: lambda * VB + CE on x0.
    true_post = q_posterior_logits(x, xt, t, q_one_step_T, q_cum)
    pred_post = q_posterior_logits(model_logits, xt, t, q_one_step_T, q_cum)
    vb_loss = vb(true_post, pred_post)

    B = model_logits.shape[0]
    ce_loss = F.cross_entropy(model_logits.reshape(B, -1, model_logits.shape[-1]).flatten(0,1), x.flatten(), reduction='mean')
    total = hybrid_coeff * vb_loss + ce_loss
    return total, {'vb': vb_loss.item(), 'ce': ce_loss.item()}


## Section 4 – Small conditional model (UNet-lite)
A compact UNet-style network to predict $x_0$ logits given $x_t$, time $t$, and optional class label.



**Theory:** A compact UNet-like model predicts $x_0$ logits conditioned on time $t$ and optional class $y$.
- Sinusoidal time embedding modulates features at multiple scales.
- Label embeddings are added per scale for class-conditional sampling.
- Output logits are reshaped to `[B, C, H, W, K]` (per-pixel categorical logits).


In [None]:

class TinyX0Model(nn.Module):
    """Small UNet-like predictor for x0 logits."""
    def __init__(self, n_classes: int, img_channels: int = 1, base_channels: int = 32, cond_classes: int = 10, t_dim: int = 32):
        super().__init__()
        # Project time and condition to feature dims matching each scale
        self.t_proj1 = nn.Linear(t_dim, base_channels)
        self.t_proj2 = nn.Linear(t_dim, base_channels * 2)
        self.cond_emb = nn.Embedding(cond_classes, base_channels)
        self.cond_proj2 = nn.Linear(base_channels, base_channels * 2)

        self.down1 = conv_block(img_channels, base_channels)
        self.down2 = conv_block(base_channels, base_channels * 2)
        self.pool = nn.AvgPool2d(2)

        self.up1 = up_block(base_channels * 2, base_channels)
        self.up2 = conv_block(base_channels, base_channels)
        self.final = nn.Conv2d(base_channels, n_classes * img_channels, 1)
        self.n_classes = n_classes

    def forward(self, x, t, y=None):
        # Scale x to [-1,1] for stability
        x = (2 * x.float() / (self.n_classes - 1)) - 1.0
        emb_t = sinusoidal_time_embed(t.float(), 32)
        t1 = self.t_proj1(emb_t).unsqueeze(-1).unsqueeze(-1)
        t2 = self.t_proj2(emb_t).unsqueeze(-1).unsqueeze(-1)

        cond_raw = 0
        cond1 = 0
        cond2 = 0
        if y is not None:
            cond_raw = self.cond_emb(y)
            cond1 = cond_raw.unsqueeze(-1).unsqueeze(-1)
            cond2 = self.cond_proj2(cond_raw).unsqueeze(-1).unsqueeze(-1)

        h1 = self.down1(x) + t1 + cond1
        h2 = self.down2(self.pool(h1)) + t2 + cond2
        h = self.up1(h2)
        h = self.up2(h + h1)
        out = self.final(h)
        out = out.reshape(out.shape[0], -1, self.n_classes, *x.shape[2:]).transpose(2, -1).contiguous()
        return out


## Section 5 – D3PM wrapper (forward, loss, sampling)
Slim wrapper around forward noising, loss, and sampling. Based on the minimal implementation in `d3pm_runner.py`.



**Theory:** The wrapper handles forward noising, posterior utilities, loss, and reverse sampling.
- Forward: sample $x_t \sim q(x_t\mid x_0)$ using cumulative $Q_{1:t}$ + Gumbel noise.
- Reverse step: model logits $	o$ posterior logits $q_	heta(x_{t-1}\mid x_t)$ $	o$ Gumbel-max sample.
- Sampling loop: iterate $t=T\dots1$, skipping noise on the final step.



### Exercise 4: Forward noising `_q_sample`
Use the cumulative Q and Gumbel to sample $x_t$ from $x_0$.



### Exercise 5: One reverse step `p_sample`
Use model logits → posterior → Gumbel-max; mask noise when t==1.


In [None]:

# Check: with identity logits, p_sample should preserve xt (except final step)
class IdentityModel(nn.Module):
    def __init__(self, K): super().__init__(); self.K = K
    def forward(self, xt, t, y=None):
        return torch.log(F.one_hot(xt, self.K).float() + 1e-9)

K = 3
betas = torch.linspace(0.01, 0.05, 4)
d3pm_id = D3PM(IdentityModel(K), n_T=4, num_classes=K, betas=betas)
xt = torch.tensor([[0,1],[1,2]])
t = torch.tensor([2,2])
xt_prev = d3pm_id.p_sample(xt, t)
assert torch.equal(xt_prev, xt)
print("Exercise 5 check passed")


In [None]:

# Check: _q_sample outputs valid classes
K = 3
betas = torch.linspace(0.01, 0.05, 4)
tmp_model = TinyX0Model(K)
d3pm_tmp = D3PM(tmp_model, n_T=4, num_classes=K, betas=betas)
x0 = torch.tensor([[0,1],[1,2]])
t = torch.tensor([2,3])
noise = torch.rand((*x0.shape, K))
xt = d3pm_tmp._q_sample(x0, t, noise)
assert xt.shape == x0.shape
assert (xt >= 0).all() and (xt < K).all()
print("Exercise 4 check passed")


In [None]:

class D3PM(nn.Module):
    """Minimal D3PM wrapper: forward noising, loss, reverse sampling."""
    def __init__(self, x0_model: nn.Module, n_T: int, num_classes: int, betas: torch.Tensor, hybrid_coeff: float = 0.001):
        super().__init__()
        self.x0_model = x0_model
        self.n_T = n_T
        self.num_classes = num_classes
        self.hybrid_coeff = hybrid_coeff

        q_onestep = make_uniform_q_onestep(num_classes, betas)
        q_cum = cumulative_q(q_onestep)
        self.register_buffer('q_onestep', q_onestep)
        self.register_buffer('q_one_step_T', q_onestep.transpose(1, 2))
        self.register_buffer('q_cum', q_cum)
        self.eps = 1e-6

    def _q_sample(self, x0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        # Sample x_t ~ q(x_t | x0) using cumulative Q and Gumbel noise.
        logits = torch.log(gather_at(self.q_cum, t, x0) + self.eps)
        return gumbel_max_sample(logits + torch.log(noise))

    def forward(self, x: torch.Tensor, y: torch.Tensor = None):
        # Training step: pick random t, corrupt x0->xt, predict logits, compute loss.
        B = x.shape[0]
        t = torch.randint(1, self.n_T, (B,), device=x.device)
        noise = torch.rand((*x.shape, self.num_classes), device=x.device)
        xt = self._q_sample(x, t, noise)

        model_logits = self.x0_model(xt, t, y)
        loss, info = hybrid_loss(model_logits, x, xt, t, self.q_one_step_T, self.q_cum, self.hybrid_coeff)
        return loss, info

    @torch.no_grad()
    def p_sample(self, xt: torch.Tensor, t: torch.Tensor, y: torch.Tensor = None) -> torch.Tensor:
        # One reverse step: model logits -> posterior -> sample (skip noise at t=1).
        model_logits = self.x0_model(xt, t, y)
        pred_post = q_posterior_logits(model_logits, xt, t, self.q_one_step_T, self.q_cum)
        noise = torch.rand((*xt.shape, self.num_classes), device=xt.device)
        t_b = t.view((t.shape[0],) + (1,) * xt.dim())
        mask = (t_b != 1).float()
        gumbel = -torch.log(-torch.log(torch.clamp(noise, self.eps, 1.0 - self.eps)))
        sample = torch.argmax(pred_post + gumbel * mask, dim=-1)
        return sample

    @torch.no_grad()
    def sample(self, shape, y=None, stride: int = 20):
        # Full reverse chain for generation; collect frames every `stride`.
        B = shape[0]
        xt = torch.randint(0, self.num_classes, shape, device=self.q_cum.device)
        imgs = []
        for step, t_int in enumerate(reversed(range(1, self.n_T))):
            t = torch.tensor([t_int] * B, device=xt.device)
            xt = self.p_sample(xt, t, y)
            if step % stride == 0 or t_int == 1:
                imgs.append(xt.detach().cpu())
        return imgs


## Section 6 – Data and discretization
Use MNIST and discretize pixel values into a small number of bins (N) to keep the state space small and sampling fast.


In [None]:
N_CLASSES = 4  # discretization bins
IMG_CHANNELS = 1
BATCH_SIZE = 128
N_T = 200  # diffusion steps (small for speed)
BETAS = torch.linspace(1e-3, 0.1, N_T)

transform = transforms.Compose([transforms.ToTensor()])
train_ds = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
print('Train set size', len(train_ds))


## Section 7 – Training (short run)
Brief training loop to keep runtime manageable. Loss values are coarse but enough to sample.



**Theory/practice:** Training samples a random $t$, corrupts $x_0	o x_t$, predicts logits, and applies the hybrid loss.
- We keep epochs/steps small for classroom runtime; expect noisy but usable samples.
- Gradient clipping helps with stability on small models.


In [None]:
model = TinyX0Model(n_classes=N_CLASSES, img_channels=IMG_CHANNELS).to(DEVICE)
d3pm = D3PM(model, n_T=N_T, num_classes=N_CLASSES, betas=BETAS.to(DEVICE)).to(DEVICE)
opt = torch.optim.AdamW(d3pm.parameters(), lr=1e-3)

EPOCHS = 2  # keep small

for epoch in range(1, EPOCHS + 1):
    d3pm.train()
    running = {'loss': 0, 'vb': 0, 'ce': 0}
    for i, (x, y) in enumerate(train_loader):
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        x_disc = (x * (N_CLASSES - 1)).round().long().clamp(0, N_CLASSES - 1)

        opt.zero_grad()
        loss, info = d3pm(x_disc, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(d3pm.parameters(), 1.0)
        opt.step()

        running['loss'] += loss.item()
        running['vb'] += info['vb']
        running['ce'] += info['ce']
        if (i + 1) % 100 == 0:
            n = i + 1
            print(f"Epoch {epoch} Step {n}: loss {running['loss']/n:.3f} | vb {running['vb']/n:.3f} | ce {running['ce']/n:.3f}")
    print('Epoch done')


## Section 8 – Sampling
Reverse the chain to generate samples. We record frames every `stride` steps.


In [None]:
d3pm.eval()
with torch.no_grad():
    y = torch.arange(0, 16, device=DEVICE) % 10  # class labels
    imgs_seq = d3pm.sample((16, 1, 32, 32), y=y, stride=40)

final = imgs_seq[-1].float() / (N_CLASSES - 1)
grid = vutils.make_grid(final, nrow=4)
plt.figure(figsize=(6,6))
plt.axis('off')
plt.imshow(grid.permute(1,2,0), vmin=0, vmax=1, cmap='gray')
plt.show()


## Section 9 – Interactive probe (forward corruption)
Use a slider over $t$ to visualize how a single image is corrupted by $Q_{1:t}$.


In [None]:
import ipywidgets as widgets

x0_img, _ = next(iter(train_loader))
x0_img = x0_img[:1].to(DEVICE)
x0_disc = (x0_img * (N_CLASSES - 1)).round().long().clamp(0, N_CLASSES - 1)

@widgets.interact(t=(1, min(50, N_T), 1))
def show_corruption(t=1):
    with torch.no_grad():
        logits = torch.log(gather_at(d3pm.q_cum, torch.tensor([t], device=DEVICE), x0_disc) + d3pm.eps)
        xt = torch.argmax(logits, dim=-1)
        img = (xt.float() / (N_CLASSES - 1)).cpu()
    plt.figure(figsize=(3,3)); plt.axis('off');
    plt.imshow(img[0,0], cmap='gray', vmin=0, vmax=1)
    plt.title(f't={t}')
    plt.show()
