!pip install -r requirements.txt

!pip uninstall -y PyWavelets
!pip install PyWavelets

In [1]:
import os
import time
import random
import numpy as np
import scipy, multiprocessing
import tensorflow as tf
import tensorlayer as tl
from model import get_G, get_D
from config import config
from PIL import Image

import math
from random import randrange

import pandas as pd

import matplotlib.pyplot as plt

from skimage import img_as_float
from skimage.measure import compare_ssim as ssim, compare_psnr as psnr

import json

def evaluate(checkpoint_dir, model, valid_lr_img, valid_hr_img, image_name, G = None, save_dir = "validation-samples"):

    os.makedirs(save_dir, exist_ok=True)
    
    valid_lr_img = (valid_lr_img / 127.5) - 1  # rescale to ［－1, 1]

    if not G:
        G = get_G([1, None, None, 3])
    G.load_weights(os.path.join(checkpoint_dir, model))
    G.eval()

    valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
    valid_lr_img = valid_lr_img[np.newaxis,:,:,:]
    size = [valid_lr_img.shape[1], valid_lr_img.shape[2]]

    out = G(valid_lr_img).numpy()
    
    model_num = model.replace(".h5","").split("-")[1]

    print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    
    if not os.path.isfile('sr-' + model_num + "-" + image_name):
        tl.vis.save_image(out[0], os.path.join(save_dir, 'sr-' + model_num + "-" + image_name))

        out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
        tl.vis.save_image(out_bicu, os.path.join(save_dir, 'bic-' + model_num + "-" + image_name))

    sr_smaller = tf.image.resize(out[0], size=size)
    hr_smaller = tf.image.resize(valid_hr_img, size=size)

    validate = {
        "sr" : out[0],
        "sr_resized" : sr_smaller.numpy(),
        
        "lr" : valid_lr_img[0],
        "bic" : out_bicu,
        
        "hr" : valid_hr_img, 
        "hr_resized" : hr_smaller.numpy(),
    }
    
    data = {
        "G" : G,
        
        "model" : model,

        "psnr_lr" : psnr( validate.get("lr"),  validate.get("sr_resized")),
        "ssim_lr" : ssim(validate.get("lr"),  validate.get("sr_resized"), multichannel=True),

        "psnr_hr_4" : psnr( validate.get("hr_resized"),  validate.get("sr_resized"), data_range = 255),
        "ssim_hr_4" : ssim(validate.get("hr_resized"),  validate.get("sr_resized"), multichannel=True),
        
        "psnr_hr" : psnr( validate.get("hr"),  validate.get("sr")),
        "ssim_hr" : ssim(validate.get("hr"),  validate.get("sr"), multichannel=True),

        "psnr_bic_hr" : psnr( validate.get("hr"),  validate.get("bic")),
        "ssim_bic_hr" : ssim( validate.get("hr"),  validate.get("bic"), multichannel=True),
    }
    return data
       

In [2]:
def evaluate_downsample(checkpoint_dir, model, valid_hr_img, image_name, G = None, save_dir = "validation-ds-samples"):

    os.makedirs(save_dir, exist_ok=True)
    
    size = [int(valid_hr_img.shape[0]/4), int(valid_hr_img.shape[1]/4)]
    
    hr_smaller = tf.image.resize(valid_hr_img, size=size)
    
    valid_lr_img = (hr_smaller / 127.5) - 1  # rescale to ［－1, 1]

    if not G:
        G = get_G([1, None, None, 3])
    G.load_weights(os.path.join(checkpoint_dir, model))
    G.eval()

    valid_lr_img = np.asarray(valid_lr_img, dtype=np.float32)
    valid_lr_img = valid_lr_img[np.newaxis,:,:,:]
    

    out = G(valid_lr_img).numpy()
    
    model_num = model.replace(".h5","").split("-")[1]

    print("LR size: %s /  generated HR size: %s" % (size, out.shape))  # LR size: (339, 510, 3) /  gen HR size: (1, 1356, 2040, 3)
    
    if not os.path.isfile('sr-' + model_num + "-" + image_name):
        tl.vis.save_image(out[0], os.path.join(save_dir, 'sr-' + model_num + "-" + image_name))

        out_bicu = scipy.misc.imresize(valid_lr_img[0], [size[0] * 4, size[1] * 4], interp='bicubic', mode=None)
        tl.vis.save_image(out_bicu, os.path.join(save_dir, 'bic-' + model_num + "-" + image_name))

    sr_smaller = tf.image.resize(out[0], size=size)
    

    validate = {
        "sr" : out[0],
        "sr_resized" : sr_smaller.numpy(),
        
        "lr" : valid_lr_img[0],
        "bic" : out_bicu,
        
        "hr" : valid_hr_img, 
        "hr_resized" : hr_smaller.numpy(),
    }
    
    data = {
        "G" : G,
        
        "model" : model,

        "psnr_lr" : psnr( validate.get("lr"),  validate.get("sr_resized")),
        "ssim_lr" : ssim(validate.get("lr"),  validate.get("sr_resized"), multichannel=True),

        "psnr_hr_4" : psnr( validate.get("hr_resized"),  validate.get("sr_resized"), data_range = 255),
        "ssim_hr_4" : ssim(validate.get("hr_resized"),  validate.get("sr_resized"), multichannel=True),
        
        "psnr_hr" : psnr( validate.get("hr"),  validate.get("sr")),
        "ssim_hr" : ssim(validate.get("hr"),  validate.get("sr"), multichannel=True),

        "psnr_bic_hr" : psnr( validate.get("hr"),  validate.get("bic")),
        "ssim_bic_hr" : ssim( validate.get("hr"),  validate.get("bic"), multichannel=True),
    }
    return data
       

In [3]:
###====================== PRE-LOAD DATA ===========================###
valid_hr_img_list = sorted(tl.files.load_file_list(path=config.VALID.hr_img_path, regx='.*.png', printable=False))[:20]
valid_lr_img_list = sorted(tl.files.load_file_list(path=config.VALID.lr_img_path, regx='.*.png', printable=False))[:20]

valid_lr_imgs = tl.vis.read_images(valid_lr_img_list, path=config.VALID.lr_img_path, n_threads=32)

valid_hr_imgs = tl.vis.read_images(valid_hr_img_list, path=config.VALID.hr_img_path, n_threads=32)

[TL] read 20 from DIV2K/DIV2K_valid_LR_difficult/
[TL] read 20 from DIV2K/DIV2K_valid_HR/


In [4]:
def createPyPlot(validate_data, resized = True):


    label = 'SSIM: {:.2f}, sk_psnr:{:.2f} PSNR: {:.2f}'

    if resized: # show the images at size == the size of the input LR image
        fig, axes = plt.subplots(nrows=1, ncols=4, figsize=(17, 12),
                             sharex=True, sharey=True)
        ax = axes.ravel()
        
        ax[0].imshow(validate_data.get("images").get("lr"))
        ax[0].set_xlabel(label.format(1.00, 100.0, 100.0))
        ax[0].set_title('valid LR image')
        
        ax[1].imshow(validate_data.get("images").get("sr_resized"))
        ax[1].set_xlabel(label.format(validate_data.get("ssim_lr"), validate_data.get("psnr_lr"), validate_data.get("PSNR_lr")))
        ax[1].set_title('generated image resized *-4 vs LR image')
        
        ax[2].imshow(validate_data.get("images").get("hr_resized"))
        ax[2].set_xlabel(label.format(1.00, 100.0, 100.0))
        ax[2].set_title('valid HR resized *-4')      
        
        ax[3].imshow(validate_data.get("images").get("sr_resized"))
        ax[3].set_xlabel(label.format(validate_data.get("ssim_hr_4"), validate_data.get("psnr_hr_4"), validate_data.get("PSNR_hr_4")))
        ax[3].set_title('generated image resized *-4 vs HR resized')
        
    else: 
        
        fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(17, 12),
                             sharex=True, sharey=True)
        ax = axes.ravel()
    
        ax[0].imshow(validate_data.get("images").get("hr"))
        ax[0].set_xlabel(label.format(1.00, 100.0, 100.0))
        ax[0].set_title('valid HR image')

        ax[1].imshow(validate_data.get("images").get("bic"))
        ax[1].set_xlabel(label.format(validate_data.get("ssim_bic_hr"), validate_data.get("psnr_bic_hr"), validate_data.get("PSNR_bic_hr")))
        ax[1].set_title('bicubic interpolation *4 vs HR')

        ax[2].imshow(validate_data.get("images").get("sr"))
        ax[2].set_xlabel(label.format(validate_data.get("ssim_hr"), validate_data.get("psnr_hr"), validate_data.get("PSNR_bic_hr")))
        ax[2].set_title('generated image vs HR')
    
    plt.tight_layout()
    plt.show()

In [5]:
def compare_models_names(a):
    return int(a.replace(".h5","").split("-")[1])

In [6]:
def rand_three(l):
    return [i for i in set((randrange(l), randrange(l), randrange(l), randrange(l), randrange(l)))][:3]


In [7]:
models = ["g-830-base.h5", "g-830-cyclic.h5"]

In [8]:
G = None

l = len(valid_hr_img_list)

for image in rand_three(l):
    validate_array = []
    for model in models:
        valid_lr_img = valid_lr_imgs[image]
        valid_hr_img = valid_hr_imgs[image]
        image_name = valid_hr_img_list[image]
        
        ev = evaluate("checkpoint", model, valid_lr_img, valid_hr_img, image_name, G = G)
        
        G = ev.pop("G", G)
        validate_array.append(ev) 
        
    with open("logs/" + image_name + ".json", mode='w', encoding='utf-8') as f:
        json.dump(validate_array, f)
    

In [9]:
l = len(valid_hr_img_list)

validate_ds_array_base = []
validate_ds_array_cyclic = []

for image in rand_three(l):    
    valid_hr_img = valid_hr_imgs[image]
    image_name = valid_hr_img_list[image]

    ev_base = evaluate_downsample("checkpoint", "g-830-base.h5", valid_hr_img, image_name, G = G)
    G = ev.pop("G", G)
    ev_cyclic = evaluate_downsample("checkpoint", "g-830-cyclic.h5", valid_hr_img, image_name, G = G)
    
    validate_ds_array_cyclic.append(ev_cyclic) 
    validate_ds_array_base.append(ev_base)
    

[TL] Input  _inputlayer_1: [1, None, None, 3]
[TL] Conv2d conv2d_1: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: relu
[TL] Conv2d conv2d_2: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_1: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False
[TL] Conv2d conv2d_3: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_2: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: False
[TL] Elementwise elementwise_1: fn: add act: No Activation
[TL] Conv2d conv2d_4: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_3: decay: 0.900000 epsilon: 0.000010 act: relu is_train: False
[TL] Conv2d conv2d_5: n_filter: 64 filter_size: (3, 3) strides: (1, 1) pad: SAME act: No Activation
[TL] BatchNorm batchnorm2d_4: decay: 0.900000 epsilon: 0.000010 act: No Activation is_train: False
[TL] Elementwise elementwi

`imresize` is deprecated in SciPy 1.0.0, and will be removed in 1.3.0.
Use Pillow instead: ``numpy.array(Image.fromarray(arr).resize())``.
  warn("Inputs have mismatched dtype.  Setting data_range based on "
  warn("Inputs have mismatched dtype.  Setting data_range based on "


NameError: name 'ev' is not defined

In [None]:
base = pd.DataFrame(validate_ds_array_base)

In [None]:
cyclic = pd.DataFrame(validate_ds_array_cyclic)

In [None]:
createPyPlot(validate, resized = True)

createPyPlot(validate, resized = False)