# Settings

In [None]:
%env TF_KERAS = 1
import os
sep_local = os.path.sep
print(sep_local)

In [None]:
print(os.getcwd())

In [None]:
import sys
os.chdir('..' + sep_local +'..' + sep_local +'..' + sep_local + '..' + sep_local + '..' + sep_local + '..') # For Linux import
print(os.getcwd())

In [None]:
import tensorflow as tf
print(tf.__version__)

# Dataset loading

In [None]:
dataset_name='CBSD68'

In [None]:
images_dir =  '.' + sep_local + 'data' + sep_local + '.CBSD68'
validation_percentage = 10
valid_format = 'png'

In [None]:
images_dir

In [None]:
from training.generators.file_image_generator import create_image_lists, get_generators

In [None]:
imgs_list = create_image_lists(
    image_dir=images_dir, 
    validation_pct=validation_percentage, 
    valid_imgae_formats=valid_format
)

In [None]:
from PIL import Image
trace_image = Image.open(images_dir+sep_local+'original'+sep_local+'{:04d}.png'.format(66))

In [None]:
trace_image

In [None]:
image_size_original=(481, 321, 3)
scale = 2
image_size = list(map(lambda x: x//scale , image_size_original[:-1])) + [image_size_original[-1]]
image_size = (*image_size,)
batch_size = 16
latents_dim = 150
intermediate_dim = 50

In [None]:
image_size

In [None]:
training_generator, testing_generator = get_generators(
    images_list=imgs_list, 
    image_dir=images_dir, 
    image_size=image_size, 
    batch_size=batch_size, 
    class_mode=None
)

## input is half of the output

In [None]:
inputs_shape = list(map(lambda x: x//2 , image_size[:-1])) + [image_size[-1]]
inputs_shape = (*inputs_shape, )

In [None]:
image_size, inputs_shape

In [None]:
shrink_fn = lambda image: tf.image.resize(image, inputs_shape[:-1])
enlarge_fn = lambda image: tf.image.resize(image, image_size[:-1])

In [None]:
def generator_reshaper(generator): 
    while True:
        batch = next(generator)
        yield shrink_fn(batch), batch


In [None]:
train_ds = tf.data.Dataset.from_generator(
    lambda: generator_reshaper(training_generator), 
    output_types= (tf.float32, tf.float32),
    output_shapes=(tf.TensorShape((batch_size, ) +  inputs_shape), tf.TensorShape((batch_size, ) + image_size)),
)

test_ds = tf.data.Dataset.from_generator(
    lambda: generator_reshaper(testing_generator), 
    output_types= (tf.float32, tf.float32),
    output_shapes=(tf.TensorShape((batch_size, ) +  inputs_shape), tf.TensorShape((batch_size, ) + image_size)),
)



In [None]:
_instance_scale=1.0
for data in train_ds:
    _instance_scale = float(data[0].numpy().max())
    break

In [None]:
_instance_scale

In [None]:
import numpy as np
from collections.abc import Iterable
if isinstance(image_size, Iterable):
    _outputs_shape = np.prod(image_size)
_outputs_shape    

# Model's Layers definition

In [None]:
kernel_size=3
stride = 2

In [None]:
c = list(map(lambda x: x// (stride*stride), image_size[:-1]))
c = (*c, intermediate_dim)
c

In [None]:
enc_lays = [

    tf.keras.layers.Flatten(),
    # No activation
    tf.keras.layers.Dense(latents_dim)
]

dec_lays = [
    tf.keras.layers.Dense(units=np.product(c), activation=tf.nn.relu),
    tf.keras.layers.Reshape(target_shape=c),
    
    tf.keras.layers.Conv2D(filters=intermediate_dim, kernel_size=kernel_size, strides=(stride, stride), padding="SAME", activation='relu'),
    tf.keras.layers.UpSampling2D(size=(2, 2)),
    tf.keras.layers.UpSampling2D(size=(2, 2)),
    
    tf.keras.layers.Conv2D(filters=intermediate_dim, kernel_size=kernel_size, strides=(stride, stride), padding="SAME", activation='relu'),
    tf.keras.layers.UpSampling2D(size=(2, 2)),
    tf.keras.layers.UpSampling2D(size=(2, 2)),
    
    tf.keras.layers.Conv2D(filters=intermediate_dim, kernel_size=kernel_size, strides=(stride, stride), padding="SAME", activation='relu'),
    tf.keras.layers.UpSampling2D(size=(2, 2)),

    tf.keras.layers.Conv2DTranspose(filters=3, kernel_size=kernel_size, strides=(1, 1), padding="SAME")
]

# Model definition

In [None]:
model_name = dataset_name+'_Conv_Progressive_Upsampling_psnr_ssim_total_variation'
#windows
#experiments_dir='..' + sep_local + '..' + sep_local +'..' + sep_local + '..' + sep_local + '..'+sep_local+'experiments'+sep_local + model_name

#linux 
experiments_dir=os.getcwd()+ sep_local  +'experiments'+sep_local + model_name

In [None]:
variables_params = \
[
    {
        'name': 'inference',  #'constructive',
        'inputs_shape':inputs_shape,
        'outputs_shape':latents_dim,
        'layers': enc_lays
    }

    ,
    
        {
        'name': 'generative', #'upsampler', 
        'inputs_shape':latents_dim,
        'outputs_shape':image_size,
        'layers':dec_lays
    }
]

In [None]:
from os.path import abspath
from utils.data_and_files.file_utils import create_if_not_exist
_restore = os.path.join(experiments_dir, 'var_save_dir')
create_if_not_exist(_restore)
absolute = abspath(_restore)
print("Restore_dir",absolute)
absolute = abspath(experiments_dir)
print("Recording_dir",absolute)
print("Current working dir",os.getcwd())

In [None]:
from training.autoencoding_basic.transformative.AE import autoencoder as AE

In [None]:
ae = AE( 
    name=model_name,
    latents_dim=latents_dim,
    batch_size=batch_size,
    variables_params=variables_params, 
    filepath=None,#to restore trained model, set filepath=_restore
    episode_len=1
    )

In [None]:
image_size, inputs_shape

In [None]:
from evaluation.quantitive_metrics.structural_similarity import prepare_ssim_multiscale
from evaluation.quantitive_metrics.peak_signal_to_noise_ratio import prepare_psnr
from evaluation.quantitive_metrics.total_variation import prepare_total_variation
from statistical.losses_utilities import similarity_to_distance
from statistical.ae_losses import expected_loglikelihood_with_lower_bound as ellwlb

In [None]:
outputs_flat_shape = [ae.batch_size]+ae.get_outputs_shape()
psnr = similarity_to_distance(prepare_psnr(outputs_flat_shape))
ssim = similarity_to_distance(prepare_ssim_multiscale(outputs_flat_shape))
total_variation = prepare_psnr(outputs_flat_shape)

In [None]:
loss= lambda x_true, x_logits: psnr(x_true, x_logits)+ssim(x_true, x_logits)+total_variation(x_true, x_logits)

In [None]:

ae.compile(loss={'x_logits': loss})

# Callbacks

In [None]:
# added for linux warning suppression
import logging
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)


from training.callbacks.trace_image_reconstruction import trace_reconstruction
from training.callbacks.save_model import ModelSaver

In [None]:
es = tf.keras.callbacks.EarlyStopping(
    monitor='loss', 
    min_delta=1e-12, 
    patience=12, 
    verbose=1, 
    restore_best_weights=False
)

In [None]:
ms = ModelSaver(filepath=_restore,save_freq=5)

In [None]:
csv_dir = os.path.join(experiments_dir, 'csv_dir')
create_if_not_exist(csv_dir)
csv_dir = os.path.join(csv_dir, model_name+'.csv')
csv_log = tf.keras.callbacks.CSVLogger(csv_dir, append=True)
absolute = abspath(csv_dir)
print("Csv_dir",absolute)

In [None]:
image_reconstuction_dir = os.path.join(experiments_dir, 'image_reconstuction_dir')
create_if_not_exist(image_reconstuction_dir)
absolute = abspath(image_reconstuction_dir)
print("image_reconstuction_dir",absolute)

In [None]:
image = shrink_fn(np.asarray(trace_image)).numpy()/255.0
img_reconst = trace_reconstruction(filepath=image_reconstuction_dir, image=image, gen_freq=5)

# Model Training

In [None]:
ae.fit(
    x=train_ds,
    input_kw=None,
    steps_per_epoch=int(1e4),
    epochs=int(1e6), 
    verbose=2,
    callbacks=[ es, ms, csv_log, img_reconst],
    workers=-1,
    use_multiprocessing=True,
    validation_data=test_ds,
    validation_steps=int(1e4)
)