In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import skimage as sk
from unet import *
from utils import *

In [None]:
# Read image and create degraded version

im = read_img('data/airplane.tif', as_tensor=True)
ker = gaussian_kernel(sigma=1, as_tensor=True)

im_blur = blur(im, ker)
im_blur_noisy = poisson_noise(im_blur, peak=1e5)

fig, axes = plt.subplots(1,3)
axes[0].imshow(display(im), cmap='gray')
axes[1].imshow(display(im_blur), cmap='gray')
axes[2].imshow(display(im_blur_noisy), cmap='gray')
plt.show()

In [None]:
nb_epoch = 100
learning_rate = 0.01
#input = torch.rand(1,32,128,128)/10
input = im_blur_noisy
net_params = {
    'down_channels' : 128,
    'up_channels' : 128,
    'skip_channels' : 4,
    'depth' : 2
}
# Create network etc.
model = Unet(in_channels=input.shape[1], out_channels=1, **net_params)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Optimizing loop
model.train()
for i_epoch in range(nb_epoch):
    optimizer.zero_grad()
    im_recon = model(input)
    loss = criterion(blur(im_recon, ker), im_blur_noisy)
    error = criterion(im_recon, im)
    loss.backward()
    optimizer.step()
    if (i_epoch+1) % 10 == 0:
        print(' [-] epoch {:}/{:}, loss {:.4f}, error {:.4f}'.format(i_epoch+1, nb_epoch, loss.item(), error))

In [None]:

fig, axes = plt.subplots(1,3, figsize=(20,30))
axes[0].imshow(display(im), cmap='gray')
axes[1].imshow(display(im_blur_noisy), cmap='gray')
axes[2].imshow(display(im_recon), cmap='gray')
plt.show()

In [None]:
im_recon_blur = F.conv2d(im_recon, ker, padding='same')
fig, axes = plt.subplots(1,2, figsize=(10,20))
axes[0].imshow(display(im_blur), cmap='gray')
axes[1].imshow(display(im_recon_blur), cmap='gray')