In [None]:
!rm -r *
!git clone https://github.com/LudoRey/dip-deblur/
!mv dip-deblur/* ./
!rm -r dip-deblur

In [None]:
import torch
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from unet import *
from utils import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Read image and create degraded version

im = read_img('data/airplane.tif', as_tensor=True).to(device)
ker = gaussian_kernel(sigma=1.2, as_tensor=True).to(device)

im_blur = blur(im, ker)
im_blur_noisy = poisson_noise(im_blur, peak=1e4)

In [None]:
nb_epoch = 2000
learning_rate = 0.01
input = (torch.rand(1,32,128,128)/10).to(device)
criterion = nn.MSELoss()

all_net_params = [{'down_channels' : 128, 'up_channels' : 128, 'skip_channels' : 4, 'depth' : 4},
                  {'down_channels' : 96, 'up_channels' : 96, 'skip_channels' : 16, 'depth' : 4},
                  {'down_channels' : 64, 'up_channels' : 64, 'skip_channels' : 64, 'depth' : 4}]
all_psnr = np.zeros(3)

for i in range(3):
    # Create network
    model = Unet(in_channels=input.shape[1], out_channels=1, **all_net_params[i])
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    # Optimizing loop
    model.train()
    for i_epoch in range(nb_epoch):
        optimizer.zero_grad()
        im_recon = model(input)
        loss = criterion(blur(im_recon, ker), im_blur_noisy)
        loss.backward()
        optimizer.step()
        if (i_epoch+1) % 100 == 0:
            metric = psnr(to_numpy(im), to_numpy(im_recon))
            print(' [-] epoch {:}/{:}, loss {:.6f}, psnr {:.5f}'.format(i_epoch+1, nb_epoch, loss.item(), metric))
    all_psnr[i] = psnr(to_numpy(im), to_numpy(im_recon))