In [1]:
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import numpy as np
from tqdm import tqdm
from data.CustomImageDataset import GenerationImageDataset
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.transforms.v2 import Grayscale, ToDtype, Lambda, CenterCrop, Resize, ToPILImage
from einops import rearrange
from utils.visualization import plot_graph
torch.cuda.empty_cache()

## Chargement des données

In [2]:
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader


# define image transformations (e.g. using torchvision)
transform = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Lambda(lambda t: (t * 2) - 1)
])

dataset = datasets.MNIST(
    root="./data/datasets",
    train=True,
    download=True,
    transform=transform,
    
)
channels, image_size, _ = dataset[0][0].shape
batch_size = 128

# create dataloader
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

## Model declaration

In [3]:
from models.TemporalUNet import TemporalUNet
from models.DDPM import DDPM

unet = TemporalUNet(in_channels=1, out_channels=1, channels_mult=(1, 2, 4))

diffusion_model = DDPM(denoiser=unet, timestep=300, schedule_type='linear')

### Check model settings

### Sur-apprentissage

In [4]:
from torch.optim import Adam
from utils.trainer.DiffusionModelTrainer import DiffusionModelTrainer

learning_rate = 1e-3

optimizer  = Adam(params=unet.parameters(), lr=learning_rate)

#short_train_dataset = torch.utils.data.Subset(train_dataset, indices=torch.arange(0, 4))

trainer = DiffusionModelTrainer(model=diffusion_model, train_dataset=dataset, test_dataset=dataset, loss_fn=diffusion_model.compute_loss, optimizer=optimizer, batch_size=128)

In [5]:
losses = trainer.train(num_epochs=6)
#torch.save(unet.state_dict(), 'weights/diffusion_unet_overfit_b8_t300_cosine_lr25_c1.pt')

In [6]:
def cosine_beta_schedule(timesteps, s=0.008):
    """
    cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps)
    alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, 0.0001, 0.9999)

def linear_beta_schedule(timesteps):
    beta_start = 0.0001
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)

In [7]:
import torch.nn.functional as F

timesteps = 300

# define beta schedule
betas = linear_beta_schedule(timesteps=timesteps)

# define alphas 
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
sqrt_recip_alphas = torch.sqrt(1.0 / alphas)

# calculations for diffusion q(x_t | x_{t-1}) and others
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - alphas_cumprod)

# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)

def extract(a, t, x_shape):
    batch_size = t.shape[0]
    out = a.gather(-1, t.cpu())
    return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(t.device)


In [8]:
@torch.no_grad()
def p_sample(model, x, t, t_index):
    betas_t = extract(betas, t, x.shape)
    sqrt_one_minus_alphas_cumprod_t = extract(
        sqrt_one_minus_alphas_cumprod, t, x.shape
    )
    sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)
    
    # Equation 11 in the paper
    # Use our model (noise predictor) to predict the mean
    model_mean = sqrt_recip_alphas_t * (
        x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t
    )

    if t_index == 0:
        return model_mean
    else:
        posterior_variance_t = extract(posterior_variance, t, x.shape)
        noise = torch.randn_like(x)
        # Algorithm 2 line 4:
        return model_mean + torch.sqrt(posterior_variance_t) * noise 

# Algorithm 2 (including returning all images)
@torch.no_grad()
def p_sample_loop(model, shape):
    device = next(model.parameters()).device

    b = shape[0]
    # start from pure noise (for each example in the batch)
    img = torch.randn(shape, device=device)
    imgs = []

    for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):
        img = p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long), i)
        imgs.append(img.cpu().numpy())
    return imgs

@torch.no_grad()
def sample(model, image_size, batch_size=16, channels=3):
    return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

In [11]:
samples = sample(unet, image_size=image_size, batch_size=128, channels=1)
plt.imshow(samples[-1][6].reshape(image_size, image_size, channels), cmap='gray')

In [12]:
samples[-1][0].min()

In [None]:
from utils.visualization import plot_graph
plot_graph(x = np.arange(len(losses['training_loss'])), 
           y = losses['training_loss'],
           xlabel='epoch', 
           ylabel=r'$loss$', 
           title='MSE Loss evolution on training set')

In [None]:
unet.load_state_dict(torch.load('weights/diffusion_unet_overfit_b8_t300_cosine_lr25_c1.pt', map_location=torch.device('cuda:0')))

## Illustration du processus avant de diffusion

### Evaluating variable shape

In [None]:
num_images = 10

x0 = dataset[0].to('cuda')

fig, axs = plt.subplots(nrows=1, ncols=(num_images+1), figsize=((num_images+1)*2, 2))
plt.axis('off')
axs[0].axis('off')
axs[0].set_title(f't: {0}')
axs[0].imshow(reverse_transformation(x0.cpu()), cmap='gray')

x0 = rearrange(x0, 'c h w -> 1 c h w')
t = int(diffusion_model.timestep /num_images)
x_noisy = diffusion_model.forward_sampling(x0, torch.tensor([t], dtype=torch.int64, device=x0.device))

axs[1].axis('off')
axs[1].set_title(f't: {t}')
axs[1].imshow(reverse_transformation(rearrange(x_noisy, '1 c h w -> c h w').cpu()), cmap='gray')
for i in range(num_images - 2):
    axs[i+2].axis('off')
    t = int(diffusion_model.timestep / (num_images) * (i+2))
    axs[i+2].set_title(f't: {t}')
    x_noisy = diffusion_model.forward_sampling(x0, torch.tensor([t], dtype=torch.int64, device=x0.device))
    axs[i+2].imshow(reverse_transformation(rearrange(x_noisy, '1 c h w -> c h w').cpu()), cmap='gray')
    
    axs[-1].axis('off')
    t = diffusion_model.timestep - 1
    axs[-1].set_title(f't: {t}')
    x_noisy = diffusion_model.forward_sampling(x0, torch.tensor([t], dtype=torch.int64, device=x0.device))
    axs[-1].imshow(reverse_transformation(rearrange(x_noisy, '1 c h w -> c h w').cpu()), cmap='gray')

## Processus arrière de diffusion

In [None]:
x_rec = diffusion_model.reverse_sampling(x_noisy)

In [None]:
def normalize(x):
    

In [None]:
x_rec[-1].max()

In [None]:
num_images = 5

fig, axs = plt.subplots(nrows=1, ncols=num_images+1, figsize=((num_images+1)*2, 2))
plt.axis('off')
for i in range(num_images):
    t = int(diffusion_model.timestep * i / num_images)
    axs[i].axis('off')
    axs[i].set_title(f't: {t}')
    axs[i].imshow(rearrange(x_rec[t], '1 c h w -> h w c').cpu(), cmap='gray')

axs[-1].axis('off')
axs[-1].set_title(f't: {diffusion_model.timestep-1}')
axs[-1].imshow(rearrange(x_rec[-1], '1 c h w -> h w c').cpu(), cmap='gray')

## Affiche les infos sur le modèles

In [None]:
x_rec[299]

In [None]:
train_dataset[0].shape

In [None]:
from torchinfo import summary

summary_kwargs = dict(col_names=['input_size', 'output_size', 'kernel_size', 'num_params', 'mult_adds'], depth=3, verbose=0)

summary(unet, input_data=(torch.ones(16, 1, 256, 256), torch.ones(16, 1)), batch_dim=0, **summary_kwargs)

Rappel, une RTX 4080 à une puissance de 48.74 TFlops (FP32) Soit $48.74 \times 10^{12}$ opérations par secondes en float32.
Avec une profondeur de 5 blocs:
Ici, nous lisons que le modèle effectue $82.15 \times 10^9$ opérations pour une image avec des couches résiduelles. Nous avons aussi 114 millions de paramètres au total.
En utilisant remplaçant les couches résiduelles par des ResidualBottleneck, nous obtenons $20.26 \times 10^9$ opérations et 33 millions de paramètres pour le modèle.

In [None]:
import torch
a = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])

In [None]:
torch.tensor([1, *a[:-1]])