<pre style="text-align: right; width: 100%; font-size: 0.75em; line-height: 0.75em;">
+ ------------------------- + <br>
| 29/04/2025                | <br>
| Héctor Tablero Díaz       | <br>
| Álvaro Martínez Gamo      | <br>
+ ------------------------- + 
</pre>

# **Class-Conditional Sampling**

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('./..')

import os

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor

from image_gen import GenerativeModel
from image_gen.samplers import ExponentialIntegrator
from image_gen.diffusion import VariancePreserving
from image_gen.noise import LinearNoiseSchedule

from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
from IPython.display import HTML
from image_gen.visualization import display_images, create_evolution_widget

In [None]:
epochs = 50

In [None]:
# Load the dataset
data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

In [None]:
model = GenerativeModel(
    diffusion=VariancePreserving,
    sampler=ExponentialIntegrator,
    noise_schedule=LinearNoiseSchedule
)

In [None]:
filename = f'saved_models/mnist_vp-lin_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

## **Overview**

Conditional generation using class labels for controlled synthesis.

Key features:
- Label-guided generation
- Classifier-free guidance
- Multi-class conditional sampling
- Batch-wise label assignments

## **Parameters**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `class_labels` | int/list/Tensor | None | Class indices to condition on |
| `guidance_scale` | float | 3.0 | Guidance strength (0=unconditional) |
| `num_samples` | int | - | Number of images to generate |
| `n_steps` | int | 500 | Reverse process steps |
| `seed` | int | None | Random seed |

## **Usage Examples**

### **Single Class Generation**

Generate 16 samples from class 7:

In [None]:
samples = model.generate(
    num_samples=16,
    class_labels=7
)
display_images(samples)

### **Mixed Class Batch**

Generate specific classes for each sample:

In [None]:
# Create label tensor [0,0,1,1,2,2,...]
labels = torch.repeat_interleave(torch.arange(0, model.num_classes), 2)

samples = model.generate(
    num_samples=len(labels),
    class_labels=labels
)
display_images(samples)

## **Implementation Details**

The conditional sampling process:

1. **Guidance Formulation**:
$$ \epsilon_{cond} = \epsilon_{uncond} + s \cdot (\epsilon_{cond} - \epsilon_{uncond}) $$
Where $s$ = `guidance_scale`

2. **Label Validation**:
- Automatic conversion to model's training labels
- Invalid labels replaced with first valid class

3. **Batch Handling**:
- Single label → applied to all samples
- Multiple labels → 1:1 mapping with batch

## **Important Notes**

- Model must be trained with class conditioning
- Valid labels: 0 to `num_classes-1`
- Invalid labels auto-corrected
- Higher guidance scales (5-10) for clearer class features
- Set `guidance_scale=0` for unconditional sampling

## **Visualization**

In [None]:
HTML(create_evolution_widget(
    model,
    class_labels=5
).to_jshtml(default_mode="once"))

### **Guidance Scale Comparison**

Effect of different guidance strengths:

In [None]:
model.verbose = False
fig, axs = plt.subplots(2, 4, figsize=(15, 8))
for i, scale in tqdm(list(enumerate([0, 0.5, 1, 2, 3, 5, 7.5, 10])), desc="Generating samples"):
    samples = model.generate(
        num_samples=1,
        class_labels=6,
        guidance_scale=scale,
        seed=123
    ).cpu()
    axs[i//4, i%4].imshow(samples[0].permute(1,2,0), cmap="gray")
    axs[i//4, i%4].set_title(f'Scale={scale}')
plt.tight_layout()
plt.show()