Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from torchvision.utils import save_image
from copy import deepcopy
import os
import math

config

In [None]:
class Config:
    DATA_DIR = "./data"
    IMAGE_SIZE = 28
    IMAGE_CHANNELS = 1
    NUM_CLASSES = 10
    BATCH_SIZE = 128
    NUM_WORKERS = 2

    N_STEPS = 1000
    LEARNING_RATE = 1e-4
    N_EPOCHS = 100
    SAVE_EVERY_N_EPOCHS = 10
    EMA_DECAY = 0.999
    KL_WEIGHT = 1e-4

    LATENT_CHANNELS = 4
    LATENT_SIZE = 7
    VAE_EPOCHS = 35
    VAE_PATH = "fashion_vae.pt"

    LATENT_MODEL_PATH = "latent_model.pt"
    LATENT_EMA_PATH = "latent_model.ema.pt"

    N_SAMPLES = 80
    GUIDANCE_WEIGHT = 0.0

    BASE_CHANNELS = 64
    CHANNEL_MULT = [1, 2, 2]
    NUM_RES_BLOCKS = 2
    TIME_EMBED_DIM = BASE_CHANNELS*4
    CONTEXT_DIM = TIME_EMBED_DIM

config = Config()

Data Loader

In [None]:
def get_dataloader(train: bool = True):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    dataset = datasets.FashionMNIST(
        root=config.DATA_DIR,
        train=train,
        download=True,
        transform=transform
    )
    return DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=train,
        num_workers=config.NUM_WORKERS,
        pin_memory=True
    )

In [None]:
def get_timestep_embedding(timesteps, dim):
    half_dim = dim // 2
    exponent = torch.exp(-math.log(10000) * torch.arange(half_dim, device=timesteps.device) / half_dim)
    timesteps = timesteps.float().unsqueeze(-1)
    emb = timesteps * exponent.unsqueeze(0)
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
    return emb.to(dtype=torch.float32)

def save_checkpoint(model, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)

def load_checkpoint(model, path, device="cpu"):
    if os.path.exists(path):
        model.load_state_dict(torch.load(path, map_location=device))
        print(f"Checkpoint loaded from {path}")
        return True
    return False

VAE

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y

class Encoder(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(config.IMAGE_CHANNELS, 32, 3, padding=1),
            nn.GroupNorm(8, 32),
            nn.SiLU(),

            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            ChannelAttention(64),

            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.GroupNorm(8, 128),
            nn.SiLU(),
            ChannelAttention(128)
        )

        self.to_mu = nn.Conv2d(128, latent_channels, 1)
        self.to_logvar = nn.Conv2d(128, latent_channels, 1)

    def forward(self, x):
        h = self.net(x)
        return self.to_mu(h), self.to_logvar(h)

def reparameterize(mu, logvar):
    std = torch.exp(0.5 * logvar)
    eps = torch.randn_like(std)
    return mu + eps * std

class Decoder(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()

        self.net = nn.Sequential(
            nn.Conv2d(latent_channels, 128, 3, padding=1),
            nn.GroupNorm(8, 128),
            nn.SiLU(),

            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.GroupNorm(8, 64),
            nn.SiLU(),
            ChannelAttention(64),

            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.GroupNorm(8, 32),
            nn.SiLU(),

            nn.Conv2d(32, config.IMAGE_CHANNELS, 3, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        return self.net(z)

class VAE(nn.Module):
    def __init__(self, latent_channels=config.LATENT_CHANNELS):
        super().__init__()
        self.encoder = Encoder(latent_channels)
        self.decoder = Decoder(latent_channels)

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar, z

UNet

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim, num_heads=4):
        super().__init__()
        self.q_proj = nn.Linear(query_dim, query_dim)
        self.k_proj = nn.Linear(context_dim, query_dim)
        self.v_proj = nn.Linear(context_dim, query_dim)
        self.attn = nn.MultiheadAttention(query_dim, num_heads, batch_first=True)
        self.proj_out = nn.Linear(query_dim, query_dim)

    def forward(self, x, context):
        q = self.q_proj(x)
        k = self.k_proj(context)
        v = self.v_proj(context)
        attn_output, _ = self.attn(q, k, v)
        return x + self.proj_out(attn_output)

class LatentResBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim, context_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, out_ch))
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.shortcut = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

        self.norm_ca = nn.GroupNorm(8, out_ch)
        self.ca = CrossAttention(out_ch, context_dim)

    def forward(self, x, t_emb, context):
        b, _, h, w = x.shape
        out_ch = self.conv1.out_channels

        h_temp = F.group_norm(x, min(8, x.shape[1]))
        h_temp = self.conv1(F.silu(h_temp))

        h_temp = h_temp + self.time_mlp(t_emb)[:, :, None, None]

        h_ca = F.silu(self.norm_ca(h_temp))
        new_hw = int(h) * int(w)
        h_ca = h_ca.permute(0, 2, 3, 1).reshape(b, new_hw, out_ch)
        h_ca = self.ca(h_ca, context)
        h_temp = h_temp + h_ca.transpose(1, 2).reshape(b, out_ch, h, w)

        h_temp = self.conv2(F.silu(F.group_norm(h_temp, min(8, h_temp.shape[1]))))

        return h_temp + self.shortcut(x)

class LatentEpsModel(nn.Module):
    def __init__(self):
        super().__init__()

        time_dim = config.TIME_EMBED_DIM
        context_dim = config.CONTEXT_DIM

        self.time_mlp = nn.Sequential(nn.Linear(time_dim, time_dim), nn.SiLU(), nn.Linear(time_dim, time_dim))
        self.class_emb = nn.Embedding(config.NUM_CLASSES, context_dim)
        self.conv_in = nn.Conv2d(config.LATENT_CHANNELS, config.BASE_CHANNELS, 3, padding=1)

        self.downs = nn.ModuleList()
        ch = config.BASE_CHANNELS
        for mult in config.CHANNEL_MULT:
            out = config.BASE_CHANNELS * mult
            blocks = nn.ModuleList()
            blocks.append(LatentResBlock(ch, out, time_dim, context_dim))
            for _ in range(config.NUM_RES_BLOCKS - 1):
                blocks.append(LatentResBlock(out, out, time_dim, context_dim))

            downsample = nn.Conv2d(out, out, 3, 1, 1)
            self.downs.append(nn.ModuleDict({"blocks": blocks, "down": downsample}))
            ch = out

        self.bot1 = LatentResBlock(ch, ch, time_dim, context_dim)
        self.bot2 = LatentResBlock(ch, ch, time_dim, context_dim)

        self.ups = nn.ModuleList()
        for mult in reversed(config.CHANNEL_MULT):
            skip_ch = config.BASE_CHANNELS * mult
            out = skip_ch
            in_ch_res = out + skip_ch

            blocks = nn.ModuleList()
            blocks.append(LatentResBlock(in_ch_res, out, time_dim, context_dim))
            for _ in range(config.NUM_RES_BLOCKS - 1):
                blocks.append(LatentResBlock(out, out, time_dim, context_dim))

            upsample = nn.ConvTranspose2d(ch, out, 3, 1, 1)
            self.ups.append(nn.ModuleDict({"blocks": blocks, "up": upsample}))
            ch = out

        self.conv_out = nn.Conv2d(ch, config.LATENT_CHANNELS, 3, padding=1)

    def forward(self, x, t, y):
        t_emb = self.time_mlp(get_timestep_embedding(t, config.TIME_EMBED_DIM))
        context = self.class_emb(y).unsqueeze(1)

        hs = []
        h = self.conv_in(x)
        for module in self.downs:
            for block in module["blocks"]:
                h = block(h, t_emb, context)
            hs.append(h)
            h = module["down"](h)

        h = self.bot1(h, t_emb, context)
        h = self.bot2(h, t_emb, context)

        for module in self.ups:
            skip = hs.pop()
            h = module["up"](h)

            h = torch.cat([h, skip], dim=1)
            for block in module["blocks"]:
                h = block(h, t_emb, context)

        return self.conv_out(h)

Sheduler

In [None]:
class ConditionalDenoiseDiffusion():
    def __init__(self, eps_model, n_steps=config.N_STEPS, device=None):
        super().__init__()
        self.eps_model = eps_model
        self.device = device if device is not None else torch.device("cpu")

        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(self.device)
        self.alpha = 1 - self.beta
        self.alpha_bar = torch.cumprod(self.alpha, dim=0)
        self.n_steps = n_steps
        self.sqrt_alpha_bar = torch.sqrt(self.alpha_bar)
        self.sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bar)

        self.alphas_cumprod_prev = F.pad(self.alpha_bar[:-1], (1, 0), value=1.0)
        self.post_variance = self.beta * (1. - self.alphas_cumprod_prev) / (1. - self.alpha_bar)

    def q_sample(self, x0, t, eps=None):
        if eps is None:
            eps = torch.randn_like(x0)
        a_bar = self.sqrt_alpha_bar[t].reshape(-1, 1, 1, 1)
        one_minus = self.sqrt_one_minus_alpha_bar[t].reshape(-1, 1, 1, 1)
        return a_bar * x0 + one_minus * eps

    def p_sample(self, xt, t, c=None):
        if not isinstance(t, torch.Tensor):
            t = torch.tensor([t] * xt.shape[0], device=xt.device, dtype=torch.long)

        eps_theta = self.eps_model(xt, t, c)

        alpha_t = self.alpha[t].reshape(-1, 1, 1, 1)
        alpha_bar_t = self.alpha_bar[t].reshape(-1, 1, 1, 1)
        alpha_bar_t_prev = self.alphas_cumprod_prev[t].reshape(-1, 1, 1, 1)

        x0_pred = (xt - self.sqrt_one_minus_alpha_bar[t].reshape(-1, 1, 1, 1) * eps_theta) / self.sqrt_alpha_bar[t].reshape(-1, 1, 1, 1)
        x0_pred = torch.clamp(x0_pred, -1., 1.)

        mean = (alpha_bar_t_prev.sqrt() * self.beta[t].reshape(-1, 1, 1, 1) / (1. - alpha_bar_t)) * x0_pred + \
               (alpha_t.sqrt() * (1. - alpha_bar_t_prev) / (1. - alpha_bar_t)) * xt

        variance = self.post_variance[t].reshape(-1, 1, 1, 1)

        if t[0] > 0:
            noise = torch.randn_like(xt)
            return mean + torch.sqrt(variance) * noise
        else:
            return mean

    def sample(self, shape, device, c=None):
        x = torch.randn(shape, device=device)
        for t in tqdm(reversed(range(self.n_steps)), desc="Sampling"):
            x = self.p_sample(x, t, c)
        return x

    def loss(self, x0, labels=None):
        batch_size = x0.shape[0]
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        eps = torch.randn_like(x0)
        xt = self.q_sample(x0, t, eps)
        eps_theta = self.eps_model(xt, t, labels)
        return F.mse_loss(eps, eps_theta)

Training & Sampling

In [None]:
def train_vae(device):
    vae = VAE().to(device)

    for p in vae.parameters():
        p.requires_grad_(True)

    vae_optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)
    dataloader = get_dataloader(train=True)

    print("--- Starting Attention-Enhanced VAE Pre-training (Fashion MNIST) ---")
    for epoch in range(config.VAE_EPOCHS):
        pbar = tqdm(dataloader, desc=f"VAE Epoch {epoch+1}/{config.VAE_EPOCHS}")
        for imgs, _ in pbar:
            imgs = imgs.to(device)
            vae_optimizer.zero_grad()

            recon, mu, logvar, _ = vae(imgs)

            recon_loss = F.mse_loss(recon, imgs, reduction='mean')
            kl_div = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
            loss = recon_loss + config.KL_WEIGHT * kl_div

            loss.backward()
            vae_optimizer.step()
            pbar.set_postfix(recon=f"{recon_loss.item():.4f}", kl=f"{kl_div.item():.4f}")

    save_checkpoint(vae, config.VAE_PATH)
    print(f"Fashion VAE weights saved to {config.VAE_PATH}. VAE training complete.")
    return vae

def sample_vae(device, vae):
    if not load_checkpoint(vae, config.VAE_PATH, device=device):
        print("ERROR: VAE weights not found. Cannot sample VAE.")
        return
    vae.eval()

    dataloader = get_dataloader(train=False)
    imgs, _ = next(iter(dataloader))
    imgs = imgs[:10].to(device)

    with torch.no_grad():
        recon, _, _, _ = vae(imgs)

    combined = torch.cat([imgs, recon], dim=0)

    out_path = "vae_reconstruction_check.png"
    save_image(combined, out_path, nrow=10)
    print(f"VAE Reconstruction check saved to {out_path}.")

def train_latent_ddpm(device, vae):
    if not load_checkpoint(vae, config.VAE_PATH, device=device):
        print("ERROR: VAE weights not found. Please run 'train_vae' first!")
        return

    for p in vae.parameters():
        p.requires_grad_(False)
    vae.eval()
    print("Attention-Enhanced VAE is loaded and frozen.")

    model = LatentEpsModel().to(device)
    ema_model = deepcopy(model)
    for p in ema_model.parameters():
        p.requires_grad_(False)

    load_checkpoint(model, config.LATENT_MODEL_PATH, device=device)

    sched = ConditionalDenoiseDiffusion(model, device=device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.LEARNING_RATE)
    dataloader = get_dataloader(train=True)

    print("--- Starting Fashion CLDM Training ---")
    for epoch in range(config.N_EPOCHS):
        model.train()
        pbar = tqdm(dataloader, desc=f"CLDM Epoch {epoch+1}/{config.N_EPOCHS}")
        total_loss = 0.0

        for imgs, labels in pbar:
            imgs, labels = imgs.to(device), labels.to(device)

            with torch.no_grad():
                mu, log_var = vae.encode(imgs)
                z0 = mu + torch.randn_like(mu) * torch.exp(0.5 * log_var)

            optimizer.zero_grad()
            loss = sched.loss(z0, labels)

            loss.backward()
            optimizer.step()

            with torch.no_grad():
                for ema_p, p in zip(ema_model.parameters(), model.parameters()):
                    ema_p.mul_(config.EMA_DECAY).add_(p * (1 - config.EMA_DECAY))

            total_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")

        print(f"\n--- Epoch {epoch+1} finished. Avg Loss: {total_loss / len(dataloader):.4f} ---")

        if (epoch+1) % config.SAVE_EVERY_N_EPOCHS == 0 or (epoch+1) == config.N_EPOCHS:
            save_checkpoint(model, config.LATENT_MODEL_PATH)
            save_checkpoint(ema_model, config.LATENT_EMA_PATH)
            print(f"Fashion CLDM Checkpoint saved at epoch {epoch+1}")

            sample_latent_images(device, vae, ema_model)
    return model, ema_model

def sample_latent_images(device, vae, model=None):

    if model is None:
        model = LatentEpsModel().to(device)
        if not load_checkpoint(model, config.LATENT_EMA_PATH, device=device):
            print("ERROR: EMA Latent model not found. Cannot sample.")
            return

    if not load_checkpoint(vae, config.VAE_PATH, device=device):
        print("ERROR: VAE weights not found. Cannot sample.")
        return

    for p in vae.parameters():
        p.requires_grad_(False)
    vae.eval()
    model.eval()

    sched = ConditionalDenoiseDiffusion(model, device=device)

    n_per_class = config.N_SAMPLES // config.NUM_CLASSES

    target_labels = torch.arange(config.NUM_CLASSES, device=device).repeat_interleave(n_per_class)

    latent_shape = (config.N_SAMPLES, config.LATENT_CHANNELS, config.LATENT_SIZE, config.LATENT_SIZE)
    print(f"Generating {config.N_SAMPLES} samples in latent space {latent_shape}...")

    with torch.no_grad():
        z_samples = sched.sample(
            shape=latent_shape,
            device=device,
            c=target_labels
        )

        x_samples = vae.decode(z_samples).clamp(-1, 1)

    x_samples = (x_samples + 1) * 0.5
    out_path = "fashion_latent_samples.png"
    save_image(x_samples, out_path, nrow=n_per_class)
    print(f"Generated samples saved to {out_path}.")