A Sample Notebook for realizing proposed compressed version of DnCNN framework
for image denoising.

In [None]:
import torch
import torch.nn as nn
import numpy as np
import random
import h5py

from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.autograd import Variable
from skimage.measure.simple_metrics import compare_psnr
import torch.nn.functional as F

torch.backends.cudnn.benchmark = True

from network import DnCNN, DnCNN_cheap

In [None]:
def batch_PSNR(img, imclean, data_range):
    Img = img.data.cpu().numpy().astype(np.float32)
    Iclean = imclean.data.cpu().numpy().astype(np.float32)
    PSNR = 0
    for i in range(Img.shape[0]):
        PSNR += compare_psnr(Iclean[i,:,:,:], Img[i,:,:,:], data_range=data_range)
    return (PSNR/Img.shape[0])

In [None]:
class Dataset(Dataset):
    def __init__(self, train=True):
        super(Dataset, self).__init__()
        self.train = train
        # Store images in .h5 file format
        if self.train:
            h5f = h5py.File('train.h5', 'r')
        else:
            h5f = h5py.File('val.h5', 'r')
        self.keys = list(h5f.keys())
        random.shuffle(self.keys)
        h5f.close()
    def __len__(self):
        return len(self.keys)
    def __getitem__(self, index):
        if self.train:
            h5f = h5py.File('train.h5', 'r')
        else:
            h5f = h5py.File('val.h5', 'r')
        key = self.keys[index]
        data = np.array(h5f[key])
        h5f.close()
        return torch.Tensor(data)


In [None]:
noiseL=25
learning_rate=0.001
batchSize=64
print('Loading dataset ...\n')
dataset_train = Dataset(train=True)
dataset_val = Dataset(train=False)
loader_train = DataLoader(dataset=dataset_train, num_workers=7, 
                          batch_size=batchSize, shuffle=True)
model = DnCNN_cheap() # DnCNN()
model=model.cuda()
criterion = nn.MSELoss(size_average=False)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)  
best_val=0 

for epoch in range(50):
    test_loss = 0
    epoch_loss = 0
    for i, data in enumerate(loader_train, 0):
        model.train()
        model.zero_grad()
        optimizer.zero_grad()
        img_train = data
        noise = torch.FloatTensor(img_train.size()).normal_(mean=0, std=noiseL/255.)
        imgn_train = img_train + noise
        img_train, imgn_train = Variable(img_train.cuda()), Variable(imgn_train.cuda())
        noise = Variable(noise.cuda())
        out_train = model(imgn_train)
        loss = criterion(out_train, noise) / (img_train.size()[0]*2)
        psnr_train = batch_PSNR(imgn_train-out_train, img_train, 1.)
        loss.backward()
        optimizer.step()
        epoch_loss = epoch_loss+loss.item()
        # results
        if i%30 == 0:
             print("[epoch %d][%d/%d] loss: %.4f PSNR_train: %.4f" %
                  (epoch+1, i+1, len(loader_train), loss.item(), psnr_train))
    model.eval()
    epoch_loss=epoch_loss/len(loader_train)
    psnr_val = 0
    for k in range(len(dataset_val)):
        img_val = torch.unsqueeze(dataset_val[k], 0)
        noise = torch.FloatTensor(img_val.size()).normal_(mean=0, std=noiseL/255.)
        imgn_val = img_val + noise
        img_val, imgn_val = Variable(img_val.cuda()), Variable(imgn_val.cuda())
        with torch.no_grad():
          out_val = torch.clamp(imgn_val-model(imgn_val), 0., 1.)
        psnr_val += batch_PSNR(out_val, img_val, 1.)
        test_loss += criterion(out_val, img_val) / (imgn_train.size()[0]*2)
    psnr_val /= len(dataset_val)
    test_loss /= len(dataset_val)
    print("[epoch %d] Train Loss: %.4f Val Loss: %.4f PSNR_val: %.4f" %
                  (epoch+1, epoch_loss,test_loss.item(), psnr_val))
    
    if epoch%10==0:
        learning_rate=learning_rate*0.5
        for param_group in optimizer.param_groups:
            param_group['lr'] = learning_rate
            
    if psnr_val>=best_val:
        torch.save(model.state_dict(),'model.pth')
        best_val=psnr_val 