In [None]:
!pip install torch_snippets

In [None]:
!pip install pytorch_ssim 

In [None]:
# -*- coding: utf-8 -*-
"""SRGAN

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1CsyvA790GM30uX1J3wnSnfoTGSQLXrr9
"""

import torch
import math
import torch.nn as nn
from torchvision import transforms as T
from torchvision import models
from torchvision.transforms import InterpolationMode
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, lr_scheduler
from torch.nn import functional as F
from tqdm import tqdm
from glob import glob
from PIL import Image
from torch_snippets import *
from torchvision.models import vgg19, VGG19_Weights
from torchvision.models.feature_extraction import create_feature_extractor
from torchsummary import summary
from torchvision.utils import save_image, make_grid
import matplotlib.pyplot as plt
import pytorch_ssim
import numpy as np
from math import log10

device = "cuda" if torch.cuda.is_available() else "cpu"

class SuperResolutionDatasetTrain(Dataset):
  def __init__(self, data_dir):
    super().__init__()
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]
    self.image_files = glob(data_dir + "/*.png")
    self.lr_transform = T.Compose(
        [   
            T.ToPILImage(),
            T.Resize((96 // 4, 96 // 4), InterpolationMode.BICUBIC),
            T.ToTensor(),
        ])
    self.hr_transform = T.Compose(
        [
            T.RandomCrop(96),
            T.Resize((96, 96)),
            T.ToTensor(),
        ]
    )
  
  def __len__(self):
    return len(self.image_files)
  
  def __getitem__(self, index):
    img = Image.open(self.image_files[index]).convert("RGB")
    hr_image = self.hr_transform(img)
    lr_image = self.lr_transform(hr_image)

    return hr_image.to(device), lr_image.to(device)

class SuperResolutionDatasetTest(Dataset):
  def __init__(self, gt_data_dir, lr_data_dir):
    super().__init__()
    self.gt_data_dir = sorted(glob(gt_data_dir + "/*"))
    self.lr_data_dir = sorted(glob(lr_data_dir + "/*"))

  def __len__(self):
    return len(self.gt_data_dir)
  
  def __getitem__(self, index):
    gt_image = Image.open(self.gt_data_dir[index]).convert("RGB")
    lr_image = Image.open(self.lr_data_dir[index]).convert("RGB")

    transformer = T.Compose([
        T.ToTensor()
    ])

    gt_image = transformer(gt_image)
    lr_image = transformer(lr_image)
    
    return gt_image.to(device), lr_image.to(device)
    
class Residual_Block(nn.Module):
  def __init__(self, channels):
    super().__init__()
    self.block = nn.Sequential(
        nn.Conv2d(channels, channels,(3,3),(1,1),(1,1),bias = False),
        nn.BatchNorm2d(channels),
        nn.PReLU(),
        nn.Conv2d(channels, channels,(3,3),(1,1),(1,1),bias = False), 
        nn.BatchNorm2d(channels)
    )

  def forward(self, x):
    residual = self.block(x)
    return torch.add(x, residual)

class Upsample_Block(nn.Module):
  def __init__(self, channels, upscale_factor): 
    super().__init__()
    self.block = nn.Sequential(
        nn.Conv2d(channels, channels * upscale_factor ** 2, kernel_size = (3,3), stride = 1, padding = 1),
        nn.BatchNorm2d(channels * 2 ** 2),
        nn.PixelShuffle(2),
        nn.PReLU()
    )
  
  def forward(self, x):
    return self.block(x)

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.pre_res = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size = 9, stride = 1, padding = 4), #check what impact removing padding has
        nn.PReLU()
    )
    
    res_block = []
    for _ in range(16):
        res_block.append(Residual_Block(64))
    self.res_blocks = nn.Sequential(*res_block)
    
    self.pos_res = nn.Sequential(
        nn.Conv2d(64, 64, (3,3), (1,1), (1,1)), 
        nn.BatchNorm2d(64), 
    )

    up_sample = [] 
    for _ in range(int(math.log(4,2))):
        up_sample.append(Upsample_Block(64,2))
    self.upsampling = nn.Sequential(*up_sample)

    self.pos_up = nn.Conv2d(64,3, kernel_size = (9,9), stride = (1,1), padding = (4,4))
    
    self._initialize_weights()

  def forward(self, x):
    out1 = self.pre_res(x)
    out = self.res_blocks(out1)
    out2 = self.pos_res(out)
    out = torch.add(out1, out2) 
    out = self.upsampling(out)
    out = self.pos_up(out)

    out = torch.clamp_(out, 0.0, 1.0)

    return out
  
  def _initialize_weights(self) -> None:
    for module in self.modules():
        if isinstance(module, nn.Conv2d):
            nn.init.kaiming_normal_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.BatchNorm2d):
            nn.init.constant_(module.weight, 1)


class Discriminator(nn.Module):
    def __init__(self) -> None:
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            # input size. (3) x 128 x 128
            nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True),
            nn.LeakyReLU(0.2, True),
            # state size. (64) x 64 x 64
            nn.Conv2d(64, 64, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            # state size. (128) x 32 x 32
            nn.Conv2d(128, 128, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            # state size. (256) x 16 x 16
            nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # state size. (512) x 8 x 8
            nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # state size. (512) x 4 x 4
            nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True)
        )

        self.classifier = nn.Sequential(
            nn.Linear(4608, 100),
            nn.LeakyReLU(0.2, True),
            nn.Linear(100, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.features(x)
        out = torch.flatten(out, 1)
        out = self.classifier(out)

        return out
    
class PerceptualLoss(nn.Module):
  def __init__(self):
      super(PerceptualLoss, self).__init__()

      vgg = vgg19(pretrained=True)
      loss_network = nn.Sequential(*list(vgg.features)[:35]).eval()
      for param in loss_network.parameters():
          param.requires_grad = False
      self.loss_network = loss_network
      self.l1_loss = nn.L1Loss()

  def forward(self, high_resolution, fake_high_resolution):
      perception_loss = self.l1_loss(self.loss_network(high_resolution), self.loss_network(fake_high_resolution))
      return perception_loss

def define_loss():
  pixel_loss = nn.MSELoss()
  content_loss = PerceptualLoss()
  adversarial_loss = nn.BCEWithLogitsLoss()

  pixel_loss = pixel_loss.to(device)
  content_loss = content_loss.to(device)
  adversarial_loss = adversarial_loss.to(device)

  return pixel_loss, content_loss, adversarial_loss

def get_optimizers(netD, netG):
  optimizer_d = Adam(netD.parameters(),lr = 1e-4)
  optimizer_g = Adam(netG.parameters(),lr = 1e-4)
  return optimizer_d, optimizer_g

def get_lr_scheduler(opt_d, opt_g, epoch_range):
  lr_d = lr_scheduler.StepLR(opt_d, epoch_range//2, gamma = 0.1)
  lr_g = lr_scheduler.StepLR(opt_g, epoch_range//2, gamma = 0.1)
  return lr_d, lr_g

def load_dataset(train_data_dir, test_data_dir):
  tr_ds = SuperResolutionDatasetTrain(train_data_dir)
  ts_ds = SuperResolutionDatasetTest(test_data_dir)
  tr_dl = DataLoader(tr_ds, batch_size = 64, shuffle = True)
  ts_dl = DataLoader(ts_ds, batch_size = 16, shuffle = False) 

  return tr_dl, ts_dl

def train(model_d, model_g, opt_d, opt_g,data):
    model_d.train()
    model_g.train()
    hr, lr = data

    batch_size, _, height, width = hr.shape

    real_label = torch.full([batch_size,1],1.0, dtype = hr.dtype).to(device)
    fake_label = torch.full([batch_size,1],0.0, dtype = hr.dtype).to(device)
  
    pixel_criterion, content_criterion, adversarial_criterion = define_loss() 

    #Discriminator
    for d in model_d.parameters():
        d.requires_grad = True 
    model_d.zero_grad(set_to_none=True) 
    gt_output = model_d(hr)
    d_loss_gt = adversarial_criterion(gt_output, real_label) 
    d_loss_gt.backward(retain_graph = True)

    sr = model_g(lr)
    sr_output = model_d(sr.detach().clone())
    d_loss_sr = adversarial_criterion(sr_output, fake_label) 
    d_loss_sr.backward()

    d_loss = d_loss_gt + d_loss_sr 
    opt_d.step()

    for d in model_d.parameters():
        d.requires_grad = False 

    #Generator
    model_g.zero_grad(set_to_none=True)
    pixel_loss = pixel_criterion(sr, hr)
    content_loss = content_criterion(sr, hr) 
    adversarial_loss = adversarial_criterion(model_d(sr), real_label) * 0.001

    g_loss = pixel_loss + content_loss + adversarial_loss 
    g_loss.backward()
    opt_g.step() 

    d_gt_probaility = torch.sigmoid(torch.mean(gt_output.detach()))
    d_sr_probaility = torch.sigmoid(torch.mean(sr_output.detach()))

    torch.cuda.empty_cache()
    return d_loss.item(), g_loss.item(), d_gt_probaility.item(), d_sr_probaility.item()  

def test(model_d, model_g, data, epoch):
  model_d.eval()
  model_g.eval()

  hr, lr = data

  with torch.no_grad():
    gen_sr = model_g(lr) 

  mse = torch.mean((hr * 255 - gen_sr * 255) ** 2 + 1e-8, dim = [1,2,3])
  psnr = 10 * torch.log10_((255.0 ** 2) / mse)

  if epoch % 10 == 0:
    plot_train_result(hr, gen_sr, lr, epoch, save = True, show = False)

  return psnr

def to_np(x):
    return x.data.cpu().numpy()

def plot_train_result(real_image, gen_image, recon_image, epoch=1, save=False,  show=True, fig_size=(15, 15)):
    fig, axes = plt.subplots(1, 3, figsize=fig_size)
    imgs = [to_np(real_image)[0], to_np(gen_image)[0], to_np(recon_image)[0]]
    for ax, img in zip(axes.flatten(), imgs):
        ax.axis('off')
        img = img.squeeze()
        img = (((img - img.min()) * 255) / (img.max() - img.min())).transpose(1, 2, 0).astype(np.uint8)
        ax.imshow(img, cmap=None, aspect='equal')
    plt.subplots_adjust(wspace=0, hspace=0)

    title = 'Epoch {0}'.format(epoch + 1)
    fig.text(0.5, 0.04, title, ha='center')

    if save:
        save_fn = 'Result_epoch_{:d}'.format(epoch+1) + '.png'
        plt.savefig(save_fn)

    if show:
        plt.show()
    else:
        plt.close()

def live_practice(image_dir, output_dir, model_pth):
  model = Generator().to(device)
  model.load_state_dict(torch.load(model_pth))
  model.eval()
  ds = SuperResolutionDataset(image_dir)
  dl = DataLoader(ds, shuffle = True)
  for ix, data in enumerate(dl):
    hr, lr = data
    gen_hr = model(lr)
    save_image(gen_hr, f"{output_dir}/{ix+1}_sr.png",normalize = True)
    save_image(hr, f"{output_dir}/{ix+1}_hr.png",normalize = True)
    save_image(lr, f"{output_dir}/{ix+1}_lr.png",normalize = True)

  print("[INFO] FINISHED CONVERTING IMAGES TO HIGH RESOLUTION")

def make_img_grids(img_lr, img_hr, gen_hr, output_dir,epoch):
  img_lr = nn.functional.interpolate(img_lr, scale_factor = 4)
  img_hr = make_grid(img_hr, nrow = 1, normalize = True)
  gen_hr = make_grid(gen_hr, nrow = 1, normalize = True)
  img_lr = make_grid(img_lr, nrow = 1, normalize = True)
  img_grid = torch.cat([img_lr, img_hr, gen_hr])
  save_image(img_grid, f"{output_dir}/grid_results_{epoch}.png", normalize = False)

def main():
  epoch_range = 200

  training_dl, testing_dl = load_dataset("C:/Users/richa/source/repos/Data/Image Resolution/train","C:/Users/richa/source/repos/Data/Image Resolution/test")

  log = Report(epoch_range)
  model_g = Generator().to(device)
  model_d = Discriminator().to(device)

  summary(model_g, (3,24,24))
  summary(model_d, (3,96,96))

  opt_d, opt_g = get_optimizers(model_d, model_g)
  lr_d, lr_g = get_lr_scheduler(opt_d, opt_g, epoch_range)
  print("[INFO] SRGAN MODEL BUILT")
  print("[INFO] DATASETS LOADED")
  print("[INFO] OPTIMIZERS AND SCHEDULERS BUILT")
  print(f"[INFO] CURRENTLY USING {device}")

  for epoch in range(epoch_range):
    N = len(training_dl)
    for i, data in enumerate(training_dl):
      d_loss, g_loss, d_gt, d_sr  = train(model_d, model_g, opt_d, opt_g, data)
      log.record(epoch + (1+i)/N, d_loss = d_loss, g_loss = g_loss,d_gt = d_gt, d_sr = d_sr, end = "/r")
      del d_loss, g_loss, d_gt, d_sr
    
    N = len(testing_dl)
    for i, data in enumerate(testing_dl):
      psnr = test(model_d,model_g,data, epoch)
      log.record(epoch + (1+i)/N, psnr = psnr, end = "/r")
      del psnr

    torch.cuda.empty_cache()
    if(epoch % 10 == 0):
      torch.save(model_g.state_dict(), f"Epoch_{epoch}_model_g.pth")
      torch.save(model_d.state_dict(), f"Epoch_{epoch}_model_d.pth")

    log.report_avgs(epoch+1)

  log.plot_epochs(['d_loss','g_loss'])

if __name__ == "__main__":
  model_g = Generator().to(device)
  model_d = Discriminator().to(device)

  summary(model_g, (3,24,24))
  summary(model_d, (3,96,96))
