In [1]:
%%capture
!pip install kaggle 

In [2]:
from google.colab import files
files.upload()

!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d joe1995/div2k-dataset

from zipfile import ZipFile
dataset_path = './div2k-dataset.zip'

with ZipFile(dataset_path, 'r') as z:
    z.extractall(path='./dataset')
    print("Done")

Saving kaggle.json to kaggle.json
Downloading div2k-dataset.zip to /content
100% 3.71G/3.71G [02:48<00:00, 24.0MB/s]
100% 3.71G/3.71G [02:48<00:00, 23.6MB/s]
Done


In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision.models import vgg19
from torch.utils.data import DataLoader

import albumentations as A
from albumentations.pytorch import ToTensorV2

import os
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
class _ResidualDenseBlock(nn.Module):
    """Achieves densely connected convolutional layers.
    `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993v5.pdf>` paper.
    Args:
        channels (int): The number of channels in the input image.
        growth_channels (int): The number of channels that increase in each layer of convolution.
    """

    def __init__(self, channels: int, growth_channels: int) -> None:
        super(_ResidualDenseBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels + growth_channels * 0, growth_channels, (3, 3), (1, 1), (1, 1))
        self.conv2 = nn.Conv2d(channels + growth_channels * 1, growth_channels, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(channels + growth_channels * 2, growth_channels, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(channels + growth_channels * 3, growth_channels, (3, 3), (1, 1), (1, 1))
        self.conv5 = nn.Conv2d(channels + growth_channels * 4, channels, (3, 3), (1, 1), (1, 1))

        self.leaky_relu = nn.LeakyReLU(0.2, True)
        self.identity = nn.Identity()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out1 = self.leaky_relu(self.conv1(x))
        out2 = self.leaky_relu(self.conv2(torch.cat([x, out1], 1)))
        out3 = self.leaky_relu(self.conv3(torch.cat([x, out1, out2], 1)))
        out4 = self.leaky_relu(self.conv4(torch.cat([x, out1, out2, out3], 1)))
        out5 = self.identity(self.conv5(torch.cat([x, out1, out2, out3, out4], 1)))
        out = torch.mul(out5, 0.2)
        out = torch.add(out, identity)

        return out


class _ResidualResidualDenseBlock(nn.Module):
    """Multi-layer residual dense convolution block.
    Args:
        channels (int): The number of channels in the input image.
        growth_channels (int): The number of channels that increase in each layer of convolution.
    """

    def __init__(self, channels: int, growth_channels: int) -> None:
        super(_ResidualResidualDenseBlock, self).__init__()
        self.rdb1 = _ResidualDenseBlock(channels, growth_channels)
        self.rdb2 = _ResidualDenseBlock(channels, growth_channels)
        self.rdb3 = _ResidualDenseBlock(channels, growth_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = x

        out = self.rdb1(x)
        out = self.rdb2(out)
        out = self.rdb3(out)
        out = torch.mul(out, 0.2)
        out = torch.add(out, identity)

        return out


class Discriminator(nn.Module):
    def __init__(self) -> None:
        super(Discriminator, self).__init__()
        self.features = nn.Sequential(
            # input size. (3) x 128 x 128
            nn.Conv2d(3, 64, (3, 3), (1, 1), (1, 1), bias=True),
            nn.LeakyReLU(0.2, True),
            # state size. (64) x 64 x 64
            nn.Conv2d(64, 64, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            # state size. (128) x 32 x 32
            nn.Conv2d(128, 128, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            # state size. (256) x 16 x 16
            nn.Conv2d(256, 256, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # state size. (512) x 8 x 8
            nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 512, (3, 3), (1, 1), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            # state size. (512) x 4 x 4
            nn.Conv2d(512, 512, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True)
        )

        self.classifier = nn.Sequential(
            nn.Linear(512 * 4 * 4, 100),
            nn.LeakyReLU(0.2, True),
            nn.Linear(100, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.features(x)
        out = torch.flatten(out, 1)
        out = self.classifier(out)

        return out


class RRDBNet(nn.Module):
    def __init__(
            self,
            in_channels: int = 3,
            out_channels: int = 3,
            channels: int = 64,
            growth_channels: int = 32,
            num_blocks: int = 23,
            upscale_factor: int = 4,
    ) -> None:
        super(RRDBNet, self).__init__()
        self.upscale_factor = upscale_factor

        # The first layer of convolutional layer.
        self.conv1 = nn.Conv2d(in_channels, channels, (3, 3), (1, 1), (1, 1))

        # Feature extraction backbone network.
        trunk = []
        for _ in range(num_blocks):
            trunk.append(_ResidualResidualDenseBlock(channels, growth_channels))
        self.trunk = nn.Sequential(*trunk)

        # After the feature extraction network, reconnect a layer of convolutional blocks.
        self.conv2 = nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1))

        # Upsampling convolutional layer.
        if upscale_factor == 2:
            self.upsampling1 = nn.Sequential(
                nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
                nn.LeakyReLU(0.2, True)
            )
        if upscale_factor == 4:
            self.upsampling1 = nn.Sequential(
                nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
                nn.LeakyReLU(0.2, True)
            )
            self.upsampling2 = nn.Sequential(
                nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
                nn.LeakyReLU(0.2, True)
            )
        if upscale_factor == 8:
            self.upsampling1 = nn.Sequential(
                nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
                nn.LeakyReLU(0.2, True)
            )
            self.upsampling2 = nn.Sequential(
                nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
                nn.LeakyReLU(0.2, True)
            )
            self.upsampling3 = nn.Sequential(
                nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
                nn.LeakyReLU(0.2, True)
            )

        # Reconnect a layer of convolution block after upsampling.
        self.conv3 = nn.Sequential(
            nn.Conv2d(channels, channels, (3, 3), (1, 1), (1, 1)),
            nn.LeakyReLU(0.2, True)
        )

        # Output layer.
        self.conv4 = nn.Conv2d(channels, out_channels, (3, 3), (1, 1), (1, 1))

        # Initialize all layer
        self._initialize_weights()

    # The model should be defined in the Torch.script method.
    def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
        out1 = self.conv1(x)
        out = self.trunk(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)

        if self.upscale_factor == 2:
            out = self.upsampling1(F.interpolate(out, scale_factor=2, mode="nearest"))
        if self.upscale_factor == 4:
            out = self.upsampling1(F.interpolate(out, scale_factor=2, mode="nearest"))
            out = self.upsampling2(F.interpolate(out, scale_factor=2, mode="nearest"))
        if self.upscale_factor == 8:
            out = self.upsampling1(F.interpolate(out, scale_factor=2, mode="nearest"))
            out = self.upsampling2(F.interpolate(out, scale_factor=2, mode="nearest"))
            out = self.upsampling3(F.interpolate(out, scale_factor=2, mode="nearest"))

        out = self.conv3(out)
        out = self.conv4(out)

        out = torch.clamp_(out, 0.0, 1.0)

        return out

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self._forward_impl(x)

    def _initialize_weights(self) -> None:
        for module in self.modules():
            if isinstance(module, nn.Conv2d):
                nn.init.kaiming_normal_(module.weight)
                module.weight.data *= 0.1
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)

In [13]:
class LowHighResDataset:
    def __init__(self, img_paths, high_low_res_transform, low_res_transform, high_res_transform):
        super().__init__()
        self.img_paths = img_paths

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        img_path = self.img_paths[index]
        img = np.array(Image.open(img_path))

        img = high_low_res_transform(image=img)["image"]
        img_low = low_res_transform(image=img)["image"]
        img_high = high_res_transform(image=img)["image"]

        return img_low, img_high

In [14]:
HIGH_RES = 128
LOW_RES = HIGH_RES // 4
TRAIN_DIR = "/content/dataset/DIV2K_train_HR/DIV2K_train_HR"
BATCH_SIZE = 16
NUM_WORKERS = 0


high_low_res_transform = A.Compose(
    [
        A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)

low_res_transform = A.Compose(
    [
        A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
        A.Normalize(
            mean=[0, 0, 0], 
            std=[1, 1, 1]
        ),
        ToTensorV2(),
    ]
)

high_res_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)
test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

train_img_paths = [os.path.join(TRAIN_DIR, img_name) for img_name in os.listdir(TRAIN_DIR)]
train_dataset = LowHighResDataset(train_img_paths, high_low_res_transform, low_res_transform, high_res_transform)
train_data_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
)

In [15]:
def save_checkpoint(model, optimizer, filename="model.pth.tar"):
    print("=>Saving Checkpoint...")
    checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}
    torch.save(checkpoint, filename)

def load_checkpoint(checkpoint_path, model, optimizer, learning_rate):
    print("=>Loading Checkpoint...")
    checkpoint = torch.load(checkpoint_path, map_location=torch.device(DEVICE))
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    for param_group in optimizer.param_groups:
        param_group["lr"] = learning_rate

In [16]:
def save_imgs(gen, test_transform, test_dir, dir="saved_imgs/"):
    gen.eval()
    img_names = os.listdir(test_dir)
    for img_name in img_names:
        img = np.array(Image.open(os.path.join(test_dir, img_name)))
        img = test_transform(image=img)["image"]
        img_batch = img.unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            fake_img = gen(img_batch)
        torchvision.utils.save_image(fake_img, f"{dir}{img_name}")

In [17]:
def gradientPenalty(critic, real_imgs, fake_imgs):
    batch_size, channels, height, width = real_imgs.shape
    epsilon = torch.rand((batch_size, 1, 1, 1)).repeat(1, channels, height, width).to(DEVICE)

    interpolated_imgs = epsilon*real_imgs + (1-epsilon)*fake_imgs
    interpolated_imgs.requires_grad_(True)
    outputs_interpolated_imgs = critic(interpolated_imgs)

    gradient = torch.autograd.grad(
        outputs=outputs_interpolated_imgs,
        inputs=interpolated_imgs,
        grad_outputs=torch.ones_like(outputs_interpolated_imgs),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm-1)**2)
    return gradient_penalty

In [18]:
def train(train_data_loader, gen, critic, opt_gen, opt_critic, l1_loss, vgg_loss, gen_scaler, critic_scaler, epoch):
    gen.train()
    critic.train()

    train_progress_bar = tqdm(enumerate(train_data_loader), total=len(train_data_loader), leave=True)
    for idx, (low_res_imgs, high_res_imgs) in train_progress_bar:
        low_res_imgs = low_res_imgs.to(DEVICE)
        high_res_imgs = high_res_imgs.to(DEVICE)

        with torch.cuda.amp.autocast():
            fake_imgs = gen(low_res_imgs)
            outputs_real_imgs = critic(high_res_imgs)
            outputs_fake_imgs = critic(fake_imgs.detach())

            gp = gradientPenalty(critic, high_res_imgs, fake_imgs.detach())
            loss_critic = -(torch.mean(outputs_real_imgs) - torch.mean(outputs_fake_imgs)) + LAMBDA_GP*gp
        
        critic.zero_grad()
        critic_scaler.scale(loss_critic).backward()
        critic_scaler.step(opt_critic)
        critic_scaler.update()

        with torch.cuda.amp.autocast():
            outputs_fake_imgs = critic(fake_imgs).view(-1)
            loss_gen_1 = -torch.mean(outputs_fake_imgs) * ADVERSARIAL_LAMBDA
            loss_gen_2 = l1_loss(fake_imgs, high_res_imgs) * L1_LAMBDA
            loss_gen_3 = vgg_loss(fake_imgs, high_res_imgs)
            loss_gen = loss_gen_1 + loss_gen_2 + loss_gen_3

        opt_gen.zero_grad()
        gen_scaler.scale(loss_gen).backward()
        gen_scaler.step(opt_gen)
        gen_scaler.update()

        train_progress_bar.set_description(f"Epoch [{epoch}/{EPOCHS-1}]")
        train_progress_bar.set_postfix(critic_loss=loss_critic.item(), gen_loss=loss_gen.item())
        
    train_progress_bar.close()

In [19]:
class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:35].eval().to(DEVICE)

        for param in self.vgg.parameters():
            param.requires_grad = False

        self.mse = nn.MSELoss()

    def forward(self, inputs, targets):
        input_features = self.vgg(inputs)
        target_features = self.vgg(targets)
        return self.mse(input_features, target_features)

In [26]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CHECKPOINT_GEN = "gen.pth.tar"
CHECKPOINT_CRITIC = "critic.pth.tar"
LOAD_GEN = True
LOAD_CRITIC = True

LEARNING_RATE = 1e-4
EPOCHS = 3

LAMBDA_GP = 10
ADVERSARIAL_LAMBDA = 5e-3
L1_LAMBDA = 1e-2

gen = Generator().to(DEVICE)
critic = Critic().to(DEVICE)
initialize_weights(gen)

l1_loss = nn.L1Loss()
vgg_loss = VGGLoss()

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.999))
opt_critic = optim.Adam(critic.parameters(), lr=LEARNING_RATE, betas=(0.0, 0.999))

gen_scaler = torch.cuda.amp.GradScaler()
critic_scaler = torch.cuda.amp.GradScaler()

if LOAD_GEN:
    load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
if LOAD_CRITIC:
    load_checkpoint(CHECKPOINT_CRITIC, critic, opt_critic, LEARNING_RATE)

for epoch in range(EPOCHS):
    train(train_data_loader, gen, critic, opt_gen, opt_critic, l1_loss, vgg_loss, gen_scaler, critic_scaler, epoch)

    save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
    save_checkpoint(critic, opt_critic, filename=CHECKPOINT_CRITIC)
    save_imgs(gen, test_transform, test_dir="test_imgs")

=>Loading Checkpoint...
=>Loading Checkpoint...


Epoch [0/2]: 100%|██████████| 50/50 [01:50<00:00,  2.20s/it, critic_loss=2.76, gen_loss=2.27]


=>Saving Checkpoint...
=>Saving Checkpoint...


Epoch [1/2]: 100%|██████████| 50/50 [01:49<00:00,  2.19s/it, critic_loss=-30, gen_loss=3.29]


=>Saving Checkpoint...
=>Saving Checkpoint...


Epoch [2/2]: 100%|██████████| 50/50 [01:50<00:00,  2.21s/it, critic_loss=-17.8, gen_loss=2.66]


=>Saving Checkpoint...
=>Saving Checkpoint...
