In [5]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))
        break

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/div2k-data/DIV2K_valid_HR/DIV2K_valid_HR/0857.png
/kaggle/input/div2k-data/DIV2K_train_HR/DIV2K_train_HR/0566.png
/kaggle/input/div2k-data/DIV2K_valid_LR_difficult/DIV2K_valid_LR_difficult/0888x4d.png
/kaggle/input/div2k-data/DIV2K_train_LR_difficult/DIV2K_train_LR_difficult/0087x4d.png


In [6]:
## generator_loss.py
import torch
from torch import nn
from torchvision.models.vgg import vgg16

class TVLoss(nn.Module):
    def __init__(self, tv_loss_weight=1):
        super(TVLoss, self).__init__()
        self.tv_loss_weight = tv_loss_weight

    def forward(self, x):
        batch_size, c_x, h_x, w_x = x.size()
        count_h = self.tensor_size(x[:, :, 1:, :])
        count_w = self.tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    @staticmethod
    def tensor_size(t):
        return t.size()[1] * t.size()[2] * t.size()[3]

class GeneratorLoss(nn.Module):
    def __init__(self):
        super(GeneratorLoss, self).__init__()
        vgg = vgg16(pretrained=True)
        loss_network = nn.Sequential(*list(vgg.features)[:31]).eval()
        for param in loss_network.parameters():
            param.requires_grad = False
        self.loss_network = loss_network
        self.mse_loss = nn.MSELoss()
        self.tv_loss = TVLoss()

    def forward(self, out_labels, out_images, target_images):
        # Adversarial Loss
        adversarial_loss = torch.mean(1 - out_labels)
        # Perception Loss
        perception_loss = self.mse_loss(self.loss_network(out_images), self.loss_network(target_images))
        # Image Loss
        image_loss = self.mse_loss(out_images, target_images)
        # TV Loss
        tv_loss = self.tv_loss(out_images)
        return image_loss + 0.001 * adversarial_loss + 0.006 * perception_loss + 2e-8 * tv_loss

In [7]:
## generator.py
import torch.nn as nn
import torch.nn.functional as F

class ganGenerator(nn.Module):
    def __init__(self):
        super(ganGenerator, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=9, stride=1, padding=4)
        self.prelu1 = nn.PReLU()
        self.GRB1 = GeneratorResidualBlock()
        self.GRB2 = GeneratorResidualBlock()
        self.GRB3 = GeneratorResidualBlock()
        self.GRB4 = GeneratorResidualBlock()
        self.GRB5 = GeneratorResidualBlock()
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.pxlshuffle1 = nn.PixelShuffle(2)
        self.prelu2 = nn.PReLU()
        self.conv4 = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.pxlshuffle2 = nn.PixelShuffle(2)
        self.prelu3 = nn.PReLU()
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.prelu1(x1)

        x2 = self.GRB1(x1)
        x2 = self.GRB2(x2)
        x2 = self.GRB3(x2)
        x2 = self.GRB4(x2)
        x2 = self.GRB5(x2)

        x2 = self.conv2(x2)
        x2 = self.bn1(x2)
        x3 = x1 + x2

        x3 = self.conv3(x3)
        x3 = self.pxlshuffle1(x3)
        x3 = self.prelu2(x3)
        x3 = self.conv4(x3)
        x3 = self.pxlshuffle2(x3)
        x4 = self.prelu3(x3)

        x5 = self.conv5(x4)

        return x5

class GeneratorResidualBlock(nn.Module):
    def __init__(self):
        super(GeneratorResidualBlock, self).__init__()
        # convolution
        self.conv1 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        # batchnorm
        self.bn1 = nn.BatchNorm2d(64)
        # prelu
        self.prelu1 = nn.PReLU()
        #convolution
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1)
        # batchnorm
        self.bn2 = nn.BatchNorm2d(64)
        # prelu
        self.prelu2 = nn.PReLU()
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.prelu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out = self.prelu2(out)
        return out + x

In [8]:
## discriminator.py
import torch.nn as nn

class ganDiscriminator(nn.Module):
    def __init__(self):
        super(ganDiscriminator, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1)
        self.lrelu1 = nn.LeakyReLU()

        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2)
        self.bn2 = nn.BatchNorm2d(64)
        self.lrelu2 = nn.LeakyReLU()

        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.lrelu3 = nn.LeakyReLU()
        self.conv4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2)
        self.bn4 = nn.BatchNorm2d(128)
        self.lrelu4 = nn.LeakyReLU()

        self.conv5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1)
        self.bn5 = nn.BatchNorm2d(256)
        self.lrelu5 = nn.LeakyReLU()
        self.conv6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2)
        self.bn6 = nn.BatchNorm2d(256)
        self.lrelu6 = nn.LeakyReLU()

        self.conv7 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1)
        self.bn7 = nn.BatchNorm2d(512)
        self.lrelu7 = nn.LeakyReLU()
        self.conv8 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=2)
        self.bn8 = nn.BatchNorm2d(512)
        self.lrelu8 = nn.LeakyReLU()

        # self.flat = nn.Flatten()

        # self.dense9 = nn.Linear(in_features=60*124*512, out_features=1024, bias=True)
        # self.lrelu9 = nn.LeakyReLU()

        # self.dense10 = nn.Linear(in_features=1024, out_features=1, bias=True)
        # self.sigmoid10 = nn.Sigmoid()

        self.adaptive_pool_1 = nn.AdaptiveAvgPool2d(1)
        self.conv9 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)
        self.lrelu9 = nn.LeakyReLU()
        self.conv10 = nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=1) 
        self.sigmoid9 = nn.Sigmoid()

    def forward(self, x):
        x1 = self.conv1(x)
        x1 = self.lrelu1(x1)

        x2 = self.conv2(x1)
        x2 = self.bn2(x2)
        x2 = self.lrelu2(x2)

        x3 = self.conv3(x2)
        x3 = self.bn3(x3)
        x3 = self.lrelu3(x3)

        x4 = self.conv4(x3)
        x4 = self.bn4(x4)
        x4 = self.lrelu4(x4)

        x5 = self.conv5(x4)
        x5 = self.bn5(x5)
        x5 = self.lrelu5(x5)

        x6 = self.conv6(x5)
        x6 = self.bn6(x6)
        x6 = self.lrelu6(x6)

        x7 = self.conv7(x6)
        x7 = self.bn7(x7)
        x7 = self.lrelu7(x7)

        x8 = self.conv8(x7)
        x8 = self.bn8(x8)
        x8 = self.lrelu8(x8)

        x9 = self.adaptive_pool_1(x8)
        x9 = self.conv9(x9)
        x9 = self.lrelu9(x9)

        x10 = self.conv10(x9)
        x10 = self.sigmoid9(x10.view(x10.size()[0]))

        # x9 = self.flat(x8)
        # x10 = self.dense9(x9)
        # x10 = self.lrelu9(x10)
        # x11 = self.dense10(x10)
        # x11 = self.sigmoid10(x11)

        x11 = x10

        return x11

In [9]:
## data_utils.py
from PIL import Image
from torchvision.transforms import Compose, ToTensor, ToPILImage, CenterCrop, transforms, Resize
from torch.utils.data.dataset import Dataset

def is_image_file(filename):
    return any(filename.endswith(extension) for extension in ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG'])

hr_target_size = (1020, 2040)

hr_transform = transforms.Compose([
    transforms.Lambda(lambda img: img.rotate(0) if img.size[0] > img.size[1] else img.rotate(90)),
    transforms.CenterCrop(hr_target_size),
    transforms.ToTensor()
])

lr_target_size = (1020 // 4, 2040 // 4)

lr_transform = transforms.Compose([
    transforms.Lambda(lambda img: img.rotate(0) if img.size[0] > img.size[1] else img.rotate(90)),
    transforms.CenterCrop(lr_target_size),
    transforms.ToTensor()
])


class Div2kTrainDataset(Dataset):
    def __init__(self, hr_base_dir, lr_base_dir):
        super(Div2kTrainDataset, self).__init__()

        self.hr_base_dir = hr_base_dir
        self.lr_base_dir = lr_base_dir

        self.hr_image_filenames = [f'{self.hr_base_dir}/{i:0>4}.png' for i in range(1, 801)]
        self.lr_image_filenames = [f'{self.lr_base_dir}/{i:0>4}x4d.png' for i in range(1, 801)]

        self.hr_transform = hr_transform
        self.lr_transform = lr_transform

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.hr_image_filenames[index]))
        lr_image = self.lr_transform(Image.open(self.lr_image_filenames[index]))
        return hr_image, lr_image

    def __len__(self):
        return len(self.hr_image_filenames)

class Div2kValDataset(Dataset):
    def __init__(self, hr_base_dir, lr_base_dir):
        super(Div2kValDataset, self).__init__()

        self.hr_base_dir = hr_base_dir
        self.lr_base_dir = lr_base_dir

        self.hr_image_filenames = [f'{self.hr_base_dir}/{i:0>4}.png' for i in range(801, 901)]
        self.lr_image_filenames = [f'{self.lr_base_dir}/{i:0>4}x4d.png' for i in range(801, 901)]

        self.hr_transform = hr_transform
        self.lr_transform = lr_transform

    def __getitem__(self, index):
        hr_image = self.hr_transform(Image.open(self.hr_image_filenames[index]))
        lr_image = self.lr_transform(Image.open(self.lr_image_filenames[index]))
        return hr_image, lr_image

    def __len__(self):
        return len(self.hr_image_filenames)
    
def display_transform():
    return Compose([
        ToPILImage(),
        Resize(400),
        CenterCrop(400),
        ToTensor()
    ])

In [12]:
## train.py
# aliased imports
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.utils as utils

# package imports no alias
import torch
import os
# import pytorch_ssim
import math

# Class/Fuction Imports
from tqdm import tqdm
from torch.autograd import Variable
from torch.utils.data import DataLoader

# Local Imports
# from generator import ganGenerator
# from discriminator import ganDiscriminator
# from generator_loss import GeneratorLoss
# from data_utils import Div2kValDataset, Div2kTrainDataset, display_transform

# filter warnings
import warnings
warnings.filterwarnings("ignore")

train_hr_loc = "/kaggle/input/div2k-data/DIV2K_train_HR/DIV2K_train_HR"
train_lr_loc = "/kaggle/input/div2k-data/DIV2K_train_LR_difficult/DIV2K_train_LR_difficult"
valid_hr_loc = "/kaggle/input/div2k-data/DIV2K_valid_HR/DIV2K_valid_HR"
valid_lr_loc = "/kaggle/input/div2k-data/DIV2K_valid_LR_difficult/DIV2K_valid_LR_difficult"

if __name__ == '__main__':

    NUM_EPOCHS = 40

    train_set = Div2kTrainDataset(train_hr_loc, train_lr_loc)
    val_set = Div2kValDataset(valid_hr_loc, valid_lr_loc)
    train_loader = DataLoader(dataset=train_set, batch_size=1, shuffle=True)
    val_loader = DataLoader(dataset=val_set, batch_size=1, shuffle=False)


    netG = ganGenerator()
    netG.load_state_dict(torch.load("/kaggle/working/epochs/netG_epoch_4_26.pth"), strict=False)
    print('# generator parameters:', sum(param.numel() for param in netG.parameters()))
    netD = ganDiscriminator()
    netD.load_state_dict(torch.load("/kaggle/working/epochs/netD_epoch_4_26.pth"), strict=False)
    print('# discriminator parameters:', sum(param.numel() for param in netD.parameters()))

    generator_criterion = GeneratorLoss()

    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device_gen = torch.device("cuda:0")
    device_disc = torch.device("cuda:0")
    # if torch.cuda.is_available():
    #     netG.to(device)
    #     print("netG sent to cuda")
    #     netD.to(device)
    #     print("netD sent to cuda")
    #     generator_criterion.to(device)
    #     print("generator criterion sent to cuda")

    if torch.cuda.is_available():
        netG.to(device_gen)
        print("netG sent to cuda")
        netD.to(device_disc)
        print("netD sent to cuda")
        generator_criterion.to(device_gen)
        print("generator criterion sent to cuda")

    optimizerG = optim.Adam(netG.parameters())
    optimizerD = optim.Adam(netD.parameters())

    results = {'d_loss': [], 'g_loss': [], 'd_score': [], 'g_score': [], 'psnr': [], 'ssim': []}

    for epoch in range(27, NUM_EPOCHS + 1):
        train_bar = tqdm(train_loader)
        running_results = {'batch_sizes': 0, 'd_loss': 0, 'g_loss': 0, 'd_score': 0, 'g_score': 0}

        netG.train()
        netD.train()

        for target, source in train_bar:
            g_update_first = True
            batch_size = target.size(0)
            running_results['batch_sizes'] += batch_size

            ############################
            # (1) Update D network: maximize D(x)-1-D(G(z))
            ###########################
            real_img = Variable(target)
            if torch.cuda.is_available():
                real_img = real_img.to(device_disc)
            z = Variable(source)
            if torch.cuda.is_available():
                z = z.to(device_gen)
            fake_img = netG(z).to(device_disc)

            netD.zero_grad()
            real_out = netD(real_img).mean()
            fake_out = netD(fake_img).mean()
            d_loss = 1 - real_out + fake_out
            d_loss.backward()
            optimizerD.step()

            ############################
            # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
            ###########################
            netG.zero_grad()
            ## The two lines below are added to prevent runetime error in Google Colab ##
#             fake_img = netG(z).detach().to(device_disc)
            fake_img = netG(z).to(device_disc)
            fake_out = netD(fake_img).mean().to(device_gen)
            real_img = real_img.to(device_gen)
            fake_img = fake_img.to(device_gen)
            ##
            g_loss = generator_criterion(fake_out, fake_img, real_img)
            g_loss.backward()

#             fake_img = netG(z).detach().to(device_disc)
            fake_img = netG(z).to(device_disc)
            fake_out = netD(fake_img).mean()

            optimizerG.step()

            # loss for current batch before optimization
            running_results['g_loss'] += g_loss.item() * batch_size
            running_results['d_loss'] += d_loss.item() * batch_size
            running_results['d_score'] += real_out.item() * batch_size
            running_results['g_score'] += fake_out.item() * batch_size

            train_bar.set_description(desc='[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f' % (
                epoch, NUM_EPOCHS, running_results['d_loss'] / running_results['batch_sizes'],
                running_results['g_loss'] / running_results['batch_sizes'],
                running_results['d_score'] / running_results['batch_sizes'],
                running_results['g_score'] / running_results['batch_sizes']))

        netG.eval()
        out_path = '/kaggle/working/training_results/'
        if not os.path.exists(out_path):
            os.makedirs(out_path)

        with torch.no_grad():
            val_bar = tqdm(val_loader)
            validation_results = {'mse': 0, 'ssims': 0, 'psnr': 0, 'ssim': 0, 'batch_sizes': 0}
            val_images = []
            for val_hr, val_lr in val_bar:
                batch_size = val_lr.size(0)
                validation_results['batch_sizes'] += batch_size
                lr = val_lr
                hr = val_hr
                if torch.cuda.is_available():
                    lr = lr.to(device_gen)
                    hr = hr.to(device_gen)
                sr = netG(lr)

                batch_mse = ((sr - hr) ** 2).data.mean()
                validation_results['mse'] += batch_mse * batch_size
                # batch_ssim = pytorch_ssim.ssim(sr, hr).item()
                # validation_results['ssims'] += batch_ssim * batch_size
                validation_results['psnr'] = 10 * math.log10((hr.max()**2) / (validation_results['mse'] / validation_results['batch_sizes']))
                # validation_results['ssim'] = validation_results['ssims'] / validation_results['batch_sizes']
                val_bar.set_description(
                    desc='[converting LR images to SR images] PSNR: %.4f dB SSIM: %.4f' % (
                        validation_results['psnr'], 0))

                val_images.extend(
                    [display_transform()(lr.squeeze(0)), display_transform()(hr.data.cpu().squeeze(0)),
                        display_transform()(sr.data.cpu().squeeze(0))])
            val_images = torch.stack(val_images)
            val_images = torch.chunk(val_images, val_images.size(0) // 15)
            val_save_bar = tqdm(val_images, desc='[saving training results]')
            index = 1
            for image in val_save_bar:
                image = utils.make_grid(image, nrow=3, padding=5)
                utils.save_image(image, out_path + 'epoch_%d_index_%d.png' % (epoch, index), padding=5)
                index += 1

        if not os.path.exists('/kaggle/working/epochs/'):
            os.makedirs('/kaggle/working/epochs/')
        # save model parameters
        torch.save(netG.state_dict(), '/kaggle/working/epochs/netG_epoch_%d_%d.pth' % (4, epoch))
        torch.save(netD.state_dict(), '/kaggle/working/epochs/netD_epoch_%d_%d.pth' % (4, epoch))
        # save loss\scores\psnr\ssim
        results['d_loss'].append(running_results['d_loss'] / running_results['batch_sizes'])
        results['g_loss'].append(running_results['g_loss'] / running_results['batch_sizes'])
        results['d_score'].append(running_results['d_score'] / running_results['batch_sizes'])
        results['g_score'].append(running_results['g_score'] / running_results['batch_sizes'])
        results['psnr'].append(validation_results['psnr'])
        # results['ssim'].append(validation_results['ssim'])

#         if epoch % 10 == 0 and epoch != 0:
#             out_path = '/kaggle/working/statistics/'
#             if not os.path.exists(out_path):
#                 os.makedirs(out_path)
#             data_frame = pd.DataFrame(
#                 data={'Loss_D': results['d_loss'], 'Loss_G': results['g_loss'], 'Score_D': results['d_score'],
#                         'Score_G': results['g_score'], 'PSNR': results['psnr'], 'SSIM': 0},
#                 index=range(1, epoch + 1))
#             data_frame.to_csv(out_path + 'srf_' + str(4) + '_train_results.csv', index_label='Epoch')

# generator parameters: 734224
# discriminator parameters: 5215425


Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:03<00:00, 172MB/s] 


netG sent to cuda
netD sent to cuda
generator criterion sent to cuda


[27/40] Loss_D: 1.0000 Loss_G: 0.0248 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 800/800 [40:32<00:00,  3.04s/it]
[converting LR images to SR images] PSNR: 16.8992 dB SSIM: 0.0000: 100%|██████████| 100/100 [00:54<00:00,  1.84it/s]
[saving training results]: 100%|██████████| 20/20 [00:25<00:00,  1.30s/it]
[28/40] Loss_D: 1.0000 Loss_G: 0.0195 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 800/800 [39:30<00:00,  2.96s/it]
[converting LR images to SR images] PSNR: 17.6777 dB SSIM: 0.0000: 100%|██████████| 100/100 [00:47<00:00,  2.11it/s]
[saving training results]: 100%|██████████| 20/20 [00:25<00:00,  1.29s/it]
[29/40] Loss_D: 1.0000 Loss_G: 0.0182 D(x): 1.0000 D(G(z)): 1.0000: 100%|██████████| 800/800 [39:29<00:00,  2.96s/it]
[converting LR images to SR images] PSNR: 17.5770 dB SSIM: 0.0000: 100%|██████████| 100/100 [00:47<00:00,  2.12it/s]
[saving training results]: 100%|██████████| 20/20 [00:25<00:00,  1.28s/it]
[30/40] Loss_D: 1.0000 Loss_G: 0.0187 D(x): 1.0000 D(G(z)): 1.0000: 10

KeyboardInterrupt: 