In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load
!pip install lpips

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        pass
        #print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session
import os
import numpy as np
import math
import itertools
import sys
import glob
import random
import time

import torchvision
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
from torchvision.models import vgg19

from torch.utils.data import DataLoader, Dataset, random_split
from torch.autograd import Variable


import torch.nn.init as init
import torch.nn as nn
import torch.nn.functional as F
import torch 

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
from math import sqrt
from datasets import load_dataset

import lpips

'''
!pip install GPUtil

import torch
from GPUtil import showUtilization as gpu_usage
from numba import cuda

def free_gpu_cache():
    print("Initial GPU Usage")
    gpu_usage()                             

    torch.cuda.empty_cache()

    cuda.select_device(0)
    cuda.close()
    cuda.select_device(0)

    print("GPU Usage after emptying the cache")
    gpu_usage()

free_gpu_cache()  
'''

In [None]:
from kaggle_datasets import KaggleDatasets
GCS_PATH = KaggleDatasets().get_gcs_path('div2k-high-resolution-images')

In [None]:
GCS_PATH

In [None]:
!pip install datasets

In [None]:
#dataset


# Normalization parameters for pre-trained PyTorch models
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])


def denormalize(tensors):
    """ Denormalizes image tensors using mean and std """
    for c in range(3):
        tensors[:, c].mul_(std[c]).add_(mean[c])
    return torch.clamp(tensors, 0, 255)


class ImageDataset(Dataset):
    def __init__(self, root, hr_shape, split="train"):
        hr_height, hr_width = hr_shape
        # Transforms for low resolution images and high resolution images
        self.lr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
                transforms.ToTensor(),
                #transforms.Normalize(mean, std),
            ]
        )
        self.hr_transform = transforms.Compose(
            [
                transforms.Resize((hr_height, hr_height), Image.BICUBIC),
                transforms.ToTensor(),
                #transforms.Normalize(mean, std),
            ]
        )
        print(root + "/*.*")
        self.files = sorted(glob.glob(root + "/*.*"))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)

        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)

In [None]:
class RFAB(nn.Module):
    """
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    """

    def __init__(self, filters, res_scale=0.2):
        super(RFAB, self).__init__()
        self.res_scale = res_scale
        layers = [nn.Conv2d(filters, filters, 3, 1, 1, bias=True),
                  nn.LeakyReLU(),
                  nn.Conv2d(filters, filters, 3, 1, 1, bias=True)]
        self.model = nn.Sequential(*layers)
        self.FAB = nn.Sequential(*[
            nn.AdaptiveAvgPool2d((1,1)),
            nn.Conv2d(filters, filters, 3, 1, 1, bias=True),
            nn.LeakyReLU(),
            nn.Conv2d(filters, filters, 3, 1, 1, bias=True),
            nn.Sigmoid()
        ])

    def forward(self, x):
        inputs = x
        out = self.model(inputs)
        out1 = self.FAB(out)
        out_from_FAB = out1.mul(out)
        return out_from_FAB.mul(self.res_scale) + x

In [None]:
#model

class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)


class DenseResidualBlock(nn.Module):
    """
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    """

    def __init__(self, filters, res_scale=0.2):
        super(DenseResidualBlock, self).__init__()
        self.res_scale = res_scale

        def block(in_features, non_linearity=True):
            layers = [nn.Conv2d(in_features, filters, 3, 1, 1, bias=True)]
            if non_linearity:
                layers += [nn.LeakyReLU()]
            return nn.Sequential(*layers)

        self.b1 = block(in_features=1 * filters)
        self.b2 = block(in_features=2 * filters)
        self.b3 = block(in_features=3 * filters)
        self.b4 = block(in_features=4 * filters)
        self.b5 = block(in_features=5 * filters, non_linearity=False)
        self.blocks = [self.b1, self.b2, self.b3, self.b4, self.b5]

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)
        return out.mul(self.res_scale) + x


class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2):
        super(ResidualInResidualDenseBlock, self).__init__()
        self.res_scale = res_scale
        self.dense_blocks = nn.Sequential(
            RFAB(filters), RFAB(filters), RFAB(filters)
        )

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x

In [None]:
class GeneratorRRDB(nn.Module):
    def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
        super(GeneratorRRDB, self).__init__()

        # First layer
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
        # Second conv layer post residual blocks
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        # Upsampling layers
        upsample_layers = [
            nn.Conv2d(filters, filters * (2 ** (num_upsample)), kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
        ]
        for i in range(num_upsample):
            mult = 2 ** (num_upsample - i)
            print(filters * mult, int(filters * mult / 2))
            upsample_layers += [
                nn.Upsample(scale_factor = 2, mode='nearest'),
                nn.Conv2d(filters * mult, int(filters * mult / 2),
                            kernel_size=3, stride=1, padding=1),
            ]
        
        
        self.upsampling = nn.Sequential(*upsample_layers)
        # Final output block
        # ORIGINAL
#         self.conv3 = nn.Sequential(
#             nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
#             nn.LeakyReLU(),
#             nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
#         )
#       # Exp 1
        self.conv3 = nn.Sequential(
            nn.Upsample(scale_factor = 1, mode='nearest'),
            nn.ReflectionPad2d(1),
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=0),
            nn.LeakyReLU(),
            #nn.Upsample(scale_factor = 1, mode='nearest'),
            #nn.ReflectionPad2d(1),
            nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
        )
#         # Exp 2
#         self.conv3 = nn.Sequential(
#             nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
#             nn.LeakyReLU(),
#             nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
            
#             nn.Upsample(scale_factor = 1, mode='bilinear'),
#             nn.ReflectionPad2d(1),
#             nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=0),
#         )
#        # Exp 3
#         self.conv3 = nn.Sequential(
#             nn.Upsample(scale_factor = 1, mode='bilinear'),
#             nn.ReflectionPad2d(1),
#             nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=0),
#         )
        
    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out

In [None]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)

In [None]:
opt = {
    "epoch": 0,
    "n_epochs": 3,
    "batch_size": 1,
    "lr_g": 0.00002,
    "lr_d": 0.0009,
    "b1": 0.9,
    "b2": 0.999,
    "decay_epoch": 50,
    "n_cpu": 2,
    "hr_height": 1024,
    "hr_width": 1024,
    "channels": 3,
    "sample_interval": 100,
    "checkpoint_interval": 2000,
    "residual_blocks": 23,
    "warmup_batches": 400,
    "lambda_adv": 9e-3,
    "lambda_pixel": 2e-2,
}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hr_shape = (opt["hr_height"], opt["hr_width"])
device

In [None]:
generator = GeneratorRRDB(opt["channels"], filters=64, num_res_blocks=opt["residual_blocks"]).to(device)
discriminator = Discriminator(input_shape=(opt["channels"], *hr_shape)).to(device)
feature_extractor = FeatureExtractor().to(device)

In [None]:
# print(discriminator)

In [None]:
# print(generator)

### Texture Loss

In [None]:
def gram_matrix(y):
    (b, ch, h, w) = y.size()
    features = y.view(b, ch, w * h)
    features_t = features.transpose(1, 2)
    gram = features.bmm(features_t) / (ch * h * w)
    return gram

class VGGTexture(nn.Module):
    'Pretrained VGG-19 model features.'
    def __init__(self, layers=(0), replace_pooling = False):
        super(VGGTexture, self).__init__()
        self.layers = layers
        self.instance_normalization = nn.InstanceNorm2d(128)
        self.relu = nn.ReLU()
        self.model = vgg19(pretrained=True).features
        # Changing Max Pooling to Average Pooling
        if replace_pooling:
            self.model._modules['4'] = nn.AvgPool2d((2,2), (2,2), (1,1))
            self.model._modules['9'] = nn.AvgPool2d((2,2), (2,2), (1,1))
            self.model._modules['18'] =nn.AvgPool2d((2,2), (2,2), (1,1))
            self.model._modules['27'] =nn.AvgPool2d((2,2), (2,2), (1,1))
            self.model._modules['36'] = nn.AvgPool2d((2,2), (2,2), (1,1))
        for param in self.model.parameters():
            param.requires_grad = False

    def forward(self, x):
        features = []
        for name, layer in enumerate(self.model):
            x = layer(x)
            if name in self.layers:
                features.append(x)
                if len(features) == len(self.layers):
                    break
        return features

In [None]:
from torch.autograd import Variable

vgg_layers = [8, 17, 26, 35]
vgg_texture = VGGTexture(layers=vgg_layers, replace_pooling = False).to(device)

def to_variable(x):
    x = x.to(device)
    return Variable(x)

def mse_texture(a, b):
    return torch.mean(torch.abs((a-b)**2).view(-1))

def TEXTURE_LOSS(data_fake, data_real):
    text_loss = []

    vgg_fake = vgg_texture.forward(data_fake)
    vgg_real = vgg_texture.forward(data_real)
    gram_fake = [gram_matrix(y) for y in vgg_fake]
    gram_real = [gram_matrix(y) for y in vgg_real]
    
    for m in range(0, len(vgg_fake)):
        text_loss += [mse_texture(gram_fake[m], gram_real[m])]
    text_loss = sum(text_loss)
    return text_loss

In [None]:
criterion_GAN = torch.nn.BCEWithLogitsLoss().to(device)
criterion_content = torch.nn.L1Loss().to(device)
criterion_pixel = torch.nn.L1Loss().to(device)
criterion_texture = TEXTURE_LOSS

In [None]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt["lr_g"], betas=(opt["b1"],opt["b2"]))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt["lr_d"], betas=(opt["b1"], opt["b2"]))

In [None]:
set5_hr_path, set14_hr_path = "../input/set5-sr/images","../input/set14-sr/images"

train_path = "../input/div2k-high-resolution-images/DIV2K_train_HR/DIV2K_train_HR"
# train_path = "../input/div2k-bicubic-x4/DIV2K_train_LR_bicubic/X4"

In [None]:
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
img_dataset = ImageDataset(train_path, hr_shape=hr_shape)
set_5_dataset = ImageDataset(set5_hr_path, hr_shape=hr_shape)
set_14_dataset = ImageDataset(set14_hr_path, hr_shape=hr_shape)

#train, test = random_split(img_dataset, [640, 160], generator=torch.Generator().manual_seed(42)) 
# numbers in above list are no of imgs in train and test (80-20) split of 800 files
train_ds = DataLoader(
    img_dataset, # train,
    batch_size=opt["batch_size"],
    shuffle=True,
    num_workers=opt["n_cpu"],
)
'''
test_ds = DataLoader(
    test,
    batch_size=opt["batch_size"],
    shuffle=True,
    num_workers=opt["n_cpu"],
)
'''
set_5_ds =  DataLoader(
    set_5_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=opt["n_cpu"],
)
set_14_ds =  DataLoader(
    set_14_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=opt["n_cpu"],
)

In [None]:
PSNR_F = PSNR().cuda()
SSIM_F = SSIM().cuda()

In [None]:
Path("./images/training").mkdir(parents=True, exist_ok=True)
Path("./saved_models").mkdir(parents=True, exist_ok=True)
Path("./images/testing_set5/").mkdir(parents=True, exist_ok=True)
Path("./images/testing_set14/").mkdir(parents=True, exist_ok=True)

In [None]:
# try:
#     generator = GeneratorRRDB(opt["channels"], filters=64, num_res_blocks=opt["residual_blocks"]).to(device)
#     generator.load_state_dict(torch.load("../input/xsrgan-generator/generator.pth"))
    
#     discriminator = Discriminator(input_shape=(opt["channels"], *hr_shape)).to(device)
#     discriminator.load_state_dict(torch.load("../input/xsrgan-generator/discriminator.pth"))
# except:
#     print("Failed to load saved models")

In [None]:
for epoch in range(opt["epoch"], opt["n_epochs"]):
    for i, imgs in enumerate(train_ds):

        batches_done = epoch * len(train_ds) + i

        # Configure model input
        imgs_lr = Variable(imgs["lr"].type(Tensor))
        imgs_hr = Variable(imgs["hr"].type(Tensor))

        # Adversarial ground truths
        valid = Variable(Tensor(np.ones((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)
        fake = Variable(Tensor(np.zeros((imgs_lr.size(0), *discriminator.output_shape))), requires_grad=False)

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # Generate a high resolution image from low resolution input
        gen_hr = generator(imgs_lr)
        # Measure pixel-wise loss against ground truth
        loss_pixel = criterion_pixel(gen_hr, imgs_hr)
        y_hat = torch.clone(gen_hr) #denormalize(gen_hr)
        y = torch.clone(imgs_hr)    #denormalize(imgs_hr)
        #calc psnr denormalize(imgs_hr)
        psnr = PSNR_F(denormalize(y_hat), denormalize(y), 255)
        ssim = SSIM_F(denormalize(y_hat), denormalize(y), 255)
        
        if batches_done < opt["warmup_batches"]:
            # Warm-up (pixel-wise loss only)
            loss_pixel.backward()
            optimizer_G.step()
            print(
                "[Epoch {}/{}] [Batch {}/{}] [G pixel: {}]".format(epoch, opt["n_epochs"], i, len(train_ds), loss_pixel.item())
            )
            continue

        # Extract validity predictions from discriminator
        pred_real = discriminator(imgs_hr).detach()
        pred_fake = discriminator(gen_hr)

        # Adversarial loss (relativistic average GAN)
        loss_GAN = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), valid)

        # Content loss
        gen_features = feature_extractor(gen_hr)
        real_features = feature_extractor(imgs_hr).detach()
        loss_content = criterion_content(gen_features, real_features)

        # Texture loss
        loss_texture = criterion_texture(gen_hr, imgs_hr) #fake, real
        
        # In TSRGAN loss_G = loss_pixel + loss_texture + (a*loss_adverserial) + (b*loss_content)
        # loss_pixel is loss_perceptual
        
        # Total generator loss (Original)
        # loss_G = loss_content + opt["lambda_adv"] * loss_GAN + opt["lambda_pixel"] * loss_pixel
        
        # Total generator loss (TSRGAN)
        # loss_G = loss_texture + opt["lambda_pixel"] * loss_content + opt["lambda_adv"] * loss_GAN + loss_pixel
        
        # Total generator loss (loss_texture + Original)
        loss_G = loss_texture + loss_content + opt["lambda_adv"] * loss_GAN + opt["lambda_pixel"] * loss_pixel
        
        loss_G.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        pred_real = discriminator(imgs_hr)
        pred_fake = discriminator(gen_hr.detach())

        # Adversarial loss for real and fake images (relativistic average GAN)
        loss_real = criterion_GAN(pred_real - pred_fake.mean(0, keepdim=True), valid)
        loss_fake = criterion_GAN(pred_fake - pred_real.mean(0, keepdim=True), fake)

        # Total loss
        loss_D = (loss_real + loss_fake) / 2

        loss_D.backward()
        optimizer_D.step()

        # --------------
        #  Log Progress
        # --------------

        if batches_done % opt["sample_interval"] == 0:
            # Save image grid with upsampled inputs and ESRGAN outputs
            print(
                "[Epoch {}/{}] [Batch {}/{}] [D loss: {}] [G loss: {}, content: {}, adv: {}, pixel: {}, texture: {}, psnr: {}, ssim: {}]"
                .format(
                    epoch,
                    opt["n_epochs"],
                    i,
                    len(train_ds),
                    loss_D.item(),
                    loss_G.item(),
                    loss_content.item(),
                    loss_GAN.item(),
                    loss_pixel.item(),
                    loss_texture.item(),
                    psnr,
                    ssim
                )
            )
            imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
            img_grid = denormalize(torch.cat((imgs_lr, gen_hr), -1))
            save_image(img_grid, "./images/training/%d.png" % batches_done, nrow=1, normalize=False)

        if batches_done % opt["checkpoint_interval"] == 0:
            # Save model checkpoints
            torch.save(generator.state_dict(), "./saved_models/generator_%d.pth" % epoch)
            torch.save(discriminator.state_dict(), "./saved_models/discriminator_%d.pth" %epoch)
            print("./saved_models/generator_%d.pth" % epoch)

In [None]:
torch.save(generator.state_dict(), "./saved_models/generator.pth")
torch.save(discriminator.state_dict(), "./saved_models/discriminator.pth")
print("./saved_models/generator.pth")
opt["final_generator_pth"] = "./saved_models/generator.pth"

In [None]:
PSNR_F = PSNR().cuda()
SSIM_F = SSIM().cuda()

LPIPS_F = lpips.LPIPS(net='vgg').cuda()

In [None]:
from pathlib import Path
Path("./images/testing").mkdir(parents=True, exist_ok=True)
Path("./saved_models").mkdir(parents=True, exist_ok=True)

generator = GeneratorRRDB(opt["channels"], filters=64, num_res_blocks=opt["residual_blocks"]).to(device)
generator.load_state_dict(torch.load(opt["final_generator_pth"])) # change 10 to max trained epoch
generator.eval()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
psnr_ls = []
ssim_ls = []
lpis_ls = []
# Prepare input
# eval 
for i, imgs in enumerate(test_ds):
    imgs_lr = Variable(imgs["lr"].type(Tensor))
    imgs_hr = Variable(imgs["hr"].type(Tensor))
    
    
    start_time = time.time()
    with torch.no_grad():
        sr_image = generator(imgs_lr).cuda()
    print("--- %s seconds ---" % (time.time() - start_time))


    
    mse = torch.mean((sr_image - imgs_hr) ** 2)
    psnr = PSNR_F(sr_image, imgs_hr, 255)
    ssim = SSIM_F(sr_image, imgs_hr, 255)
    lp = LPIPS_F.forward(sr_image, imgs_hr).flatten()

    print("[psnr: {}, ssim: {}, lpips: {}]"
            .format(
                psnr,
                ssim,
                lp
            )
         )
    psnr_ls.extend(psnr.tolist())
    ssim_ls.extend(ssim.tolist())
    lpis_ls.extend(lp.tolist())
    
    imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
    img_grid = torch.cat((imgs_lr, imgs_hr, sr_image), -1)
    save_image(img_grid, "./images/testing/%d.png" % i, nrow=1, normalize=False)
    
    for i, im in enumerate(sr_image):
        plt.subplot(3,4, i+1)
        plt.imshow(imgs_lr[i].cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
        plt.subplot(3,4, i+1+4)
        plt.imshow(imgs_hr[i].cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
        
        plt.subplot(3,4, i+1+4+4)
        plt.imshow(im.cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
    plt.show()
    psnr_arr = np.array(psnr_ls)
    print("PSNR on DIV2K Test: ",psnr_arr.mean())
    ssim_arr = np.array(ssim_ls)
    print("SSIM on DIV2K Test: ",ssim_arr.mean())
    lpips_arr = np.array(lpis_ls)
    print("LPIPS on DIV2K Test: ",lpips_arr.mean())

In [None]:
generator = GeneratorRRDB(opt["channels"], filters=64, num_res_blocks=opt["residual_blocks"]).to(device)
generator.load_state_dict(torch.load(opt["final_generator_pth"])) # change 10 to max trained epoch
generator.eval()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
psnr_ls = []
ssim_ls = []
lpis_ls = []
time_ls = []

for i, imgs in enumerate(set_5_ds):
    imgs_lr = Variable(imgs["lr"].type(Tensor))
    imgs_hr = Variable(imgs["hr"].type(Tensor))
    
    start_time = time.time()
    with torch.no_grad():
        sr_image = generator(imgs_lr).cuda()
    time_ls.append(time.time() - start_time)

    
    mse = torch.mean((sr_image - imgs_hr) ** 2)
    psnr = PSNR_F(sr_image, imgs_hr, 255)
    ssim = SSIM_F(sr_image, imgs_hr, 255)
    lp = LPIPS_F.forward(sr_image, imgs_hr).flatten()
    
    print("[psnr: {}, ssim: {}, lpips: {}]"
            .format(
                psnr,
                ssim,
                lp
            )
         )
    psnr_ls.extend(psnr.tolist())
    ssim_ls.extend(ssim.tolist())
    lpis_ls.extend(lp.tolist())

    
    imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
    img_grid = torch.cat((imgs_lr, imgs_hr, sr_image), -1)
    save_image(img_grid, "./images/testing_set5/%d.png" % i, nrow=1, normalize=False)
    plt.figure(figsize=(15,15))
    for i, im in enumerate(sr_image):
        plt.subplot(5,3, 3*i + 1)
        plt.imshow(imgs_lr[i].cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
        plt.subplot(5,3, 3*i + 2)
        plt.imshow(imgs_hr[i].cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
        
        plt.subplot(5, 3, 3*i+3)
        plt.imshow(im.cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
    plt.show()

In [None]:
print(time_ls)
times = np.array(time_ls)
print("Total Reconstruction Time on Set5:", np.sum(times))

In [None]:
ssim_arr = np.array(ssim_ls)
print(ssim_arr)
print("SSIM on Set5 Test: ",ssim_arr.mean())

In [None]:
psnr_arr = np.array(psnr_ls)
print(psnr_arr)
print("PSNR on Set5 Test: ",psnr_arr.mean())

In [None]:
lpips_arr = np.array(lpis_ls)
print(lpips_arr)
print("LPIPS on Set5 Test: ", lpips_arr.mean())

In [None]:
generator = GeneratorRRDB(opt["channels"], filters=64, num_res_blocks=opt["residual_blocks"]).to(device)
generator.load_state_dict(torch.load(opt["final_generator_pth"])) # change 10 to max trained epoch
generator.eval()

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean, std)])
psnr_ls = []
ssim_ls = []
lpis_ls = []
time_ls = []

for i, imgs in enumerate(set_14_ds):
    imgs_lr = Variable(imgs["lr"].type(Tensor))
    imgs_hr = Variable(imgs["hr"].type(Tensor))
    
    start_time = time.time()
    with torch.no_grad():
        sr_image = generator(imgs_lr).cuda()
    time_ls.append(time.time() - start_time)

    
    mse = torch.mean((sr_image - imgs_hr) ** 2)
    psnr = PSNR_F(sr_image, imgs_hr, 255)
    ssim = SSIM_F(sr_image, imgs_hr, 255)
    lp = LPIPS_F.forward(sr_image, imgs_hr).flatten()
    
    print("[psnr: {}, ssim: {}, lpips: {}]"
            .format(
                psnr,
                ssim,
                lp
            )
         )
    psnr_ls.extend(psnr.tolist())
    ssim_ls.extend(ssim.tolist())
    lpis_ls.extend(lp.tolist())

    
    imgs_lr = nn.functional.interpolate(imgs_lr, scale_factor=4)
    img_grid = torch.cat((imgs_lr, imgs_hr, sr_image), -1)
    save_image(img_grid, "./images/testing_set14/%d.png" % i, nrow=1, normalize=False)
    plt.figure(figsize=(15,15))
    for i, im in enumerate(sr_image):
        plt.subplot(5,3, 3*i + 1)
        plt.imshow(imgs_lr[i].cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
        plt.subplot(5,3, 3*i + 2)
        plt.imshow(imgs_hr[i].cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
        
        plt.subplot(5, 3, 3*i+3)
        plt.imshow(im.cpu().permute(1, 2, 0))
        plt.tick_params(
            left = False,
            right = False,
            labelleft = False ,
            labelbottom = False,
            bottom = False
        )
    plt.show()

In [None]:
print(time_ls)
times = np.array(time_ls)
print("Total Reconstruction Time on Set14:", np.sum(times))

In [None]:
ssim_arr = np.array(ssim_ls)
print(ssim_arr)
print("SSIM on Set14 Test: ",ssim_arr.mean())

In [None]:
psnr_arr = np.array(psnr_ls)
print(psnr_arr)
print("PSNR on Set14 Test: ",psnr_arr.mean())

In [None]:
lpips_arr = np.array(lpis_ls)
print(lpips_arr)
print("LPIPS on Set14 Test: ",lpips_arr.mean())

In [None]:
!zip -r testing_set14.zip ./images/testing_set14

In [None]:
!zip -r testing_set5.zip ./images/testing_set5

In [None]:
!zip -r training.zip ./images/training

In [None]:
!zip -r models.zip ./saved_models