# Week 9, Day 3: Diffusion Models

## Learning Objectives
- Understand diffusion model concepts
- Learn denoising process
- Master diffusion architectures
- Practice implementing diffusion models

## Topics Covered
1. Diffusion Process
2. Denoising Models
3. Score-Based Models
4. Advanced Architectures

In [None]:
# Import required libraries
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers

## 1. Basic Diffusion Model

In [None]:
class DiffusionModel:
    def __init__(self, timesteps=1000):
        self.timesteps = timesteps
        
        # Define noise schedule
        self.beta = np.linspace(0.0001, 0.02, timesteps)
        self.alpha = 1. - self.beta
        self.alpha_bar = np.cumprod(self.alpha)
        
        # Build model
        self.model = self.build_model()
    
    def build_model(self):
        # U-Net architecture
        inputs = layers.Input(shape=(32, 32, 1))
        x = inputs
        
        # Encoder
        skips = []
        for filters in [64, 128, 256]:
            x = layers.Conv2D(filters, 3, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.ReLU()(x)
            skips.append(x)
            x = layers.MaxPooling2D()(x)
        
        # Middle
        x = layers.Conv2D(512, 3, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        
        # Decoder
        for filters, skip in zip([256, 128, 64], reversed(skips)):
            x = layers.UpSampling2D()(x)
            x = layers.Concatenate()([x, skip])
            x = layers.Conv2D(filters, 3, padding='same')(x)
            x = layers.BatchNormalization()(x)
            x = layers.ReLU()(x)
        
        outputs = layers.Conv2D(1, 3, padding='same')(x)
        
        return tf.keras.Model(inputs, outputs)
    
    def diffusion_forward(self, x0, t):
        """Forward diffusion process"""
        noise = tf.random.normal(shape=x0.shape)
        alpha_t = tf.gather(self.alpha_bar, t)
        alpha_t = tf.reshape(alpha_t, (-1, 1, 1, 1))
        return tf.sqrt(alpha_t) * x0 + tf.sqrt(1. - alpha_t) * noise, noise
    
    def diffusion_reverse(self, xt, t):
        """Reverse diffusion process"""
        predicted_noise = self.model(xt, training=False)
        alpha_t = tf.gather(self.alpha_bar, t)
        alpha_t = tf.reshape(alpha_t, (-1, 1, 1, 1))
        beta_t = tf.gather(self.beta, t)
        beta_t = tf.reshape(beta_t, (-1, 1, 1, 1))
        
        mean = (1. / tf.sqrt(alpha_t)) * (xt - ((1. - alpha_t) / tf.sqrt(1. - alpha_t)) * predicted_noise)
        variance = beta_t
        
        return mean + tf.sqrt(variance) * tf.random.normal(shape=xt.shape)
    
    def sample(self, batch_size=16):
        """Generate samples using the reverse process"""
        # Start from pure noise
        x = tf.random.normal(shape=(batch_size, 32, 32, 1))
        
        # Reverse diffusion
        for t in reversed(range(self.timesteps)):
            x = self.diffusion_reverse(x, t)
        
        return x

## 2. Score-Based Model

In [None]:
class ScoreBasedModel:
    def __init__(self):
        self.model = self.build_score_model()
        
    def build_score_model(self):
        inputs = layers.Input(shape=(32, 32, 1))
        x = inputs
        
        # Downsampling
        for filters in [64, 128, 256]:
            x = layers.Conv2D(filters, 3, padding='same')(x)
            x = layers.GroupNormalization()(x)
            x = layers.Activation('swish')(x)
            x = layers.Conv2D(filters, 3, padding='same')(x)
            x = layers.GroupNormalization()(x)
            x = layers.Activation('swish')(x)
            x = layers.AveragePooling2D()(x)
        
        # Middle
        x = layers.Conv2D(512, 3, padding='same')(x)
        x = layers.GroupNormalization()(x)
        x = layers.Activation('swish')(x)
        
        # Upsampling
        for filters in [256, 128, 64]:
            x = layers.UpSampling2D()(x)
            x = layers.Conv2D(filters, 3, padding='same')(x)
            x = layers.GroupNormalization()(x)
            x = layers.Activation('swish')(x)
            x = layers.Conv2D(filters, 3, padding='same')(x)
            x = layers.GroupNormalization()(x)
            x = layers.Activation('swish')(x)
        
        outputs = layers.Conv2D(1, 3, padding='same')(x)
        
        return tf.keras.Model(inputs, outputs)
    
    def score_matching_loss(self, x, sigma):
        """Denoising score matching loss"""
        noise = tf.random.normal(shape=x.shape)
        perturbed_x = x + sigma * noise
        score = self.model(perturbed_x)
        target = -noise / sigma
        return tf.reduce_mean(tf.square(score - target))

## 3. Training Loop

In [None]:
def train_diffusion_model(model, dataset, epochs=100):
    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
    
    for epoch in range(epochs):
        for batch in dataset:
            with tf.GradientTape() as tape:
                # Sample timestep
                t = tf.random.uniform(
                    shape=(batch.shape[0],),
                    minval=0,
                    maxval=model.timesteps,
                    dtype=tf.int32
                )
                
                # Forward process
                noisy_batch, noise = model.diffusion_forward(batch, t)
                
                # Predict noise
                predicted_noise = model.model(noisy_batch, training=True)
                
                # Calculate loss
                loss = tf.reduce_mean(tf.square(noise - predicted_noise))
            
            # Update model
            grads = tape.gradient(loss, model.model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.model.trainable_variables))
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}, Loss: {loss:.4f}')
            
            # Generate samples
            samples = model.sample(batch_size=4)
            
            # Plot samples
            plt.figure(figsize=(8, 2))
            for i in range(4):
                plt.subplot(1, 4, i+1)
                plt.imshow(samples[i, ..., 0], cmap='gray')
                plt.axis('off')
            plt.show()

## Practical Exercises

In [None]:
# Exercise 1: Simple Diffusion

def diffusion_exercise():
    print("Task: Implement basic diffusion model")
    print("1. Create noise schedule")
    print("2. Implement forward process")
    print("3. Implement reverse process")
    print("4. Generate samples")
    
    # Your code here

diffusion_exercise()

In [None]:
# Exercise 2: Score Matching

def score_matching_exercise():
    print("Task: Implement score-based model")
    print("1. Create score network")
    print("2. Implement score matching")
    print("3. Train model")
    print("4. Generate samples")
    
    # Your code here

score_matching_exercise()

## MCQ Quiz

1. What is a diffusion model?
   - a) Classification model
   - b) Noise-based generation
   - c) Regression model
   - d) Clustering model

2. What is the forward process?
   - a) Generation
   - b) Noise addition
   - c) Classification
   - d) Regression

3. What is the reverse process?
   - a) Noise addition
   - b) Denoising
   - c) Classification
   - d) Regression

4. What is score matching?
   - a) Classification
   - b) Gradient learning
   - c) Regression
   - d) Clustering

5. What is the noise schedule?
   - a) Random noise
   - b) Noise parameters
   - c) Model architecture
   - d) Loss function

6. What is denoising diffusion?
   - a) Noise addition
   - b) Generation process
   - c) Classification
   - d) Regression

7. What is a U-Net?
   - a) Loss function
   - b) Network architecture
   - c) Optimization method
   - d) Sampling strategy

8. What is timestep embedding?
   - a) Loss function
   - b) Time conditioning
   - c) Model architecture
   - d) Sampling method

9. What is guidance?
   - a) Training method
   - b) Generation control
   - c) Loss function
   - d) Model architecture

10. What is ancestral sampling?
    - a) Training method
    - b) Sampling strategy
    - c) Loss function
    - d) Model architecture

Answers: 1-b, 2-b, 3-b, 4-b, 5-b, 6-b, 7-b, 8-b, 9-b, 10-b