In [2]:
import torch
import torchvision
import torch.nn as nn
from torchsummary import summary
import os
import cv2
import numpy as np
from google.colab.patches import cv2_imshow
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms as transforms
from torchvision.utils import save_image
import gc
import math
import sys, time

In [None]:
!pip install onnxruntime
import torch.utils.model_zoo as model_zoo
import torch.onnx
import onnxruntime

Collecting onnxruntime
[?25l  Downloading https://files.pythonhosted.org/packages/91/a5/a70b05bc5a6037e0bf29d21828945a49fa4d341690c8ae7f01a62a177a2b/onnxruntime-1.5.2-cp36-cp36m-manylinux2014_x86_64.whl (3.8MB)
[K     |████████████████████████████████| 3.8MB 9.3MB/s 
Installing collected packages: onnxruntime
Successfully installed onnxruntime-1.5.2


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
class MobileNet(nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(inp, oup, 3, stride,0, bias=False),
                nn.PReLU()
            )

        def conv_dw(inp, oup, stride,kernel_size):
            return nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(inp, inp, kernel_size, stride, groups=inp, bias=False),
                nn.PReLU(),

                nn.Conv2d(inp, oup, 1, 1, bias=False),
                nn.PReLU(),
            )

        def deconv_dw(inp, oup, upscale,kernel_size):
          return nn.Sequential(
              nn.ReflectionPad2d(1),
              nn.Conv2d(inp, oup, 1, 1, bias=False),
              nn.PReLU(),

              nn.ReflectionPad2d(1),
              nn.Conv2d(oup, oup * upscale ** 2, kernel_size ,groups=oup),
              nn.PixelShuffle(upscale),
              nn.PReLU(), 
          )
        
        def res_block(inp):
          return nn.Sequential(
              conv_dw(inp,inp,1,3),
              nn.BatchNorm2d(inp),
              nn.PReLU(),

              conv_dw(inp,inp,1,3),
              nn.BatchNorm2d(inp)
          )

        def finblock(inp):
          return nn.Sequential(
              nn.ReflectionPad2d(1),
              deconv_dw(inp,3,1,9)
          )
        
        self.conv_1 = conv_bn(3,32,1)
        self.res_block1 = res_block(32)
        self.res_block2 = res_block(32)
        self.upsample_1 = deconv_dw(32,48,2,3)
        self.finconv = finblock(48)
        
    def forward(self, inp):
      pre_res = self.conv_1(inp)
      x1 = self.res_block1(pre_res)
      x1+= pre_res
      x2 = self.res_block2(x1)
      x2+= x1
      x2+= pre_res
      x = self.upsample_1(x2)
      x = self.finconv(x)
      return x

In [5]:
import torch.nn.functional as F

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(inp, inp, 3, stride, groups=inp, bias=False),
                nn.ReLU6(inplace=True),

                nn.Conv2d(inp, oup, 1, 1, bias=False),
                nn.ReLU6(inplace=True),
                nn.BatchNorm2d(oup)
            )
        
        self.model = nn.Sequential(
            conv_dw(3,64,1),
            conv_dw(64,64,2),
            conv_dw(64,128,1),
            conv_dw(128,128,2),
            conv_dw(128,256,1),
            conv_dw(256,256,2),
            conv_dw(256,512,1),
            conv_dw(512,512,2),
            nn.ReflectionPad2d(1),
            nn.Conv2d(512,1,1,stride = 1)
        )

    def forward(self, x):
        x = self.model(x)
        return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)

In [6]:
##### Perceptual Loss Network using VGG-16 Pre Trained Model

class VGGPerceptualLoss(nn.Module):
    def __init__(self, resize=True):

        super(VGGPerceptualLoss, self).__init__()

        blocks = []

        ## Notes:
        ## .eval() makes the model in evaluation mode to disable use of dropout
        ## the feature blocks have been done before every maxpool layer of the VGG-16 NET
        ## check the below link to get the sequential orders of the layers and match the numbers 4,9,16,23
        ## visit this link : https://stackoverflow.com/questions/53114882/pytorch-modifying-vgg16-architecture 
        temp = torchvision.models.vgg16(pretrained=True)
        # for param in temp.parameters():
        #   param.requires_grad = False

        blocks.append(temp.features[:4].eval())
        blocks.append(temp.features[4:9].eval())
        blocks.append(temp.features[9:16].eval())
        blocks.append(temp.features[16:23].eval())

        for bl in blocks:
            for p in bl:

                # since we are doing only forward prop and no backpropagation, gradient is not necessary
                p.requires_grad = False
        
        ## this is basically like an array of the first 4 layers of the VGG
        self.blocks = torch.nn.ModuleList(blocks)
        self.transform = torch.nn.functional.interpolate

        ## this is given in torchvision documentation and is necessary to normalize for input
        ## can simplify this to one line check doc later
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
        self.resize = resize

    def forward(self, input, target):
        

        ## Applying the normalization defined before
        if input.shape[1] != 3:
            input = input.repeat(1, 3, 1, 1)
            target = target.repeat(1, 3, 1, 1)
        input = (input-self.mean) / self.std
        target = (target-self.mean) / self.std

        ## VGG-16 takes in only images of size 224x224 hence the reshapes
        if self.resize:
            input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
            target = self.transform(target, mode='bilinear', size=(224, 224), align_corners=False)
        
        ## Here we compute loss
        loss = 0.0
        x = input
        y = target

        for block in self.blocks:

            # The image and ground truth are fed to the VGG blocks
            x = block(x)
            y = block(y)

            # simple L1 loss between the output of the network from giving the actual input-x and GT-y to the network
            loss += nn.functional.l1_loss(x, y)
        return loss

In [7]:
! cp "/content/drive/My Drive/DIV2K_train_HR.zip"  "/content/"
! unzip DIV2K_train_HR.zip 
! rm DIV2K_train_HR.zip
! mv /content/DIV2K_train_HR /content/dataset

Archive:  DIV2K_train_HR.zip
   creating: DIV2K_train_HR/
  inflating: DIV2K_train_HR/0103.png  
  inflating: DIV2K_train_HR/0413.png  
  inflating: DIV2K_train_HR/0031.png  
  inflating: DIV2K_train_HR/0660.png  
  inflating: DIV2K_train_HR/0126.png  
  inflating: DIV2K_train_HR/0793.png  
  inflating: DIV2K_train_HR/0764.png  
  inflating: DIV2K_train_HR/0550.png  
  inflating: DIV2K_train_HR/0437.png  
  inflating: DIV2K_train_HR/0374.png  
  inflating: DIV2K_train_HR/0755.png  
  inflating: DIV2K_train_HR/0614.png  
  inflating: DIV2K_train_HR/0646.png  
  inflating: DIV2K_train_HR/0371.png  
  inflating: DIV2K_train_HR/0312.png  
  inflating: DIV2K_train_HR/0108.png  
  inflating: DIV2K_train_HR/0556.png  
  inflating: DIV2K_train_HR/0794.png  
  inflating: DIV2K_train_HR/0722.png  
  inflating: DIV2K_train_HR/0780.png  
  inflating: DIV2K_train_HR/0555.png  
  inflating: DIV2K_train_HR/0439.png  
  inflating: DIV2K_train_HR/0396.png  
  inflating: DIV2K_train_HR/0666.png  
  infl

In [8]:
! mkdir /content/data/
! mkdir /content/data/lowres/ 
! mkdir /content/data/highres/

In [9]:
count = 0 
dims = 128

files = os.listdir('/content/dataset')
percheck = 0

for item in files:

  filename = os.path.join('/content/dataset',item)
  img = cv2.imread(filename)

  for r in range(0,img.shape[0], dims):
    for c in range(0,img.shape[1], dims):
      highres_name = "/content/data/highres/"+ str(count) + ".png"
      lowres_name = "/content/data/lowres/"+ str(count) + ".png"
      cropped = img[r:r+dims, c:c+dims,:]
      if(cropped.shape != (dims,dims,3)):
        continue
      cropped_lowres = cv2.resize(cropped,(int(dims/2),int(dims/2)))
      cv2.imwrite(highres_name,cropped)
      cv2.imwrite(lowres_name,cropped_lowres)
      count+=1
    
  percheck+=1
  print('\r',percheck*100/len(files),end='')

 100.0

In [10]:
!rm -rf dataset

In [11]:
class SupaResDataset(Dataset):
  def __init__(self):
    files = os.listdir("/content/data/lowres")
    self.transform = transforms.ToTensor()
    self.lowres_img = []
    self.highres_img = []
    count = 0
    for file in files:
      low_name = "/content/data/lowres/"+file
      high_name = "/content/data/highres/"+file
      self.lowres_img.append(low_name)
      self.highres_img.append(high_name)


  def __getitem__(self, i):

      lr_img = Image.open(self.lowres_img[i], mode='r')
      lr_img = lr_img.convert('RGB')
      lr_img = self.transform(lr_img)
      hr_img = Image.open(self.highres_img[i], mode='r')
      hr_img = hr_img.convert('RGB')
      hr_img = self.transform(hr_img)

      return lr_img,hr_img
  
  def __len__(self):
    return len(self.lowres_img)

In [18]:
batch_size = 38
dataset = SupaResDataset()
train_loader = torch.utils.data.DataLoader(
  dataset, 
  batch_size=batch_size, 
  shuffle=True, 
  num_workers=0,
  pin_memory=True
)

In [19]:
cuda = True

In [30]:
def save_models(generator,discriminator,model_dir,epochno):
  torch.save(generator.state_dict(),model_dir+str(epochno)+"_Gen.pth")
  torch.save(discriminator.state_dict(),model_dir+str(epochno)+"_Dis.pth")

def test_images(generator,image_dir,epochno):
  img = Image.open("/content/drive/My Drive/Datasets/Super Resolution/input.jpg", mode='r')
  img = img.convert('RGB')
  transform = transforms.ToTensor()
  img = transform(img)
  img = img.unsqueeze(0)
  generator.eval()
  with torch.no_grad():
    pred = generator(img.cuda())
  generator.train()
  save_image(pred[0], image_dir+str(epochno)+".jpg")

def load_models(generator,discriminator,model_dir,num):
  generator.load_state_dict(torch.load(model_dir+str(num)+"_Gen.pth"))
  discriminator.load_state_dict(torch.load(model_dir+str(num)+"_Dis.pth"))

In [21]:
epochs = 10
model_dir = "/content/drive/My Drive/Datasets/Super Resolution/Model Checkpoints/"
img_dir = "/content/drive/My Drive/Datasets/Super Resolution/Image checkpoints/"
filecount = len(os.listdir("/content/data/lowres"))
steps_per_epoch = int(filecount/batch_size)

In [22]:
generator = MobileNet()
discriminator = Discriminator()
perceptual = VGGPerceptualLoss()
load_models(generator,discriminator,model_dir,11)
if(cuda):
  generator = generator.cuda()
  discriminator = discriminator.cuda()
  perceptual = perceptual.cuda()

In [23]:
adversarial_loss = nn.BCEWithLogitsLoss()

In [24]:
optimizerG = torch.optim.Adam(generator.parameters(), lr=1e-4)
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=1e-4)

In [25]:
displacement = 12

In [None]:
for i in range(epochs):
  runningG_loss = 0
  runningD_loss = 0

  iterator = iter(train_loader)
  print("Epoch "+str(displacement+i)+" start")
  start_time = time.time()
  for z in range(steps_per_epoch):

    lr_img, hr_img = next(iterator)
    if(cuda):
      lr_img = lr_img.cuda()
      hr_img = hr_img.cuda()
    
    sr_pred = generator(lr_img)
    disc_pred = discriminator(sr_pred)

    perceptualloss = perceptual(hr_img.detach(),sr_pred)
    advloss = adversarial_loss(disc_pred,torch.ones_like(disc_pred))
    loss = perceptualloss+ 0.001*advloss
    runningG_loss+=loss.item()
    optimizerG.zero_grad()
    loss.backward()
    optimizerG.step()

    hr_disc = discriminator(hr_img)
    sr_disc = discriminator(sr_pred.detach())
    
    advloss = adversarial_loss(sr_disc, torch.zeros_like(sr_disc)) + \
              adversarial_loss(hr_disc, torch.ones_like(hr_disc))
    runningD_loss+=advloss.item()
    optimizerD.zero_grad()
    advloss.backward()
    optimizerD.step()
    
    print('\r',str((z+1)*100/steps_per_epoch)[0:4]+
          "\t GenLoss: "+str(runningG_loss/((z+1)*batch_size))+
          "\t DiscLoss: "+str(runningD_loss/((z+1)*batch_size))+
          "\t "+str(time.time()-start_time),
          end='')
  
  save_models(generator,discriminator,model_dir,displacement+i)
  test_images(generator,img_dir,displacement+i)
  
  print()

Epoch 9 start
 81.3	 GenLoss: 0.029021943512352384	 DiscLoss: 0.02648444243034608	 2257.6144506931305