#### Import Libraries

In [1]:
#*******________________________Code for Training DSD_Net______________________________*******

import torch
import torch.nn as nn
import os, glob
import random, csv
from torch.autograd import Variable

#import visdom
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
from PIL import Image
from tqdm.auto import tqdm
import numpy as np


#### Custom Dataloader

In [2]:
import random
import numpy as np
import torch.utils.data as data
import utils.utils_image as util


class DatasetSR(data.Dataset):
    

    def __init__(self):
        super(DatasetSR, self).__init__()
        self.n_channels = 1
        self.sf = 1
        self.patch_size = 128
        self.L_size = self.patch_size // self.sf

        # ------------------------------------
        # get paths of L/H
        # ------------------------------------
        self.paths_H = util.get_image_paths('/notebooks/JPEG_HR_Gray/DIV2K_train_HR_Grayscale')  #The Path of Ground Truth Images
        self.paths_L = util.get_image_paths('/notebooks/JPEG_LR_Gray/JPEG_Grayscale/JPEG_05')    #The path of LR Images 
        '''
        assert self.paths_H, 'Error: H path is empty.'
        if self.paths_L and self.paths_H:
            assert len(self.paths_L) == len(self.paths_H), 'L/H mismatch - {}, {}.'.format(len(self.paths_L), len(self.paths_H))
        '''
    def __getitem__(self, index):

        H_path = None
        # ------------------------------------
        # get H image
        # ------------------------------------
        L_path = self.paths_L[index]
        img_L = util.imread_uint(L_path, self.n_channels)
        img_L = util.uint2single(img_L)

        # ------------------------------------
        # modcrop
        # ------------------------------------
        

        # ------------------------------------
        # get L image
        # ------------------------------------
        if self.paths_H:
            # --------------------------------
            # directly load L image
            # --------------------------------
            H_path = self.paths_H[index % 800]
            img_H = util.imread_uint(H_path, self.n_channels)
            img_H = util.uint2single(img_H)
            img_H = util.modcrop(img_H, self.sf)
        else:
            # --------------------------------
            # sythesize L image via matlab's bicubic
            # --------------------------------
            H, W = img_H.shape[:2]
            img_L = util.imresize_np(img_H, 1 / self.sf, True)

        # ------------------------------------
        # if train, get L/H patch pair
        # ------------------------------------

        H, W, C = img_L.shape

        # --------------------------------
        # randomly crop the L patch
        # --------------------------------
        rnd_h = random.randint(0, max(0, H - self.L_size))
        rnd_w = random.randint(0, max(0, W - self.L_size))
        img_L = img_L[rnd_h:rnd_h + self.L_size, rnd_w:rnd_w + self.L_size]

        # --------------------------------
        # crop corresponding H patch
        # --------------------------------
        rnd_h_H, rnd_w_H = int(rnd_h * self.sf), int(rnd_w * self.sf)
        img_H = img_H[rnd_h_H:rnd_h_H + self.patch_size, rnd_w_H:rnd_w_H + self.patch_size]

        # --------------------------------
        # augmentation - flip and/or rotate
        # --------------------------------
        
        mode = random.randint(0, 7)
        img_L, img_H = util.augment_img(img_L, mode=mode), util.augment_img(img_H, mode=mode)
        
        # ------------------------------------
        # L/H pairs, HWC to CHW, numpy to tensor
        # ------------------------------------
        img_H, img_L = util.single2tensor3(img_H), util.single2tensor3(img_L)

        if L_path is None:
            L_path = H_path

        return {'L': img_L, 'H': img_H, 'L_path': L_path, 'H_path': H_path}

    def __len__(self):
        return len(self.paths_L)


#### Create Dataset and Dataloader

In [3]:
train_set = DatasetSR()
#train_sampler = DistributedSampler(train_set, shuffle=True, drop_last=True)
train_loader = DataLoader(train_set,
                          batch_size=8,
                          shuffle=True,
                          drop_last=True,
                          num_workers = 4,
                          pin_memory=True)

In [4]:
'''
from matplotlib import pyplot as plt
arr = x[1].permute(1,2,0)
#arr = torch.Tensor.numpy(x[1])
print(arr.shape)
plt.imshow(arr)

plt.imshow(x[1].permute(1, 2, 0),cmap='gray')
'''

"\nfrom matplotlib import pyplot as plt\narr = x[1].permute(1,2,0)\n#arr = torch.Tensor.numpy(x[1])\nprint(arr.shape)\nplt.imshow(arr)\n\nplt.imshow(x[1].permute(1, 2, 0),cmap='gray')\n"

### DCT Submodule

In [5]:
import torch
import torch.nn as nn
#import tensorflow as tf
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

import torchvision

class LWAB(nn.Module):
    def __init__(self):
        super(LWAB, self).__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels= 64, out_channels= 256, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            nn.PReLU(),
            nn.Conv2d(in_channels=256, out_channels= 64, kernel_size=(3,3), stride=(1,1), padding=(1,1)),
            nn.PReLU(),
        )

    def forward(self, x):
        out_1 = self.block(x)
        out_2 = self.block(out_1)
        
        return x + out_1 + out_2

class DCT_Net(nn.Module):
    def __init__(self):
        super(DCT_Net, self).__init__()
        #Patch_Extraction is the 1st layer of DCT Network with Channels of 64 and filters of 8 X 8
        self.Patch_Extraction = nn.Sequential(nn.Conv2d(in_channels= 1, out_channels= 64, kernel_size=(8,8), stride=(1,1), padding=(3,3)),
                                      nn.ReLU()
                                      )
        #DCT and IDCT layers Have Channel size 64 and filter size 1 X 1
        self.DCT = nn.Sequential(nn.Conv2d(in_channels= 64, out_channels= 64, kernel_size=(1,1), stride=(1,1), padding=(0,0)),
                                      nn.ReLU()
                                      )
        
        self.IDCT = nn.Sequential(nn.Conv2d(in_channels= 64, out_channels= 64, kernel_size=(1,1), stride=(1,1), padding=(0,0)),
                                      nn.ReLU()
                                      )
        # The Weights of DCT and IDCT are not updated during the training process
        self.DCT.requires_grad_(False)    
        self.IDCT.requires_grad_(False)
        
        #LWAB_body contains the code for implementing LWAB. 3 LWAB Blocks are used in D3SN
        
        LWAB_body = [LWAB() for _ in range(3)]       
        self.LWAB_body = nn.Sequential(*LWAB_body)
        
        #Patch_Reconstruction Outputs the Final Output of DCT Branch
        
        self.Patch_Reconstruction = nn.Sequential(nn.Conv2d(in_channels= 64, out_channels= 1, kernel_size=(8,8), stride=(1,1), padding=(4,4)),
                                      )
       
        
        
        
        
        
        
        
    def forward(self, x):
        output = self.Patch_Extraction(x)                                #Extract the image Patch   
        #print('E Layer: ', output.shape)
        output_DCT = self.DCT(output)                                    #Pass Through DCT layer
        output_LWARG = self.LWAB_body(output_DCT)                        #Pass Through LWARGs
        #print('LWARGs: ', output_LWARG.shape)
        output_final = output_LWARG + output_DCT                         #Residual Connection: LWARGs Output + DCT
        #print('LWARG_Res: ', output_final.shape)
        output_IDCT = self.IDCT(output_final)                            #Pass Through IDCT layer
        #print('IDCT: ', output_IDCT.shape)
        output_image_1 = self.Patch_Reconstruction(output_IDCT)          #Final output of DCT Branch
         
        
        #print(output_image_1.shape)
        
        
       
        return output_image_1
    


### DPD Submodule

In [6]:
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F 
import numpy as np
from torch.autograd import Variable

class Up_Sample(nn.Module):
    def __init__(self, scale, act=False):
        super(Up_Sample, self).__init__()
        modules = []
        modules.append(nn.PixelShuffle(scale))
        self.body = nn.Sequential(*modules)
    def forward(self, x):
        x = self.body(x)
        return x


class make_dense(nn.Module):
  def __init__(self, nChannels, growthRate, kernel_size=3):
    super(make_dense, self).__init__()
    self.conv = nn.Sequential(nn.Conv2d(in_channels= nChannels, out_channels= growthRate, kernel_size=3, stride=(1,1), padding=(1,1),bias = True),
                                      nn.PReLU()
                                      )
  def forward(self, x):
    out = self.conv(x)
    out = torch.cat((x, out), 1)
    return out

# Residual dense block (RDB) architecture
class RDB(nn.Module):
  def __init__(self, nChannels, nDenselayer, growthRate):
    super(RDB, self).__init__()
    nChannels_ = nChannels
    modules = []
    for i in range(nDenselayer):    
        modules.append(make_dense(nChannels_, growthRate))
        nChannels_ += growthRate 
    self.dense_layers = nn.Sequential(*modules)    
    self.conv_1x1 = nn.Conv2d(nChannels_, nChannels, kernel_size=1, padding=0, bias=False)
  def forward(self, x):
    out = self.dense_layers(x)
    out = self.conv_1x1(out)
    out = out + x
    return out
class DPD_Net(nn.Module):
    def __init__(self):
        super(DPD_Net, self).__init__()
        #Patch_Extraction is the 1st layer of DPD_Net. 1 X 1 Convolution Performed on the input image followed by a PRelu Layer
        self.Patch_Extraction = nn.Sequential(nn.Conv2d(in_channels= 1, out_channels= 64, kernel_size=(1,1), stride=(1,1), padding=(0,0)),
                                      nn.PReLU()
                                      )
        
        
        
        self.AWDRU_64X64 = RDB(64,3,64)   
        
        
        self.AWDRU_128X128 = RDB(128, 3, 128)
        
        
        self.AWDRU_256X256 = RDB(256, 3, 256)
        
        
        #Up-sampling Module. Upsamples by a scale of 2
        self.upsample = Up_Sample(2)
        

        
        #Downsampling Module
        self.Down_Sample_64_128 = nn.Sequential(nn.Conv2d(in_channels= 64, out_channels= 128, kernel_size=(1,1), stride=(2,2), padding=(0,0)),
                                      nn.PReLU()
                                      )
        
        self.Down_Sample_128_256 = nn.Sequential(nn.Conv2d(in_channels= 128, out_channels= 256, kernel_size=(1,1), stride=(2,2), padding=(0,0)),
                                      nn.PReLU()
                                      )
        
        
        
        self.Patch_Reconstruction = nn.Sequential(nn.Conv2d(in_channels= 128, out_channels= 1, kernel_size=(1,1), stride=(1,1), padding=(0,0)),
                                                  nn.PReLU()
                                      )
       
        
        
        
        
        
        
        
    def forward(self, x):
        feature = self.Patch_Extraction(x)
        #print(feature.shape)
        AWDRU_1 = self.AWDRU_64X64(feature)
        #print('k')
        AWDRU_2 = self.AWDRU_64X64(AWDRU_1)
        AWDRU_3 = self.AWDRU_64X64(AWDRU_2)
        Down_1 = self.Down_Sample_64_128(AWDRU_3)
        AWDRU_4 = self.AWDRU_128X128(Down_1)
        AWDRU_5 = self.AWDRU_128X128(AWDRU_4)
        AWDRU_6 = self.AWDRU_128X128(AWDRU_5)
        Down_2 = self.Down_Sample_128_256(AWDRU_6)
        AWDRU_7 = self.AWDRU_256X256(Down_2)
        AWDRU_8 = self.AWDRU_256X256(AWDRU_7)
        AWDRU_9 = self.AWDRU_256X256(AWDRU_8)
        Concat_1 = torch.cat((AWDRU_7, AWDRU_9), 1)
        #print('COncat_1',Concat_1.shape)
        Up_1 = self.upsample(Concat_1)
        #print('Up1: ',Up_1.shape)
        AWDRU_10 = self.AWDRU_128X128(Up_1)
        Concat_2 = torch.cat((AWDRU_10, AWDRU_6), 1)
        AWDRU_11 = self.AWDRU_256X256(Concat_2)
        AWDRU_12 = self.AWDRU_256X256(AWDRU_11)
        Up_2 = self.upsample(AWDRU_12)
        AWDRU_13 = self.AWDRU_64X64(Up_2)
        Concat_3 = torch.cat((AWDRU_3, AWDRU_13), 1)
        AWDRU_14 = self.AWDRU_128X128(Concat_3)
        AWDRU_15 = self.AWDRU_128X128(AWDRU_14)
        out = self.Patch_Reconstruction(AWDRU_15)
        final = out + x
        
        
            
        
        
       
        return final


### D3SN Module

In [7]:
class D3SN(nn.Module):
    def __init__(self):
        super(D3SN, self).__init__()
        self.DCT = DCT_Net()
        self.DPD = DPD_Net()
        
        self.Patch_Reconstruction = nn.Sequential(nn.Conv2d(in_channels= 2, out_channels= 1, kernel_size=(1,1), stride=(1,1), padding=(0,0),bias = False)
                                      )      
        
    def forward(self, x):
        DCT = self.DCT(x)
        DPD = self.DPD(x)
        Concat = torch.cat((DCT,DPD),1)
        final = self.Patch_Reconstruction(Concat)
        
            
        
        
       
        return final


In [8]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = D3SN()
model.to(device)
co_ef = 1e-3
error = nn.L1Loss()

learning_rate = 1e-4 
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,betas=(0.9, 0.999), eps=1e-08) # Code to use the Adam Optimizer
scheduler = MultiStepLR(optimizer, milestones=[10 * 100 ,20 * 100 , 30 * 100, 40 * 100], gamma=0.5) #milestones=[Intended Epoch * Batch Size]
print(model)

D3SN(
  (DCT): DCT_Net(
    (Patch_Extraction): Sequential(
      (0): Conv2d(1, 64, kernel_size=(8, 8), stride=(1, 1), padding=(3, 3))
      (1): ReLU()
    )
    (DCT): Sequential(
      (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): ReLU()
    )
    (IDCT): Sequential(
      (0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): ReLU()
    )
    (LWAB_body): Sequential(
      (0): LWAB(
        (block): Sequential(
          (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): PReLU(num_parameters=1)
          (2): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): PReLU(num_parameters=1)
        )
      )
      (1): LWAB(
        (block): Sequential(
          (0): Conv2d(64, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): PReLU(num_parameters=1)
          (2): Conv2d(256, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (3): PReLU(num_parameters=1)


In [9]:
#Template for loading presaved check points for training.
'''
#IAM_3 = IAM_Net()
saved_state_dict = torch.load('/notebooks/D3SN_Files/models_JPEG_10/D3SN_JPEG_10_5000.pth')
model.load_state_dict(saved_state_dict)
#IAM_3 = IAM_3.to(device)
model.train()


#Load IAM Branch_1 Optimizer Checkpoint
saved_state_dict = torch.load('/notebooks/JPEG_Gray_IDPD_Files/models_1_channel/Optimizer300.pth')
optimizer.load_state_dict(saved_state_dict)
'''

"\n#Load IAM Branch_1 Optimizer Checkpoint\nsaved_state_dict = torch.load('/notebooks/JPEG_Gray_IDPD_Files/models_1_channel/Optimizer300.pth')\noptimizer.load_state_dict(saved_state_dict)\n"

#### Create Log File

In [10]:
import logging
logging.basicConfig(
        level=logging.DEBUG,
        format="%(asctime)s %(levelname)s %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        filename="/notebooks/D3SN_Files/logs/train_D3SN_JPEG_05.log",
    )

In [None]:
from tqdm.auto import tqdm
num_epochs = 4000
count = 0 * 450
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.train()


for epoch in range(1,num_epochs): 
    for idx, train_data in tqdm(enumerate(train_loader)):
        labels = []
        L = train_data['L'].to(device)
        H = train_data['H'].to(device)
        
        
        
        
        
        
        output = model(L)
        
        

        loss = error(output, H)

        # Initializing a gradient as 0 so there is no mixing of gradient among the batches
        optimizer.zero_grad()
        

        #Propagating the error backward
        loss.backward()

        # Optimizing the parameters
        optimizer.step()
        
        scheduler.step()
        
        count += 1

    if (count % 100 == 0):
        logging.info("Iteration: {}, Loss: {}%, epoch: {}, Learning Rate: {}".format(count, loss.data,epoch, scheduler.get_last_lr()))
        #scheduler.get_last_lr(), Learning Rate: {},
    if (count % 1000 == 0 and epoch != 0) or epoch ==3999 :
        torch.save(model.state_dict(), os.path.join("/notebooks/D3SN_Files/models_JPEG_05", 'D3SN_JPEG_05_' + str(count) + '.pth'))
        torch.save(optimizer.state_dict(), os.path.join("/notebooks/D3SN_Files/models_JPEG_05", 'Optimizer_' + str(count) + '.pth'))

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]