In [1]:
import os
from pathlib import Path

from tqdm import tqdm
import numpy as np
import pandas as pd

from PIL import Image, ImageDraw

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision.transforms as T

from models import srgan, esrgan, mesrgan
from data import *
from utils import denormalize, psnr, ssim_rgb

# Hyperparameters

In [2]:
DATA_FOLDER = 'data/image_SRF_4/'
OUTPUT_FOLDER = 'BSD100/'

# Data

In [3]:
data_mean  = [0.4439, 0.4517, 0.4054]
data_std   = [0.2738, 0.2607, 0.2856]

class Interpolate:
    def __call__(self, img):
        return F.interpolate(
            img.unsqueeze(0), 
            size=tuple(int(i/4) for i in img.shape[1:]), 
            mode='bicubic', 
            align_corners=False
        )
    
to_lr = T.Compose([
    T.Normalize(mean=data_mean, std=data_std),
    Interpolate()
])
dataset   = ImageDataset(DATA_FOLDER)

# Pandas

In [4]:
arrays = [
    np.array(["SRGAN", "ESRGAN", "MESRGAN_T2"] * 2),

    np.array(["PSNR", "SSIM", ] * 3),
]

df = pd.DataFrame(np.zeros((90, 6)), columns=arrays)
df.head(5)

Unnamed: 0_level_0,SRGAN,ESRGAN,MESRGAN_T2,SRGAN,ESRGAN,MESRGAN_T2
Unnamed: 0_level_1,PSNR,SSIM,PSNR,SSIM,PSNR,SSIM
0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0


# Benchmark

In [5]:
generator_names = ['SRGAN', 'ESRGAN', 'MESRGAN_T2']
generator_models = [srgan, esrgan, mesrgan]
generator_trained_paths = [name + '_ALL_DATA_stage2_generator.trch' for name in generator_names]

In [6]:
if not os.path.isdir(OUTPUT_FOLDER):
    os.mkdir(OUTPUT_FOLDER)
        
package = zip(generator_models, generator_names, generator_trained_paths)
for model, name, model_path in tqdm(package):
    generator = model.Generator()
    generator.load_state_dict(torch.load('trained_models/' + model_path))
    generator.to("cuda:1")
    
    for i, hr_image in enumerate(dataset):
        if not os.path.isdir(OUTPUT_FOLDER + 'hr_images/'):
            os.mkdir(OUTPUT_FOLDER + 'hr_images/')
        
        if not os.path.isdir(OUTPUT_FOLDER + name):
            os.mkdir(OUTPUT_FOLDER + name)
            
        hr_image.save(OUTPUT_FOLDER + 'hr_images/' + f'{i}.png')
        hr_image_tensor = T.ToTensor()(hr_image)
        
        lr_image_tensor = to_lr(hr_image_tensor).to("cuda:1")
        
        with torch.no_grad():
            output = generator(lr_image_tensor)
            sr_image_tensor = denormalize(output.cpu()).squeeze(0)
    
        psnr_item = psnr(hr_image_tensor, sr_image_tensor).item()
        ssim_item = ssim_rgb(hr_image_tensor, sr_image_tensor).item()

        df.loc[i, (name,'PSNR')] = psnr_item
        df.loc[i, (name,'SSIM')] = ssim_item

        sr_image_pil = T.ToPILImage()(sr_image_tensor).convert("RGB")
        sr_image_pil.save(OUTPUT_FOLDER + name + f'/{i}.png')

3it [07:08, 142.90s/it]


In [7]:
# We reformated this table in the report
print(df.mean().round(2).to_latex())

\begin{tabular}{llr}
\toprule
           &      &      0 \\
\midrule
SRGAN & PSNR &  28.47 \\
ESRGAN & SSIM &   0.96 \\
MESRGAN\_T2 & PSNR &  27.36 \\
SRGAN & SSIM &   0.97 \\
ESRGAN & PSNR &  27.72 \\
MESRGAN\_T2 & SSIM &   0.95 \\
\bottomrule
\end{tabular}

