In [None]:
!rm -r *
!git clone https://github.com/LudoRey/dip-deblur/
!mv dip-deblur/* ./
!rm -r dip-deblur

In [None]:
import torch
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from unet import *
from utils import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Read image and create degraded version

im = read_img('data/airplane.tif', as_tensor=True).to(device)
ker = gaussian_kernel(sigma=1.2, as_tensor=True).to(device)

im_blur = blur(im, ker)
peaks = [10**4, 10**3, 10**2]
all_im_blur_noisy = [poisson_noise(im_blur, peak=peak) for peak in peaks]

fig, axes = plt.subplots(1,3, figsize=(5,10))
for i in range(3):
    axes[i].imshow(to_numpy(all_im_blur_noisy[i]), cmap='gray')
for ax in axes:
    ax.tick_params(axis=u'both', which=u'both',length=0)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
plt.show()

In [None]:
nb_epoch = 2000
learning_rate = 0.01
input = (torch.rand(1,32,128,128)/10).to(device)
criterions = [nn.MSELoss(), CsiszarDiv]

all_net_params = [{'feature_channels' : 128, 'skip_channels' : 4, 'depth' : 2},
                  {'feature_channels' : 128, 'skip_channels' : 4, 'depth' : 4},
                  {'feature_channels' : 128, 'skip_channels' : 4, 'depth' : 6}]

all_psnr = np.zeros((3,2))
all_im_recon = np.zeros((3,2,128,128))

for i_run in range(3):
    for j_crit in range(2):
        im_blur_noisy = all_im_blur_noisy[i_run]
        criterion = criterions[j_crit]
        # Create network
        model = Unet(in_channels=input.shape[1], out_channels=1, **all_net_params[i_run])
        model = model.to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        # Optimizing loop
        model.train()
        for i_epoch in range(nb_epoch):
            optimizer.zero_grad()
            im_recon = model(input)
            loss = criterion(blur(im_recon, ker), im_blur_noisy)
            loss.backward()
            optimizer.step()
            if (i_epoch+1) % 100 == 0:
                metric = psnr(to_numpy(im), to_numpy(im_recon))
                print(' [-] Run {:} epoch {:}/{:}, loss {:.6f}, psnr {:.5f}'.format(i_run, i_epoch+1, nb_epoch, loss.item(), metric))
        all_psnr[i_run][j_crit] = psnr(to_numpy(im), to_numpy(im_recon))
        all_im_recon[i_run][j_crit] = to_numpy(im_recon)

In [None]:
fig, axes = plt.subplots(3,3, figsize=(5,10))
axes[0][0].set_title('y')
axes[0][1].set_title('MSE')
axes[0][2].set_title['I-div']
for i in range(3):
  axes[i][0].imshow(all_im_blur_noisy[i])
  axes[i][0].set_ylabel('alpha')
  for j in range(2):
    axes[i][j+1].imshow(all_im_recon[i][j], cmap='gray')
    axes[i][j+1].set_xlabel('PSNR {:.2f}'.format(all_psnr[i][j]))
for ax in axes.flatten():
  ax.tick_params(axis=u'both', which=u'both',length=0)
  ax.set_xticklabels([])
  ax.set_yticklabels([])
  
plt.show()
#fig.savefig('net_depth.png', dpi=300, bbox_inches='tight', transparent=True)