### Terminal

In [None]:
#!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
#!python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
#!pip install --quiet pytorch-lightning torchsummary

### Imports

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms, models
from skimage import io, transform
from torchsummary import summary
import matplotlib.pyplot as plt
import torch.nn.functional as F
import pytorch_lightning as pl
import torch.nn as nn
import numpy as np
import torch
import os, time

### Hyperparameters

In [None]:
img_size = 128
batch_size = 16

### Generator Blocks

In [None]:
lrelu = lambda x: nn.init.calculate_gain('leaky_relu', 0.2) * nn.LeakyReLU(0.2)(x)

def initialize_weights(m):
    if isinstance(m, nn.Conv2d):
        scale = 0.1
        he_gain = (1 / np.prod(m.weight.shape[1:])) ** 0.5
        nn.init.normal_(m.weight.data, std=scale*he_gain)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight.data, 1)
        nn.init.constant_(m.bias.data, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)
            
class ScaledConv(nn.Conv2d):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

class ResBlock(nn.Module):
    def __init__(self, nf=64):
        super().__init__()
        self.conv1 = ScaledConv(nf, nf, 3, padding=1)
        self.conv2 = ScaledConv(nf, nf, 3, padding=1)
    
    def forward(self, inputs):
        x = inputs
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = inputs + x
        return x
    
class DenseBlock(nn.Module):
    def __init__(self, nf=64, gc=32): # gc - growth channel (intermediate)
        super().__init__()
        self.convs = nn.ModuleList([ScaledConv(nf + gc * i, gc, 3, padding=1) for i in range(4)])
        self.final_conv = ScaledConv(nf + gc * 4, nf, 3, padding=1)
    
    def forward(self, inputs):
        x = inputs
        all_x = [x]
        for conv in self.convs:
            x = conv(torch.cat(all_x, dim=1))
            x = lrelu(x)
            all_x.append(x)
        x = self.final_conv(torch.cat(all_x, dim=1))
        return x

class RRDB(nn.Module):
    def __init__(self, nf=64, gc=32, beta=0.2): # beta - residual scaling parameter
        super().__init__()
        self.dbs = nn.ModuleList([DenseBlock(nf, gc) for _ in range(3)])
        self.beta = beta
    
    def forward(self, inputs):
        x = inputs
        for db in self.dbs:
            x += self.beta * db(x)
        x = inputs + self.beta * x
        return x

### Generator Arch

In [None]:
class Generator(nn.Module):
    def __init__(self, nblocks=23, block_type=RRDB):
        super().__init__()
        self.first_conv = ScaledConv(3, 64, 9, padding=4)
        self.blocks = nn.Sequential(*[block_type() for _ in range(nblocks)])
        self.end_block_conv = ScaledConv(64, 64, 3, padding=1)
        self.up_convs = nn.Sequential(*[ScaledConv(64, 256, 9, padding=4) for _ in range(2)])
        self.final_conv = ScaledConv(64, 3, 9, padding=4)
        self.pix_shuf = nn.PixelShuffle(2)
    
    def forward(self, inputs):
        x = inputs
        x = self.first_conv(x)
        x = lrelu(x)
        before_block_state = x
        
        x = self.blocks(x)
        
        x = self.end_block_conv(x)
        x = x + before_block_state
        
        for up_conv in self.up_convs:
            x = up_conv(x)
            x = self.pix_shuf(x)
            x = lrelu(x)
        
        x = self.final_conv(x)
        
        return x

### Discriminator Arch

In [None]:
class Discriminator(nn.Module):
    def __init__(self, hr_img_size=img_size):
        super().__init__()
        self.conv_filters = [64, 64, 128, 128, 256, 256, 512, 512]
        self.strides = [1, 2] * 4
        
        self.first_conv = ScaledConv(3, 64, 9, padding=4)
        self.convs = nn.Sequential(*[ScaledConv(
            self.conv_filters[idx-1], self.conv_filters[idx],
            3, stride=self.strides[idx], padding=1)
                      for idx in range(1, len(self.conv_filters))])
        self.bns = nn.ModuleList([nn.BatchNorm2d(nf) for nf in self.conv_filters[1:]])
        
        self.fc1 = nn.Linear((hr_img_size // 2 ** 4) ** 2 * self.conv_filters[-1], 1024)
        self.fc2 = nn.Linear(1024, 1)
    
    def forward(self, inputs):
        x = inputs
        x = self.first_conv(x)
        x = lrelu(x)
        before_block_state = x
        
        for block_idx, block in enumerate(self.convs):
            x = block(x)
            x = self.bns[block_idx](x)
            x = lrelu(x)
        
        print(x.shape)
        x = torch.flatten(x, start_dim=1) # only flatten CHW in NCHW
        x = self.fc1(x)
        x = lrelu(x)
        x = self.fc2(x)
        
        return x

### Pretrain GAN

In [None]:
class PretrainGenerator(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.gen = Generator()
        self.gen.apply(initialize_weights)
        self.l1 = nn.L1Loss()
        self.epoch_num = 0
        self.time = time.time()
    
    def forward(self, inputs):
        return self.gen(inputs)
    
    def training_step(self, batch, batch_idx):
        lr, hr = batch['lr'], batch['hr']
        sr = self.gen(lr)
        
        gloss = self.l1(sr, hr)
        return gloss
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.gen.parameters(), lr=1e-4 * batch_size / 16, betas=(0.9, 0.999))
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 2e5, gamma=0.5)
        return [optimizer], [scheduler]

    def training_epoch_end(self, outputs) -> None:
        end_time = time.time()
        duration = round(end_time - self.time)
        print('Epoch {} ended | Time to complete: {}m {}s       \r'.format(self.epoch_num, duration // 60, duration % 60), end='')
        self.epoch_num += 1
        self.time = end_time

### Adversarial GAN

In [None]:
class GAN(pl.LightningModule):
    def __init__(self, pretrained_generator):
        super().__init__()
        self.gen = pretrained_generator
        self.disc = Discriminator()
        self.disc.apply(initialize_weights)
        self.vgg = moels.vgg19(pretrained=True).features[:35] # 4th conv before 5th max pool
        self.sig = nn.Sigmoid()
        self.mse = nn.MSELoss()
        self.l1 = nn.L1Loss()
    
    def forward(self, inputs): # idk not really needed
        return inputs
    
    def d_ra(self, observed_preds, base_preds):
        return self.sig(observed_preds - torch.mean(base_preds))
    
    def disc_loss(self, true_preds, fake_preds):
        loss_real = -torch.mean(torch.log(self.d_ra(true_preds, fake_preds)))
        loss_fake = -torch.mean(torch.log(1 - self.d_ra(fake_preds, real_preds)))
        return loss_real + loss_fake
    
    def gen_loss(self, sr, hr, true_preds, fake_preds, lb=5e-3, eta=1e-2):
        vgg_hr = self.vgg(hr)
        vgg_sr = self.vgg(sr)
        vgg_loss = self.mse(vgg_sr, vgg_hr)
        
        loss_real = -torch.mean(torch.log(1 - self.d_ra(true_preds, fake_preds)))
        loss_fake = -torch.mean(torch.log(self.d_ra(fake_preds, real_preds)))
        adv_loss = loss_real + loss_fake
        
        l1_loss = self.l1(sr, hr)
        
        return vgg_loss + lb * adv_loss + eta * l1_loss
    
    def training_step(self, batch, batch_idx, optimizer_idx):
        lr, hr = batch
        sr = self.gen(lr)
        true_preds = self.disc(hr)
        fake_preds = self.disc(sr)
        
        if optimizer_idx == 0:
            gloss = gen_loss(sr, hr, true_preds, fake_preds)
            tqdm_dict = {'g_loss': gloss}
            output = {'loss': gloss, 'progress_bar': tqdm_dict, 'log': tqdm_dict}
            return output
        
        if optimizer_idx == 1:
            dloss = disc_loss(true_preds, fake_preds)
            tqdm_dict = {'d_loss': dloss}
            output = {'loss': dloss, 'progress_bar': tqdm_dict, 'log': tqdm_dict}
            return output
        
    def configure_optimizers(self):
        gen_opt = torch.optim.Adam(self.gen.parameters(), lr=1e-4 * batch_size / 16, betas=(0.9, 0.999))
        gen_scheduler = torch.optim.lr_scheduler.MultiStepLR(gen_opt, [5e4, 1e5, 2e5, 3e5], gamma=0.5)
        disc_opt = torch.optim.Adam(self.disc.parameters(), lr=1e-4 * batch_size / 16, betas=(0.9, 0.999))
        disc_scheduler = torch.optim.lr_scheduler.MultiStepLR(disc_opt, [5e4, 1e5, 2e5, 3e5], gamma=0.5)
        return [gen_opt, disc_opt], [gen_scheduler, disc_scheduler]

### Image Dataset

In [None]:
class SRDataset(Dataset):
    def __init__(self, root_dir='/kaggle/input/art-sr', max_size=20000, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.max_size = max_size
        self.img_filenames = os.listdir(os.path.join(root_dir, 'imgs32'))
        self.img_filenames.sort()
    
    def __len__(self):
        return min(len(self.img_filenames), int(batch_size * (self.max_size // batch_size)))
    
    def __getitem__(self, idx):
        lr_img = io.imread(os.path.join(self.root_dir, 'imgs32', self.img_filenames[idx]))
        hr_img = io.imread(os.path.join(self.root_dir, 'imgs128', self.img_filenames[idx]))
        ret = {'lr': lr_img, 'hr': hr_img}

        if self.transform:
            ret = self.transform(ret)
        
        return ret

class ToTensor(object):
    def __call__(self, sample):
        lr, hr = sample['lr'], sample['hr']
        
        lr = (lr.transpose((2, 0, 1)) / 255.0).astype(np.float32) # HWC -> CHW
        lr = torch.from_numpy(lr)
        hr = (hr.transpose((2, 0, 1)) / 255.0).astype(np.float32) # HWC -> CHW
        hr = torch.from_numpy(hr)
        return {'lr': lr,
                'hr': hr,
               }

class SRDataModule(pl.LightningDataModule):
    def __init__(self, train_test_split=0.9):
        super().__init__()
        self.dataset = SRDataset(
            transform=transforms.Compose([ToTensor()])
        )
        
        train_len = int(train_test_split * len(self.dataset))
        val_len = len(self.dataset) - train_len
        self.dataset_train, self.dataset_val = random_split(self.dataset, [train_len, val_len])
    
    def setup(self, stage=None):
        dataset = SRDataset(
            transform=transforms.Compose([ToTensor()])
        )
        
        if stage == 'train' or stage is None:
            self.dataset_train, self.dataset_val = random_split(self.dataset, [train_len, val_len])
    
    def train_dataloader(self):
        return DataLoader(self.dataset_train, batch_size=batch_size, shuffle=True, num_workers=4)
    
    '''def val_dataloader(self):
        return DataLoader(self.dataset_val, batch_size=batch_size, shuffle=True, num_workers=4)'''

### Train

In [None]:
dm = SRDataModule()
model = PretrainGenerator()

In [None]:
img_pair = dm.dataset[508]
lr = img_pair['lr']
hr = img_pair['hr']
print(torch.min(lr), torch.max(lr))
print(torch.min(hr), torch.max(hr))

lr = lr.permute(1, 2, 0).detach().numpy()
plt.imshow(lr)
plt.show()

hr = hr.permute(1, 2, 0).detach().numpy()
plt.imshow(hr)
plt.show()

In [None]:
trainer = pl.Trainer(max_epochs=10, gpus=1)
trainer.fit(model, dm)

In [1]:
img_pair = dm.dataset[0]
print(torch.std(model.gen.first_conv.weight))
lr = img_pair['lr']
lr = lr.view((1, 3, 32, 32))
sr = model.gen(lr)
sr = sr.permute(0, 2, 3, 1).detach().numpy()[0]
plt.imshow(sr)
plt.show()

hr = img_pair['hr']
print(nn.L1Loss()(hr, torch.from_numpy(sr).permute(2, 0, 1)))
hr = hr.permute(1, 2, 0).detach().numpy()
plt.imshow(hr)
plt.show()

NameError: name 'dm' is not defined