In [1]:
import torch
from torch.utils.data import Dataset
import json
import os
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from PIL import Image, ImageFilter
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt
import torch
from torch import nn
import torchvision
import math
import torch.backends.cudnn as cudnn
import torch
from torch import nn
from torchvision.utils import make_grid
from torch.utils.tensorboard import SummaryWriter
import os
from math import log10

In [2]:
# defining the convolutional block
class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, batch_norm=False, activation=None):
        super(ConvolutionalBlock, self).__init__()
        # check if the activation is there or not
        if activation is not None:
            activation = activation.lower()
            assert activation in {'prelu', 'leakyrelu', 'tanh'}
        # initialize an empty list, which will store the layers of the model
        layers = list()
        # append the first layer of Conv
        layers.append(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                      padding=kernel_size // 2))
        # add the BatchNorm
        if batch_norm is True:
            layers.append(nn.BatchNorm2d(num_features=out_channels))
        # batchNorm is followed by an activation layer
        if activation == 'prelu':
            layers.append(nn.PReLU())
        elif activation == 'leakyrelu':
            layers.append(nn.LeakyReLU(0.2))
        elif activation == 'tanh':
            layers.append(nn.Tanh())
        # create a sequential model from the list of layers.
        self.conv_block = nn.Sequential(*layers)
    
    # define the forward function
    def forward(self, input):
        output = self.conv_block(input)
        return output


# SubPixel Convolution layer
class SubPixelConvolutionalBlock(nn.Module):
    # pixel shuffle layer is introduced to increase the dimension at the end.
    def __init__(self, kernel_size=3, n_channels=64, scaling_factor=2):
        super(SubPixelConvolutionalBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels=n_channels, out_channels=n_channels * (scaling_factor ** 2),
                              kernel_size=kernel_size, padding=kernel_size // 2)
        # r2*c*h*w --> r*c*rh*rw
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=scaling_factor)
        self.prelu = nn.PReLU()

    # define the forward function
    def forward(self, input):
        output = self.conv(input)
        output = self.pixel_shuffle(output)  
        output = self.prelu(output) 

        return output


# define the residual block
class ResidualBlock(nn.Module):
    def __init__(self, kernel_size=3, n_channels=64):
        super(ResidualBlock, self).__init__()
        self.conv_block1 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size,
                                              batch_norm=True, activation='PReLu')

        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size,
                                              batch_norm=True, activation=None)

    def forward(self, input):
        residual = input  # (N, n_channels, w, h)
        output = self.conv_block1(input)  # (N, n_channels, w, h)
        output = self.conv_block2(output)  # (N, n_channels, w, h)
        output = output + residual  # (N, n_channels, w, h)

        return output


class SRResNet(nn.Module):
    # defining the entire structure of the SRResNet
    def __init__(self, large_kernel_size=9, small_kernel_size=3, n_channels=64, n_blocks=16, scaling_factor=4):
        super(SRResNet, self).__init__()
        scaling_factor = int(scaling_factor)
        assert scaling_factor in {2, 4, 8}

        self.conv_block1 = ConvolutionalBlock(in_channels=3, out_channels=n_channels, kernel_size=large_kernel_size,
                                              batch_norm=False, activation='PReLu')

        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(kernel_size=small_kernel_size, n_channels=n_channels) for i in range(n_blocks)])

        self.conv_block2 = ConvolutionalBlock(in_channels=n_channels, out_channels=n_channels,
                                              kernel_size=small_kernel_size,
                                              batch_norm=True, activation=None)

        n_subpixel_convolution_blocks = int(math.log2(scaling_factor))
        self.subpixel_convolutional_blocks = nn.Sequential(
            *[SubPixelConvolutionalBlock(kernel_size=small_kernel_size, n_channels=64, scaling_factor=4) for i
              in range(n_subpixel_convolution_blocks)])

        self.conv_block3 = ConvolutionalBlock(in_channels=n_channels, out_channels=3, kernel_size=large_kernel_size,
                                              batch_norm=False, activation='Tanh')

    def forward(self, lr_imgs):
        output = self.conv_block1(lr_imgs)  # (16, 3, 24, 24)
        residual = output  # (16, 64, 24, 24)
        output = self.residual_blocks(output)  # (16, 64, 24, 24)
        output = self.conv_block2(output)  # (16, 64, 24, 24)
        output = output + residual  # (16, 64, 24, 24)
        output = self.subpixel_convolutional_blocks(output)  # (16, 64, 24 * 2, 24 * 2)
        sr_imgs = self.conv_block3(output)  # (16, 3, 24 * 2, 24 * 2)

        return sr_imgs


In [5]:

crop_size = 600      
scaling_factor = 2 


large_kernel_size = 9   
small_kernel_size = 3   
n_channels = 64         
n_blocks = 16           


checkpoint = '/content/rsresnet.pth'  
batch_size = 2    
start_epoch = 1     
epochs = 20      
workers = 1        
lr = 1e-4           

pre_psnr=0

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ngpu = 1

cudnn.benchmark = True 


model = torch.load(checkpoint).to(device)
print('Loading of previous model succeded')
    
# defining the optimiser of the model
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model.parameters()),lr=lr)
# LR scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
# putting the model on the device
model = model.to(device)


Loading of previous model succeded


In [6]:
print(model)

SRResNet(
  (conv_block1): ConvolutionalBlock(
    (conv_block): Sequential(
      (0): Conv2d(3, 64, kernel_size=(9, 9), stride=(1, 1), padding=(4, 4))
      (1): PReLU(num_parameters=1)
    )
  )
  (residual_blocks): Sequential(
    (0): ResidualBlock(
      (conv_block1): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): PReLU(num_parameters=1)
        )
      )
      (conv_block2): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (1): ResidualBlock(
      (conv_block1): ConvolutionalBlock(
        (conv_block): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1)

In [7]:
class SRDataset(Dataset):
    def __init__(self, data_path, crop_size, scaling_factor):
        self.data_path=data_path
        self.crop_size = int(crop_size)
        self.scaling_factor = int(scaling_factor)
        self.images_path=[]

        
        for name in os.listdir(self.data_path):
            self.images_path.append(os.path.join(self.data_path,name))

        # transformation common for both
        self.pre_trans=transforms.Compose([
                                transforms.CenterCrop(self.crop_size),
                                transforms.RandomHorizontalFlip(p=0.5),
                                transforms.RandomVerticalFlip(p=0.5)
                                ])
        
        # transformation only for the input class [lr images]
        self.input_transform = transforms.Compose([
                                transforms.Resize(self.crop_size//self.scaling_factor),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5],std=[0.5])
                                ])
        
        # transformation only for the target class [hr images]
        self.target_transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5],std=[0.5]),
                                ])

    # get an element based on index value
    # get the lr image and the hr image
    def __getitem__(self, i):
        img = Image.open(self.images_path[i], mode='r')
        img = img.convert('RGB')
        img=self.pre_trans(img)

        lr_img = self.input_transform(img)
        hr_img = self.target_transform(img.copy())
        
        return lr_img, hr_img

    # get the length of the input image
    def __len__(self):
        return len(self.images_path)

In [8]:
!unzip Set5.zip

Archive:  Set5.zip
  inflating: Set5/baby.png           
  inflating: Set5/bird.png           
  inflating: Set5/butterfly.png      
  inflating: Set5/head.png           
  inflating: Set5/woman.png          


In [9]:
path = "/content/Set5"

In [27]:
 
dataset = SRDataset(path, crop_size, scaling_factor)

loader = torch.utils.data.DataLoader(dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
    pin_memory=True)


model.eval()  
test_loss=0
all_psnr = 0
n_iter = len(loader)

criterion = nn.MSELoss()

for i, (lr_imgs, hr_imgs) in enumerate(loader):
    with torch.no_grad():
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)

        sr_imgs = model(lr_imgs)
        

        loss = criterion(sr_imgs, hr_imgs)

        psnr = 10 * log10(1 / loss.item())

        all_psnr+=psnr
        test_loss+=loss.item()
    

epoch_loss_test=test_loss/n_iter
epoch_psnr=all_psnr / n_iter


print(f"Average PSNR Set 5 dataset, Model with Batchnorm in Residual blocks: {epoch_psnr} dB.")

Average PSNR Set 5 dataset, Model with Batchnorm in Residual blocks: 30.776351084623087 dB.


In [13]:
!unzip Set14.zip


Archive:  Set14.zip
  inflating: Set14/baboon.png        
  inflating: Set14/barbara.png       
  inflating: Set14/bridge.png        
  inflating: Set14/coastguard.png    
  inflating: Set14/comic.png         
  inflating: Set14/face.png          
  inflating: Set14/flowers.png       
  inflating: Set14/foreman.png       
  inflating: Set14/lenna.png         
  inflating: Set14/man.png           
  inflating: Set14/monarch.png       
  inflating: Set14/pepper.png        
  inflating: Set14/ppt3.png          
  inflating: Set14/zebra.png         


In [16]:
set14path = "/content/Set14"
dataset = SRDataset(set14path, crop_size, scaling_factor)

loader = torch.utils.data.DataLoader(dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
    pin_memory=True)


model.eval()  
test_loss=0
all_psnr = 0
n_iter = len(loader)

criterion = nn.MSELoss()

for i, (lr_imgs, hr_imgs) in enumerate(loader):
    with torch.no_grad():
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)

        sr_imgs = model(lr_imgs)




        loss = criterion(sr_imgs, hr_imgs)

        psnr = 10 * log10(1 / loss.item())

        all_psnr+=psnr
        test_loss+=loss.item()
    

epoch_loss_test=test_loss/n_iter
epoch_psnr=all_psnr / n_iter


print(f"Average PSNR Set 14 dataset, Model with Batchnorm in Residual blocks: {epoch_psnr} dB.")

Average PSNR Set 14 dataset, Model with Batchnorm in Residual blocks: 24.702948793059193 dB.


Model without Batchnorm

In [17]:
model_wo_bn_chkpt = "/content/resnet_default_False.pth"


crop_size = 600      
scaling_factor = 2 


large_kernel_size = 9   
small_kernel_size = 3   
n_channels = 64         
n_blocks = 16           


  
batch_size = 2    
start_epoch = 1     
epochs = 20      
workers = 1        
lr = 1e-4           

pre_psnr=0

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

ngpu = 1

cudnn.benchmark = True 


model_wo = torch.load(model_wo_bn_chkpt).to(device)
print('Loading of previous model succeded')
    
# defining the optimiser of the model
optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, model_wo.parameters()),lr=lr)
# LR scheduler
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.9)
# putting the model on the device
model_wo = model_wo.to(device)


Loading of previous model succeded


In [19]:
 
dataset = SRDataset(path, crop_size, scaling_factor)

loader = torch.utils.data.DataLoader(dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
    pin_memory=True)


model_wo.eval()  
test_loss=0
all_psnr = 0
n_iter = len(loader)

criterion = nn.MSELoss()

for i, (lr_imgs, hr_imgs) in enumerate(loader):
    with torch.no_grad():
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)

        sr_imgs = model_wo(lr_imgs)




        loss = criterion(sr_imgs, hr_imgs)

        psnr = 10 * log10(1 / loss.item())

        all_psnr+=psnr
        test_loss+=loss.item()
    

epoch_loss_test=test_loss/n_iter
epoch_psnr=all_psnr / n_iter


print(f"Average PSNR Set 5 dataset, Model with Batchnorm in Residual blocks: {epoch_psnr} dB.")

Average PSNR Set 5 dataset, Model with Batchnorm in Residual blocks: 30.685715845418386 dB.


In [20]:



set14path = "/content/Set14"
dataset = SRDataset(set14path, crop_size, scaling_factor)

loader = torch.utils.data.DataLoader(dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=workers,
    pin_memory=True)


model_wo.eval()  
test_loss=0
all_psnr = 0
n_iter = len(loader)

criterion = nn.MSELoss()

for i, (lr_imgs, hr_imgs) in enumerate(loader):
    with torch.no_grad():
        lr_imgs = lr_imgs.to(device)
        hr_imgs = hr_imgs.to(device)

        sr_imgs = model_wo(lr_imgs)





        loss = criterion(sr_imgs, hr_imgs)

        psnr = 10 * log10(1 / loss.item())

        all_psnr+=psnr
        test_loss+=loss.item()
    

epoch_loss_test=test_loss/n_iter
epoch_psnr=all_psnr / n_iter


print(f"Average PSNR Set 14 dataset, Model with Batchnorm in Residual blocks: {epoch_psnr} dB.")

Average PSNR Set 14 dataset, Model with Batchnorm in Residual blocks: 25.981879256546158 dB.
