In [52]:
import torch
import torchvision
import torch.nn as nn
from torchsummary import summary
import os
import cv2
import numpy as np
from google.colab.patches import cv2_imshow
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms as transforms
from torchvision.models.vgg import vgg16
from torchvision.utils import save_image
import gc
import math
import sys, time
from torch.autograd import Variable

In [53]:
!pip install onnxruntime
import torch.utils.model_zoo as model_zoo
import torch.onnx
import onnxruntime



In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [54]:
class MobileNet(nn.Module):
    def __init__(self):
        super(MobileNet, self).__init__()

        def conv_bn(inp, oup, stride):
            return nn.Sequential(
                nn.Conv2d(inp, oup, 3, stride,1, bias=False),
                nn.PReLU()
            )

        def conv_dw(inp, oup, stride,kernel_size):
            return nn.Sequential(
                nn.Conv2d(inp, inp, kernel_size, stride,1, groups=inp, bias=False),
                nn.PReLU(),
                nn.Conv2d(inp, oup, 1, 1, bias=False),
                nn.PReLU(),
            )

        def deconv_dw(inp, oup, upscale,kernel_size):
          return nn.Sequential(
              nn.Conv2d(inp, oup, 1, 1,padding=1, bias=False),
              nn.PReLU(),
              nn.Conv2d(oup, oup * upscale ** 2, kernel_size,padding=1,groups=oup),
              nn.PixelShuffle(upscale),
              nn.PReLU(), 
          )
        
        def res_block(inp):
          return nn.Sequential(
              conv_dw(inp,inp,1,3),
              nn.PReLU(),
              conv_dw(inp,inp,1,3),
          )
        
        self.conv_1 = conv_bn(3,32,1)
        self.res_block1 = res_block(32)
        self.res_block2 = res_block(32)
        self.upsample_1 = deconv_dw(32,48,2,3)
        self.finconv = deconv_dw(48,3,1,9)
        
    def forward(self, inp):
      pre_res = self.conv_1(inp)
      x1 = self.res_block1(pre_res)
      x1+= pre_res
      x2 = self.res_block2(x1)
      x2+= x1
      x2+= pre_res
      x = self.upsample_1(x2)
      x = self.finconv(x)
      return (torch.tanh(x) + 1) / 2

In [55]:
import torch.nn.functional as F

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def conv_dw(inp, oup, stride):
            return nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(inp, inp, 3, stride, groups=inp, bias=False),
                nn.LeakyReLU(0.2),
                nn.Conv2d(inp, oup, 1, 1, bias=False),
                nn.LeakyReLU(0.2),
                nn.BatchNorm2d(oup)
            )
        
        self.model = nn.Sequential(
            conv_dw(3,64,1),
            conv_dw(64,64,2),
            conv_dw(64,128,1),
            conv_dw(128,128,2),
            conv_dw(128,256,1),
            conv_dw(256,256,2),
            conv_dw(256,512,1),
            conv_dw(512,512,2),
            nn.ReflectionPad2d(1),
            nn.Conv2d(512,1,1,stride = 1)
        )

    def forward(self, x):
        x = self.model(x)
        return torch.sigmoid(F.avg_pool2d(x, x.size()[2:])).view(x.size()[0], -1)

In [56]:
class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval().cuda()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss


class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

In [None]:
! cp "/content/drive/My Drive/DIV2K_train_HR.zip"  "/content/"
! unzip DIV2K_train_HR.zip 
! rm DIV2K_train_HR.zip
! mv /content/DIV2K_train_HR /content/dataset

In [None]:
! mkdir /content/data/
! mkdir /content/data/lowres/ 
! mkdir /content/data/highres/

In [None]:
count = 0 
dims = 128

files = os.listdir('/content/dataset')
percheck = 0

for item in files:

  filename = os.path.join('/content/dataset',item)
  img = cv2.imread(filename)

  for r in range(0,img.shape[0], dims):
    for c in range(0,img.shape[1], dims):
      highres_name = "/content/data/highres/"+ str(count) + ".png"
      lowres_name = "/content/data/lowres/"+ str(count) + ".png"
      cropped = img[r:r+dims, c:c+dims,:]
      if(cropped.shape != (dims,dims,3)):
        continue
      cropped_lowres = cv2.resize(cropped,(int(dims/2),int(dims/2)))
      cv2.imwrite(highres_name,cropped)
      cv2.imwrite(lowres_name,cropped_lowres)
      count+=1
    
  percheck+=1
  print('\r',percheck*100/len(files),end='')

 100.0

In [None]:
!rm -rf dataset

In [57]:
class SupaResDataset(Dataset):
  def __init__(self):
    files = os.listdir("/content/data/lowres")
    self.transform = transforms.ToTensor()
    self.toPIL = transforms.ToPILImage()
    self.lowres_img = []
    self.highres_img = []
    count = 0
    for file in files:
      low_name = "/content/data/lowres/"+file
      high_name = "/content/data/highres/"+file
      self.lowres_img.append(low_name)
      self.highres_img.append(high_name)


  def __getitem__(self, i):

      lr_img = Image.open(self.lowres_img[i], mode='r')
      lr_img = self.transform(lr_img)
      hr_img = Image.open(self.highres_img[i], mode='r')
      hr_img = self.transform(hr_img)

      return lr_img,hr_img
  
  def __len__(self):
    return len(self.lowres_img)

In [62]:
generator = MobileNet().cuda()
discriminator = Discriminator().cuda()
gen_loss = GeneratorLoss().cuda()

In [75]:
batch_size = 32
dataset = SupaResDataset()
train_loader = torch.utils.data.DataLoader(
  dataset, 
  batch_size=batch_size, 
  shuffle=True, 
  num_workers=4,
  pin_memory=True
)

In [60]:
def save_models(generator,discriminator,model_dir,epochno):
  torch.save(generator.state_dict(),model_dir+str(epochno)+"_Gen.pth")
  torch.save(discriminator.state_dict(),model_dir+str(epochno)+"_Dis.pth")

def test_images(generator,image_dir,epochno):
  img = Image.open("/content/drive/My Drive/Datasets/Super Resolution/input.jpg", mode='r')
  img = Variable(transforms.ToTensor()(img), requires_grad=False).unsqueeze(0).cuda()
  generator.eval()
  with torch.no_grad():
    out = generator(img)
  generator.train()
  out_img = transforms.ToPILImage()(out[0].data.cpu())
  out_img.save(image_dir+str(epochno)+".jpg")

def load_models(generator,discriminator,model_dir,num):
  generator.load_state_dict(torch.load(model_dir+str(num)+"_Gen.pth"))
  discriminator.load_state_dict(torch.load(model_dir+str(num)+"_Dis.pth"))

In [76]:
epochs = 10
model_dir = "/content/drive/My Drive/Datasets/Super Resolution/Model Checkpoints/"
img_dir = "/content/drive/My Drive/Datasets/Super Resolution/Image checkpoints/"
filecount = len(os.listdir("/content/data/lowres"))
steps_per_epoch = int(filecount/batch_size)

In [63]:
optimizerG = torch.optim.Adam(generator.parameters())
optimizerD = torch.optim.Adam(discriminator.parameters())

In [78]:
displacement = 11

In [None]:
generator.train()
discriminator.train()
for i in range(epochs):

  runningG_loss = 0
  runningD_loss = 0
  iterator = iter(train_loader)
  print("Epoch "+str(displacement+i)+" start")
  start_time = time.time()

  for z in range(steps_per_epoch):

    lr_img, hr_img = next(iterator)
    lr_img = lr_img.cuda()
    hr_img = hr_img.cuda()

    generator.zero_grad()
    fk_img = generator(lr_img)
    fake_out = discriminator(fk_img).detach()
    g_loss = gen_loss(fake_out,fk_img,hr_img)
    g_loss.backward()
    runningG_loss+=g_loss.item()
    optimizerG.step()
    
    discriminator.zero_grad()
    real_out = discriminator(hr_img).mean()
    fake_out = discriminator(fk_img.detach()).mean()
    d_loss = 1-real_out + fake_out
    runningD_loss+=d_loss.item()
    d_loss.backward()
    optimizerD.step()    
  
    print(
          '\r',str((z+1)*100/steps_per_epoch)[0:4]+
          "\t GenLoss: "+str(runningG_loss/((z+1)*batch_size))+
          "\t DiscLoss: "+str(runningD_loss/((z+1)*batch_size))+
          "\t "+str(time.time()-start_time),
          end=''
         )
  
  save_models(generator,discriminator,model_dir,displacement+i)
  test_images(generator,img_dir,displacement+i)
  
  print()

Epoch 11 start
 100.	 GenLoss: 7.498227774421612e-05	 DiscLoss: 2.579525898954184e-10	 1335.5373508930206
Epoch 12 start
 100.	 GenLoss: 7.470391589509806e-05	 DiscLoss: 2.0462895456939256e-11	 1332.2308876514435
Epoch 13 start
 100.	 GenLoss: 7.444949362647993e-05	 DiscLoss: 8.639404290613273e-12	 1333.0729355812073
Epoch 14 start
 42.4	 GenLoss: 7.429587517826629e-05	 DiscLoss: 1.3132153119155893e-12	 566.2831008434296

In [50]:
img = Image.open("/content/drive/My Drive/Datasets/Super Resolution/input.jpg", mode='r')
img = Variable(transforms.ToTensor()(img), requires_grad=False).unsqueeze(0).cuda()
out = generator(img)
out_img = transforms.ToPILImage()(out[0].data.cpu())
out_img.save("oup.jpg")

In [None]:
img = cv2.imread("oup.jpg")
from google.colab.patches import cv2_imshow
cv2_imshow(img)