# Importing Libraries

In [1]:
from __future__ import print_function
import torch
import torch.optim
from functions.models.optim import *

from skimage.measure import compare_psnr
from functions.models import *
from copy import deepcopy
from functions.utils.global_parameters import *
from functions.utils.common_utils import torch_to_np
from IPython.core.debugger import set_trace

import matplotlib.pyplot as plt
import numpy as np
import cv2

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
#dtype = torch.cuda.FloatTensor
dtype = torch.FloatTensor

                the kernel may be left running.  Please let us know
                about your system (bitness, Python, etc.) at
                ipython-dev@scipy.org
  ipython-dev@scipy.org""")


In [2]:
def dip(img_np, arch = 'default', LR = 0.01, num_iter = 1000, reg_noise_std = 1.0/30,exp_weight = 0.99, INPUT = 'noise', save = False, save_path = '', plot = True, input_depth = None, name = None, loss_fn = "MSE", OPTIMIZER = "adam", pad = 'zero',  OPT_OVER = 'net' ):
    
    glparam = global_parameters()
    glparam.set_params(save, plot, reg_noise_std, exp_weight)
    glparam.load_images(img_np)
    glparam.img_torch = glparam.img_torch.type(dtype)
    
    if arch == 'simple':
        if input_depth == None:
            input_depth = 3 
        glparam.net = get_net(input_depth,'skip', pad,
                skip_n33d=16, 
                skip_n33u=16, 
                skip_n11=0, 
                num_scales=3,
                upsample_mode='bilinear').type(dtype)
    else:
        assert False

    net_input = get_noise(input_depth, INPUT, (glparam.img_np.shape[1], glparam.img_np.shape[2])).type(dtype).detach()   
    glparam.net_input_saved = net_input.detach().clone()
    glparam.noise = net_input.detach().clone()
    
    # Compute number of parameters
    param_numbers  = sum([np.prod(list(p.size())) for p in glparam.net.parameters()]) 
    print ('\n Number of params: %d' % param_numbers)

    # Loss function
    if loss_fn == 'MSE':
        criterion = torch.nn.MSELoss().type(dtype)
    if loss_fn == 'KLDiv':
        criterion = torch.nn.KLDivLoss().type(dtype)
        
    if save == True:
        f= open("{}/Stats.txt".format(save_path),"w+")
        f.write("{:>11}{:>12}{:>12}\n".format('Iterations','Total_Loss','PSNR'))
        save_net_details(save_path, arch, param_numbers, pad, OPT_OVER, OPTIMIZER, input_depth,
                 loss_fn = loss_fn, LR = LR, num_iter = num_iter, exp_weight = glparam.exp,
                 reg_noise_std = reg_noise_std, INPUT = 'INPUT', net = glparam.net)
                
    def closure(iter_value):
        show_every = 100
        figsize = 4
        
        ## Initialiaze/ Update variables
        if glparam.noise_std > 0.0:
            net_input = glparam.net_input_saved + (glparam.noise.normal_() * glparam.noise_std)
        net_input = torch.tensor(net_input, dtype=torch.float32, requires_grad=True)
        out = glparam.net(net_input)

        ## Exponential Smoothing
        if glparam.out_avg is None:
            glparam.out_avg = out.detach()
        else:
            glparam.out_avg = glparam.out_avg * glparam.exp + out.detach() * (1 - glparam.exp)
        
        ## Calculate loss
        total_loss = criterion(out, glparam.img_torch)
        total_loss.backward()
        set_trace()
        
        glparam.psnr_noisy = compare_psnr(glparam.img_np, out.detach().cpu().numpy()[0]).astype(np.float32)
            
        print ('DIP Iteration {:>11}   Loss {:>11.7f}   PSNR_noisy: {:>5.4f}'.format(
            iter_value, total_loss.item(), glparam.psnr_noisy), end='\r')
        
        ## Backtracking   
        if (glparam.psnr_noisy_last - glparam.psnr_noisy) > 5.0:
            glparam.interrupts = glparam.interrupts + 1
            print('\n Falling back to previous checkpoint.')
            glparam.net.load_state_dict(glparam.last_net.state_dict())
            glparam.optimizer.load_state_dict(glparam.optimizer_last.state_dict())
            
            if glparam.interrupts > 3:
                glparam.psnr_noisy_last = glparam.psnr_noisy
                
            if OPTIMIZER == "adam":     
                for j in range(iter_value % show_every - 1):                
                    glparam.optimizer.zero_grad()
                    closure(iter_value - (iter_value % show_every) + j + 1)
                    glparam.optimizer.step()
                glparam.optimizer.zero_grad()
                closure(iter_value)          
                print('\n Return back to the original')                        
                return total_loss 
            
            if OPTIMIZER == "EntropySGD":
                for j in range(iter_value % show_every - 1):
                    glparam.optimizer.zero_grad()
                    glparam.optimizer.step(iter_value - (iter_value % show_every) + j + 1, closure, glparam.net, criterion)
                glparam.optimizer.zero_grad()
                closure(iter_value)   
                print('\n Return back to the original')                        
                return total_loss                      
            
        if (iter_value % show_every) == 0: 
            glparam.last_net = deepcopy(glparam.net)
            glparam.psnr_noisy_last = glparam.psnr_noisy
            glparam.optimizer_last = deepcopy(glparam.optimizer)
            
            if glparam.interrupts > 3 :
                print("\n Error, was not able to converge after reset")
            glparam.interrupts = 0
            
            if glparam.PLOT:
                fig=plt.figure(figsize=(16, 16))
                fig.add_subplot(1, 3, 1)
                plt.imshow(np.clip(torch_to_np(out), 0, 1).transpose(1, 2, 0))
                plt.title('Output')
                fig.add_subplot(1, 3, 2)
                plt.imshow(np.clip(torch_to_np(glparam.out_avg), 0, 1).transpose(1, 2, 0))
                plt.title('Averaged Output')
                fig.add_subplot(1, 3, 3)
                plt.title('Original/Target')
                plt.imshow(glparam.img_np.transpose(1, 2, 0))
                plt.show()
                
            if glparam.save:
                f = open("{}/Stats.txt".format(save_path),"a")
                f.write("{:>11}{:>12.8f}{:>12.8f}\n".format(iter_value, total_loss.item(), glparam.psnr_noisy))
                plt.imsave("{}/it_{}.png".format(save_path,iter_value),
                       np.clip(torch_to_np(glparam.out_avg), 0, 1).transpose(1,2,0), format="png")
                
        return total_loss
        
    ### Optimize
    glparam.net.train()
    p = get_params(OPT_OVER, glparam.net, net_input)
    
    if OPTIMIZER == "adam":
        glparam.optimizer = torch.optim.Adam(p, lr = LR)
        for j in range(num_iter):
            glparam.optimizer.zero_grad()
            closure(j)
            glparam.optimizer.step()            
    if OPTIMIZER == "EntropySGD":
        glparam.optimizer = EntropySGD(p,config=dict(lr = LR))
        for j in range(num_iter):
            glparam.optimizer.zero_grad()
            glparam.optimizer.step(j, closure, glparam.net, criterion)    
    print('\n')       
    
    out = glparam.net(net_input)
    glparam.out_avg = glparam.out_avg * glparam.exp + out.detach() * (1 - glparam.exp)
    return glparam.out_avg

In [None]:
import cv2
img = cv2.imread('data/goldfish.jpg')[..., ::-1]
img_np = np.array(img)
_, dip(img_np, arch = 'simple')


 Number of params: 20355


  "See the documentation of nn.Upsample for details.".format(mode))


> <ipython-input-2-f5a0c3771562>(62)closure()
     60         set_trace()
     61 
---> 62         glparam.psnr_noisy = compare_psnr(glparam.img_np, out.detach().cpu().numpy()[0]).astype(np.float32)
     63 
     64         print ('DIP Iteration {:>11}   Loss {:>11.7f}   PSNR_noisy: {:>5.4f}'.format(

ipdb> out.grad
ipdb> input.grad
*** AttributeError: 'function' object has no attribute 'grad'
ipdb> net_input.grad
tensor([[[[ 1.7442e-07,  2.6749e-07,  1.6338e-07,  ...,  7.1478e-08,
           -1.0701e-07, -2.7952e-07],
          [ 1.0810e-07, -6.1325e-08, -4.6639e-08,  ..., -5.8087e-08,
            4.3745e-07,  1.5662e-07],
          [ 4.6022e-09, -2.3011e-07,  7.3344e-08,  ...,  1.9324e-07,
           -2.1687e-07, -1.2123e-07],
          ...,
          [-2.4005e-07, -4.4894e-09, -4.0040e-08,  ..., -2.6967e-07,
            1.6164e-07,  4.4678e-09],
          [-3.6670e-09,  1.1470e-07,  1.3444e-08,  ..., -1.9295e-08,
           -1.5533e-08, -9.0350e-08],
          [ 7.7571e-08, -6.9927e

          [-0.0035, -0.0059, -0.0554]]]], requires_grad=True)
ipdb> glparam.net[2][0].weight.grad
tensor([[[[-5.6845e-06,  3.3405e-06,  1.0355e-05],
          [ 1.4907e-05,  2.6446e-05,  3.5287e-05],
          [ 3.2837e-05,  4.5812e-05,  5.5732e-05]],

         [[ 5.5808e-05,  6.3223e-05,  6.8482e-05],
          [ 7.0558e-05,  7.9171e-05,  8.4847e-05],
          [ 8.1478e-05,  9.0887e-05,  9.6603e-05]],

         [[ 1.9398e-04,  2.1412e-04,  2.1874e-04],
          [ 2.2372e-04,  2.4952e-04,  2.5885e-04],
          [ 2.3946e-04,  2.6962e-04,  2.8303e-04]],

         ...,

         [[ 1.6346e-05,  8.4575e-06, -2.7543e-06],
          [ 3.6035e-05,  3.0402e-05,  2.0166e-05],
          [ 4.9293e-05,  4.6063e-05,  3.7428e-05]],

         [[ 5.5071e-05,  3.0745e-05,  5.0483e-06],
          [ 3.3234e-05,  5.7691e-06, -2.2066e-05],
          [ 5.5956e-06, -2.2662e-05, -4.9337e-05]],

         [[ 8.9389e-06,  9.0128e-06,  1.1829e-05],
          [ 1.8891e-07,  2.8809e-06,  8.0722e-06],
          

         [[-3.9979e-06]]]])
ipdb> glparam.net[4][0].weight.grad
*** TypeError: 'LeakyReLU' object does not support indexing
ipdb> glparam.net[8][0].weight.grad
tensor([[[[0.0039]],

         [[0.0066]],

         [[0.0005]],

         [[0.0132]],

         [[0.0051]],

         [[0.0029]],

         [[0.0075]],

         [[0.0074]],

         [[0.0162]],

         [[0.0200]],

         [[0.0132]],

         [[0.0007]],

         [[0.0022]],

         [[0.0174]],

         [[0.0083]],

         [[0.0100]]],


        [[[0.0038]],

         [[0.0083]],

         [[0.0005]],

         [[0.0129]],

         [[0.0068]],

         [[0.0040]],

         [[0.0070]],

         [[0.0093]],

         [[0.0189]],

         [[0.0173]],

         [[0.0211]],

         [[0.0008]],

         [[0.0029]],

         [[0.0199]],

         [[0.0098]],

         [[0.0105]]],


        [[[0.0038]],

         [[0.0080]],

         [[0.0004]],

         [[0.0131]],

         [[0.0061]],

         [[0.0031]],



In [None]:
image_dataset = ['it_0.png', 'it_100.png']
image_dataset.append(['it_{}.png'.format(100*i) for i in range(60, 101)])

In [None]:
print(image_dataset)

In [None]:
'it_' + 100*i + '.png'

In [None]:
import cv2
import matplotlib.pyplot as plt
img = cv2.imread('data/knife.jpg')

plt.imshow(img)
plt.show()

In [None]:
img