In [1]:
#Importing all the libraries
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset,DataLoader
from torchvision import models,datasets,transforms

from tqdm import tqdm
import os
from PIL import Image
import matplotlib.pyplot as plt
import math
import random
import pickle

In [2]:
#Checking if a GPU with CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
#Creating a custom dataset class that generates the noisy counterpart of the clean image and returns both of them
class CustomDataset(Dataset):
  def __init__(self,dir_name,train_flag):
    """
    train_flag is True for the Train dataset and False for the evaluation dataset
    """
    super().__init__()
    #Defining the transforms for the train and test datasets
    self.test_transform = transforms.Compose([transforms.ToTensor()])

    self.train_flag = train_flag

    #Downloading the train and test datasets
    if train_flag:
      pfile = open(dir_name, 'rb')     #opening the pickle file
      self.dataset = pickle.load(pfile) #Reading the tensors from the pickle file
      pfile.close() #Closing the pickle file


  def __len__(self):
    if self.train_flag:
      return len(self.dataset)

    else:
      return 1

  def __getitem__(self, index):
    if self.train_flag:
      img = self.dataset[index] #of shape (3,40,40)
      return img

    else:
      im1 = Image.open("/content/drive/MyDrive/R2R_ML/an_image/test/Sony_4-5_125_3200_plant_13_mean.JPG").convert("RGB") #PIL object
      clean_img = self.test_transform(im1) #of shape (3,h,w)

      im2 = Image.open("/content/drive/MyDrive/R2R_ML/an_image/test/Sony_4-5_125_3200_plant_13_real.JPG").convert("RGB") #PIL object
      noisy_img = self.test_transform(im2) #of shape (3,h,w)
      return clean_img, noisy_img



In [4]:
#Defining a Block of the DnCNN
class Block(nn.Module):
  def __init__(self,k=3,p=1,c=64):
    super().__init__()
    self.conv = nn.Conv2d(in_channels=c,out_channels=c,kernel_size=k,padding=p,bias=False) #same padding convolution
    self.norm = nn.BatchNorm2d(c) #batch normalization
    self.relu = nn.ReLU(inplace=True) #activation function

  def forward(self,x):
    x = self.conv(x)
    x = self.norm(x)
    x = self.relu(x)
    return x

In [5]:
#Defining the DnCNN model
class DCNN(nn.Module):
  def __init__(self,k=3,p=1,c=64,l=17,in_c=1):
    super().__init__()
    layers = [nn.Conv2d(in_channels=in_c,out_channels=c,kernel_size=k,padding=p,bias=False), #First same padding convolution layer
              nn.ReLU(inplace=True)]
    layers.extend([Block(k,p,c) for _ in range(l-2)]) #Adding all the "Blocks" to the model
    layers.append(nn.Conv2d(in_channels=c,out_channels=in_c,kernel_size=k,padding=p,bias=False)) #Last same padding convolution layer

    self.all = nn.Sequential(*layers)

  def forward(self,x):
    out = self.all(x)
    return x-out


In [6]:
#Function for unsupervised training of the model
def train(criterion,optimizer,model,device,train_loader,sigma,alpha):
  """
  criterion is the loss function
  optimizer is the optimization algorithm used
  model is the denoiser model
  device is either CPU or GPU(cuda)
  train_loader is the DataLoader containing the training dataset
  sigma is the noise level
  alpha is the constant of 20
  """
  model.train()
  loop = tqdm(train_loader) #Used to visualized the progress in training
  cur_loss = 0.0

  for i,noisy in enumerate(loop): #iterating batch-by-batch through the dataset
    noisy = noisy.to(device) #Moving over the data to the "device"

    #Generating pair of noisy images from the noisy image
    noise = (torch.randn(noisy.shape)*(sigma/255.)).to(device)
    noisy1 = noisy+alpha*noise
    noisy2 = noisy-noise/alpha

    noisy2_pred = model(noisy1) #Passing the data through the model
    loss = criterion(noisy2_pred,noisy2) #Computing the loss

    optimizer.zero_grad() #Zeroing all the previous gradients
    loss.backward() #Computing the gradients for the current iteration
    optimizer.step() #Updating the weights of the model

    cur_loss += loss.item() #Keeping track of the loss
    loop.set_postfix(loss=cur_loss/(i+1)) #Printing the cumulative loss after each iteration

In [7]:
#Function for testing the model
def test(criterion,model,device,test_loader,alpha,sigma,T):
  """
  criterion is the function used to compute PSNR
  model is the denoiser model
  device is either CPU or GPU(cuda)
  test_loader is the DataLoader containing the test dataset
  sigma is the noise level
  alpha is the constant of 20
  T is the number of forward processes averaged to reduce the effect of recorruption
  """

  model.eval()
  loop = tqdm(test_loader) #Used to visualized the progress in testing
  total_mse = []

  with torch.no_grad(): #Ensures that the gradients are not computed
    for i,(clean,noisy) in enumerate(loop): #iterating batch-by-batch through the dataset
      clean,noisy = clean.to(device).float(), noisy.to(device).float() #Moving over the data to the "device"

      #Averaging T forward passes
      out = torch.zeros(clean.shape).to(device)
      for _ in range(T):
        noise = (torch.randn(noisy.shape)*(sigma/255.)).to(device)
        noisy_main = noisy+alpha*noise
        out += model(noisy_main) #Passing the data through the model

      clean_pred = torch.clamp(out/T,min=0.0,max=1.0) #Clips all the values greater than 1 or less than 0
      loss = (criterion(clean_pred,clean).mean(axis=(1,2,3))).tolist() #Computing MSE at an image level
      total_mse.extend(loss)

  total_mse_tensor = torch.tensor(total_mse)
  psnr = (-10*torch.log10(total_mse_tensor)).mean() #Computing the PSNR using the corresponding MSE values

  print(f"The PSNR is {psnr}")
  return psnr.item()

In [8]:
#Function to computer the number of parameters in a model
def number_of_parameters(model):
    return sum(params.numel() for params in model.parameters() if params.requires_grad)

In [9]:
#Wrapper function to train and evaluate the denoiser model
def wrapper(sigma):
  print(f"This is for sigma of {sigma}")

  #Defines the loaders for the train and test set
  train_set = CustomDataset("/content/drive/MyDrive/R2R_ML/an_image/PolyU_train.pkl",True)
  test_set = CustomDataset("",False)

  train_loader = DataLoader(train_set,batch_size=128,shuffle=True,num_workers=64)
  test_loader = DataLoader(test_set,batch_size=1,shuffle=False,num_workers=1)

  print(f"The number of images in the train set is {len(train_set)}")
  print(f"The number of images in the test set is {(len(test_set))}")

  #Defining the model, loss function and optimizer
  model = DCNN(in_c=3).to(device)
  criterion_train = nn.MSELoss()
  criterion_test = nn.MSELoss(reduce=False)
  optimizer = torch.optim.Adam(model.parameters(),lr=0.0001)
  epochs = 50
  alpha = 20
  T = 50

  print(f"The model has {number_of_parameters(model)} parameters")
  #Computing the PSNR between the noisy and clean image
  total_mse = []
  with torch.no_grad():
      for i,(clean,noisy) in enumerate(test_loader):
        clean,noisy = clean.to(device), noisy.to(device)
        loss = (criterion_test(noisy,clean).mean(axis=(1,2,3))).tolist()
        total_mse.extend(loss)

  total_mse_tensor = torch.tensor(total_mse)
  psnr = (-10*torch.log10(total_mse_tensor)).mean()
  print(f"The PSNR for an untrained densoiser is {psnr}")

  #Iterating through the epochs
  best_psnr = 0.0
  for epoch in range(epochs):
    print(f"The current epoch is {epoch}")
    train(criterion_train,optimizer,model,device,train_loader,sigma,alpha)
    cur_psnr = test(criterion_test,model,device,test_loader,alpha,sigma,T)
    if cur_psnr>best_psnr: #Saving the model with the best PSNR value
      best_psnr = cur_psnr
      torch.save(model.state_dict(), "Unsupervised"+str(epoch)+"_"+str(round(cur_psnr,2))+"_"+ str(sigma) + ".pt")


In [10]:
#For sigma value of 0.44
wrapper(0.4434824560303241)

This is for sigma of 0.4434824560303241
The number of images in the train set is 52367
The number of images in the test set is 1
The model has 558336 parameters


  self.pid = os.fork()


The PSNR for an untrained densoiser is 31.182998657226562
The current epoch is 0


100%|██████████| 410/410 [00:48<00:00,  8.39it/s, loss=0.00325]
100%|██████████| 1/1 [00:01<00:00,  1.97s/it]


The PSNR is 28.343982696533203
The current epoch is 1


100%|██████████| 410/410 [00:48<00:00,  8.53it/s, loss=0.00126]
100%|██████████| 1/1 [00:01<00:00,  1.99s/it]


The PSNR is 30.9163818359375
The current epoch is 2


100%|██████████| 410/410 [00:48<00:00,  8.46it/s, loss=0.00121]
100%|██████████| 1/1 [00:02<00:00,  2.03s/it]


The PSNR is 30.984088897705078
The current epoch is 3


100%|██████████| 410/410 [00:49<00:00,  8.35it/s, loss=0.00112]
100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


The PSNR is 30.27202796936035
The current epoch is 4


100%|██████████| 410/410 [00:50<00:00,  8.09it/s, loss=0.000953]
100%|██████████| 1/1 [00:02<00:00,  2.10s/it]


The PSNR is 31.207237243652344
The current epoch is 5


100%|██████████| 410/410 [00:50<00:00,  8.15it/s, loss=0.000739]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 31.161134719848633
The current epoch is 6


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000575]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 30.00558853149414
The current epoch is 7


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000462]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 31.536746978759766
The current epoch is 8


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000396]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.91158103942871
The current epoch is 9


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000356]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 31.754207611083984
The current epoch is 10


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000331]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.879947662353516
The current epoch is 11


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000312]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 31.687519073486328
The current epoch is 12


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000295]
100%|██████████| 1/1 [00:02<00:00,  2.11s/it]


The PSNR is 32.00699996948242
The current epoch is 13


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000279]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 32.0134391784668
The current epoch is 14


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000268]
100%|██████████| 1/1 [00:02<00:00,  2.11s/it]


The PSNR is 32.07668685913086
The current epoch is 15


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000261]
100%|██████████| 1/1 [00:02<00:00,  2.11s/it]


The PSNR is 31.656661987304688
The current epoch is 16


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000249]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 31.908782958984375
The current epoch is 17


100%|██████████| 410/410 [00:50<00:00,  8.10it/s, loss=0.000242]
100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


The PSNR is 32.14962387084961
The current epoch is 18


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000234]
100%|██████████| 1/1 [00:02<00:00,  2.15s/it]


The PSNR is 31.617238998413086
The current epoch is 19


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000232]
100%|██████████| 1/1 [00:02<00:00,  2.11s/it]


The PSNR is 31.97002410888672
The current epoch is 20


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.000227]
100%|██████████| 1/1 [00:02<00:00,  2.11s/it]


The PSNR is 31.420948028564453
The current epoch is 21


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000224]
100%|██████████| 1/1 [00:02<00:00,  2.11s/it]


The PSNR is 31.645999908447266
The current epoch is 22


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.000221]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.99186134338379
The current epoch is 23


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000218]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 32.226829528808594
The current epoch is 24


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000214]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 32.04560852050781
The current epoch is 25


100%|██████████| 410/410 [00:50<00:00,  8.10it/s, loss=0.000213]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 32.074947357177734
The current epoch is 26


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000212]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.94856834411621
The current epoch is 27


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000208]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 32.19762420654297
The current epoch is 28


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.000207]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 31.82235336303711
The current epoch is 29


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000207]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 32.08175277709961
The current epoch is 30


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.000206]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.385520935058594
The current epoch is 31


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000217]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 32.08048629760742
The current epoch is 32


100%|██████████| 410/410 [00:50<00:00,  8.11it/s, loss=0.000203]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 31.875328063964844
The current epoch is 33


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000203]
100%|██████████| 1/1 [00:02<00:00,  2.15s/it]


The PSNR is 32.024864196777344
The current epoch is 34


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.0002]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 32.13422393798828
The current epoch is 35


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000199]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 32.076351165771484
The current epoch is 36


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000198]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 30.8680362701416
The current epoch is 37


100%|██████████| 410/410 [00:50<00:00,  8.11it/s, loss=0.000199]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.946237564086914
The current epoch is 38


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.000197]
100%|██████████| 1/1 [00:02<00:00,  2.15s/it]


The PSNR is 31.79949951171875
The current epoch is 39


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000196]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.663591384887695
The current epoch is 40


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.000197]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.81661605834961
The current epoch is 41


100%|██████████| 410/410 [00:50<00:00,  8.11it/s, loss=0.000196]
100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


The PSNR is 30.887052536010742
The current epoch is 42


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000195]
100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


The PSNR is 32.147090911865234
The current epoch is 43


100%|██████████| 410/410 [00:50<00:00,  8.14it/s, loss=0.000194]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.876781463623047
The current epoch is 44


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.000195]
100%|██████████| 1/1 [00:02<00:00,  2.12s/it]


The PSNR is 30.916034698486328
The current epoch is 45


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000195]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 31.940099716186523
The current epoch is 46


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000192]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]


The PSNR is 32.09718322753906
The current epoch is 47


100%|██████████| 410/410 [00:50<00:00,  8.09it/s, loss=0.000191]
100%|██████████| 1/1 [00:02<00:00,  2.14s/it]


The PSNR is 31.876808166503906
The current epoch is 48


100%|██████████| 410/410 [00:50<00:00,  8.12it/s, loss=0.000191]
100%|██████████| 1/1 [00:02<00:00,  2.15s/it]


The PSNR is 32.08868408203125
The current epoch is 49


100%|██████████| 410/410 [00:50<00:00,  8.13it/s, loss=0.000192]
100%|██████████| 1/1 [00:02<00:00,  2.13s/it]

The PSNR is 32.10338592529297



