In [1]:
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline

import argparse
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import numpy as np
from models import *

import torch
import torch.optim

from skimage.measure import compare_psnr
from models.blindconvolution import BiConvolution

from utils.sr_utils import *

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor
#dtype = torch.FloatTensor


In [2]:
# Starts here
#path_LR_image = 'data/sr/images/Cameraman256_gauss.png'
#path_HR_image = 'data/sr/images/Cameraman256.png'
kernel_path ='data/sr/images/kernel_gauss.png'
img_k_pil, img_k_np = get_image(kernel_path, -1)
#img_lr_pil, img_lr_np = get_image(path_LR_image, -1)
#img_hr_pil, img_hr_np = get_image(path_HR_image, -1)
#plot_image_grid([img_lr_np,img_hr_np])
img_k_np.shape

(1, 9, 9)

In [3]:
input_depth = 32
 
INPUT =     'noise'
pad   =     'reflection'
OPT_OVER =  'net'
KERNEL_TYPE='udf'

LR = 0.01
tv_weight = 0.0
num_iter=10000
reg_noise_std = 0.03


OPTIMIZER = 'adam'

PLOT = True


In [4]:
def closure():
    global i, net_input, net_input_kernel,info_dict,X,psnr_HR_max,K,psnr_k_max
    
    reg_noise_std = 0.01
    net_input = net_input_saved + (noise.normal_() * reg_noise_std)
    net_input_kernel = net_input_kernel + (noise_kernel.normal_() * reg_noise_std)

    out_HR = net(net_input)
    out_kernel = net_kernel(net_input_kernel)
    #print(out_kernel.shape)
    out_kernel = torch.nn.Softmax(2)(out_kernel.view(*out_kernel.size()[:2], -1)).view_as(out_kernel)
    #print(out_kernel.shape)
    padding = nn.ReflectionPad2d(4)
    out_LR = torch.nn.functional.conv2d(padding(out_HR), out_kernel.expand(-1, out_HR.shape[1], -1, -1), padding=0)

    total_loss = mse(out_LR, img_LR_var)
    
    if tv_weight > 0:
        total_loss += tv_weight * tv_loss(out_HR)
        
    total_loss.backward()

    # Log
    print(img_lr_np.shape,torch_to_np(out_LR).shape)
    psnr_LR = compare_psnr(img_lr_np, torch_to_np(out_LR))
    psnr_HR = compare_psnr(img_hr_np, torch_to_np(out_HR))
    psnr_k = compare_psnr(img_k_np, torch_to_np(out_kernel.view(1, 1, 9, 9)))
    print ('Iteration %05d' % (i), '\r', end='')
                      
    # History
    psnr_history.append([psnr_LR, psnr_HR, psnr_k])
    
   # if PLOT and i % 100 == 0:
        #out_HR_np = torch_to_np(out_HR)
        #plot_image_grid([img_hr_np, np.clip(out_HR_np, 0, 1)], factor=13, nrow=3)
    if psnr_HR > psnr_HR_max:
        psnr_HR_max = psnr_HR
        X = torch_to_np(out_HR)
        K = torch_to_np(out_kernel)
        psnr_k_max = psnr_k
    
    i += 1
    if  i == num_iter:
        info_dict = {'psnr_HR':psnr_HR_max,'psnr_k':psnr_k_max,'X': X,'K':K}
    
    return total_loss

In [5]:
result = dict()
for photo in ['image_House256rgb','image_Peppers512rgb']:
    photo_dict = dict()
    path_HR_image = 'data/sr/images/' + photo +'.png'
    for kernel in ['gauss','motion','defocus']:#'gauss','motion','defocus'
        if kernel == 'gauss':
            kernel_path ='data/sr/kernels/kernel_gauss.mat'
            path_LR_image = 'data/sr/images/'+ photo +'_gauss.png'
        elif kernel == 'motion':
            kernel_path ='data/sr/kernels/kernel_motionblur.mat'
            path_LR_image = 'data/sr/images/'+ photo +'_motion.png'
        elif kernel == 'defocus':
            kernel_path ='data/sr/kernels/kernel_defocus.mat'
            path_LR_image = 'data/sr/images/'+ photo +'_defocus.png'
        else:
            print('no image in kernel.')
        
        #continue when there is no file:
        if os.path.isfile(path_LR_image) == False:
            continue
        ## load images
        img_lr_pil, img_lr_np = get_image(path_LR_image, -1)
        img_hr_pil, img_hr_np = get_image(path_HR_image, -1)
        n_channels = img_hr_np.shape[0]
        
        net_input = get_noise(input_depth, INPUT, (img_hr_pil.size[1], img_hr_pil.size[0])).type(dtype).detach()
        net_input_kernel= get_noise(input_depth, INPUT, (img_k_pil.size[1], img_k_pil.size[0])).type(dtype).detach()


        NET_TYPE = 'skip' # UNet, ResNet
        net = get_net(input_depth, 'skip', pad, n_channels=3, skip_n33d=128, skip_n33u=128, skip_n11=4, 
              num_scales=5, upsample_mode='bilinear').type(dtype)
        net_kernel = biget_net(input_depth, 'skip', pad, n_channels=1, skip_n33d=128, skip_n33u=128, skip_n11=4, 
              num_scales=5, upsample_mode='bilinear').type(dtype)
        mse = torch.nn.MSELoss().type(dtype)
        img_LR_var = np_to_torch(img_lr_np).type(dtype)
        img_k_var = np_to_torch(img_k_np).type(dtype)
             
        psnr_history = [] 
        net_input_saved = net_input.detach().clone()
        net_input_kernel_saved = net_input_kernel.detach().clone()
        noise = net_input.detach().clone()
        noise_kernel = net_input_kernel.clone()

        i = 0
        p = get_params(OPT_OVER, net, net_input) + get_params(OPT_OVER, net_kernel, net_input_kernel)
        mse_arr = np.zeros(num_iter)
        info_dict = dict()
        X = 0
        K = 0
        psnr_HR_max = 0
        # call network
        optimize(OPTIMIZER, p, closure, LR, num_iter)
        
        # save result
        kernel_dict = dict()
        kernel_dict['info'] = info_dict
        photo_dict[kernel] = kernel_dict
        result[photo]=photo_dict
        np.save('result_blind'+photo+kernel+'.npy',result)


Starting optimization with ADAM


  "See the documentation of nn.Upsample for details.".format(mode))


(3, 256, 256) (1, 256, 256)


ValueError: Input images must have the same dimensions.