In [3]:
import math, os, argparse, random, itertools, time
from dataclasses import dataclass

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils

# ---------------------------
# Utilities
# ---------------------------

def set_seed(seed=42):
    random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def exists(x): return x is not None

def default(val, d): return val if exists(val) else d

def count_params(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def save_image_grid(tensor, path, nrow=8, normalize=True, value_range=(-1,1)):
    # tensor in [-1,1] -> [0,1]
    if normalize:
        lo, hi = value_range
        tensor = (tensor - lo) / (hi - lo)
        tensor = tensor.clamp(0, 1)
    vutils.save_image(tensor, path, nrow=nrow)

# ---------------------------
# Sinusoidal timestep embeddings (Transformer-style)
# ---------------------------

def timestep_embedding(timesteps, dim, max_period=10000):
    """
    timesteps: (batch,) with values in [0, T-1]
    returns: (batch, dim)
    """
    half = dim // 2
    freqs = torch.exp(-math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=timesteps.device) / half)
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=-1)
    return emb

# ---------------------------
# Building blocks: ResBlock, Attention, Down/Up
# ---------------------------

class SiLU(nn.Module):
    def forward(self, x): return x * torch.sigmoid(x)

class TimeMLP(nn.Module):
    def __init__(self, time_dim, out_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(time_dim, out_dim),
            SiLU(),
            nn.Linear(out_dim, out_dim),
        )
    def forward(self, t_emb): return self.net(t_emb)

class ResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim, dropout=0.0):
        super().__init__()
        self.norm1 = nn.GroupNorm(32, in_ch)
        self.act1  = SiLU()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)

        self.time = TimeMLP(time_dim, out_ch)

        self.norm2 = nn.GroupNorm(32, out_ch)
        self.act2  = SiLU()
        self.dropout = nn.Dropout(dropout)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)

        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb):
        h = self.conv1(self.act1(self.norm1(x)))
        h = h + self.time(t_emb)[:, :, None, None]
        h = self.conv2(self.dropout(self.act2(self.norm2(h))))
        return h + self.skip(x)

class SelfAttention2d(nn.Module):
    def __init__(self, channels, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.norm = nn.GroupNorm(32, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        h_ = self.norm(x)
        qkv = self.qkv(h_)
        q, k, v = qkv.chunk(3, dim=1)

        # reshape to (b, heads, dim, hw)
        heads = self.num_heads
        dim_head = c // heads
        q = q.view(b, heads, dim_head, h*w)
        k = k.view(b, heads, dim_head, h*w)
        v = v.view(b, heads, dim_head, h*w)

        attn = torch.einsum('bhdi,bhdj->bhij', q, k) / math.sqrt(dim_head)
        attn = attn.softmax(dim=-1)
        out = torch.einsum('bhij,bhdj->bhdi', attn, v)
        out = out.reshape(b, c, h, w)
        return x + self.proj(out)

class Downsample(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.conv = nn.Conv2d(ch_in, ch_out, 3, stride=2, padding=1)
    def forward(self, x): return self.conv(x)

class Upsample(nn.Module):
    def __init__(self, ch_in, ch_out):
        super().__init__()
        self.conv = nn.Conv2d(ch_in, ch_out, 3, padding=1)
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        return self.conv(x)

# ---------------------------
# U-Net backbone (PixelCNN++ style, with attention at 16x16)
# ---------------------------

class UNet(nn.Module):
    def __init__(
        self,
        in_channels=3,
        base_channels=128,
        channel_mults=(1, 2, 2, 2),  # resolutions: 32,16,8,4
        num_res_blocks=2,
        time_dim=512,
        dropout=0.1,
        attn_resolutions=(16,),  # add attention at 16x16 as in the paper
        img_size=32
    ):
        super().__init__()
        self.img_size = img_size
        self.time_mlp = nn.Sequential(
            nn.Linear(time_dim, time_dim),
            SiLU(),
            nn.Linear(time_dim, time_dim),
        )

        self.time_in = nn.Linear(time_dim//2*2, time_dim)  # we pass sinusoidal, then project

        # input conv
        self.in_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)

        # Down path
        in_ch = base_channels
        self.downs = nn.ModuleList()
        self.downs_attn = nn.ModuleList()
        hs = [in_ch]
        curr_res = img_size
        for i, mult in enumerate(channel_mults):
            out_ch = base_channels * mult
            for _ in range(num_res_blocks):
                self.downs.append(ResBlock(in_ch, out_ch, time_dim, dropout))
                in_ch = out_ch
                hs.append(in_ch)
                # attention at chosen resolution
                self.downs_attn.append(SelfAttention2d(in_ch) if curr_res in attn_resolutions else nn.Identity())
            if i != len(channel_mults) - 1:
                self.downs.append(Downsample(in_ch, in_ch))
                self.downs_attn.append(nn.Identity())
                curr_res //= 2
                hs.append(in_ch)

        # Middle
        self.mid_block1 = ResBlock(in_ch, in_ch, time_dim, dropout)
        self.mid_attn = SelfAttention2d(in_ch)
        self.mid_block2 = ResBlock(in_ch, in_ch, time_dim, dropout)

        # Up path
        self.ups = nn.ModuleList()
        self.ups_attn = nn.ModuleList()
        for i, mult in reversed(list(enumerate(channel_mults))):
            out_ch = base_channels * mult
            for _ in range(num_res_blocks + 1):  # +1 to consume skip from either res or downsample
                skip_ch = hs.pop()
                self.ups.append(ResBlock(in_ch + skip_ch, out_ch, time_dim, dropout))
                in_ch = out_ch
                self.ups_attn.append(SelfAttention2d(in_ch) if curr_res in attn_resolutions else nn.Identity())
            if i != 0:
                self.ups.append(Upsample(in_ch, in_ch))
                self.ups_attn.append(nn.Identity())
                curr_res *= 2

        assert len(hs) == 0

        # Output
        self.out_norm = nn.GroupNorm(32, in_ch)
        self.out_act = SiLU()
        self.out_conv = nn.Conv2d(in_ch, in_channels, 3, padding=1)

        self.time_dim = time_dim

    def forward(self, x, t):
        # t: (b,) in [0, T-1]
        t_emb = timestep_embedding(t, self.time_dim)
        t_emb = self.time_mlp(self.time_in(t_emb))

        x = self.in_conv(x)
        feats = [x]

        it = iter(self.downs_attn)
        for mod in self.downs:
            x = mod(x, t_emb) if isinstance(mod, ResBlock) else mod(x)
            attn = next(it)
            x = attn(x)
            feats.append(x)

        x = self.mid_block1(x, t_emb)
        x = self.mid_attn(x)
        x = self.mid_block2(x, t_emb)

        it = iter(self.ups_attn)
        for mod in self.ups:
            if isinstance(mod, ResBlock):
                skip = feats.pop()
                x = torch.cat([x, skip], dim=1)
                x = mod(x, t_emb)
            else:
                x = mod(x)
            attn = next(it)
            x = attn(x)

        x = self.out_conv(self.out_act(self.out_norm(x)))
        return x  # predicts epsilon

# ---------------------------
# Diffusion process (DDPM)
# ---------------------------

@dataclass
class DiffusionConfig:
    image_size: int = 32
    channels: int = 3
    timesteps: int = 1000
    beta_start: float = 1e-4
    beta_end: float = 2e-2
    variance_type: str = "fixed_small"  # "fixed_small" -> posterior (beta_tilde), "fixed_large" -> beta_t

class DDPM(nn.Module):
    def __init__(self, model: nn.Module, config: DiffusionConfig):
        super().__init__()
        self.model = model
        self.config = config

        T = config.timesteps
        betas = torch.linspace(config.beta_start, config.beta_end, T, dtype=torch.float32)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = torch.cat([torch.tensor([1.], dtype=torch.float32), alphas_cumprod[:-1]], dim=0)

        # register buffers
        for name, tensor in {
            "betas": betas,
            "alphas": alphas,
            "alphas_cumprod": alphas_cumprod,
            "alphas_cumprod_prev": alphas_cumprod_prev,
            "sqrt_alphas_cumprod": torch.sqrt(alphas_cumprod),
            "sqrt_one_minus_alphas_cumprod": torch.sqrt(1. - alphas_cumprod),
            "sqrt_recip_alphas": torch.sqrt(1.0 / alphas),
        }.items():
            self.register_buffer(name, tensor)

        # posterior variance beta_tilde
        posterior_variance = (betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)).clamp(min=1e-20)
        self.register_buffer("posterior_variance", posterior_variance)
        self.register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))

        # mean coefficient helpers for p_theta
        self.register_buffer("posterior_mean_coef1", (betas * torch.sqrt(alphas_cumprod_prev)) / (1. - alphas_cumprod))
        self.register_buffer("posterior_mean_coef2", ((1. - alphas_cumprod_prev) * torch.sqrt(alphas)) / (1. - alphas_cumprod))

    def q_sample(self, x0, t, noise=None):
        """
        q(x_t | x0) = N(sqrt(alpha_bar_t) x0, (1-alpha_bar_t) I)
        """
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_ab = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_omb = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_ab * x0 + sqrt_omb * noise

    def p_mean_variance(self, x_t, t):
        """
        Use epsilon-prediction parameterization to compute mean.
        Optionally choose variance type.
        """
        eps_theta = self.model(x_t, t)
        beta_t = self.betas[t].view(-1,1,1,1)
        sqrt_recip_alpha_t = self.sqrt_recip_alphas[t].view(-1,1,1,1)
        sqrt_one_minus_ab_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1,1,1,1)

        # Eq. 11-style mean: mu = 1/sqrt(alpha_t) * (x_t - beta_t / sqrt(1-ab_t) * eps_theta)
        # Note: beta_t == (1 - alpha_t)
        model_mean = sqrt_recip_alpha_t * (x_t - beta_t / sqrt_one_minus_ab_t * eps_theta)

        if self.config.variance_type == "fixed_small":
            # posterior variance (beta_tilde) as in many DDPM refs
            model_var = self.posterior_variance[t].view(-1,1,1,1)
            model_log_var = self.posterior_log_variance_clipped[t].view(-1,1,1,1)
        else:
            # "fixed_large": set variance to beta_t (upper bound choice)
            model_var = beta_t
            model_log_var = torch.log(beta_t)

        return model_mean, model_var, model_log_var, eps_theta

    @torch.no_grad()
    def p_sample(self, x_t, t):
        """
        One reverse step: x_{t-1} ~ N(mean, var).
        For t == 0, return mean (z = 0).
        """
        b = x_t.shape[0]
        model_mean, model_var, model_log_var, _ = self.p_mean_variance(x_t, t)
        if (t == 0).all():
            return model_mean
        z = torch.randn_like(x_t)
        return model_mean + (model_var ** 0.5) * z

    @torch.no_grad()
    def sample(self, batch_size, device):
        x = torch.randn(batch_size, 3, self.config.image_size, self.config.image_size, device=device)
        for i in reversed(range(self.config.timesteps)):
            t = torch.full((batch_size,), i, device=device, dtype=torch.long)
            x = self.p_sample(x, t)
        return x

    def training_losses(self, x0):
        """
        L_simple: E_{t, eps} || eps - eps_theta(x_t, t) ||^2
        """
        b = x0.size(0)
        t = torch.randint(0, self.config.timesteps, (b,), device=x0.device).long()
        noise = torch.randn_like(x0)
        x_t = self.q_sample(x0, t, noise)
        eps_theta = self.model(x_t, t)
        loss = F.mse_loss(eps_theta, noise, reduction='mean')
        return loss

# ---------------------------
# EMA (as in the paper)
# ---------------------------

class EMA:
    def __init__(self, model, decay=0.9999):
        self.decay = decay
        self.shadow = {k: v.clone().detach() for k, v in model.state_dict().items() if v.dtype.is_floating_point}
    def update(self, model):
        with torch.no_grad():
            for k, v in model.state_dict().items():
                if k in self.shadow and v.dtype.is_floating_point:
                    self.shadow[k].mul_(self.decay).add_(v, alpha=1 - self.decay)
    def copy_to(self, model):
        with torch.no_grad():
            sd = model.state_dict()
            for k in self.shadow:
                sd[k].copy_(self.shadow[k])

# ---------------------------
# Data
# ---------------------------

def get_cifar10_loader(data_dir, batch_size, num_workers=4):
    tfm = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x * 2.0 - 1.0),  # [0,1] -> [-1,1]
    ])
    train_set = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=tfm)
    return DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)

# ---------------------------
# Training loop
# ---------------------------

def train(args):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    set_seed(args.seed)

    # Model
    unet = UNet(
        in_channels=3,
        base_channels=args.base_channels,
        channel_mults=(1,2,2,2),
        num_res_blocks=2,
        time_dim=512,
        dropout=0.1,               # 0.1 for CIFAR-10 (paper)
        attn_resolutions=(16,),    # attention at 16x16
        img_size=32
    ).to(device)

    diffusion = DDPM(unet, DiffusionConfig(
        image_size=32,
        channels=3,
        timesteps=1000,
        beta_start=1e-4,
        beta_end=2e-2,
        variance_type=args.variance_type
    )).to(device)

    print(f"Trainable params: {count_params(diffusion)/1e6:.2f}M")

    # Optimizer
    opt = torch.optim.Adam(diffusion.parameters(), lr=2e-4)  # as in paper
    scaler = torch.cuda.amp.GradScaler(enabled=args.amp)

    # EMA
    ema = EMA(diffusion, decay=0.9999)

    # Data
    loader = get_cifar10_loader(args.data, args.batch_size, args.workers)

    os.makedirs(args.out, exist_ok=True)

    global_step = 0
    diffusion.train()
    for epoch in range(args.epochs):
        for imgs, _ in loader:
            imgs = imgs.to(device)

            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=args.amp):
                loss = diffusion.training_losses(imgs)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            ema.update(diffusion)

            if global_step % args.log_interval == 0:
                print(f"[epoch {epoch} step {global_step}] loss: {loss.item():.4f}")

            # Sampling preview with EMA params
            if global_step % args.sample_interval == 0 and global_step > 0:
                eval_copy = DDPM(UNet(
                    in_channels=3,
                    base_channels=args.base_channels,
                    channel_mults=(1,2,2,2),
                    num_res_blocks=2,
                    time_dim=512,
                    dropout=0.0,
                    attn_resolutions=(16,),
                    img_size=32
                ), diffusion.config).to(device)
                eval_copy.load_state_dict(diffusion.state_dict(), strict=True)
                ema.copy_to(eval_copy)  # swap in EMA weights

                eval_copy.eval()
                with torch.no_grad():
                    samples = eval_copy.sample(args.nsample, device)
                save_image_grid(samples, os.path.join(args.out, f"samples_step{global_step}.png"), nrow=int(math.sqrt(args.nsample)))
                del eval_copy
                torch.cuda.empty_cache()

            if global_step % args.ckpt_interval == 0 and global_step > 0:
                ckpt = {
                    "model": diffusion.state_dict(),
                    "opt": opt.state_dict(),
                    "step": global_step,
                    "args": vars(args)
                }
                torch.save(ckpt, os.path.join(args.out, f"ddpm_step{global_step}.pt"))

            global_step += 1
            if global_step >= args.max_steps:
                break
        if global_step >= args.max_steps:
            break

    # final save
    ckpt = {"model": diffusion.state_dict(), "opt": opt.state_dict(), "step": global_step, "args": vars(args)}
    torch.save(ckpt, os.path.join(args.out, f"ddpm_final.pt"))

def parse_args(argv=None):
    p = argparse.ArgumentParser()
    p.add_argument("--data", type=str, default="./data")
    p.add_argument("--out", type=str, default="./runs/ddpm_cifar10")
    p.add_argument("--batch_size", type=int, default=128)
    p.add_argument("--workers", type=int, default=4)
    p.add_argument("--epochs", type=int, default=1000)
    p.add_argument("--max_steps", type=int, default=800_000)
    p.add_argument("--log_interval", type=int, default=100)
    p.add_argument("--sample_interval", type=int, default=5000)
    p.add_argument("--ckpt_interval", type=int, default=50_000)
    p.add_argument("--nsample", type=int, default=64)
    p.add_argument("--base_channels", type=int, default=128)
    p.add_argument("--variance_type", type=str, choices=["fixed_small", "fixed_large"], default="fixed_small")
    p.add_argument("--amp", action="store_true")
    p.add_argument("--seed", type=int, default=42)
    return p.parse_known_args(argv)  # <-- ignore unknown args from Jupyter

if __name__ == "__main__":
    args, _ = parse_args()   # args now has .seed and all other attributes
    train(args)

def parse_args(argv=None):
    p = argparse.ArgumentParser()
    # ... (same add_argument calls)
    return p.parse_known_args(argv)  # <— note: parse_known_args

if __name__ == "__main__":
    args, _unknown = parse_args()    # <— ignore the Jupyter -f arg
    train(args)



Trainable params: 37.41M


  scaler = torch.cuda.amp.GradScaler(enabled=args.amp)
100%|██████████| 170M/170M [01:51<00:00, 1.53MB/s]
  with torch.cuda.amp.autocast(enabled=args.amp):


[epoch 0 step 0] loss: 1.1330


KeyboardInterrupt: 