In [None]:
## Importing necessary libraries

import torch as tor
import torch.nn as nn
import torch.utils.data
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt

from skimage.measure import compare_ssim as ssim

In [None]:
def calc_total_mean(datafiles,num_chn = 3,verbose = False):

    """
        Find the total mean for mean centering
    """
    
    img_sum = 0
    num_files = len(datafiles)

    for e,file in enumerate(datafiles):

        if(num_chn == 3):
            img = cv2.resize(cv2.imread(file),(256,256))
            img_sum += img
        elif(num_chn == 1):
            img = cv2.imread(file,0)
            img_sum += img
        else:
            assert "Incorrect number of channels"

        if(verbose):
            print(e,file)


    return np.float32(img_sum) / num_files

In [None]:
class dncnn(nn.Module):

    """
        DnCNN Module
    """
    
    def __init__(self,in_channels = 3,depth = 17):

        super().__init__()

        layers = []

        layers.append(nn.Conv2d(in_channels = in_channels,out_channels = 64,kernel_size = (3,3),padding = 1))
        layers.append(nn.ReLU(inplace = True))

        for l in range(depth - 2):
            layers.append(nn.Conv2d(in_channels = 64,out_channels = 64,kernel_size = (3,3),padding = 1))
            layers.append(nn.ReLU(inplace = True))
            layers.append(nn.BatchNorm2d(64))

        layers.append(nn.Conv2d(in_channels = 64,out_channels = in_channels,kernel_size = (3,3),padding = 1))

        self.net = nn.Sequential(*layers)

        self.init_weights()

    def forward(self,y):

        out = self.net(y)
        return y - out

    def init_weights(self):

        for m in self.modules():

            if(isinstance(m,nn.Conv2d)):
                nn.init.orthogonal_(m.weight)

                if(m.bias is not None):
                    nn.init.constant_(m.bias,0)

            elif(isinstance(m,nn.BatchNorm2d)):
                nn.init.constant_(m.weight,1)
                nn.init.constant_(m.bias,0)

In [None]:
def getpatch(img,patchsize):

    """
        Obtains patches of 40 x 40
    """
    
    h,w,_ = img.shape
    x,y = np.random.randint(0,w - patchsize),np.random.randint(0,h - patchsize)

    patch = img[y:y + patchsize,x:x + patchsize,:]

    return patch

In [None]:
class dataset(torch.utils.data.Dataset):

    """
        Main dataset class
    """
    
    total_mean = 0.0

    def __init__(self,data_dir,data_size = -1,patchsize = 40,sigma_range = [],phase = "",apply_transform = False):

        super().__init__()

        self.datafiles = os.listdir(data_dir)
        self.datafiles = [os.path.join(data_dir,x) for x in self.datafiles]

        if(data_size == -1):
            self.data_size = len(self.datafiles)

        if(phase == "train"):
            dataset.total_mean = tor.from_numpy(calc_total_mean(self.data_files))
            dataset.total_mean = dataset.total_mean.permute(2,0,1)

        self.patchsize = patchsize
        
        if(len(sigma_range) == 0):
            self.sigma_range = [5,40]

    def __len__(self):
        return self.data_size

    def __getitem__(self,idx):

        imgname = self.datafiles[idx]
        img = np.float32(cv2.imread(imgname,1))
        clean_patch = getpatch(img,self.patchsize)
        clean_patch = tor.from_numpy(clean_patch).permute(2,0,1)

        sigma = np.random.randint(self.sigma_range[0],self.sigma_range[1])
        noise = tor.randn(clean_patch.size()).mul_(sigma)

        noisy_patch = clean_patch + noise

        return noisy_patch,clean_patch

In [None]:
def train(net,epochs,dataloaders,hyper_params,reset = True,save = False):

    if(reset):
        net.init_weights()
        print("/////////////// Weights Reset \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\")

    trainloader,valloader = dataloaders

    lr,reg = hyper_params

    optimizer = tor.optim.Adam(net.parameters(),lr = lr,weight_decay = reg)
    criterion = nn.MSELoss()

    for epoch in range(epochs):

        batch_losses = []

        for batch_idx,(noisy_patch,clean_patch) in enumerate(trainloader):

            noisy_patch,clean_patch = noisy_patch.to(device),clean_patch.to(device)

            optimizer.zero_grad()

            out = net(noisy_patch)
            loss = criterion(out,clean_patch)
            loss.backward()
            optimizer.step()

            batch_losses.append(loss.item())

        print("Epoch: ",epoch,"train loss: ",np.mean(batch_losses))

        net.eval()
        with tor.no_grad():
            
            batch_losses = []
            
            for batch_idx,(valdata,vallabel) in enumerate(valloader):

                valdata,vallabel = valdata.to(device),vallabel.to(device)

                valout = net(valdata)
                loss = criterion(valout,vallabel)

                batch_losses.append(loss.item())

            print("Val Loss: ",np.mean(batch_losses))
    

        print("-------------------------------------------------------------------------")
        net.train()

In [None]:
traindir = "./dncnn_dataset/train/"
trainset = dataset(traindir)
trainloader = torch.utils.data.DataLoader(trainset,batch_size = 32,shuffle = True)

In [None]:
valdir = "./dncnn_dataset/val/"
valset = dataset(valdir)
valloader = torch.utils.data.DataLoader(valset,batch_size = 32,shuffle = True)

In [None]:
device = tor.device("cuda:0" if tor.cuda.is_available() else "cpu")
net = dncnn().to(device)

In [None]:
state = tor.load("/content/drive/My Drive/datasets/dncnn_dataset/dncnn_chk.pth.tar")
net.load_state_dict(state)

In [None]:
_ = net.train()

In [None]:
epochs = 200
lr = 1e-3
reg = 0.0
reset = False
hyper_params = [lr,reg]

dataloaders = [trainloader,valloader]

train(net,epochs,dataloaders,hyper_params,reset = reset)

In [None]:
testdir = "./dncnn_dataset/test/Set12/"
testset = dataset(valdir)
testloader = torch.utils.data.DataLoader(testset,batch_size = 32,shuffle = True)

In [None]:
criterion = nn.MSELoss()

net.eval()
with tor.no_grad():
    
    batch_losses = []
    
    for batch_idx,(testdata,testlabel) in enumerate(testloader):

        testdata,testlabel = testdata.to(device),testlabel.to(device)

        testout = net(testdata)
        loss = criterion(testout,testlabel)

        batch_losses.append(loss.item())

    print("Test Loss: ",np.mean(batch_losses))


print("-------------------------------------------------------------------------")
_ = net.train()

In [None]:
state = net.state_dict()

tor.save(state,"/content/drive/My Drive/datasets/dncnn_dataset/dncnn_sigmix_chk.pth.tar")

In [None]:
def psnr(img1, img2,PIXEL_MAX):
    """
        Calculates the peak signal-to-noise ratio of 2 images
        
        Arguments:
            img1: Image1
            img2: Image2
            
        Returns:
            The peak signal-to-noise ratio of the 2 images
    """
    mse = np.mean( (img1 - img2) ** 2 )
    if mse == 0:
        return 100
    
    return 20 * np.log10(PIXEL_MAX / np.sqrt(mse))

In [None]:
net.eval()

sigma = 35

img = np.float32(cv2.imread("./dncnn_dataset/test/Set12/06.png")) 
img = tor.from_numpy(img)

noise = tor.randn(img.size()).mul_(sigma)

noisy_img = img + noise

noisy_img_n = noisy_img.numpy()

noisy_img1 = ((noisy_img_n - noisy_img_n.min()) / (noisy_img_n.max() - noisy_img_n.min()))
img1 = img.detach().clone().cpu().numpy()
img1 = ((img1 - img1.min()) / (img1.max() - img1.min()))

print("input psnr: ",psnr(noisy_img1,img1,1.0))
print("input ssim: ",ssim(noisy_img1,img1,multichannel = True))

plt.imshow(noisy_img1)

In [None]:
noisy_img2 = noisy_img1 * 255
noisy_img2 = np.uint8(noisy_img2)

cv2.imwrite("plane_noisy.png",noisy_img2)

In [None]:
h,w,_ = noisy_img.size()
noisy_img = noisy_img.permute(2,0,1).unsqueeze(0).to(device)

res = net(noisy_img)
clean_img = res.squeeze().permute(1,2,0)
clean_img = clean_img.detach().clone().cpu().numpy()

# img1 = img.detach().clone().cpu().numpy()

clean_img1 = (clean_img - clean_img.min()) / (clean_img.max() - clean_img.min())
# img1 = (img1 - img1.min()) / (img1.max() - img1.min())

print("psnr: ",psnr(clean_img1,img1,1.0))
print("ssim: ",ssim(clean_img1,img1,multichannel = True))

plt.imshow(clean_img1)

In [None]:
clean_img2 = clean_img1 * 255
clean_img2 = np.uint8(clean_img2)

cv2.imwrite("plane_denoised.png",clean_img2)