In [1]:
import warnings
import random
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
# from tqdm.notebook import tqdm
import os

import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Dataset
import tifffile
import torch.optim as optim
import torchvision
from torchvision.io import read_image
from torchvision.utils import make_grid
import torchvision.transforms as transforms
import torchvision.transforms.functional as tvf
from torchmetrics import PeakSignalNoiseRatio

from model import *
from model_reference import *
from config import *

In [2]:
class DIV2K_Data(Dataset):
    def __init__(self, csv_path, is_transform=False):
        self.data = pd.read_csv(csv_path)
        self.is_transform = is_transform
    
    def __len__(self):
        return self.data.shape[0]

    def get_difference(self, tensor_image_1, tensor_image_2):
        image_1 = tensor_image_1.detach().numpy()
        image_2 = tensor_image_2.detach().numpy()

        difference = image_1 - image_2

        return torch.from_numpy(difference)
    
    def __getitem__(self, index):
        lr_image = read_image(self.data.iloc[index, 2])
        lr_height, lr_width = tvf.get_image_size(lr_image)
        resize = transforms.Resize((lr_width*2, lr_height*2))
        
        # Normalization using ImageNet measures of center and spread.
        normalize = transforms.Normalize(mean=torch.Tensor([0.485, 0.456, 0.406]), 
                                          std=torch.Tensor([0.229, 0.224, 0.225]), 
                                          inplace=True)
        # tensorify = transforms.ToTensor()
        lr_interpolated_image = resize(lr_image)
        hr_image = read_image(self.data.iloc[index, 5])
        if self.is_transform:
            if random.random() > 0.5:
                angle = random.randint(0, 180)
                lr_interpolated_image = tvf.rotate(lr_interpolated_image, angle)
                hr_image = tvf.rotate(hr_image, angle)
                
            if random.random() > 0.5:
                lr_interpolated_image = tvf.hflip(lr_interpolated_image)
                hr_image = tvf.hflip(hr_image)
            
            if random.random() > 0.5:
                lr_interpolated_image = tvf.vflip(lr_interpolated_image)
                hr_image = tvf.vflip(hr_image)
        
        lr_interpolated_image = normalize(lr_interpolated_image.type(torch.float32))
        hr_image = normalize(hr_image.type(torch.float32))
        return lr_interpolated_image, hr_image, self.get_difference(hr_image, lr_interpolated_image).type(torch.float32)

In [3]:
# https://pytorch.org/vision/stable/auto_examples/plot_visualization_utils.html
def plot_multiple_images(imgs, title):
    if not isinstance(imgs, list):
        imgs = [imgs]
    fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(18, 8))
    
    for i, img in enumerate(imgs):
        img = img.detach()
        img = tvf.to_pil_image(img)
        axs[0, i].imshow(np.asarray(img))
        axs[0, i].set(title=title)
        # axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])

In [4]:
def visualize_random_number_of_images(data_torch, n_rand_images = 5):
    # Show n random images from the training set
    rand_indices = [random.randint(0, len(data_torch)) for _ in range(n_rand_images)]

    for i in rand_indices:
        sample_lr, sample_hr, residual_diff = data_torch[i]
        print(i, sample_lr.shape, sample_hr.shape, residual_diff.shape)
        grid = make_grid([sample_lr, sample_hr, residual_diff])
        plot_multiple_images(grid, "Low Resolution image - " + "High Resolution image - " + "Difference in images")

In [5]:
class Trainer:
    def __init__(self, train_dataloader, valid_dataloader, config):
        self.patience = 5
        self.config = config
        self.model = VDSR(in_channels=self.config.INPUT_CHANNELS, 
                              out_channels=self.config.OUTPUT_CHANNELS, )
        self.loss_function = self.config.LOSS_MSE
        self.batch_size = self.config.BATCH_SIZE
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.device = self.config.DEVICE
        self.epochs = self.config.EPOCHS
        self.lr = self.config.LEARNING_RATE
        self.optim_step_size = self.config.OPTIM_STEP_SIZE
        self.optim_gamma = self.config.OPTIM_GAMMA
        self.grad_clip_max_norm = self.config.GRAD_CLIP_MAX_NORM
        self.momentum=self.config.MOMENTUM
        self.weight_decay=self.config.WEIGHT_DECAY
        self.evaluation_metric = self.config.EVALUATION_METRIC
        self.device = self.config.DEVICE
        self.val_for_early_stopping = 9999999 #early stopping
        
        if not os.path.isdir(self.config.MODEL_SAVEPATH):
            os.makedirs(self.config.MODEL_SAVEPATH)
        
        self.log = pd.DataFrame(columns=["model_name", "train_loss", "train_PSNR", "valid_loss", "valid_PSNR"])
        self.optimizer = optim.SGD(params=self.model.parameters(), 
                                   lr=self.lr, 
                                   momentum=self.momentum,
                                   weight_decay=self.weight_decay)
        self.optim_scheduler = optim.lr_scheduler.StepLR(self.optimizer, 
                                      step_size=self.optim_step_size,
                                      gamma=self.optim_gamma)
    
    def calculate_metrics(self, dataloader):
        self.model.eval()
        total_loss = 0
        total_performance = 0
        with torch.no_grad():
            for lr_image, hr_image, _ in tqdm(dataloader, total=len(dataloader)):
                low_res_image = lr_image.to(self.device)
                high_res_image = hr_image.to(self.device)
                out = self.model(low_res_image)
                loss = self.loss_function(out.data, high_res_image)
                total_loss += loss.item()
                total_performance += self.evaluation_metric(out.data.to("cpu"), high_res_image.cpu())
        return total_performance/len(dataloader), total_loss/len(dataloader)
    
    
    def early_stopping(self, val_loss):
        if val_loss < self.val_for_early_stopping:
            self.val_for_early_stopping = val_loss
            return True
        else:
            self.patience -= 1
            return False
    
    
    def fit(self):
        print("-"*25, "THE MODEL BASED ON" , self.config.ARCHITECTURE, "BEGINS TRAINING", "-"*25)
        print(f"TRAINING ON {self.config.DEVICE.upper()}")
        
        best_loss = 9e+6
        
        for epoch in range(self.epochs):
            self.model.train()
            self.model.to(self.device)
            log_file = open(f"log_epoch_{epoch}.txt", "a")
            log_messages = ""
            
            performance_train = 0
            loss_train = 0
            b_num = 0
            
            for lr_image, hr_image, res_diff in tqdm(self.train_dataloader, total=len(self.train_dataloader)):
                
                low_res_batch = lr_image.to(self.device)
                # print(low_res_batch.dtype)
                high_res_image = hr_image.to(self.device)
                self.optimizer.zero_grad()
                output = self.model(low_res_batch)
                loss = self.loss_function(output, high_res_image)
                loss.backward()
                self.optimizer.step()
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip_max_norm)
                loss_train += loss.item()
                performance_train += self.evaluation_metric(output.data.to("cpu"), hr_image.to("cpu"))
                log_messages += f"LOSS, Performance at batch {b_num} after epoch {epoch}: {loss_train} : {performance_train}\n"
                b_num += 1
            
            self.optim_scheduler.step()
            
            log_file.write(log_messages)
            log_file.close()
            
            performance_valid, loss_valid = self.calculate_metrics(self.valid_dataloader)
            performance_train /= len(self.train_dataloader)
            loss_train /= len(self.train_dataloader)
            
            print("-"*10, "STATUS AT EPOCH NO.", epoch, "-"*10)
            print(f"Train performance : {performance_train}, Train loss {loss_train}")
            print(f"Valid performance : {performance_valid}, valid loss {loss_valid}")
            
            self.log.loc[epoch,:] = [f"{self.config.ARCHITECTURE}_{self.config.BATCH_SIZE}.pth", 
                                     f"{loss_train}",
                                     f"{performance_train}",
                                     f"{loss_valid}",
                                     f"{performance_valid}"]
            self.log.to_csv(self.config.MODEL_SAVEPATH + 
                            f"/{self.config.ARCHITECTURE}_{self.config.BATCH_SIZE}_valid_{epoch}.csv",index=False)
            
            if self.patience >= 0 and self.early_stopping(loss_valid):
                print(f"Saving model at Epoch: {epoch}")
                torch.save(self.model.state_dict(), self.config.MODEL_SAVEPATH + "/" +
                           f"{self.config.ARCHITECTURE}_{self.config.BATCH_SIZE}.pth")
                self.patience = 5
            
            if self.patience <= 0:
                print("-"*10, "EARLY STOPPING", "-"*10)
                print("Training terminated, no improvement in valid loss")
                break
                

In [6]:
def main(config, n_images_to_viz=0):
    train_data = DIV2K_Data(csv_path=config.TRAIN_PATH, is_transform=True)
    valid_data = DIV2K_Data(csv_path=config.VALID_PATH, is_transform=False)
    if(n_images_to_viz):
        visualize_random_number_of_images(train_data, n_images_to_viz)
    train_dataloader = DataLoader(train_data, batch_size=vdsr_config.BATCH_SIZE, shuffle=True)
    valid_dataloader = DataLoader(valid_data, batch_size=vdsr_config.BATCH_SIZE, shuffle=True)
    
    trainer = Trainer(train_dataloader, valid_dataloader, config)
    trainer.fit()

In [7]:
vdsr_config = Configuration()
print(vdsr_config)

----------------------------------------CONFIGURATION DETAILS----------------------------------------
Architecture : VDSR-Net
Batch Size : 2
Number of Input Channels : 3
Number of Output Channels : 3
Depth of the network : 6
Training platform : CUDA
Number of epochs : 80
Gradient clipping with max norm : 0.01
Loss Function : MSE Loss
Performance metric : Peak Signal To Noise Ratio (PSNR)
----------------------------------------------------------------------------------------------------


In [8]:
main(vdsr_config)

------------------------- THE MODEL BASED ON VDSR-Net BEGINS TRAINING -------------------------
TRAINING ON CUDA


100%|██████████| 215/215 [09:37<00:00,  2.69s/it]
100%|██████████| 27/27 [00:46<00:00,  1.72s/it]


---------- STATUS AT EPOCH NO. 0 ----------
Train performance : nan, Train loss nan
Valid performance : nan, valid loss nan


 39%|███▉      | 84/215 [02:41<04:12,  1.93s/it]


KeyboardInterrupt: 