<a href="https://colab.research.google.com/github/Ermi1223/conditional-diffusion-tutorial/blob/main/conditional_diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Step 1: Setup GPU Acceleration**

In [2]:
# First cell: Setup GPU
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.6.0+cu124
CUDA available: True
GPU: Tesla T4


**Step 2: Install Required Packages**

In [3]:
# Second cell: Install dependencies
!pip install torch torchvision matplotlib tqdm ipywidgets
!pip install diffusers transformers

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

**Step 3: Import Libraries and Configuration**

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from diffusers import DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer

# Configuration
IMG_SIZE = 32
BATCH_SIZE = 64
TIMESTEPS = 500  # Reduced for faster training
NUM_EPOCHS = 10  # Reduced for demonstration
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Using device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
PyTorch version: 2.6.0+cu124
GPU: Tesla T4


**Step 4: Dataset Preparation**

In [5]:
from transformers import CLIPTokenizer, CLIPTextModel

# Load CLIP once (global)
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE).eval()

# Precompute embeddings (only 10 CIFAR classes)
class_to_text = {
    0: "an airplane flying in the sky",
    1: "an automobile on the road",
    2: "a bird perched on a branch",
    3: "a cat sitting on a windowsill",
    4: "a deer in the forest",
    5: "a dog playing in the park",
    6: "a frog on a lily pad",
    7: "a horse running in a field",
    8: "a ship sailing on the ocean",
    9: "a truck driving on the highway"
}

precomputed_embeddings = {}
for label, text in class_to_text.items():
    inputs = clip_tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(DEVICE)
    with torch.no_grad():
        text_emb = clip_text_encoder(**inputs).last_hidden_state.mean(dim=1)
    precomputed_embeddings[label] = text_emb.squeeze(0).cpu()  # move to CPU for safety

# Dataset Class
class CIFAR10Conditional(Dataset):
    def __init__(self, root, train=True):
        super().__init__()
        self.cifar = datasets.CIFAR10(
            root=root, train=train, download=True,
            transform=transforms.Compose([
                transforms.Resize(IMG_SIZE),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])
        )

    def __len__(self):
        return len(self.cifar)

    def __getitem__(self, idx):
        img, label = self.cifar[idx]
        text_emb = precomputed_embeddings[label]
        return img, label, text_emb

# Load Data
train_dataset = CIFAR10Conditional(root='./data', train=True)
test_dataset = CIFAR10Conditional(root='./data', train=False)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=2, pin_memory=True)


print(f"Train dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/592 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/605M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/605M [00:00<?, ?B/s]


  0%|          | 0.00/170M [00:00<?, ?B/s][A
  0%|          | 98.3k/170M [00:00<04:16, 665kB/s][A
  0%|          | 328k/170M [00:00<01:56, 1.46MB/s][A
  1%|          | 1.25M/170M [00:00<00:38, 4.42MB/s][A
  3%|▎         | 4.92M/170M [00:00<00:10, 15.7MB/s][A
  5%|▌         | 9.11M/170M [00:00<00:12, 13.4MB/s][A
  6%|▌         | 10.6M/170M [00:01<00:31, 5.09MB/s][A
  9%|▊         | 14.8M/170M [00:01<00:17, 8.67MB/s][A
 12%|█▏        | 20.1M/170M [00:01<00:10, 14.0MB/s][A
 16%|█▌        | 26.5M/170M [00:02<00:06, 20.9MB/s][A
 19%|█▉        | 32.2M/170M [00:02<00:05, 27.1MB/s][A
 22%|██▏       | 37.9M/170M [00:02<00:04, 32.7MB/s][A
 26%|██▌       | 43.7M/170M [00:02<00:03, 38.2MB/s][A
 29%|██▉       | 49.2M/170M [00:02<00:02, 42.1MB/s][A
 32%|███▏      | 54.8M/170M [00:02<00:02, 44.3MB/s][A
 36%|███▌      | 60.7M/170M [00:02<00:02, 47.3MB/s][A
 39%|███▉      | 66.9M/170M [00:02<00:02, 50.9MB/s][A
 42%|████▏     | 72.4M/170M [00:02<00:01, 50.9MB/s][A
 46%|████▌     | 77.

Train dataset size: 50000
Test dataset size: 10000


**Step 5: Model Architecture**

In [7]:
class TimeEmbedding(nn.Module):
    """Sinusoidal time embedding"""
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, t):
        half_dim = self.dim // 2
        emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device=DEVICE) * -emb)
        emb = t[:, None] * emb[None, :]
        return torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)

class ConditionalBlock(nn.Module):
    """U-Net block with conditioning"""
    def __init__(self, in_ch, out_ch, time_dim, cond_dim):
        super().__init__()
        self.time_layer = nn.Sequential(
            nn.Linear(time_dim, out_ch),
            nn.SiLU()
        )
        self.cond_layer = nn.Sequential(
            nn.Linear(cond_dim, out_ch),
            nn.SiLU()
        )
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.GroupNorm(8, out_ch),
            nn.SiLU()
        )
        self.residual = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()

    def forward(self, x, t_emb, c_emb):
        h = self.conv(x) + self.residual(x)
        t_emb = self.time_layer(t_emb)
        c_emb = self.cond_layer(c_emb)
        scale = t_emb + c_emb
        scale = scale.view(scale.shape[0], scale.shape[1], 1, 1)
        return h * scale

class ConditionalUNet(nn.Module):
    """U-Net with conditioning for diffusion"""
    def __init__(self):
        super().__init__()
        time_dim = 128
        cond_dim = 512  # Match CLIP output size

        # Time embedding
        self.time_mlp = nn.Sequential(
            TimeEmbedding(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
            nn.Linear(time_dim, time_dim)
        )

        # Class embedding
        self.class_embed = nn.Embedding(10, cond_dim)

        # Conditioning projection
        self.cond_proj = nn.Sequential(
            nn.Linear(cond_dim, cond_dim),
            nn.SiLU(),
            nn.Linear(cond_dim, cond_dim)
        )

        # Down blocks
        self.down1 = ConditionalBlock(3, 64, time_dim, cond_dim)
        self.down2 = ConditionalBlock(64, 128, time_dim, cond_dim)
        self.down3 = ConditionalBlock(128, 256, time_dim, cond_dim)

        # Bottleneck
        self.bottleneck = ConditionalBlock(256, 512, time_dim, cond_dim)

        # Up blocks
        self.up1 = ConditionalBlock(512 + 256, 256, time_dim, cond_dim)
        self.up2 = ConditionalBlock(256 + 128, 128, time_dim, cond_dim)
        self.up3 = ConditionalBlock(128 + 64, 64, time_dim, cond_dim)

        # Output
        self.out = nn.Conv2d(64, 3, 1)

        # Pooling/upsampling
        self.downsample = nn.AvgPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x, t, class_labels, text_emb):
        # Embeddings
        t_emb = self.time_mlp(t)
        c_emb = self.class_embed(class_labels) + self.cond_proj(text_emb)

        # Down path
        d1 = self.down1(x, t_emb, c_emb)
        d2 = self.down2(self.downsample(d1), t_emb, c_emb)
        d3 = self.down3(self.downsample(d2), t_emb, c_emb)

        # Bottleneck
        b = self.bottleneck(self.downsample(d3), t_emb, c_emb)

        # Up path with skip connections
        u1 = self.up1(torch.cat([self.upsample(b), d3], dim=1), t_emb, c_emb)
        u2 = self.up2(torch.cat([self.upsample(u1), d2], dim=1), t_emb, c_emb)
        u3 = self.up3(torch.cat([self.upsample(u2), d1], dim=1), t_emb, c_emb)

        return self.out(u3)

# Create model
model = ConditionalUNet().to(DEVICE)
print(f"Model created with {sum(p.numel() for p in model.parameters()):,} parameters")

Model created with 9,687,747 parameters


**Step 6:Diffusion Utilities**

In [8]:
class Diffusion:
    """Handles forward and reverse diffusion processes"""
    def __init__(self, timesteps=TIMESTEPS):
        self.timesteps = timesteps

        # Linear noise schedule
        self.betas = torch.linspace(1e-4, 0.02, timesteps)
        self.alphas = 1. - self.betas
        self.alpha_bars = torch.cumprod(self.alphas, dim=0)

    def forward_process(self, x0, t):
        """Add noise to image at timestep t"""
        sqrt_alpha_bar = torch.sqrt(self.alpha_bars[t])[:, None, None, None]
        sqrt_one_minus_alpha_bar = torch.sqrt(1 - self.alpha_bars[t])[:, None, None, None]
        noise = torch.randn_like(x0)
        noisy_img = sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise
        return noisy_img, noise

    def reverse_process(self, model, noisy_img, t, class_labels, text_emb, noise_pred=None):
        """Predict and remove noise from image"""
        if noise_pred is None:
            noise_pred = model(noisy_img, t, class_labels, text_emb)

        # Compute coefficients
        alpha_t = self.alphas[t][:, None, None, None]
        alpha_bar_t = self.alpha_bars[t][:, None, None, None]
        beta_t = self.betas[t][:, None, None, None]
        sqrt_alpha_t = torch.sqrt(alpha_t)
        sqrt_one_minus_alpha_bar_t = torch.sqrt(1 - alpha_bar_t)

        # Predict x0
        pred_x0 = (noisy_img - sqrt_one_minus_alpha_bar_t * noise_pred) / torch.sqrt(alpha_bar_t)

        # Direction pointing to x_t
        dir_xt = torch.sqrt(1 - alpha_t) * noise_pred

        # Reverse diffusion step
        prev_noisy = (1 / sqrt_alpha_t) * (noisy_img - dir_xt)

        return prev_noisy, pred_x0

# Initialize diffusion utilities
diffusion = Diffusion()
print(f"Diffusion process initialized with {TIMESTEPS} timesteps")

Diffusion process initialized with 500 timesteps
