<pre style="text-align: right; width: 100%; font-size: 0.75em; line-height: 0.75em;">
+ ------------------------- + <br>
| 20/04/2025                | <br>
| Héctor Tablero Díaz       | <br>
| Álvaro Martínez Gamo      | <br>
+ ------------------------- + 
</pre>

# **Getting Started**

### **Setup & Imports**

The module is divided into the following structure:

- `image_gen.diffusion` contains different diffusers:
  - `VarianceExploding`
  - `VariancePreserving`
  - `SubVariancePreserving`
- `image_gen.metrics` has functions to estimate the quality of the generated images
- `image_gen.noise` contains 2 noise schedulers that control the amount of noise added at each timestep of the process:
  - `LinearNoiseSchedule`
  - `CosineNoiseSchedule`
- `image_gen.samplers` contains 4 samplers that can generate images from random noise:
  - `EulerMaruyama`
  - `ExponentialIntegrator`
  - `ODEProbabilityFlow`
  - `PredictorCorrector`
- `image_gen.visualization` has functions to display both the results and the progress of the generative process

In [None]:
import sys
sys.path.append('./..')

import os

import torch
from torch.utils.data import Subset
from torchvision import datasets
from torchvision.transforms import ToTensor

from image_gen import GenerativeModel
from image_gen.samplers import EulerMaruyama
from image_gen.diffusion import VarianceExploding

from image_gen.visualization import display_images, display_evolution

### **Setup & Training**

The main class used is `GenerativeModel`, directly under `image_gen`. It can be called with the samplers, diffusers and noise schedulers mentioned previously. If nothing is set, the default initialization will be the same one as the one displayed in this notebook, using `VarianceExploding` together with `EulerMaruyama`.

In [None]:
# Set up the common variables
epochs = 50
digit = 3

seed = 42

In [None]:
# Load the dataset
data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

# Select a subset to speed up the training process
indices_digit = torch.where(data.targets == digit)[0]
data = Subset(data, indices_digit)

In [None]:
model = GenerativeModel(
    diffusion=VarianceExploding,
    sampler=EulerMaruyama
)

Models can also be initialized with their short codes, although this approach doesn't allow to set custom parameters.

The codes are:
- **VarianceExploding**: `ve`
- **VariancePreserving**: `vp`
- **SubVariancePreserving**: `sub-vp`, `svp`

<br>

- **LinearNoiseSchedule**: `linear`, `lin`
- **CosineNoiseSchedule**: `cosine`, `cos`

<br>

- **EulerMaruyama**: `euler-maruyama`, `em`
- **ExponentialIntegrator**: `exponential`, `exp`
- **ODEProbabilityFlow**: `ode`
- **PredictorCorrector**: `predictor-corrector`, `pred`

In [None]:
# Same initialization as before
model = GenerativeModel(
    diffusion="ve",
    sampler="euler-maruyama"
)

In [None]:
filename = f'saved_models/mnist_{digit}_ve_euler_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

### **Image Generation**

Images can be created with the `generate` method of the `GenerativeModel` class.

It can be configured with these parameters:
- `num_samples` (int): The number of images to generate
- `n_steps` (int): The number of iterations the sampler will go through
- `seed` (int, optional): An optional seed to get repeatable results
- `class_labels` (int/Tensor, optional): An integer or Tensor of ints that determines the class being generated in each image. If only an int is passed, all images will be of that class.
- `progress_callback` (Callable, optional): A function that will be called every few iterations of the model. Useful for displaying intermediate images or to estimate the remaining generation time.

In [None]:
n_images = 16
samples = model.generate(n_images, seed=seed)
display_images(samples)

The diffusion process looks as follows:

In [None]:
display_evolution(model, num_samples=4, seed=seed)