In [1]:
import torch
import torch as t
from skimage.measure import compare_psnr
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
import numpy as np
import os
import math
import matplotlib.pyplot as plt
import random
import cv2
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import time
from PIL import Image
import models as models
from utils import *
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
dtype = t.cuda.HalfTensor

# Hyper Parameters

In [15]:
METHOD = "DMPHN_1_2_4"
basic = 'datas/GoPro/test/'
cho = 'GOPR0854_11_00/'
basic_dir=basic + cho
SAMPLE_DIR = basic_dir + "blur"
EXPDIR = basic_dir + "deblur"
sharp = basic_dir +'sharp'
    
GPU = 0

# Preprocessing

In [3]:
def save_images(images, name):
    filename = EXPDIR + "/" + name
    torchvision.utils.save_image(images, filename)

def weight_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
        m.weight.data.normal_(0, 0.5*math.sqrt(2. / n))
        if m.bias is not None:
            m.bias.data.zero_()
    elif classname.find('BatchNorm') != -1:
        m.weight.data.fill_(1)
        m.bias.data.zero_()
    elif classname.find('Linear') != -1:
        n = m.weight.size(1)
        m.weight.data.normal_(0, 0.01)
        m.bias.data = torch.ones(m.bias.data.size())


# DMPHN

In [4]:
encoder_lv1 = models.Encoder().apply(weight_init).cuda(GPU)
encoder_lv2 = models.Encoder().apply(weight_init).cuda(GPU)
encoder_lv3 = models.Encoder().apply(weight_init).cuda(GPU)

decoder_lv1 = models.Decoder().apply(weight_init).cuda(GPU)
decoder_lv2 = models.Decoder().apply(weight_init).cuda(GPU)
decoder_lv3 = models.Decoder().apply(weight_init).cuda(GPU)

if os.path.exists(str('checkpoints/' + METHOD + "/encoder_lv1.pkl")):
    encoder_lv1.load_state_dict(torch.load(str('checkpoints/' + METHOD + "/encoder_lv1.pkl")))
    #print("load encoder_lv1 success")
if os.path.exists(str('checkpoints/' + METHOD + "/encoder_lv2.pkl")):
    encoder_lv2.load_state_dict(torch.load(str('checkpoints/' + METHOD + "/encoder_lv2.pkl")))
    #print("load encoder_lv2 success")
if os.path.exists(str('checkpoints/' + METHOD + "/encoder_lv3.pkl")):
    encoder_lv3.load_state_dict(torch.load(str('checkpoints/' + METHOD + "/encoder_lv3.pkl")))
    #print("load encoder_lv3 success")


if os.path.exists(str('checkpoints/' + METHOD + "/decoder_lv1.pkl")):
    decoder_lv1.load_state_dict(torch.load(str('checkpoints/' + METHOD + "/decoder_lv1.pkl")))
    #print("load encoder_lv1 success")
if os.path.exists(str('checkpoints/' + METHOD + "/decoder_lv2.pkl")):
    decoder_lv2.load_state_dict(torch.load(str('checkpoints/' + METHOD + "/decoder_lv2.pkl")))
    #print("load decoder_lv2 success")
if os.path.exists(str('checkpoints/' + METHOD + "/decoder_lv3.pkl")):
    decoder_lv3.load_state_dict(torch.load(str('checkpoints/' + METHOD + "/decoder_lv3.pkl")))
    #print("load decoder_lv3 success")


In [5]:
def DMPHN(images_lv1):
    #if os.path.exists('./test_results/' + EXPDIR) == False:
    #    os.system('mkdir ./test_results/' + EXPDIR)
    H = images_lv1.size(2)
    W = images_lv1.size(3)
    
    images_lv2_1 = images_lv1[:,:,0:int(H/2),:]
    images_lv2_2 = images_lv1[:,:,int(H/2):H,:]
    images_lv3_1 = images_lv2_1[:,:,:,0:int(W/2)]
    images_lv3_2 = images_lv2_1[:,:,:,int(W/2):W]
    images_lv3_3 = images_lv2_2[:,:,:,0:int(W/2)]
    images_lv3_4 = images_lv2_2[:,:,:,int(W/2):W]

    feature_lv3_1 = encoder_lv3(images_lv3_1)
    feature_lv3_2 = encoder_lv3(images_lv3_2)
    feature_lv3_3 = encoder_lv3(images_lv3_3)
    feature_lv3_4 = encoder_lv3(images_lv3_4)
    feature_lv3_top = torch.cat((feature_lv3_1, feature_lv3_2), 3)
    feature_lv3_bot = torch.cat((feature_lv3_3, feature_lv3_4), 3)
    feature_lv3 = torch.cat((feature_lv3_top, feature_lv3_bot), 2)
    residual_lv3_top = decoder_lv3(feature_lv3_top)
    residual_lv3_bot = decoder_lv3(feature_lv3_bot)

    feature_lv2_1 = encoder_lv2(images_lv2_1 + residual_lv3_top)
    feature_lv2_2 = encoder_lv2(images_lv2_2 + residual_lv3_bot)
    feature_lv2 = torch.cat((feature_lv2_1, feature_lv2_2), 2) + feature_lv3
    residual_lv2 = decoder_lv2(feature_lv2)

    feature_lv1 = encoder_lv1(images_lv1 + residual_lv2) + feature_lv2
    deblur_image = decoder_lv1(feature_lv1)
    
    #deblur_image = deblur_image.data + 0.5 
    return deblur_image

In [12]:
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        out = self.dncnn(x)
        return out
    
dncnn_s = DnCNN(channels=3, num_of_layers=17)
device_ids = [0]
dncnn_s = nn.DataParallel(dncnn_s, device_ids=device_ids).cuda()
dncnn_s.load_state_dict(torch.load(os.path.join('../Denoising/Deep_Plug_and_play/checkpoints/dncnn_s15.pth')))
dncnn_s.eval()


DataParallel(
  (module): DnCNN(
    (dncnn): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (4): ReLU(inplace=True)
      (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (7): ReLU(inplace=True)
      (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (9): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU(inplace=True)
      (11): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (12): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (13):

In [19]:
print("init data folders")            
psnr_test = 0
import time
psnr_last = []
plot = False
i=0
ii = 0
#-------------------------------------
#（1）noise
#-------------------------------------
sigma = 0 # 噪音水平
sigma_ = sigma / 255.
start = time.time()
ssim_test = 0
for images_name in os.listdir(SAMPLE_DIR):
    ii+=1
    print('Testing The %d Pictures'%ii)
    
    gt_pil = Image.open(sharp+'/'+images_name)
    gt_np = pil_to_np(gt_pil)
    gt_torch = np_to_torch(gt_np).cuda(GPU)
    img_pil = Image.open(SAMPLE_DIR + '/' + images_name)
    img_np = pil_to_np(img_pil)
    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)

    noisy_torch = np_to_torch(img_np)
    gt_torch = Variable(gt_torch)
    noisy_torch = Variable(noisy_torch.cuda(GPU))
    yita = 0.1
    #delta =8e-4
    delta =0.001
    lam = 1 - delta*(1 + yita)

    A_ = np.eye(gt_np.shape[1])*lam #shape: [a,a]
    A_T = np.eye(gt_np.shape[1]) # shape: [a,a]
    A = A_T
    rou = 0.00001
    space=np.zeros(gt_np.shape) #shape [3,a,b]
    y = torch_to_np(noisy_torch) #shape: [3,a,b]
    v_0 = np.zeros(gt_np.shape) #shape [3,a,b]
    x_0 = np.zeros(gt_np.shape) #shape [3,a,b]
    for i in range(y.shape[0]):   
        y_ = np.squeeze(y[i,:,:]) 
        x_0[i,:,:] = np.dot(A_T,y_)
    x_next = x_0  #shape [3,a,b]        
    v_next = v_0  #shape [3,a,b]
      
    with torch.no_grad():
        for i in range(2):  
            for i in range(x_next.shape[0]):
                x_next_ = x_next[i,:,:]
                y_ = y[i,:,:]
                v_next_ = v_next[i,:,:]
                x_next[i,:,:] = A_.dot(x_next_) + delta * A_T.dot(y_) + delta *v_next_ 
            x_next = np_to_torch(x_next).cuda().float()
            v_next = DMPHN(x_next.data-0.5).data+0.5
            x_next = torch_to_np(x_next)
            v_next = torch_to_np(v_next)  
            for j in range(x_next.shape[0]):
                A_temp =  torch_to_np(noisy_torch - dncnn_s(noisy_torch))[j,:,:]#shape [3,a,b]
                x_temp =  ((v_next[j,:,:].T).dot(A)).dot(v_next[j,:,:])[:gt_np.shape[1],:]#
                space[j,:,:] = rou*(A_temp+x_temp)
            t =  ((space[0,:,:]+space[1,:,:]+space[2,:,:])/3)[:gt_np.shape[1],:gt_np.shape[1]]
            
            A = (A -t)[:gt_np.shape[1],:gt_np.shape[1]]
            A_T = A.T
        v_hat = np_to_torch(v_next) 
          
        psnr = batch_PSNR(v_hat,gt_torch,1.)
        ssim = batch_SSIM(v_hat,gt_torch)
        ssim_test += ssim
        out_img = torch_to_np(v_hat).astype(np.float32)
        if plot == True:
            plot_image_grid([np.clip(out_img, 0, 1),
                             np.clip(img_noisy_np, 0, 1),np.clip(gt_np, 0, 1)], factor=20, nrow=3)
        tim = time.time() - start
        print ('PSNR: %f  SSIM: %f  Time consumes: %f' % (psnr,ssim,tim), '\r',
        end='')
        print('\n')
    
    psnr_last.append(psnr)
    
psnr_last = np.asarray(psnr_last)
print('Average PSNR for testing is: %.8f'%(psnr_last.mean()))
#print(psnr_last)
ssim_test /= len(os.listdir(SAMPLE_DIR))
print("\nAverage SSIM on test data is: %f" % ssim_test)
end = time.time() - start
print('The total time is %f' % end)
print(cho)

init data folders
Testing The 1 Pictures
PSNR: 27.453828  SSIM: 0.892091  Time consumes: 1.000281 

Testing The 2 Pictures
PSNR: 24.134423  SSIM: 0.807821  Time consumes: 1.902080 

Testing The 3 Pictures
PSNR: 28.954836  SSIM: 0.931098  Time consumes: 2.804260 

Testing The 4 Pictures
PSNR: 28.113755  SSIM: 0.904334  Time consumes: 3.706643 

Testing The 5 Pictures
PSNR: 26.013758  SSIM: 0.882014  Time consumes: 4.610795 

Testing The 6 Pictures
PSNR: 32.316207  SSIM: 0.969289  Time consumes: 5.519813 

Testing The 7 Pictures
PSNR: 26.825066  SSIM: 0.884891  Time consumes: 6.423859 

Testing The 8 Pictures
PSNR: 21.569924  SSIM: 0.734042  Time consumes: 7.329440 

Testing The 9 Pictures
PSNR: 26.852755  SSIM: 0.872078  Time consumes: 8.228904 

Testing The 10 Pictures
PSNR: 25.459673  SSIM: 0.876178  Time consumes: 9.133497 

Testing The 11 Pictures
PSNR: 26.834524  SSIM: 0.895785  Time consumes: 10.030993 

Testing The 12 Pictures
PSNR: 28.672987  SSIM: 0.921256  Time consumes: 10.92

PSNR: 22.917132  SSIM: 0.805430  Time consumes: 89.743865 

Testing The 100 Pictures
PSNR: 25.387540  SSIM: 0.857994  Time consumes: 90.646687 

Average PSNR for testing is: 26.98221730

Average SSIM on test data is: 0.884104
The total time is 90.648338
GOPR0854_11_00/
