In [None]:
from matplotlib import pyplot as plt
from keras.models import Model
from tensorflow.image import ssim
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
import numpy as np
from IPython import display
import tensorflow as tf
import os
import pathlib
import time
import datetime
import glob
import gc

DATASET_PATH = 'task_1/dataset/processed_mirror_hk_dataset_320x256_ac_4/'

BATCH_SIZE = 4

IMG_WIDTH = 320
IMG_HEIGHT = 256

# --------------------------------------------------------------------------------------------------------------- #

# data distribution 

# 209 subjects 
# 150 training   72%
# 30 validation  14%
# 29 testing     14%

# split in train - validation - test

training_files = glob.glob(DATASET_PATH + 'train/*.png')
test_files = glob.glob(DATASET_PATH + 'test/*.png')
test_files.sort()
trainingset_size = len(training_files)

validation_files = []
new_test_files = []
len_file = len(test_files[1])

i=0
for file in test_files:
    sub = test_files[i][len_file-11:len_file-8]
    if int(sub) <= 180 :
        validation_files.append(file)
    else :
        new_test_files.append(file)
    i = i+1


validationset_size = len(validation_files)

test_files = new_test_files
testset_size = len(test_files)

DATASET_PATH  = pathlib.Path(DATASET_PATH)

# --------------------------------------------------------------------------------------------------------------- #

EPOCHS = 430
VALIDATION_START_EPOCHS = 300

NSTEPS = trainingset_size * EPOCHS

# dataset is made of RBG images 
OUTPUT_CHANNELS = 3

# buffer size is equal to training set size
BUFFER_SIZE = trainingset_size

# log directory 
LOG_DIR = "logs/" + '_nsteps_' + str(NSTEPS) + '_bs_' + str(BATCH_SIZE)  + '/'

In [None]:
def load(image_file):
    # Read and decode an image file to a uint8 tensor
    image = tf.io.read_file(image_file)
    image = tf.io.decode_png(image)
        
    # Split each image tensor into two tensors:
    # - one with a real building facade image
    # - one with an architecture label image 
    w = tf.shape(image)[1]
    w = w // 2

    input_image = image[:, :w, :]
    real_image = image[:, w:, :]

    # Convert both images to float32 tensors
    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image

In [None]:
def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    return input_image, real_image

In [None]:
inp, re = load(training_files[1])
#inp, re = resize(inp,re, IMG_WIDTH,IMG_HEIGHT)

plt.figure(figsize=(6, 6))

display_list = [(inp / 255.0), (re / 255.0)]
title = ['Input Image', 'Ground Truth']

for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i])
    plt.axis('off')

plt.show()

In [None]:
def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(
        stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]

In [None]:
# Normalizing the images to [-1, 1]
def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image

In [None]:
@tf.function()
def random_jitter(input_image, real_image):
    # Resizing to 286x286
    jitter_offset = 30
    input_image, real_image = resize(input_image, real_image, IMG_HEIGHT + jitter_offset, IMG_WIDTH + jitter_offset)

    # Random cropping back to 256x256
    input_image, real_image = random_crop(input_image, real_image)

    if tf.random.uniform(()) > 0.5:
        # Random mirroring
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image

In [None]:
# -------- PRINT JITTERED IMAGES -------- #

# plt.figure(figsize=(6, 6))
# for i in range(4):
#     rj_inp, rj_re = random_jitter(inp, re)
#     plt.subplot(2, 2, i + 1)
#     plt.imshow(rj_inp / 255.0)
#     plt.axis('off')
# plt.show()

In [None]:
def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

In [None]:
def load_image_validation(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = normalize(input_image, real_image)
    return input_image, real_image

In [None]:
def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

In [None]:
train_dataset = tf.data.Dataset.list_files(str(DATASET_PATH / 'train/*.png'))
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

In [None]:
validation_dataset = tf.data.Dataset.from_tensor_slices(validation_files)
validation_dataset = validation_dataset.map(load_image_validation)
validation_dataset = validation_dataset.batch(BATCH_SIZE)

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices(test_files)
test_dataset = test_dataset.map(load_image_test,
                                    num_parallel_calls=tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(1)

In [None]:
print("trainset size : ", len(list(train_dataset)))  
print("validation size : ", len(list(validation_dataset)))    
print("testset size : ", len(list(test_dataset)))

In [None]:
def downsample(filters, size, apply_batchnorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    
    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    if apply_batchnorm:
        result.add(tf.keras.layers.BatchNormalization())

    result.add(tf.keras.layers.LeakyReLU())

    return result

In [None]:
def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                    padding='same',
                                    kernel_initializer=initializer,
                                    use_bias=False))

    result.add(tf.keras.layers.BatchNormalization())

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

In [None]:
def Generator():
    inputs = tf.keras.layers.Input(shape=[IMG_HEIGHT, IMG_WIDTH, OUTPUT_CHANNELS])

    down_stack = [
        downsample(64, 4, apply_batchnorm=False),  
        downsample(128, 4),  
        downsample(256, 4), 
        downsample(512, 4),  
        downsample(512, 4), 
        downsample(512, 4),
      ]

    up_stack = [
        upsample(512, 4, apply_dropout=True),  
        upsample(512, 4), 
        upsample(256, 4), 
        upsample(128, 4),
        upsample(64, 4), 
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                         strides=2,
                                         padding='same',
                                         kernel_initializer=initializer,
                                         activation='tanh')  

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

In [None]:
generator = Generator()
tf.keras.utils.plot_model(generator, show_shapes=True, dpi=64)

In [None]:
# -------- PRINT GENERATED IMAGES -------- #

gen_output = generator(inp[tf.newaxis, ...], training=False)
plt.imshow(gen_output[0, ...])

In [None]:
loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
# -------------------------------- VGG LOSS -------------------------------- #

loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)

vgg = VGG16(weights='imagenet', input_shape=(IMG_HEIGHT, IMG_WIDTH, OUTPUT_CHANNELS), include_top=False)

vgg_relu3_3 = Model(vgg.input, vgg.layers[9].output) # or 13
#vgg_relu3_3.summary()


def normalize_tensor(in_feat):
    norm_factor = tf.math.sqrt(tf.keras.backend.sum(in_feat**2, axis=-1, keepdims=True))
    return in_feat / (norm_factor + 1e-10)


def vgg_loss_3(y_true, y_pred):
    true = vgg_relu3_3(preprocess_input(y_true * 255))
    pred = vgg_relu3_3(preprocess_input(y_pred * 255))    

    t = normalize_tensor(true)
    p = normalize_tensor(pred)
    loss = tf.math.reduce_mean(tf.math.square(t - p))
    
    return loss


def l1_loss(y_true, y_pred) :
    return tf.reduce_mean(tf.abs(y_true - y_pred))


def l2_loss(y_true, y_pred) :
    return tf.reduce_mean(tf.square(y_true - y_pred))


def ssim_loss(y_true, y_pred) :
    return (1 - tf.reduce_mean(ssim(y_true, y_pred, max_val=1.0)))

In [None]:
def generator_loss(gen_output, target):
    
    # ------- l1 loss ------- #
    return l1_loss(target, gen_output)

In [None]:
def validation_step(generator, validation_ds) :
    val_error = []
    for input_image, target in validation_ds:
        
        gen_output = generator(input_image, training=False)
        loss = generator_loss(gen_output, target)
        val_error.append(loss)
        
    return np.mean(val_error)

In [None]:
def save_models(string, generator) :
    generator.save('models/' + string)
    return

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
def generate_images(model, test_input, tar):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15, 15))

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        
        # Getting the pixel values in the [0, 1] range to plot.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    
    plt.show()

In [None]:
def save_images(model, test_input, tar, step):
    prediction = model(test_input, training=True)
    plt.figure(figsize=(15, 15))

    display_list = [test_input[0], tar[0], prediction[0]]
    title = ['Input Image', 'Ground Truth', 'Predicted Image']

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        
        # Getting the pixel values in the [0, 1] range to plot.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    #plt.show()
    filename = 'image_' + str(step) + '.jpg'
    plt.savefig(filename)

In [None]:
summary_writer = tf.summary.create_file_writer(
    LOG_DIR + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))

In [None]:
@tf.function
def train_step(input_image, target, step):
    with tf.GradientTape() as gen_tape :
        gen_output = generator(input_image, training=True)

        gen_l1_loss = generator_loss(gen_output, target)

    generator_gradients = gen_tape.gradient(gen_l1_loss,
                                          generator.trainable_variables)
    
    generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))

    with summary_writer.as_default():
        tf.summary.scalar('gen_loss', gen_l1_loss, step=step//trainingset_size)

In [None]:
def fit(train_ds, validation_ds, test_ds, steps):
    
    example_input, example_target = next(iter(test_ds.take(1)))
    start = time.time()
    min_val_error = float("inf")
    count = 0

    for step, (input_image, target) in train_ds.repeat().take(steps).enumerate():
        
        if (step) % trainingset_size == 0:
            display.clear_output(wait=True)

            if step != 0:
                print(f'Time taken for an epoch: {time.time()-start:.2f} sec\n')

            start = time.time()

            generate_images(generator, example_input, example_target)
            print(f"Step: {step//trainingset_size}k")
            
            
        train_step(input_image, target, step)
        
        if ((count +1) % trainingset_size) == 0 and ((count +1) >= (trainingset_size * VALIDATION_START_EPOCHS)) :
            val_error = validation_step(generator, validation_ds)
            if val_error < min_val_error :
                min_val_error = val_error

                # the last one has the best performance
                filename = 'best_' + str(step) + '.h5'
                save_images(generator, example_input, example_target, count+1)
                save_models(filename, generator)
                
        # Training step
        if (step+1) % 10 == 0:
            print('.', end='', flush=True)
               
        count = count +1 

In [None]:
fit(train_dataset, validation_dataset, test_dataset, steps= NSTEPS)
save_models('last_model_.h5', generator)