In [None]:
"""
colab version of DNCNN https://github.com/SaoYan/DnCNN-PyTorch using GPU
"""

'\ncolab version of DNCNN https://github.com/SaoYan/DnCNN-PyTorch using GPU\n'

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install tensorboardX



In [None]:
import math
import torch
import torch.nn as nn
import numpy as np
import os
import os.path
import random
import h5py
import cv2
import glob
import torch.utils.data as udata
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.autograd import Variable
from torchvision import utils

cuda = True if torch.cuda.is_available() else False
torch.cuda.empty_cache()

In [None]:
def get_default_device():
    if torch.cuda.is_available():
        print(torch.device('cuda'), torch.cuda.get_device_name(0))
        return torch.device('cuda')
    else:
        print(torch.device('cpu'))
        return torch.device('cpu')

device = get_default_device()

cuda Tesla P100-PCIE-16GB


In [None]:
#From utils.py
def data_augmentation(image, mode):
    out = np.transpose(image, (1,2,0))
    if mode == 0:
        # original
        out = out
    elif mode == 1:
        # flip up and down
        out = np.flipud(out)
    elif mode == 2:
        # rotate counterwise 90 degree
        out = np.rot90(out)
    elif mode == 3:
        # rotate 90 degree and flip up and down
        out = np.rot90(out)
        out = np.flipud(out)
    elif mode == 4:
        # rotate 180 degree
        out = np.rot90(out, k=2)
    elif mode == 5:
        # rotate 180 degree and flip
        out = np.rot90(out, k=2)
        out = np.flipud(out)
    elif mode == 6:
        # rotate 270 degree
        out = np.rot90(out, k=3)
    elif mode == 7:
        # rotate 270 degree and flip
        out = np.rot90(out, k=3)
        out = np.flipud(out)
    return np.transpose(out, (2,0,1))


#From dataset.py
def normalize(data):
    return data/255.

def Im2Patch(img, win, stride=1):
    k = 0
    endc = img.shape[0]
    endw = img.shape[1]
    endh = img.shape[2]
    patch = img[:, 0:endw-win+0+1:stride, 0:endh-win+0+1:stride]
    TotalPatNum = patch.shape[1] * patch.shape[2]
    Y = np.zeros([endc, win*win,TotalPatNum], np.float32)
    for i in range(win):
        for j in range(win):
            patch = img[:,i:endw-win+i+1:stride,j:endh-win+j+1:stride]
            Y[:,k,:] = np.array(patch[:]).reshape(endc, TotalPatNum)
            k = k + 1
    return Y.reshape([endc, win, win, TotalPatNum])

def prepare_data(data_path, patch_size, stride, aug_times=1):
    # train
    print('process training data')
    scales = [1, 0.9, 0.8, 0.7]
    files = glob.glob(os.path.join(data_path, 'train', '*.png'))
    files.sort()
    h5f = h5py.File('train.h5', 'w')
    train_num = 0
    for i in range(len(files)):
        img = cv2.imread(files[i])
        h, w, c = img.shape
        for k in range(len(scales)):
            Img = cv2.resize(img, (int(h*scales[k]), int(w*scales[k])), interpolation=cv2.INTER_CUBIC)
            Img = np.expand_dims(Img[:,:,0].copy(), 0)
            Img = np.float32(normalize(Img))
            patches = Im2Patch(Img, win=patch_size, stride=stride)
            print("file: %s scale %.1f # samples: %d" % (files[i], scales[k], patches.shape[3]*aug_times))
            for n in range(patches.shape[3]):
                data = patches[:,:,:,n].copy()
                h5f.create_dataset(str(train_num), data=data)
                train_num += 1
                for m in range(aug_times-1):
                    data_aug = data_augmentation(data, np.random.randint(1,8))
                    h5f.create_dataset(str(train_num)+"_aug_%d" % (m+1), data=data_aug)
                    train_num += 1
    h5f.close()
    # val
    print('\nprocess validation data')
    files.clear()
    files = glob.glob(os.path.join(data_path, 'Set12', '*.png'))
    files.sort()
    h5f = h5py.File('val.h5', 'w')
    val_num = 0
    for i in range(len(files)):
        print("file: %s" % files[i])
        img = cv2.imread(files[i])
        img = np.expand_dims(img[:,:,0], 0)
        img = np.float32(normalize(img))
        h5f.create_dataset(str(val_num), data=img)
        val_num += 1
    h5f.close()
    print('training set, # samples %d\n' % train_num)
    print('val set, # samples %d\n' % val_num)

# New Section

In [None]:
#prepare data, DnCNN with known noise level, only run once
prepare_data(data_path='/content/drive/My Drive/data', patch_size=40, stride=10, aug_times=1)

process training data
file: /content/drive/My Drive/data/train/test_001.png scale 1.0 # samples: 225
file: /content/drive/My Drive/data/train/test_001.png scale 0.9 # samples: 169
file: /content/drive/My Drive/data/train/test_001.png scale 0.8 # samples: 121
file: /content/drive/My Drive/data/train/test_001.png scale 0.7 # samples: 81
file: /content/drive/My Drive/data/train/test_002.png scale 1.0 # samples: 225
file: /content/drive/My Drive/data/train/test_002.png scale 0.9 # samples: 169
file: /content/drive/My Drive/data/train/test_002.png scale 0.8 # samples: 121
file: /content/drive/My Drive/data/train/test_002.png scale 0.7 # samples: 81
file: /content/drive/My Drive/data/train/test_003.png scale 1.0 # samples: 225
file: /content/drive/My Drive/data/train/test_003.png scale 0.9 # samples: 169
file: /content/drive/My Drive/data/train/test_003.png scale 0.8 # samples: 121
file: /content/drive/My Drive/data/train/test_003.png scale 0.7 # samples: 81
file: /content/drive/My Drive/dat

In [None]:
save_dir = "/content/drive/My Drive/"

In [None]:
class Dataset(udata.Dataset):
    def __init__(self, train=True):
        super(Dataset, self).__init__()
        self.train = train
        if self.train:
            h5f = h5py.File('train.h5', 'r')
        else:
            h5f = h5py.File('val.h5', 'r')
        self.keys = list(h5f.keys())
        random.shuffle(self.keys)
        h5f.close()
    def __len__(self):
        return len(self.keys)
    def __getitem__(self, index):
        if self.train:
            h5f = h5py.File('train.h5', 'r')
        else:
            h5f = h5py.File('val.h5', 'r')
        key = self.keys[index]
        data = np.array(h5f[key])
        h5f.close()
        return torch.Tensor(data)

In [None]:
class DnCNN(nn.Module):
    def __init__(self, channels, num_of_layers=17):
        super(DnCNN, self).__init__()
        kernel_size = 3
        padding = 1
        features = 64
        layers = []
        layers.append(nn.Conv2d(in_channels=channels, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
        layers.append(nn.ReLU(inplace=True))
        for _ in range(num_of_layers-2):
            layers.append(nn.Conv2d(in_channels=features, out_channels=features, kernel_size=kernel_size, padding=padding, bias=False))
            layers.append(nn.BatchNorm2d(features))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Conv2d(in_channels=features, out_channels=channels, kernel_size=kernel_size, padding=padding, bias=False))
        self.dncnn = nn.Sequential(*layers)
    def forward(self, x):
        out = self.dncnn(x)
        return out


In [None]:
#From utils.py
from skimage.metrics import peak_signal_noise_ratio
def weights_init_kaiming(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('Linear') != -1:
        nn.init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
    elif classname.find('BatchNorm') != -1:
        # nn.init.uniform(m.weight.data, 1.0, 0.02)
        m.weight.data.normal_(mean=0, std=math.sqrt(2./9./64.)).clamp_(-0.025,0.025)
        nn.init.constant(m.bias.data, 0.0)

def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += peak_signal_noise_ratio(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])

In [None]:
#From Main.py main()
# Load dataset
print('Loading dataset ...\n')
dataset_train = Dataset(train=True)
dataset_val = Dataset(train=False)
loader_train = DataLoader(dataset=dataset_train, num_workers=4, batch_size=128, shuffle=True)
print("# of training samples: %d\n" % int(len(dataset_train)))
# Build model
net = DnCNN(channels=1, num_of_layers=17)
net.apply(weights_init_kaiming)
criterion = nn.MSELoss(size_average=False)

# Move to GPU
if cuda:

  device_ids = [0]
  model = nn.DataParallel(net, device_ids=device_ids).cuda()
  criterion.cuda()
"""

device_ids = [0]
model = nn.DataParallel(net, device_ids=device_ids)
"""

# Optimizer
learning_rate = 1e-3
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
from tensorboardX import SummaryWriter
import argparse

# parser = argparse.ArgumentParser(description="DnCNN")
# opt = parser.parse_args()

# training
writer = SummaryWriter("logs")
step = 0
train_epoch = 1
milestone = 30 #When to decay learning rate; should be less than epochs
noise_level = 25
for epoch in range(train_epoch):
    if epoch < milestone:
        current_lr = learning_rate
    else:
        current_lr = learning_rate / 10.
    
    # set learning rate
    for param_group in optimizer.param_groups:
        param_group["lr"] = current_lr
    print('learning rate %f' % current_lr)
    
    # train
    for i, data in enumerate(loader_train, 0):
        # training step
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        img_train = data
        noise = torch.FloatTensor(img_train.size()).normal_(mean=0, std=noise_level/255.)
        imgn_train = img_train + noise
        # save img_train imgn_train, out_train
        
        img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda())
        
        noise = Variable(noise.cuda())
        out_train = model(imgn_train)
        loss = criterion(out_train, noise) / (imgn_train.size()[0]*2)
        loss.backward()
        optimizer.step()
        
        # results
        model.eval()
        out_train = torch.clamp(imgn_train-model(imgn_train), 0., 1.)
        psnr_train = batch_PSNR(out_train, img_train, 1.)
        print("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" %
            (epoch+1, i+1, len(loader_train), loss.item(), psnr_train))
        # if you are using older version of PyTorch, you may need to change loss.item() to loss.data[0]
        if step % 10 == 0:
            # Log the scalar values
            writer.add_scalar('loss', loss.item(), step)
            writer.add_scalar('PSNR on training data', psnr_train, step)
        step += 1
    ## the end of each epoch
    model.eval()

    # validate
    val_noiseL = 25
    psnr_val = 0
    with torch.no_grad():
      for k in range(len(dataset_val)):
          img_val = torch.unsqueeze(dataset_val[k], 0)
          noise = torch.FloatTensor(img_val.size()).normal_(mean=0, std=val_noiseL/255.)
          imgn_val = img_val + noise
          img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda(), volatile=True)
          out_val = torch.clamp(imgn_val-model(imgn_val), 0., 1.)
          psnr_val += batch_PSNR(out_val, img_val, 1.)
    psnr_val /= len(dataset_val)
    print("\n[epoch %d] PSNR_val: %.4f" % (epoch+1, psnr_val))
    writer.add_scalar('PSNR on validation data', psnr_val, epoch)
    # log the images
    out_train = torch.clamp(imgn_train-model(imgn_train), 0., 1.)
    Img = utils.make_grid(img_train.data, nrow=8, normalize=True, scale_each=True)
    Imgn = utils.make_grid(imgn_train.data, nrow=8, normalize=True, scale_each=True)
    Irecon = utils.make_grid(out_train.data, nrow=8, normalize=True, scale_each=True)
    writer.add_image('clean image', Img, epoch)
    writer.add_image('noisy image', Imgn, epoch)
    writer.add_image('reconstructed image', Irecon, epoch)
    # save model
    #torch.save(Disc.state_dict(), save_dir + 'model/disc_params.pkl')
    torch.save(model.state_dict(), save_dir + 'savedncnn/net.pth')