In [None]:
# -*- coding: utf-8 -*-

# PyTorch 0.4.1, https://pytorch.org/docs/stable/index.html

# =============================================================================
#  @article{zhang2017beyond,
#    title={Beyond a {Gaussian} denoiser: Residual learning of deep {CNN} for image denoising},
#    author={Zhang, Kai and Zuo, Wangmeng and Chen, Yunjin and Meng, Deyu and Zhang, Lei},
#    journal={IEEE Transactions on Image Processing},
#    year={2017},
#    volume={26}, 
#    number={7}, 
#    pages={3142-3155}, 
#  }
# by Kai Zhang (08/2018)
# cskaizhang@gmail.com
# https://github.com/cszn
# modified on the code from https://github.com/SaoYan/DnCNN-PyTorch
# =============================================================================

# run this to train the model

# =============================================================================
# For batch normalization layer, momentum should be a value from [0.1, 1] rather than the default 0.1. 
# The Gaussian noise output helps to stablize the batch normalization, thus a large momentum (e.g., 0.95) is preferred.
# =============================================================================

In [1]:
!pip install import_ipynb



In [2]:
!pip install easydict



In [1]:
import argparse
import re
import os, glob, datetime, time
import numpy as np
import torch
import torch.nn as nn
from torch.nn.modules.loss import _Loss
import torch.nn.init as init
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.optim.lr_scheduler import MultiStepLR

import import_ipynb
import data_generator as dg
from data_generator import DenoisingDataset

import easydict

importing Jupyter notebook from data_generator.ipynb


In [13]:
# Params

#parser = argparse.ArgumentParser(description='PyTorch DnCNN')
#parser.add_argument('--model', default='DnCNN', type=str, help='choose a type of model')
#parser.add_argument('--batch_size', default=128, type=int, help='batch size')
#parser.add_argument('--train_data', default='Colab Notebooks/Set12', type=str, help='path of train data')
#parser.add_argument('--sigma', default=25, type=int, help='noise level')
#parser.add_argument('--epoch', default=180, type=int, help='number of train epoches')
#parser.add_argument('--lr', default=1e-3, type=float, help='initial learning rate for Adam')
#args = parser.parse_args()

args = easydict.EasyDict({
    "model" : "DnCNN",
    "batch_size" : 128,
    "train_data" : 'Noise_image_Training(qp22)',
    "sigma" : 25,
    "epoch" : 10,
    "lr" : 1e-3
})

batch_size = args.batch_size
cuda = torch.cuda.is_available()
n_epoch = args.epoch
sigma = args.sigma

save_dir = os.path.join('models', args.model+'_' + 'sigma' + str(sigma))

if not os.path.exists(save_dir):
    os.mkdir(save_dir)

class DnCNN(nn.Module):
    def __init__(self, depth=17, n_channels=64, image_channels=1, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        layers = []

        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.95))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
        self._initialize_weights()

    def forward(self, x):
        y = x
        out = self.dncnn(x)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)


class sum_squared_error(_Loss):  # PyTorch 0.4.1
    """
    Definition: sum_squared_error = 1/2 * nn.MSELoss(reduction = 'sum')
    The backward is defined as: input-target
    """
    def __init__(self, size_average=None, reduce=None, reduction='mean'):
        super(sum_squared_error, self).__init__(size_average, reduce, reduction)
    
    def forward(self, input, target):
        # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
        return torch.nn.functional.l1_loss(input, target, size_average=None, reduce=None, reduction='mean').div_(2)

    #def forward(self, input, target):
    #    # return torch.sum(torch.pow(input-target,2), (0,1,2,3)).div_(2)
    #    return torch.nn.functional.mse_loss(input, target, size_average=None, reduce=None, reduction='sum').div_(2)


def findLastCheckpoint(save_dir):
    file_list = glob.glob(os.path.join(save_dir, 'model_*.pth'))
    if file_list:
        epochs_exist = []
        for file_ in file_list:
            result = re.findall(".*model_(.*).pth.*", file_)
            epochs_exist.append(int(result[0]))
        initial_epoch = max(epochs_exist)
    else:
        initial_epoch = 0
    return initial_epoch


def log(*args, **kwargs):
     print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs)

if __name__ == '__main__':
    # model selection
    print('===> Building model')
    model = DnCNN()
    
    initial_epoch = findLastCheckpoint(save_dir=save_dir)  # load the last model in matconvnet style
    
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        # model.load_state_dict(torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)))
        model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))
    
    model.train()
    # criterion = nn.MSELoss(reduction = 'sum')  # PyTorch 0.4.1
    criterion = sum_squared_error()
    
    if cuda:
        model = model.cuda()
         # device_ids = [0]
         # model = nn.DataParallel(model, device_ids=device_ids).cuda()
         # criterion = criterion.cuda()
    
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2)  # learning rates
    
    for epoch in range(initial_epoch, n_epoch):

        scheduler.step(epoch)  # step to the learning rate in this epcoh
        
        xs = dg.datagenerator(data_dir='Noise_image_Training(qp22)')
        xs = xs.astype('float32')/255.0
        xs = torch.from_numpy(xs.transpose((0, 3, 1, 2)))
        
        DDataset = DenoisingDataset(xs, sigma)
        DLoader = DataLoader(dataset=DDataset, num_workers=4, drop_last=True, batch_size=batch_size, shuffle=True)
        epoch_loss = 0
        start_time = time.time()

        for n_count, batch_yx in enumerate(DLoader):
                optimizer.zero_grad()
                if cuda:
                    batch_x, batch_y = batch_yx[1].cuda(), batch_yx[0].cuda()
                loss = criterion(model(batch_y), batch_x)
                epoch_loss += loss.item()
                loss.backward()
                optimizer.step()
                if n_count % 10 == 0:
                    print('%4d %4d / %4d loss = %2.4f' % (epoch+1, n_count, xs.size(0)//batch_size, loss.item()/batch_size))
        elapsed_time = time.time() - start_time

        log('epoch = %4d , loss = %4.4f , time = %4.2f s' % (epoch+1, epoch_loss/n_count, elapsed_time))
        np.savetxt('train_result.txt', np.hstack((epoch+1, epoch_loss/n_count, elapsed_time)), fmt='%2.4f')
        # torch.save(model.state_dict(), os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))
        torch.save(model, os.path.join(save_dir, 'model_%03d.pth' % (epoch+1)))

===> Building model
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
init weight
^_^-training data finished-^_^
   1    0 / 4884 loss = 0.0018
   1   10 / 4884 loss = 0.0003
   1   20 / 4884 loss = 0.0003
   1   30 / 4884 loss = 0.0003
   1   40 / 4884 loss = 0.0003
   1   50 / 4884 loss = 0.0003
   1   60 / 4884 loss = 0.0003
   1   70 / 4884 loss = 0.0003
   1   80 / 4884 loss = 0.0002
   1   90 / 4884 loss = 0.0002
   1  100 / 4884 loss = 0.0002
   1  110 / 4884 loss = 0.0002
   1  120 / 4884 loss = 0.0002
   1  130 / 4884 loss = 0.0002
   1  140 / 4884 loss = 0.0002
   1  150 / 4884 loss = 0.0002
   1  160 / 4884 loss = 0.0002
   1  170 / 4884 loss = 0.0002
   1  180 / 4884 loss = 0.0002
   1  190 / 4884 loss = 0.0002
   1  200 / 4884 loss = 0.0002
   1  210 / 4884 loss = 0.0002
   1  220 / 4884 loss = 0.0002
   1  230 / 4884 loss = 0.0002
 

In [11]:
test = np.array([[1, 2], [3, 4]])
test.shape
test

array([[1, 2],
       [3, 4]])

In [5]:
import cv2
img = cv2.imread('Noise_image_Training(qp22)_original/000000166642.jpg', 0)

In [8]:
xs.shape

torch.Size([625152, 1, 40, 40])