In [None]:
!rm -r *
!git clone https://github.com/LudoRey/dip-deblur/
!mv dip-deblur/* ./
!rm -r dip-deblur

In [1]:
import torch
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr
from unet import *
from utils import *

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
# Read image and create degraded version

im = read_img('data/airplane.tif', as_tensor=True).to(device)
ker = gaussian_kernel(sigma=1.2, as_tensor=True).to(device)

im_blur = blur(im, ker)
im_blur_noisy = poisson_noise(im_blur, peak=1e4)

#im_blur_noisy = read_img('data/veil128.tif', as_tensor=True).to(device)
#ker = read_img('data/veil128_psf.tif', as_tensor=True).to(device)

fig, axes = plt.subplots(1,3, figsize=(20,30))
axes[0].imshow(to_numpy(im), cmap='gray')
axes[1].imshow(to_numpy(im_blur), cmap='gray')
axes[2].imshow(to_numpy(im_blur_noisy), cmap='gray')
plt.show()

In [None]:
nb_epoch = 2000
learning_rate = 0.01
input = (torch.rand(1,32,128,128)/10).to(device)
#input = im_blur_noisy
net_params = {
    'down_channels' : 128,
    'up_channels' : 128,
    'skip_channels' : 4,
    'depth' : 4
}

# Create network
model = Unet(in_channels=input.shape[1], out_channels=1, **net_params)
model = model.to(device)

# Create loss function and optimizer
criterion = nn.MSELoss()
#criterion = CsiszarDiv()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Optimizing loop
model.train()
for i_epoch in range(nb_epoch):
    optimizer.zero_grad()
    im_recon = model(input)
    loss = criterion(blur(im_recon, ker), im_blur_noisy)
    loss.backward()
    optimizer.step()
    if (i_epoch+1) % 100 == 0:
        metric = psnr(to_numpy(im), to_numpy(im_recon))
        print(' [-] epoch {:}/{:}, loss {:.6f}, psnr {:.5f}'.format(i_epoch+1, nb_epoch, loss.item(), metric))

In [None]:
fig, axes = plt.subplots(1,3, figsize=(20,30))
axes[0].imshow(to_numpy(im), cmap='gray')
axes[1].imshow(to_numpy(im_blur_noisy), cmap='gray')
axes[2].imshow(to_numpy(im_recon), cmap='gray')
plt.show()

In [None]:
im_recon_blur = blur(im_recon, ker)
fig, axes = plt.subplots(1,2, figsize=(10,20))
axes[0].imshow(to_numpy(im_blur_noisy), cmap='gray')
axes[1].imshow(to_numpy(im_recon_blur), cmap='gray')

In [None]:
# For veil nebula unstretched

ref = read_img('data/veil128_stretched.tif', as_tensor=False)

fig, axes = plt.subplots(1,2, figsize=(10,20))
axes[0].imshow(match_hist(to_numpy(im_blur_noisy), ref), vmin=0, vmax=1, cmap='gray')
axes[1].imshow(match_hist(to_numpy(im_recon), ref), vmin=0, vmax=1, cmap='gray')