In [None]:
!rm -r *
!git clone https://github.com/LudoRey/dip-deblur/
!mv dip-deblur/* ./
!rm -r dip-deblur

In [None]:
import torch
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from unet import *
from utils import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Read image and create degraded version

im = read_img('data/airplane.tif', as_tensor=True).to(device)
ker = gaussian_kernel(sigma=1.2, as_tensor=True).to(device)

im_blur = blur(im, ker)
im_blur_noisy = gaussian_noise(im_blur, sigma=0.1)

fig, axes = plt.subplots(1,2, figsize=(5,10))
axes[0].imshow(to_numpy(im), cmap='gray')
axes[1].imshow(to_numpy(im_blur_noisy), cmap='gray')
for ax in axes:
    ax.tick_params(axis=u'both', which=u'both',length=0)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
plt.show()

In [None]:
nb_epoch = 10
learning_rate = 0.01
input = (torch.rand(1,32,128,128)/10).to(device)
criterion = nn.MSELoss()

net_params = {'feature_channels' : 128, 'skip_channels' : 4, 'depth' : 4}

psnr_tracker = []
loss_tracker = []

# Create network
model = Unet(in_channels=input.shape[1], out_channels=1, **net_params)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# Optimizing loop
model.train()
for i_epoch in range(nb_epoch):
    optimizer.zero_grad()
    im_recon = model(input)
    loss = criterion(blur(im_recon, ker), im_blur_noisy)
    loss.backward()
    optimizer.step()
    if (i_epoch+1) % 100 == 0:
        metric = psnr(to_numpy(im), to_numpy(im_recon))
        print(' [-] Run {:} epoch {:}/{:}, loss {:.6f}, psnr {:.5f}'.format(i_run, i_epoch+1, nb_epoch, loss.item(), metric))
    psnr_tracker.append(psnr(to_numpy(im), to_numpy(im_recon)))
    loss_tracker.append(loss.item())

In [None]:
fig, ax1 = plt.subplots(1,1, figsize=(5,3))

ax1.plot(loss_tracker, color='C0')
ax1.set_xlabel('Itérations')
ax1.set_ylabel('MSE', color='C0')
ax1.tick_params(axis='y', labelcolor='C0', color='C0')
ax2 = ax1.twinx()
ax2.plot(psnr_tracker, color='C3')
ax2.set_ylabel('PSNR', color='C3')
ax2.tick_params(axis='y', labelcolor='C3', color='C3')
  
plt.show()
fig.savefig('figs/poisson.png', dpi=300, bbox_inches='tight', transparent=True)