In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5),
                                     (0.5, 0.5, 0.5))
])


trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2, pin_memory=True)

In [None]:
import torch
import math

class VPSDE:
    def __init__(self, beta_min=0.1, beta_max=20.0, T=1.0):
        self.beta_min = beta_min
        self.beta_max = beta_max
        self.T = T

    def beta(self, t):
        # t in [0,1], shape [B]
        return self.beta_min + t * (self.beta_max - self.beta_min)

    def int_beta(self, t):
        # ∫0^t beta(s) ds for linear beta
        return self.beta_min * t + 0.5 * (self.beta_max - self.beta_min) * t**2

    def alpha(self, t):
        # alpha(t) = exp(-1/2 ∫ beta)
        return torch.exp(-0.5 * self.int_beta(t))

    def sigma(self, t):
        # sigma(t) = sqrt(1 - alpha(t)^2)
        a = self.alpha(t)
        return torch.sqrt(1.0 - a*a).clamp(min=1e-5)

    def diffusion(self, t):
        # g(t) = sqrt(beta(t))
        return torch.sqrt(self.beta(t)).clamp(min=1e-5)

    def drift(self, x, t):
        # f(x,t) = -1/2 beta(t) x
        b = self.beta(t).view(-1, 1, 1, 1)
        return -0.5 * b * x


In [None]:
import torch
import torch.nn as nn

class TimeEmbedding(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        self.dim = dim
        self.lin = nn.Linear(dim, dim)

    def forward(self, t):
        half_dim = self.dim // 2
        emb = torch.exp(
            torch.arange(half_dim, device=t.device) * -(torch.log(torch.tensor(10000.0, device=t.device)) / half_dim)
        )
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
        return self.lin(emb)  # [B, dim]


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, groups=32):
        super().__init__()
        # groups must divide out_ch. For 128,256,512,1024 channels, 32 works.
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(groups, out_ch),
            nn.SiLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


In [None]:
class UNetScoreCIFAR3Level(nn.Module):
    def __init__(self, time_dim=32, base_channels=128, img_channels=3):
        super().__init__()
        self.time_mlp = TimeEmbedding(dim=time_dim)
        in_ch = img_channels + time_dim  # 3 + time_dim

        C = base_channels

        # -------- Encoder --------
        self.down1 = ConvBlock(in_ch, C)        # 32x32
        self.pool1 = nn.MaxPool2d(2)            # 32->16

        self.down2 = ConvBlock(C, 2*C)          # 16x16
        self.pool2 = nn.MaxPool2d(2)            # 16->8

        self.down3 = ConvBlock(2*C, 4*C)        # 8x8
        self.pool3 = nn.MaxPool2d(2)            # 8->4

        # -------- Bottleneck --------
        self.bottleneck = ConvBlock(4*C, 8*C)   # 4x4

        # -------- Decoder --------
        self.up3 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),
            nn.Conv2d(8*C, 4*C, 3, padding=1),
        )

        self.dec3 = ConvBlock(8*C, 4*C)          # concat(4C + 4C)=8C -> 4C

        self.up2 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),     # 8 -> 16
            nn.Conv2d(4*C, 2*C, kernel_size=3, padding=1),
        )
        self.dec2 = ConvBlock(4*C, 2*C)          # concat(2C + 2C)=4C -> 2C

        self.up1 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode="nearest"),     # 16 -> 32
            nn.Conv2d(2*C, C, kernel_size=3, padding=1),
        )
        self.dec1 = ConvBlock(2*C, C)            # concat(C + C)=2C -> C

        # score output has same channels as x: 3
        self.out_conv = nn.Conv2d(C, img_channels, kernel_size=3, padding=1)

    def forward(self, x, t):
        # time embedding to spatial map
        emb = self.time_mlp(t)                   # [B, time_dim]
        emb = emb[:, :, None, None]              # [B, time_dim, 1, 1]
        emb = emb.expand(-1, -1, x.size(2), x.size(3))  # [B, time_dim, H, W]

        x_in = torch.cat([x, emb], dim=1)        # [B, 3+time_dim, 32, 32]

        # Encoder
        d1 = self.down1(x_in)                    # [B, C, 32, 32]
        p1 = self.pool1(d1)                      # [B, C, 16, 16]

        d2 = self.down2(p1)                      # [B, 2C, 16, 16]
        p2 = self.pool2(d2)                      # [B, 2C, 8, 8]

        d3 = self.down3(p2)                      # [B, 4C, 8, 8]
        p3 = self.pool3(d3)                      # [B, 4C, 4, 4]

        # Bottleneck
        b = self.bottleneck(p3)                  # [B, 8C, 4, 4]

        # Decoder
        u3 = self.up3(b)                         # [B, 4C, 8, 8]
        u3 = torch.cat([u3, d3], dim=1)          # [B, 8C, 8, 8]
        u3 = self.dec3(u3)                       # [B, 4C, 8, 8]

        u2 = self.up2(u3)                        # [B, 2C, 16, 16]
        u2 = torch.cat([u2, d2], dim=1)          # [B, 4C, 16, 16]
        u2 = self.dec2(u2)                       # [B, 2C, 16, 16]

        u1 = self.up1(u2)                        # [B, C, 32, 32]
        u1 = torch.cat([u1, d1], dim=1)          # [B, 2C, 32, 32]
        u1 = self.dec1(u1)                       # [B, C, 32, 32]

        return self.out_conv(u1)                 # [B, 3, 32, 32]


In [None]:
@torch.no_grad()
def vp_sample_xt(x0, t, sde):
    """
    x0: [B,3,32,32], t: [B]
    returns xt, eps where xt = alpha x0 + sigma eps
    """
    a = sde.alpha(t).view(-1, 1, 1, 1)
    s = sde.sigma(t).view(-1, 1, 1, 1)
    eps = torch.randn_like(x0)
    xt = a * x0 + s * eps
    return xt, eps


In [None]:
import os
import torch
import torch.nn.functional as F
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"

ckpt_path = "vp_cifar_ckpt.pth"

# ---- hyperparams ----
time_dim = 32
base_channels = 128
img_channels = 3

beta_min = 0.1
beta_max = 20.0
T = 1.0

lr = 3e-5
weight_decay = 1e-4

ema_decay = 0.999
t_min = 1e-4

num_epochs_total = 150      # total epochs you want to reach
save_every = 10


@torch.no_grad()
def update_ema(ema_model, model, decay):
    for p_ema, p in zip(ema_model.parameters(), model.parameters()):
        p_ema.data.mul_(decay).add_(p.data, alpha=1 - decay)


# ---- build models ----
model = UNetScoreCIFAR3Level(time_dim=time_dim, base_channels=base_channels, img_channels=img_channels).to(device)
ema_model = UNetScoreCIFAR3Level(time_dim=time_dim, base_channels=base_channels, img_channels=img_channels).to(device)

# ---- build SDE ----
sde = VPSDE(beta_min=beta_min, beta_max=beta_max, T=T)

# ---- optimizer ----
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

# ---- resume if possible ----
start_epoch = 0
if os.path.exists(ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=device)

    # Safety: only resume if architecture matches; otherwise it will error here.
    model.load_state_dict(ckpt["model_state"])
    ema_model.load_state_dict(ckpt["ema_state"])

    if "optimizer_state" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer_state"])

    # Restore schedule params if present (avoids mismatch)
    if "beta_min" in ckpt and "beta_max" in ckpt and "T" in ckpt:
        sde = VPSDE(beta_min=float(ckpt["beta_min"]), beta_max=float(ckpt["beta_max"]), T=float(ckpt["T"]))

    start_epoch = int(ckpt.get("epoch", -1)) + 1

    # IMPORTANT: optimizer state restores old LR; overwrite with desired LR
    for g in optimizer.param_groups:
        g["lr"] = lr
        g["weight_decay"] = weight_decay

    print(f"Resuming from {ckpt_path} at epoch {start_epoch} (lr={optimizer.param_groups[0]['lr']})")

else:
    # start fresh
    ema_model.load_state_dict(model.state_dict())
    print(f"No checkpoint found at {ckpt_path}. Starting fresh at epoch 0 (lr={lr}).")


# ---- training ----
end_epoch = num_epochs_total

for epoch in range(start_epoch, end_epoch):
    model.train()
    epoch_loss = 0.0
    n_batches = 0

    for x0, _ in loader:
        x0 = x0.to(device)

        # t ~ Uniform(t_min, 1)
        t = t_min + (1.0 - t_min) * torch.rand(x0.size(0), device=device)

        xt, eps = vp_sample_xt(x0, t, sde)
        eps_pred = model(xt, t)

        loss = F.mse_loss(eps_pred, eps)

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        update_ema(ema_model, model, ema_decay)

        epoch_loss += loss.item()
        n_batches += 1

    mean_loss = epoch_loss / max(n_batches, 1)
    print(f"Epoch {epoch+1}/{end_epoch}: mean loss = {mean_loss:.4f}")

    # ---- checkpoint ----
    if (epoch + 1) % save_every == 0 or (epoch + 1) == end_epoch:
        torch.save(
            {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "ema_state": ema_model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "beta_min": sde.beta_min,
                "beta_max": sde.beta_max,
                "T": sde.T,
                "time_dim": time_dim,
                "base_channels": base_channels,
                "img_channels": img_channels,
                "t_min": t_min,
            },
            ckpt_path,
        )
        print(f"Checkpoint saved at epoch {epoch+1} → {ckpt_path}")


In [None]:
@torch.no_grad()
def sample_prob_flow_ode(model, sde, num_steps=2000, batch_size=16, device="cuda", t_min=1e-4):
    model.eval()
    t_grid = torch.linspace(1.0, t_min, num_steps, device=device)

    x = torch.randn(batch_size, 3, 32, 32, device=device)

    for i in range(num_steps - 1):
        t_cur = t_grid[i]
        t_next = t_grid[i+1]
        dt = t_next - t_cur  # negative

        t_batch = torch.full((batch_size,), t_cur, device=device)

        beta = sde.beta(t_batch).view(batch_size, 1, 1, 1)
        sigma = sde.sigma(t_batch).view(batch_size, 1, 1, 1)

        eps_pred = model(x, t_batch)
        score = -eps_pred / sigma

        f = -0.5 * beta * x
        drift = f - 0.5 * beta * score

        x = x + drift * dt

    return x


In [None]:
import matplotlib.pyplot as plt
import math

def show_cifar_grid(x, nrow=4, title="Samples"):
    x = x.detach().cpu()
    x = (x.clamp(-1, 1) + 1) / 2.0  # -> [0,1]
    B = x.size(0)
    ncol = math.ceil(B / nrow)

    plt.figure(figsize=(ncol*2, nrow*2))
    for i in range(B):
        plt.subplot(nrow, ncol, i+1)
        img = x[i].permute(1, 2, 0).numpy()
        plt.imshow(img)
        plt.axis("off")
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


In [None]:
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model_state"])
ema_model.load_state_dict(ckpt["ema_state"])
samples = sample_prob_flow_ode(ema_model, sde, num_steps=4000, batch_size=16, device=device, t_min=1e-4)
show_cifar_grid(samples, nrow=4, title="CIFAR-10 samples (VP + ODE, EMA)")
