In [None]:
from PIL import Image
import os
import glob
from PIL import Image, ImageFont, ImageDraw
import numpy as np
import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils import data as dat
import time
import torch.nn.functional as F
from torch.autograd import Variable
from utils import *

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Deep prior - DnCNN

In [None]:
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        out = self.dncnn(x)
        return out
    
dncnn_s = DnCNN(channels=3, num_of_layers=17)
device_ids = [0]
dncnn_s = nn.DataParallel(dncnn_s, device_ids=device_ids).cuda()
#--------------------------------------
#Add the pretrained model to the path 
#--------------------------------------
dncnn_s.load_state_dict(torch.load(os.path.join('checkpoints/dncnn_s25.pth')))
dncnn_s.eval()

# Deep prior - VDN

In [None]:
def conv3x3(in_chn, out_chn, bias=True):
    layer = nn.Conv2d(in_chn, out_chn, kernel_size=3, stride=1, padding=1, bias=bias)
    return layer
class dncnn(nn.Module):
    def __init__(self, in_channels, out_channels, dep=20, num_filters=64, slope=0.2):
        '''

        Args:
            in_channels (int): number of input channels
            out_channels (int): number of output channels
            dep (int): depth of the network, Default 20
            num_filters (int): number of filters in each layer, Default 64
        '''
        super(dncnn, self).__init__()
        self.conv1 = conv3x3(in_channels, num_filters, bias=True)
        self.relu = nn.LeakyReLU(slope, inplace=True)
        mid_layer = []
        for ii in range(1, dep-1):
            mid_layer.append(conv3x3(num_filters, num_filters, bias=True))
            mid_layer.append(nn.LeakyReLU(slope, inplace=True))
        self.mid_layer = nn.Sequential(*mid_layer)
        self.conv_last = conv3x3(num_filters, out_channels, bias=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.mid_layer(x)
        out = self.conv_last(x)

        return out

class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=6, depth=4, wf=64, slope=0.2):
        """

        Args:
            in_channels (int): number of input channels, Default 3
            depth (int): depth of the network, Default 4
            wf (int): number of filters in the first layer, Default 32
        """
        super(UNet, self).__init__()
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(UNetConvBlock(prev_channels, (2**i)*wf, slope))
            prev_channels = (2**i) * wf

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(UNetUpBlock(prev_channels, (2**i)*wf, slope))
            prev_channels = (2**i)*wf

        self.last = conv3x3(prev_channels, out_channels, bias=True)

    def forward(self, x):
        blocks = []
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path)-1:
                blocks.append(x)
                x = F.avg_pool2d(x, 2)

        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i-1])

        return self.last(x)

class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, slope=0.2):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=1, bias=True))
        block.append(nn.LeakyReLU(slope, inplace=True))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=1, bias=True))
        block.append(nn.LeakyReLU(slope, inplace=True))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out

class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, slope=0.2):
        super(UNetUpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2, bias=True)
        self.conv_block = UNetConvBlock(in_size, out_size, slope)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[:, :, diff_y:(diff_y + target_size[0]), diff_x:(diff_x + target_size[1])]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out


def weight_init_kaiming(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            if not m.bias is None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
    return net

class VDN(nn.Module):
    def __init__(self, in_channels, wf=64, dep_S=5, dep_U=4, slope=0.2):
        super(VDN, self).__init__()
        net1 = UNet(in_channels, in_channels*2, wf=wf, depth=dep_U, slope=slope)
        self.DNet = weight_init_kaiming(net1)
        net2 = dncnn(in_channels, in_channels*2, dep=dep_S, num_filters=64, slope=slope)
        self.SNet = weight_init_kaiming(net2)

    def forward(self, x, mode='train'):
        if mode.lower() == 'train':
            phi_Z = self.DNet(x)
            phi_sigma = self.SNet(x)
            return phi_Z, phi_sigma
        elif mode.lower() == 'test':
            phi_Z = self.DNet(x)
            return phi_Z
        elif mode.lower() == 'sigma':
            phi_sigma = self.SNet(x)
            return phi_sigma
#--------------------------------------
#Add the pretrained model to the path 
#--------------------------------------
checkpoint = torch.load('checkpoints/model_state_niidgauss')
vdn = VDN(3, dep_U=4, wf=64)
vdn = torch.nn.DataParallel(vdn).cuda()
vdn.load_state_dict(checkpoint)
vdn.eval()

# Adaptive Deep Priors Framework

In [None]:
print('Loading model ...\n')
#----------------
#Basic settings
#----------------
Plot = True
psnr_test = 0
ssim_test = 0
sigma = 15
test_data = 'data/BSD68/'
iters = 2
ii = 0
start = time.time()
#-----------------------------------------
#Adaptive Deep Priors Framework on BSD68
#-----------------------------------------
for img in os.listdir(test_data):
    ii+=1
    print('Testing on the %dth pictures'%ii)
    #--------------------------------------
    #(1.1) Data augmentation of the input 
    #--------------------------------------
    fname=test_data+'/'+img
    img_pil = crop_image(Image.open(fname))
    img_np = pil_to_np(img_pil)
    if img_np.shape[1]>img_np.shape[2]:
        out = img_pil.transpose(Image.ROTATE_90)
        img_np = pil_to_np(out)
    img_torch = np_to_torch(img_np).cuda()
    gt_pil=crop_image(get_image(fname, -1)[0])
    gt_np = pil_to_np(gt_pil)
    if gt_np.shape[1]>gt_np.shape[2]:
        out = gt_pil.transpose(Image.ROTATE_90)
        gt_np = pil_to_np(out)
    ISource = img_torch
    #---------------------------------
    #(1.2) Add the noise of the input
    #---------------------------------
    img_noisy_pil,img_noisy_np = get_noisy_image(gt_np,sigma/255.)
    INoisy = np_to_torch(img_noisy_np)
    ISource, INoisy = Variable(ISource.cuda()), Variable(INoisy.cuda())
    img_noisy_np = torch_to_np(INoisy)
    #--------------------------------------------------------
    #(2) Initialization of the parameters
    # You need to change the value for different conditions
    #--------------------------------------------------------
    yita = 0.6
    delta =0.01
    lam = 1 - delta*(1 + yita)
    rou = 0.00001
    A_ = np.eye(gt_np.shape[1])*lam #shape: [a,a]
    A_T = np.eye(gt_np.shape[1]) # shape: [a,a]
    A = A_T
    y = torch_to_np(INoisy) #shape: [3,a,b]
    v_0 = np.zeros(gt_np.shape) #shape [3,a,b]
    x_0 = np.zeros(gt_np.shape) #shape [3,a,b]
    for i in range(y.shape[0]):   
        y_ = np.squeeze(y[i,:,:]) 
        x_0[i,:,:] = np.dot(A_T,y_)
    x_next = x_0  #shape [3,a,b]        
    v_next = v_0  #shape [3,a,b]
    space=np.zeros(gt_np.shape) #shape [3,a,b]

    with torch.no_grad(): # this can save much memory
        #--------------------------------------
        #(3) The main loop of our framework
        #--------------------------------------
        for i in range(iters):
            #-----------------------------------------------
            #(3.1) The forward model with gradient descent
            #-----------------------------------------------
            for i in range(x_next.shape[0]):
                x_next_ = x_next[i,:,:]
                y_ = y[i,:,:]
                v_next_ = v_next[i,:,:]
                x_next[i,:,:] = A_.dot(x_next_) + delta * A_T.dot(y_) + delta *v_next_ 
            x_next = np_to_torch(x_next).cuda().float()
            #-----------------------------------------
            #(3.2) Plug the deep prior
            #-----------------------------------------
            phi_Z = vdn(x_next, 'test')
            err = phi_Z.cpu().numpy()
            im_noisy = x_next.cpu().numpy()
            #-----------------------------------------
            #(3.3) Denoising
            #-----------------------------------------
            im_denoise = im_noisy - err[:, :3,]
            im_denoise = im_denoise.squeeze()
            v_next = np_to_torch(im_denoise).cuda()

            x_next = torch_to_np(x_next)
            v_next = torch_to_np(v_next)  
            #-----------------------------------------
            #(3.4) Update A for overcoming the blurriness
            #-----------------------------------------
            for j in range(x_next.shape[0]):
                A_temp =  torch_to_np(INoisy - dncnn_s(INoisy))[j,:,:]#shape [3,a,b]
                x_temp =  ((v_next[j,:,:].T).dot(A)).dot(v_next[j,:,:])[:gt_np.shape[1],:]#
                space[j,:,:] = rou*(A_temp+x_temp)
            t =  ((space[0,:,:]+space[1,:,:]+space[2,:,:])/3)[:gt_np.shape[1],:gt_np.shape[1]]
            A = (A -t)[:gt_np.shape[1],:gt_np.shape[1]]
            A_T = A.T        
    v_next =pil_to_np(np_to_pil(v_next).convert('L'))
    v_hat = np_to_torch(v_next) 
    Out = torch.clamp(v_hat,0.,1.)
    stop = time.time() - start
    out_img = torch_to_np(Out)
    
    if Plot:
        plot_image_grid([np.clip(out_img, 0, 1),
                         np.clip(img_noisy_np, 0, 1),np.clip(img_np, 0, 1)], factor=20, nrow=3)

    psnr = batch_PSNR(Out, ISource,1.)
    ssim = batch_SSIM(Out,ISource)
    psnr_test += psnr
    ssim_test += ssim
    print('Time consumes %.3f s'%stop)
    print("PSNR is: %f SSIM is: %f" % (psnr,ssim))
psnr_test /= len(os.listdir(test_data))
print("\nAverage PSNR on test data is: %f" % psnr_test)
ssim_test /= len(os.listdir(test_data))
print("\nAverage SSIM on test data is: %f" % ssim_test)
end = time.time() - start
if not Plot:
    ave_time = end/len(os.listdir(test_data))
    print("\nAverage Time on test data is: %f" % ave_time)