In [1]:
# ddpm_full_finetune_toy.py
import math, os, random, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

In [2]:
# ----------------------------
# 0) Utilities
# ----------------------------
def set_seed(seed=0):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
os.makedirs("samples", exist_ok=True)


In [3]:
# ----------------------------
# 1) Toy datasets (no downloads)
# ----------------------------
def make_blob(h, w, radius, cx, cy):
    """Return a binary circle mask."""
    y, x = np.ogrid[:h, :w]
    return ((x - cx)**2 + (y - cy)**2) <= radius**2

def make_square(h, w, side, cx, cy):
    """Return a binary square mask centered at (cx,cy)."""
    x0, y0 = int(cx - side//2), int(cy - side//2)
    mask = np.zeros((h, w), dtype=bool)
    mask[max(y0,0):min(y0+side,h), max(x0,0):min(x0+side,w)] = True
    return mask

def render_toy(h=32, w=32, kind="blobs"):
    img = np.zeros((h, w), dtype=np.float32)
    if kind == "blobs":
        # 1â€“3 circles with random radii/centers
        for _ in range(np.random.randint(1, 4)):
            r = np.random.randint(3, 8)
            cx, cy = np.random.randint(r, w-r), np.random.randint(r, h-r)
            img[make_blob(h, w, r, cx, cy)] = 1.0
    elif kind == "squares":
        for _ in range(np.random.randint(1, 3)):
            s = np.random.randint(4, 10)
            cx, cy = np.random.randint(s//2, w-s//2), np.random.randint(s//2, h-s//2)
            img[make_square(h, w, s, cx, cy)] = 1.0
    elif kind == "mixed":
        if np.random.rand() < 0.5:
            return render_toy(h, w, "blobs")
        else:
            return render_toy(h, w, "squares")
    # slight blur-ish antialias: distance transform lite
    img = img + 0.1 * np.random.randn(h, w).astype(np.float32)
    img = np.clip(img, 0, 1)
    return img

class ToyDataset(Dataset):
    def __init__(self, n=2000, kind="blobs"):
        self.n = n; self.kind = kind
    def __len__(self): return self.n
    def __getitem__(self, idx):
        img = render_toy(kind=self.kind)  # [H,W] in [0,1]
        img = 2.0 * img - 1.0             # to [-1,1]
        img = torch.from_numpy(img)[None, ...]  # [1,H,W]
        return img


In [8]:
# ----------------------------
# 2) Tiny UNet (2D, single-channel)
# ----------------------------
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch), nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch), nn.SiLU(),
        )
    def forward(self, x): return self.net(x)

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.pool = nn.Conv2d(in_ch, in_ch, 3, stride=2, padding=1)
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x):
        x = self.pool(x)
        return self.conv(x)

class Up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="nearest")
        self.conv = DoubleConv(in_ch, out_ch)
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

class TimeEmbedding(nn.Module):
    def __init__(self, dim, max_period=10000):
        super().__init__()
        self.dim = dim
        self.lin1 = nn.Linear(dim, dim*4)
        self.lin2 = nn.Linear(dim*4, dim)
    def forward(self, t):
        # sinusoidal embedding
        half = self.dim // 2
        freqs = torch.exp(-math.log(10000) * torch.arange(0, half, device=t.device).float() / half)
        args = t[:, None].float() * freqs[None]
        emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        emb = self.lin2(F.silu(self.lin1(emb)))
        return emb

class UNetSmall(nn.Module):
    def __init__(self, in_ch=1, base=32, time_dim=64):
        super().__init__()
        self.time_mlp = TimeEmbedding(time_dim)
        self.inc = DoubleConv(in_ch, base)
        self.down1 = Down(base, base*2)
        self.down2 = Down(base*2, base*4)
        self.bot  = DoubleConv(base*4, base*4)
        self.up2  = Up(base*4 + base*2, base*2)
        self.up1  = Up(base*2 + base, base)
        self.outc = nn.Conv2d(base, in_ch, 1)
        # FiLM-like time conditioning
        # self.t_to_scale = nn.Linear(time_dim, base*3)
        # self.t_to_shift = nn.Linear(time_dim, base*3)
        # FiLM per stage (channels match each feature map)
        self.s1, self.b1 = nn.Linear(time_dim, base),    nn.Linear(time_dim, base)
        self.s2, self.b2 = nn.Linear(time_dim, base*2),  nn.Linear(time_dim, base*2)
        self.s3, self.b3 = nn.Linear(time_dim, base*4),  nn.Linear(time_dim, base*4)

    def film(self, x, s, b):
        return x * (1 + s[..., None, None]) + b[..., None, None]

    def apply_time(self, xs, t_emb):
        scales = self.t_to_scale(t_emb).chunk(3, dim=-1)
        shifts = self.t_to_shift(t_emb).chunk(3, dim=-1)
        outs = []
        for x, s, b in zip(xs, scales, shifts):
            outs.append(x * (1 + s[..., None, None]) + b[..., None, None])
        return outs

    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        xb = self.bot(x3)
        # time condition on three stages (simple demo)
        #x1, x2, xb = self.apply_time([x1, x2, xb], t_emb)
        # new: per-stage FiLM with matching channels
        x1 = self.film(x1, self.s1(t_emb), self.b1(t_emb))
        x2 = self.film(x2, self.s2(t_emb), self.b2(t_emb))
        xb = self.film(xb, self.s3(t_emb), self.b3(t_emb))
        u2 = self.up2(xb, x2)
        u1 = self.up1(u2, x1)
        return self.outc(u1)


In [9]:
# ----------------------------
# 3) DDPM schedule & helpers
# ----------------------------
class DDPM:
    def __init__(self, timesteps=200, beta_start=1e-4, beta_end=0.02, device=DEVICE):
        self.T = timesteps
        beta = torch.linspace(beta_start, beta_end, self.T, device=device)
        alpha = 1.0 - beta
        self.register(alpha, beta)

    def register(self, alpha, beta):
        self.alpha = alpha
        self.beta = beta
        self.alphabar = torch.cumprod(alpha, dim=0)
        self.sqrt_ab = torch.sqrt(self.alphabar)                           # [T]
        self.sqrt_1mab = torch.sqrt(torch.clamp(1.0 - self.alphabar, 1e-20))  # [T]
        self.one_over_sqrt_a = torch.sqrt(1.0 / self.alpha)
        # length T-1, valid for t >= 1
        self.posterior_var = (
                 self.beta[1:] * (1.0 - self.alphabar[:-1])
                  / torch.clamp(1.0 - self.alphabar[1:], 1e-20)
             )
    def q_sample(self, x0, t, noise=None):
        if noise is None: noise = torch.randn_like(x0)
        sqrt_ab = self.sqrt_ab[t].view(-1,1,1,1)
        sqrt_1mab = self.sqrt_1mab[t].view(-1,1,1,1)
        return sqrt_ab * x0 + sqrt_1mab * noise

    @torch.no_grad()
    def p_sample(self, model, x_t, t):
        # predict noise
        eps = model(x_t, t)
        a_t = self.alpha[t].view(-1,1,1,1)
        ab_t = self.alphabar[t].view(-1,1,1,1)
        one_over_sqrt_a = self.one_over_sqrt_a[t].view(-1,1,1,1)
        # predicted x0
        x0_hat = (x_t - torch.sqrt(torch.clamp(1 - ab_t, 1e-20)) * eps) / torch.sqrt(torch.clamp(ab_t, 1e-20))
        mean   = one_over_sqrt_a * (x_t - (1 - a_t) / torch.sqrt(torch.clamp(1 - ab_t, 1e-20)) * eps)
        # final step: return clean prediction
        if (t == 0).all():
          return x0_hat
        var = self.posterior_var[(t-1).clamp(min=0)].view(-1,1,1,1)
        
        noise = torch.randn_like(x_t)
        return mean + torch.sqrt(torch.clamp(var, 1e-20)) * noise

    @torch.no_grad()
    def sample(self, model, n=16, shape=(1,32,32)):
        model.eval()
        x = torch.randn(n, *shape, device=DEVICE)
        for ti in reversed(range(self.T)):
            t = torch.full((n,), ti, device=DEVICE, dtype=torch.long)
            x = self.p_sample(model, x, t)
        return x.clamp(-1,1)

In [10]:
# ----------------------------
# 4) Training / Fine-tuning loops
# ----------------------------
def train_epoch(model, ddpm, loader, opt):
    model.train()
    total = 0.0
    for x0 in loader:
        x0 = x0.to(DEVICE)
        b = x0.size(0)
        t = torch.randint(0, ddpm.T, (b,), device=DEVICE, dtype=torch.long)
        noise = torch.randn_like(x0)
        x_t = ddpm.q_sample(x0, t, noise)
        pred = model(x_t, t)
        loss = F.mse_loss(pred, noise)
        opt.zero_grad(); loss.backward(); opt.step()
        total += loss.item() * b
    return total / len(loader.dataset)

def save_grid(tensor, path, nrow=4):
    # tensor: [N,1,H,W] in [-1,1]
    x = (tensor.add(1).mul(0.5)).cpu().numpy()
    N,C,H,W = x.shape
    rows = (N + nrow - 1) // nrow
    canvas = np.ones((rows*H, nrow*W), dtype=np.float32)
    idx = 0
    for r in range(rows):
        for c in range(nrow):
            if idx < N:
                canvas[r*H:(r+1)*H, c*W:(c+1)*W] = x[idx,0]
                idx += 1
    plt.figure(figsize=(nrow, rows))
    plt.axis("off")
    plt.imshow(canvas, vmin=0, vmax=1, cmap="gray")
    plt.tight_layout()
    plt.savefig(path, dpi=150)
    plt.close()


In [11]:
def main():
    set_seed(3)
    # Pretrain on domain A (blobs)
    trainA = ToyDataset(n=2000, kind="blobs")
    trainB = ToyDataset(n=2000, kind="squares")  # OOD target for fine-tune
    loaderA = DataLoader(trainA, batch_size=64, shuffle=True, num_workers=0)
    loaderB = DataLoader(trainB, batch_size=64, shuffle=True, num_workers=0)

    model = UNetSmall(in_ch=1, base=32, time_dim=64).to(DEVICE)
    ddpm  = DDPM(timesteps=200, beta_start=1e-4, beta_end=2e-2, device=DEVICE)

    # ---------- Stage 1: pretrain (a few epochs just to see learning) ----------
    opt = torch.optim.AdamW(model.parameters(), lr=2e-4)
    for epoch in range(5):
        loss = train_epoch(model, ddpm, loaderA, opt)
        print(f"[pretrain] epoch {epoch+1:02d} loss={loss:.4f}")

    with torch.no_grad():
        samples = ddpm.sample(model, n=16, shape=(1,32,32))
    save_grid(samples, "samples/pretrained_on_blobs.png")
    print("saved:", "samples/pretrained_on_blobs.png")

    # ---------- Stage 2: FULL fine-tune on domain B (ALL params trainable) ----------
    # (This is "full FT": we do NOT freeze any layers.)
    opt_ft = torch.optim.AdamW(model.parameters(), lr=1e-4)
    for epoch in range(5):
        loss = train_epoch(model, ddpm, loaderB, opt_ft)
        print(f"[finetune] epoch {epoch+1:02d} loss={loss:.4f}")

    with torch.no_grad():
        samples_ft = ddpm.sample(model, n=16, shape=(1,32,32))
    save_grid(samples_ft, "samples/finetuned_on_squares.png")
    print("saved:", "samples/finetuned_on_squares.png")

if __name__ == "__main__":
    main()

[pretrain] epoch 01 loss=0.4832
[pretrain] epoch 02 loss=0.2481
[pretrain] epoch 03 loss=0.2010
[pretrain] epoch 04 loss=0.1857
[pretrain] epoch 05 loss=0.1676
saved: samples/pretrained_on_blobs.png
[finetune] epoch 01 loss=0.1564
[finetune] epoch 02 loss=0.1373
[finetune] epoch 03 loss=0.1381
[finetune] epoch 04 loss=0.1311
[finetune] epoch 05 loss=0.1285
saved: samples/finetuned_on_squares.png
