In [1]:
%%capture
!pip install kaggle

In [2]:
from google.colab import files
files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d joe1995/div2k-dataset

from zipfile import ZipFile
dataset_path = './div2k-dataset.zip'

with ZipFile(dataset_path, 'r') as z:
    z.extractall(path='./dataset')
    print("Done")

Saving kaggle.json to kaggle.json
Downloading div2k-dataset.zip to /content
100% 3.71G/3.71G [02:48<00:00, 24.0MB/s]
100% 3.71G/3.71G [02:48<00:00, 23.6MB/s]
Done


In [3]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg19
from torch.utils.data import DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import os
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

In [4]:
class convBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_act):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.act = nn.LeakyReLU(0.2, inplace=True) if use_act else nn.Identity()

    def forward(self, x):
        return self.act(self.conv(x))

In [5]:
class upSampleBlock(nn.Module):
    def __init__(self, channels, scale_factor=2):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode="nearest")
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.act = nn.LeakyReLU(0.2, inplace=True)
    
    def forward(self, x):
        return self.act(self.conv(self.upsample(x)))

In [6]:
class denseResidualBlock(nn.Module):
    def __init__(self, in_channels, channels=32, residual_beta=0.2):
        super().__init__()
        self.residual_beta = residual_beta
        self.blocks = nn.ModuleList()

        for i in range(5):
            out_channels = in_channels if i == 4 else channels
            use_act = False if i == 4 else True
            block = convBlock(in_channels + channels * i, out_channels , kernel_size=3, stride=1, padding=1, use_act=use_act)
            self.blocks.append(block)

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat((inputs, out), dim=1)
        return x + self.residual_beta * out

In [7]:
class residualInResidualDenseBlocks(nn.Module):
    def __init__(self, in_channels, residual_beta=0.2):
        super().__init__()
        self.residual_beta = residual_beta
        rrdbs = [denseResidualBlock(in_channels, residual_beta=residual_beta) for _ in range(3)]
        self.rrdbs = nn.Sequential(*rrdbs)

    def forward(self, x):
        return x + self.residual_beta * self.rrdbs(x)

In [8]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, channels=64, num_res_blocks=23):
        super().__init__()
        self.initial = nn.Conv2d(in_channels, channels, kernel_size=3, stride=1, padding=1, bias=True)

        res_blocks = []
        for _ in range(num_res_blocks):
            res_block = residualInResidualDenseBlocks(channels)
            res_blocks.append(res_block)
        
        self.res_blocks = nn.Sequential(*res_blocks)
        
        self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.upscales = nn.Sequential(
            upSampleBlock(channels, scale_factor=2),
            upSampleBlock(channels, scale_factor=2),
        )
        self.final = nn.Sequential(
            nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(channels, 3, kernel_size=3, stride=1, padding=1, bias=True),
        )

    
    def forward(self, x):
        initial = self.initial(x)
        x = self.res_blocks(initial)
        x = initial + self.conv(x)
        x = self.upscales(x)
        x = self.final(x)
        return x

In [9]:
def initialize_weights(model, scale=0.1):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight.data)
            m.weight.data *= scale

In [10]:
gen = Generator()
initialize_weights(gen)
x = torch.randn((5, 3, 21, 64))
gen(x).shape

torch.Size([5, 3, 84, 256])

In [11]:
class Critic(nn.Module):
    def __init__(self, in_channels=3, channels=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        self.initial = convBlock(in_channels, channels[0], kernel_size=3, stride=1, padding=3, use_act=True)

        blocks = []
        in_channels = channels[0]
        for channel_idx, channel in enumerate(channels[1:]):
            stride = 2 if channel_idx%2 == 0 else 1
            block = convBlock(in_channels, channel, kernel_size=3, stride=stride, padding=1, use_act=True)
            blocks.append(block)
            in_channels = channel
        
        self.blocks = nn.Sequential(*blocks)

        self.final = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )
    def forward(self, x):
        x = self.initial(x)
        x = self.blocks(x)
        x = self.final(x)
        return x

In [12]:
critic = Critic()
x = torch.randn((5, 3, 84, 256))
critic(x).shape

torch.Size([5, 1])

In [13]:
class LowHighResDataset:
    def __init__(self, img_paths, high_low_res_transform, low_res_transform, high_res_transform):
        super().__init__()
        self.img_paths = img_paths

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img_path = self.img_paths[index]
        img = np.array(Image.open(img_path))

        img = high_low_res_transform(image=img)["image"]
        img_low = low_res_transform(image=img)["image"]
        img_high = high_res_transform(image=img)["image"]

        return img_low, img_high

In [14]:
HIGH_RES = 128
LOW_RES = HIGH_RES // 4
TRAIN_DIR = "/content/dataset/DIV2K_train_HR/DIV2K_train_HR"
BATCH_SIZE = 16
NUM_WORKERS = 0


high_low_res_transform = A.Compose(
    [
        A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)

low_res_transform = A.Compose(
    [
        A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
        A.Normalize(
            mean=[0, 0, 0], 
            std=[1, 1, 1]
        ),
        ToTensorV2(),
    ]
)

high_res_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)
test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

train_img_paths = [os.path.join(TRAIN_DIR, img_name) for img_name in os.listdir(TRAIN_DIR)]
train_dataset = LowHighResDataset(train_img_paths, high_low_res_transform, low_res_transform, high_res_transform)
train_data_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

In [15]:
def save_checkpoint(model, optimizer, filename="model.pth.tar"):
    print("=>Saving Checkpoint...")
    checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_path, model, optimizer, learning_rate):
    print("=>Loading Checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=torch.device(DEVICE))
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = learning_rate

In [16]:
def save_imgs(gen, test_transform, test_dir, dir="saved_imgs/"):
    gen.eval()
    img_names = os.listdir(test_dir)
    for img_name in img_names:
        img = np.array(Image.open(os.path.join(test_dir, img_name)))
        img = test_transform(image=img)["image"]
        img_batch = img.unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            fake_img = gen(img_batch)
        torchvision.utils.save_image(fake_img, f"{dir}{img_name}")

In [17]:
def gradientPenalty(critic, real_imgs, fake_imgs):
    batch_size, channels, height, width = real_imgs.shape
    epsilon = torch.rand((batch_size, 1, 1, 1)).repeat(1, channels, height, width).to(DEVICE)

    interpolated_imgs = epsilon*real_imgs + (1-epsilon)*fake_imgs
    interpolated_imgs.requires_grad_(True)
    outputs_interpolated_imgs = critic(interpolated_imgs)

    gradient = torch.autograd.grad(
        outputs=outputs_interpolated_imgs,
        inputs=interpolated_imgs,
        grad_outputs=torch.ones_like(outputs_interpolated_imgs),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm-1)**2)
    return gradient_penalty

In [18]:
def train(train_data_loader, gen, critic, opt_gen, opt_critic, l1_loss, vgg_loss, gen_scaler, critic_scaler, epoch):
    gen.train()
    critic.train()

    train_progress_bar = tqdm(enumerate(train_data_loader), total=len(train_data_loader), leave=True)
    for idx, (low_res_imgs, high_res_imgs) in train_progress_bar:
        low_res_imgs = low_res_imgs.to(DEVICE)
        high_res_imgs = high_res_imgs.to(DEVICE)

        with torch.cuda.amp.autocast():
            fake_imgs = gen(low_res_imgs)
            outputs_real_imgs = critic(high_res_imgs)
            outputs_fake_imgs = critic(fake_imgs.detach())

            gp = gradientPenalty(critic, high_res_imgs, fake_imgs.detach())
            loss_critic = -(torch.mean(outputs_real_imgs) - torch.mean(outputs_fake_imgs)) + LAMBDA_GP*gp
        
        critic.zero_grad()
        critic_scaler.scale(loss_critic).backward()
        critic_scaler.step(opt_critic)
        critic_scaler.update()

        with torch.cuda.amp.autocast():
            outputs_fake_imgs = critic(fake_imgs).view(-1)
            loss_gen_1 = -torch.mean(outputs_fake_imgs) * ADVERSARIAL_LAMBDA
            loss_gen_2 = l1_loss(fake_imgs, high_res_imgs) * L1_LAMBDA
            loss_gen_3 = vgg_loss(fake_imgs, high_res_imgs)
            loss_gen = loss_gen_1 + loss_gen_2 + loss_gen_3

        opt_gen.zero_grad()
        gen_scaler.scale(loss_gen).backward()
        gen_scaler.step(opt_gen)
        gen_scaler.update()

        train_progress_bar.set_description(f"Epoch [{epoch}/{EPOCHS-1}]")
        train_progress_bar.set_postfix(critic_loss=loss_critic.item(), gen_loss=loss_gen.item())
        
    train_progress_bar.close()

In [19]:
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:35].eval().to(DEVICE)

        for param in self.vgg.parameters():
            param.requires_grad = False

        self.mse = nn.MSELoss()

    def forward(self, inputs, targets):
        input_features = self.vgg(inputs)
        target_features = self.vgg(targets)
        return self.mse(input_features, target_features)

In [26]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CHECKPOINT_GEN = "gen.pth.tar"
CHECKPOINT_CRITIC = "critic.pth.tar"
LOAD_GEN = True
LOAD_CRITIC = True

LEARNING_RATE = 1e-4
EPOCHS = 3

LAMBDA_GP = 10
ADVERSARIAL_LAMBDA = 5e-3
L1_LAMBDA = 1e-2

gen = Generator().to(DEVICE)
critic = Critic().to(DEVICE)
initialize_weights(gen)

l1_loss = nn.L1Loss()
vgg_loss = VGGLoss()

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.999))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.999))

gen_scaler = torch.cuda.amp.GradScaler()
critic_scaler = torch.cuda.amp.GradScaler()

if LOAD_GEN:
    load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
if LOAD_CRITIC:
    load_checkpoint(CHECKPOINT_CRITIC, critic, opt_critic, LEARNING_RATE)

for epoch in range(EPOCHS):
    train(train_data_loader, gen, critic, opt_gen, opt_critic, l1_loss, vgg_loss, gen_scaler, critic_scaler, epoch)

    save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
    save_checkpoint(critic, opt_critic, filename=CHECKPOINT_CRITIC)
    save_imgs(gen, test_transform, test_dir="test_imgs")

=>Loading Checkpoint...
=>Loading Checkpoint...


Epoch [0/2]: 100%|██████████| 50/50 [01:50<00:00,  2.20s/it, critic_loss=2.76, gen_loss=2.27]


=>Saving Checkpoint...
=>Saving Checkpoint...


Epoch [1/2]: 100%|██████████| 50/50 [01:49<00:00,  2.19s/it, critic_loss=-30, gen_loss=3.29]


=>Saving Checkpoint...
=>Saving Checkpoint...


Epoch [2/2]: 100%|██████████| 50/50 [01:50<00:00,  2.21s/it, critic_loss=-17.8, gen_loss=2.66]


=>Saving Checkpoint...
=>Saving Checkpoint...
