In [None]:
!if [ ! -d "./datasets/DIV2K_train_HR" ]; then \
    wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip -P ./datasets && \
    unzip ./datasets/DIV2K_train_HR.zip -d ./datasets; \
fi

!if [ ! -d "./datasets/DIV2K_valid_LR_x8" ]; then \
    wget http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_x8.zip -P ./datasets && \
    unzip ./datasets/DIV2K_valid_LR_x8.zip -d ./datasets; \
fi

In [None]:
import os
import numpy as np
import albumentations as A
import torch 
from math import log2
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torchvision.models import vgg19
from torchvision.utils import save_image
from PIL import Image
from albumentations import Resize
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm

In [None]:
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

In [None]:
TRAIN_PATH = './datasets/DIV2K_train_HR/'
VAL_PATH = './datasets/DIV2K_valid_LR_x8/'
LOAD_MODEL = False
SAVE_MODEL = True
CHECKPOINT_GEN = 'gen.pth.tar'
CHECKPOINT_DISC = 'disc.pth.tar'
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
LEARNING_RATE = 1e-4
START_EPOCHS = 1
NUM_EPOCHS = 100
BATCH_SIZE = 64
NUM_WORKERS = 0
HIGH_RES = 256
RATIO = 8
LOW_RES = HIGH_RES // RATIO
IMG_CHANNELS = 3

In [None]:
high_res_transform = A.Compose([
  A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
  ToTensorV2(),
])
low_res_transform = A.Compose([
  A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
  A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
  ToTensorV2(),
])
both_transforms = A.Compose([
  A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
  A.HorizontalFlip(p=0.5),
  A.RandomRotate90(p=0.5),
])

In [None]:
def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake.detach() * (1 - alpha)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def plot_examples(low_res_folder, gen):
    os.system("rm saved/*")
    files = os.listdir(low_res_folder)
    np.random.shuffle(files)
    gen.eval()
    for file in files[:10]:
        image = Image.open(low_res_folder + file)
        try:
            with torch.no_grad():
                upscaled_img = gen(
                    test_transform(image=np.asarray(image))["image"]
                    .unsqueeze(0)
                    .to(DEVICE)
                )
            save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
        except : print('Memory insufficient for that image')
    gen.train()

def compress(img: np.ndarray, ratio:int):
    resized = Resize(width=img.shape[1]//ratio, height=img.shape[0]//ratio, interpolation=Image.BICUBIC)(image=img)["image"]
    return resized

In [None]:
class MyDataset(Dataset):
  def __init__(self, root_dir):
    super(MyDataset, self).__init__()
    self.data = []
    self.root_dir = root_dir
    self.data = [os.path.join(root_dir, fl) for fl in os.listdir(root_dir)]

  def __len__(self):
    return len(self.data)
  
  def __getitem__(self, index):
    img_file = self.data[index]
    image = np.array(Image.open(img_file))
    image = both_transforms(image=image)['image']
    high_res = high_res_transform(image=image)['image']
    low_res = low_res_transform(image=image)['image']
    return high_res, low_res

In [None]:
class ConvBlock(nn.Module):
  def __init__(
    self, 
    in_channels, 
    out_channels, 
    discriminator=False, 
    use_act=True, 
    use_bn=True, 
    **kwargs,
  ):
    super().__init__()
    self.use_act = use_act
    self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=(not use_bn))
    self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
    self.act = (
      nn.LeakyReLU(0.2, inplace=True) 
      if discriminator 
      else nn.PReLU(num_parameters=out_channels)
    )

  def forward(self, x):
    out = self.cnn(x)
    out = self.bn(out)
    if self.use_act:
      out = self.act(out)
    return out

class ResidualBlock(nn.Module):
  def __init__(self, in_channels):
    super().__init__()
    self.block1 = ConvBlock(
      in_channels,
      in_channels,
      kernel_size=3,
      stride=1,
      padding=1,
    )
    self.block2 = ConvBlock(
      in_channels,
      in_channels,
      kernel_size=3,
      stride=1,
      padding=1,
      use_act=False,
    )
  
  def forward(self, x):
    out = self.block1(x)
    out = self.block2(out)
    return out+x

class UnsampleBlock(nn.Module):
  def __init__(self, in_channels, scale_factor):
    super().__init__()
    self.conv = nn.Conv2d(
      in_channels,
      in_channels * scale_factor**2,
      kernel_size=3,
      stride=1,
      padding=1,
    )
    self.ps = nn.PixelShuffle(upscale_factor=scale_factor)
    self.act = nn.PReLU(num_parameters=in_channels)
  
  def forward(self, x):
    out = self.conv(x)
    out = self.ps(out)
    out = self.act(out)
    return out

class Generator(nn.Module):
  def __init__(self, in_channels=3, num_channels=64, num_blocks=16, ratio=4):
    super().__init__()
    self.initial = ConvBlock(
      in_channels, 
      num_channels, 
      kernel_size=9, 
      stride=1, 
      padding=4, 
      use_bn=False
    )
    self.residuals = nn.Sequential(
      *[ResidualBlock(num_channels) for _ in range(num_blocks)]
    )
    self.convblock = ConvBlock(
      num_channels, 
      num_channels, 
      kernel_size=3, 
      stride=1, 
      padding=1, 
      use_act=False
    )
    self.unsamples = nn.Sequential(
      *[UnsampleBlock(num_channels, 2) for _ in range(int(log2(ratio)))]
    )
    self.final = nn.Conv2d(
      num_channels,
      in_channels,
      kernel_size=9,
      stride=1,
      padding=4
    )
  
  def forward(self, x):
    initial = self.initial(x)
    x = self.residuals(initial)
    x = self.convblock(x) + initial
    x = self.unsamples(x)
    return torch.tanh(self.final(x))

class Discriminator(nn.Module):
  def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
    super().__init__()
    blocks = []
    for idx, feature in enumerate(features):
      blocks.append(
        ConvBlock(
          in_channels, 
          feature, 
          kernel_size=3, 
          stride=(1 + idx%2), 
          padding=1, 
          discriminator=True,
          use_act=True,
          use_bn=idx
        )
      )
      in_channels = feature
    
    self.blocks = nn.Sequential(*blocks)
    self.classifier = nn.Sequential(
      nn.AdaptiveAvgPool2d((6, 6)),
      nn.Flatten(),
      nn.Linear(512*6*6, 1024),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(1024, 1),
    )

  def forward(self, x):
    x = self.blocks(x)
    return self.classifier(x)

In [None]:
class VGGLoss(nn.Module):
  def __init__(self):
    super().__init__()
    self.vgg = vgg19(pretrained=True).features[:36].eval().to(DEVICE)
    self.loss = nn.MSELoss()

    for param in self.vgg.parameters():
      param.requires_grad = False
  
  def forward(self, input, target):
    vgg_input_features = self.vgg(input)
    vgg_target_features = self.vgg(target)
    return self.loss(vgg_input_features, vgg_target_features)

In [None]:
dataset = MyDataset(root_dir=TRAIN_PATH)
loader = DataLoader(
  dataset, 
  batch_size=BATCH_SIZE, 
  shuffle=True, 
  pin_memory=True, 
  num_workers=NUM_WORKERS
)
gen = Generator(in_channels=3, ratio=RATIO).to(DEVICE)
disc = Discriminator(in_channels=3).to(DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
vgg = VGGLoss()

In [None]:
def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg):
  loop = tqdm(loader, leave=True)

  for idx, (low_res, high_res) in enumerate(loop):
    high_res = high_res.to(DEVICE)
    low_res = low_res.to(DEVICE)
        
    fake = gen(low_res)
    disc_real = disc(high_res)
    disc_fake = disc(fake.detach())
    disc_loss_real = bce(disc_real, torch.ones_like(disc_real) - 0.1*torch.rand_like(disc_real))
    disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
    disc_loss = disc_loss_fake + disc_loss_real

    opt_disc.zero_grad()
    disc_loss.backward()
    opt_disc.step()

    disc_fake = disc(fake)
    l2_loss = mse(fake, high_res)
    adversarial_loss = bce(disc_fake, torch.ones_like(disc_fake))
    loss_for_vgg = vgg(fake, high_res)
    gen_loss = 6e-2*loss_for_vgg + 1e-2*adversarial_loss + 0.92*l2_loss

    opt_gen.zero_grad()
    gen_loss.backward()
    opt_gen.step()

    if not idx % 200:
      plot_examples(VAL_PATH, gen)
      print(f'Discrimantor loss:{disc_loss}')
      print(f'Generative loss:{gen_loss}')

In [None]:
if LOAD_MODEL:
  load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
  load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)

for epoch in range(START_EPOCHS-1,NUM_EPOCHS):
  print(f'====================== EPOCH: {epoch+1} =====================')
  train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg)

  if SAVE_MODEL:
    save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
    save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)
