## Import Module

In [1]:
import argparse
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR
import easydict
import cv2
# from multiprocessing import Pool
from torch.utils.data import Dataset

## Define Model Constants

In [2]:
args = easydict.EasyDict({
    "model" : "DnCNN",
    "batch_size" : 128,
    "train_data" : 'Noise_image_Training(qp22)',
    "origin_data" : 'Noise_image_Training(qp22)_original',
    "sigma" : 25,
    "epoch" : 10,
    "lr" : 1e-3
})

batch_size = args.batch_size
cuda = torch.cuda.is_available()
n_epoch = args.epoch
sigma = args.sigma

save_dir = os.path.join('models', args.model+'_' + 'sigma' + str(sigma))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

## Define Data Loader

In [3]:
patch_size, stride = 40, 10
aug_times = 1
scales = [1, 0.9, 0.8, 0.7]

class DenoisingDataset(Dataset):
    def __init__(self, ys, xs):
        super(DenoisingDataset, self).__init__()
        self.ys = ys # compression Noisy image
        self.xs = xs # original image
        

    def __getitem__(self, index):
        batch_x = self.xs[index]
        batch_y = self.ys[index]
        return batch_y, batch_x

    def __len__(self):
        return self.xs.size(0)

def datagenerator(data_dir, img_type):
    file_list = glob.glob(data_dir + '/*.' + img_type)
    data = []
    for i in range(len(file_list)):
        img = cv2.imread(file_list[i], 0)
        data.append(img)
    data = np.array(data, dtype = 'uint8')
    data = np.expand_dims(data, axis=3)
    
    return data

## Define DnCNN Model

In [4]:
class DnCNN(nn.Module):
    
    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        self.depth = depth
        
        self.conv_relu_layers = nn.Sequential(
            nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True),
            nn.ReLU(inplace=True)
        )

        self.conv_bn_relu_layers = nn.Sequential(
            nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False),
            nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95),
            nn.ReLU(inplace=True)
        )
        
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False)
        )

        #self._initialize_weights()

    def forward(self, x):
        
        y = x
        
        out = self.conv_relu_layers(x)
        
        for _ in range(self.depth - 2):
            out = self.conv_bn_relu_layers(out)
        
        out = self.conv_layers(out)
        
        return y - out

In [5]:
def findLastCheckpoint(save_dir):
    file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            result = re.findall(".*model_(.*).pth.*", file_)
            epochs_exist.append(int(result[0]))
        initial_epoch = max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch


def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)

## Train Model

In [None]:
if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = DnCNN()
    
    initial_epoch = findLastCheckpoint(save_dir=save_dir)  # load the last model in matconvnet style
    
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
    
    model.train()
    criterion = nn.MSELoss(reduction = 'sum')
    
    if cuda:
        model = model.cuda()
    
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)  # learning rates
    
    train_dataset = datagenerator(args.train_data, 'png')
    train_dataset = train_dataset.astype('float32')/255.0
    train_dataset = torch.from_numpy(train_dataset.transpose((0, 3, 1, 2)))
    
    train_target = datagenerator(args.origin_data, 'jpg')
    train_target = train_target.astype('float32')/255.0
    train_target = torch.from_numpy(train_target.transpose((0, 3, 1, 2)))
    
    Noisy_Dataset = DenoisingDataset(train_dataset, train_target)
    DLoader = DataLoader(dataset = Noisy_Dataset, num_workers=4, batch_size=batch_size, shuffle=True)
    
    for epoch in range(initial_epoch, n_epoch):

        scheduler.step(epoch)  # step to the learning rate in this epcoh
        epoch_loss = 0
        start_time = time.time()

        for n_count, batch_yx in enumerate(DLoader):
                optimizer.zero_grad()
                if cuda:
                    batch_y, batch_x = batch_yx[0].cuda(), batch_yx[1].cuda()
                denoising = model(batch_y)
                copy_denoising = denoising.data.cpu().numpy()
                loss = criterion(copy_denoising, batch_x)
                epoch_loss += loss.item()
                loss.backward()
                optimizer.step()
                if n_count % 10 == 0:
                    print('%4d %4d / %4d loss = %2.4f' % (epoch+1, n_count, train_dataset.size(0)//batch_size, loss.item()/batch_size))
        elapsed_time = time.time() - start_time

        log('epoch = %4d , loss = %4.4f , time = %4.2f s' % (epoch+1, epoch_loss/n_count, elapsed_time))
        np.savetxt('train_result.txt', np.hstack((epoch+1, epoch_loss/n_count, elapsed_time)), fmt='%2.4f')
        torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))

## Import Test moduel

In [None]:
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as compare_ssim

## Define Model Test Constants

In [None]:
def parse_args():

    args = easydict.EasyDict({
    "set_dir" : 'Test_set',
    "set_names" : ['Test_set(qp22)'],
    "sigma" : 25,
    "model_dir" : os.path.join('models', 'DnCNN_sigma25'),
    "model_name" : 'model_010.pth',
    "result_dir" : 'results',
    "save_result" : 0
    })
    return args


def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)


def save_result(result, path):
    path = path if path.find('.') != -1 else path+'.png'
    ext = os.path.splitext(path)[-1]
    if ext in ('.txt', '.dlm'):
        np.savetxt(path, result, fmt='%2.4f')
    else:
        imsave(path, np.clip(result, 0, 1))

## Model Test

In [None]:
if __name__ == '__main__':

    args = parse_args()

    if not os.path.exists(os.path.join(args.model_dir, args.model_name)):

        model = torch.load(os.path.join(args.model_dir, 'model.pth'))
        log('load trained model on Train400 dataset by kai')
    else:
        model = torch.load(os.path.join(args.model_dir, args.model_name))
        log('load trained model')

    model.eval()

    if torch.cuda.is_available():
        model = model.cuda()

    if not os.path.exists(args.result_dir):
        os.mkdir(args.result_dir)

    for set_cur in args.set_names:

        if not os.path.exists(os.path.join(args.result_dir, set_cur)):
            os.mkdir(os.path.join(args.result_dir, set_cur))
        psnrs = []
        ssims = []

        for im in os.listdir(os.path.join(args.set_dir, set_cur)):
            if im.endswith(".jpg") or im.endswith(".bmp") or im.endswith(".png"):
                
                x = np.array(cv2.imread(os.path.join(args.set_dir, set_cur, im), 0), dtype=np.float32)/255.0
                h, w = x.shape
                x = cv2.resize(x, (h, h), interpolation=cv2.INTER_CUBIC)
                np.random.seed(seed=0)  # for reproducibility
                y = x + np.random.normal(0, args.sigma/255.0, x.shape)  # Add Gaussian noise without clipping
                y = y.astype(np.float32)
                y_ = torch.from_numpy(y).view(1, -1, y.shape[0], y.shape[1])

                torch.cuda.synchronize()
                start_time = time.time()
                y_ = y_.cuda()
                x_ = model(y_)  # inference
                x_ = x_.view(y.shape[0], y.shape[1])
                x_ = x_.cpu()
                x_ = x_.detach().numpy().astype(np.float32)
                torch.cuda.synchronize()
                elapsed_time = time.time() - start_time
                print('%10s : %10s : %2.4f second' % (set_cur, im, elapsed_time))

                psnr_x_ = compare_psnr(x, x_, data_range = 2)
                ssim_x_ = compare_ssim(x, x_)
                if args.save_result:
                    name, ext = os.path.splitext(im)
                    show(np.hstack((y, x_)))  # show the image
                    save_result(x_, path=os.path.join(args.result_dir, set_cur, name+'_dncnn'+ext))  # save the denoised image
                psnrs.append(psnr_x_)
                ssims.append(ssim_x_)
        psnr_avg = np.mean(psnrs)
        ssim_avg = np.mean(ssims)
        psnrs.append(psnr_avg)
        ssims.append(ssim_avg)
        if args.save_result:
            save_result(np.hstack((psnrs, ssims)), path=os.path.join(args.result_dir, set_cur, 'results.txt'))
        log('Datset: {0:10s} \n  PSNR = {1:2.2f}dB, SSIM = {2:1.4f}'.format(set_cur, psnr_avg, ssim_avg))