In [1]:
pip install matplotlib

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install gdown

Note: you may need to restart the kernel to use updated packages.


In [3]:
import gdown
import os
file_id = "1UWUP4u-zcY1AOOD5GoF0xKZiQ-kn2X7Z"
out_path = "cifar_ou_no_attention_ema_weights.pth"

if not os.path.exists(out_path):
    gdown.download(
        f"https://drive.google.com/uc?id={file_id}",
        out_path,
        quiet=False,
    )
else:
    print("EMA weights already downloaded.")


Downloading...
From (original): https://drive.google.com/uc?id=1UWUP4u-zcY1AOOD5GoF0xKZiQ-kn2X7Z
From (redirected): https://drive.google.com/uc?id=1UWUP4u-zcY1AOOD5GoF0xKZiQ-kn2X7Z&confirm=t&uuid=fec3ac64-dd02-485a-af17-9dc83c827e59
To: /home/onyxia/work/cifar_ou_no_attention_ema_weights.pth
100%|██████████| 123M/123M [00:01<00:00, 106MB/s]  


In [4]:
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"


In [5]:
class Forward:  # OU process
    def __init__(self, lambda_=0.3):
        self.lmbd = float(lambda_)

    def mean(self, x0, t):
        t = t[:, None, None, None]
        return x0 * torch.exp(-self.lmbd * t)

    def std(self, t):
        t = t[:, None, None, None]
        return torch.sqrt(1.0 - torch.exp(-2 * self.lmbd * t)).clamp(min=1e-3)

    def diffusion_coeff(self, t):
        # shape [B]
        return (2 * self.lmbd) ** 0.5 * torch.ones_like(t)


In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        self.dim = dim
        self.lin = nn.Linear(dim, dim)

    def forward(self, t):
        half_dim = self.dim // 2
        freqs = torch.exp(
            torch.arange(half_dim, device=t.device, dtype=t.dtype)
            * -(torch.log(torch.tensor(10000.0, device=t.device, dtype=t.dtype)) / half_dim)
        )
        args = t[:, None] * freqs[None, :]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        return self.lin(emb)  # [B, dim]


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNetScoreCIFAR3Level(nn.Module):
    """
    Score network for OU reverse SDE:
      score = model(x, t)  with shape [B,3,32,32]
    """
    def __init__(self, time_dim=32, base_channels=128, img_channels=3):
        super().__init__()
        self.time_mlp = TimeEmbedding(dim=time_dim)
        in_ch = img_channels + time_dim
        C = base_channels

        # Encoder
        self.down1 = ConvBlock(in_ch, C)      # 32x32
        self.pool1 = nn.MaxPool2d(2)          # 32->16

        self.down2 = ConvBlock(C, 2*C)        # 16x16
        self.pool2 = nn.MaxPool2d(2)          # 16->8

        self.down3 = ConvBlock(2*C, 4*C)      # 8x8
        self.pool3 = nn.MaxPool2d(2)          # 8->4

        self.bottleneck = ConvBlock(4*C, 8*C) # 4x4

        self.up3 = nn.ConvTranspose2d(8*C, 4*C, kernel_size=2, stride=2)  # 4->8
        self.dec3 = ConvBlock(8*C, 4*C)

        self.up2 = nn.ConvTranspose2d(4*C, 2*C, kernel_size=2, stride=2)  # 8->16
        self.dec2 = ConvBlock(4*C, 2*C)

        self.up1 = nn.ConvTranspose2d(2*C, C, kernel_size=2, stride=2)    # 16->32
        self.dec1 = ConvBlock(2*C, C)

        self.out_conv = nn.Conv2d(C, img_channels, kernel_size=3, padding=1)

    def forward(self, x, t):
        emb = self.time_mlp(t)[:, :, None, None].expand(-1, -1, x.size(2), x.size(3))
        x_in = torch.cat([x, emb], dim=1)

        d1 = self.down1(x_in); p1 = self.pool1(d1)
        d2 = self.down2(p1);   p2 = self.pool2(d2)
        d3 = self.down3(p2);   p3 = self.pool3(d3)

        b = self.bottleneck(p3)

        u3 = self.up3(b)
        u3 = torch.cat([u3, d3], dim=1)
        u3 = self.dec3(u3)

        u2 = self.up2(u3)
        u2 = torch.cat([u2, d2], dim=1)
        u2 = self.dec2(u2)

        u1 = self.up1(u2)
        u1 = torch.cat([u1, d1], dim=1)
        u1 = self.dec1(u1)

        return self.out_conv(u1)


In [8]:
@torch.no_grad()
def sample_reverse_euler_maruyama_cifar(model, sde, num_steps=1000, batch_size=16, device="cuda", t_min=0.01):
    model.eval()
    T = 1.0
    t_grid = torch.linspace(T, t_min, num_steps, device=device)

    x = torch.randn(batch_size, 3, 32, 32, device=device)

    for i in range(num_steps - 1):
        t_cur = t_grid[i]
        t_next = t_grid[i + 1]
        dt = t_next - t_cur  # dt < 0

        t_batch = torch.full((batch_size,), t_cur, device=device)

        g = sde.diffusion_coeff(t_batch)         # [B]
        g2 = (g ** 2).view(batch_size, 1, 1, 1)
        g = g.view(batch_size, 1, 1, 1)

        score = model(x, t_batch)                # [B,3,32,32]
        drift = -sde.lmbd * x - g2 * score

        noise = torch.randn_like(x)
        x = x + drift * dt + g * torch.sqrt(-dt) * noise

    return x


In [None]:
def show_cifar_grid(x, nrow=4, title="Samples"):
    x = x.detach().cpu()
    x = (x.clamp(-1, 1) + 1) / 2.0  # to [0,1]

    B = x.size(0)
    ncol = math.ceil(B / nrow)

    plt.figure(figsize=(ncol * 2, nrow * 2))
    for i in range(B):
        plt.subplot(nrow, ncol, i + 1)
        img = x[i].permute(1, 2, 0).numpy()  
        plt.imshow(img)
        plt.axis("off")
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


In [None]:
time_dim = 32
base_channels = 128
lambda_ = 0.3     
t_min = 0.01

ema_weights_path = "cifar_ou_no_attention_ema_weights.pth"  

sde = Forward(lambda_=2)

ema_model = UNetScoreCIFAR3Level(
    time_dim=time_dim,
    base_channels=base_channels,
    img_channels=3,
).to(device)

ema_sd = torch.load(ema_weights_path, map_location=device)
ema_model.load_state_dict(ema_sd, strict=True)
ema_model.eval()

samples = sample_reverse_euler_maruyama_cifar(
    ema_model, sde, num_steps=2000, batch_size=16, device=device, t_min=t_min
)
show_cifar_grid(samples, nrow=4, title="CIFAR-10 samples (EMA, no-attention OU)")


NameError: name 'sample_reverse_euler_maruyama_cifar' is not defined

In [None]:
def run_sampling_benchmark(
    model: nn.Module,
    sde,
    sampler_fn,                 
    *,
    num_steps: int,
    batch_size: int,
    device="cuda",
    t_min=0.01,
    warmup_runs=2,
    timed_runs=10,
    reset_cuda_peak_mem=True,
):
    timed_model = TimedModel(model, device=device).to(device)

    for _ in range(warmup_runs):
        timed_model.reset_stats()
        _sync_if_cuda(device)
        _ = sampler_fn(timed_model, sde, num_steps=num_steps, batch_size=batch_size, device=device, t_min=t_min)
        _sync_if_cuda(device)

    per_img_total_ms = []
    per_img_model_ms = []
    per_img_overhead_ms = []

    nfes_per_run = []

    ms_per_forward_total = []
    ms_per_forward_model = []

    throughput_img_per_s = []

    peak_alloc_mb = []
    peak_reserved_mb = []

    for _ in range(timed_runs):
        timed_model.reset_stats()

        if reset_cuda_peak_mem and torch.cuda.is_available() and ("cuda" in str(device)):
            torch.cuda.reset_peak_memory_stats()

        _sync_if_cuda(device)
        t0 = time.perf_counter()
        _ = sampler_fn(timed_model, sde, num_steps=num_steps, batch_size=batch_size, device=device, t_min=t_min)
        _sync_if_cuda(device)
        t1 = time.perf_counter()

        total_ms_batch = (t1 - t0) * 1000.0
        model_ms_batch = float(timed_model.model_ms)
        overhead_ms_batch = total_ms_batch - model_ms_batch

        nfe_batch = int(timed_model.nfe)
        nfes_per_run.append(nfe_batch)

        total_ms_img = total_ms_batch / batch_size
        model_ms_img = model_ms_batch / batch_size
        overhead_ms_img = overhead_ms_batch / batch_size

        per_img_total_ms.append(total_ms_img)
        per_img_model_ms.append(model_ms_img)
        per_img_overhead_ms.append(overhead_ms_img)

        denom = max(nfe_batch, 1)
        ms_per_forward_total.append(total_ms_batch / denom)
        ms_per_forward_model.append(model_ms_batch / denom)

        throughput_img_per_s.append(1000.0 / max(total_ms_img, 1e-12))

        if torch.cuda.is_available() and ("cuda" in str(device)):
            peak_alloc_mb.append(torch.cuda.max_memory_allocated() / (1024**2))
            peak_reserved_mb.append(torch.cuda.max_memory_reserved() / (1024**2))

    nfe_per_sample = nfes_per_run
    nfe_per_img_legacy = [n / batch_size for n in nfes_per_run]

    total_ms_img_pct = _percentiles(per_img_total_ms, ps=(50, 95))
    model_ms_img_pct = _percentiles(per_img_model_ms, ps=(50, 95))
    nfe_sample_pct = _percentiles(nfe_per_sample, ps=(50, 95))

    summary = {
        "num_steps": num_steps,
        "batch_size": batch_size,
        "t_min": t_min,
        "timed_runs": timed_runs,

        "per_img_total_ms": per_img_total_ms,
        "per_img_model_ms": per_img_model_ms,
        "per_img_overhead_ms": per_img_overhead_ms,
        "nfe_per_sample": nfe_per_sample,
        "nfe_per_img_legacy": nfe_per_img_legacy,
        "ms_per_forward_total": ms_per_forward_total,
        "ms_per_forward_model": ms_per_forward_model,
        "throughput_img_per_s": throughput_img_per_s,

        "total_ms_percentiles": total_ms_img_pct,
        "model_ms_percentiles": model_ms_img_pct,
        "overhead_ms_percentiles": _percentiles(per_img_overhead_ms, ps=(50, 95)),
        "nfe_per_sample_percentiles": nfe_sample_pct,
        "nfe_per_img_legacy_percentiles": _percentiles(nfe_per_img_legacy, ps=(50, 95)),
        "ms_per_forward_total_percentiles": _percentiles(ms_per_forward_total, ps=(50, 95)),
        "ms_per_forward_model_percentiles": _percentiles(ms_per_forward_model, ps=(50, 95)),
        "throughput_img_per_s_percentiles": _percentiles(throughput_img_per_s, ps=(50, 95)),

        "throughput_img_per_s_p50": _percentiles(throughput_img_per_s, ps=(50,))["p50"],

        "ms_per_forward_equiv_from_total_p50": total_ms_img_pct["p50"] / max(nfe_sample_pct["p50"], 1e-9),
        "ms_per_forward_equiv_from_model_p50": model_ms_img_pct["p50"] / max(nfe_sample_pct["p50"], 1e-9),

        "ms_per_forward_equiv_from_total_p95": total_ms_img_pct["p95"] / max(nfe_sample_pct["p95"], 1e-9),
        "ms_per_forward_equiv_from_model_p95": model_ms_img_pct["p95"] / max(nfe_sample_pct["p95"], 1e-9),
    }

    if peak_alloc_mb:
        summary["peak_alloc_mb_p50"] = _percentiles(peak_alloc_mb, ps=(50,))["p50"]
        summary["peak_reserved_mb_p50"] = _percentiles(peak_reserved_mb, ps=(50,))["p50"]

    return summary


def pretty_print_summary(summary: dict, name=""):
    print(f"\n=== Benchmark: {name} ===")
    print(f"steps={summary['num_steps']} | batch={summary['batch_size']} | timed_runs={summary['timed_runs']}")

    print(f"NFE/sample p50: {summary['nfe_per_sample_percentiles']['p50']:.1f} | p95: {summary['nfe_per_sample_percentiles']['p95']:.1f}")
    print(f"NFE/img (legacy: NFE/sample ÷ batch) p50: {summary['nfe_per_img_legacy_percentiles']['p50']:.3f}")

    print(f"total ms/img p50: {summary['total_ms_percentiles']['p50']:.3f} | p95: {summary['total_ms_percentiles']['p95']:.3f}")
    print(f"model ms/img p50: {summary['model_ms_percentiles']['p50']:.3f} | p95: {summary['model_ms_percentiles']['p95']:.3f}")
    print(f"overhd ms/img p50: {summary['overhead_ms_percentiles']['p50']:.3f} | p95: {summary['overhead_ms_percentiles']['p95']:.3f}")

    print(
        f"ms/forward TOTAL p50: {summary['ms_per_forward_total_percentiles']['p50']:.6f} | "
        f"p95: {summary['ms_per_forward_total_percentiles']['p95']:.6f}"
    )
    print(
        f"ms/forward MODEL p50: {summary['ms_per_forward_model_percentiles']['p50']:.6f} | "
        f"p95: {summary['ms_per_forward_model_percentiles']['p95']:.6f}"
    )

    print(
        f"throughput img/s p50: {summary['throughput_img_per_s_percentiles']['p50']:.2f} | "
        f"p95: {summary['throughput_img_per_s_percentiles']['p95']:.2f}"
    )

    print(f"(derived) ms/forward equiv from TOTAL p50: {summary['ms_per_forward_equiv_from_total_p50']:.6f}")
    print(f"(derived) ms/forward equiv from MODEL p50: {summary['ms_per_forward_equiv_from_model_p50']:.6f}")

    print(f"(derived) ms/forward equiv from TOTAL p95: {summary['ms_per_forward_equiv_from_total_p95']:.6f}")
    print(f"(derived) ms/forward equiv from MODEL p95: {summary['ms_per_forward_equiv_from_model_p95']:.6f}")

    if "peak_alloc_mb_p50" in summary:
        print(f"peak alloc MB p50: {summary['peak_alloc_mb_p50']:.1f} | peak reserved MB p50: {summary['peak_reserved_mb_p50']:.1f}")


bench1 = run_sampling_benchmark(
    ema_model, sde, sample_reverse_euler_maruyama_cifar_timed,
    num_steps=2000, batch_size=16, device=device, t_min=t_min,
    warmup_runs=2, timed_runs=10
)
pretty_print_summary(bench1, name="OU no-attention | Euler–Maruyama | steps=2000 | batch=16")

bench2 = run_sampling_benchmark(
    ema_model, sde, sample_reverse_euler_maruyama_cifar_timed,
    num_steps=2000, batch_size=128, device=device, t_min=t_min,
    warmup_runs=2, timed_runs=10
)
pretty_print_summary(bench2, name="OU no-attention | Euler–Maruyama | steps=2000 | batch=256")



=== Benchmark: OU no-attention | Euler–Maruyama | steps=2000 | batch=16 ===
steps=2000 | batch=16 | timed_runs=10
NFE/sample p50: 1999.0 | p95: 1999.0
NFE/img (legacy: NFE/sample ÷ batch) p50: 124.938
total ms/img p50: 2131.308 | p95: 2295.116
model ms/img p50: 2076.202 | p95: 2235.849
overhd ms/img p50: 58.076 | p95: 60.922
ms/forward TOTAL p50: 17.058996 | p95: 18.370111
ms/forward MODEL p50: 16.617925 | p95: 17.895738
throughput img/s p50: 0.47 | p95: 0.48
(derived) ms/forward equiv from TOTAL p50: 1.066187
(derived) ms/forward equiv from MODEL p50: 1.038620
(derived) ms/forward equiv from TOTAL p95: 1.148132
(derived) ms/forward equiv from MODEL p95: 1.118484
peak alloc MB p50: 420.5 | peak reserved MB p50: 2430.0

=== Benchmark: OU no-attention | Euler–Maruyama | steps=2000 | batch=256 ===
steps=2000 | batch=128 | timed_runs=10
NFE/sample p50: 1999.0 | p95: 1999.0
NFE/img (legacy: NFE/sample ÷ batch) p50: 15.617
total ms/img p50: 1393.599 | p95: 1394.657
model ms/img p50: 1382.59