In [16]:
import argparse
import os, time, datetime
import PIL.Image as Image
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import torch
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim
from skimage.io import imread, imsave
import sys
import warnings

In [17]:
warnings.filterwarnings('ignore')

In [18]:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
    parser.add_argument('--set_names', default=['Set68', 'Set12'], help='directory of test dataset')
    parser.add_argument('--sigma', default=4, type=int, help='noise level')
    parser.add_argument('--model_dir', default=os.path.join('models', 'DnCNN_sigma4'), help='directory of the model')
    parser.add_argument('--model_name', default='model_003.pth', type=str, help='the model name')
    parser.add_argument('--result_dir', default='results', type=str, help='directory of test dataset')
    parser.add_argument('--save_result', default=0, type=int, help='save the denoised image, 1 or 0')
    return parser.parse_args()


def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


def save_result(result, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        imsave(path, np.clip(result, 0, 1))


def show(x, title=None, cbar=False, figsize=None):
    import matplotlib.pyplot as plt
    plt.figure(figsize=figsize)
    plt.imshow(x, interpolation='nearest', cmap='gray')
    if title:
        plt.title(title)
    if cbar:
        plt.colorbar()
    plt.show()

In [19]:
class Residual_Blocks(nn.Module):
    def __init__(self, in_channels=64, out_channels=64, padding = 1):
        super(Residual_Blocks, self).__init__()
        self.convx_1 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)
        self.Leakyrelu = nn.LeakyReLU(inplace=True)
        self.BN = nn.BatchNorm2d(64)
        self.convx_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=False)
        self.Leakyrelu = nn.LeakyReLU(inplace=True)

    def forward(self, x):
        x = self.BN(self.Leakyrelu(self.convx_1(x)))
        out = self.BN(self.Leakyrelu(self.convx_2(x)))
        return out
    
class Attention(nn.Module):
    def __init__(self, reduction=8):
        super(Attention, self).__init__()
        self.red = reduction
        self.query = nn.Conv2d(in_channels=64, out_channels=64//self.red, kernel_size=3, padding=1, stride=1, bias=False)
        self.value = nn.Conv2d(in_channels=64, out_channels=64//self.red, kernel_size=3, padding=1, stride=1, bias=False)
        self.key = nn.Conv2d(in_channels=64, out_channels=64//self.red, kernel_size=3, padding=1, stride=1, bias=False)
        self.out = nn.Conv2d(in_channels=64//self.red, out_channels=64, kernel_size=3, padding=1, stride=1, bias=False)
        
    def forward(self, x):
        b, c, h, w = x.size()
        query = self.query(x).view(b, -1, h*w)
        value = self.value(x).view(b, -1, h*w)
        key = self.key(x).view(b, -1, h*w)
        mul_1 = torch.bmm(query, value.transpose(1,2))
        attention_weights = mul_1 / np.sqrt(64//self.red)
        res_1 = nn.functional.softmax(attention_weights, dim=-1)
        mul_2 = torch.bmm(res_1, key).view(b, -1, h, w)
        out = self.out(mul_2)
        return x + out

        
class DnCNN(nn.Module):
    def __init__(self, n_channels=64, image_channels=1):
        super(DnCNN, self).__init__()
        self.conv1 = nn.Conv2d(image_channels, n_channels, kernel_size=3, padding=1, bias=False)
        self.attention = Attention()
        self.act = nn.LeakyReLU(inplace=True)
        self.out = nn.Sequential(*[Residual_Blocks() for _ in range(1)])
        self.dn = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=True)
        self._initialize_weights()
        
    def forward(self, x):
        y = x
        out = self.act(self.attention(self.conv1(x)))
        out = self.out(out)
        dn = y - self.dn(out)
        return dn
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

In [20]:
def parse_args(args=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('--set_dir', default='data/Test', type=str, help='directory of test dataset')
    parser.add_argument('--set_names', nargs='+', default=['Set68', 'Set12'], help='list of dataset names')
    parser.add_argument('--sigma', default=4, type=int, help='noise level')
    parser.add_argument('--model_dir', default='models/DnCNN_sigma4', type=str, help='directory of the model')
    parser.add_argument('--model_name', default='model_003.pth', type=str, help='the model name')
    parser.add_argument('--result_dir', default='results', type=str, help='directory for storing results')
    parser.add_argument('--save_result', default=False, action='store_true', help='save the denoised image, true or false')

    return parser.parse_args(args)
 
if __name__ == '__main__':
    args = parse_args(['--set_dir', 'data/Test', '--set_names', 'Set68', 'Set12', '--sigma', '4', '--model_dir', 'models/DnCNN_sigma4', '--model_name', 'model_003.pth', '--result_dir', 'results', '--save_result'])

    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):

        model = torch.load(os.path.join(args.model_dir, 'model_003.pth'))
        # load weights into new model
        log('load trained model on Train400 dataset by kai')
    else:
        # model.load_state_dict(torch.load(os.path.join(args.model_dir, args.model_name)))
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

#    params = model.state_dict()
#    print(params.values())
#    print(params.keys())
#
#    for key, value in params.items():
#        print(key)    # parameter name
#    print(params['dncnn.12.running_mean'])
#    print(model.state_dict())

    model.eval()  # evaluation mode
#    model.train()

    if torch.cuda.is_available():
        model = model.cuda()

    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)

    for set_cur in args.set_names:

        if not os.path.exists(os.path.join(args.result_dir, set_cur)):
            os.mkdir(os.path.join(args.result_dir, set_cur))
        psnrs = []
        ssims = []

        for im in os.listdir(os.path.join(args.set_dir, set_cur)):
            if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):

                x = np.array(imread(os.path.join(args.set_dir, set_cur, im)), dtype=np.float32)/255.0
                np.random.seed(seed=0)  # for reproducibility
                y = x + np.random.normal(0, args.sigma/255.0, x.shape)  # Add Gaussian noise without clipping
                y = y.astype(np.float32)
                y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])

                torch.cuda.synchronize()
                start_time = time.time()
                y_ = y_.cuda()
                x_ = model(y_)  # inference
                x_ = x_.view(y.shape[0], y.shape[1])
                x_ = x_.cpu()
                x_ = x_.detach().numpy().astype(np.float32)
                torch.cuda.synchronize()
                elapsed_time = time.time() - start_time
                print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))

                im_x = np.array(x, dtype=np.float32)
                im_x_ = np.array(x_, dtype=np.float32)
                data_range = im_x.max() - im_x.min()

                psnr_x_ = compare_psnr(x, x_)
                ssim_x_ = compare_ssim(x, x_, data_range = data_range)
                if args.save_result:
                    name, ext = os.path.splitext(im)
                    show(np.hstack((y, x_)))  # show the image
                    res = np.array(x_)
                    res_dn = np.clip(res, 0, 255).astype(np.uint8)
                    im_dn = Image.fromarray(res_dn)
                    save_result(im_dn, path=os.path.join(args.result_dir, set_cur, name+'_dncnn'+ext))  # save the denoised image
                psnrs.append(psnr_x_)
                ssims.append(ssim_x_)
        psnr_avg = np.mean(psnrs)
        ssim_avg = np.mean(ssims)
        psnrs.append(psnr_avg)
        ssims.append(ssim_avg)
        if args.save_result:
            save_result(np.hstack((psnrs, ssims)), path=os.path.join(args.result_dir, set_cur, 'results.txt'))
        log('Datset: {0:10s} \n  PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))

FileNotFoundError: [Errno 2] No such file or directory: 'models/DnCNN_sigma4/model_003.pth'