### 0. Preparation

In [None]:
import numpy as np
import tensorlayer as tl
import tensorflow as tf
import os,sys
sys.path.append('../lib/SRGAN/')
from model import SRGAN_g

In [None]:
def mse(img1,img2):
    return np.square(img1-img2).mean()

def psnr(img1, img2):
    '''
    img1 and img2 are two 3-dimention images
    '''
    return 10*np.log10(255*255/(np.square(img1-img2).mean()))

### 1. Prediction

In [None]:
# set the paths
test_lr_path = '../data/train_set/LR'
checkpoint_path = '../output/SRGAN2/checkpoint'
save_path = '../output'

####### set different start and end images #######
start = 0
end = 500
# start = 500
# end = 1000
# start = 1000
# end = 1500

In [None]:
def predict(test_lr_path, checkpoint_path, save_path, start, end):

    ## create folders to save result images
    save_dir = os.path.join(save_path, 'test_gen')
    tl.files.exists_or_mkdir(save_dir)

    ###======PRE-LOAD DATA======###
    test_lr_img_list = sorted(tl.files.load_file_list(path=test_lr_path, regx='.*.jpg', printable=False))
    
    test_lr_img_list = test_lr_img_list[start:end]
    
    test_lr_imgs = tl.vis.read_images(test_lr_img_list, path=test_lr_path)

    ###======DEFINE MODEL======###

    test_lr_imgs = [(img / 127.5)-1 for img in test_lr_imgs] # rescale to ［－1, 1]

    test_image = tf.placeholder('float32', [1, None, None, 3], name='input_image')

    net_g = SRGAN_g(test_image, is_train=False, reuse=False)

    ###======RESTORE G======###
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False))
    sess.run(tf.global_variables_initializer())
    tl.files.load_and_assign_npz(sess=sess, name=os.path.join(checkpoint_path, 'g_srgan.npz'), network=net_g)

    ###======EVALUATION======###
    for i in range(len(test_lr_img_list)):
        img = test_lr_imgs[i]
        out = sess.run(net_g.outputs, {test_image: [img]})
        out = (out[0]+1)*127.5
        tl.vis.save_image(out.astype(np.uint8), os.path.join(save_dir, '{}'.format(test_lr_img_list[i])))
        if (i != 0) and (i % 10 == 0):
            print('saving %d images, ok' % i)
    print("finish")

In [None]:
# prediction
tf.reset_default_graph()
predict(test_lr_path=test_lr_path, checkpoint_path=checkpoint_path, save_path=save_path, start=start, end=end)

### 2. Calculate MSE

In [None]:
# set the paths
test_hr_img_path = '../data/test_set/HR'
gen_hr_img_path = os.path.join(save_path, 'test_gen')

In [None]:
# calculate mse
test_hr_list = sorted(tl.files.load_file_list(path=test_hr_img_path, regx='.*.jpg', printable=False))
test_hr_list = test_hr_list[start:end]
test_gen_list = sorted(tl.files.load_file_list(path=gen_hr_img_path, regx='.*.jpg', printable=False))

test_hr_imgs = tl.vis.read_images(test_hr_list, path=test_hr_img_path)
test_gen_imgs = tl.vis.read_images(test_gen_list, path=gen_hr_img_path)

# mse
np.mean([mse(img1,img2) for img1, img2 in zip(test_hr_imgs, test_gen_imgs)])