<a href="https://colab.research.google.com/github/amanshenoy/image-super-resolution/blob/master/training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Tensorboard for visual logging

In [0]:
%load_ext tensorboard 
%tensorboard --logdir runs

Reusing TensorBoard on port 6006 (pid 3486), started 2:23:39 ago. (Use '!kill 3486' to kill it.)

<IPython.core.display.Javascript object>

# Imports

In [0]:
import torchvision, torchvision.transforms as transforms
import torch, torch.nn as nn
import random, subprocess, os
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
from google.colab import drive

# To read and write to google drive
drive.mount("/content/drive/")

# Use GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


# Dataloaders

In [0]:
def load_loader_stl(crop_size: int = 33, batch_size: int = 128, num_workers: int = 1, scale: float = 2.0):
    """
    Loads the dataloader of the STL-10 Dataset using the given specifications with the required 
                          augmentation schemes
    input : crop_size -> image size of the square sub images the model has been trained on
            scale     -> Scale by which the low resolution image is downscaled  
    output: dataloader iterable to be able to train on the images

    Augmentation Schemes: Since torch has strong built in support for transforms, augmentation
                          was done within our dataloader transforms employing TenCrop on each 
                          image. For every image we get 5 crops (Center + 4 corners) and the horizontal 
                          flip of each. TenCrop returns a tuple, which was handled using lambda 
                          and also in the training script in the cell below.
          
    """
    # Write transforms for TenCrop and for generating low res images using bicubic interpolation (interpolation = 3)
    transform_high_res = transforms.Compose([
            transforms.TenCrop(crop_size),
            transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops]))
        ])
    transform_low_res = transforms.Compose([
            transforms.Resize(int(96 / scale), interpolation=3),
            transforms.Resize(96, interpolation=3),
            transform_high_res
        ])
    
    # Make STL-10 dataset object
    dataset_high_res = torchvision.datasets.STL10('.', transform = transform_high_res, download = True)
    dataset_low_res = torchvision.datasets.STL10('.', transform = transform_low_res, download = False)

    # Create the dataloader object using the transforms (Not shuffled since we will be checking progress on the same examples)
    dataloader_high_res = torch.utils.data.DataLoader(dataset_high_res, batch_size = batch_size, num_workers = num_workers, shuffle = False)
    dataloader_low_res = torch.utils.data.DataLoader(dataset_low_res, batch_size = batch_size, num_workers = num_workers, shuffle = False)
    return dataloader_low_res, dataloader_high_res

# Models used

In [0]:
class SuperResolution(nn.Module):
    """
    Network Architecture as per specified in the paper. 
    The chosen configuration for successive filter sizes are 9-5-5
    The chosed configuration for successive filter depth are 128-64(-3)
    """
    def __init__(self, sub_image: int = 33, spatial: list = [9, 5, 5], filter: list = [128, 64], num_channels: int = 3):
        super().__init__()
        self.layer_1 = nn.Conv2d(num_channels, filter[0], spatial[0], padding = spatial[0] // 2)
        self.layer_2 = nn.Conv2d(filter[0], filter[1], spatial[1], padding = spatial[1] // 2)
        self.layer_3 = nn.Conv2d(filter[1], num_channels, spatial[2], padding = spatial[2] // 2)
        self.relu = nn.ReLU()

    def forward(self, image_batch):
        x = self.layer_1(image_batch)
        x = self.relu(x)
        x = self.layer_2(x)
        y = self.relu(x)
        x = self.layer_3(y)
        return x, y 

# Training loop

In [0]:
def train():
    """
    Train function for training and constantly visualizing intermediate layers and 
    immediate outputs. All images relevant, along with losses are tracked on tensorboard
    in the first cell of this notebook. All hyperparameters are directly embedded in the
    code since the model has few to begin with, and the ones that exist also have fairly
    standard values

    We achieve lesser PSNR with the same configurations as the paper since we train for 
    much lesser steps (They train for 10^8 backward steps), since complete training 
    according to the paper was simply infeasible given the idle time of a colab notebook 
    is only 90 minutes 
    """
    # Initialize model, data, writer, optimizer, and backward count
    low_res_loader, high_res_loader = load_loader_stl()
    model = SuperResolution()

    # Comment the below line out, if training from scratch - Tensorboard graphs make more
    #                                      sense if training trends are seen from scratch
    #model.load_state_dict(torch.load('/content/drive/My Drive/isr/isr_best.pth')) 
    # ----------------------------------------------> To continue training if desired
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), 1e-04)
    writer = SummaryWriter()
    n = 0

    for epoch in tqdm(range(500), desc= "Training", ncols = 120):
        for low_res, high_res in zip(low_res_loader, high_res_loader):

            # Convert TenCrop tuple into a trainable shape of (batch_size * 10, c, h, w)
            low_res_batch, high_res_batch = low_res[0], high_res[0]
            _, _, c, h, w = low_res_batch.size()
            low_res_batch, high_res_batch = low_res_batch.to(device), high_res_batch.to(device)
            low_res_batch, high_res_batch = low_res_batch.view(-1, c, h, w), high_res_batch.view(-1, c, h, w)
            reconstructed_batch, intermediate = model(low_res_batch)

            # Calculate gradients and make a backward step on MSE loss
            loss_fn = nn.MSELoss()
            loss = loss_fn(high_res_batch, reconstructed_batch)
            loss_to_compare = loss_fn(high_res_batch, low_res_batch)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            
            # Clamp the image between 0 and 1 and prepare transforms and image arrays to write on tensorboard  
            to_pil = torchvision.transforms.ToPILImage()
            resize = torchvision.transforms.Resize((48 * 7, 144 * 7))
            other_resize = torchvision.transforms.Resize((48 * 5, 48 * 5))
            to_tensor = torchvision.transforms.ToTensor()
            ind = 4
            image = to_pil(torch.cat((low_res_batch[ind], high_res_batch[ind], reconstructed_batch[ind]), dim = 2).cpu())
            image = to_tensor(resize(image))
            image = image.clamp(0, 1)
            n += 1
            psnr = 10 * torch.log10(1 / loss)
            psnr_tc = 10 * torch.log10(1 / loss_to_compare)

            # Write relevant scalars and comparitive images on tensorboard
            writer.add_scalar("MSE loss", loss * (255 ** 2), n)
            writer.add_scalar("PSNR of Reconstruction", psnr, n)
            writer.add_scalar("PSNR of BiCubic Interpolation (For comparision)", psnr_tc, n)
            writer.add_image("Low Resolution Image | High Resolution Image | Reconstructed Image", image, n, dataformats='CHW')

        # Choose image on who intermediate layers are to be visualized and the channels to visualize
        index = 30 #Chooses image patch to visualize on, up till 80 (Size of the remnant batch)
        channels_to_visualize = [1, 2, 3, 4, 5, 6, 7, 8]  #Channel numbers out of 64 to visualize 

        # Write the intermediate layer visualizations and also write to drive, to download and create animated gifs
        patch = to_tensor(other_resize(to_pil(high_res_batch.detach().cpu()[index])))
        writer.add_image("Image Patch {}".format(index), patch, n, dataformats='CHW')
        pil_patch = to_pil(patch)
        pil_patch.save('/content/drive/My Drive/isr/patch_{}.png'.format(index))

        # Write the progress of training on two standard examples - 25 and 30 0f last batch
        os.makedirs('/content/drive/My Drive/isr/r1', exist_ok=True)
        r1 = to_tensor(other_resize(to_pil(reconstructed_batch.detach().cpu()[25])))
        pil_r1 = to_pil(r1)
        pil_r1.save('/content/drive/My Drive/isr/r1/frame_{}.png'.format(epoch))

        os.makedirs('/content/drive/My Drive/isr/r2', exist_ok=True)
        r2 = to_tensor(other_resize(to_pil(reconstructed_batch.detach().cpu()[30])))
        pil_r2 = to_pil(r2)
        pil_r2.save('/content/drive/My Drive/isr/r2/frame_{}.png'.format(epoch))

        for feature in channels_to_visualize:
            os.makedirs('/content/drive/My Drive/isr/channel_{}'.format(feature), exist_ok=True)
            visualization = to_tensor(other_resize(to_pil(intermediate.detach().cpu()[index, feature,:,:])))
            writer.add_image("Channel {}".format(feature), visualization, n, dataformats='CHW')
            pil_vis = to_pil(visualization)
            pil_vis.save('/content/drive/My Drive/isr/channel_{}/frame_{}.png'.format(feature, epoch))
            
        # Save the model for every epoch
        torch.save(model.state_dict(), '/content/drive/My Drive/isr/isr_best_2.pth'.format(n)) 
    
    return model

model = train()