In [5]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np
import utils
import os
import vgg19
import model

In [8]:
image_height = 100
image_width = 100
image_channels = 3
image_size = image_height * image_width * image_channels
batch_size = 50
learning_rate = 5e-4
num_train_iters = 20000
w_content = 10
w_color = 0.5
w_texture = 1
w_tv = 2000
eval_step = 1000

In [None]:
with tf.Graph.as_default(), tf.compat.v1.Session() as sess :
    input_ = tf.compat.v1.placeholder(tf.float32, [None, image_size])
    input_image = tf.reshape(input_, [-1, image_height, image_width, image_channels])
    
    output_ = tf.compat.v1.placeholder(tf.float32, [None, image_size])
    output_image = tf.reshape(output_, [-1, image_height, image_width, image_channels])
    
    adv_ = tf.compat.v1.placeholder(tf.float32, [None, 1])
    
    enhanced = model.Generator(input_image)
    
    # Grayscale Images of Enhanced and Output
    enhanced_gray = tf.reshape(tf.image.rgb_to_grayscale(enhanced), [-1, image_height, image_width])
    output_gray = tf.reshape(tf.image.rgb_to_grayscale(output_image), [-1, image_height, image_width])
    
    # Push to Discriminator
    adversarial_ = tf.multiply(enhanced_gray, 1 - adv_) + tf.multiply(output_gray, adv_)
    adversarial_image = tf.reshape(adversarial_, [-1, image_height, image_width, 1])
    
    discriminator_pred = model.Discriminator(adversarial_image)
    
    # Losses Calculation
    # Loss #1 - Texture Loss ( Smoothness Loss )
    discriminator_target = tf.concat([adv_, 1 - adv_], 1)

    loss_discriminator = -tf.reduce_sum(discriminator_target * tf.compat.v1.log(tf.clip_by_value(discriminator_pred, 1e-10, 1.0)))
    loss_texture = -loss_discriminator

    correct_predictions = tf.equal(tf.argmax(discriminator_pred, 1), tf.argmax(discriminator_target, 1))
    discriminator_accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))
    
    # Loss #2 - Cotent Loss ( Using VGG-19 Network )
    content_layer = 'relu5_4'

    enhanced_vgg = vgg19.vgg19NetworkPass(vgg19.preprocess(enhanced * 255))
    output_vgg = vgg19.vgg19NetworkPass(vgg19.preprocess(output_image * 255))

    content_size = utils._tensor_size(output_vgg[content_layer]) * batch_size
    loss_content = 2 * tf.nn.l2_loss(enhanced_vgg[content_layer] - output_vgg[content_layer]) / content_size
    
    # Loss #3 - Color Loss ( Using Gaussian Blur )
    enhanced_blur = utils.blur(enhanced)
    output_blur = utils.blur(output_image)

    loss_color = tf.reduce_sum(tf.pow(output_blur - enhanced_blur, 2))/(2 * batch_size)
    
    # Total Variational Loss
    batch_shape = (batch_size, image_height, image_width, image_channels)
    tv_y_size = utils._tensor_size(enhanced[:,1:,:,:])
    tv_x_size = utils._tensor_size(enhanced[:,:,1:,:])
    y_tv = tf.nn.l2_loss(enhanced[:,1:,:,:] - enhanced[:,:batch_shape[1]-1,:,:])
    x_tv = tf.nn.l2_loss(enhanced[:,:,1:,:] - enhanced[:,:,:batch_shape[2]-1,:])
    loss_tv = 2 * (x_tv/tv_x_size + y_tv/tv_y_size) / batch_size
    
    # Final Loss
    loss_generator = w_content * loss_content + w_texture * loss_texture + w_color * loss_color + w_tv * loss_tv
    
    # PSNR Loss
    enhanced_flat = tf.reshape(enhanced, [-1, image_size])

    loss_mse = tf.reduce_sum(tf.pow(output_ - enhanced_flat, 2))/(image_size * batch_size)
    loss_psnr = 20 * utils.log10(1.0 / tf.sqrt(loss_mse))

    # optimize parameters of image enhancement (generator) and discriminator networks

    generator_vars = [v for v in tf.compat.v1.global_variables() if v.name.startswith("generator")]
    discriminator_vars = [v for v in tf.compat.v1.global_variables() if v.name.startswith("discriminator")]

    train_step_gen = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss_generator, var_list=generator_vars)
    train_step_disc = tf.compat.v1.train.AdamOptimizer(learning_rate).minimize(loss_discriminator, var_list=discriminator_vars)

    saver = tf.compat.v1.train.Saver(var_list=generator_vars, max_to_keep=100)

    print('Initializing variables')
    sess.run(tf.compat.v1.global_variables_initializer())
 