# Deep Image Prior
## Lab session #3 2023
#### Julien.rabin (at) ensicaen.fr - 2023

![Logo](Ensicaen-logo.png) 

________________________________
### LastName / Nom : 
### Surname / Prénom : 
### Group :
### Date : 
________________________________

In this notebook, the goal is to implement the deep image prior optimization :
- [At the beginning](#0---import--load-image) some useful functions are given to load/read a pytorch tensor image
- [Deep Prior](#1---create-a-deep-neural-network):
    - create your own `deep image prior` by creating a neural network architecture, such as an MLP
    - experiments with various architectures
- [Random model initialisation](#2---test-the-model)
    - displays a random image from the prior distribution
- [Train the model on the image](#3---optimization)
- [Comparison with patch based denoising](#comparison-with-nl-pca-for-image-denoising)
- [Application to artefact reduction from image compression](#5---application-to-jpeg-artefact-reduction)
- [Application to super-resolution](#application-to-super-resolution)
- [](#)


________________________________
<a id='cell_0'></a>
## 0 - Import & Load image


In [None]:
# import packages
import matplotlib.pyplot  as plt
%matplotlib inline

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

In [None]:
%env CUDA_VISIBLE_DEVICES=0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
if device == 'cuda' :
    print("id of GPU :", torch.cuda.current_device())
    print("number of GPU available :", torch.cuda.device_count())
    print(torch.cuda.get_device_name(0))

In [None]:
# Fonctions utiles pour l'algorithme

def Tensor_display(img_torch,IMG_NORMALIZE = False, figsize=(10,10)) : # display images coded as torch tensor
    img_np = img_torch.squeeze(0).permute(1, 2, 0).cpu().numpy() #is an array, shaped as an image for plt with permute

    if IMG_NORMALIZE : # image is in [-1,1]
        img_np = img_np/2.+.5 # now in [0,1]

    plt.figure(figsize=figsize)
    plt.imshow(img_np, interpolation="bicubic")
    plt.axis('off')
    plt.show()

def Tensor_load(file_name, IMG_NORMALIZE = False) : # load an image as a torch tensor : BATCHSIZE=1 x COLOR=3 x Width x Height
    img_np0 = plt.imread(file_name)
    if img_np0.max()>1 : # normalization pour corriger un bug entre PNG (dans [0,1]) et JPG (dans [0,255])
        img_np0 = img_np0/img_np0.max()

    if IMG_NORMALIZE :
        img_np0 = (img_np0-0.5)*2. # now in [-1,1]

    img_torch = torch.tensor(img_np0, dtype = torch.float, device=device, requires_grad = False).permute(2, 0, 1).unsqueeze(0)
    if img_torch.size(1) == 4 : # remove transparent layer
        img_torch = img_torch[:,:3,:,:]
    return img_torch

In [None]:
data_rep = 'img/' #'res/'
! ls $data_rep/
im_name  =  'butterfly.png' #

In [None]:
save_rep = 'save/'

file_name = data_rep + im_name

IMG_NORMALIZE = True # if True [-1,1], else in [0,1]

## load image as a Tensor
img_torch = Tensor_load(file_name, IMG_NORMALIZE)
print(img_torch.shape)

## resize (required for large images)
imsize = 256 # desired input/output resolution

img_torch = nn.functional.adaptive_avg_pool2d(img_torch, imsize).to(device)
img_torch0 = torch.clone(img_torch).to(device)

img_torch += torch.randn_like(img_torch) * 80./255.

C = img_torch.size(1) # channel dim = 3
W = img_torch.size(2) # width
H = img_torch.size(3) # height

# display image
Tensor_display(img_torch0.detach(), IMG_NORMALIZE, figsize=(5,5))
print(img_torch0.shape)
Tensor_display(img_torch.detach(), IMG_NORMALIZE, figsize=(5,5))
img_torch.shape

# 1 - Create a Deep Neural Network

- complete the following code to design your generative network
- several architecture can be considered (MLP, ConvNet, U-Net, ResNet, etc.)
- the input is a (fixed) random latent variable of arbitrary size (batch size is 1)
- the output image should match the desired tensor resolution, e.g. 1 x 3 x 256 x 256

In [None]:
class create_Generator(nn.Module) :
    def __init__(self, latent_dim=16, hidden_dim=16, out_res=256):
        super().__init__()

        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.out_res = out_res

        self.conv1 = nn.Conv2d(latent_dim, hidden_dim, 5, stride=1, padding=2)
        self.act = nn.ReLU()
        self.conv2 = nn.Conv2d(hidden_dim, 3, 5, stride=1, padding=2)
        

    def forward(self,x) :
        x = self.conv1(x)
        x = self.act(x)
        out = self.conv2(x)
        
        return out

In [None]:
#imprimer le nombre de paramètres
def print_nb_params(model) :
    n = 0
    for x in model.parameters() :
        #print("layer : ", x)
        n+= x.numel()
    print("total nb of parameters :", n)


# 2 - test the model

In [None]:
latent_dim = 3
hidden_dim = 16
out_res = imsize
model = create_Generator(latent_dim=latent_dim, hidden_dim=hidden_dim, out_res=out_res)

model = model.to(device)
#print(model.parameters)
print_nb_params(model)

In [None]:
BATCH_SIZE = 1 # should be one, the dataset is reduced to 1 image, but a batch of random inputs can be used
z0 = torch.randn((BATCH_SIZE,latent_dim,out_res,out_res), device=device, requires_grad = False)
z = z0.clone()
out = model(z)
out.shape

In [None]:
# display image
Tensor_display(out.detach()[0], IMG_NORMALIZE, figsize=(3,3))


# 3 - Optimization

In [None]:
criterion = torch.nn.MSELoss()
out = model(z)
loss = criterion(out,img_torch)
print(f" loss = |(output) - noisy_data|^2 = {loss.detach().cpu().numpy()}")

In [None]:
def eval_psnr(im, ref, IMG_NORMALIZE = False) :
    if IMG_NORMALIZE is True :
      return 0
    else :
      return "à completer"

out = model(z)
PSNR = eval_psnr(out, img_torch0)
print(f" PSNR of HR_output vs HR_data_GT = {PSNR.detach().cpu().numpy()} dB")

In [None]:
optimiser = torch.optim ...

Loss = []
PSNR = []

In [None]:
niter = int(1e2)
for it in tqdm(range(niter)) :
    # sample the model
    out = ...

    # optimization of the NN
    loss = ...

    # save loss & psnr
    Loss = ...
    PSNR = ...


In [None]:
# display images
print('output of generative network')
out = model(z0)
Tensor_display(out.detach()[0], IMG_NORMALIZE, figsize=(4,4))
print('noisy data')
Tensor_display(img_torch.detach(), IMG_NORMALIZE, figsize=(4,4))
print('ground truth')
Tensor_display(img_torch0.detach(), IMG_NORMALIZE, figsize=(4,4))

- plot the training error on the dataset and the evaluation on the ground-truth to detect overfitting
- for upsampling (if needed in the network) compare the results by using : nn.upsampling (with 'nearest', 'bicubic', 'linear' ...) and convtranspose
- experiment with various architectures
- for a given architecture, examine the role of the number of layers, the size of the latent space, the number of hidden dimensions ...

________________________________
## 4 - Comparison with NL-PCA for image denoising


- compare the denoising result for different level of noise

________________________________
## 5 - Application to JPEG artefact reduction

- modify the previous code to train a network to remove artefact from JPEG compression

________________________________
## 6 - Application to super resolution

- create a function which downsample the image by a factor K (e.g. K=4), using for instance nn.AvgPool2d
- modify the previous code to train a network to solve the corresponding inverse problem (generate an image $K$ times larger)