In [None]:
"""
test the accuracy of DNCNN model from https://github.com/SaoYan/DnCNN-PyTorch
"""

'\ntest the accuracy of DNCNN model from https://github.com/SaoYan/DnCNN-PyTorch\n'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
trained_model_path='/content/drive/My Drive/DNCNN_logs'
data_path = '/content/drive/My Drive/data'
test_dataset = '/Set12'
test_noiseL = 25

In [None]:
import cv2
import os
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from skimage.metrics import peak_signal_noise_ratio

In [None]:
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        out = self.dncnn(x)
        return out

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += peak_signal_noise_ratio(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])


In [None]:
def normalize(data):
    return data/255.

def main():
    # Build model
    print('Loading model ...\n')
    net = DnCNN(channels=1, num_of_layers=17)
    device_ids = [0]
    model = nn.DataParallel(net, device_ids=device_ids).cuda()
    model.load_state_dict(torch.load(os.path.join(trained_model_path + '/DnCNN-S-25', 'net.pth')))
    model.eval()
    # load data info
    print('Loading data info ...\n')
    files_source = glob.glob(os.path.join(data_path + test_dataset, '*.png'))
    files_source.sort()
    # process data
    psnr_test = 0
    for f in files_source:
        # image
        Img = cv2.imread(f)
        Img = normalize(np.float32(Img[:,:,0]))
        Img = np.expand_dims(Img, 0)
        Img = np.expand_dims(Img, 1)
        ISource = torch.Tensor(Img)
        # noise
        noise = torch.FloatTensor(ISource.size()).normal_(mean=0, std=test_noiseL/255.)
        # noisy image
        INoisy = ISource + noise
        ISource, INoisy = Variable(ISource.cuda()), Variable(INoisy.cuda())
        with torch.no_grad(): # this can save much memory
            Out = torch.clamp(INoisy-model(INoisy), 0., 1.)
        ## if you are using older version of PyTorch, torch.no_grad() may not be supported
        # ISource, INoisy = Variable(ISource.cuda(),volatile=True), Variable(INoisy.cuda(),volatile=True)
        # Out = torch.clamp(INoisy-model(INoisy), 0., 1.)
        psnr = batch_PSNR(Out, ISource, 1.)
        psnr_test += psnr
        print("%s PSNR %f" % (f, psnr))
    psnr_test /= len(files_source)
    print("\nPSNR on test data %f" % psnr_test)

main()

Loading model ...

Loading data info ...

/content/drive/My Drive/data/Set12/01.png PSNR 30.074796
/content/drive/My Drive/data/Set12/02.png PSNR 33.030703
/content/drive/My Drive/data/Set12/03.png PSNR 30.770762
/content/drive/My Drive/data/Set12/04.png PSNR 29.382573
/content/drive/My Drive/data/Set12/05.png PSNR 30.279145
/content/drive/My Drive/data/Set12/06.png PSNR 29.153004
/content/drive/My Drive/data/Set12/07.png PSNR 29.406059
/content/drive/My Drive/data/Set12/08.png PSNR 32.349526
/content/drive/My Drive/data/Set12/09.png PSNR 29.888108
/content/drive/My Drive/data/Set12/10.png PSNR 30.158338
/content/drive/My Drive/data/Set12/11.png PSNR 30.013587
/content/drive/My Drive/data/Set12/12.png PSNR 29.991291

PSNR on test data 30.374824
