# Diffusion experiments

## Imports

In [1]:
from torchvision import datasets
import tqdm
import os
import numpy as np
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from unet import Unet
from denoiser import Diffusion

### plot utils

In [2]:
def show_images(original, noisy, reconstructed):
    original = np.transpose(original.cpu().numpy(), (1, 2, 0))
    noisy = np.transpose(noisy.cpu().numpy(), (1, 2, 0))
    reconstructed = np.transpose(reconstructed.detach().cpu().numpy(), (1, 2, 0))

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    axes[0].imshow(original)
    axes[0].set_title("Image Originale")
    axes[0].axis("off")

    axes[1].imshow(noisy)
    axes[1].set_title("Image Bruitée")
    axes[1].axis("off")

    axes[2].imshow(reconstructed)
    axes[2].set_title("Image Reconstituée")
    axes[2].axis("off")

    plt.show()

## Experiment CelebA

In [None]:

# Set up data loader
transform = transforms.Compose([transforms.ToTensor(),
transforms.Lambda(lambda x: 2 * x - 1),
transforms.Resize((64,64))])
#train_dataset = datasets.CIFAR10(root="data", train=True, transform=transform, download=True)
train_dataset = datasets.CelebA(root="data", transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
def gene():
    for elt in tqdm.tqdm(train_loader):
        yield elt[0]
# Initialize diffusion process and model
timesteps = 1000
t_reconst=200
num_epochs = 10
device = torch.device("cuda")
model = Unet(channels=64).to(device)
nb_params = sum([elt.numel() for elt in model.parameters()])
print(f"nb params : {nb_params}")
diffusion = Diffusion(model, timesteps=timesteps, device=device)

full_image = next(gene())[0].unsqueeze(0)  # Charger une image
# Train model
diffusion.train(gene, num_epochs=num_epochs)
shape = full_image.shape
img_gen = diffusion.sampling(shape)
plt.imshow(np.clip(img_gen[0].permute(2,1,0).detach().cpu().numpy(), 0, 1))
plt.show()



noisy_img = full_image.clone()
# noisy_img[0, :, :] = 0
noisy_img, _ = diffusion.forward_diffusion(noisy_img, t_reconst)
noisy_img = noisy_img.squeeze(0)
reconstructed_full_image = diffusion.sampling(shape=None, xT=noisy_img, T=t_reconst).squeeze(0)
# reconstructed_full_image = diffusion.sampling(shape = noisy_img.shape, T=t_reconst).squeeze(0).clip(0., 1.)
# reconstructed_full_image = diffusion.sampling(xT=reconstructed_full_image, T=20).squeeze(0).clip(0., 1.)
# Afficher les résultats


show_images((1 + full_image[0]) / 2, (1 + noisy_img) / 2, (1 + reconstructed_full_image) / 2)

torch.save(model.state_dict(), "model_latest.pt")

## Experiment HE

In [None]:

np.random.seed(42)
torch.manual_seed(42)
# Chargement du dataset avec extraction de patchs
class HistoLiverPatchDataset(Dataset):

    def __init__(self, image_dir, nb_steps=1000, patch_size=64, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.patch_size = patch_size
        self.images = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if
                       img.endswith('.jpg') or img.endswith('.png')]

        self.np_images = [Image.open(img_path).convert("RGB") for img_path in self.images]
        self.nb_img = len(self.images)
        self.nb_steps = nb_steps
    def __len__(self):
        return self.nb_steps

    def __getitem__(self, idx):
        # Randomly select an image from the list
        image_idx = np.random.choice(range(self.nb_img))
        image = self.np_images[image_idx]
        # Convert to tensor temporarily to get width and height
        image_tensor = 2*transforms.ToTensor()(image)-1
        _, img_height, img_width = image_tensor.shape

        # Randomly select the top-left corner of the patch
        max_row = img_height - self.patch_size
        max_col = img_width - self.patch_size
        row = np.random.randint(0, max_row)
        col = np.random.randint(0, max_col)

        # Crop the patch from the image
        patch = image_tensor[:, row:row + self.patch_size, col:col + self.patch_size]

        return patch

# Définir les transformations pour les images
transform = transforms.Compose([
    transforms.ToTensor()
])

# Charger le dataset de patchs
image_dir = '../../dataset/data/liver_HE'  # Assurez-vous que ce dossier contient vos images
patch_size = 64
batch_size = 64
nb_patch_per_epoch = 6400
dataset = HistoLiverPatchDataset(nb_steps= nb_patch_per_epoch ,image_dir=image_dir, patch_size=patch_size, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

timesteps = 1000
t_reconst = 100
device = torch.device("cuda")
model = Unet(channels=100).to(device)
nb_params = sum([elt.numel() for elt in model.parameters()])
print(device)
print(f"nb params : {nb_params}")
diffusion = Diffusion(model, timesteps=timesteps, beta_end=2e-2,device=device)
diffusion.train(dataloader, lr=1e-3, num_epochs=100)

# Chargement d'une image complète et affichage des résultats
full_image = dataset[0][:,:patch_size,:patch_size]  # Charger une image
noisy_img = full_image.clone()
#noisy_img[0, :, :] = 0
noisy_img, _ = diffusion.forward_diffusion(noisy_img, t_reconst)
noisy_img = noisy_img.squeeze(0)
reconstructed_full_image = diffusion.sampling(shape=None,xT=noisy_img, T=t_reconst).squeeze(0)
#reconstructed_full_image = diffusion.sampling(shape = noisy_img.shape, T=t_reconst).squeeze(0).clip(0., 1.)
#reconstructed_full_image = diffusion.sampling(xT=reconstructed_full_image, T=20).squeeze(0).clip(0., 1.)
# Afficher les résultats


show_images((1+full_image)/2, (1+noisy_img)/2, (1+reconstructed_full_image)/2)

cuda
nb params : 11434115
Epoch 1/100, Loss: 0.2838, nb_step: 100
Epoch 2/100, Loss: 0.1674, nb_step: 100
Epoch 3/100, Loss: 0.1242, nb_step: 100
Epoch 4/100, Loss: 0.08726, nb_step: 100
Epoch 5/100, Loss: 0.0986, nb_step: 100
Epoch 6/100, Loss: 0.09496, nb_step: 100
Epoch 7/100, Loss: 0.07203, nb_step: 100
Epoch 8/100, Loss: 0.06779, nb_step: 100
Epoch 9/100, Loss: 0.06251, nb_step: 100
Epoch 10/100, Loss: 0.07291, nb_step: 100
Epoch 11/100, Loss: 0.07452, nb_step: 100
Epoch 12/100, Loss: 0.09058, nb_step: 100
Epoch 13/100, Loss: 0.05127, nb_step: 100
Epoch 14/100, Loss: 0.06355, nb_step: 100
Epoch 15/100, Loss: 0.08679, nb_step: 100
Epoch 16/100, Loss: 0.06474, nb_step: 100
Epoch 17/100, Loss: 0.05874, nb_step: 100
Epoch 18/100, Loss: 0.05672, nb_step: 100
Epoch 19/100, Loss: 0.06628, nb_step: 100
