<pre style="text-align: right; width: 100%; font-size: 0.75em; line-height: 0.75em;">
+ ------------------------- + <br>
| 28/04/2025                | <br>
| Héctor Tablero Díaz       | <br>
| Álvaro Martínez Gamo      | <br>
+ ------------------------- + 
</pre>

# **Imputation (Inpainting)**

In [None]:
import sys
sys.path.append('./..')

import os

import torch
from torch.utils.data import Subset
from torchvision import datasets
from torchvision.transforms import ToTensor

from image_gen import GenerativeModel
from image_gen.samplers import ExponentialIntegrator
from image_gen.diffusion import VariancePreserving
from image_gen.noise import LinearNoiseSchedule

from image_gen.visualization import display_images

In [None]:
epochs = 50
digit = 3
seed = 0

In [None]:
# Load the dataset
data = datasets.MNIST(
    root='data',
    train=True,
    download=True,
    transform=ToTensor()
)

# Select a subset to speed up the training process
indices_digit = torch.where(data.targets == digit)[0]
data = Subset(data, indices_digit)

In [None]:
model = GenerativeModel(
    diffusion=VariancePreserving,
    sampler=ExponentialIntegrator,
    noise_schedule=LinearNoiseSchedule
)

In [None]:
filename = f'saved_models/mnist_{digit}_vp-lin_{epochs}e.pth'

if os.path.isfile(filename):
    model.load(filename)
else:
    model.train(data, epochs=epochs)
    # Tip: Save the models for them to be accessible through the dashboard
    model.save(filename)

## **Overview**

Mask-guided generation for filling missing image regions using diffusion models.

Key features:
- Supports arbitrary binary masks
- Preserves known pixel values
- Blends generated content with original image
- Dashboard supports transparent PNG mask handling

## **Parameters**
| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `x` | Tensor | - | Input image tensor (B,C,H,W) |
| `mask` | Tensor | - | Binary mask (1=generate, 0=preserve) |
| `n_steps` | int | 500 | Number of reverse steps |
| `seed` | int | None | Random seed for reproducibility |
| `class_labels` | Tensor | None | Optional class conditioning |
| `progress_callback` | function | None | Generation progress handler |

## **Usage Examples**

### **Manual Mask Creation**

Create a custom mask programmatically:

In [None]:
base_image = model.generate(num_samples=1, seed=seed)
display_images(base_image)

In [None]:
# Create center rectangle mask
mask = torch.ones_like(base_image)
h, w = base_image.shape[2], base_image.shape[3]
mask[:, :, h//4:3*h//4, w//4:3*w//4] = 0

# Create a batch of 16 images with the same mask
mask_batch = mask.repeat(16, 1, 1, 1)
base_image_batch = base_image.repeat(16, 1, 1, 1)

results_batch = model.imputation(base_image_batch, mask_batch, n_steps=500, seed=seed)
display_images(results_batch)

## **Implementation Details**

The imputation process:

1. **Normalization**: Scale input to [-1,1] range
2. **Mask Preparation**: Expand mask to match image channels
3. **Noise Injection**: Apply noise only to masked regions
4. **Guided Sampling**: Blend generated content with original pixels
5. **Denormalization**: Convert back to original value range

Key equation during sampling:
$$x_{t} = \text{mask} \cdot x_{t}^{\text{generated}} + (1-\text{mask}) \cdot (\frac{t}{\text{n\_steps}} \cdot x_{t}^{\text{original}} + (1 - \frac{t}{\text{n\_steps}}) \cdot x_{t}^{\text{generated}})$$

## **Important Notes**

- The mask must be a single-channel binary tensor
- Input images are automatically normalized
- The dashboard requires a PNG with alpha channel, which will be used as the mask
- For color images, the mask applied to all channels
- The function preserves original pixel values exactly in unmasked regions