# Almost the same as in the train.py script. Used to test new approaches and hyperparameters.


In [None]:
import sys
import datetime
import tensorflow as tf
import tensorflow.keras as keras
sys.path.append('..')
from scripts.inpaint_ops import d_wasserstein_loss, g_wasserstein_loss, min_max_tf_scaler2, mask_2_crop
from scripts.data_processing_utils import DataLoader
from models.networks.generator import dem_fill_net
from models.networks.global_critic import build_wgan_discriminator_global
from models.networks.local_critic import build_wgan_discriminator_local
from models.inpaintModel import WGAN_GP, GANMonitor
from keras.optimizers.schedules import ExponentialDecay

DataLoader = DataLoader()

# define version of Dataset:
version = 'V1.7'


dem = f'/home/robin/Nextcloud_sn/Masterarbeit/DataSet/{version}/DEMs'
inner_outer = f'/home/robin/Nextcloud_sn/Masterarbeit/DataSet/{version}/Inner-Outer Mask'
intersection = f'/home/robin/Nextcloud_sn/Masterarbeit/DataSet/{version}/Intersection Mask'
intersection_small = f'/home/robin/Nextcloud_sn/Masterarbeit/DataSet/{version}/Intersection Mask Small'



# define lists
dem_list = DataLoader.populate_list(dem)
inner_outer_list = DataLoader.populate_list(inner_outer)
intersection_list = DataLoader.populate_list(intersection)
intersection_small_list = DataLoader.populate_list(intersection_small)

# shuffle the lists
dem_list, inner_outer_list, intersection_list, intersection_small_list = DataLoader.shuffle_lists([dem_list, inner_outer_list, intersection_list,
                                                                                                   intersection_small_list])

batch_size = 16
large_dim = 256
small_dim = 64

dem_tensor = tf.data.Dataset.from_generator(
    DataLoader.load_dem_cv2,
    args=[dem_list],
    output_signature = tf.TensorSpec(shape = (large_dim,large_dim,1), dtype = tf.float32)
).batch(batch_size)

train_inner_outer = tf.data.Dataset.from_generator(
    DataLoader.load_mask,
    args=[inner_outer_list],
    output_signature = tf.TensorSpec(shape = (large_dim,large_dim,1), dtype = tf.float32)
).batch(batch_size)

train_intersection = tf.data.Dataset.from_generator(
    DataLoader.load_mask,
    args=[intersection_list],
    output_signature = tf.TensorSpec(shape = (large_dim,large_dim,1), dtype = tf.float32)
).batch(batch_size)

train_intersection_small = tf.data.Dataset.from_generator(
    DataLoader.load_mask,
    args=[intersection_small_list, (small_dim,small_dim)],
    output_signature = tf.TensorSpec(shape = (small_dim,small_dim,1), dtype = tf.float32)
).batch(batch_size)

# normalize the DEMs

train_norm_dems = dem_tensor.map(min_max_tf_scaler2, num_parallel_calls = tf.data.AUTOTUNE)


# define the models
generator = dem_fill_net(image_size=large_dim)
global_discriminator = build_wgan_discriminator_global(image_size=large_dim)
local_discriminator = build_wgan_discriminator_local(image_size=small_dim)

wgan = WGAN_GP(generator = generator, global_critic = global_discriminator, local_critic = local_discriminator)

In [None]:
# set the prefetch buffer size
train_dem = train_norm_dems.cache().prefetch(buffer_size = tf.data.AUTOTUNE)
train_inner_outer = train_inner_outer.cache().prefetch(buffer_size = tf.data.AUTOTUNE)
train_intersection = train_intersection.cache().prefetch(buffer_size = tf.data.AUTOTUNE)
train_intersection_small = train_intersection_small.cache().prefetch(buffer_size = tf.data.AUTOTUNE)

In [None]:
# zip the datasets
dataset = tf.data.Dataset.zip((train_dem, train_inner_outer, train_intersection,
                               train_intersection_small))

In [None]:
# Define the callbacks:
log_dir = '/home/robin/Nextcloud_sn/Masterarbeit/Results/logs'


gan_monitor = GANMonitor(log_dir = log_dir,
                        save_gen_path = '/home/robin/Nextcloud_sn/Masterarbeit/Results/networks/generator',
                        data = dataset)

In [None]:
NUM_EPOCHS = 10
LR = 0.0001
LR_shedule = ExponentialDecay(initial_learning_rate = LR, decay_steps = 100000, decay_rate = 0.96, staircase = True)

# define the optimizers
d_optimzer_local = keras.optimizers.Adam(learning_rate = LR_shedule, beta_1 = 0.5, beta_2 = 0.999)
d_optimizer_global = keras.optimizers.Adam(learning_rate = LR_shedule, beta_1 = 0.5, beta_2 = 0.999)
g_optimizer = keras.optimizers.Adam(learning_rate = LR_shedule, beta_1 = 0.5, beta_2 = 0.999)

# define the loss functions
d_loss_fn = d_wasserstein_loss
g_loss_fn = g_wasserstein_loss

# compile the models
wgan.compile(d_optimizer_local = d_optimzer_local, d_optimizer_global = d_optimizer_global, g_optimizer = g_optimizer,
                d_loss_fn = d_loss_fn, g_loss_fn = g_loss_fn)



In [None]:
wgan.fit(dataset, epochs = NUM_EPOCHS, callbacks = [gan_monitor])