In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
from matplotlib import cm
from scipy import misc
import numpy as np
import math, time
import tensorflow as tf
from models import generator
import models
import utils
import os
import sys
import imageio
from os import environ
from metrics import MultiScaleSSIM
from skimage.measure import compare_ssim
from PIL import Image
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
import cv2


In [None]:
data_dir = 'images/'
use_gpu = "true"
gpu = environ.get('gpu', '0')
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
phone = environ.get('phone', 'Nova2i')
model = environ.get('model', 'StyleEnhance')
iteration = environ.get('iteration', '27239')
ground_truth = environ.get('gt', 'iPhone8')
resolution = environ.get('res', 'iPhone8_resize')
eval_step = int(environ.get('eval_step', '1000'))
n_resnet = int(environ.get('resnet', '16'))

#Folders: test_images, test_patches
test_folder = environ.get('folder', 'input_images')

start = time.time()

if(environ.get('use_sn', 'False') == 'True'):
    use_sn = True
    print("Spectral Norm")
else:
    use_sn = False

pad = 20
evaluate = True
    
# get all available image resolutions
res_sizes = utils.get_resolutions()

# get the specified image resolution
IMAGE_HEIGHT, IMAGE_WIDTH, IMAGE_SIZE = utils.get_specified_res(res_sizes, phone, resolution)
print("Height: " + str(IMAGE_HEIGHT))
print("Width: " + str(IMAGE_WIDTH))
    
if(test_folder == "patches"):
    evaluate = False
    test_dir = "results/input_patches/"
    
elif(test_folder == "mobile_patches"):
    test_dir = data_dir + "Nova2i/test_patches/"
    gt_dir = data_dir + "iPhone8/test_patches/"

elif(test_folder == "mobile_full"):
    test_dir = data_dir + "Nova2i/test_images/"
    gt_dir = data_dir + "iPhone8/test_images/"

elif(test_folder == "test_images"):
    evaluate = False
    folder = os.getcwd() + "/results/input_patches/"
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(e)
    folder = os.getcwd() + "/results/merge_patches/"
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(e)
    test_dir = data_dir + "Nova2i/" + test_folder + "/"
    print("var")
    images = ([name for name in os.listdir(test_dir) if os.path.isfile(os.path.join(test_dir, name)) and (name.endswith(".png") or name.endswith(".jpg"))])
    for image in images:
        print(image)
        img_lq = Image.open(test_dir + image)
        img_lq_leftover = [None] * 2
        img_lq_leftover[0] = int(img_lq.size[0] % IMAGE_WIDTH)
        img_lq_leftover[1] = int(img_lq.size[1] % IMAGE_HEIGHT)
        img_lq_size = [None] * 2
        img_lq_size[0] = img_lq.size[0]
        img_lq_size[1] = img_lq.size[1]
        if(img_lq_leftover[0] > 0):
            img_lq_size[0] = img_lq.size[0] - img_lq_leftover[0] + IMAGE_WIDTH
        if(img_lq_leftover[1] > 0):
            img_lq_size[1] = img_lq.size[1] - img_lq_leftover[1] + IMAGE_HEIGHT
        
        k = 0
        l = 0
        for i in range(0,img_lq_size[1],IMAGE_HEIGHT):
            for j in range(0,img_lq_size[0],IMAGE_WIDTH):
                #if(j + IMAGE_WIDTH <= img_lq.size[0] and i + IMAGE_HEIGHT <= img_lq.size[1]):
                pad_i = pad_j = pad_w = pad_h = 0
                if(i > 0):
                    pad_i = pad
                if(j > 0):
                    pad_j = pad
                if(j+IMAGE_WIDTH+pad <= img_lq_size[0]):
                    pad_w = pad
                if(i+IMAGE_HEIGHT+pad <= img_lq_size[1]):
                    pad_h = pad
                box = (j - pad_j, i - pad_i, j+IMAGE_WIDTH+pad_w, i+IMAGE_HEIGHT+pad_h)
                img_patch = Image.new('RGB', (IMAGE_WIDTH + pad * 2, IMAGE_HEIGHT + pad * 2))
                img_patch.paste(img_lq.crop(box), (pad - pad_j, pad - pad_i))
                #img_patch.save(os.getcwd() + "/results/input_patches/" + image.split('.')[0] + "_" + str(k) + ".jpg")
                img_patch.save(os.getcwd() + "/results/input_patches/" + image.split('.')[0].replace(')','').split('(')[1] + "_" + str(k) + ".jpg")
                k += 1

        print(k)
    
    test_dir = "results/input_patches/"

elif(test_folder.endswith("images")):
    evaluate = False
    folder = os.getcwd() + "/results/input_patches/"
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(e)
    folder = os.getcwd() + "/results/merge_patches/"
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
        except Exception as e:
            print(e)
    test_dir = "results/" + test_folder + "/"
    
    images = ([name for name in os.listdir(test_dir) if os.path.isfile(os.path.join(test_dir, name)) and (name.endswith(".png") or name.endswith(".jpg"))])
    for image in images:
        print(image)
        img_lq = Image.open(test_dir + image)
        k = 0
        l = 0
        for i in range(0,img_lq.size[1],IMAGE_HEIGHT):
            for j in range(0,img_lq.size[0],IMAGE_WIDTH):
                pad_i = pad_j = pad_w = pad_h = 0
                if(i > 0):
                    pad_i = pad
                if(j > 0):
                    pad_j = pad
                if(j+IMAGE_WIDTH+pad <= img_lq.size[0]):
                    pad_w = pad
                if(i+IMAGE_HEIGHT+pad <= img_lq.size[1]):
                    pad_h = pad
                box = (j - pad_j, i - pad_i, j+IMAGE_WIDTH+pad_w, i+IMAGE_HEIGHT+pad_h)
                img_patch = Image.new('RGB', (IMAGE_WIDTH + pad * 2, IMAGE_HEIGHT + pad * 2))
                img_patch.paste(img_lq.crop(box), (pad - pad_j, pad - pad_i))
                #img_patch.save(os.getcwd() + "/results/input_patches/" + image.split('.')[0] + "_" + str(k) + ".jpg")
                img_patch.save(os.getcwd() + "/results/input_patches/" + image + "_" + str(k) + ".jpg")
                k += 1

        print(k)
    
    test_dir = "results/input_patches/"
    
else:
    test_dir = data_dir + phone.split('_')[0] + "/" + test_folder + "/"
    gt_dir = data_dir + ground_truth + "/" + test_folder + "/"
            
if(test_folder.endswith("images")):
    IMAGE_HEIGHT = IMAGE_HEIGHT + pad * 2
    IMAGE_WIDTH = IMAGE_WIDTH + pad * 2
    IMAGE_SIZE = IMAGE_HEIGHT * IMAGE_WIDTH * 3
    

# disable gpu if specified
config = tf.ConfigProto(device_count={'GPU': 0}) if use_gpu == "false" else tf.ConfigProto()
config.gpu_options.allow_growth=True  

# create placeholders for input images
input_ = tf.placeholder(tf.float32, [None, IMAGE_SIZE])
input_image = tf.reshape(input_, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3])    

if(test_folder == "test_split" or test_folder.endswith("images")):
    destination = "results/merge_patches/"
else:
    destination = "results/"



In [None]:
# generate enhanced image
with tf.Session(config=config) as sess:
    
    test_photos = [f for f in sorted(os.listdir(test_dir)) if os.path.isfile(test_dir + f) and (f.endswith(".png") or f.endswith(".jpg"))]

    enhanced = generator(input_image, n_resnet=n_resnet, isTraining=False, use_sn=use_sn)

    num_saved_models = int(len([f for f in os.listdir("models/") if f.startswith(str(phone) + model + "_iteration")]) / 2)
    print("testing " + str(num_saved_models) + " models")

    iteration = [int(iteration)]

    for i in iteration:
        print("Loading model: " + "models/" + phone + "_" + model + "_iteration_" + str(i) + ".ckpt")

        # load model
        saver = tf.train.Saver()
        saver.restore(sess, 'models/' + str(phone) + "_"  + model + '_iteration_' + str(i) + ".ckpt")

        t_rmse = 0
        t_psnr = 0
        t_ssim = 0
        t_msssim = 0

        for photo in test_photos:

            # load training image and crop it if necessary
            print("iteration " + str(i) + ", processing image " + photo + " of " + str(len(test_photos)))

            image = Image.open(test_dir + photo)
            if(evaluate):
                HD = Image.open(gt_dir + photo.split(".")[0] + ".jpg")

            image_resize = np.float32( image.resize([IMAGE_WIDTH, IMAGE_HEIGHT])) / 255
            image_resize = image_resize * 2.0 - 1

            if(evaluate):
                hd_resize = np.float32( HD.resize([IMAGE_WIDTH, IMAGE_HEIGHT])) / 255
            image_resize_2d = np.reshape(image_resize, [1, IMAGE_SIZE])

            run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True)

            enhanced_2d = sess.run(enhanced, feed_dict={input_: image_resize_2d}, options = run_options)
            enhanced_image = np.reshape(enhanced_2d, [IMAGE_HEIGHT, IMAGE_WIDTH, 3])


            if(test_folder == "test_split" or test_folder.endswith("images")):
                enhanced_image = enhanced_image[pad:IMAGE_HEIGHT-pad, pad:IMAGE_WIDTH-pad,:]
                image_resize = image_resize[pad:IMAGE_HEIGHT-pad, pad:IMAGE_WIDTH-pad,:]
                if(evaluate):
                    hd_resize = hd_resize[pad:IMAGE_HEIGHT-pad, pad:IMAGE_WIDTH-pad,:]

            enhanced_image = (enhanced_image + 1) / 2 
            enhanced_image = enhanced_image.clip(min=0,max=1)

            photo_name = photo.rsplit(".", 1)[0]

            # save the results as .png images
            if(test_folder == "test_split" or test_folder.endswith("images")):
                plt.imsave(destination + photo_name + ".png", enhanced_image)
            else:
                plt.imsave(destination + photo_name + ".png", enhanced_image)
            print("image saved")


            if(evaluate):
                ssim = compare_ssim(hd_resize, enhanced_image, multichannel=True)
                msssim = MultiScaleSSIM(np.expand_dims(hd_resize, axis=0), np.expand_dims(enhanced_image, axis=0))

            if(evaluate):
                rmse = metrics.rmse(hd_resize, enhanced_image)
                psnr = metrics.psnr(hd_resize, enhanced_image)
                t_ssim = t_ssim + ssim
                t_msssim = t_msssim + msssim
                t_rmse = t_rmse + rmse
                t_psnr = t_psnr + psnr

                print("SSIM (target) " + photo + ": " + str(ssim))
                print("SSIM (orig) " + photo + ": " + str(compare_ssim(image_resize, enhanced_image, multichannel=True)))
                print("MS-SSIM " + photo + ": " + str(msssim))
                print("RMSE " + photo + ": " + str(rmse))
                print("PSNR " + photo + ": " + str(psnr))

                logs = open(destination + phone + '.txt', "a")
                logs.write(phone + '_' + photo + ": " + str(compare_ssim(hd_resize, enhanced_image, multichannel=True)))
                logs.write( '|' + str(utils.rmse(hd_resize, enhanced_image)))
                logs.write( '|' + str(utils.psnr(hd_resize, enhanced_image)))
                logs.write('\n\n')
                logs.close()

In [None]:
if(evaluate):
    print("SSIM: " + str(t_ssim / len(test_photos)))
    print("MSSSIM: " + str(t_msssim / len(test_photos)))
    print("\nRMSE: " + str(t_rmse / len(test_photos)))
    print("PSNR: " + str(t_psnr / len(test_photos)))

    logs = open(destination + phone + '.txt', "a")
    logs.write("\nSSIM: " + str(t_ssim / len(test_photos)))
    logs.write("\nMSSSIM: " + str(t_msssim / len(test_photos)))
    logs.write("\nRMSE: " + str(t_rmse / len(test_photos)))
    logs.write("\nPSNR: " + str(t_psnr / len(test_photos)))
    logs.write('\n\n')
    logs.close()

In [None]:
if(test_folder == "test_split" or test_folder.endswith("images")):
    #Evan Russenberger-Rosica
    PATH = os.getcwd() + "/" + destination
    if(test_folder == "test_split" or test_folder == "test_images"):
        test_images = os.getcwd() + "/images/Nova2i/test_images/"
    elif(test_folder.endswith("images")):
        test_images = os.getcwd() + "/results/" + test_folder + "/"
        
    test_image_filenames = [f for f in sorted(os.listdir(test_images)) if os.path.isfile(test_images + f) and (f.endswith(".png") or f.endswith(".jpg"))]

    test_image = Image.open(test_images + test_image_filenames[0])
    frame_width, frame_height = test_image.size
    #frame_width = frame_width - (frame_width % IMAGE_HEIGHT)
    images_per_row = math.ceil(frame_width/IMAGE_WIDTH)
    patches_per_image = images_per_row * math.ceil(frame_height/IMAGE_HEIGHT)
    print("Frame width:" + str(frame_width))
    print("Patches per row:" + str(images_per_row))
    print("Patches per image:" + str(patches_per_image))
    padding = 0
    
    def patch_num(x):
        return int(x.split('_')[1].split('.')[0])

    num_test_images = len([name for name in os.listdir(test_images) if os.path.isfile(os.path.join(test_images, name))])
    print(num_test_images)
    images = sorted([name for name in os.listdir(PATH) if os.path.isfile(os.path.join(PATH, name))])

    for t in range(num_test_images):

        image = (images[t * patches_per_image: (t + 1) * patches_per_image])
        image.sort(key=patch_num)

        img_width, img_height = Image.open(PATH + image[0]).size
        sf = 1
        print(sf)
        scaled_img_width = math.floor(img_width*sf) 
        scaled_img_height = math.floor(img_height*sf)

        number_of_rows = math.ceil(len(image)/images_per_row)
        frame_height = math.ceil(sf*img_height*number_of_rows)
        new_im = Image.new('RGB', (test_image.size[0], test_image.size[1]))

        i,j=0,0
        for num, im in enumerate(image):
            if num%images_per_row==0:
                i=0
            im = Image.open(PATH + im)
            y_cord = (j//images_per_row)*scaled_img_height
            new_im.paste(im, (i,y_cord))
            i=(i+scaled_img_width)+padding
            j+=1

        new_im.save(os.getcwd() + "/results/" + image[0].split('_')[0] + ".jpg")
        print(image[0].split('_')[0])
print("Total Time Elapsed: " + str(time.time() - start))

In [None]:
#set to true to delete "merge patches" from results folder
if(False):
    folder = os.getcwd() + "/results/merge_patches/"
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            #elif os.path.isdir(file_path): shutil.rmtree(file_path)
        except Exception as e:
            print(e)
#set to true to delete "input patches" from results folder            
if(False):
    folder = os.getcwd() + "/results/input_patches/"
    for the_file in os.listdir(folder):
        file_path = os.path.join(folder, the_file)
        try:
            if os.path.isfile(file_path):
                os.unlink(file_path)
            #elif os.path.isdir(file_path): shutil.rmtree(file_path)
        except Exception as e:
            print(e)