# **Título**

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('./..')

from image_gen import GenerativeModel
from image_gen.diffusion import VarianceExploding, VariancePreserving, SubVariancePreserving
from image_gen.noise import LinearNoiseSchedule, CosineNoiseSchedule
from image_gen.samplers import EulerMaruyama, ExponentialIntegrator, ODEProbabilityFlow, PredictorCorrector

import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np

In [None]:
diffusion_model = "ve"
noise_schedule = "l"
sampler = "euler"

schedule_map = {
    "l": LinearNoiseSchedule(beta_min=0.0001, beta_max=10),
    "c": CosineNoiseSchedule(beta_max=0.9999)
}
diffusion_map = {
    "ve": VarianceExploding,
    "vp": VariancePreserving,
    "svp": SubVariancePreserving
}
sampler_map = {
    "euler": EulerMaruyama,
    "exp": ExponentialIntegrator,
    "ode": ODEProbabilityFlow,
    "pc": PredictorCorrector
}

model = GenerativeModel(
    diffusion=diffusion_map.get(diffusion_model),
    sampler=sampler_map.get(sampler),
    noise_schedule=schedule_map.get(noise_schedule)
)

if diffusion_model != 've':
    diffusion_model = f"{diffusion_model}_{noise_schedule}"

In [None]:
CLASSES = ['Airplane', 'Car', 'Bird', 'Cat', 'Deer',
               'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

In [None]:
def get_cifar_dataset(class_id=None):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    
    train_dataset = datasets.CIFAR10(
        root='./data', 
        train=True,
        download=True,
        transform=transform
    )
    
    if class_id is not None: 
        targets = torch.tensor(train_dataset.targets)
        idx = (targets == class_id).nonzero().flatten()
        
        train_dataset = torch.utils.data.Subset(train_dataset, idx)
        print(f"Selected {len(train_dataset)} images of class: {CLASSES[class_id]}")
    
    return train_dataset

In [None]:
class_id = None
dataset = get_cifar_dataset(class_id=class_id)

In [None]:
epochs = 500
model.train(dataset, epochs=epochs)
class_name = "_" + CLASSES[class_id] if class_id is not None else ""
model.save(f'cifar10{class_name}_{epochs}e_{diffusion_model}_{sampler}.pth')
# model.load(f'cifar10_{CLASSES[class_id]}_{epochs}e_{diffusion_model}_{sampler}.pth')

In [None]:
n_images = 16
samples = model.generate(n_images)

In [None]:
def show_images(images, n_images=4, contrast=1.0):
    images = images[:n_images]  # Select only the first n_images
    images = images.permute(0, 2, 3, 1).cpu().detach().numpy()
    images = (images + 1) / 2  # Scale from [-1,1] to [0,1]
    
    # Convert to grayscale intensity for proper contrast scaling
    mean = images.mean(axis=(1, 2, 3), keepdims=True)
    std = images.std(axis=(1, 2, 3), keepdims=True) + 1e-6  # Avoid division by zero

    # Adjust contrast properly
    images = mean + contrast * (images - mean)
    images = np.clip(images, 0, 1)  # Ensure values remain in [0,1]

    grid_size = int(np.sqrt(n_images))  # Ensure a square-like grid
    
    fig, axes = plt.subplots(grid_size, grid_size, figsize=(4, 4))
    axes = axes.flatten()
    
    for idx, img in enumerate(images):
        axes[idx].imshow(img)
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
def visualize_cifar_images(dataset, n_images=4):
    fig, axes = plt.subplots(int(np.sqrt(n_images)), int(np.sqrt(n_images)), figsize=(4, 4))
    axes = axes.flatten()
    
    for i in range(n_images):
        img, label = dataset[i]  # Get image and label
        img = img.permute(1, 2, 0).numpy()  # Convert to (H, W, C)
        img = (img + 1) / 2  # Rescale from [-1, 1] to [0, 1]
        
        axes[i].imshow(img)
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()

### Originales

In [None]:
visualize_cifar_images(dataset, n_images=n_images)

### Generadas

In [None]:
show_images(samples, n_images=n_images, contrast=1)