Code for **deconvolution** .

You can play with parameters and see how they affect the result. 

# Import libs

In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
from models import *

import torch
import torch.optim

from skimage.measure import compare_psnr
from models.convolution import Convolution

from utils.sr_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
#dtype = torch.FloatTensor


# Set up parameters and net

In [None]:
input_depth = 32
 
INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
KERNEL_TYPE='udf'

LR = 0.01
tv_weight = 0.0

num_iter=5000
OPTIMIZER = 'adam'

PLOT = True

# Define closure and optimize

In [None]:
def closure():
    global i, net_input,info_dict
    
    reg_noise_std = 0.01
    net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    out_HR = net(net_input)
    out_LR = convolution(out_HR)

    total_loss = mse(out_LR, img_LR_var) 
    mse_arr[i] = total_loss
        
    if tv_weight > 0:
        total_loss += tv_weight * tv_loss(out_HR)
        
    total_loss.backward()

    # Log
    psnr_LR = compare_psnr(img_lr_np, torch_to_np(out_LR))
    psnr_HR = compare_psnr(img_hr_np, torch_to_np(out_HR))
    print ('Iteration %05d    PSNR_LR %.3f   PSNR_HR %.3f' % (i, psnr_LR, psnr_HR), '\r', end='')
                      
    # History
    psnr_history.append([psnr_LR, psnr_HR])
    
    if PLOT and i % 100 == 0:
        out_HR_np = torch_to_np(out_HR)
        plot_image_grid([img_hr_np, np.clip(out_HR_np, 0, 1)], factor=13, nrow=3)
    
    i += 1
    if  i == num_iter:
        info_dict = {'mse': mse_arr, 'X': torch_to_np(out_HR)}
        
    return total_loss

In [None]:
result = dict()
for photo in ['Cameraman256', 'Lena512','image_House256rgb','image_Peppers512rgb']:
    photo_dict = dict()
    path_HR_image = 'data/sr/images/' + photo +'.png'
    for kernel in ['gauss','motion','defocus']:
        if kernel == 'gauss':
            kernel_path ='data/sr/kernels/kernel_gauss.mat'
            path_LR_image = 'data/sr/images/'+ photo +'_gauss.png'
        elif kernel == 'motion':
            kernel_path ='data/sr/kernels/kernel_motionblur.mat'
            path_LR_image = 'data/sr/images/'+ photo +'_motionblur.png'
        elif kernel == 'defocus':
            kernel_path ='data/sr/kernels/kernel_defocus.mat'
            path_LR_image = 'data/sr/images/'+ photo +'_defocus.png'
        else:
            print('no image in kernel.')
        
        #continue when there is no file:
        if os.path.isfile(path_LR_image) == False:
            continue
        
        ## load images
        img_lr_pil, img_lr_np = get_image(path_LR_image, -1)
        img_hr_pil, img_hr_np = get_image(path_HR_image, -1)
        n_channels = img_hr_np.shape[0]
        
        net_input = get_noise(input_depth, INPUT, (img_hr_pil.size[1], img_hr_pil.size[0])).type(dtype).detach()
        NET_TYPE = 'skip' # UNet, ResNet
        net = get_net(input_depth, 'skip', pad, n_channels=n_channels, skip_n33d=128, skip_n33u=128, skip_n11=4, 
              num_scales=5, upsample_mode='bilinear').type(dtype)
        mse = torch.nn.MSELoss().type(dtype)
        img_LR_var = np_to_torch(img_lr_np).type(dtype)

        convolution = Convolution(n_planes=n_channels, kernel_type=KERNEL_TYPE, kernel_path=kernel_path, preserve_size=True).type(dtype)

        # initialization
        psnr_history = [] 
        net_input_saved = net_input.detach().clone()
        noise = net_input.detach().clone()
        
        i=0
        p = get_params(OPT_OVER, net, net_input)
        mse_arr = np.zeros(num_iter)
        info_dict = dict()

        # call network
        optimize(OPTIMIZER, p, closure, LR, num_iter)

        # save result
        kernel_dict = dict()
        kernel_dict['info'] = info_dict
        photo_dict[kernel] = kernel_dict
        result[photo]=photo_dict

In [None]:
jsObj = json.dumps(result)
 
file = open('result.json', 'w')
file.write(jsObj)
file.close()