# Importation library

# Hyperparamètres
# Diffusion hyperparamètres


In [9]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import HTML
from diffusion_utilities import ResidualConvBlock, UnetUp, UnetDown, EmbedFC, plot_sample
import os


In [6]:
timesteps = 500
beta1 = 1e-4
beta2 = 0.02

# Réseau hyperparamètres
device = torch.device("cuda:0" if torch.cuda.is_available() else torch.device('cpu'))
n_feat = 64       # Dimension des features cachées
n_cfeat = 5       # Taille du vecteur de contexte
height = 16       # Image 16x16


# Construction du planning de bruit DDPM avec un schedule linéaire
b_t = beta1 + ((beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device))
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1

#L'Architecture U-Net pour la prédiction du bruit
# (Cette partie reste inchangée, il s'agit de comprendre le fonctionnement du U-Net)

In [7]:

class ContextUnet(nn.Module):
    def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28):  # n_cfeat : features du contexte
        super(ContextUnet, self).__init__()
        self.in_channels = in_channels
        self.n_feat = n_feat
        self.n_cfeat = n_cfeat
        self.h = height  # h == w, doit être divisible par 4 (ex: 28,24,20,16...)

        # Couche de convolution initiale
        self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)

        # Chemin de down-sampling du U-Net avec deux niveaux
        self.down1 = UnetDown(n_feat, n_feat)
        self.down2 = UnetDown(n_feat, 2 * n_feat)

        # Transformation des feature maps en vecteur
        self.to_vec = nn.Sequential(nn.AvgPool2d(4), nn.GELU())

        # Embedding du temps et du contexte via une couche FC
        self.timeembed1 = EmbedFC(1, 2*n_feat)
        self.timeembed2 = EmbedFC(1, n_feat)
        self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
        self.contextembed2 = EmbedFC(n_cfeat, n_feat)

        # Chemin d'up-sampling du U-Net avec trois niveaux
        self.up0 = nn.Sequential(
            nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4),
            nn.GroupNorm(8, 2 * n_feat),
            nn.ReLU(),
        )
        self.up1 = UnetUp(4 * n_feat, n_feat)
        self.up2 = UnetUp(2 * n_feat, n_feat)

        # Couche de sortie pour ramener au nombre de canaux d'entrée
        self.out = nn.Sequential(
            nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),
            nn.GroupNorm(8, n_feat),
            nn.ReLU(),
            nn.Conv2d(n_feat, self.in_channels, 3, 1, 1),
        )

    def forward(self, x, t, c=None):
        # x : (batch, n_feat, h, w) : image d'entrée
        # t : (batch, n_cfeat)      : temps
        # c : (batch, n_classes)    : étiquette de contexte
        x = self.init_conv(x)
        down1 = self.down1(x)
        down2 = self.down2(down1)

        hiddenvec = self.to_vec(down2)

        if c is None:
            c = torch.zeros(x.shape[0], self.n_cfeat).to(x.device)

        cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1)
        temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
        cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
        temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)

        up1 = self.up0(hiddenvec)
        up2 = self.up1(cemb1 * up1 + temb1, down2)
        up3 = self.up2(cemb2 * up2 + temb2, down1)
        out = self.out(torch.cat((up3, x), 1))
        return out

# Construction du modèle
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)



In [10]:
# Chargement du modèle pré-entraîné
model_path = '/content/context_model_trained.pth'
nn_model.load_state_dict(torch.load(model_path, map_location=device))
nn_model.eval()
print("Modèle pré-entraîné chargé.")

# Détermination du répertoire pour sauvegarder les visualisations à partir du modèle chargé
save_dir = os.path.dirname(model_path)
if save_dir == "":
    save_dir = "."

Modèle pré-entraîné chargé.


In [11]:
# Fonction helper : débruitage avec ajout de bruit supplémentaire
def denoise_add_noise(x, t, pred_noise, z=None):
    if z is None:
        z = torch.randn_like(x)
    noise = b_t.sqrt()[t] * z
    mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
    return mean + noise

In [12]:
# 4. Le Processus d'Échantillonnage
@torch.no_grad()
def sample_ddpm(n_sample, save_rate=20):
    samples = torch.randn(n_sample, 3, height, height).to(device)
    intermediate = []
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
        z = torch.randn_like(samples) if i > 1 else 0
        eps = nn_model(samples, t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % save_rate == 0 or i == timesteps or i < 8:
            intermediate.append(samples.detach().cpu().numpy())
    intermediate = np.stack(intermediate)
    return samples, intermediate

@torch.no_grad()
def sample_ddpm_incorrect(n_sample):
    samples = torch.randn(n_sample, 3, height, height).to(device)
    intermediate = []
    for i in range(timesteps, 0, -1):
        print(f'sampling timestep {i:3d}', end='\r')
        t = torch.tensor([i / timesteps])[:, None, None, None].to(device)
        z = 0   # Pas d'ajout de bruit
        eps = nn_model(samples, t)
        samples = denoise_add_noise(samples, i, eps, z)
        if i % 20 == 0 or i == timesteps or i < 8:
            intermediate.append(samples.detach().cpu().numpy())
    intermediate = np.stack(intermediate)
    return samples, intermediate

In [14]:
# Visualisation avec l'échantillonnage standard
plt.clf()
samples, intermediate_ddpm = sample_ddpm(32)
animation_ddpm = plot_sample(intermediate_ddpm, 32, 4, save_dir, "ani_run", None, save=False)
HTML(animation_ddpm.to_jshtml())



<Figure size 640x480 with 0 Axes>

In [13]:



# Visualisation avec l'échantillonnage incorrect (sans bruit supplémentaire)
plt.clf()
samples, intermediate = sample_ddpm_incorrect(32)
animation = plot_sample(intermediate, 32, 4, save_dir, "ani_run", None, save=False)
HTML(animation.to_jshtml())




<Figure size 640x480 with 0 Axes>