In [1]:
import numpy as np
import torch
from PIL import Image
from scipy.interpolate import interpn
from torchvision import transforms
from tqdm import tqdm
from utils.tensor_utils import *
import wandb
from torch.utils.data import random_split

from Dataset.AerialDataset import AerialDataset
from utils.metrics_utils import calculate_ssim, calculate_psnr

In [2]:
lr_size = 64
hr_size = 256
dataset_dir = 'E:\\TFG\\air_dataset'
dataset = AerialDataset(dataset_dir, lr_size, hr_size, return_pil=True)
_, _, test_dataset = random_split(dataset, [0.6, 0.2, 0.2], generator=torch.Generator().manual_seed(420))

In [5]:
def interpolation(image, objective_dim, method):
    # Calculo nuevas dimensiones
    height, width = image.shape[0], image.shape[1]
    new_width, new_height = objective_dim[0], objective_dim[1]
    new_image = np.zeros((new_height, new_width, image.shape[2]))

    # Generar cuadrículas para las coordenadas X e Y de la imagen original y la interpolada
    x = np.linspace(0, width - 1, width)
    y = np.linspace(0, height - 1, height)
    new_x = np.linspace(0, width - 1, new_width)
    new_y = np.linspace(0, height - 1, new_height)
    new_image = interpn((y, x), image, (new_y[:, None], new_x), method=method, bounds_error=False, fill_value=0)
    return new_image

In [5]:
for method in ["nearest", "bilinear", "bicubic"]:
    n_samples = 0
    ssim = 0     
    psnr = 0
    wandb.login()
    wandb.init(project="SR model benchmarking", name = method)
    for images in tqdm(test_dataset):
        if method == "cubic":
            interpolated = images["bicubic"]
        else:    
            lr = images["lr"]
            if method == "nearest":
                interpolated = lr.resize((hr_size,hr_size), Image.NEAREST)
            elif method == "bilinear":
                interpolated = lr.resize((hr_size,hr_size), Image.BILINEAR)
            
        hr = images["hr"]
        ssim += calculate_ssim(np.array(interpolated), np.array(hr))
        psnr += calculate_psnr(np.array(interpolated), np.array(hr))
        n_samples += 1
      
    wandb.log({"ssim" : ssim / n_samples, "psnr" : psnr / n_samples})
    wandb.finish()

In [7]:
method = "cubic"
n_samples = 0
ssim = 0     
psnr = 0
wandb.login()
wandb.init(project="SR model benchmarking", name = method)
for images in tqdm(test_dataset):
    lr = images["lr"]
    interpolated = lr.resize((hr_size,hr_size), Image.BICUBIC)
        
    hr = images["hr"]
    ssim += calculate_ssim(np.array(interpolated), np.array(hr))
    psnr += calculate_psnr(np.array(interpolated), np.array(hr))
    n_samples += 1
  
wandb.log({"ssim" : ssim / n_samples, "psnr" : psnr / n_samples})
wandb.finish()