In [None]:
import os
import matplotlib.pyplot as plt
%matplotlib inline
from PIL import Image
import os
from numpy import asarray
import numpy as np
import tensorflow as tf

In [None]:
weights_path = '../Examples/PreTrained/'
from model_custom import Trainer
from utils import Utils
util = Utils()
trainer = Trainer(util, 128)
trainer.load_checkpoint(weights_path)


In [None]:
from data import DIV2K
div2k_train = DIV2K(subset='train')
div2k_valid = DIV2K(subset='valid')

In [None]:
train_ds = div2k_train.dataset(batch_size=1, random_transform=True)
valid_ds = div2k_valid.dataset(batch_size=1, random_transform=True, repeat_count=1)

In [None]:
def load_image(path):
    return np.array(Image.open(path))

In [None]:
def resize_np_img(image, new_size):
    return tf.image.resize(image, new_size)

In [None]:
%%time
import cv2
def resize_image_bicubic(image, new_size):
    return cv2.resize(image,new_size,cv2.INTER_CUBIC)

In [None]:
%%time
import cv2
def resize_image_linear(image, new_size):
    return cv2.resize(image,new_size,cv2.INTER_LINEAR)

In [None]:
def resolve_single(model, lr):
    return resolve(model, tf.expand_dims(lr, axis=0))[0]

In [None]:
def resolve(model, lr_batch):
    lr_batch = tf.cast(lr_batch, tf.float32)
    sr_batch = model(lr_batch)
    sr_batch = tf.clip_by_value(sr_batch, 0, 255)
    sr_batch = tf.round(sr_batch)
    sr_batch = tf.cast(sr_batch, tf.uint8)
    return sr_batch

In [None]:
import numpy as np
def resolve_and_plot(lr,hr):
    print(lr.shape)
    print(hr.shape)
    lr_4 = tf.expand_dims(lr, axis=0)
    gan_sr = trainer.generator(lr_4)[0]
    gan_sr = np.array(gan_sr,dtype='uint8')


    gan_sr = np.array(gan_sr,dtype='uint8')
    lr_cubic = resize_image_bicubic(lr,(2048,1024))
    lr_linear = resize_image_linear(lr,(2048,1024))

    linear_psnr = tf.image.psnr(lr_linear, hr, 255).numpy()
    cubic_psnr = tf.image.psnr(lr_cubic, hr, 255).numpy()
    gan_psnr = tf.image.psnr(gan_sr, hr, 255).numpy()
    hr_psnr = tf.image.psnr(hr, hr, 255).numpy()

    linear_ssim = tf.image.ssim(lr_linear, hr, 255).numpy()
    cubic_ssim = tf.image.ssim(lr_cubic, hr, 255).numpy()
    gan_ssim = tf.image.ssim(gan_sr, hr, 255).numpy()
    hr_ssim= tf.image.ssim(hr, hr, 255).numpy()


    images = [lr_linear, lr_cubic,gan_sr,hr]
    titles = ['LINEAR','BICUBIC','GAN','HR']
    psnrs = [linear_psnr,cubic_psnr,gan_psnr,hr_psnr]
    ssims = [linear_ssim,cubic_ssim,gan_ssim,hr_ssim]
    positions = [1, 2, 3, 4]

    plt.figure(figsize=(45, 30))
    for i, (img, title,psnr,ssim, pos) in enumerate(zip(images, titles, psnrs, ssims, positions)):
        plt.subplot(2, 2, pos)
        plt.imshow(img/255)
        plt.title(f"{title}, psnr: {psnr:.2f}, ssmim: {ssim:.2f}", fontsize=20)
        plt.xticks([])
        plt.yticks([])

In [None]:
def resolve_and_show(lr_image_path,hr_image_path):
    lr = load_image(lr_image_path)
    hr = load_image(hr_image_path)
    resolve_and_plot(lr,hr)


In [None]:
resolve_and_show("../Examples/div2k/images/DIV2K_valid_LR_bicubic/X4/0818x4.png","../Examples/div2k/images/DIV2K_valid_HR/0818.png")