In [1]:
pip install matplotlib

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install gdown

Note: you may need to restart the kernel to use updated packages.


In [3]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

In [4]:
import gdown
import os

file_id = "18zvmgkm-T_PT8L-bSntddZhCVUZcMlNO"
out_path = "vp_cifar_ema_weights.pth"

if not os.path.exists(out_path):
    gdown.download(
        f"https://drive.google.com/uc?id={file_id}",
        out_path,
        quiet=False,
    )
else:
    print("EMA weights already downloaded.")


EMA weights already downloaded.


In [5]:
import torch
import math

class VPSDE:
    def __init__(self, beta_min=0.1, beta_max=20.0, T=1.0):
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.T = T

    def beta(self, t):
        # t in [0,1], shape [B]
        return self.beta_min + t * (self.beta_max - self.beta_min)

    def int_beta(self, t):
        # ∫0^t beta(s) ds for linear beta
        return self.beta_min * t + 0.5 * (self.beta_max - self.beta_min) * t**2

    def alpha(self, t):
        # alpha(t) = exp(-1/2 ∫ beta)
        return torch.exp(-0.5 * self.int_beta(t))

    def sigma(self, t):
        # sigma(t) = sqrt(1 - alpha(t)^2)
        a = self.alpha(t)
        return torch.sqrt(1.0 - a*a).clamp(min=1e-5)

    def diffusion(self, t):
        # g(t) = sqrt(beta(t))
        return torch.sqrt(self.beta(t)).clamp(min=1e-5)

    def drift(self, x, t):
        # f(x,t) = -1/2 beta(t) x
        b = self.beta(t).view(-1, 1, 1, 1)
        return -0.5 * b * x


In [6]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        self.dim = dim
        self.lin = nn.Linear(dim, dim)

    def forward(self, t):
        # t: [B]
        half_dim = self.dim // 2
        freqs = torch.exp(
            torch.arange(half_dim, device=t.device, dtype=t.dtype)
            * -(torch.log(torch.tensor(10000.0, device=t.device, dtype=t.dtype)) / half_dim)
        )
        args = t[:, None] * freqs[None, :]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        if emb.shape[1] != self.dim:
            emb = F.pad(emb, (0, self.dim - emb.shape[1]))
        return self.lin(emb)


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, groups=32):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNetScoreCIFAR3Level(nn.Module):
    """
    U-Net that predicts epsilon: eps_pred = model(x, t)
    Input x: [B,3,32,32], t: [B]
    """
    def __init__(self, time_dim=32, base_channels=128, img_channels=3, gn_groups=32):
        super().__init__()
        self.time_mlp = TimeEmbedding(dim=time_dim)

        in_ch = img_channels + time_dim  # 3 + time_dim
        C = base_channels

        # Encoder
        self.down1 = ConvBlock(in_ch, C, groups=gn_groups)        # 32x32
        self.pool1 = nn.MaxPool2d(2)                              # 32->16

        self.down2 = ConvBlock(C, 2*C, groups=gn_groups)          # 16x16
        self.pool2 = nn.MaxPool2d(2)                              # 16->8

        self.down3 = ConvBlock(2*C, 4*C, groups=gn_groups)        # 8x8
        self.pool3 = nn.MaxPool2d(2)                              # 8->4

        # Bottleneck
        self.bottleneck = ConvBlock(4*C, 8*C, groups=gn_groups)   # 4x4

        # Decoder (upsample via Sequential so keys are up*.1.weight etc.)
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),          # 4->8
            nn.Conv2d(8*C, 4*C, 3, padding=1),
        )
        self.dec3 = ConvBlock(8*C, 4*C, groups=gn_groups)

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),          # 8->16
            nn.Conv2d(4*C, 2*C, 3, padding=1),
        )
        self.dec2 = ConvBlock(4*C, 2*C, groups=gn_groups)

        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),          # 16->32
            nn.Conv2d(2*C, C, 3, padding=1),
        )
        self.dec1 = ConvBlock(2*C, C, groups=gn_groups)

        self.out_conv = nn.Conv2d(C, img_channels, 3, padding=1)

    def forward(self, x, t):
        # time emb -> [B, time_dim, H, W]
        emb = self.time_mlp(t)[:, :, None, None].expand(-1, -1, x.size(2), x.size(3))
        x_in = torch.cat([x, emb], dim=1)

        # Encoder
        d1 = self.down1(x_in)
        p1 = self.pool1(d1)

        d2 = self.down2(p1)
        p2 = self.pool2(d2)

        d3 = self.down3(p2)
        p3 = self.pool3(d3)

        # Bottleneck
        b = self.bottleneck(p3)

        # Decoder
        u3 = self.up3(b)
        u3 = torch.cat([u3, d3], dim=1)
        u3 = self.dec3(u3)

        u2 = self.up2(u3)
        u2 = torch.cat([u2, d2], dim=1)
        u2 = self.dec2(u2)

        u1 = self.up1(u2)
        u1 = torch.cat([u1, d1], dim=1)
        u1 = self.dec1(u1)

        return self.out_conv(u1)


In [7]:
@torch.no_grad()
def sample_prob_flow_ode(model, sde, num_steps=4000, batch_size=16, device="cuda", t_min=1e-4):
    model.eval()
    t_grid = torch.linspace(1.0, t_min, num_steps, device=device)

    x = torch.randn(batch_size, 3, 32, 32, device=device)

    for i in range(num_steps - 1):
        t_cur = t_grid[i]
        t_next = t_grid[i + 1]
        dt = t_next - t_cur  # negative

        t_batch = torch.full((batch_size,), float(t_cur.item()), device=device)

        beta = sde.beta(t_batch).view(batch_size, 1, 1, 1)
        sigma = sde.sigma(t_batch).view(batch_size, 1, 1, 1).clamp_min(1e-12)

        eps_pred = model(x, t_batch)
        score = -eps_pred / sigma

        f = -0.5 * beta * x
        drift = f - 0.5 * beta * score

        x = x + drift * dt

    return x


In [9]:
import time
import numpy as np
import torch
import torch.nn as nn


ema_weights_path = "vp_cifar_ema_weights.pth"  # EDIT ME if needed

time_dim = 32
base_channels = 128
img_channels = 3

beta_min = 0.1
beta_max = 20.0
T = 1.0

sde = VPSDE(beta_min=beta_min, beta_max=beta_max, T=T)

ema_model = UNetScoreCIFAR3Level(
    time_dim=time_dim,
    base_channels=base_channels,
    img_channels=img_channels,
).to(device)

ema_sd = torch.load(ema_weights_path, map_location=device)
ema_model.load_state_dict(ema_sd, strict=True)
ema_model.eval()

# ---------------------------
# (B) Benchmark harness (generic, sampler chosen by name)
# ---------------------------
def _sync_if_cuda(device):
    if isinstance(device, str):
        is_cuda = "cuda" in device
    else:
        is_cuda = (device.type == "cuda")
    if is_cuda and torch.cuda.is_available():
        torch.cuda.synchronize()

def _percentiles(xs, ps=(50, 95)):
    xs = np.asarray(xs, dtype=np.float64)
    return {f"p{p}": float(np.percentile(xs, p)) for p in ps}

class TimedModel(nn.Module):
    """Counts NFEs and times model forward calls with CUDA events."""
    def __init__(self, model: nn.Module, device="cuda"):
        super().__init__()
        self.model = model
        self.device = device
        self.reset_stats()

    def reset_stats(self):
        self.nfe = 0
        self.model_ms = 0.0

    @torch.no_grad()
    def forward(self, *args, **kwargs):
        self.nfe += 1
        if torch.cuda.is_available() and ("cuda" in str(self.device)):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = self.model(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()
            self.model_ms += start.elapsed_time(end)
            return out
        else:
            t0 = time.perf_counter()
            out = self.model(*args, **kwargs)
            t1 = time.perf_counter()
            self.model_ms += (t1 - t0) * 1000.0
            return out

# ---------------------------
# EDIT ME: choose sampler function name in this notebook
# ---------------------------
# Examples you used before: "sample_prob_flow_ode_timed", "sample_prob_flow_ode",
# "sample_prob_flow_heun", etc. Put the *actual* function name here.
SAMPLE_FN_NAME = "sample_prob_flow_ode"  # EDIT ME if needed

assert SAMPLE_FN_NAME in globals(), f"Sampler {SAMPLE_FN_NAME} not found in globals()."
_sampler = globals()[SAMPLE_FN_NAME]

@torch.no_grad()
def sampler_timed(timed_model, sde, *, num_steps, batch_size, device, t_min, **kwargs):
    """Calls the notebook sampler unchanged."""
    return _sampler(
        model=timed_model,
        sde=sde,
        num_steps=num_steps,
        batch_size=batch_size,
        device=device,
        t_min=t_min,
        **kwargs
    )

def run_sampling_benchmark(
    model: nn.Module,
    sde,
    sampler_fn,
    *,
    num_steps: int,
    batch_size: int,
    device="cuda",
    t_min=1e-4,
    warmup_runs=2,
    timed_runs=10,
    reset_cuda_peak_mem=True,
    sampler_kwargs=None,
):
    if sampler_kwargs is None:
        sampler_kwargs = {}

    timed_model = TimedModel(model, device=device).to(device)

    # Warmup
    for _ in range(warmup_runs):
        timed_model.reset_stats()
        _sync_if_cuda(device)
        _ = sampler_fn(
            timed_model, sde,
            num_steps=num_steps, batch_size=batch_size,
            device=device, t_min=t_min,
            **sampler_kwargs
        )
        _sync_if_cuda(device)

    per_img_total_ms = []
    per_img_model_ms = []
    per_img_overhead_ms = []
    nfes = []
    ms_per_forward_total = []
    ms_per_forward_model = []
    throughput_img_s = []
    peak_alloc_mb = []
    peak_reserved_mb = []

    for _ in range(timed_runs):
        timed_model.reset_stats()

        if reset_cuda_peak_mem and torch.cuda.is_available() and ("cuda" in str(device)):
            torch.cuda.reset_peak_memory_stats()

        _sync_if_cuda(device)
        t0 = time.perf_counter()
        _ = sampler_fn(
            timed_model, sde,
            num_steps=num_steps, batch_size=batch_size,
            device=device, t_min=t_min,
            **sampler_kwargs
        )
        _sync_if_cuda(device)
        t1 = time.perf_counter()

        total_ms_batch = (t1 - t0) * 1000.0
        model_ms_batch = float(timed_model.model_ms)
        overhead_ms_batch = total_ms_batch - model_ms_batch

        nfe = int(timed_model.nfe)
        nfes.append(nfe)

        total_ms_img = total_ms_batch / batch_size
        model_ms_img = model_ms_batch / batch_size
        overhead_ms_img = overhead_ms_batch / batch_size

        per_img_total_ms.append(total_ms_img)
        per_img_model_ms.append(model_ms_img)
        per_img_overhead_ms.append(overhead_ms_img)

        denom = max(nfe, 1)
        ms_per_forward_total.append(total_ms_batch / denom)
        ms_per_forward_model.append(model_ms_batch / denom)

        throughput_img_s.append(1000.0 / max(total_ms_img, 1e-12))

        if torch.cuda.is_available() and ("cuda" in str(device)):
            peak_alloc_mb.append(torch.cuda.max_memory_allocated() / (1024**2))
            peak_reserved_mb.append(torch.cuda.max_memory_reserved() / (1024**2))

    summary = {
        "num_steps": num_steps,
        "batch_size": batch_size,
        "t_min": t_min,
        "timed_runs": timed_runs,
        "nfe_sample": _percentiles(nfes, ps=(50, 95)),
        "nfe_img_legacy_p50": _percentiles([n / batch_size for n in nfes], ps=(50,))["p50"],
        "total_ms_img": _percentiles(per_img_total_ms, ps=(50, 95)),
        "model_ms_img": _percentiles(per_img_model_ms, ps=(50, 95)),
        "overhd_ms_img": _percentiles(per_img_overhead_ms, ps=(50, 95)),
        "ms_forward_total": _percentiles(ms_per_forward_total, ps=(50, 95)),
        "ms_forward_model": _percentiles(ms_per_forward_model, ps=(50, 95)),
        "throughput": _percentiles(throughput_img_s, ps=(50, 95)),
    }

    if peak_alloc_mb:
        summary["peak_alloc_mb_p50"] = _percentiles(peak_alloc_mb, ps=(50,))["p50"]
        summary["peak_reserved_mb_p50"] = _percentiles(peak_reserved_mb, ps=(50,))["p50"]

    return summary

def pretty_print_summary(summary: dict, name=""):
    print(f"\n=== Benchmark: {name} ===")
    print(f"steps={summary['num_steps']} | batch={summary['batch_size']} | timed_runs={summary['timed_runs']}")
    print(f"NFE/sample p50: {summary['nfe_sample']['p50']:.1f} | p95: {summary['nfe_sample']['p95']:.1f}")
    print(f"NFE/img (legacy: NFE/sample ÷ batch) p50: {summary['nfe_img_legacy_p50']:.3f}")
    print(f"total ms/img p50: {summary['total_ms_img']['p50']:.3f} | p95: {summary['total_ms_img']['p95']:.3f}")
    print(f"model ms/img p50: {summary['model_ms_img']['p50']:.3f} | p95: {summary['model_ms_img']['p95']:.3f}")
    print(f"overhd ms/img p50: {summary['overhd_ms_img']['p50']:.3f} | p95: {summary['overhd_ms_img']['p95']:.3f}")
    print(f"ms/forward TOTAL p50: {summary['ms_forward_total']['p50']:.6f} | p95: {summary['ms_forward_total']['p95']:.6f}")
    print(f"ms/forward MODEL p50: {summary['ms_forward_model']['p50']:.6f} | p95: {summary['ms_forward_model']['p95']:.6f}")
    print(f"throughput img/s p50: {summary['throughput']['p50']:.2f} | p95: {summary['throughput']['p95']:.2f}")
    if "peak_alloc_mb_p50" in summary:
        print(f"peak alloc MB p50: {summary['peak_alloc_mb_p50']:.1f} | peak reserved MB p50: {summary['peak_reserved_mb_p50']:.1f}")

# ---------------------------
# (C) Run benchmarks (batch=16 and 128)
# ---------------------------
NUM_STEPS = num_steps if "num_steps" in globals() else 500
T_MIN = t_min if "t_min" in globals() else 1e-4

sampler_kwargs = {}
# If your notebook defines extra args like time_power, pass them through:
if "time_power" in globals():
    sampler_kwargs["time_power"] = time_power

b16 = run_sampling_benchmark(
    ema_model, sde, sampler_timed,
    num_steps=NUM_STEPS, batch_size=16, device=device, t_min=T_MIN,
    warmup_runs=2, timed_runs=10, sampler_kwargs=sampler_kwargs
)
pretty_print_summary(b16, name=f"VP (no attn) | {SAMPLE_FN_NAME} | steps={NUM_STEPS} | batch=16")

b128 = run_sampling_benchmark(
    ema_model, sde, sampler_timed,
    num_steps=NUM_STEPS, batch_size=128, device=device, t_min=T_MIN,
    warmup_runs=2, timed_runs=10, sampler_kwargs=sampler_kwargs
)
pretty_print_summary(b128, name=f"VP (no attn) | {SAMPLE_FN_NAME} | steps={NUM_STEPS} | batch=128")



=== Benchmark: VP (no attn) | sample_prob_flow_ode | steps=500 | batch=16 ===
steps=500 | batch=16 | timed_runs=10
NFE/sample p50: 499.0 | p95: 499.0
NFE/img (legacy: NFE/sample ÷ batch) p50: 31.188
total ms/img p50: 629.928 | p95: 630.806
model ms/img p50: 606.357 | p95: 607.259
overhd ms/img p50: 24.005 | p95: 26.586
ms/forward TOTAL p50: 20.198080 | p95: 20.226245
ms/forward MODEL p50: 19.442322 | p95: 19.471245
throughput img/s p50: 1.59 | p95: 1.59
peak alloc MB p50: 587.4 | peak reserved MB p50: 1562.0

=== Benchmark: VP (no attn) | sample_prob_flow_ode | steps=500 | batch=128 ===
steps=500 | batch=128 | timed_runs=10
NFE/sample p50: 499.0 | p95: 499.0
NFE/img (legacy: NFE/sample ÷ batch) p50: 3.898
total ms/img p50: 446.149 | p95: 447.447
model ms/img p50: 442.976 | p95: 443.982
overhd ms/img p50: 3.189 | p95: 3.848
ms/forward TOTAL p50: 114.443096 | p95: 114.776072
ms/forward MODEL p50: 113.628994 | p95: 113.887269
throughput img/s p50: 2.24 | p95: 2.24
peak alloc MB p50: 1268