<pre style="text-align: right; width: 100%; font-size: 0.75em; line-height: 0.75em;">
+ ------------------------- + <br>
| 28/04/2025                | <br>
| Héctor Tablero Díaz       | <br>
| Álvaro Martínez Gamo      | <br>
+ ------------------------- + 
</pre>

# **Colorization**

In [None]:
import sys
sys.path.append('./..')

import os

import torch
from torch.utils.data import Subset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

from image_gen import GenerativeModel
from image_gen.samplers import EulerMaruyama
from image_gen.diffusion import VarianceExploding

from image_gen.visualization import display_images

In [None]:
epochs = 500
class_id = 1
seed = 123

In [None]:
# Load the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

data = datasets.CIFAR10(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

# Select a subset to speed up the training process
targets = torch.tensor(data.targets)
idx = (targets == class_id).nonzero().flatten()
data = Subset(data, idx)

In [None]:
model = GenerativeModel(
    diffusion=VarianceExploding,
    sampler=EulerMaruyama
)

In [None]:
filename = f'saved_models/cifar10_{class_id}_ve_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

## **Overview**

Grayscale-to-color synthesis using YUV-space luminance guidance.

Key features:
- Requires 3-channel diffusion model
- Preserves original luminance values
- Generates plausible color variations
- Interactive evolution visualization

## **Parameters**

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `x` | Tensor | - | Input grayscale image (1 or 3 channels) |
| `n_steps` | int | 500 | Reverse process steps |
| `seed` | int | None | Random seed |
| `class_labels` | Tensor | None | Optional class conditioning |
| `progress_callback` | function | None | Generation progress handler |

## **Usage Examples**

### **Basic Colorization**

Colorize a grayscale image:

In [None]:
generated_image = model.generate(num_samples=1, seed=seed)
gray_image = torch.mean(generated_image, dim=1, keepdim=True)
display_images(generated_image)
display_images(gray_image)

In [None]:
colorized = model.colorize(gray_image, seed=seed)
display_images(colorized)

### **Multiple Variations**

Generate different color hypotheses:

In [None]:
gray_batch = gray_image.repeat(16, 1, 1, 1)
colorized_batch = model.colorize(gray_batch, seed=seed)
display_images(colorized_batch)

## **Implementation Details**

The colorization process:

1. **YUV Conversion**: Convert grayscale to YUV space
2. **UV Initialization**: Randomize chrominance channels
3. **Luminance Enforcement**: Gradually blend generated colors with original luminance
4. **RGB Conversion**: Final result in standard color space

Key equation during sampling:
$$ Y_{t} = (1-\alpha)Y^{\text{generated}} + \alpha Y^{\text{original}} $$
where $\alpha$ decreases linearly from 1 to 0

## **Important Notes**

- Model **must** be initialized with 3 channels
- Input can be 1-channel (grayscale) or 3-channel (RGB)
- Dashboard expects grayscale PNG inputs
- Output values clamped to [0,1] range
- Higher steps (500+) improve color coherence