In [None]:
# -*- coding: utf-8 -*-

# =============================================================================
#  @article{zhang2017beyond,
#    title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
#    author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
#    journal={IEEE Transactions on Image Processing},
#    year={2017},
#    volume={26}, 
#    number={7}, 
#    pages={3142-3155}, 
#  }
# by Kai Zhang (08/2018)
# cskaizhang@gmail.com
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# run this to test the model

In [1]:
!pwd
%cd drive/MyDrive/Colab\ Notebooks

/home/gbu06095
[Errno 2] No such file or directory: 'drive/MyDrive/Colab Notebooks'
/home/gbu06095


In [4]:
!pip list

Package                        Version
------------------------------ -------------------
aiohttp                        3.7.4.post0
ansiwrap                       0.8.4
anyio                          3.2.0
appdirs                        1.4.4
argon2-cffi                    20.1.0
arrow                          1.1.0
asn1crypto                     1.4.0
async-generator                1.10
async-timeout                  3.0.1
attrs                          21.2.0
backcall                       0.2.0
backports.functools-lru-cache  1.6.4
beatrix-jupyterlab             0.0.3
binaryornot                    0.4.4
black                          21.5b2
bleach                         3.3.0
blinker                        1.4
Bottleneck                     1.3.2
brotlipy                       0.7.0
cachetools                     4.2.2
caip-notebooks-serverextension 1.0.0
certifi                        2021.10.8
cffi                           1.14.5
chardet                

In [1]:
import argparse
import os, time, datetime
# import PIL.Image as Image
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import torch
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
#from skimage.measure import compare_psnr, compare_ssim
from skimage.io import imread, imsave
import cv2

import easydict
import import_ipynb
import data_generator as dg

importing Jupyter notebook from data_generator.ipynb


In [2]:
def parse_args():
    #parser = argparse.ArgumentParser()
    #parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
    #parser.add_argument('--set_names', default=['Set68', 'Set12'], help='directory of test dataset')
    #parser.add_argument('--sigma', default=25, type=int, help='noise level')
    #parser.add_argument('--model_dir', default=os.path.join('models', 'DnCNN_sigma25'), help='directory of the model')
    #parser.add_argument('--model_name', default='model_001.pth', type=str, help='the model name')
    #parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
    #parser.add_argument('--save_result', default=0, type=int, help='save the denoised image, 1 or 0')

    args = easydict.EasyDict({
    "set_dir" : 'Test_set',
    "set_names" : ['Test_set(qp22)'],
    "sigma" : 25,
    "model_dir" : os.path.join('models', 'DnCNN_sigma25'),
    "model_name" : 'model_010.pth',
    "result_dir" : 'results',
    "save_result" : 0
    })
    return args


def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


def save_result(result, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        imsave(path, np.clip(result, 0, 1))


def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()


class DnCNN(nn.Module):

    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []
        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum=0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)


if __name__ == '__main__':

    args = parse_args()

    # model = DnCNN()
    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):

        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        # model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
#    print(params.keys())
#
#    for key, value in params.items():
#        print(key)    # parameter name
#    print(params['dncnn.12.running_mean'])
#    print(model.state_dict())

    model.eval()  # evaluation mode
#    model.train()

    if torch.cuda.is_available():
        model = model.cuda()

    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)

    for set_cur in args.set_names:

        if not os.path.exists(os.path.join(args.result_dir, set_cur)):
            os.mkdir(os.path.join(args.result_dir, set_cur))
        psnrs = []
        ssims = []

        for im in os.listdir(os.path.join(args.set_dir, set_cur)):
            if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
                
                x = np.array(cv2.imread(os.path.join(args.set_dir, set_cur, im), 0), dtype=np.float32)/255.0
                h, w = x.shape
                x = cv2.resize(x, (h, h), interpolation=cv2.INTER_CUBIC)
                np.random.seed(seed=0)  # for reproducibility
                y = x + np.random.normal(0, args.sigma/255.0, x.shape)  # Add Gaussian noise without clipping
                y = y.astype(np.float32)
                y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])

                torch.cuda.synchronize()
                start_time = time.time()
                y_ = y_.cuda()
                x_ = model(y_)  # inference
                x_ = x_.view(y.shape[0], y.shape[1])
                x_ = x_.cpu()
                x_ = x_.detach().numpy().astype(np.float32)
                torch.cuda.synchronize()
                elapsed_time = time.time() - start_time
                print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))

                psnr_x_ = compare_psnr(x, x_, data_range = 2)
                ssim_x_ = compare_ssim(x, x_)
                if args.save_result:
                    name, ext = os.path.splitext(im)
                    show(np.hstack((y, x_)))  # show the image
                    save_result(x_, path=os.path.join(args.result_dir, set_cur, name+'_dncnn'+ext))  # save the denoised image
                psnrs.append(psnr_x_)
                ssims.append(ssim_x_)
        psnr_avg = np.mean(psnrs)
        ssim_avg = np.mean(ssims)
        psnrs.append(psnr_avg)
        ssims.append(ssim_avg)
        if args.save_result:
            save_result(np.hstack((psnrs, ssims)), path=os.path.join(args.result_dir, set_cur, 'results.txt'))
        log('Datset: {0:10s} \n  PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))

2021-12-12 14:26:47: load trained model
Test_set(qp22) : 000000022935.png : 0.6071 second
Test_set(qp22) : 000000002153.png : 0.0806 second
Test_set(qp22) : 000000166287.png : 0.0628 second
Test_set(qp22) : 000000025560.png : 0.0602 second
Test_set(qp22) : 000000008021.png : 0.0535 second
Test_set(qp22) : 000000155341.png : 0.0536 second
Test_set(qp22) : 000000024610.png : 0.0530 second
Test_set(qp22) : 000000023230.png : 0.0534 second
Test_set(qp22) : 000000017627.png : 0.0529 second
Test_set(qp22) : 000000027620.png : 0.0524 second
Test_set(qp22) : 000000161397.png : 0.0540 second
Test_set(qp22) : 000000009891.png : 0.0527 second
Test_set(qp22) : 000000018150.png : 0.0536 second
Test_set(qp22) : 000000025593.png : 0.0526 second
Test_set(qp22) : 000000009769.png : 0.0540 second
Test_set(qp22) : 000000165336.png : 0.0531 second
Test_set(qp22) : 000000001000.png : 0.0541 second
Test_set(qp22) : 000000022589.png : 0.0532 second
Test_set(qp22) : 000000018837.png : 0.0535 second
Test_set(q

In [18]:
true_min, true_max = np.min(x), np.max(x)
true_min, true_max

(-0.11421569, 1.1099265)