# Complete Standalone Diffusion Training Notebook

This notebook contains all the necessary code to train a diffusion model on CIFAR-10 from scratch, including the model architecture and scheduler definitions. It is designed to be run directly on Kaggle.

In [None]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

# Hyperparameters
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 64
LEARNING_RATE = 1e-4
EPOCHS = 20
IMAGE_SIZE = 32
CHANNELS = 3
N_TIMESTEPS = 1000
N_HEADS = 4
N_EMBED = 320
SAVE_INTERVAL = 5 # Save checkpoint every 5 epochs

print(f"Using device: {DEVICE}")

## 1. DDPM Scheduler (Modified for Training)

In [None]:
class DDPMSampler:
    def __init__(
        self,
        generator: torch.Generator,
        num_training_steps: int = 10000,
        beta_start: float = 0.00085,
        beta_end: float = 0.0120,
    ) -> None:
        self.beta = (
            torch.linspace(
                beta_start**0.5, beta_end**0.5, num_training_steps, dtype=torch.float32
            )
            ** 2
        )
        self.alpha = 1.0 - self.beta
        self.generator = generator
        self.one = torch.tensor(1.0)
        self.alphacum = torch.cumprod(self.alpha, dim=0)
        self.num_training_steps = num_training_steps
        self.timestep = torch.from_numpy(np.arange(0, num_training_steps)[::-1].copy())

    def set_inference_timesteps(self, num_inference_steps=50):
        self.num_inference_steps = num_inference_steps
        step_ratio = self.num_training_steps // self.num_inference_steps
        timesteps = (
            (np.arange(0, num_inference_steps) * step_ratio)
            .round()[::-1]
            .copy()
            .astype(np.int64)
        )
        self.timesteps = torch.from_numpy(timesteps)

    def _get_previous_timestep(self, timestep: int) -> int:
        prev_timestep = timestep - (self.num_training_steps // self.num_inference_steps)
        return prev_timestep

    def _get_variance(self, timestep: int):
        previous_t = self._get_previous_timestep(timestep)
        alpha_t = self.alphacum[timestep]
        alpha_prev_t = self.alphacum[previous_t] if previous_t >= 0 else self.one
        current_beta_t = 1 - alpha_t / alpha_prev_t

        variance = ((1 - alpha_prev_t) / (1 - alpha_t)) * current_beta_t
        variance = torch.clamp(variance, min=1e-10)

        return variance

    def set_strength(self, strength: float = 1.0):
        start_step = self.num_training_steps - (self.num_inference_steps * strength)
        self.timesteps = self.timesteps[start_step:]
        self.start_step = start_step

    def step(self, timestep: int, latents: torch.Tensor, model_outputs: torch.Tensor):
        t = timestep
        previous_t = self._get_previous_timestep(t)

        alpha_t = self.alphacum[timestep]
        alpha_prev_t = self.alphacum[previous_t] if previous_t >= 0 else self.one
        beta_t = 1 - alpha_t
        beta_prev_t = 1 - alpha_prev_t
        curr_alpha_t = alpha_t / alpha_prev_t
        curr_beta_t = 1 - curr_alpha_t

        # x0 calculation
        prediction_original_sample = (
            latents - (beta_t ** (0.5)) * model_outputs
        ) / alpha_t**0.5

        predicted_sample_coeff = ((beta_prev_t ** (0.5)) * curr_beta_t) / beta_t
        current_sample_coeff = (curr_alpha_t) ** 0.5 * (beta_prev_t) / curr_beta_t

        predicted_prev_samples = (
            predicted_sample_coeff * prediction_original_sample
            + current_sample_coeff * latents
        )

        variance = 0
        if t > 0:
            device = model_outputs.device
            noise = torch.randn(
                model_outputs.shape,
                generator=self.generator,
                device=device,
                dtype=model_outputs.dtype,
            )

            variance = (self._get_variance(t) ** 0.5) * noise
            predicted_prev_samples += variance

        return predicted_prev_samples

    def add_noise(
        self, original_samples: torch.FloatTensor, timesteps: torch.IntTensor
    ):
        cumprod_add = self.alphacum.to(
            device=original_samples.device, dtype=original_samples.dtype
        )
        timesteps = timesteps.to(original_samples.device)

        sqrt_alpha = self.alphacum[timesteps] ** 0.5
        sqrt_alpha = sqrt_alpha.flatten()
        if len(sqrt_alpha.shape) < len(original_samples.shape):
            sqrt_alpha = sqrt_alpha.unsqueeze(-1)

        sqrt_alpha_minus = (1 - self.alphacum[timesteps]) ** 0.5
        sqrt_alpha_minus = sqrt_alpha_minus.flatten()
        if len(sqrt_alpha_minus.shape) < len(original_samples.shape):
            sqrt_alpha_minus = sqrt_alpha_minus.unsqueeze(-1)

        noise = torch.randn(
            original_samples.shape,
            generator=self.generator,
            device=original_samples.device,
            dtype=original_samples.dtype,
        )

        final_noise = sqrt_alpha * original_samples + sqrt_alpha_minus * noise
        return final_noise, noise  # Modified to return noise for training loss

## 2. Model Architecture (Modified for RGB Input)

In [None]:
class SelfAttention(nn.Module):
    def __init__(
        self,
        n_heads: int,
        d_embed: int,
        input_projection_bias: bool = True,
        output_projection_bias: bool = True,
    ) -> None:
        super().__init__()
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias=input_projection_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=output_projection_bias)

        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x, masked=False):

        batch_size, seq_len, d_embed = x.shape

        q, k, v = self.in_proj(x).chunk(3, dim=-1)

        # manual forward pass
        q = q.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)

        scale = math.sqrt(self.d_head)

        qk = q @ k.transpose(-1, -2)

        if masked == True:
            mask = torch.ones(seq_len, seq_len, dtype=bool, device=qk.device).triu(
                diagonal=1
            )
            qk = qk.masked_fill(mask)

        mul = qk / scale

        final = F.softmax(qk, dim=-1)

        output = final @ v

        output = output.transpose(1, 2).reshape(batch_size, seq_len, d_embed)

        output = self.out_proj(output)

        return output


class CrossAttention(nn.Module):
    def __init__(
        self,
        n_heads: int,
        d_embed: int,
        d_cross: int,
        input_projection_bias: bool = True,
        output_projection_bias: bool = True,
    ) -> None:
        super().__init__()
        self.q_proj = nn.Linear(d_embed, d_embed, bias=input_projection_bias)
        self.k_proj = nn.Linear(d_cross, d_embed, bias=input_projection_bias)
        self.v_proj = nn.Linear(d_cross, d_embed, bias=input_projection_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=output_projection_bias)

        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x, y, masked=False):

        batch_size, seq_len, d_embed = x.shape

        q = self.q_proj(x)
        k = self.k_proj(y)
        v = self.v_proj(y)

        # manual forward pass
        q = q.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_heads, self.d_head).transpose(1, 2)

        scale = math.sqrt(self.d_head)

        qk = q @ k.transpose(-1, -2)

        if masked == True:
            mask = torch.ones(seq_len, seq_len, dtype=bool, device=qk.device).triu(
                diagonal=1
            )
            qk = qk.masked_fill(mask)

        mul = qk / scale

        final = F.softmax(qk, dim=-1)

        output = final @ v

        output = output.transpose(1, 2).reshape(batch_size, seq_len, d_embed)

        output = self.out_proj(output)

        return output


class TimeEmbedding(nn.Module):
    def __init__(self, n_embed: int) -> None:
        super().__init__()
        self.linear1 = nn.Linear(n_embed, 4 * n_embed)
        self.linear2 = nn.Linear(4 * n_embed, 4 * n_embed)

    def forward(self, x):
        x = self.linear1(x)
        x = F.silu(x)
        x = self.linear2(x)
        return x


class UNET_AttentionBlock(nn.Module):
    def __init__(self, n_heads: int, n_embed: int, d_cross: int = 768) -> None:
        super().__init__()
        channels = n_heads * n_embed

        self.groupnorm = nn.GroupNorm(32, channels)
        self.convlayer = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

        self.layernorm1 = nn.LayerNorm(channels)
        self.selfattention = SelfAttention(n_heads, n_embed)

        self.layernorm2 = nn.LayerNorm(channels)
        self.crosssattention = CrossAttention(n_heads, n_embed, d_cross=d_cross) # Added d_cross to init call if not hardcoded

        self.layernorm3 = nn.LayerNorm(channels)
        self.linear1 = nn.Linear(channels, 4 * channels)
        self.linear2 = nn.Linear(4 * channels, channels)

        self.convout = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

    def forward(self, x, context):

        long_residue = x

        x = self.groupnorm(x)
        x = self.convlayer(x)

        n, c, h, w = x.shape

        x = x.view(n, c, h * w).transpose(-1, -2)

        short_residue = x

        x = self.layernorm1(x)
        x = self.selfattention(x)

        x += short_residue

        short_residue = x

        x = self.layernorm2(x)
        x = self.crosssattention(x, context)

        x += short_residue

        short_residue = x

        x = self.layernorm3(x)
        x, gate = self.linear1(x).chunk(2, dim=-1)
        x = x * F.gelu(gate)
        x = self.linear2(x)

        x += short_residue

        x = x.transpose(-1, -2).view((n, c, h, w))

        x = self.convout(x)

        return x + long_residue


class Upsample(nn.Module):
    def __init__(
        self,
        channels: int,
    ) -> None:
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)

    def forward(self, x):

        x = F.interpolate(x, scale_factor=2, mode="nearest")
        x = self.conv(x)
        return x


class UNET_ResidualBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int, n_time: int = 1280) -> None:
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, in_channels)
        self.convlayer = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.timelayer = nn.Linear(n_time, out_channels)
        
        # Fix: Ensure GroupNorm num_groups matches channels
        # The original code uses 32 groups. If out_channels < 32, this will crash.
        # For standard UNET sizes (320, 640, 1280), it's fine.
        self.groupnorm_time = nn.GroupNorm(32, out_channels)
        self.convlayer_time = nn.Conv2d(
            in_channels, out_channels, kernel_size=3, padding=1
        )

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(
                in_channels, out_channels, kernel_size=1, padding=0
            )

    def forward(self, image: torch.Tensor, time):

        residue = image

        image = self.groupnorm(image)
        image = F.silu(image)
        image = self.convlayer(image)

        time = self.timelayer(time)

        # Using unsqueeze to match dimensions
        y = image + time.unsqueeze(-1).unsqueeze(-1)
        y = self.groupnorm_time(y)
        y = F.silu(y)
        y = self.convlayer_time(y)

        return y + self.residual_layer(residue)


class SwitchSequential(nn.Module):
    def __init__(self, *layers):
        super().__init__()
        self.layers = nn.ModuleList(layers)

    def forward(self, x, context, time):
        for layer in self.layers:
            if isinstance(layer, UNET_ResidualBlock):
                x = layer(x, time)
            elif isinstance(layer, UNET_AttentionBlock):
                x = layer(x, context)
            else:
                x = layer(x)
        return x


class UNET(nn.Module):
    def __init__(self, in_channels: int = 4) -> None:
        super().__init__()

        self.encoder = nn.ModuleList(
            [
                SwitchSequential(nn.Conv2d(in_channels, 320, kernel_size=3, padding=1)), # Modified for in_channels
                SwitchSequential(
                    UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)
                ),
                SwitchSequential(
                    UNET_ResidualBlock(320, 320), UNET_AttentionBlock(8, 40)
                ),
                SwitchSequential(
                    nn.Conv2d(320, 320, kernel_size=3, stride=2, padding=1)
                ),
                SwitchSequential(
                    UNET_ResidualBlock(320, 640), UNET_AttentionBlock(8, 80)
                ),
                SwitchSequential(
                    UNET_ResidualBlock(640, 640), UNET_AttentionBlock(8, 80)
                ),
                SwitchSequential(
                    nn.Conv2d(640, 640, kernel_size=3, stride=2, padding=1)
                ),
                SwitchSequential(
                    UNET_ResidualBlock(640, 1280), UNET_AttentionBlock(8, 160)
                ),
                SwitchSequential(
                    UNET_ResidualBlock(1280, 1280), UNET_AttentionBlock(8, 160)
                ),
                SwitchSequential(
                    nn.Conv2d(1280, 1280, kernel_size=3, stride=2, padding=1)
                ),
                SwitchSequential(UNET_ResidualBlock(1280, 1280)),
                SwitchSequential(UNET_ResidualBlock(1280, 1280)),
            ]
        )

        self.bottleneck = nn.ModuleList(
            [
                UNET_ResidualBlock(1280, 1280),
                UNET_AttentionBlock(8, 160),
                UNET_ResidualBlock(1280, 1280),
            ]
        )

        self.decoder = nn.ModuleList(
            [
                SwitchSequential(UNET_ResidualBlock(2560, 1280)),
                SwitchSequential(UNET_ResidualBlock(2560, 1280)),
                SwitchSequential(UNET_ResidualBlock(2560, 1280), Upsample(1280)),
                SwitchSequential(
                    UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)
                ),
                SwitchSequential(
                    UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)
                ),
                SwitchSequential(
                    UNET_ResidualBlock(2560, 1280), UNET_AttentionBlock(8, 160)
                ),
                SwitchSequential(
                    UNET_ResidualBlock(1920, 1280),
                    UNET_AttentionBlock(8, 160),
                    Upsample(1280),
                ),
                SwitchSequential(
                    UNET_ResidualBlock(1920, 640),
                    UNET_AttentionBlock(8, 80),
                ),
                SwitchSequential(
                    UNET_ResidualBlock(1920, 640),
                    UNET_AttentionBlock(8, 80),
                ),
                SwitchSequential(
                    UNET_ResidualBlock(1280, 640),
                    UNET_AttentionBlock(8, 80),
                ),
                SwitchSequential(
                    UNET_ResidualBlock(1280, 640),
                    UNET_AttentionBlock(8, 80),
                ),
                SwitchSequential(
                    UNET_ResidualBlock(960, 640),
                    UNET_AttentionBlock(8, 80),
                    Upsample(640),
                ),
                SwitchSequential(
                    UNET_ResidualBlock(960, 320),
                    UNET_AttentionBlock(8, 80),
                ),
                SwitchSequential(
                    UNET_ResidualBlock(640, 320),
                    UNET_AttentionBlock(8, 80),
                ),
                SwitchSequential(
                    UNET_ResidualBlock(640, 320),
                    UNET_AttentionBlock(8, 80),
                ),
            ]
        )

    def forward(self, x, context, time):

        skip_connections = []
        for layer in self.encoder:
            x = layer(x, context, time)
            skip_connections.append(x)

        x = self.bottleneck(x, context, time)

        for layer in self.decoder:
            x = torch.cat((x, skip_connections.pop()), dim=1)
            x = layer(x, context, time)

        return x


class UNET_Out(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, in_channels)
        self.convlayer = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1) # Standard 3x3 out

    def forward(self, x):
        x = self.groupnorm(x)
        x = F.silu(x)
        x = self.convlayer(x)

        return x


class Diffusion(nn.Module):
    def __init__(
        self,
        n_channels: int = 4
    ) -> None:
        super().__init__()
        self.time = TimeEmbedding(320)
        self.unet = UNET(in_channels=n_channels)
        self.output = UNET_Out(320, n_channels)

    def forward(self, x, context, time):

        time = self.time(time)

        out = self.unet(x, context, time)

        out = self.output(out)

        return out

## 3. Dataset Loading (CIFAR-10) with Verification

In [None]:
transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

print("Loading Dataset...")
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)

# --- Verification Logic ---
print("\n--- Verifying Dataset ---")
print(f"Total samples: {len(train_dataset)}")

# Check one batch
sample_batch, sample_labels = next(iter(train_loader))
print(f"Batch Shape: {sample_batch.shape}")
print(f"Expected Shape: ({BATCH_SIZE}, {CHANNELS}, {IMAGE_SIZE}, {IMAGE_SIZE})")

assert sample_batch.shape[1] == CHANNELS, f"Channel mismatch! Expected {CHANNELS}, got {sample_batch.shape[1]}"
assert sample_batch.shape[2] == IMAGE_SIZE, f"Height mismatch! Expected {IMAGE_SIZE}, got {sample_batch.shape[2]}"
assert sample_batch.shape[3] == IMAGE_SIZE, f"Width mismatch! Expected {IMAGE_SIZE}, got {sample_batch.shape[3]}"

print(f"Value Range: Min {sample_batch.min():.4f}, Max {sample_batch.max():.4f}")

if sample_batch.shape[1] == 3:
    # Visualize a grid
    print("Visualizing Sample Images...")
    grid_img = sample_batch[:8] # Take first 8
    grid_img = (grid_img * 0.5 + 0.5).permute(0, 2, 3, 1).numpy() # Denormalize
    
    fig, axes = plt.subplots(1, 8, figsize=(12, 2))
    for i, ax in enumerate(axes):
        ax.imshow(grid_img[i])
        ax.axis('off')
    plt.show()

print("--- Dataset Verification Passed ---")

## 4. Training Loop with Logging and Checkpointing

In [None]:
# Initialize Diffusion Model
print("Initializing Model...")
model = Diffusion(n_channels=CHANNELS).to(DEVICE)

# Verify Model Compilation
dummy_input = torch.randn(1, CHANNELS, IMAGE_SIZE, IMAGE_SIZE).to(DEVICE)
dummy_context = torch.zeros(1, 77, 768).to(DEVICE)
dummy_time = torch.tensor([1]).to(DEVICE)
try:
    _ = model(dummy_input, dummy_context, dummy_time)
    print("Model forward pass check: SUCCESS")
except Exception as e:
    print(f"Model forward pass check: FAILED with error: {e}")
    raise e

# Optimizer and Loss
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.MSELoss()
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

# DDPM Sampler
generator = torch.Generator(device=DEVICE)
sampler = DDPMSampler(generator, num_training_steps=N_TIMESTEPS)

# Training
print("Starting Training...")
output_history = {"loss": []}

try:
    for epoch in range(EPOCHS):
        model.train()
        epoch_loss = 0
        
        # Progress bar for the epoch
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        
        for step, (images, _) in enumerate(pbar):
            # Index verification log (every 100 steps)
            if step % 200 == 0:
                # Just a sanity print to ensure loop is progressing correctly
                pass # tqdm handles display, but we can verify data integrity if needed
                
            images = images.to(DEVICE)
            batch_size = images.shape[0]
            
            # Sample random timesteps
            t = torch.randint(0, N_TIMESTEPS, (batch_size,), device=DEVICE).long()
            
            # Add noise to images
            noisy_images, noise = sampler.add_noise(images, t)
            
            # Create unconditional context (null tokens)
            context = torch.zeros(batch_size, 77, 768).to(DEVICE)
            
            # Forward pass
            noise_pred = model(noisy_images, context, t)
            
            # Calculate loss
            loss = criterion(noise_pred, noise)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({"Loss": loss.item()})
        
        scheduler.step()
        avg_loss = epoch_loss / len(train_loader)
        output_history["loss"].append(avg_loss)
        print(f"Epoch {epoch+1} Completed. Average Loss: {avg_loss:.4f}")
        
        # Save Checkpoint Logic
        if (epoch + 1) % SAVE_INTERVAL == 0:
            print(f"Saving checkpoint at epoch {epoch+1}...")
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }
            torch.save(checkpoint, f"diffusion_checkpoint_epoch_{epoch+1}.pt")
            print(f"Checkpoint saved: diffusion_checkpoint_epoch_{epoch+1}.pt")

except KeyboardInterrupt:
    print("Training interrupted by user. Saving emergency checkpoint...")
    torch.save(model.state_dict(), "diffusion_emergency_checkpoint.pt")
    print("Emergency checkpoint saved.")
except Exception as e:
    print(f"Error occurred: {e}. Saving emergency checkpoint...")
    torch.save(model.state_dict(), "diffusion_error_checkpoint.pt")
    raise e

print("Training process finished.")