In [1]:
pip install matplotlib

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install gdown

Note: you may need to restart the kernel to use updated packages.


In [3]:
import gdown
import os

file_id = "12Gb_fwzUVRAK81mZ2k0YkbTj80gBrV73"
out_path = "vp_cfg_attn8_4_ema_weights.pth"

if not os.path.exists(out_path):
    gdown.download(
        f"https://drive.google.com/uc?id={file_id}",
        out_path,
        quiet=False,
    )
else:
    print("EMA weights already downloaded.")


Downloading...
From (original): https://drive.google.com/uc?id=12Gb_fwzUVRAK81mZ2k0YkbTj80gBrV73
From (redirected): https://drive.google.com/uc?id=12Gb_fwzUVRAK81mZ2k0YkbTj80gBrV73&confirm=t&uuid=1b5da246-8577-414a-b426-a3ed13961f7e
To: /home/onyxia/work/vp_cfg_attn8_4_ema_weights.pth
100%|██████████| 158M/158M [00:27<00:00, 5.70MB/s] 


In [4]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"


In [5]:
class VPSDE:
    def __init__(self, beta_min=0.1, beta_max=20.0, T=1.0):
        self.beta_min = float(beta_min)
        self.beta_max = float(beta_max)
        self.T = float(T)

    def beta(self, t):
        return self.beta_min + t * (self.beta_max - self.beta_min)

    def int_beta(self, t):
        return self.beta_min * t + 0.5 * (self.beta_max - self.beta_min) * t**2

    def alpha(self, t):
        return torch.exp(-0.5 * self.int_beta(t))

    def sigma(self, t):
        a = self.alpha(t)
        return torch.sqrt((1.0 - a*a).clamp(min=1e-12))

    def drift(self, x, t):
        b = self.beta(t).view(-1, 1, 1, 1)
        return -0.5 * b * x


In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        self.dim = dim
        self.lin = nn.Linear(dim, dim)

    def forward(self, t):
        half_dim = self.dim // 2
        freqs = torch.exp(
            torch.arange(half_dim, device=t.device, dtype=t.dtype)
            * -(torch.log(torch.tensor(10000.0, device=t.device, dtype=t.dtype)) / half_dim)
        )
        args = t[:, None] * freqs[None, :]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        if emb.shape[1] != self.dim:
            emb = F.pad(emb, (0, self.dim - emb.shape[1]))
        return self.lin(emb)


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, groups=32):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class SelfAttention2d(nn.Module):
    def __init__(self, channels, num_heads=4, gn_groups=32):
        super().__init__()
        assert channels % num_heads == 0
        self.num_heads = num_heads
        self.head_dim = channels // num_heads
        self.scale = self.head_dim ** -0.5

        self.norm = nn.GroupNorm(gn_groups, channels)
        self.qkv = nn.Conv2d(channels, 3 * channels, kernel_size=1, bias=False)
        self.proj = nn.Conv2d(channels, channels, kernel_size=1)

    def forward(self, x):
        B, C, H, W = x.shape
        h = self.norm(x)
        qkv = self.qkv(h)
        q, k, v = qkv.chunk(3, dim=1)

        N = H * W
        q = q.view(B, self.num_heads, self.head_dim, N).permute(0, 1, 3, 2)  # [B,h,N,d]
        k = k.view(B, self.num_heads, self.head_dim, N)                      # [B,h,d,N]
        v = v.view(B, self.num_heads, self.head_dim, N).permute(0, 1, 3, 2)  # [B,h,N,d]

        attn = (q @ k) * self.scale
        attn = attn.softmax(dim=-1)

        out = attn @ v
        out = out.permute(0, 1, 3, 2).contiguous().view(B, C, H, W)
        out = self.proj(out)
        return x + out


class UNetCIFAR3Level_Attn_CFG(nn.Module):
    """
    eps-prediction U-Net with attention at 8x8 and 4x4, plus CFG conditioning.
    y in {0..9} or y=null_label (=10) for unconditional.
    """
    def __init__(self, time_dim=32, base_channels=128, img_channels=3, num_classes=10, attn_heads=4):
        super().__init__()
        self.time_dim = time_dim
        self.num_classes = num_classes
        self.null_label = num_classes

        self.time_mlp = TimeEmbedding(dim=time_dim)
        self.label_emb = nn.Embedding(num_classes + 1, time_dim)

        in_ch = img_channels + time_dim
        C = base_channels

        # Encoder
        self.down1 = ConvBlock(in_ch, C)        
        self.pool1 = nn.MaxPool2d(2)            

        self.down2 = ConvBlock(C, 2*C)          
        self.pool2 = nn.MaxPool2d(2)           

        self.down3 = ConvBlock(2*C, 4*C)        
        self.attn8 = SelfAttention2d(4*C, num_heads=attn_heads)
        self.pool3 = nn.MaxPool2d(2)           

        # Bottleneck
        self.bottleneck = ConvBlock(4*C, 8*C)   
        self.attn4 = SelfAttention2d(8*C, num_heads=attn_heads)

        # Decoder
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(8*C, 4*C, 3, padding=1),
        )
        self.dec3 = ConvBlock(8*C, 4*C)

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(4*C, 2*C, 3, padding=1),
        )
        self.dec2 = ConvBlock(4*C, 2*C)

        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(2*C, C, 3, padding=1),
        )
        self.dec1 = ConvBlock(2*C, C)

        self.out_conv = nn.Conv2d(C, img_channels, 3, padding=1)

    def forward(self, x, t, y=None):
        B, _, H, W = x.shape

        emb_t = self.time_mlp(t)  # [B,time_dim]
        if y is None:
            y = torch.full((B,), self.null_label, device=x.device, dtype=torch.long)
        emb_y = self.label_emb(y)

        emb = emb_t + emb_y
        emb = emb[:, :, None, None].expand(-1, -1, H, W)

        x_in = torch.cat([x, emb], dim=1)

        d1 = self.down1(x_in); p1 = self.pool1(d1)
        d2 = self.down2(p1);   p2 = self.pool2(d2)

        d3 = self.down3(p2)
        d3 = self.attn8(d3)
        p3 = self.pool3(d3)

        b = self.bottleneck(p3)
        b = self.attn4(b)

        u3 = self.up3(b); u3 = torch.cat([u3, d3], dim=1); u3 = self.dec3(u3)
        u2 = self.up2(u3); u2 = torch.cat([u2, d2], dim=1); u2 = self.dec2(u2)
        u1 = self.up1(u2); u1 = torch.cat([u1, d1], dim=1); u1 = self.dec1(u1)

        return self.out_conv(u1)


In [None]:
num_classes = 10
time_dim = 32
base_channels = 128
img_channels = 3
attn_heads = 4

beta_min = 0.1
beta_max = 20.0
T = 1.0
t_min = 1e-4

ema_weights_path = "vp_cfg_attn8_4_ema_weights.pth"  

sde = VPSDE(beta_min=beta_min, beta_max=beta_max, T=T)

ema_model = UNetCIFAR3Level_Attn_CFG(
    time_dim=time_dim,
    base_channels=base_channels,
    img_channels=img_channels,
    num_classes=num_classes,
    attn_heads=attn_heads,
).to(device)

ema_sd = torch.load(ema_weights_path, map_location=device)
ema_model.load_state_dict(ema_sd, strict=True)
ema_model.eval()

import time
import numpy as np
import torch
import torch.nn as nn

def _sync_if_cuda(device):
    if isinstance(device, str):
        is_cuda = "cuda" in device
    else:
        is_cuda = (device.type == "cuda")
    if is_cuda and torch.cuda.is_available():
        torch.cuda.synchronize()

def _percentiles(xs, ps=(50, 95)):
    xs = np.asarray(xs, dtype=np.float64)
    return {f"p{p}": float(np.percentile(xs, p)) for p in ps}

class TimedModel(nn.Module):
    """Wrap a model to count NFEs and accumulate forward time using CUDA events."""
    def __init__(self, model: nn.Module, device="cuda"):
        super().__init__()
        self.model = model
        self.device = device
        self.reset_stats()

    def reset_stats(self):
        self.nfe = 0
        self.model_ms = 0.0

    @torch.no_grad()
    def forward(self, *args, **kwargs):
        self.nfe += 1

        if torch.cuda.is_available() and ("cuda" in str(self.device)):
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            start.record()
            out = self.model(*args, **kwargs)
            end.record()
            torch.cuda.synchronize()
            self.model_ms += start.elapsed_time(end)
            return out
        else:
            t0 = time.perf_counter()
            out = self.model(*args, **kwargs)
            t1 = time.perf_counter()
            self.model_ms += (t1 - t0) * 1000.0
            return out

@torch.no_grad()
def _vp_probflow_drift_cfg(timed_model, sde, x, t_batch, y, w: float):
    """
    timed_model is assumed to predict epsilon: eps = model(x, t, y)
    CFG: eps = eps_u + w * (eps_c - eps_u)
    score = -eps / sigma(t)
    drift = f - 0.5 * beta(t) * score   (probability flow ODE for VP SDE)
    where f = -0.5 * beta(t) * x
    """
    B = x.shape[0]
    beta = sde.beta(t_batch).view(B, 1, 1, 1)
    sigma = sde.sigma(t_batch).view(B, 1, 1, 1).clamp_min(1e-12)

    y_uncond = None

    eps_u = timed_model(x, t_batch, y_uncond)
    eps_c = timed_model(x, t_batch, y)
    eps = eps_u + w * (eps_c - eps_u)

    score = -eps / sigma
    f = -0.5 * beta * x
    drift = f - 0.5 * beta * score
    return drift

@torch.no_grad()
def sample_vp_probflow_heun_cfg_timed(
    timed_model,
    sde,
    *,
    batch_size: int,
    num_steps: int,
    w: float,
    device="cuda",
    t_min: float = 1e-4,
    num_classes: int = 10,
):
    """
    Returns x (samples) and y (labels).
    Uses Heun (2nd order) on the probability flow ODE.
    """
    timed_model.eval()
    t_grid = torch.linspace(1.0, t_min, num_steps, device=device)
    x = torch.randn(batch_size, 3, 32, 32, device=device)

    y = torch.randint(low=0, high=num_classes, size=(batch_size,), device=device)

    for i in range(num_steps - 1):
        t_cur = t_grid[i]
        t_next = t_grid[i + 1]
        dt = t_next - t_cur  

        t_batch = torch.full((batch_size,), float(t_cur.item()), device=device)
        drift1 = _vp_probflow_drift_cfg(timed_model, sde, x, t_batch, y, w=w)

        x_pred = x + drift1 * dt

        t_batch_next = torch.full((batch_size,), float(t_next.item()), device=device)
        drift2 = _vp_probflow_drift_cfg(timed_model, sde, x_pred, t_batch_next, y, w=w)

        x = x + 0.5 * (drift1 + drift2) * dt

    return x, y

def run_sampling_benchmark(
    model: nn.Module,
    sde,
    sampler_fn,
    *,
    num_steps: int,
    batch_size: int,
    device="cuda",
    t_min=1e-4,
    warmup_runs=2,
    timed_runs=10,
    reset_cuda_peak_mem=True,
    **sampler_kwargs,
):
    timed_model = TimedModel(model, device=device).to(device)

    # Warmup
    for _ in range(warmup_runs):
        timed_model.reset_stats()
        _sync_if_cuda(device)
        _ = sampler_fn(timed_model, sde, batch_size=batch_size, num_steps=num_steps, device=device, t_min=t_min, **sampler_kwargs)
        _sync_if_cuda(device)

    per_img_total_ms = []
    per_img_model_ms = []
    per_img_overhead_ms = []
    nfes_per_run = []
    ms_per_forward_total = []
    ms_per_forward_model = []
    throughput_img_per_s = []
    peak_alloc_mb = []
    peak_reserved_mb = []

    for _ in range(timed_runs):
        timed_model.reset_stats()

        if reset_cuda_peak_mem and torch.cuda.is_available() and ("cuda" in str(device)):
            torch.cuda.reset_peak_memory_stats()

        _sync_if_cuda(device)
        t0 = time.perf_counter()
        _ = sampler_fn(timed_model, sde, batch_size=batch_size, num_steps=num_steps, device=device, t_min=t_min, **sampler_kwargs)
        _sync_if_cuda(device)
        t1 = time.perf_counter()

        total_ms_batch = (t1 - t0) * 1000.0
        model_ms_batch = float(timed_model.model_ms)
        overhead_ms_batch = total_ms_batch - model_ms_batch

        nfe_batch = int(timed_model.nfe)
        nfes_per_run.append(nfe_batch)

        total_ms_img = total_ms_batch / batch_size
        model_ms_img = model_ms_batch / batch_size
        overhead_ms_img = overhead_ms_batch / batch_size

        per_img_total_ms.append(total_ms_img)
        per_img_model_ms.append(model_ms_img)
        per_img_overhead_ms.append(overhead_ms_img)

        denom = max(nfe_batch, 1)
        ms_per_forward_total.append(total_ms_batch / denom)
        ms_per_forward_model.append(model_ms_batch / denom)

        throughput_img_per_s.append(1000.0 / max(total_ms_img, 1e-12))

        if torch.cuda.is_available() and ("cuda" in str(device)):
            peak_alloc_mb.append(torch.cuda.max_memory_allocated() / (1024**2))
            peak_reserved_mb.append(torch.cuda.max_memory_reserved() / (1024**2))

    nfe_per_sample = nfes_per_run
    nfe_per_img_legacy = [n / batch_size for n in nfes_per_run]

    summary = {
        "num_steps": num_steps,
        "batch_size": batch_size,
        "t_min": t_min,
        "timed_runs": timed_runs,

        "nfe_per_sample_percentiles": _percentiles(nfe_per_sample, ps=(50, 95)),
        "nfe_per_img_legacy_percentiles": _percentiles(nfe_per_img_legacy, ps=(50,)),

        "total_ms_percentiles": _percentiles(per_img_total_ms, ps=(50, 95)),
        "model_ms_percentiles": _percentiles(per_img_model_ms, ps=(50, 95)),
        "overhead_ms_percentiles": _percentiles(per_img_overhead_ms, ps=(50, 95)),

        "ms_per_forward_total_percentiles": _percentiles(ms_per_forward_total, ps=(50, 95)),
        "ms_per_forward_model_percentiles": _percentiles(ms_per_forward_model, ps=(50, 95)),

        "throughput_img_per_s_percentiles": _percentiles(throughput_img_per_s, ps=(50, 95)),
    }

    if peak_alloc_mb:
        summary["peak_alloc_mb_p50"] = _percentiles(peak_alloc_mb, ps=(50,))["p50"]
        summary["peak_reserved_mb_p50"] = _percentiles(peak_reserved_mb, ps=(50,))["p50"]

    return summary


def pretty_print_summary(summary: dict, name=""):
    print(f"\n=== Benchmark: {name} ===")
    print(f"steps={summary['num_steps']} | batch={summary['batch_size']} | timed_runs={summary['timed_runs']}")
    print(f"NFE/sample p50: {summary['nfe_per_sample_percentiles']['p50']:.1f} | p95: {summary['nfe_per_sample_percentiles']['p95']:.1f}")
    print(f"NFE/img (legacy: NFE/sample ÷ batch) p50: {summary['nfe_per_img_legacy_percentiles']['p50']:.3f}")

    print(f"total ms/img p50: {summary['total_ms_percentiles']['p50']:.3f} | p95: {summary['total_ms_percentiles']['p95']:.3f}")
    print(f"model ms/img p50: {summary['model_ms_percentiles']['p50']:.3f} | p95: {summary['model_ms_percentiles']['p95']:.3f}")
    print(f"overhd ms/img p50: {summary['overhead_ms_percentiles']['p50']:.3f} | p95: {summary['overhead_ms_percentiles']['p95']:.3f}")

    print(
        f"ms/forward TOTAL p50: {summary['ms_per_forward_total_percentiles']['p50']:.6f} | "
        f"p95: {summary['ms_per_forward_total_percentiles']['p95']:.6f}"
    )
    print(
        f"ms/forward MODEL p50: {summary['ms_per_forward_model_percentiles']['p50']:.6f} | "
        f"p95: {summary['ms_per_forward_model_percentiles']['p95']:.6f}"
    )
    print(
        f"throughput img/s p50: {summary['throughput_img_per_s_percentiles']['p50']:.2f} | "
        f"p95: {summary['throughput_img_per_s_percentiles']['p95']:.2f}"
    )

    if "peak_alloc_mb_p50" in summary:
        print(f"peak alloc MB p50: {summary['peak_alloc_mb_p50']:.1f} | peak reserved MB p50: {summary['peak_reserved_mb_p50']:.1f}")

NUM_STEPS = 500
GUIDANCE_W = 4.0

bench_cfg_b16 = run_sampling_benchmark(
    ema_model, sde, sample_vp_probflow_heun_cfg_timed,
    num_steps=NUM_STEPS, batch_size=16, device=device, t_min=t_min,
    w=GUIDANCE_W, num_classes=10
)
pretty_print_summary(bench_cfg_b16, name=f"VP CFG + attn | Heun ODE | steps={NUM_STEPS} | w={GUIDANCE_W} | batch=16")

bench_cfg_b128 = run_sampling_benchmark(
    ema_model, sde, sample_vp_probflow_heun_cfg_timed,
    num_steps=NUM_STEPS, batch_size=128, device=device, t_min=t_min,
    w=GUIDANCE_W, num_classes=10
)
pretty_print_summary(bench_cfg_b128, name=f"VP CFG + attn | Heun ODE | steps={NUM_STEPS} | w={GUIDANCE_W} | batch=128")



=== Benchmark: VP CFG + attn | Heun ODE | steps=500 | w=4.0 | batch=16 ===
steps=500 | batch=16 | timed_runs=10
NFE/sample p50: 1996.0 | p95: 1996.0
NFE/img (legacy: NFE/sample ÷ batch) p50: 124.750
total ms/img p50: 2654.144 | p95: 2705.568
model ms/img p50: 2608.652 | p95: 2657.949
overhd ms/img p50: 46.735 | p95: 48.914
ms/forward TOTAL p50: 21.275703 | p95: 21.687921
ms/forward MODEL p50: 20.911042 | p95: 21.306202
throughput img/s p50: 0.38 | p95: 0.39
peak alloc MB p50: 482.5 | peak reserved MB p50: 674.0

=== Benchmark: VP CFG + attn | Heun ODE | steps=500 | w=4.0 | batch=128 ===
steps=500 | batch=128 | timed_runs=10
NFE/sample p50: 1996.0 | p95: 1996.0
NFE/img (legacy: NFE/sample ÷ batch) p50: 15.594
total ms/img p50: 1932.798 | p95: 1940.036
model ms/img p50: 1924.935 | p95: 1932.199
overhd ms/img p50: 7.865 | p95: 7.965
ms/forward TOTAL p50: 123.946948 | p95: 124.411100
ms/forward MODEL p50: 123.442733 | p95: 123.908530
throughput img/s p50: 0.52 | p95: 0.52
peak alloc MB p5