In [1]:
pip install matplotlib

Note: you may need to restart the kernel to use updated packages.


In [2]:
pip install gdown

Note: you may need to restart the kernel to use updated packages.


In [3]:
import gdown
import os
file_id = "1UWUP4u-zcY1AOOD5GoF0xKZiQ-kn2X7Z"
out_path = "cifar_ou_no_attention_ema_weights.pth"

if not os.path.exists(out_path):
    gdown.download(
        f"https://drive.google.com/uc?id={file_id}",
        out_path,
        quiet=False,
    )
else:
    print("EMA weights already downloaded.")


Downloading...
From (original): https://drive.google.com/uc?id=1UWUP4u-zcY1AOOD5GoF0xKZiQ-kn2X7Z
From (redirected): https://drive.google.com/uc?id=1UWUP4u-zcY1AOOD5GoF0xKZiQ-kn2X7Z&confirm=t&uuid=fe782980-c6d9-41f1-8f91-2292ee7d251e
To: /home/onyxia/work/cifar_ou_no_attention_ema_weights.pth
100%|██████████| 123M/123M [00:01<00:00, 107MB/s]  


In [4]:
import math
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"


In [5]:
class Forward:  # OU process
    def __init__(self, lambda_=0.3):
        self.lmbd = float(lambda_)

    def mean(self, x0, t):
        t = t[:, None, None, None]
        return x0 * torch.exp(-self.lmbd * t)

    def std(self, t):
        t = t[:, None, None, None]
        return torch.sqrt(1.0 - torch.exp(-2 * self.lmbd * t)).clamp(min=1e-3)

    def diffusion_coeff(self, t):
        # shape [B]
        return (2 * self.lmbd) ** 0.5 * torch.ones_like(t)


In [6]:
class TimeEmbedding(nn.Module):
    def __init__(self, dim=32):
        super().__init__()
        self.dim = dim
        self.lin = nn.Linear(dim, dim)

    def forward(self, t):
        half_dim = self.dim // 2
        freqs = torch.exp(
            torch.arange(half_dim, device=t.device, dtype=t.dtype)
            * -(torch.log(torch.tensor(10000.0, device=t.device, dtype=t.dtype)) / half_dim)
        )
        args = t[:, None] * freqs[None, :]
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        return self.lin(emb)  # [B, dim]


class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.block(x)


class UNetScoreCIFAR3Level(nn.Module):
    """
    Score network for OU reverse SDE:
      score = model(x, t)  with shape [B,3,32,32]
    """
    def __init__(self, time_dim=32, base_channels=128, img_channels=3):
        super().__init__()
        self.time_mlp = TimeEmbedding(dim=time_dim)
        in_ch = img_channels + time_dim
        C = base_channels

        # Encoder
        self.down1 = ConvBlock(in_ch, C)      # 32x32
        self.pool1 = nn.MaxPool2d(2)          # 32->16

        self.down2 = ConvBlock(C, 2*C)        # 16x16
        self.pool2 = nn.MaxPool2d(2)          # 16->8

        self.down3 = ConvBlock(2*C, 4*C)      # 8x8
        self.pool3 = nn.MaxPool2d(2)          # 8->4

        # Bottleneck
        self.bottleneck = ConvBlock(4*C, 8*C) # 4x4

        # Decoder (ConvTranspose2d so keys are up3.weight, up3.bias, etc.)
        self.up3 = nn.ConvTranspose2d(8*C, 4*C, kernel_size=2, stride=2)  # 4->8
        self.dec3 = ConvBlock(8*C, 4*C)

        self.up2 = nn.ConvTranspose2d(4*C, 2*C, kernel_size=2, stride=2)  # 8->16
        self.dec2 = ConvBlock(4*C, 2*C)

        self.up1 = nn.ConvTranspose2d(2*C, C, kernel_size=2, stride=2)    # 16->32
        self.dec1 = ConvBlock(2*C, C)

        self.out_conv = nn.Conv2d(C, img_channels, kernel_size=3, padding=1)

    def forward(self, x, t):
        # time embedding to spatial map
        emb = self.time_mlp(t)[:, :, None, None].expand(-1, -1, x.size(2), x.size(3))
        x_in = torch.cat([x, emb], dim=1)

        d1 = self.down1(x_in); p1 = self.pool1(d1)
        d2 = self.down2(p1);   p2 = self.pool2(d2)
        d3 = self.down3(p2);   p3 = self.pool3(d3)

        b = self.bottleneck(p3)

        u3 = self.up3(b)
        u3 = torch.cat([u3, d3], dim=1)
        u3 = self.dec3(u3)

        u2 = self.up2(u3)
        u2 = torch.cat([u2, d2], dim=1)
        u2 = self.dec2(u2)

        u1 = self.up1(u2)
        u1 = torch.cat([u1, d1], dim=1)
        u1 = self.dec1(u1)

        return self.out_conv(u1)


In [7]:
@torch.no_grad()
def sample_reverse_euler_maruyama_cifar(model, sde, num_steps=1000, batch_size=16, device="cuda", t_min=0.01):
    model.eval()
    T = 1.0
    t_grid = torch.linspace(T, t_min, num_steps, device=device)

    x = torch.randn(batch_size, 3, 32, 32, device=device)

    for i in range(num_steps - 1):
        t_cur = t_grid[i]
        t_next = t_grid[i + 1]
        dt = t_next - t_cur  # dt < 0

        t_batch = torch.full((batch_size,), t_cur, device=device)

        g = sde.diffusion_coeff(t_batch)         # [B]
        g2 = (g ** 2).view(batch_size, 1, 1, 1)
        g = g.view(batch_size, 1, 1, 1)

        score = model(x, t_batch)                # [B,3,32,32]
        drift = -sde.lmbd * x - g2 * score

        noise = torch.randn_like(x)
        x = x + drift * dt + g * torch.sqrt(-dt) * noise

    return x


In [8]:
def show_cifar_grid(x, nrow=4, title="Samples"):
    x = x.detach().cpu()
    x = (x.clamp(-1, 1) + 1) / 2.0  # to [0,1]

    B = x.size(0)
    ncol = math.ceil(B / nrow)

    plt.figure(figsize=(ncol * 2, nrow * 2))
    for i in range(B):
        plt.subplot(nrow, ncol, i + 1)
        img = x[i].permute(1, 2, 0).numpy()  # HWC
        plt.imshow(img)
        plt.axis("off")
    plt.suptitle(title)
    plt.tight_layout()
    plt.show()


In [None]:
# ---- match ckpt metadata / training hyperparams ----
time_dim = 32
base_channels = 128
lambda_ = 0.3      # (use the value you trained with; your ckpt also stores "lambda")
t_min = 0.01

ema_weights_path = "cifar_ou_no_attention_ema_weights.pth"  # output of save_ema_only()

sde = Forward(lambda_=2)

ema_model = UNetScoreCIFAR3Level(
    time_dim=time_dim,
    base_channels=base_channels,
    img_channels=3,
).to(device)

ema_sd = torch.load(ema_weights_path, map_location=device)
ema_model.load_state_dict(ema_sd, strict=True)
ema_model.eval()

samples = sample_reverse_euler_maruyama_cifar(
    ema_model, sde, num_steps=2000, batch_size=16, device=device, t_min=t_min
)
show_cifar_grid(samples, nrow=4, title="CIFAR-10 samples (EMA, no-attention OU)")
