# Diffusion Model Lab: Building Stable Diffusion from Scratch

(adapted from Harvard seminars)

## Introduction
Welcome to the Diffusion Models Laboratory! In this lab, you will implement essential components
of a Stable Diffusion-like model and have a working text-to-image generation system by the end.

## Learning Objectives
By completing this lab, you will:
1. Understand the theory behind diffusion models
2. Implement key components of a diffusion-based generative model
3. Learn how to train and sample from diffusion models
4. Experience how attention mechanisms enable text-conditioning
5. Explore the efficiency benefits of latent diffusion models

## Lab Overview
This lab will guide you through building a complete diffusion-based generative model with the following components:
1) Forward and reverse diffusion processes for generating new content
2) Neural networks with appropriate inductive biases for images (U-Net architecture)
3) Cross-attention mechanisms to bridge text and visual features
4) Latent spaces for efficient image representation via autoencoders

We'll use the MNIST dataset (28x28 pixel handwritten digits) as our training data to make the
process reasonably fast. By the end, your model will convert a text prompt like "7" into
an image of the handwritten digit 7.

## Lab Structure
This lab is divided into several tasks, each focusing on a different component of diffusion models:
- Task 1: Implement forward and reverse diffusion processes
- Task 2: Implement the U-Net architecture for diffusion
- Task 3: Develop loss functions and sampling mechanisms
- Task 4: Build attention mechanisms for conditioning
- Task 5: Create a latent diffusion model with an autoencoder

For each task, you'll find:
- Theoretical background explaining key concepts
- Starter code with TODOs for you to complete
- Testing functions to verify your implementations
- Visualization helpers to understand what's happening

## Managing computational resources
If you encounter resource limitations (e.g., end of Colab quota or VPN issues):
- Use CPU for training (can be done on Colab or your laptop)
- Implement checkpointing to save model progress (save checkpoints to Google Drive in Colab)
- Load from checkpoints to continue training from where you left off

### Google Drive Integration
When running this lab in Google Colab, you can save your model checkpoints to Google Drive, which prevents
losing your training progress if your Colab session crashes or times out. To enable this:

1. Set `USE_GOOGLE_DRIVE = True` near the top of the code
2. When prompted, authorize the connection to your Google Drive
3. Your checkpoints will be saved to the "diffusion_model_checkpoints" folder in your Drive

You can always resume training from these checkpoints even if you start a new Colab session.

Let's get started with building diffusion models!

In [None]:
# Required library imports
import functools
import os  # Added for file operations
import time  # Added for timing operations
import math
from pathlib import Path  # Added for path handling

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, TensorDataset
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from tqdm import tqdm, trange
from einops import rearrange

In [None]:
# Detect if running in Google Colab
try:
    import google.colab
    IN_COLAB = True
    print("Running in Google Colab")
except:
    IN_COLAB = False
    print("Not running in Google Colab")

# Function to setup Google Drive for checkpoint storage
def setup_checkpoint_dir(use_google_drive=False):
    """
    Sets up the directory for storing model checkpoints.

    Args:
        use_google_drive: If True and running in Colab, mounts Google Drive
                         and creates checkpoint directory there

    Returns:
        Path to the checkpoint directory
    """
    if use_google_drive and IN_COLAB:
        try:
            from google.colab import drive
            # Mount Google Drive
            drive.mount('/content/drive')

            # Create checkpoint directory in Google Drive
            checkpoint_dir = "/content/drive/MyDrive/diffusion_model_checkpoints"
            os.makedirs(checkpoint_dir, exist_ok=True)
            print(f"Checkpoint directory set to Google Drive: {checkpoint_dir}")
            return checkpoint_dir
        except Exception as e:
            print(f"Failed to set up Google Drive for checkpoints: {e}")
            print("Falling back to local checkpoint directory")

    # Use local directory if not using Google Drive or if setup failed
    checkpoint_dir = "model_checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    print(f"Checkpoint directory set to local path: {checkpoint_dir}")
    return checkpoint_dir

In [None]:
# Set up device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Optional: Use Google Drive for checkpoints when running in Colab
# Set to True to save checkpoints to Google Drive
USE_GOOGLE_DRIVE = False  # Change to True to use Google Drive for checkpoints

# Create directory for saving models
CHECKPOINT_DIR = setup_checkpoint_dir(use_google_drive=USE_GOOGLE_DRIVE)

## 1. Forward and Reverse Diffusion Processes

The core concept of diffusion models is to gradually add noise to data samples until they become
pure noise, then learn to reverse this process. This allows us to generate new samples by starting
with random noise and applying the learned denoising process.

### Forward Diffusion
In forward diffusion, we gradually add noise to our data samples according to:

x(t + Δt) = x(t) + σ(t) √(Δt) * ε

where:
- σ(t) is the noise strength at time t
- Δt is the step size
- ε ~ N(0, 1) is a standard normal random variable

As t increases, our sample becomes increasingly noisy until it's indistinguishable from pure noise.

### Task 1.1: Implement the Forward Diffusion Process

In [None]:
# Example noise strength function: constant value regardless of time
def constant_noise_strength(t):
    """Returns a constant noise strength regardless of time."""
    return 1.0

# TODO: Implement the forward diffusion process
def forward_diffusion_1D(initial_sample, noise_strength_fn, initial_time, num_steps, step_size):
    """
    Simulates the forward diffusion process in 1D.

    Args:
        initial_sample: Starting value (scalar)
        noise_strength_fn: Function that returns noise strength at given time
        initial_time: Starting time value
        num_steps: Number of diffusion steps to simulate
        step_size: Size of each time step

    Returns:
        tuple: (trajectory array, time array)
    """
    # Initialize arrays to store the trajectory and time values
    trajectory = np.zeros(num_steps + 1); trajectory[0] = initial_sample
    time_values = initial_time + np.arange(num_steps + 1)*step_size

    # Perform Euler-Maruyama steps for the diffusion process
    for i in range(num_steps):
        #  1. Get the current noise strength using noise_strength_fn
        noise_strength = noise_strength_fn(time_values[i])

        ############# YOUR CODE HERE (2 lines)
        #  TODO: 2. Sample random noise from a standard normal distribution
        #  TODO: 3. Update the trajectory according to the forward diffusion equation
        #############

    return trajectory, time_values


# Test the forward diffusion function
def visualize_forward_diffusion():
    """Visualize multiple forward diffusion trajectories starting from the same point."""
    num_steps = 100
    initial_time = 0
    step_size = 0.1
    initial_sample = 0
    num_trajectories = 5

    plt.figure(figsize=(10, 6))
    for i in range(num_trajectories):
        trajectory, time_values = forward_diffusion_1D(
            initial_sample,
            constant_noise_strength,
            initial_time,
            num_steps,
            step_size
        )
        plt.plot(time_values, trajectory)

    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Sample Value', fontsize=14)
    plt.title('Forward Diffusion Process - Multiple Trajectories', fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.show()

### Reverse Diffusion

The reverse diffusion process allows us to generate new data by starting from noise and
progressively denoising it. The reverse process is given by:

x(t + Δt) = x(t) + σ(T-t)² · ∇_x log p(x, T-t) · Δt + σ(T-t) · √(Δt) · ε

The function ∇_x log p(x, t) is called the "score function" and represents the gradient of the
log probability density. If we can learn this score function, we can reverse the diffusion process
to generate data from noise.

For the special case where our initial distribution is concentrated at x₀ = 0 and the noise strength
is constant, the score function has the closed form:

s(x, t) = -x/(σ² · t)

In [None]:
# Score function for the special case
def simple_score_function(x, x0, noise_strength, t):
    """
    Score function for the special case where the initial distribution
    is concentrated at x0 and noise strength is constant.

    Args:
        x: Current sample value
        x0: Initial distribution center (typically 0)
        noise_strength: Current noise strength
        t: Current time

    Returns:
        Score value
    """
    score = score = - (x-x0)/((noise_strength**2)*t)

    return score

### Task 1.2: Implement the Reverse Diffusion Process

In [None]:
# TODO: Implement the reverse diffusion process
def reverse_diffusion_1D(initial_sample, noise_strength_fn, score_fn, final_time, num_steps, step_size):
    """
    Simulates the reverse diffusion process in 1D.

    Args:
        initial_sample: Starting value (scalar, typically sampled from noise distribution)
        noise_strength_fn: Function that returns noise strength at given time
        score_fn: Score function that guides the reverse process
        final_time: Total diffusion time (T)
        num_steps: Number of reverse steps to simulate
        step_size: Size of each time step

    Returns:
        tuple: (trajectory array, time array)
    """
    # Initialize arrays to store the trajectory and time values
    trajectory = np.zeros(num_steps + 1); trajectory[0] = initial_sample
    time_values = np.arange(num_steps + 1)*step_size

    for i in range(num_steps):
      #   1. Calculate the reverse time (final_time - current_time)
      reverse_time = final_time - time_values[i]
      #   2. Get the current noise strength
      noise_strength = noise_strength_fn(reverse_time)
      #   3. Calculate the score value using the score function
      score = score_fn(time_values[i], 0, noise_strength, final_time-time_values[i])

      ############# YOUR CODE HERE (2 lines) Perform reverse diffusion steps
      #   TODO: 4. Generate random noise
      #   TODO: 5. Update the trajectory according to the reverse diffusion equation
      #############

    return trajectory, time_values

# Test the reverse diffusion function
def visualize_reverse_diffusion():
    """Visualize multiple reverse diffusion trajectories starting from noise."""
    num_steps = 100
    step_size = 0.1
    final_time = 11
    num_trajectories = 5

    plt.figure(figsize=(10, 6))
    for i in range(num_trajectories):
        # Start from a noisy point (sample from the noise distribution)
        noisy_sample = np.random.normal(loc=0, scale=final_time)

        trajectory, time_values = reverse_diffusion_1D(
            noisy_sample,
            constant_noise_strength,
            simple_score_function,
            final_time,
            num_steps,
            step_size
        )
        plt.plot(time_values, trajectory)

    plt.xlabel('Time', fontsize=14)
    plt.ylabel('Sample Value', fontsize=14)
    plt.title('Reverse Diffusion Process - Multiple Trajectories', fontsize=16)
    plt.grid(True, alpha=0.3)
    plt.show()

## 2. Neural Network Architecture for Image Diffusion

To effectively apply diffusion models to images, we need neural network architectures that work well
with spatial data. The U-Net architecture is an excellent choice because it:

1. Maintains spatial information through skip connections
2. Processes images at multiple resolutions via downsampling and upsampling
3. Can effectively capture both local and global image features

Since our score function must be time-dependent, we also need a way to condition our network on time.
We'll use sinusoidal time embeddings that convert a scalar time value into a high-dimensional vector.

### Task 2.1: Implement Time Embedding Modules

In [None]:
# Time embedding modules
class GaussianFourierProjection(nn.Module):
    """
    Gaussian random features for encoding time steps using sinusoidal functions.
    This creates a high-dimensional representation of time that helps the model
    distinguish between different diffusion timesteps.
    """
    def __init__(self, embedding_dim, scale=30.0):
        super().__init__()
        # Randomly sample frequencies during initialization
        # These frequencies are fixed during optimization
        self.weight = nn.Parameter(torch.randn(embedding_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        """
        Args:
            x: Time values [batch_size]

        Returns:
            Time embeddings [batch_size, embedding_dim]
        """

        x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)


class TimeEmbedding(nn.Module):
    """
    A dense layer that reshapes time embeddings for injection into convolutional layers.
    This allows time information to be added directly to feature maps.
    """
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        """
        Args:
            x: Time embeddings [batch_size, input_dim]

        Returns:
            Reshaped time embeddings [batch_size, output_dim, 1, 1]
            which can be broadcast-added to convolutional feature maps
        """
        # TODO: Apply the dense layer and reshape the output for broadcasting to feature maps
        # The output should have shape [batch_size, output_dim, 1, 1]

        # YOUR CODE HERE
        return None  # Replace with your implementation

### Task 2.2: Implement the U-Net Architecture for Diffusion Models

In [None]:
class DiffusionUNet(nn.Module):
    """
    U-Net architecture for image-based diffusion models with time conditioning.
    """
    def __init__(self, noise_distribution_fn, channels=[32, 64, 128, 256], embed_dim=256):
        """
        Initialize a time-dependent score-based network based on U-Net architecture.

        Args:
            noise_distribution_fn: Function that maps time t to the standard deviation
                                  of the perturbation kernel p(x(t)|x(0))
            channels: Number of channels for feature maps at each resolution
            embed_dim: Dimensionality of time embeddings
        """
        super().__init__()

        # Time embedding layers
        self.time_embed = nn.Sequential(
            GaussianFourierProjection(embed_dim=embed_dim),
            nn.Linear(embed_dim, embed_dim)
        )

        self.conv_in = nn.Conv2d(1, channels[0], kernel_size=3, stride=1, padding=1, bias=False)
        # Create a TimeEmbedding layer to inject time information
        self.time_embed1 = TimeEmbedding(embed_dim, channels[0])
        self.norm1 = nn.GroupNorm(4, num_channels=channels[0])

        self.conv2 = nn.Conv2d(channels[0], channels[1], kernel_size=3, stride=2, bias=False)
        self.time_embed2 = TimeEmbedding(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = TimeEmbedding(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])

        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = TimeEmbedding(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # TODO: Define the decoding (upsampling) path with skip connections
        # For each level in the upsampling path:
        # 1. Create a ConvTranspose2d layer
        # 2. Create a TimeEmbedding layer
        # 3. Add normalization

        # YOUR CODE HERE - Decoding Path

        # TODO: Define the output layer

        # Activation function: Swish/SiLU
        self.activation = nn.SiLU()

        # Store noise function for normalization during inference
        self.noise_distribution_fn = noise_distribution_fn

    def forward(self, x, t, class_labels=None):
        """
        Forward pass of the U-Net.

        Args:
            x: Input images [batch_size, 1, height, width]
            t: Diffusion timesteps [batch_size]
            class_labels: Optional class conditioning (not used in base model)

        Returns:
            Predicted score (noise estimate) [batch_size, 1, height, width]
        """
        # TODO: Implement the forward pass through the U-Net
        # 1. Get time embeddings
        # 2. Pass through the encoding path
        # 3. Pass through the decoding path with skip connections
        # 4. Apply the output layer
        # 5. Normalize the output by the noise level

        # YOUR CODE HERE

        # Obtain the Gaussian random feature embedding for t
        temb = self.activation(self.time_embed(t))

        # TODO: Implement encoding path (downsampling)
        h1 = None  # Replace with your implementation
        h2 = None  # Replace with your implementation
        h3 = None  # Replace with your implementation
        h4 = None  # Replace with your implementation

        # TODO: Implement decoding path (upsampling) with skip connections
        h = None  # Replace with your implementation

        # TODO: Normalize output
        return None  # Replace with your implementation

### Task 2.3: Implement a Residual U-Net Variant (Optional - for bonus points)

For advanced students: implement a variant of the U-Net with residual connections
instead of concatenation for skip connections.

In [None]:
class DiffusionUNetResidual(nn.Module):
    """
    Alternative U-Net architecture with residual connections instead of concatenation.
    """
    def __init__(self, noise_distribution_fn, channels=[32, 64, 128, 256], embed_dim=256):
        """
        Initialize a U-Net with residual skip connections instead of concatenation.

        Args:
            noise_distribution_fn: Function that maps time t to the standard deviation
            channels: Number of channels for feature maps at each resolution
            embed_dim: Dimensionality of time embeddings
        """
        super().__init__()

        # TODO: Implement the residual U-Net architecture
        # The structure is similar to the regular U-Net but uses addition
        # instead of concatenation for skip connections

        # YOUR CODE HERE

    def forward(self, x, t, class_labels=None):
        """
        Forward pass of the U-Net with residual connections.

        Args:
            x: Input images [batch_size, 1, height, width]
            t: Diffusion timesteps [batch_size]
            class_labels: Optional class conditioning (not used in this model)

        Returns:
            Predicted score (noise estimate) [batch_size, 1, height, width]
        """
        # TODO: Implement the forward pass with residual connections

        # YOUR CODE HERE

        return None  # Replace with your implementation

## 3. Training the Diffusion Model

Now we'll define the specific diffusion process and the loss function to train our model.
Our forward process follows the stochastic differential equation (SDE):

dx = σ^t dw

where σ is a noise parameter and w is the Wiener process (Brownian motion).

### Task 3.1: Implement the Noise Distribution and Diffusion Functions

In [None]:
# TODO: Implement the marginal noise distribution
def marginal_prob_std(t, sigma=25.0):
    """
    Compute the standard deviation of p_{0t}(x(t) | x(0)).

    Args:
        t: Time values [batch_size]
        sigma: Base noise multiplier

    Returns:
        Standard deviation at time t
    """

    t = = torch.tensor(t, device=device)
    return torch.sqrt((sigma**(2 * t) - 1.) / 2. / np.log(sigma))


def diffusion_coeff(t, sigma=25.0):
    """
    Compute the diffusion coefficient of the SDE.

    Args:
        t: Time values [batch_size]
        sigma: Base noise multiplier

    Returns:
        Diffusion coefficient at time t
    """
    return torch.tensor(sigma**t, device=device)

# Create functions with default parameter
noise_distribution_fn = functools.partial(marginal_prob_std, sigma=25.0)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=25.0)

### Task 3.2: Implement the Diffusion Loss Function

In [None]:
# TODO: Implement the loss function for training the diffusion model
def diffusion_loss_fn(model, images, noise_distribution_fn, epsilon=1e-5):
    """
    Loss function for training score-based generative models.

    Args:
        model: The score model network
        images: Training images [batch_size, 1, height, width]
        noise_distribution_fn: Function computing noise std at time t
        epsilon: Small value for numerical stability

    Returns:
        Average loss value for the batch
    """
    # TODO: Implement the diffusion loss function with the following steps:

    # 5. Predict the noise component using the model
    # 6. Calculate MSE between predicted and actual noise, properly weighted

    # 1. Sample random timesteps for each image
    random_t = torch.rand(images.shape[0], device=images.device) * (1. - eps) + eps
    # 2. Get the noise standard deviation at sampled times
    std = marginal_prob_std(random_t)

    ##### YOUR CODE HERE
    # TODO: 3. Generate random noise with the same shape as input images
    z =
    # TODO: 4. Add noise to the input images according to the noise schedule

    # Sample random timesteps for each image
    batch_size = None  # Replace with your implementation
    random_t = None  # Replace with your implementation

    # Get the noise standard deviation
    noise_std = None  # Replace with your implementation

    # Generate random noise
    noise = None  # Replace with your implementation

    # Add noise to the images
    noisy_images = None  # Replace with your implementation

    # Predict the noise component
    predicted_noise = None

    # Calculate weighted MSE loss
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z)**2, dim=(1,2,3)))

    return loss

### Task 3.3: Implement the Conditional Loss Function

In [None]:
# TODO: Implement the conditional loss function for class-conditioned generation
def conditional_diffusion_loss_fn(model, images, class_labels, noise_distribution_fn, epsilon=1e-5):
    """
    Loss function for training class-conditioned diffusion models.

    Args:
        model: The score model network
        images: Training images [batch_size, 1, height, width]
        class_labels: Class labels for conditioning [batch_size]
        noise_distribution_fn: Function computing noise std at time t
        epsilon: Small value for numerical stability

    Returns:
        Average loss value for the batch
    """
    # TODO: Implement the conditional diffusion loss function
    # This is similar to the unconditional loss, but passes class_labels to the model

    # YOUR CODE HERE

    return None  # Replace with your implementation

### Task 3.4: Implement the Euler-Maruyama Sampler

In [None]:
def euler_maruyama_sampler(
    score_model,
    noise_distribution_fn,
    diffusion_coeff_fn,
    batch_size=64,
    img_shape=(1, 28, 28),
    num_steps=500,
    device="cuda",
    epsilon=1e-3,
    class_labels=None
):
    """
    Generate samples from the score-based model using Euler-Maruyama solver.

    Args:
        score_model: The trained score model
        noise_distribution_fn: Function mapping time to noise std
        diffusion_coeff_fn: Function mapping time to diffusion coefficient
        batch_size: Number of samples to generate
        img_shape: Shape of a single image
        num_steps: Number of sampling steps
        device: Device to run on ('cuda' or 'cpu')
        epsilon: Small time value for numerical stability
        class_labels: Optional class labels for conditional generation

    Returns:
        Generated samples [batch_size, channels, height, width]
    """

    # Start from pure noise
    t = torch.ones(batch_size, device=device)
    init_noise = torch.randn(batch_size, *img_shape, device=device) * marginal_prob_std(t)[:, None, None, None]

    # Setup time steps
    time_steps = torch.linspace(1., epsilon, num_steps, device=device)
    step_size = time_steps[0] - time_steps[1]

    # Progressive denoising
    x = init_noise
    with torch.no_grad():
        for time_step in tqdm(time_steps):
          batch_time_step = torch.ones(batch_size, device=device) * time_step
          g = diffusion_coeff(batch_time_step)
          mean_x = x + (g**2)[:, None, None, None] * score_model(x, batch_time_step, y=class_labels) * step_size
          x = mean_x + torch.sqrt(step_size) * g[:, None, None, None] * torch.randn_like(x)

    return x

# Utility function to save generated samples
def save_samples(samples, filename):
    """
    Save generated samples as an image.

    Args:
        samples: Generated samples [batch_size, channels, height, width]
        filename: Path to save the image
    """
    # Ensure pixel values are in valid range
    samples = samples.clamp(0.0, 1.0)

    # Create a grid of images
    sample_grid = make_grid(samples, nrow=int(np.sqrt(samples.shape[0])))

    # Convert to numpy and transpose dimensions for matplotlib
    sample_np = sample_grid.permute(1, 2, 0).cpu().numpy()

    # Save using matplotlib
    plt.figure(figsize=(10, 10))
    plt.axis('off')
    plt.imshow(sample_np, vmin=0., vmax=1.)
    plt.savefig(filename, bbox_inches='tight', pad_inches=0.1, dpi=300)
    plt.close()

    return sample_np  # Return for display

## 4. Training Utility Functions

Below are utility functions for training diffusion models and saving/loading checkpoints.
These functions include proper checkpointing to allow resuming training sessions.

### Task 4.1: Implement Model Training Function with Checkpointing

In [None]:
# TODO: Implement the diffusion model training function
def train_diffusion_model(
    model,
    dataset,
    noise_distribution_fn,
    num_epochs=50,
    batch_size=256,
    learning_rate=1e-4,
    conditional=False,
    checkpoint_freq=5,
    checkpoint_path=None,  # Changed from a default string to None
    resume_training=False
):
    """
    Train a diffusion model with checkpointing.

    Args:
        model: The diffusion model to train
        dataset: Dataset for training
        noise_distribution_fn: Function to compute noise std at different times
        num_epochs: Number of epochs to train for
        batch_size: Batch size for training
        learning_rate: Learning rate for the optimizer
        conditional: Whether the model is class-conditional
        checkpoint_freq: How often to save checkpoints (in epochs)
        checkpoint_path: Path to save/load model checkpoints. If None, uses CHECKPOINT_DIR.
        resume_training: Whether to resume from a checkpoint

    Returns:
        Trained model and training history
    """
    # Set default checkpoint path if not provided
    if checkpoint_path is None:
        checkpoint_filename = "conditional_diffusion.pt" if conditional else "diffusion_model.pt"
        checkpoint_path = os.path.join(CHECKPOINT_DIR, checkpoint_filename)

    # TODO: Implement the training function with these steps:
    # 1. Setup data loader
    # 2. Initialize optimizer and learning rate scheduler
    # 3. Initialize tracking variables
    # 4. Resume from checkpoint if requested
    # 5. Implement the training loop:
    #    a. Process batches and compute loss (conditional or not)
    #    b. Backpropagation and optimization
    #    c. Track statistics and update progress
    # 6. Save checkpoints periodically

    # YOUR CODE HERE

    # Setup data loader
    data_loader = None  # Replace with your implementation

    # Initialize optimizer and scheduler
    optimizer = None  # Replace with your implementation
    scheduler = None  # Replace with your implementation

    # Initialize tracking variables
    start_epoch = 0
    training_history = []

    # Resume from checkpoint if requested
    if resume_training and os.path.exists(checkpoint_path):
        # TODO: Implement checkpoint loading
        pass

    # Create full path for checkpoint (parent directories)
    os.makedirs(os.path.dirname(os.path.abspath(checkpoint_path)), exist_ok=True)

    # Setup training loop
    model.train()
    progress_bar = trange(start_epoch, num_epochs)

    for epoch in progress_bar:
        # TODO: Implement epoch loop
        epoch_loss = 0.0
        num_batches = 0

        # TODO: Process batches

        # TODO: Update learning rate

        # TODO: Calculate average loss and update history

        # TODO: Save checkpoint periodically and at the end
        # When implementing the checkpoint saving, use the checkpoint_path variable
        # This ensures compatibility with Google Drive storage when enabled

    # Final model saving
    model.eval()
    return model, training_history

### Task 4.2: Implement Model Loading Function

In [None]:
# TODO: Implement the function to load a model from a checkpoint
def load_model_from_checkpoint(
    model_class,
    checkpoint_path,
    noise_distribution_fn,
    map_location=None,
    **model_kwargs
):
    """
    Load a diffusion model from a checkpoint.

    Args:
        model_class: The model class to instantiate
        checkpoint_path: Path to the checkpoint file
        noise_distribution_fn: Function to compute noise distribution
        map_location: Device mapping for loading the checkpoint
        **model_kwargs: Additional arguments for the model constructor

    Returns:
        Loaded model and training history
    """
    # TODO: Implement the model loading function with these steps:
    # 1. Set map_location to the appropriate device
    # 2. Check if checkpoint exists
    # 3. Load the checkpoint
    # 4. Create a model with the same architecture
    # 5. Handle parallel models if needed
    # 6. Load model weights
    # 7. Return the model and training history

    # YOUR CODE HERE

    if map_location is None:
        map_location = device

    # Check if checkpoint exists
    if not os.path.exists(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}")

    # TODO: Load checkpoint and model
    checkpoint = None  # Replace with your implementation
    model = None  # Replace with your implementation

    # TODO: Handle parallel models and load weights

    return model, checkpoint.get('history', [])

## 5. Attention Mechanisms for Conditional Generation

To enable conditional image generation (generating a specific digit based on a text prompt),
we'll enhance our model with attention mechanisms. Attention allows the model to focus on
relevant parts of the input when making predictions, and helps bridge the gap between
text and image modalities.

We'll implement:
1. Cross-Attention: To make the image features attend to text features
2. Self-Attention: To make each part of the image aware of other parts
3. Transformer blocks: Combining attention with feed-forward networks

### Task 5.1: Implement the Text Embedding Layer

In [None]:
# TODO: Implement word embedding for digit labels (0-9)
class TextEmbedding(nn.Module):
    """
    Simple word embedding layer to convert digit labels into vectors.
    """
    def __init__(self, vocab_size=10, embed_dim=256):
        super(TextEmbedding, self).__init__()
        # TODO: Initialize an embedding layer for the vocabulary
        # Add +1 for padding token if needed

        # YOUR CODE HERE
        self.embedding = None  # Replace with your implementation

    def forward(self, text_indices):
        """
        Args:
            text_indices: Integer indices representing text tokens [batch_size]

        Returns:
            Embedded vectors [batch_size, embed_dim]
        """
        # TODO: Apply the embedding layer to the input indices

        # YOUR CODE HERE
        return None  # Replace with your implementation

### Task 5.2: Implement the Cross-Attention Mechanism

In [None]:
# TODO: Implement the cross-attention layer
class CrossAttention(nn.Module):
    """
    Cross-attention layer that allows image features to attend to text features.
    """
    def __init__(self, query_dim, key_dim=None, embed_dim=256, num_heads=1):
        """
        Initialize cross-attention module.

        Args:
            query_dim: Dimension of query vectors
            key_dim: Dimension of key/value vectors (if None, same as query_dim)
            embed_dim: Output embedding dimension
            num_heads: Number of attention heads (simplified to 1 for clarity)
        """
        super(CrossAttention, self).__init__()
        self.embed_dim = embed_dim
        self.query_dim = query_dim
        self.key_dim = key_dim if key_dim is not None else query_dim

        # TODO: Initialize linear projections for Q, K, V

        # YOUR CODE HERE
        self.to_q = None  # Replace with your implementation
        self.to_k = None  # Replace with your implementation
        self.to_v = None  # Replace with your implementation

        self.is_self_attention = key_dim is None

    def forward(self, x, context=None):
        """
        Apply attention mechanism.

        Args:
            x: Query tensor [batch_size, seq_len_q, query_dim]
            context: Key/value tensor [batch_size, seq_len_kv, key_dim]
                    (if None, uses self-attention)

        Returns:
            Attended features [batch_size, seq_len_q, query_dim]
        """
        # TODO: Implement the attention mechanism with these steps:
        # 1. Handle self-attention vs cross-attention cases
        # 2. Project inputs to query, key, and value
        # 3. Calculate attention scores (dot product of query and key)
        # 4. Scale scores and apply softmax to get attention weights
        # 5. Apply attention weights to values

        # YOUR CODE HERE

        # Handle self-attention vs cross-attention
        q = None  # Replace with your implementation
        k = None  # Replace with your implementation
        v = None  # Replace with your implementation

        # Calculate attention scores
        attn_scores = None  # Replace with your implementation

        # Scale scores and apply softmax
        attn_weights = None  # Replace with your implementation

        # Apply attention weights to values
        output = None  # Replace with your implementation

        return output

### Task 5.3: Implement the Transformer Block

In [None]:
# TODO: Implement a transformer block combining self-attention, cross-attention, and feed-forward networks
class TransformerBlock(nn.Module):
    """
    Transformer block combining self-attention, cross-attention, and feed-forward layers.
    """
    def __init__(self, feature_dim, context_dim=None):
        """
        Initialize a transformer block.

        Args:
            feature_dim: Dimension of input features
            context_dim: Dimension of context features for cross-attention
        """
        super(TransformerBlock, self).__init__()

        # TODO: Initialize transformer components:
        # 1. Self-attention layer
        # 2. Cross-attention layer (if context_dim is provided)
        # 3. Normalization layers
        # 4. Feed-forward network

        # YOUR CODE HERE
        self.self_attn = None  # Replace with your implementation
        self.cross_attn = None  # Replace with your implementation

        self.norm1 = None  # Replace with your implementation
        self.norm2 = None  # Replace with your implementation
        self.norm3 = None  # Replace with your implementation

        self.ff_network = None  # Replace with your implementation

    def forward(self, x, context=None):
        """
        Apply transformer block.

        Args:
            x: Input features [batch_size, seq_len, feature_dim]
            context: Optional context features for cross-attention
                    [batch_size, context_len, context_dim]

        Returns:
            Transformed features [batch_size, seq_len, feature_dim]
        """
        # TODO: Implement the transformer block forward pass:
        # 1. Apply self-attention with residual connection
        # 2. Apply cross-attention with residual connection (if context is provided)
        # 3. Apply feed-forward network with residual connection

        # YOUR CODE HERE

        return x

### Task 5.4: Implement the Spatial Transformer for 2D Feature Maps

In [None]:
# TODO: Implement a transformer for spatial (image) data
class SpatialTransformer(nn.Module):
    """
    Transformer module for spatial data (images) that reshapes 2D feature maps
    into sequences, applies transformer operations, and reshapes back.
    """
    def __init__(self, feature_dim, context_dim=None):
        """
        Initialize spatial transformer.

        Args:
            feature_dim: Number of feature channels
            context_dim: Dimension of context features
        """
        super(SpatialTransformer, self).__init__()

        # TODO: Initialize a transformer block
        self.transformer = None  # Replace with your implementation

    def forward(self, x, context=None):
        """
        Apply spatial transformer to image features.

        Args:
            x: Image features [batch_size, channels, height, width]
            context: Context features [batch_size, context_len, context_dim]

        Returns:
            Transformed image features [batch_size, channels, height, width]
        """
        # TODO: Implement spatial transformer forward pass:
        # 1. Save input and extract dimensions
        # 2. Reshape spatial dimensions into sequence length
        # 3. Apply transformer block
        # 4. Reshape back to spatial representation
        # 5. Add residual connection

        # YOUR CODE HERE

        # Save input and extract dimensions
        orig_x = x
        batch, channels, height, width = x.shape

        # Reshape spatial dimensions into sequence length
        # [batch, channels, height, width] -> [batch, height*width, channels]
        x = None  # Replace with your implementation

        # Apply transformer
        x = None  # Replace with your implementation

        # Reshape back to spatial representation
        # [batch, height*width, channels] -> [batch, channels, height, width]
        x = None  # Replace with your implementation

        # Add residual connection
        return None  # Replace with your implementation

### Task 5.5: Implement the Conditional Diffusion U-Net Model

In [None]:
# TODO: Implement a U-Net with cross-attention for text-conditioned image generation
class ConditionalDiffusionUNet(nn.Module):
    """
    U-Net model with cross-attention for text-conditioned image generation.
    """
    def __init__(self, noise_distribution_fn, channels=[32, 64, 128, 256],
                 embed_dim=256, text_dim=256, num_classes=10):
        """
        Initialize conditional U-Net model with attention mechanisms.

        Args:
            noise_distribution_fn: Function that maps time t to noise standard deviation
            channels: List of channel dimensions for each resolution level
            embed_dim: Dimension of time embeddings
            text_dim: Dimension of text embeddings
            num_classes: Number of possible classes (digits 0-9 for MNIST)
        """
        super(ConditionalDiffusionUNet, self).__init__()

        # TODO: Initialize model components:
        # 1. Time embedding
        # 2. Text embedding
        # 3. Encoder (downsampling) path with attention blocks
        # 4. Decoder (upsampling) path with attention blocks and skip connections
        # 5. Output layer

        # YOUR CODE HERE

        # Time embedding
        self.time_embed = None  # Replace with your implementation

        # Text embedding for digit conditioning
        self.text_embedding = None  # Replace with your implementation

        # TODO: Implement encoding (downsampling) path

        # TODO: Implement decoding (upsampling) path with skip connections

        # Activation function
        self.activation = nn.SiLU()

        # Store noise function for normalization
        self.noise_distribution_fn = noise_distribution_fn

    def forward(self, x, t, class_labels=None):
        """
        Forward pass including text conditioning via cross-attention.

        Args:
            x: Input noisy images [batch_size, 1, height, width]
            t: Diffusion timesteps [batch_size]
            class_labels: Class indices for conditioning [batch_size]

        Returns:
            Predicted score (noise estimate) [batch_size, 1, height, width]
        """
        # TODO: Implement the conditional forward pass with these steps:
        # 1. Get time embedding
        # 2. Get text conditioning
        # 3. Apply encoding path with attention to text features
        # 4. Apply decoding path with attention and skip connections
        # 5. Generate output and normalize by noise level

        # YOUR CODE HERE

        # Get time embedding
        time_emb = None  # Replace with your implementation

        # Get text conditioning
        text_emb = None  # Replace with your implementation

        # TODO: Implement encoding path with attention

        # TODO: Implement decoding path with attention and skip connections

        # TODO: Generate and normalize output

        return None  # Replace with your implementation

## 6. Latent Diffusion with Autoencoders

Instead of performing diffusion directly in pixel space, we can operate in a compressed latent space.
This approach, introduced in the Stable Diffusion paper, has several advantages:

1. Computational efficiency: Working with smaller latent representations is faster
2. Better modeling: The latent space may capture more semantic information
3. Better quality: Can lead to higher quality generation

We'll implement an autoencoder to compress MNIST images to a latent space, then
perform diffusion in this compressed representation.

### Task 6.1: Implement the Autoencoder Architecture

In [None]:
# TODO: Implement the autoencoder for latent space compression
class Autoencoder(nn.Module):
    """
    Autoencoder for compressing images to a latent representation and reconstructing them.
    """
    def __init__(self, channels=[8, 16, 32]):
        """
        Initialize autoencoder with encoder and decoder components.

        Args:
            channels: Channel dimensions for feature maps at different levels
        """
        super().__init__()

        # TODO: Implement encoder network to compress image to latent representation
        # The encoder should contain:
        # 1. Multiple convolutional layers with downsampling
        # 2. Batch normalization
        # 3. Activation functions

        # YOUR CODE HERE
        self.encoder = None  # Replace with your implementation

        # TODO: Implement decoder network to reconstruct image from latent representation
        # The decoder should contain:
        # 1. Multiple transpose convolutional layers with upsampling
        # 2. Batch normalization
        # 3. Activation functions
        # 4. Final output activation (e.g., sigmoid for 0-1 normalized images)

        # YOUR CODE HERE
        self.decoder = None  # Replace with your implementation

    def forward(self, x):
        """
        Full autoencoder forward pass: encode then decode.

        Args:
            x: Input images [batch_size, 1, height, width]

        Returns:
            Reconstructed images [batch_size, 1, height, width]
        """
        # TODO: Implement the autoencoder forward pass:
        # 1. Encode input to latent representation
        # 2. Decode latent back to image
        # 3. Handle any size mismatch (due to transpose convolutions)

        # YOUR CODE HERE

        # Encode input to latent representation
        latent = None  # Replace with your implementation

        # Decode latent back to image
        reconstructed = None  # Replace with your implementation

        # Handle potential size mismatch
        if reconstructed.shape != x.shape:
            # TODO: Implement center cropping to match original size
            pass

        return reconstructed

    def encode(self, x):
        """Encode input to latent representation."""
        return self.encoder(x)

    def decode(self, latent):
        """Decode latent representation to image."""
        return self.decoder(latent)

### Task 6.2: Implement the Latent Diffusion Model

In [None]:
# TODO: Implement the latent diffusion model
class LatentDiffusionModel(nn.Module):
    """
    Diffusion model operating in the latent space of an autoencoder.
    """
    def __init__(self, autoencoder, noise_distribution_fn,
                 latent_channels=32, hidden_channels=[64, 128, 256],
                 embed_dim=256, text_dim=256, num_classes=10):
        """
        Initialize latent diffusion model.

        Args:
            autoencoder: Trained autoencoder model
            noise_distribution_fn: Function mapping time to noise std
            latent_channels: Number of channels in autoencoder latent space
            hidden_channels: Channel dimensions for U-Net layers
            embed_dim: Dimension of time embeddings
            text_dim: Dimension of text/class embeddings
            num_classes: Number of possible conditioning classes
        """
        super().__init__()

        # TODO: Initialize the latent diffusion model:
        # 1. Store the autoencoder (frozen)
        # 2. Create a conditional diffusion U-Net for the latent space

        # YOUR CODE HERE

        # Store autoencoder (should be pre-trained and frozen)
        self.autoencoder = autoencoder
        # TODO: Freeze the autoencoder parameters

        # Complete channel specification for U-Net
        channels = [latent_channels] + hidden_channels

        # Use conditional U-Net for latent diffusion
        self.diffusion_model = None  # Replace with your implementation

    def forward(self, x, t, class_labels=None):
        """
        Forward pass for training.

        Args:
            x: Input latent representations [batch_size, latent_channels, h, w]
            t: Diffusion timesteps [batch_size]
            class_labels: Optional conditioning labels [batch_size]

        Returns:
            Predicted score/noise
        """
        # TODO: Implement the forward pass (just using the U-Net on latent inputs)

        # YOUR CODE HERE
        return None  # Replace with your implementation

    def encode_images(self, images):
        """
        Encode images to latent representations.

        Args:
            images: Input images [batch_size, 1, height, width]

        Returns:
            Latent representations [batch_size, latent_channels, h', w']
        """
        # TODO: Implement image encoding using the autoencoder

        # YOUR CODE HERE
        return None  # Replace with your implementation

    def decode_latents(self, latents):
        """
        Decode latent representations to images.

        Args:
            latents: Latent representations [batch_size, latent_channels, h', w']

        Returns:
            Reconstructed images [batch_size, 1, height, width]
        """
        # TODO: Implement latent decoding using the autoencoder

        # YOUR CODE HERE
        return None  # Replace with your implementation

    def generate(self, num_samples, class_labels=None, num_steps=500, latent_shape=None):
        """
        Generate images from noise via latent diffusion.

        Args:
            num_samples: Number of images to generate
            class_labels: Optional conditioning labels [num_samples]
            num_steps: Number of diffusion sampling steps
            latent_shape: Shape of latent representations

        Returns:
            Generated images [num_samples, 1, height, width]
        """
        # TODO: Implement image generation with these steps:
        # 1. Determine latent shape (if not provided)
        # 2. Generate latents using the diffusion model
        # 3. Decode latents to images

        # YOUR CODE HERE

        # Determine latent shape if not provided
        if latent_shape is None:
            # Default latent shape for MNIST with our encoder
            latent_shape = None  # Replace with your implementation

        # Generate latents using the diffusion model
        latents = None  # Replace with your implementation

        # Decode latents to images
        images = None  # Replace with your implementation

        return images

## 7. Putting It All Together: Training and Evaluation

In this section, you'll implement functions to train and evaluate your diffusion models.


In [None]:
# These functions are provided. You can modify them for your needs or create new ones.

def train_and_evaluate_basic_model(train_model=False):
    """Train and evaluate a basic unconditioned diffusion model."""
    print("\n1. Basic diffusion model (unconditioned)...\n")

    # Load MNIST dataset
    transform = transforms.ToTensor()
    mnist_train = MNIST('.', train=True, transform=transform, download=True)
    mnist_test = MNIST('.', train=False, transform=transform, download=True)

    # Model path
    model_path = os.path.join(CHECKPOINT_DIR, "basic_diffusion.pt")

    if train_model:
        # Create and train the model
        basic_model = torch.nn.DataParallel(DiffusionUNet(noise_distribution_fn=noise_distribution_fn))
        basic_model = basic_model.to(device)

        print(f"Training basic diffusion model...")
        basic_model, history = train_diffusion_model(
            basic_model,
            mnist_train,
            noise_distribution_fn,
            num_epochs=30,
            batch_size=256,
            learning_rate=1e-4,
            checkpoint_path=model_path,
            resume_training=False
        )

        # Plot training loss
        plt.figure(figsize=(10, 6))
        plt.plot([h['epoch'] for h in history], [h['loss'] for h in history])
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss')
        plt.savefig(os.path.join(CHECKPOINT_DIR, "basic_model_loss.png"))
        plt.close()

    try:
        # Load model from checkpoint
        print(f"Loading model from {model_path}")
        basic_model, _ = load_model_from_checkpoint(
            DiffusionUNet,
            model_path,
            noise_distribution_fn
        )

        # Generate samples
        print("Generating samples...")
        samples = euler_maruyama_sampler(
            basic_model,
            noise_distribution_fn,
            diffusion_coeff_fn,
            batch_size=25,
            num_steps=200,
            device=device
        )

        # Save samples
        save_path = os.path.join(CHECKPOINT_DIR, "basic_model_samples.png")
        save_samples(samples, save_path)
        print(f"Samples saved to {save_path}")

        return basic_model

    except FileNotFoundError:
        print(f"Model checkpoint not found at {model_path}. Train the model first with train_model=True.")
        return None


def train_and_evaluate_conditional_model(train_model=False):
    """Train and evaluate a conditional diffusion model with attention."""
    print("\n2. Conditional diffusion model with attention...\n")

    # Load MNIST dataset
    transform = transforms.ToTensor()
    mnist_train = MNIST('.', train=True, transform=transform, download=True)
    mnist_test = MNIST('.', train=False, transform=transform, download=True)

    # Model path
    model_path = os.path.join(CHECKPOINT_DIR, "conditional_diffusion.pt")

    if train_model:
        # Create and train the model
        cond_model = torch.nn.DataParallel(ConditionalDiffusionUNet(
            noise_distribution_fn=noise_distribution_fn,
            channels=[32, 64, 128, 256],
            embed_dim=256,
            text_dim=128,
            num_classes=10
        ))
        cond_model = cond_model.to(device)

        print(f"Training conditional diffusion model...")
        cond_model, history = train_diffusion_model(
            cond_model,
            mnist_train,
            noise_distribution_fn,
            num_epochs=50,
            batch_size=256,
            learning_rate=1e-4,
            conditional=True,
            checkpoint_path=model_path,
            resume_training=False
        )

        # Plot training loss
        plt.figure(figsize=(10, 6))
        plt.plot([h['epoch'] for h in history], [h['loss'] for h in history])
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss (Conditional Model)')
        plt.savefig(os.path.join(CHECKPOINT_DIR, "conditional_model_loss.png"))
        plt.close()

    try:
        # Load model from checkpoint
        print(f"Loading model from {model_path}")
        cond_model, _ = load_model_from_checkpoint(
            ConditionalDiffusionUNet,
            model_path,
            noise_distribution_fn,
            channels=[32, 64, 128, 256],
            embed_dim=256,
            text_dim=128,
            num_classes=10
        )

        # Generate samples for each digit
        for digit in range(10):
            print(f"Generating samples for digit {digit}...")

            # Create labels tensor for the digit
            labels = torch.ones(25, dtype=torch.long, device=device) * digit

            # Generate samples
            samples = euler_maruyama_sampler(
                cond_model,
                noise_distribution_fn,
                diffusion_coeff_fn,
                batch_size=25,
                num_steps=200,
                device=device,
                class_labels=labels
            )

            # Save samples
            save_path = os.path.join(CHECKPOINT_DIR, f"conditional_model_samples_digit{digit}.png")
            save_samples(samples, save_path)
            print(f"Samples for digit {digit} saved to {save_path}")

        return cond_model

    except FileNotFoundError:
        print(f"Model checkpoint not found at {model_path}. Train the model first with train_model=True.")
        return None


def train_and_evaluate_latent_model(train_model=False, train_autoencoder=False):
    """Train and evaluate a latent diffusion model."""
    print("\n3. Latent diffusion model with autoencoder...\n")

    # Load MNIST dataset
    transform = transforms.ToTensor()
    mnist_train = MNIST('.', train=True, transform=transform, download=True)
    mnist_test = MNIST('.', train=False, transform=transform, download=True)

    # Paths for saving models
    ae_path = os.path.join(CHECKPOINT_DIR, "autoencoder.pt")
    latent_model_path = os.path.join(CHECKPOINT_DIR, "latent_diffusion.pt")

    # First train or load the autoencoder
    autoencoder = Autoencoder(channels=[8, 16, 32]).to(device)

    if train_autoencoder:
        # Define loss function (MSE) and optimizer
        print("Training autoencoder...")
        mse_loss = nn.MSELoss()
        optimizer = Adam(autoencoder.parameters(), lr=1e-3)

        # Setup data loader
        data_loader = DataLoader(mnist_train, batch_size=256, shuffle=True, num_workers=4)

        # Training loop
        num_epochs = 20
        for epoch in range(num_epochs):
            epoch_loss = 0.0
            num_samples = 0

            # Process batches
            batch_progress = tqdm(data_loader, desc=f"AE Epoch {epoch+1}/{num_epochs}", leave=False)
            for x, _ in batch_progress:
                x = x.to(device)

                # Forward pass
                x_recon = autoencoder(x)

                # Calculate loss
                loss = mse_loss(x_recon, x)

                # Backward and optimize
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                # Track statistics
                epoch_loss += loss.item() * x.shape[0]
                num_samples += x.shape[0]

                # Update batch progress
                batch_progress.set_postfix({"Batch Loss": f"{loss.item():.6f}"})

            # Print epoch results
            avg_loss = epoch_loss / num_samples
            print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.6f}")

        # Save trained autoencoder
        torch.save(autoencoder.state_dict(), ae_path)
        print(f"Autoencoder saved to {ae_path}")

        # Visualize reconstructions
        with torch.no_grad():
            test_samples = next(iter(DataLoader(mnist_test, batch_size=16)))[0].to(device)
            reconstructions = autoencoder(test_samples)

            # Create comparison grid
            comparison = torch.cat([test_samples, reconstructions])
            save_path = os.path.join(CHECKPOINT_DIR, "autoencoder_reconstructions.png")
            save_samples(comparison, save_path)
            print(f"Reconstruction examples saved to {save_path}")
    else:
        # Load pre-trained autoencoder
        try:
            autoencoder.load_state_dict(torch.load(ae_path, map_location=device))
            print(f"Loaded autoencoder from {ae_path}")
        except FileNotFoundError:
            print(f"Autoencoder checkpoint not found at {ae_path}. Train it first with train_autoencoder=True.")
            return None

    # Create latent dataset
    if train_model:
        print("Creating latent dataset...")
        data_loader = DataLoader(mnist_train, batch_size=256, shuffle=False, num_workers=4)

        all_latents = []
        all_labels = []

        with torch.no_grad():
            for x, y in tqdm(data_loader, desc="Encoding dataset to latents"):
                x = x.to(device)
                latents = autoencoder.encode(x)
                all_latents.append(latents.cpu())
                all_labels.append(y)

        latent_data = torch.cat(all_latents, dim=0)
        label_data = torch.cat(all_labels, dim=0)

        # Create dataset
        latent_dataset = TensorDataset(latent_data, label_data)

        # Get latent dimensions for model creation
        sample_latent_shape = all_latents[0].shape[1:]
        latent_channels = sample_latent_shape[0]

        # Create and train latent diffusion model
        latent_model = LatentDiffusionModel(
            autoencoder=autoencoder,
            noise_distribution_fn=noise_distribution_fn,
            latent_channels=latent_channels,
            hidden_channels=[64, 128, 256],
            embed_dim=256,
            text_dim=128,
            num_classes=10
        ).to(device)

        print(f"Training latent diffusion model...")
        # We train just the diffusion part, not the autoencoder
        # We're using the same training function but passing the diffusion model directly
        latent_diffusion, history = train_diffusion_model(
            latent_model.diffusion_model,
            latent_dataset,
            noise_distribution_fn,
            num_epochs=50,
            batch_size=256,
            learning_rate=1e-4,
            conditional=True,
            checkpoint_path=latent_model_path,
            resume_training=False
        )

        # Update the diffusion model part of latent_model
        latent_model.diffusion_model = latent_diffusion

        # Plot training loss
        plt.figure(figsize=(10, 6))
        plt.plot([h['epoch'] for h in history], [h['loss'] for h in history])
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.title('Training Loss (Latent Diffusion)')
        plt.savefig(os.path.join(CHECKPOINT_DIR, "latent_model_loss.png"))
        plt.close()

    try:
        # Load model from checkpoint
        print(f"Loading diffusion model from {latent_model_path}")
        # First get the latent dimensions
        with torch.no_grad():
            sample_x = torch.zeros(1, 1, 28, 28, device=device)
            sample_latent = autoencoder.encode(sample_x)
            latent_channels = sample_latent.shape[1]
            latent_shape = (latent_channels, sample_latent.shape[2], sample_latent.shape[3])

        # Create a latent diffusion model
        latent_model = LatentDiffusionModel(
            autoencoder=autoencoder,
            noise_distribution_fn=noise_distribution_fn,
            latent_channels=latent_channels,
            hidden_channels=[64, 128, 256],
            embed_dim=256,
            text_dim=128,
            num_classes=10
        ).to(device)

        # Load just the diffusion part
        checkpoint = torch.load(latent_model_path, map_location=device)
        latent_model.diffusion_model.load_state_dict(checkpoint['model_state_dict'])

        # Generate samples for each digit
        for digit in range(10):
            print(f"Generating samples for digit {digit} using latent diffusion...")

            # Generate using the helper method that handles the entire pipeline
            samples = latent_model.generate(
                num_samples=25,
                class_labels=torch.ones(25, dtype=torch.long, device=device) * digit,
                num_steps=200,
                latent_shape=latent_shape
            )

            # Save samples
            save_path = os.path.join(CHECKPOINT_DIR, f"latent_model_samples_digit{digit}.png")
            save_samples(samples, save_path)
            print(f"Samples for digit {digit} saved to {save_path}")

        return latent_model

    except FileNotFoundError:
        print(f"Latent diffusion model checkpoint not found at {latent_model_path}. Train it first with train_model=True.")
        return None

### Task 7: Run and Experiment

After implementing all the required components, you can run your models and experiment with different
parameters to observe their effects on the generated images.

In [None]:
if __name__ == "__main__":
    # Choose which models to train and evaluate
    # Set parameters to True to train the models, or False to load pre-trained models

    # TODO: Uncomment the models you want to train/evaluate

    # 1. Basic diffusion model (no conditioning)
    # train_and_evaluate_basic_model(train_model=True)

    # 2. Conditional diffusion model with attention
    # train_and_evaluate_conditional_model(train_model=True)

    # 3. Latent diffusion model
    # train_and_evaluate_latent_model(train_autoencoder=True, train_model=True)

    # Or just evaluate pre-trained models:
    # train_and_evaluate_basic_model(train_model=False)
    # train_and_evaluate_conditional_model(train_model=False)
    # train_and_evaluate_latent_model(train_model=False, train_autoencoder=False)

    print("\nDiffusion Model Lab completed! Check the 'model_checkpoints' directory for results.")

    """
    ## Lab Extensions (Optional)

    If you've completed all the tasks and want to explore further, you can try these extensions:

    1. Compare the quality and training times of the different approaches
    2. Experiment with different noise schedules and model architectures
    3. Apply the model to a different dataset (e.g., FashionMNIST)
    4. Implement classifier-free guidance for improved conditional generation
    5. Add more control to the generation process (e.g., controlling digit size or style)
    """
