In [1]:
import torch
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torchvision.transforms import ToTensor
from utils.RandomNoise import AddGaussianNoise
import random

# U-Net Denoysing

In [2]:
dataset_path = 'C:/Users/M1074839/Documents/datasets/cifar-10-python'
train_dataset = CIFAR10(root=dataset_path, train=True, transform=transforms.Compose([ToTensor()]))
test_dataset = CIFAR10(root=dataset_path, train=False, transform=transforms.Compose([ToTensor()]))

### Size of an image

In [None]:
train_dataset[0][0].shape

### Size of dataset

In [None]:
print(f'taille dataset d\'entrainement {len(train_dataset)}')
print(f'taille dataset de test {len(test_dataset)}')

### Illustration of the noised apply to our dataset

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2)
plt.axis('off')

axs[0].axis('off')
axs[0].imshow(train_dataset[0][0].permute(1,2,0))

axs[1].axis('off')
axs[1].imshow(AddGaussianNoise(mean=0, std=0.1)(train_dataset[0][0]).permute(1,2,0))

In [None]:
from torch.nn import MSELoss
from torch.optim import Adam
from utils.trainer.DenoisingTrainer import DenoysingTrainer
from models.ResidualUNet import ResidualUNet

learning_rate = 5e-2
noise = AddGaussianNoise(0., .1)
model = ResidualUNet(in_channels=3, depth=5, num_classes=3, task_name='denoising', dataset_name='CIFAR10')

loss = MSELoss()
optimizer  = Adam(params=model.parameters(), lr=learning_rate)

trainer = DenoysingTrainer(model=model, train_dataset=train_dataset, test_dataset=test_dataset, noise=noise, loss_fn=loss, optimizer=optimizer, batch_size=64, save_best=True)

In [3]:
trainer.train(num_epochs=10)

In [None]:
loss, PSNR = trainer.evaluate()
print(f'loss: {loss:.4f}')
print(f'PSNR: {PSNR:.4f}')

### Illustration of the denoising effect of our model

In [4]:
random_test_image = test_dataset[random.randint(0,len(test_dataset))][0]
random_test_image_noised = AddGaussianNoise(mean=0, std=0.05)(random_test_image)
model_predict_unnoised_image = model(random_test_image_noised[None, :, :, :].to('cuda'))

fig, axs = plt.subplots(nrows=1, ncols=3,  figsize=(7, 21))
plt.axis('off')

axs[0].axis('off')
axs[0].imshow(random_test_image.permute(1,2,0))

axs[1].axis('off')
axs[1].imshow(random_test_image_noised.permute(1,2,0))

axs[2].axis('off')
axs[2].imshow(model_predict_unnoised_image[0].permute(1,2,0).cpu().detach().numpy())

## Débruitage d'une image random

In [5]:
random_noised_1 = torch.randn((3, 256, 256)).to('cuda')
random_noised_2 = random_noised_1 + torch.randn(random_noised_1.shape).to('cuda')
random_noised_3 = random_noised_2 + torch.randn(random_noised_2.shape).to('cuda')
random_noised_4 = random_noised_3 + torch.randn(random_noised_3.shape).to('cuda')
random_noised_5 = random_noised_4 + torch.randn(random_noised_4.shape).to('cuda')

model_predict_unnoised_image_1 = model(random_noised_5[None, :, :, :])
model_predict_unnoised_image_2 = model(model_predict_unnoised_image_1)
model_predict_unnoised_image_3 = model(model_predict_unnoised_image_2)
model_predict_unnoised_image_4 = model(model_predict_unnoised_image_3)
model_predict_unnoised_image_5 = model(model_predict_unnoised_image_4)
model_predict_unnoised_image_6 = model(model_predict_unnoised_image_5)
model_predict_unnoised_image_7 = model(model_predict_unnoised_image_6)
model_predict_unnoised_image_8 = model(model_predict_unnoised_image_7)
model_predict_unnoised_image_9 = model(model_predict_unnoised_image_8)
model_predict_unnoised_image_10 = model(model_predict_unnoised_image_9)


fig, axs = plt.subplots(nrows=3, ncols=5, figsize=(21, 7))
plt.axis('off')

axs[0, 0].axis('off')
axs[0, 0].imshow(random_noised_1.permute(1,2,0).cpu().detach().numpy())

axs[0, 1].axis('off')
axs[0, 1].imshow(random_noised_2.permute(1,2,0).cpu().detach().numpy())

axs[0, 2].axis('off')
axs[0, 2].imshow(random_noised_3.permute(1,2,0).cpu().detach().numpy())

axs[0, 3].axis('off')
axs[0, 3].imshow(random_noised_4.permute(1,2,0).cpu().detach().numpy())

axs[0, 4].axis('off')
axs[0, 4].imshow(random_noised_5.permute(1,2,0).cpu().detach().numpy())



axs[1, 0].axis('off')
axs[1, 0].imshow(model_predict_unnoised_image_1[0].permute(1,2,0).cpu().detach().numpy())

axs[1, 1].axis('off')
axs[1, 1].imshow(model_predict_unnoised_image_2[0].permute(1,2,0).cpu().detach().numpy())

axs[1, 2].axis('off')
axs[1, 2].imshow(model_predict_unnoised_image_3[0].permute(1,2,0).cpu().detach().numpy())

axs[1, 3].axis('off')
axs[1, 3].imshow(model_predict_unnoised_image_4[0].permute(1,2,0).cpu().detach().numpy())

axs[1, 4].axis('off')
axs[1, 4].imshow(model_predict_unnoised_image_5[0].permute(1,2,0).cpu().detach().numpy())



axs[2, 0].axis('off')
axs[2, 0].imshow(model_predict_unnoised_image_6[0].permute(1,2,0).cpu().detach().numpy())

axs[2, 1].axis('off')
axs[2, 1].imshow(model_predict_unnoised_image_7[0].permute(1,2,0).cpu().detach().numpy())

axs[2, 2].axis('off')
axs[2, 2].imshow(model_predict_unnoised_image_8[0].permute(1,2,0).cpu().detach().numpy())

axs[2, 3].axis('off')
axs[2, 3].imshow(model_predict_unnoised_image_9[0].permute(1,2,0).cpu().detach().numpy())

axs[2, 4].axis('off')
axs[2, 4].imshow(model_predict_unnoised_image_10[0].permute(1,2,0).cpu().detach().numpy())