In [1]:
import torch
import torch as t
from skimage.measure import compare_psnr
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import numpy as np
import os
import math
import matplotlib.pyplot as plt
import random
import cv2
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import time
from PIL import Image
import models as models
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
dtype = t.cuda.HalfTensor

# Hyper Parameters

In [2]:
METHOD = "SDNet3"
basic_dir='datas/GoPro/test/GOPR0881_11_01/'
SAMPLE_DIR = basic_dir + "blur"
EXPDIR = basic_dir + "deblur"
sharp = basic_dir +'sharp'
    
GPU = 0

# Preprocessing

In [3]:
def save_images(images, name):
    filename =  name + '.png'
    torchvision.utils.save_image(images, filename)

def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, 0.5*math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        m.bias.data = torch.ones(m.bias.data.size())


# DMPHN

In [4]:
encoder = {}
decoder = {}
encoder_optim = {}
decoder_optim = {}
encoder_scheduler = {}
decoder_scheduler = {}
for s in ['s1', 's2', 's3']:
    encoder[s] = {}
    decoder[s] = {}
    encoder_optim[s] = {}
    decoder_optim[s] = {}
    encoder_scheduler[s] = {}
    decoder_scheduler[s] = {}
    for lv in ['lv1', 'lv2', 'lv3']:
        encoder[s][lv] = models.Encoder()
        decoder[s][lv] = models.Decoder()
        encoder[s][lv].apply(weight_init).cuda(GPU).half()
        decoder[s][lv].apply(weight_init).cuda(GPU).half()
        
        if os.path.exists(str('./checkpoints/' + METHOD + "/encoder_" + s + "_" + lv + ".pkl")):
            encoder[s][lv].load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/encoder_" + s + "_" + lv + ".pkl")))
            #print("load encoder_" + s + "_" + lv + " successfully!")
        if os.path.exists(str('./checkpoints/' + METHOD + "/decoder_" + s + "_" + lv + ".pkl")):
            decoder[s][lv].load_state_dict(torch.load(str('./checkpoints/' + METHOD + "/decoder_" + s + "_" + lv + ".pkl")))
            #print("load decoder_" + s + "_" + lv + " successfully!")

In [5]:
def DMPHN(images_lv1):
    images = {}
    feature = {}
    residual = {}
    for s in ['s1', 's2', 's3']:
        feature[s] = {}
        residual[s] = {}


    images['lv1'] = images_lv1
    images['lv1'] = Variable(images['lv1']).cuda(GPU).half()
    
    H = images['lv1'].size(2)
    W = images['lv1'].size(3)

    images['lv2_1'] = images['lv1'][:,:,0:int(H/2),:]
    images['lv2_2'] = images['lv1'][:,:,int(H/2):H,:]
    images['lv3_1'] = images['lv2_1'][:,:,:,0:int(W/2)]
    images['lv3_2'] = images['lv2_1'][:,:,:,int(W/2):W]
    images['lv3_3'] = images['lv2_2'][:,:,:,0:int(W/2)]
    images['lv3_4'] = images['lv2_2'][:,:,:,int(W/2):W]

    s = 's1'		
    feature[s]['lv3_1'] = encoder[s]['lv3'](images['lv3_1'])
    feature[s]['lv3_2'] = encoder[s]['lv3'](images['lv3_2'])
    feature[s]['lv3_3'] = encoder[s]['lv3'](images['lv3_3'])
    feature[s]['lv3_4'] = encoder[s]['lv3'](images['lv3_4'])
    feature[s]['lv3_top'] = torch.cat((feature[s]['lv3_1'], feature[s]['lv3_2']), 3)
    feature[s]['lv3_bot'] = torch.cat((feature[s]['lv3_3'], feature[s]['lv3_4']), 3)
    residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top'])
    residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot'])

    feature[s]['lv2_1'] = encoder[s]['lv2'](images['lv2_1'] + residual[s]['lv3_top']) + feature[s]['lv3_top']
    feature[s]['lv2_2'] = encoder[s]['lv2'](images['lv2_2'] + residual[s]['lv3_bot']) + feature[s]['lv3_bot']
    feature[s]['lv2'] = torch.cat((feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
    residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2'])

    feature[s]['lv1'] = encoder[s]['lv1'](images['lv1'] + residual[s]['lv2']) + feature[s]['lv2']
    residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

    s = 's2'
    ps = 's1'
    feature[s]['lv3_1'] = encoder[s]['lv3'](residual[ps]['lv1'][:,:,0:int(H/2),0:int(W/2)])
    feature[s]['lv3_2'] = encoder[s]['lv3'](residual[ps]['lv1'][:,:,0:int(H/2),int(W/2):W])
    feature[s]['lv3_3'] = encoder[s]['lv3'](residual[ps]['lv1'][:,:,int(H/2):H,0:int(W/2)])
    feature[s]['lv3_4'] = encoder[s]['lv3'](residual[ps]['lv1'][:,:,int(H/2):H,int(W/2):W])
    feature[s]['lv3_top'] = torch.cat((feature[s]['lv3_1'], feature[s]['lv3_2']), 3) + feature[ps]['lv3_top']
    feature[s]['lv3_bot'] = torch.cat((feature[s]['lv3_3'], feature[s]['lv3_4']), 3) + feature[ps]['lv3_bot']
    residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top'])
    residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot'])

    feature[s]['lv2_1'] = encoder[s]['lv2'](residual[ps]['lv1'][:,:,0:int(H/2),:] + residual[s]['lv3_top']) + feature[s]['lv3_top'] + feature[ps]['lv2_1']
    feature[s]['lv2_2'] = encoder[s]['lv2'](residual[ps]['lv1'][:,:,int(H/2):H,:] + residual[s]['lv3_bot']) + feature[s]['lv3_bot'] + feature[ps]['lv2_2']
    feature[s]['lv2'] = torch.cat((feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
    residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2']) + residual['s1']['lv1']

    feature[s]['lv1'] = encoder[s]['lv1'](residual[ps]['lv1'] + residual[s]['lv2']) + feature[s]['lv2'] + feature[ps]['lv1']
    residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

    s = 's3'
    ps = 's2'
    feature[s]['lv3_1'] = encoder[s]['lv3'](residual[ps]['lv1'][:,:,0:int(H/2),0:int(W/2)])
    feature[s]['lv3_2'] = encoder[s]['lv3'](residual[ps]['lv1'][:,:,0:int(H/2),int(W/2):W])
    feature[s]['lv3_3'] = encoder[s]['lv3'](residual[ps]['lv1'][:,:,int(H/2):H,0:int(W/2)])
    feature[s]['lv3_4'] = encoder[s]['lv3'](residual[ps]['lv1'][:,:,int(H/2):H,int(W/2):W])
    feature[s]['lv3_top'] = torch.cat((feature[s]['lv3_1'], feature[s]['lv3_2']), 3) + feature[ps]['lv3_top']
    feature[s]['lv3_bot'] = torch.cat((feature[s]['lv3_3'], feature[s]['lv3_4']), 3) + feature[ps]['lv3_bot']
    residual[s]['lv3_top'] = decoder[s]['lv3'](feature[s]['lv3_top'])
    residual[s]['lv3_bot'] = decoder[s]['lv3'](feature[s]['lv3_bot'])

    feature[s]['lv2_1'] = encoder[s]['lv2'](residual[ps]['lv1'][:,:,0:int(H/2),:] + residual[s]['lv3_top']) + feature[s]['lv3_top'] + feature[ps]['lv2_1']
    feature[s]['lv2_2'] = encoder[s]['lv2'](residual[ps]['lv1'][:,:,int(H/2):H,:] + residual[s]['lv3_bot']) + feature[s]['lv3_bot'] + feature[ps]['lv2_2']
    feature[s]['lv2'] = torch.cat((feature[s]['lv2_1'], feature[s]['lv2_2']), 2)
    residual[s]['lv2'] = decoder[s]['lv2'](feature[s]['lv2']) + residual['s1']['lv1']

    feature[s]['lv1'] = encoder[s]['lv1'](residual[ps]['lv1'] + residual[s]['lv2']) + feature[s]['lv2'] + feature[ps]['lv1']
    residual[s]['lv1'] = decoder[s]['lv1'](feature[s]['lv1'])

    deblurred_image = residual[s]['lv1']
   
    return deblurred_image

In [6]:
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        out = self.dncnn(x)
        return out
    
dncnn_s = DnCNN(channels=3, num_of_layers=17)
device_ids = [0]
dncnn_s = nn.DataParallel(dncnn_s, device_ids=device_ids).cuda().half()
dncnn_s.load_state_dict(torch.load(os.path.join('../Denoising/Deep_Plug_and_play/checkpoints/dncnn_s25.pth')))
dncnn_s.eval()


DataParallel(
  (module): DnCNN(
    (dncnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): ReLU(inplace=True)
      (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): ReLU(inplace=True)
      (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU(inplace=True)
      (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (13):

In [7]:
print("init data folders")            
psnr_test = 0
import time
psnr_last = []
i=0
ii = 0
#-------------------------------------
#（1）noise
#-------------------------------------
sigma = 0 # 噪音水平
sigma_ = sigma / 255.
start = time.time()
ssim_test = 0
for images_name in os.listdir(SAMPLE_DIR):
    ii+=1
    print('Testing The %d Pictures'%ii)

    gt_pil = Image.open(sharp+'/'+images_name)
    gt_np = pil_to_np(gt_pil)
    gt_torch = np_to_torch(gt_np).cuda(GPU).half()
    img_pil = Image.open(SAMPLE_DIR + '/' + images_name)
    img_np = pil_to_np(img_pil)
    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)

    noisy_torch = np_to_torch(img_np)
    gt_torch = Variable(gt_torch)
    noisy_torch = Variable(noisy_torch.cuda(GPU).half())
    yita = 0.2
    #delta =8e-4
    delta =0.001
    lam = 1 - delta*(1 + yita)

    A_ = np.eye(gt_np.shape[1])*lam #shape: [a,a]
    A_T = np.eye(gt_np.shape[1]) # shape: [a,a]
    A = A_T
    rou = 0.00001
    space=np.zeros(gt_np.shape) #shape [3,a,b]
    y = torch_to_np(noisy_torch) #shape: [3,a,b]
    v_0 = np.zeros(gt_np.shape) #shape [3,a,b]
    x_0 = np.zeros(gt_np.shape) #shape [3,a,b]
    for i in range(y.shape[0]):   
        y_ = np.squeeze(y[i,:,:]) 
        x_0[i,:,:] = np.dot(A_T,y_)
    x_next = x_0  #shape [3,a,b]        
    v_next = v_0  #shape [3,a,b]

    with torch.no_grad():
        for k in range(2):  
            for i in range(x_next.shape[0]):
                x_next_ = x_next[i,:,:]
                y_ = y[i,:,:]
                v_next_ = v_next[i,:,:]
                x_next[i,:,:] = A_.dot(x_next_) + delta * A_T.dot(y_) + delta *v_next_ 
            x_next = np_to_torch(x_next).cuda().half()
            v_next = DMPHN(x_next.data-0.5).data+0.5
            x_next = torch_to_np(x_next)
            v_next = torch_to_np(v_next)  
            for j in range(x_next.shape[0]):
                A_temp =  torch_to_np(noisy_torch - dncnn_s(noisy_torch))[j,:,:]#shape [3,a,b]
                x_temp =  ((v_next[j,:,:].T).dot(A)).dot(v_next[j,:,:])[:gt_np.shape[1],:]#
                space[j,:,:] = rou*(A_temp+x_temp)
            t =  ((space[0,:,:]+space[1,:,:]+space[2,:,:])/3)[:gt_np.shape[1],:gt_np.shape[1]]

            A = (A -t)[:gt_np.shape[1],:gt_np.shape[1]]
            A_T = A.T

            v_hat = np_to_torch(v_next) 

        psnr = batch_PSNR(v_hat,gt_torch,1.)
        ssim = batch_SSIM(v_hat,gt_torch)
        ssim_test += ssim
        out_img = torch_to_np(v_hat).astype(np.float32)

        #plot_image_grid([np.clip(out_img, 0, 1),
        #                 np.clip(img_noisy_np, 0, 1),np.clip(gt_np, 0, 1)], factor=20, nrow=3)
        #save_images(np_to_torch(out_img),str(k))
        tim = time.time() - start
        print ('PSNR: %f  SSIM: %f  Time consumes: %f' % (psnr,ssim,tim), '\r',
        end='')
        print('\n')

    psnr_last.append(psnr)

psnr_last = np.asarray(psnr_last)
print('Average PSNR for testing is: %.8f'%(psnr_last.mean()))
#print(psnr_last)
ssim_test /= len(os.listdir(SAMPLE_DIR))
print("\nAverage SSIM on test data is: %f" % ssim_test)
end = time.time() - start
print('The total time is %f' % end)
print(basic_dir)

init data folders
Testing The 1 Pictures
PSNR: 33.777287  SSIM: 0.977680  Time consumes: 3.281894 

Testing The 2 Pictures
PSNR: 34.202271  SSIM: 0.983337  Time consumes: 6.319674 

Testing The 3 Pictures
PSNR: 34.810301  SSIM: 0.980196  Time consumes: 9.473063 

Testing The 4 Pictures
PSNR: 34.415710  SSIM: 0.980189  Time consumes: 12.556814 

Testing The 5 Pictures
PSNR: 34.022336  SSIM: 0.977208  Time consumes: 15.553538 

Testing The 6 Pictures
PSNR: 37.632935  SSIM: 0.991085  Time consumes: 18.610140 

Testing The 7 Pictures
PSNR: 36.640057  SSIM: 0.986168  Time consumes: 21.532096 

Testing The 8 Pictures
PSNR: 29.927356  SSIM: 0.963578  Time consumes: 24.635446 

Testing The 9 Pictures
PSNR: 31.327457  SSIM: 0.972799  Time consumes: 27.667075 

Testing The 10 Pictures
PSNR: 36.443943  SSIM: 0.987925  Time consumes: 30.617002 

Testing The 11 Pictures
PSNR: 37.104492  SSIM: 0.988002  Time consumes: 33.812771 

Testing The 12 Pictures
PSNR: 34.623210  SSIM: 0.979385  Time consumes

PSNR: 35.166948  SSIM: 0.982035  Time consumes: 436.477355 

Testing The 99 Pictures
PSNR: 35.744938  SSIM: 0.984157  Time consumes: 442.384478 

Testing The 100 Pictures
PSNR: 34.764120  SSIM: 0.982188  Time consumes: 447.976438 

Average PSNR for testing is: 34.10321497

Average SSIM on test data is: 0.977398
The total time is 447.979483
datas/GoPro/test/GOPR0881_11_01/
