In [None]:
from __future__ import print_function
import matplotlib.pyplot as plt

import os
#os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import numpy as np
import h5py

import torch
import torch.optim

from skimage.metrics import peak_signal_noise_ratio as compare_psnr
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

import torch
import torch.nn as nn
import numpy as np
from torch.nn import Parameter
import torch.nn.functional as F
import torchvision
import sys


from PIL import Image
import PIL
import imageio
from scipy.signal import convolve2d as conv2
import re

In [None]:
class MeanOnlyBatchNorm(nn.Module):
    def __init__(self, num_features, momentum=0.1):
        super(MeanOnlyBatchNorm, self).__init__()
        self.num_features = num_features
        self.bias = Parameter(torch.Tensor(num_features))
        self.bias.data.zero_()

    def forward(self, inp):
        size = list(inp.size())
        beta = self.bias.view(1, self.num_features, 1, 1)
        avg = torch.mean(inp.view(size[0], self.num_features, -1), dim=2)

        output = inp - avg.view(size[0], size[1], 1, 1)
        output = output + beta

        return output

def bn(num_features):
    return MeanOnlyBatchNorm(num_features)
    #return nn.BatchNorm2d(num_features)

def l2normalize(v, eps=1e-12):
    return v / (v.norm() + eps)

class SpectralNorm(nn.Module):
    def __init__(self, module, ln_lambda=2.0, name='weight'):
        super(SpectralNorm, self).__init__()
        self.module = module
        self.name = name
        self.ln_lambda = torch.tensor(ln_lambda)
        if not self._made_params():
            self._make_params()

    def _update_u_v(self):

        w = getattr(self.module, self.name + "_bar")
        height = w.data.shape[0]

        _,w_svd,_ = torch.svd(w.view(height,-1).data, some=False, compute_uv=False)
        sigma = w_svd[0]
        sigma = torch.max(torch.ones_like(sigma),sigma/self.ln_lambda)
        setattr(self.module, self.name, w / sigma.expand_as(w))

    def _made_params(self):
        try:
            w = getattr(self.module, self.name + "_bar")
            return True
        except AttributeError:
            return False


    def _make_params(self):
        w = getattr(self.module, self.name)
        w_bar = Parameter(w.data)
        del self.module._parameters[self.name]
        self.module.register_parameter(self.name + "_bar", w_bar)


    def forward(self, *args):
        self._update_u_v()
        return self.module.forward(*args)

def conv(in_f, out_f, kernel_size=3, ln_lambda=2, stride=1, bias=True, pad='zero'):
    downsampler = None
    padder = None
    to_pad = int((kernel_size - 1) / 2)
    if pad == 'reflection':
        padder = nn.ReflectionPad2d(to_pad)
        to_pad = 0

    convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias)
    nn.init.kaiming_uniform_(convolver.weight, a=0, mode='fan_in')
    if ln_lambda>0:
        convolver = SpectralNorm(convolver, ln_lambda)

    layers = filter(lambda x: x is not None, [padder, convolver, downsampler])
    return nn.Sequential(*layers)

def get_kernel(kernel_width=5, sigma=0.5):

    kernel = np.zeros([kernel_width, kernel_width])
    center = (kernel_width + 1.)/2.
    sigma_sq =  sigma * sigma

    for i in range(1, kernel.shape[0] + 1):
        for j in range(1, kernel.shape[1] + 1):
            di = (i - center)/2.
            dj = (j - center)/2.
            kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq))
            kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq)

    kernel /= kernel.sum()

    return kernel

class gaussian(nn.Module):
    def __init__(self, n_planes,  kernel_width=5, sigma=0.5):
        super(gaussian, self).__init__()
        self.n_planes = n_planes
        self.kernel = get_kernel(kernel_width=kernel_width,sigma=sigma)

        convolver = nn.ConvTranspose2d(n_planes, n_planes, kernel_size=5, stride=2, padding=2, output_padding=1, groups=n_planes)
        convolver.weight.data[:] = 0
        convolver.bias.data[:] = 0
        convolver.weight.requires_grad = False
        convolver.bias.requires_grad = False

        kernel_torch = torch.from_numpy(self.kernel)
        for i in range(n_planes):
            convolver.weight.data[i, 0] = kernel_torch

        self.upsampler_ = convolver

    def forward(self, x):

        x = self.upsampler_(x)

        return x

In [None]:
class decoder(nn.Module):
    '''
        upsample_mode in ['deconv', 'nearest', 'bilinear', 'gaussian']
        pad in ['zero', 'replication', 'none']
    '''
    def __init__(self, num_input_channels=3, num_output_channels=3, ln_lambda=2,
                       upsample_mode='gaussian', pad='zero', need_sigmoid=True, need_bias=True):
        super(decoder, self).__init__()


        filters = [128, 128, 128, 128, 128]
        sigmas = [0.1,0.1,0.1,0.5,0.5]

        layers = []
        layers.append(unetConv2(num_input_channels, filters[0], ln_lambda, need_bias, pad))
        for i in range(len(filters)):
            layers.append(unetUp(filters[i], upsample_mode, ln_lambda, need_bias, pad, sigmas[i]))

        layers.append(conv(filters[-1], num_output_channels, 1, 0, bias=need_bias, pad=pad))
        if need_sigmoid:
            layers.append(nn.Sigmoid())

        self.net = nn.Sequential(*layers)

    def forward(self, x):

        return self.net(x)


class unetConv2(nn.Module):
    def __init__(self, in_size, out_size, ln_lambda, need_bias, pad):
        super(unetConv2, self).__init__()

        self.conv1= nn.Sequential(conv(in_size, out_size, 3, ln_lambda, bias=need_bias, pad=pad),
                                   bn(out_size),
                                   nn.LeakyReLU(),)
        self.conv2= nn.Sequential(conv(out_size, out_size, 3, ln_lambda, bias=need_bias, pad=pad),
                                   bn(out_size),
                                   nn.LeakyReLU(),)
    def forward(self, x):
        x= self.conv1(x)
        x= self.conv2(x)
        return x


class unetUp(nn.Module):
    def __init__(self, out_size, upsample_mode, ln_lambda, need_bias, pad, sigma=None):
        super(unetUp, self).__init__()

        num_filt = out_size
        if upsample_mode == 'deconv':
            self.up= nn.ConvTranspose2d(num_filt, out_size, 4, stride=2, padding=1)
            self.conv= conv(out_size, out_size, 3, ln_lambda, bias=need_bias, pad=pad)
        elif upsample_mode=='bilinear' or upsample_mode=='nearest':
            self.up = nn.Upsample(scale_factor=2, mode=upsample_mode)
            self.conv= unetConv2(out_size, out_size, ln_lambda, need_bias, pad)
        elif upsample_mode == 'gaussian':
            self.up = gaussian(out_size, kernel_width=5, sigma=sigma)
            self.conv= unetConv2(out_size, out_size,ln_lambda, need_bias, pad)
        else:
            assert False

    def forward(self, x):
        x= self.up(x)
        x = self.conv(x)

        return x

In [None]:
def save2img(d_img, fn):
    d_img = np.clip(d_img.transpose(1, 2, 0),0,1)
    img = d_img*255.0
    img = img.astype('uint8')
    imageio.imwrite(fn, img)

def save2enhanceimg(ori_img, out_img, fn):
    ori_img = np.clip(ori_img.transpose(1, 2, 0),0,1)
    out_img = np.clip(out_img.transpose(1, 2, 0),0,1)

    edge_img = np.clip((ori_img-out_img),0,1)
    en_img = np.clip(edge_img+ori_img,0,1)

    ori_img = ori_img*255.0
    ori_img = ori_img.astype('uint8')
    out_img = out_img*255.0
    out_img = out_img.astype('uint8')

    en_img = en_img*255.0
    en_img = en_img.astype('uint8')

    imageio.imwrite(fn, en_img)

def rgb2gray(rgb):

    r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2]
    gray = 0.2989 * r + 0.5870 * g + 0.1140 * b

    return np.clip(gray,0,1)

def crop_image(img, d=32):
    '''Make dimensions divisible by `d`'''

    new_size = (img.size[0] - img.size[0] % d,
                img.size[1] - img.size[1] % d)

    bbox = [
            int((img.size[0] - new_size[0])/2),
            int((img.size[1] - new_size[1])/2),
            int((img.size[0] + new_size[0])/2),
            int((img.size[1] + new_size[1])/2),
    ]

    img_cropped = img.crop(bbox)
    return img_cropped

def get_params(opt_over, net, net_input, downsampler=None):
    '''Returns parameters that we want to optimize over.

    Args:
        opt_over: comma separated list, e.g. "net,input" or "net"
        net: network
        net_input: torch.Tensor that stores input `z`
    '''
    opt_over_list = opt_over.split(',')
    params = []

    for opt in opt_over_list:

        if opt == 'net':
            params += [x for x in net.parameters() ]
        elif  opt=='down':
            assert downsampler is not None
            params = [x for x in downsampler.parameters()]
        elif opt == 'input':
            net_input.requires_grad = True
            params += [net_input]
        else:
            assert False, 'what is it?'

    return params

def get_image_grid(images_np, nrow=8):
    '''Creates a grid from a list of images by concatenating them.'''
    images_torch = [torch.from_numpy(x) for x in images_np]
    torch_grid = torchvision.utils.make_grid(images_torch, nrow)

    return torch_grid.numpy()

def plot_image_grid(images_np, nrow =8, factor=1, interpolation='lanczos'):
    """Draws images in a grid

    Args:
        images_np: list of images, each image is np.array of size 3xHxW of 1xHxW
        nrow: how many images will be in one row
        factor: size if the plt.figure
        interpolation: interpolation used in plt.imshow
    """
    n_channels = max(x.shape[0] for x in images_np)
    assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"

    images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]

    grid = get_image_grid(images_np, nrow)

    plt.figure(figsize=(len(images_np) + factor, 12 + factor))

    if images_np[0].shape[0] == 1:
        plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
    else:
        plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)

    plt.show()

    return grid

def load(path):
    """Load PIL image."""
    img = Image.open(path)
    return img

def get_image(path, imsize=-1):
    """Load an image and resize to a cpecific size.

    Args:
        path: path to image
        imsize: tuple or scalar with dimensions; -1 for `no resize`
    """
    img = load(path)

    if isinstance(imsize, int):
        imsize = (imsize, imsize)

    if imsize[0]!= -1 and img.size != imsize:
        if imsize[0] > img.size[0]:
            img = img.resize(imsize, Image.BICUBIC)
        else:
            img = img.resize(imsize, Image.ANTIALIAS)

    img_np = pil_to_np(img)

    return img, img_np



def fill_noise(x, noise_type):
    """Fills tensor `x` with noise of type `noise_type`."""
    if noise_type == 'u':
        x.uniform_()
    elif noise_type == 'n':
        x.normal_()
    else:
        assert False

def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10):
    """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`)
    initialized in a specific way.
    Args:
        input_depth: number of channels in the tensor
        method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid
        spatial_size: spatial size of the tensor to initialize
        noise_type: 'u' for uniform; 'n' for normal
        var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler.
    """
    if isinstance(spatial_size, int):
        spatial_size = (spatial_size, spatial_size)
    if method == 'noise':
        shape = [1, input_depth, spatial_size[0], spatial_size[1]]
        net_input = torch.zeros(shape)

        fill_noise(net_input, noise_type)
        net_input *= var
    elif method == 'fourier':
        shape = [1, input_depth//2, spatial_size[0], spatial_size[1]]
        net_input = torch.zeros(shape)

        fill_noise(net_input, noise_type)
        net_input *= var

        net_input = torch.cat([torch.sin(2.*np.pi*net_input), torch.cos(2.*np.pi*net_input)], 1)

    elif method == 'meshgrid':
        assert input_depth == 2
        X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1))
        meshgrid = np.concatenate([X[None,:], Y[None,:]])
        net_input=  np_to_torch(meshgrid)
    else:
        assert False

    return net_input

def pil_to_np(img_PIL):
    '''Converts image in PIL format to np.array.

    From W x H x C [0...255] to C x W x H [0..1]
    '''
    ar = np.array(img_PIL)

    if len(ar.shape) == 3:
        ar = ar.transpose(2,0,1)
    else:
        ar = ar[None, ...]

    return ar.astype(np.float32) / 255.

def np_to_pil(img_np):
    '''Converts image in np.array format to PIL image.

    From C x W x H [0..1] to  W x H x C [0...255]
    '''
    ar = np.clip(img_np*255,0,255).astype(np.uint8)

    if img_np.shape[0] == 1:
        ar = ar[0]
    else:
        ar = ar.transpose(1, 2, 0)

    return Image.fromarray(ar)

def np_to_torch(img_np):
    '''Converts image in numpy.array to torch.Tensor.

    From C x W x H [0..1] to  C x W x H [0..1]
    '''
    return torch.from_numpy(img_np)[None, :]

def torch_to_np(img_var):
    '''Converts an image in torch.Tensor format to np.array.

    From 1 x C x W x H [0..1] to  C x W x H [0..1]
    '''
    return img_var.detach().cpu().numpy()[0]


def optimize(optimizer_type, parameters, closure, LR, num_iter):
    """Runs optimization loop.

    Args:
        optimizer_type: 'LBFGS' of 'adam'
        parameters: list of Tensors to optimize over
        closure: function, that returns loss variable
        LR: learning rate
        num_iter: number of iterations
    """
    if optimizer_type == 'LBFGS':
        # Do several steps with adam first
        optimizer = torch.optim.Adam(parameters, lr=0.001)
        for j in range(100):
            optimizer.zero_grad()
            closure()
            optimizer.step()

        print('Starting optimization with LBFGS')
        def closure2():
            optimizer.zero_grad()
            return closure()
        optimizer = torch.optim.LBFGS(parameters, max_iter=num_iter, lr=LR, tolerance_grad=-1, tolerance_change=-1)
        optimizer.step(closure2)

    elif optimizer_type == 'adam':
        print('Starting optimization with ADAM')
        optimizer = torch.optim.Adam(parameters, lr=LR)

        for j in range(num_iter):
            optimizer.zero_grad()
            closure()
            optimizer.step()
    else:
        assert False

def get_noisy_image(img_np, sigma):
    """Adds Gaussian noise to an image.

    Args:
        img_np: image, np.array with values from 0 to 1
        sigma: std of the noise
    """
    img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma, size=img_np.shape), 0, 1).astype(np.float32)
    img_noisy_pil = np_to_pil(img_noisy_np)

    return img_noisy_pil, img_noisy_np

In [None]:
def get_circular_statastic(img_it, img_gt, size=0.2):

    if len(img_it.shape)==3:
        img_it = rgb2gray(img_it)

    if len(img_gt.shape)==3:
        img_gt = rgb2gray(img_gt)

    assert(size>0 and size<1)

    ftimage_it = np.fft.fft2(img_it)
    ftimage_it = abs(np.fft.fftshift(ftimage_it))

    ftimage_gt = np.fft.fft2(img_gt)
    ftimage_gt = abs(np.fft.fftshift(ftimage_gt))

    m_data = ftimage_it/(ftimage_gt+1e-8)
    m_data = np.clip(m_data, 0, 1)

    h,w = m_data.shape

    center = (int(w/2), int(h/2))
    Y, X = np.ogrid[:h, :w]
    dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)

    avg_mask_list = []
    pre_mask = np.zeros((h,w))
    for sz in np.linspace(size, 1, int(1/size)):

        radius = center[0]*sz#pow(center[0]**2+center[1]**2,0.5)
        mask = dist_from_center <= radius
        mask = mask.astype(np.int32)

        mask_sz = (mask-pre_mask).astype(np.int32)
        pre_mask = mask

        avg_mask_list.append(np.sum(mask_sz*m_data)/np.sum(mask_sz))

    return avg_mask_list

def PerceptualBlurMetric (Image, FiltSize=9):

    if len(Image.shape)==3:
        Image = rgb2gray(Image)

    m, n = Image.shape[0],Image.shape[1]

    Hv = 1.0/FiltSize*np.ones((1,FiltSize))
    Hh = Hv.T
    Bver = conv2(Image, Hv, 'same')
    Bhor = conv2(Image, Hh, 'same')

    s_ind = int(np.ceil(FiltSize/2))
    e_ind = int(np.floor(FiltSize/2))
    Bver = Bver[s_ind:m-e_ind, s_ind:n-e_ind]
    Bhor = Bhor[s_ind:m-e_ind, s_ind:n-e_ind]
    Image = Image[s_ind:m-e_ind, s_ind:n-e_ind]
    m, n = Image.shape[0],Image.shape[1]

    Hv = np.asarray([[1, -1]])
    Hh = Hv.T
    D_Fver = abs(conv2(Image, Hv, 'same'))
    D_Fhor = abs(conv2(Image, Hh, 'same'))
    D_Fver = D_Fver[1:m-1, 1:n-1]
    D_Fhor = D_Fhor[1:m-1, 1:n-1]

    D_Bver = abs(conv2(Bver, Hv, 'same'))
    D_Bhor = abs(conv2(Bhor, Hh, 'same'))
    D_Bver = D_Bver[1:m-1, 1:n-1]
    D_Bhor = D_Bhor[1:m-1, 1:n-1]


    D_Vver = D_Fver-D_Bver
    D_Vver[D_Vver<0] = 0
    D_Vhor = D_Fhor-D_Bhor
    D_Vhor[D_Vhor<0] = 0

    s_Fver = np.sum(D_Fver)
    s_Fhor = np.sum(D_Fhor)
    s_Vver = np.sum(D_Vver)
    s_Vhor = np.sum(D_Vhor)

    b_Fver = (s_Fver - s_Vver)/s_Fver
    b_Fhor = (s_Fhor - s_Vhor)/s_Fhor

    IDM = max(b_Fver, b_Fhor)

    return IDM

def MLVMap(im):

    if len(im.shape)==3:
        im = rgb2gray(im)

    xs, ys = im.shape
    x=im

    x1=np.zeros((xs,ys))
    x2=np.zeros((xs,ys))
    x3=np.zeros((xs,ys))
    x4=np.zeros((xs,ys))
    x5=np.zeros((xs,ys))
    x6=np.zeros((xs,ys))
    x7=np.zeros((xs,ys))
    x8=np.zeros((xs,ys))
    x9=np.zeros((xs,ys))

    x1[1:xs-2,1:ys-2] = x[2:xs-1,2:ys-1]
    x2[1:xs-2,2:ys-1] = x[2:xs-1,2:ys-1]
    x3[1:xs-2,3:ys]   = x[2:xs-1,2:ys-1]
    x4[2:xs-1,1:ys-2] = x[2:xs-1,2:ys-1]
    x5[2:xs-1,2:ys-1] = x[2:xs-1,2:ys-1]
    x6[2:xs-1,3:ys]   = x[2:xs-1,2:ys-1]
    x7[3:xs,1:ys-2]   = x[2:xs-1,2:ys-1]
    x8[3:xs,2:ys-1]   = x[2:xs-1,2:ys-1]
    x9[3:xs,3:ys]     = x[2:xs-1,2:ys-1]

    x1=x1[2:xs-1,2:ys-1]
    x2=x2[2:xs-1,2:ys-1]
    x3=x3[2:xs-1,2:ys-1]
    x4=x4[2:xs-1,2:ys-1]
    x5=x5[2:xs-1,2:ys-1]
    x6=x6[2:xs-1,2:ys-1]
    x7=x7[2:xs-1,2:ys-1]
    x8=x8[2:xs-1,2:ys-1]
    x9=x9[2:xs-1,2:ys-1]

    d1=x1-x5
    d2=x2-x5
    d3=x3-x5
    d4=x4-x5
    d5=x6-x5
    d6=x7-x5
    d7=x8-x5
    d8=x9-x5

    dd=np.maximum(d1,d2)
    dd=np.maximum(dd,d3)
    dd=np.maximum(dd,d4)
    dd=np.maximum(dd,d5)
    dd=np.maximum(dd,d6)
    dd=np.maximum(dd,d7)
    dd=np.maximum(dd,d8)

    return dd

def MLVSharpnessMeasure(im):
    T=1000;
    alpha=-0.01

    im_map = MLVMap(im)
    xs, ys = im_map.shape

    xy_number=xs*ys
    l_number=int(xy_number)
    vec = np.reshape(im_map,(xy_number))
    vec=sorted(vec.tolist(),reverse = True)
    svec=np.array(vec[1:l_number])

    a=range(1,xy_number)
    q=np.exp(np.dot(alpha,a))
    svec=svec*q
    svec=svec[1:T]
    sigma = np.sqrt(np.mean(np.power(svec,2)))

    return sigma

In [None]:
plt.rcParams.update({'font.size': 18})

def get_log_data(file_name):
    file1 = open(file_name, 'r')
    Lines = file1.readlines()
    file1.close()

    frequency_lists=[]
    psnr_list=[]
    ratio_list=[]

    for line in Lines:
        strings = re.split(':|,', line.strip())
        #print (strings)
        fre_list = []
        fre_list.append(float(strings[-7][2:]))
        fre_list.append(float(strings[-6]))
        fre_list.append(float(strings[-5]))
        fre_list.append(float(strings[-4]))
        fre_list.append(float(strings[-3][:-1]))
        frequency_lists.append(np.array(fre_list))

        psnr_list.append(float(strings[5]))
        ratio_list.append(float(strings[-1]))


    return frequency_lists, np.array(psnr_list), np.array(ratio_list)

def get_fbc_fig(all_norms,num_iter,ylim=1,save_path='',img_name=''):
    fig, ax = plt.subplots(figsize=(7,6))
    ax.set_xlim(0,num_iter)
    ax.set_ylim(0, ylim)

    norms=np.array(all_norms)

    label_list = ['Frequency band (1,lowest)','Frequency band (2)','Frequency band (3)','Frequency band (4)','Frequency band (5, highest)']

    plt.xlabel("Optimization Iteration")
    plt.ylabel("FBC ($\\bar{H}$)")
    #plt.title('FBC (%s)'%img_name)

    color_list = ['#331900', '#994C00', '#CC6600',  '#FF8000', '#FF9933']
    rate = 1
    for i in range(norms.shape[1]):
        plt.plot(range(0,num_iter,rate), norms[:num_iter:rate,i], linewidth=4, color=color_list[i], label=label_list[i])

    plt.legend(loc=4,)
    plt.grid()
    plt.savefig(save_path)
    #plt.show()
    plt.close()


def get_psnr_ratio_fig(all_datas,num_iter,ylim=35, ylabel='',save_path='',img_name=''):
    fig, ax = plt.subplots(figsize=(7,6))
    ax.set_xlim(0,num_iter)
    ax.set_ylim(0, ylim)

    plt.xlabel("Optimization Iteration")
    #plt.ylabel(ylabel)
    #plt.title(img_name)

    label_list = ['PSNR','Ratio']
    color_list = ['#d94a31','#4b43db']

    rate = 1
    for i in range(len(all_datas)):
        plt.plot(range(0,num_iter,rate), all_datas[i][0:num_iter:rate], linewidth=4, color=color_list[i], label=label_list[i])

    plt.legend(loc=0,)
    plt.grid()
    plt.savefig(save_path)
    #plt.show()
    plt.close()

In [None]:
#boat,barbara,Cameraman256,couple,fingerprint,hill
#house,Lena512,man,montage,peppers256
img_name = 'boat'
fname = '/content/data/{}.png'.format(img_name)
h5py_fname = '/content/data/{}.h5'.format(img_name)


if not os.path.exists('./figs'):
        os.mkdir('./figs')

if not os.path.exists('./logs'):
        os.mkdir('./logs')

log_path = "./logs/%s.txt"%img_name
log_file = open(log_path, "w")

#read image
img_pil, img_np = get_image(fname, -1)
#np.random.seed(10000)
#img_mask_np = (np.random.random_sample(size=img_np.shape) > 0.5).astype(int)
with h5py.File(h5py_fname, 'r') as hf:
    img_mask_np = hf['mask'][()]
img_mask_pil = np_to_pil(img_mask_np)

img_mask_pil = crop_image(img_mask_pil, 32)
img_pil      = crop_image(img_pil,      32)
img_np      = pil_to_np(img_pil)
img_mask_np = pil_to_np(img_mask_pil)

#input type
INPUT = 'fourier' # 'meshgrid', 'noise', 'fourier'
var=1
input_depth = 32
net_input = get_noise(input_depth, INPUT, (img_pil.size[1]//32, img_pil.size[0]//32),var=var).type(dtype).detach()

#network parameters
ln_lambda=1.6#the lambda in Lipschitz normalization, which is used to control spectral bias
upsample_mode='bilinear'#['deconv', 'nearest', 'bilinear', 'gaussian'], where 'gaussian' denotes our Gaussian upsampling.
pad = 'reflection'
#decoder is the used network architecture in the paper
net = decoder(num_input_channels=input_depth, num_output_channels=1, ln_lambda=ln_lambda,
                   upsample_mode=upsample_mode, pad=pad, need_sigmoid=True, need_bias=True).type(dtype)

#optimization parameters
OPTIMIZER='adam'
num_iter = 4000
LR = 0.001
reg_noise_std = 0#1./30, injecting noise in the input.
show_every = 100

#automatic stopping
ratio_list = np.zeros((num_iter))
ratio_iter=100#the n in Eq. (8)
ratio_epsilon=0.01#the ratio difference threshold
auto_stop = False

# Loss
mse = torch.nn.MSELoss().type(dtype)
img_var = np_to_torch(img_np).type(dtype)
mask_var = np_to_torch(img_mask_np).type(dtype)

net_input_saved = net_input.detach().clone()
noise = net_input.detach().clone()

i = 0
def closure():

    global i, out, net_input

    if reg_noise_std > 0:
        net_input = net_input_saved + (noise.normal_() * reg_noise_std)

    out = net(net_input)

    total_loss = mse(out * mask_var, img_var * mask_var)
    total_loss.backward()

    psrn_gt = compare_psnr(img_np, out.detach().cpu().numpy()[0])

    pre_img = out.detach().cpu().numpy()[0]
    pre_img = pre_img.transpose(1, 2, 0)

    img_noisy_np = img_np*img_mask_np
    noisy_img = img_noisy_np.transpose(1, 2, 0)

    #frequency-band correspondence metric
    avg_mask_it = get_circular_statastic(pre_img[:,:,0], noisy_img[:,:,0],  size=0.2)

    #automatic stopping
    blur_it = PerceptualBlurMetric (pre_img[:,:,0])#the blurriness of the output image
    sharp_it = MLVSharpnessMeasure(pre_img[:,:,0])#the sharpness of the output image
    ratio_it = blur_it/sharp_it#the ratio

    if auto_stop:
        ratio_list[i] = ratio_it
        if i>ratio_iter*2:
            ratio1 = np.mean(ratio_list[i-ratio_iter*2:i-ratio_iter])
            ratio2 = np.mean(ratio_list[i-ratio_iter+1:i])
            if np.abs(ratio1-ratio2)<ratio_epsilon:
                print("The optimization is automatically stopped!")
                out_np = torch_to_np(out)
                save2img(out_np, "./figs/%s_inpainted.png" % img_name)
                exit()

    print ('Iteration: %05d, Loss: %f, PSRN_gt: %f' % (i, total_loss.item(), psrn_gt))
    log_file.write('Iteration: %05d, Loss: %f, PSRN_gt: %f, mask: %s, ratio: %f\n' % (i, total_loss.item(), psrn_gt, avg_mask_it, ratio_it))
    log_file.flush()

    i += 1

    return total_loss

optimize(OPTIMIZER, net.parameters(), closure, LR, num_iter)
log_file.close()

#visualization

out_np = torch_to_np(out)
save2img(out_np, "./figs/%s_inpainted.png" % img_name)#save the denoised image

frequency_lists, psnr_list, ratio_list = get_log_data(log_path)
get_fbc_fig(frequency_lists,num_iter,ylim=1,save_path="./figs/%s_fbc.png"%img_name)#save the fbc figure

data_lists =[]
data_lists.append(psnr_list)
data_lists.append(ratio_list)
get_psnr_ratio_fig(data_lists,num_iter,ylim=35, ylabel='PSNR', save_path="./figs/%s_psnr_ratio.png"%img_name)#save the psnr_ratio figure


Starting optimization with ADAM
Iteration: 00000, Loss: 0.017116, PSRN_gt: 14.660783
Iteration: 00001, Loss: 0.015921, PSRN_gt: 14.973128
Iteration: 00002, Loss: 0.009531, PSRN_gt: 17.187246
Iteration: 00003, Loss: 0.014354, PSRN_gt: 15.440063
Iteration: 00004, Loss: 0.009378, PSRN_gt: 17.274614
Iteration: 00005, Loss: 0.010682, PSRN_gt: 16.700172
Iteration: 00006, Loss: 0.009454, PSRN_gt: 17.223781
Iteration: 00007, Loss: 0.007399, PSRN_gt: 18.285815
Iteration: 00008, Loss: 0.006522, PSRN_gt: 18.841015
Iteration: 00009, Loss: 0.006546, PSRN_gt: 18.832986
Iteration: 00010, Loss: 0.005859, PSRN_gt: 19.312805
Iteration: 00011, Loss: 0.005492, PSRN_gt: 19.592278
Iteration: 00012, Loss: 0.005429, PSRN_gt: 19.644470
Iteration: 00013, Loss: 0.005109, PSRN_gt: 19.911804
Iteration: 00014, Loss: 0.004788, PSRN_gt: 20.195771
Iteration: 00015, Loss: 0.004540, PSRN_gt: 20.427324
Iteration: 00016, Loss: 0.004323, PSRN_gt: 20.640524
Iteration: 00017, Loss: 0.004197, PSRN_gt: 20.766737
Iteration: 000

ValueError: Can't write images with one color channel.