In [None]:
from tqdm import tqdm

import fastai
from fastai.vision import *
from fastai.callbacks import *
from multiprocessing import Pool
import matplotlib.pyplot as plt
import numpy as np
from  PIL import Image
import torch
import torchvision
from torchvision.models import vgg16_bn
from skimage.metrics import structural_similarity as ssim
import os
import sys

from scipy import ndimage
from torch import nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.io import read_image, ImageReadMode
from torch.utils.data import Dataset
from torch import is_tensor, FloatTensor,tensor

sys.path.append('MZSR')
from image_resize import image_resize

from scipy.signal import convolve2d

In [None]:
from utils.metrics import *

In [None]:
import time

def measure(fun):
    def wrapper(self):
      start = time.time()
      fun(self)
      end = time.time()
      self.time = end - start

    return wrapper

class AbstractModel:
    def __init__(self):
        self.gt_image = None
        self.lr_image = None
        self.result = None

    def get_name(self) -> str:
        raise NotImplementedError()

    def get_result(self) -> np.array:
        raise NotImplementedError()

    def get_metrics(self):
        return [PSNR(np.array(self.result), np.array(self.gt_image)), SSIM(np.array(self.result), np.array(self.gt_image)), self.time]

    def set_input(self, lr_image: Image, gt_image: Image):
        self.lr_image = np.array(lr_image).astype(np.float32) / 255
        self.gt_image = np.array(gt_image).astype(np.float32) / 255

In [None]:
class UNetModel(AbstractModel):
    def get_name(self) -> str:
        return 'UNet_Model_Nowszy'

In [None]:
class KPNLPModel(AbstractModel):
    def get_name(self) -> str:
        return 'KPNLP_Model'

In [None]:
class MZSRModel(AbstractModel):
    def __init__(self, bicubic=False):
        self.gt_image = None
        self.lr_image = None
        self.result = None
        self.bicubic = bicubic

    def get_name(self) -> str:
        name = 'MZSR_'
        name += 'bicubic' if self.bicubic else 'kernelGan'
        return name

In [None]:
class BicubicModel(AbstractModel):
    def get_name(self) -> str:
        return 'bicubic'

In [None]:
def calc_means(image, kernel_size=7):
    kernel = np.ones((kernel_size, kernel_size)) / kernel_size ** 2
    return convolve2d(image, kernel, mode='same')

In [None]:
def color_equalize(y_sr, y_lr):
    temp = image_resize(y_sr, scale=1/2, kernel='cubic')
    temp = image_resize(temp, scale=2, kernel='cubic')
    
    for i in range(3):
        mean_sr = calc_means(temp[:, :, i])
        mean_lr = calc_means(y_lr[:, :, i])
        diff = mean_lr - mean_sr
        y_sr[:, :, i] = np.clip(y_sr[:, :, i] + diff, 0, 1)
    
    return y_sr

In [None]:
def get_tests(path):
  result = []
  
  with open(path, 'r') as file:
    for line in file:
      while line[-1] == '\n':
        line = line[:-1]

      result.append(line.split(';'))
  
  return result

In [None]:
def test_on_dataset(path, dataset_lr, dataset_gt, models):
  lista=os.listdir(path/'datasets'/dataset_lr)
  metrics = [open(path/f'results/{dataset_lr}_{i.get_name()}_color_equalize.csv', 'w') for i in models]

  for i in metrics:
      i.write('Name;PSNR;SSIM;time\n')
  
  p_result = path/'results'/dataset_lr

  print(p_result)

  for i in models:
      os.makedirs(p_result/f'{i.get_name()}_color_equalize', exist_ok=True)

    
  pbar = tqdm(lista)
  for i in pbar:
      p_lr = f'datasets/{dataset_lr}/{i}'
      p_gt = f'datasets/{dataset_gt}/{i}'

      lr = image_resize(np.array(Image.open(path/p_lr)), scale=2, kernel='cubic').clip(0, 255) / 255
      gt = Image.open(path/p_gt)

      for j in range(len(models)):
        pbar.set_postfix({'Model': models[j].get_name()})
        models[j].set_input(lr, gt)
        
        temp = p_result/models[j].get_name()/i
        pred = Image.open(temp)
        pred = np.array(pred)[:models[j].lr_image.shape[0], :models[j].lr_image.shape[1], 0:3].astype(np.float32) / 255
        
        start = time.time()
        
        color_equalize(pred, lr)
            
        end = time.time()
        models[j].time = end - start
        models[j].result = pred
            
        temp = p_result/f'{models[j].get_name()}_color_equalize'/i
        pred = Image.fromarray(np.uint8(pred * 255))
        pred.save(temp)
        img_metrics = models[j].get_metrics()

        temp = str(i)

        for metric in img_metrics:
            temp += f';{metric}'

        metrics[j].write(f'{temp}\n')
        metrics[j].flush()
        os.fsync(metrics[j].fileno())
      torch.cuda.empty_cache()

  for i in metrics:    
      i.close()

In [None]:
# models = [UNetModel(), KPNLPModel(), MZSRModel(bicubic=True),  MZSRModel()]
models = [KPNLPModel()]
# models = [UNetModel()]
test_path = Path('test')

tests = get_tests(test_path/'config.csv')

for index, (hr, lr) in enumerate(tests):
  print(f'{index+1}/{len(tests)}: {lr} -> {hr}')

  test_on_dataset(test_path, lr, hr, models)