In [None]:
import numpy as np
import math
import random
import tensorflow as tf
print(tf.__version__)

from metrics import MultiScaleSSIM
from metrics import tf_ms_ssim

import time
from keras.callbacks import TensorBoard

import models as models
from load_dataset import load_test_data, load_train_data
import utils
import vgg
import loss

import os, shutil
import sys
from os import environ

from PIL import Image
from IPython.display import display
from matplotlib import pyplot as plt

In [None]:
dataset_dir = 'images/'
use_gpu = "true"
vgg_dir = "vgg_pretrained/imagenet-vgg-verydeep-19.mat"

num_epochs = int(environ.get('epochs', '10'))
test_size = int(environ.get('test_size', '200'))
batch_size = int(environ.get('batch_size', '16'))
load_step = int(environ.get('load_step', '1000'))

#set to -1 to use all training patches
num_train_data = int(environ.get('train_data', '-1'))
target = environ.get('target', 'iPhone8')
source = environ.get('source', 'Nova2i')
learning_rate_gen = float(environ.get('lr_g', '0.0001'))
learning_rate_disc = float(environ.get('lr_d', '0.0004'))
beta1 = float(environ.get('beta', '0.9'))
num_train_iters = int(environ.get('iterations', '1000'))
num_resnet = int(environ.get('resnet', '16'))
g_iters = int(environ.get('g_iters', '1'))
d_iters = int(environ.get('d_iters', '1'))
pixel = environ.get('pixel', 'L2')
VGG_LAYER = environ.get('vgg_layer', 'conv5_4')
gan = environ.get('gan', 'hinge')
model = environ.get('model', 'styleenhance')
start_iter = int(environ.get('start_iter', '0'))
#use when initializing using a pre-trained model. "same" is used if the model have a same filename
init = environ.get('init', 'same')

#set to 1 if not label smoothing
label_smoothing = 1
    
if(environ.get('use_sn', 'True') == 'True'):
    use_sn = True
    print("Spectral Norm")
else:
    use_sn = False
    
phone = source

hqFolder = os.getcwd() + "/images/" + target + "/"
lqFolder = os.getcwd() + "/images/" + source + "/"
registeredFolder = "registered/"
hqPatches = hqFolder + "train_patches/"
lqPatches = lqFolder + "train_patches/"
hqTestPatches = hqFolder + "test_patches/"
lqTestPatches = lqFolder + "test_patches/"


total_train_data = len([name for name in os.listdir(lqPatches) if os.path.isfile(os.path.join(lqPatches, name))])

if(num_train_data == -1):
    num_train_data = total_train_data
iters_per_epoch = math.floor(num_train_data / batch_size)
if(num_epochs != 0):
    num_train_iters = int(num_epochs * iters_per_epoch)
if(load_step == -1):
    load_step = iters_per_epoch
train_size = int(load_step * batch_size)
load_per_epoch = int(math.ceil(num_train_data / train_size))
if(train_size > num_train_data):
    load_per_epoch = 1

print("Train Data: " + str(num_train_data))
print("Train Size: " + str(train_size))
print("Batch Size: " + str(batch_size))
print("Load Iterations: " + str(load_step))
print("Number of Epochs:" + str(num_train_iters / iters_per_epoch))
print("Iterations per Epoch: " + str(iters_per_epoch))
print("Train data loading per Epoch: " + str(load_per_epoch))
print("Number of Iterations: " + str(num_train_iters)) 

w_gan = float(environ.get('w_gan', '1000'))
w_pixel = float(environ.get('w_pixel', '1'))
w_perceptual = float(environ.get('w_perceptual', '0.5'))

config = tf.ConfigProto(device_count={'GPU': 0}) if use_gpu == "false" else None

PATCH_WIDTH = int(environ.get('width', '96'))
PATCH_HEIGHT = int(environ.get('height', '96'))
PATCH_SIZE = PATCH_WIDTH * PATCH_HEIGHT * 3

print("GAN Loss: " + gan)
print("Model Name: " + model)
print("Target: " + target)
print("Source: " + source)  

gpu = environ.get('gpu', '0')
os.environ['CUDA_VISIBLE_DEVICES'] = gpu

print(tf.test.is_gpu_available())

if(init == "same"):
    init = model
    

In [None]:
data_idx = 0

print("Loading test data...")
test_data, test_target = load_test_data(phone, dataset_dir, test_size, PATCH_SIZE, PATCH_WIDTH, target)
print("Test data was loaded\n")

print("Loading training data...")
train_data, train_target, reload = load_train_data(phone, dataset_dir, train_size, PATCH_SIZE, PATCH_WIDTH, target, data_idx % load_per_epoch, train_data_size=num_train_data)
print("Training data was loaded\n")

data_idx += 1

#if the entire dataset as been loaded
if(not reload):
    print("Reload")
    train_size = train_data.shape[0]

TEST_SIZE = test_data.shape[0]
num_test_batches = int(test_data.shape[0] / batch_size)
num_train_batches = int(math.floor(train_data.shape[0] / batch_size))

In [None]:
start = time.time()
config = tf.ConfigProto()  
config.gpu_options.allow_growth=True

with tf.Graph().as_default() as graph:
    with tf.Session(config=config) as sess:
        lq_ = tf.placeholder(tf.float32, [None, PATCH_HEIGHT, PATCH_WIDTH, 3])
        lq_image = lq_
        hq_ = tf.placeholder(tf.float32, [None, PATCH_HEIGHT, PATCH_WIDTH, 3])
        hq_image = hq_
        
        #print(hq_image)
        hq_image = hq_image * 2.0
        hq_image = hq_image - 1.0
        lq_image = lq_image * 2.0
        lq_image = lq_image - 1.0
        
        #real_image = tf.identity(hq_image)
        #real_image = real_image * 2
        #real_image = real_image - 1
        
        is_train = tf.placeholder(tf.int16)

        if(is_train == 1):
            gen_hq = models.generator(lq_image, hq_image, "_lq", n_resnet=num_resnet, isTraining=True, use_sn=False)
        else:
            gen_hq = models.generator(lq_image, hq_image, "_lq", n_resnet=num_resnet, isTraining=False, use_sn=False)
            
        #Adversarial Loss
        _, logits_disc, logits_disc_pred = models.discriminator(tf.concat([hq_image, gen_hq], 0), "_ch", use_sn=use_sn)
        logits_hq, logits_en = tf.split(logits_disc, num_or_size_splits=2, axis=0)
        logits_hq_pred, logits_en_pred = tf.split(logits_disc_pred, num_or_size_splits=2, axis=0)
        
        discrim_accuracy = tf.reduce_mean(logits_hq_pred + (1-logits_en_pred)) / 2
        accuracy_real = tf.reduce_mean(logits_hq)
        accuracy_fake = tf.reduce_mean(logits_en)

        loss_gen, loss_discrim, gan = loss.generator_loss(gan, gen_hq, hq_image, logits_hq, logits_en, label_smoothing, batch_size)
        
        #Pixel Loss
        loss_pixel = loss.pixel_loss(pixel, hq_image, gen_hq, batch_size)

        #Perceptual Loss     
        gen_hq_vgg = vgg.net(vgg_dir, vgg.preprocess((gen_hq + 1) * 2 * 255))
        hq_vgg = vgg.net(vgg_dir, vgg.preprocess((hq_image + 1) * 2 * 255))
        
        loss_perceptual = loss.perceptual_loss(gen_hq_vgg, hq_vgg, VGG_LAYER, pixel, batch_size)
        
        #Image quality metrics
        loss_mse = tf.reduce_sum(tf.pow(hq_image - gen_hq, 2))/(PATCH_SIZE * batch_size)
        loss_psnr = -(20 * utils.log10(1.0 / tf.sqrt(loss_mse)))
        loss_ssim = 1 - tf.reduce_mean(tf.image.ssim(hq_image, gen_hq, 1))

        #Multi-term Loss
        loss_final = w_perceptual * loss_perceptual + w_pixel * loss_pixel + w_gan * loss_gen
        
        #Optimizer
        optimizer_gen = tf.train.AdamOptimizer(learning_rate_gen, beta1=beta1)
        optimizer_disc = tf.train.AdamOptimizer(learning_rate_disc, beta1=beta1)
        
        generator_vars = [v for v in tf.global_variables() if "generator" in v.name]
        discriminator_vars = [v for v in tf.global_variables() if "discriminator" in v.name]
        
        with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
            train_step_gen = optimizer_gen.minimize(loss_final, var_list=generator_vars)
            train_step_disc = optimizer_disc.minimize(loss_discrim, var_list=discriminator_vars)
        
        with tf.name_scope("pixel_loss") as scope:
            tf.summary.scalar('pixel_Loss', loss_pixel)
            
        with tf.name_scope("adversarial_loss") as scope:
            tf.summary.scalar('Discriminator_Loss', loss_discrim)
            tf.summary.scalar('Discriminator_Accuracy', discrim_accuracy)
            tf.summary.scalar('Generator_Loss', loss_gen)
            
        with tf.name_scope("perceptual_loss") as scope:
            tf.summary.scalar('VGG_Loss', loss_perceptual)
        
        with tf.name_scope("psnr_mse") as scope:
            tf.summary.scalar('PSNR', loss_psnr)
            tf.summary.scalar('MSE', loss_mse)
            tf.summary.scalar('SSIM', loss_ssim)
            
        with tf.name_scope("final_loss") as scope:
            tf.summary.scalar('Final_Loss', loss_final)
            tf.summary.scalar('Adv_Loss', w_gan * loss_gen)
            tf.summary.scalar('Pixel_Loss', w_pixel * loss_pixel)
            tf.summary.scalar('Perceptual_Loss', w_perceptual * loss_perceptual)
        
        if(start_iter > 0):
            print("Loading variables")
            saver = tf.train.Saver()
            saver.restore(sess, 'models/' + str(phone) + "_"  + init + '_iteration_' + str(start_iter) + ".ckpt")
            start_iter += 1
            print(start_iter)
            start_iter = int(start_iter)
        else:
            print('Initializing variables')
            sess.run(tf.global_variables_initializer())
            
        print("Time elapsed: " + str(time.time() - start))
        print('Training network')  
        
        saver = tf.train.Saver(max_to_keep=3)
        epoch_saver = tf.train.Saver()
        gen_saver = tf.train.Saver(var_list=generator_vars, max_to_keep=1)
        
        train_loss_gen = 0.0
        train_loss_disc = 0.0
        train_acc_discrim = 0.0
        crops_i = random.sample(range(TEST_SIZE), 5)
        test_crops = test_data[crops_i, :]
        test_gt = test_target[crops_i, :]

        logs = open('models/' + phone + '.txt', "w+")
        logs.close()
        
        merge = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter( os.path.join(os.getcwd(),'tensorboard/others/train_' + phone + '_' + model + model), tf.get_default_graph())
        test_writer = tf.summary.FileWriter( os.path.join(os.getcwd(),'tensorboard/others/test_' + phone + '_' + model + model))
        epoch_writer = tf.summary.FileWriter( os.path.join(os.getcwd(),'tensorboard/others/epoch_' + phone + '_' + model + model))
        epoch_test_writer = tf.summary.FileWriter( os.path.join(os.getcwd(),'tensorboard/others/epoch_test_' + phone + '_' + model + model))
        
        run_options = tf.RunOptions(report_tensor_allocations_upon_oom = True)
        
        for i in range(start_iter, num_train_iters):
            print(i)
            iter_start = time.time()
            print("Training " + str((i + 1) * batch_size) + "/" + str(total_train_data))
            
            be = (num_train_batches - 1) * batch_size
            en = num_train_batches * batch_size
            idx_train = range(be,en)
            lq_images = train_data[idx_train]
            hq_images = train_target[idx_train]
            
            num_train_batches -= 1
            
            for j in range(g_iters):
                # train generator
                [summary, loss_temp, acc, temp] = sess.run([merge, loss_final, loss_gen, train_step_gen],
                                                feed_dict={lq_: lq_images, hq_: hq_images, is_train: 1}, options = run_options)
                
                train_loss_gen += loss_temp / (load_step * g_iters)
            for j in range(d_iters):
                # train discriminator
                [summary, accuracy_temp, loss_temp, temp] = sess.run([merge, discrim_accuracy, loss_discrim, train_step_disc],
                                                feed_dict={lq_:lq_images, hq_:hq_images, is_train: 1}, options = run_options)
                train_loss_disc += loss_temp / (load_step * d_iters)
                train_acc_discrim += accuracy_temp / (load_step * d_iters)
            train_writer.add_summary(summary, i)
            if (i % iters_per_epoch == 0 or i == num_train_iters-1):
                epoch_writer.add_summary(summary, math.ceil(i / iters_per_epoch))
                
            print("Iteration Time Elapsed: " + str(time.time() - iter_start))
                
            if((i + 1) % load_step == 0 or i % iters_per_epoch == 0 or i == num_train_iters-1):

                # test generator and discriminator CNNs
                test_losses_gen = np.zeros((1, 9))
                test_accuracy_disc = 0.0

                for j in range(num_test_batches):

                    be = j * batch_size
                    en = (j+1) * batch_size

                    lq_images = test_data[be:en]
                    hq_images = test_target[be:en]

                    [summary, accuracy_disc, losses] = sess.run([merge, discrim_accuracy, \
                                 [loss_final, loss_gen, loss_pixel, loss_perceptual, loss_discrim, accuracy_real, accuracy_fake, loss_psnr, loss_ssim]], \
                                    feed_dict={lq_: lq_images, hq_: hq_images, is_train: 0})
                    test_writer.add_summary(summary, i)
                    
                    
                    if (i % iters_per_epoch == 0 or i == num_train_iters-1):
                        epoch_test_writer.add_summary(summary, math.ceil(i / iters_per_epoch))

                    test_losses_gen += np.asarray(losses) / num_test_batches
                    test_accuracy_disc += accuracy_disc / num_test_batches

                logs_disc = "Iteration %d, %s | discriminator accuracy | train: %.4g, test: %.4g | discriminator loss | train: %.4g test: %.4g real: %.4g fake: %.4g" % \
                  (i, phone, train_acc_discrim, test_accuracy_disc, train_loss_disc, test_losses_gen[0][4], test_losses_gen[0][5], test_losses_gen[0][6])
                logs_gen = "generator losses | train: %.4g, test: %.4g | gen: %.4g, pixel: %.4g vgg: %.4g PSNR: %.4g SSIM: %.4g\n" % \
                  (train_loss_gen, test_losses_gen[0][0], test_losses_gen[0][1], test_losses_gen[0][2], test_losses_gen[0][3], test_losses_gen[0][7], test_losses_gen[0][8])
                print(logs_disc)
                print(logs_gen)

                # save the results to log file

                logs = open('models/' + phone + model + '.txt', "a")
                
                if(i == 0):
                    logs.write("\nGAN loss: " + str(gan))
                    logs.write("\nTrain Data: " + str(num_train_data))
                    logs.write("\nBatch Size: " + str(batch_size))
                    logs.write("\nNumber of Epochs:" + str(num_train_iters / iters_per_epoch))
                    logs.write("\nIterations: " + str(num_train_iters))
                    logs.write("\nIterations per Epoch: " + str(iters_per_epoch))
                    logs.write("\nTrain Size: " + str(train_size))
                    logs.write("\nTest Data Size: " + str(test_size))
                    logs.write("\nLoad Iteration: " + str(load_step))
                    logs.write("\nLearning Rate (Gen): " + str(learning_rate_gen))
                    logs.write("\nLearning Rate (Disc): " + str(learning_rate_disc))
                    logs.write("\nStart Iteration: " + str(start_iter))
                    logs.write("\nInitialization: " + init)   
                    logs.write('\n')
                    logs.write('\n')
                    
                gen_hq_crops = sess.run(gen_hq, feed_dict={lq_: test_crops, hq_: hq_images, is_train: 0}, options = run_options)
                if ((i + 1) % load_step == 0 or i == num_train_iters-1):
                    logs.write(logs_gen)
                    logs.write(logs_disc)
                    logs.write('\n')
                
                logs.close()
                idx = 0
                for crop in gen_hq_crops:
                    crop = (crop + 1) / 2
                    before_after = np.hstack((np.reshape(test_crops[idx], [PATCH_HEIGHT, PATCH_WIDTH, 3]), crop, (np.reshape(test_gt[idx], [PATCH_HEIGHT, PATCH_WIDTH, 3]))))
                    plt.imsave('validation_results/' + str(phone) + model + "_" + str(idx) + '_iteration_' + str(i) + '.jpg', before_after)
                    idx += 1
                
                if((i + 1) % load_step == 0): 
                    train_loss_gen = 0.0
                    train_loss_disc = 0.0
                    train_acc_discrim = 0.0

                # save the model that corresponds to the current iteration
                if(i % iters_per_epoch == 0 and i != 0):
                    epoch_saver.save(sess, 'models/' + str(phone) + "_"  + model + '_epoch_' + str(int(math.floor(i / iters_per_epoch))) + '.ckpt', write_meta_graph=False)
                if((i + 1) % load_step == 0 or i == num_train_iters-1):
                    saver.save(sess, 'models/' + str(phone) + "_"  + model + '_iteration_' + str(i) + '.ckpt', write_meta_graph=False)
                print("Loading new batch...")
                
            # reload a next batch of training data
            if(reload and num_train_batches == 0):
                print("Loading " + str(i))
                del train_data
                del train_target
                print((data_idx % load_per_epoch) * train_size)
                print(((data_idx % load_per_epoch) + 1) * train_size)
                train_data, train_target, reload = load_train_data(phone, dataset_dir, train_size, PATCH_SIZE, PATCH_WIDTH, target, data_idx % load_per_epoch, train_data_size=num_train_data)
                data_idx += 1
            
            if(num_train_batches == 0):
                num_train_batches = int(math.floor(train_data.shape[0] / batch_size))
                print(num_train_batches)
                indices = np.arange(train_data.shape[0])
                np.random.shuffle(indices)

                train_data = train_data[indices]
                train_target = train_target[indices]
                
                
            print("New Batch Loaded...")
            print("Time Elapsed: " + str(time.time() - start))
                        
        print("Total Time Elapsed: " + str(time.time() - start))
        
        logs = open('models/' + phone + model + '.txt', "a")
        logs.write("\n\nTotal Time Elapsed: " + str(time.time() - start))
        logs.write('\n')
        logs.close()

        train_writer.close()
        test_writer.close()
        epoch_writer.close()
        epoch_test_writer.close()
    
                
                    
                    