In [1]:
pip install matplotlib

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install gdown

Note: you may need to restart the kernel to use updated packages.


In [3]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
import gdown
import os

file_id = "1WuKug4sbbvmNGrZ5PUaCUdw8tnM6xiqX"
out_path = "mnist_cnn_weights.pth"

if not os.path.exists(out_path):
    gdown.download(
        f"https://drive.google.com/uc?id={file_id}",
        out_path,
        quiet=False,
    )
else:
    print("EMA weights already downloaded.")


Downloading...
From: https://drive.google.com/uc?id=1WuKug4sbbvmNGrZ5PUaCUdw8tnM6xiqX
To: /home/onyxia/work/denoising-diffusion-model/measurement code/mnist_cnn_weights.pth
100%|██████████| 1.01M/1.01M [00:00<00:00, 43.0MB/s]


In [5]:
import math
import torch
import torch.nn as nn
import torchvision.utils as vutils
import matplotlib.pyplot as plt

class ForwardOU:
    
    def __init__(self, lambda_=1.0):
        self.lmbd = float(lambda_)

    def mean(self, x0, t):
        # t: [B]
        a = torch.exp(-self.lmbd * t)[:, None, None, None]
        return a * x0

    def std(self, t):
        # std(t) = sqrt(1 - exp(-2 lambda t))
        return torch.sqrt(1.0 - torch.exp(-2.0 * self.lmbd * t))[:, None, None, None]

    def diffusion_coeff(self, t):
        # g(t) = sqrt(2 lambda)  (constant for OU)
        return torch.full_like(t, math.sqrt(2.0 * self.lmbd))



In [6]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim=64):
        super().__init__()
        assert dim % 2 == 0
        self.dim = dim
        self.lin1 = nn.Linear(dim, dim)
        self.act = nn.SiLU()
        self.lin2 = nn.Linear(dim, dim)

    def forward(self, t):
        """
        t: [B] in [0,1]
        returns: [B, dim]
        """
        half = self.dim // 2
        # sinusoidal features
        freqs = torch.exp(
            torch.arange(half, device=t.device, dtype=t.dtype)
            * (-math.log(10000.0) / (half - 1))
        )  # [half]
        args = t[:, None] * freqs[None, :]  # [B, half]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=1)  # [B, dim]
        # small MLP
        emb = self.lin1(emb)
        emb = self.act(emb)
        emb = self.lin2(emb)
        return emb


class ScoreNet(nn.Module):
    def __init__(self, time_dim=64):
        super().__init__()
        self.time_mlp = TimeEmbedding(dim=time_dim)
        self.net = nn.Sequential(
            nn.Conv2d(1 + time_dim, 64, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.SiLU(),
            nn.Conv2d(64, 1, 3, padding=1),
        )

    def forward(self, x, t):
        emb = self.time_mlp(t)              # [B, time_dim]
        emb = emb[:, :, None, None]         # [B, time_dim, 1, 1]
        emb = emb.expand(-1, -1, x.size(2), x.size(3))
        inp = torch.cat([x, emb], dim=1)
        return self.net(inp)

@torch.no_grad()
def sample_reverse_euler_maruyama(
    model,
    sde: ForwardOU,
    num_steps=1000,
    batch_size=64,
    device="cuda",
    t_min=1e-3,
    img_size=28,
):

    model.eval()

    T = 1.0
    t_grid = torch.linspace(T, t_min, num_steps, device=device)  # decreasing

    # Start from approx N(0, 1) at t=T (OU marginal is close to std~1 for typical lambda/T).
    x = torch.randn(batch_size, 1, img_size, img_size, device=device)

    for i in range(num_steps - 1):
        t_cur = t_grid[i]
        t_next = t_grid[i + 1]
        dt = t_next - t_cur  # dt < 0

        t_batch = torch.full((batch_size,), t_cur, device=device)

        g = sde.diffusion_coeff(t_batch)          # [B]
        g2 = (g ** 2).view(batch_size, 1, 1, 1)   # [B,1,1,1]
        g = g.view(batch_size, 1, 1, 1)

        score = model(x, t_batch)                 # [B,1,28,28]
        drift = -sde.lmbd * x - g2 * score        # f - g^2 * score

        noise = torch.randn_like(x)
        x = x + drift * dt + g * torch.sqrt(-dt) * noise

    return x

def denorm(x):
    # [-1,1] -> [0,1]
    return torch.clamp((x + 1.0) * 0.5, 0.0, 1.0)

@torch.no_grad()
def show_samples(x, nrow=8, title="Samples"):
    x = denorm(x.detach().cpu())
    grid = vutils.make_grid(x, nrow=nrow, padding=2)
    plt.figure(figsize=(6, 6))
    plt.imshow(grid.permute(1, 2, 0).squeeze(), cmap="gray")
    plt.axis("off")
    plt.title(title)
    plt.show()




In [7]:
ckpt = torch.load(out_path, map_location=device)

# Try to recover hyperparams from checkpoint (fallback to defaults)
ckpt_cfg = ckpt.get("cfg", {}) or {}
time_dim = int(ckpt_cfg.get("time_dim", 64))
lambda_ = float(ckpt.get("sde_lambda", ckpt_cfg.get("lambda_", 1.0)))

print("Loaded checkpoint.")
print("time_dim:", time_dim)
print("lambda_ :", lambda_)

sde = ForwardOU(lambda_=lambda_)
model = ScoreNet(time_dim=time_dim).to(device)
model.load_state_dict(ckpt["model"])
model.eval()


import time
import numpy as np
import torch
import torch.nn as nn

# ---------------------------
# Utils
# ---------------------------
def _sync_if_cuda(device):
    if isinstance(device, str):
        is_cuda = "cuda" in device
    else:
        is_cuda = (device.type == "cuda")
    if is_cuda and torch.cuda.is_available():
        torch.cuda.synchronize()

def _percentiles(xs, ps=(50, 95)):
    xs = np.asarray(xs, dtype=np.float64)
    return {f"p{p}": float(np.percentile(xs, p)) for p in ps}

class TimedModel(nn.Module):
    """Counts NFEs and times model forward calls."""
    def __init__(self, model: nn.Module, device="cuda"):
        super().__init__()
        self.model = model
        self.device = device
        self.reset_stats()

    def reset_stats(self):
        self.nfe = 0
        self.model_ms = 0.0

    @torch.no_grad()
    def forward(self, *args, **kwargs):
        self.nfe += 1
        if torch.cuda.is_available() and ("cuda" in str(self.device)):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = self.model(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()
            self.model_ms += start.elapsed_time(end)
            return out
        else:
            t0 = time.perf_counter()
            out = self.model(*args, **kwargs)
            t1 = time.perf_counter()
            self.model_ms += (t1 - t0) * 1000.0
            return out


# ---------------------------
# Point this to YOUR sampler
# ---------------------------
SAMPLE_FN_NAME = "sample_reverse_euler_maruyama"  # change if your notebook uses another name
assert SAMPLE_FN_NAME in globals(), f"Sampler {SAMPLE_FN_NAME} not found."
_sampler = globals()[SAMPLE_FN_NAME]


@torch.no_grad()
def sampler_timed(timed_model, sde, *, num_steps, batch_size, device, t_min):
    # Call notebook sampler unchanged.
    # If your sampler uses different kwarg names, update them here.
    return _sampler(
        model=timed_model,
        sde=sde,
        num_steps=num_steps,
        batch_size=batch_size,
        device=device,
        t_min=t_min,
    )


def run_sampling_benchmark(
    model: nn.Module,
    sde,
    sampler_fn,
    *,
    num_steps: int,
    batch_size: int,
    device="cuda",
    t_min=0.02,
    warmup_runs=2,
    timed_runs=10,
    reset_cuda_peak_mem=True,
):
    timed_model = TimedModel(model, device=device).to(device)

    # Warmup
    for _ in range(warmup_runs):
        timed_model.reset_stats()
        _sync_if_cuda(device)
        _ = sampler_fn(timed_model, sde, num_steps=num_steps, batch_size=batch_size, device=device, t_min=t_min)
        _sync_if_cuda(device)

    per_img_total_ms = []
    per_img_model_ms = []
    per_img_overhead_ms = []
    nfes = []
    ms_per_forward_total = []
    ms_per_forward_model = []
    throughput_img_s = []
    peak_alloc_mb = []
    peak_reserved_mb = []

    for _ in range(timed_runs):
        timed_model.reset_stats()

        if reset_cuda_peak_mem and torch.cuda.is_available() and ("cuda" in str(device)):
            torch.cuda.reset_peak_memory_stats()

        _sync_if_cuda(device)
        t0 = time.perf_counter()
        _ = sampler_fn(timed_model, sde, num_steps=num_steps, batch_size=batch_size, device=device, t_min=t_min)
        _sync_if_cuda(device)
        t1 = time.perf_counter()

        total_ms_batch = (t1 - t0) * 1000.0
        model_ms_batch = float(timed_model.model_ms)
        overhead_ms_batch = total_ms_batch - model_ms_batch

        nfe = int(timed_model.nfe)
        nfes.append(nfe)

        total_ms_img = total_ms_batch / batch_size
        model_ms_img = model_ms_batch / batch_size
        overhead_ms_img = overhead_ms_batch / batch_size

        per_img_total_ms.append(total_ms_img)
        per_img_model_ms.append(model_ms_img)
        per_img_overhead_ms.append(overhead_ms_img)

        denom = max(nfe, 1)
        ms_per_forward_total.append(total_ms_batch / denom)
        ms_per_forward_model.append(model_ms_batch / denom)

        throughput_img_s.append(1000.0 / max(total_ms_img, 1e-12))

        if torch.cuda.is_available() and ("cuda" in str(device)):
            peak_alloc_mb.append(torch.cuda.max_memory_allocated() / (1024**2))
            peak_reserved_mb.append(torch.cuda.max_memory_reserved() / (1024**2))

    summary = {
        "num_steps": num_steps,
        "batch_size": batch_size,
        "t_min": t_min,
        "timed_runs": timed_runs,

        "nfe_sample": _percentiles(nfes, ps=(50, 95)),
        "nfe_img_legacy_p50": _percentiles([n / batch_size for n in nfes], ps=(50,))["p50"],

        "total_ms_img": _percentiles(per_img_total_ms, ps=(50, 95)),
        "model_ms_img": _percentiles(per_img_model_ms, ps=(50, 95)),
        "overhd_ms_img": _percentiles(per_img_overhead_ms, ps=(50, 95)),

        "ms_forward_total": _percentiles(ms_per_forward_total, ps=(50, 95)),
        "ms_forward_model": _percentiles(ms_per_forward_model, ps=(50, 95)),

        "throughput": _percentiles(throughput_img_s, ps=(50, 95)),
    }

    if peak_alloc_mb:
        summary["peak_alloc_mb_p50"] = _percentiles(peak_alloc_mb, ps=(50,))["p50"]
        summary["peak_reserved_mb_p50"] = _percentiles(peak_reserved_mb, ps=(50,))["p50"]

    return summary


def pretty_print_summary(summary: dict, name=""):
    print(f"\n=== Benchmark: {name} ===")
    print(f"steps={summary['num_steps']} | batch={summary['batch_size']} | timed_runs={summary['timed_runs']}")
    print(f"NFE/sample p50: {summary['nfe_sample']['p50']:.1f} | p95: {summary['nfe_sample']['p95']:.1f}")
    print(f"NFE/img (legacy: NFE/sample ÷ batch) p50: {summary['nfe_img_legacy_p50']:.3f}")
    print(f"total ms/img p50: {summary['total_ms_img']['p50']:.3f} | p95: {summary['total_ms_img']['p95']:.3f}")
    print(f"model ms/img p50: {summary['model_ms_img']['p50']:.3f} | p95: {summary['model_ms_img']['p95']:.3f}")
    print(f"overhd ms/img p50: {summary['overhd_ms_img']['p50']:.3f} | p95: {summary['overhd_ms_img']['p95']:.3f}")
    print(f"ms/forward TOTAL p50: {summary['ms_forward_total']['p50']:.6f} | p95: {summary['ms_forward_total']['p95']:.6f}")
    print(f"ms/forward MODEL p50: {summary['ms_forward_model']['p50']:.6f} | p95: {summary['ms_forward_model']['p95']:.6f}")
    print(f"throughput img/s p50: {summary['throughput']['p50']:.2f} | p95: {summary['throughput']['p95']:.2f}")
    if "peak_alloc_mb_p50" in summary:
        print(f"peak alloc MB p50: {summary['peak_alloc_mb_p50']:.1f} | peak reserved MB p50: {summary['peak_reserved_mb_p50']:.1f}")


# ---------------------------
# Run (batch=16 and 128)
# ---------------------------
device = device if "device" in globals() else ("cuda" if torch.cuda.is_available() else "cpu")
model_to_bench = ema_model if "ema_model" in globals() else model

NUM_STEPS = 500
T_MIN = 0.02

b16 = run_sampling_benchmark(model_to_bench, sde, sampler_timed,
                             num_steps=NUM_STEPS, batch_size=16, device=device, t_min=T_MIN,
                             warmup_runs=2, timed_runs=10)
pretty_print_summary(b16, name=f"MNIST diffusion CNN | Euler--Maruyama | steps={NUM_STEPS} | batch=16")

b128 = run_sampling_benchmark(model_to_bench, sde, sampler_timed,
                              num_steps=NUM_STEPS, batch_size=128, device=device, t_min=T_MIN,
                              warmup_runs=2, timed_runs=10)
pretty_print_summary(b128, name=f"MNIST diffusion CNN | Euler--Maruyama | steps={NUM_STEPS} | batch=128")


Loaded checkpoint.
time_dim: 64
lambda_ : 1.0

=== Benchmark: MNIST diffusion CNN | Euler--Maruyama | steps=500 | batch=16 ===
steps=500 | batch=16 | timed_runs=10
NFE/sample p50: 499.0 | p95: 499.0
NFE/img (legacy: NFE/sample ÷ batch) p50: 31.188
total ms/img p50: 35.044 | p95: 36.213
model ms/img p50: 25.669 | p95: 26.217
overhd ms/img p50: 9.353 | p95: 10.007
ms/forward TOTAL p50: 1.123644 | p95: 1.161132
ms/forward MODEL p50: 0.823052 | p95: 0.840640
throughput img/s p50: 28.54 | p95: 29.01
peak alloc MB p50: 34.2 | peak reserved MB p50: 40.0

=== Benchmark: MNIST diffusion CNN | Euler--Maruyama | steps=500 | batch=128 ===
steps=500 | batch=128 | timed_runs=10
NFE/sample p50: 499.0 | p95: 499.0
NFE/img (legacy: NFE/sample ÷ batch) p50: 3.898
total ms/img p50: 18.610 | p95: 18.701
model ms/img p50: 17.318 | p95: 17.378
overhd ms/img p50: 1.289 | p95: 1.335
ms/forward TOTAL p50: 4.773582 | p95: 4.796964
ms/forward MODEL p50: 4.442250 | p95: 4.457754
throughput img/s p50: 53.74 | p95: