<pre style="text-align: right; width: 100%; font-size: 0.75em; line-height: 0.75em;">
+ ------------------------- + <br>
| 20/04/2025                | <br>
| Héctor Tablero Díaz       | <br>
| Álvaro Martínez Gamo      | <br>
+ ------------------------- + 
</pre>

# **Samplers**

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('./..')

import os

import torch
from torch import Tensor
from torch.utils.data import Subset
from torchvision import datasets
from torchvision.transforms import ToTensor

from image_gen import GenerativeModel
from image_gen.samplers import EulerMaruyama, ExponentialIntegrator, ODEProbabilityFlow, PredictorCorrector, BaseSampler
from image_gen.diffusion import VarianceExploding
from image_gen.noise import LinearNoiseSchedule

from typing import Callable, Optional

from image_gen.visualization import display_images

In [None]:
epochs = 50
digit = 3

seed = 42

In [None]:
# Load the dataset
data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

# Select a subset to speed up the training process
indices_digit = torch.where(data.targets == digit)[0]
data = Subset(data, indices_digit)

Samplers control the reverse process dynamics. Key considerations:  
- Numerical stability  
- Sample quality  
- Computational cost  

Implemented samplers:  

| Sampler Type | Characteristics                          | Best For                  |  
|--------------|------------------------------------------|---------------------------|  
| [Euler-Maruyama](#euler) | Simple SDE integration              | Quick generations        |  
| [Exponential](#exp) | Adaptive step sizing                 | Stable trajectories       |  
| [ODE](#ode)          | Deterministic sampling              | High-fidelity outputs     |  
| [Predictor-Corrector](#pc) | Iterative refinement           | Challenging distributions |  

### <span id="euler">**Euler-Maruyama**</span>

#### Mathematical Definition
$$x_{t-1} = x_t + f(x_t,t)Δt + g(x_t,t)\sqrt{Δt}ϵ $$

#### Characteristics  
- **Speed**: Fastest sampler  
- **Quality**: May produce artifacts

#### Usage Example

In [None]:
model = GenerativeModel(
    sampler=EulerMaruyama,
    diffusion=VarianceExploding,
    noise_schedule=LinearNoiseSchedule
)

In [None]:
filename = f'saved_models/mnist_{digit}_ve_euler_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

In [None]:
n_images = 16
samples = model.generate(n_images, seed=seed)
display_images(samples)

### <span id="exp">**Exponential Integrator**</span>

#### Mathematical Definition
$$x_{t-1} = x_t e^{λΔt} + \frac{g^2}{2λ}(e^{2λΔt} - 1)∇_{x}\log p_t(x)$$

#### Characteristics  
- **Adaptive**: Automatic step size adjustment  
- **Stability**: Robust to parameter choices  
- **Cost**: Moderate computational overhead

#### Usage Example

In [None]:
model = GenerativeModel(
    sampler=ExponentialIntegrator,
    diffusion=VarianceExploding,
    noise_schedule=LinearNoiseSchedule
)

In [None]:
filename = f'saved_models/mnist_{digit}_ve_exp_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

In [None]:
n_images = 16
samples = model.generate(n_images, seed=seed)
display_images(samples)

### <span id="ode">**ODE Probability Flow**</span>

#### Mathematical Definition

$$\frac{dx}{dt} = f(x,t) - \frac{1}{2}g(t)^2∇_x\log p_t(x)$$

#### Characteristics  
- **Determinism**: Reproducible outputs  
- **Precision**: High sample quality  
- **Cost**: 2-3× slower than Euler-Maruyama

#### Usage Example

In [None]:
model = GenerativeModel(
    sampler=ODEProbabilityFlow,
    diffusion=VarianceExploding,
    noise_schedule=LinearNoiseSchedule
)

In [None]:
filename = f'saved_models/mnist_{digit}_ve_ode_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

In [None]:
n_images = 16
samples = model.generate(n_images, seed=seed)
display_images(samples)

### <span id="pc">**Predictor-Corrector**</span>

#### Mathematical Definition

**Predictor:** $x'_{t-1} = x_t + f(x_t,t)Δt$

**Corrector:** $x_{t-1} = x'_{t-1} + γg^2∇_x\log p_t(x')Δt$

#### Characteristics  
- **Quality**: Best empirical results  
- **Flexibility**: Tunable correction steps  
- **Cost**: Most computationally intensive  

#### Usage Example

In [None]:
model = GenerativeModel(
    sampler=PredictorCorrector,
    diffusion=VarianceExploding,
    noise_schedule=LinearNoiseSchedule
)

In [None]:
filename = f'saved_models/mnist_{digit}_ve_pred_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

In [None]:
n_images = 16
samples = model.generate(n_images, seed=seed)
display_images(samples)

### **Creating Custom Diffusers**

Custom noise schedulers can be created by inheriting from the class `BaseSampler`. They must implement a `__call__` function.

#### Implementation Example

In [None]:
class CustomSampler(BaseSampler):
    def __call__(
        self,
        x_T: Tensor,
        score_model: Callable,
        n_steps: int = 500,
        seed: Optional[int] = None,
        callback: Optional[Callable[[Tensor, int], None]] = None,
        callback_frequency: int = 50,
        guidance: Optional[Callable[[Tensor, Tensor], Tensor]] = None
    ) -> Tensor:
        ...