<pre style="text-align: right; width: 100%; font-size: 0.75em; line-height: 0.75em;">
+ ------------------------- + <br>
| 29/04/2025                | <br>
| Héctor Tablero Díaz       | <br>
| Álvaro Martínez Gamo      | <br>
+ ------------------------- + 
</pre>

# **Getting Started**

In [None]:
import sys
sys.path.append('./..')

import os

import torch
from torch.utils.data import Subset
from torchvision import datasets
from torchvision.transforms import ToTensor

from image_gen import GenerativeModel
from image_gen.samplers import EulerMaruyama
from image_gen.diffusion import VarianceExploding

from IPython.display import HTML
from image_gen.visualization import display_images, create_evolution_widget

## **1. Prepare Data**

Load MNIST dataset (subset of digit 3)

In [None]:
data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

# Filter only digit 3
digit = 3
indices = torch.where(data.targets == digit)[0]
dataset = Subset(data, indices)

## **2. Initialize Model**

Default configuration: VE Diffusion + Euler-Maruyama

In [None]:
model = GenerativeModel(
    diffusion=VarianceExploding,
    sampler=EulerMaruyama
)

## **3. Train Model**

Quick training with 50 epochs

In [None]:
epochs = 50

filename = f'saved_models/mnist_{digit}_ve_{epochs}e.pth'

if not os.path.exists(filename):
    model.train(dataset, epochs=epochs, batch_size=64)
    model.save(filename)
else:
    model.load(filename)

## **4. Generate Samples**

Create 16 new samples

In [None]:
seed = 42

samples = model.generate(num_samples=16, seed=seed)
display_images(samples)

## **5. Generation Process**

Watch the denoising evolution

In [None]:
HTML(create_evolution_widget(model, seed=seed).to_jshtml(default_mode="once"))

## **Next Steps**

- Try different digits
- Use a dataset with 3 channels (RGB)
- Increase epochs (100-500)
- Explore other samplers/diffusers
- See advanced notebooks