In [None]:
import os
import torch
import torchvision
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch import nn
from torchvision.models import vgg19

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
class Dataset(torch.utils.data.Dataset):
  def __init__(self, dtry):
    super().__init__()
    self.dtry = dtry
    self.images = os.listdir(self.dtry)

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img = torchvision.io.read_image(os.path.join(self.dtry,self.images[idx]))
    high_res_img = torchvision.transforms.RandomCrop((96, 96))(img)
    low_res_img = torchvision.transforms.Resize((24,24))(high_res_img)
    high_res_img = high_res_img/255.0
    low_res_img = low_res_img/255.0
    return high_res_img, low_res_img

In [None]:
class ConvBlock(nn.Module):
  def __init__(self, disc, use_bn, use_act, in_channels, out_channels, kernel_size, stride, padding):
    super().__init__()
    modules = []
    modules.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
    if use_bn: modules.append(nn.BatchNorm2d(out_channels))
    if use_act:
      if disc: modules.append(nn.LeakyReLU(0.2, inplace=True))
      else: modules.append(nn.PReLU(num_parameters=out_channels))
    self.layers = nn.Sequential(*modules)

  def forward(self, x):
    return self.layers(x)

class UpsampleBlock(nn.Module):
  def __init__(self, in_channels, scale_factor):
    super().__init__()
    self.conv = nn.Conv2d(in_channels, in_channels * scale_factor**2, 3, 1, 1)
    self.ps = nn.PixelShuffle(scale_factor)
    self.act = nn.PReLU(in_channels)

  def forward(self, x):
    return self.act(self.ps(self.conv(x)))

class ResidualBlock(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.block1 = ConvBlock(disc=False, use_bn=True, use_act=True,  in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1)
    self.block2 = ConvBlock(disc=False, use_bn=True, use_act=False, in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=1, padding=1)

  def forward(self, x):
    out = self.block1(x)
    out = self.block2(x)
    return out+x

In [None]:
class Generator(nn.Module):
  def __init__(self, in_channels=3, n_channels=64, n_blocks=16):
    super().__init__()
    self.initial = ConvBlock(disc=False, use_bn=True, use_act=True, in_channels=in_channels, out_channels=n_channels, kernel_size=9, stride=1, padding=4)
    self.residuals = nn.Sequential(*[ResidualBlock(n_channels) for _ in range(n_blocks)])
    self.conv = ConvBlock(disc=False, use_bn=True, use_act=False, in_channels=n_channels, out_channels=n_channels, kernel_size=9, stride=1, padding=4)
    self.upsample = nn.Sequential(UpsampleBlock(n_channels, 2),UpsampleBlock(n_channels, 2))
    self.final = nn.Conv2d(n_channels, in_channels, 9, 1, 4)

  def forward(self, x):
    initial = self.initial(x)
    x = self.residuals(initial)
    x = self.conv(x) + initial
    x = self.upsample(x)
    x = self.final(x)
    return torch.sigmoid(x)

class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
    super().__init__()
    blocks = []
    for idx, feature in enumerate(features):
      blocks.append(ConvBlock(disc=True, use_bn=False if idx==0 else True, use_act=True, in_channels=in_channels,\
                              out_channels=feature, kernel_size=3, stride=1+(idx%2), padding=1))
      in_channels = feature
    self.layers = nn.Sequential(*blocks)
    self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d((6,6)),\
                                    nn.Flatten(),\
                                    nn.Linear(512*6*6,1024),\
                                    nn.LeakyReLU(0.2,inplace=True),\
                                    nn.Linear(1024,1), \
                                    nn.Sigmoid())

  def forward(self, x):
    x = self.layers(x)
    return self.classifier(x)


In [None]:
class VGGLoss(nn.Module):
  def __init__(self):
    super().__init__()
    self.vgg = vgg19(pretrained=True).features[:36].eval().to(device)
    self.loss = nn.MSELoss()

    for param in self.vgg.parameters():
      param.requires_grad = False
  
  def forward(self, y, y_hat):
    vgg_y = self.vgg(y)
    vgg_y_hat = self.vgg(y_hat)
    return self.loss(vgg_y, vgg_y_hat)

In [None]:
class SRGAN:
  def __init__(self,args):
    super().__init__()
    self.n_epochs = args['n_epochs']
    self.dtry = args['dtry']
    self.batch_size = args['batch_size']
    self.dataset = args['dataset']

    self.data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, num_workers=4)

    self.G = Generator().to(device)
    self.D = Discriminator().to(device)
    self.G.train()
    self.D.train()
    if 'SRGAN_G.pkl' in os.listdir(self.dtry):
      self.load(G=True, D=True)
    self.G_optim = torch.optim.Adam(self.G.parameters(), lr=args['lrG'], betas=(args['beta1'], args['beta2']))
    self.D_optim = torch.optim.Adam(self.D.parameters(), lr=args['lrD'], betas=(args['beta1'], args['beta2']))
    self.vgg_loss = VGGLoss()
    self.bce = nn.BCELoss()

  def train(self):
    for epoch in range(self.n_epochs):
      loop = tqdm(self.data_loader,  position=0, leave=True)

      for y, x in loop:
        y = y.to(device)
        x = x.to(device)
        y_hat = self.G(x)
        real = self.D(y)
        fake = self.D(y_hat.detach())

        # update D network
        D_loss = self.bce(real, torch.ones_like(real)) + self.bce(fake, torch.zeros_like(fake))
        self.D_optim.zero_grad()
        D_loss.backward()
        self.D_optim.step()

        # update G network
        fake = self.D(y_hat)
        adversarial_loss = 1e-3 * self.bce(fake, torch.ones_like(fake))
        perceptual_loss = 0.006 * self.vgg_loss(y, y_hat)
        G_loss = adversarial_loss + perceptual_loss
        self.G_optim.zero_grad()
        G_loss.backward()
        self.G_optim.step()

        loop.set_postfix(loss=(G_loss.item(),D_loss.item()))
      self.save()
    self.save()

  def test(self, x):
    y_hat = self.G(x)
    return y_hat

  def save(self):
    torch.save(self.G.state_dict(), os.path.join(self.dtry,'SRGAN_G.pkl'))
    torch.save(self.D.state_dict(), os.path.join(self.dtry,'SRGAN_D.pkl'))
  
  def load(self, G, D):
    self.G.load_state_dict(torch.load(os.path.join(self.dtry,'SRGAN_G.pkl')))
    self.D.load_state_dict(torch.load(os.path.join(self.dtry,'SRGAN_D.pkl')))

In [None]:
dataset = Dataset('drive/MyDrive/Carvana/train')
args = {'dtry':'drive/MyDrive/Carvana','n_epochs':10,'batch_size':64,\
        'dataset':dataset,'lrG':1e-4,'lrD':1e-4,'beta1':0.9,'beta2':0.999}
model = SRGAN(args)
model.train()

In [None]:
idx = 0
dtry = 'drive/MyDrive/Carvana/train'
images = os.listdir(dtry)
high_res_img = torchvision.io.read_image(os.path.join(dtry,images[idx]))/255.0
low_res_img  = torchvision.transforms.Resize((high_res_img.shape[1]//4,high_res_img.shape[2]//4))(high_res_img)
low_res_img = low_res_img.unsqueeze(0)
high_res_img_hat = model.test(low_res_img.to(device))
high_res_img_hat = high_res_img_hat.squeeze(0)