# Settings

In [None]:
%env TF_KERAS = 1
import os
sep_local = os.path.sep
import sys
sys.path.append('..' + sep_local + '..' + sep_local +'..' + sep_local + '..' + sep_local + '..') # For Windows import
os.chdir('..' + sep_local + '..' + sep_local +'..' + sep_local + '..' + sep_local + '..') # For Linux import
print(sep_local)
print(os.getcwd())

In [None]:
import tensorflow as tf
print(tf.__version__)

# Dataset loading

In [None]:
dataset_name='atari_qbert'

In [None]:
#images_dir = 'D:\\datasets\\Ms_Pacman\\Ms_Pacman\\Test\\'
images_dir = '/home/azeghost/datasets/.qbert/atari_v1/screens/qbert' #Linux
#images_dir = '..\..\..\\data\\atari_v1\\screens\\mspacman'
validation_percentage = 25
valid_format = 'png'

In [None]:
from training.generators.file_image_generator import create_image_lists, get_generators

In [None]:
imgs_list = create_image_lists(
    image_dir=images_dir, 
    validation_pct=validation_percentage, 
    valid_imgae_formats=valid_format
)

In [None]:
scale=1
image_size=(160//scale, 210//scale, 3)

batch_size = 10
EPIS_LEN = 10
EPIS_SHIFT = 5

inputs_shape= image_size
latent_dim = 30
intermediate_dim = 30

In [None]:
train_gen, test_gen = get_generators(
    images_list=imgs_list, 
    image_dir=images_dir, 
    image_size=image_size, 
    batch_size=batch_size, 
    class_mode='episode_flat', 
    episode_len=EPIS_LEN, 
    episode_shift=EPIS_SHIFT
)

In [None]:
import tensorflow as tf

In [None]:
train_ds = tf.data.Dataset.from_generator(
    lambda: train_gen, 
    output_types=(tf.float32, tf.float32) ,
    output_shapes=(tf.TensorShape((batch_size* EPIS_LEN, ) + image_size), 
                   tf.TensorShape((batch_size* EPIS_LEN, ) + image_size)
                  )
)

test_ds = tf.data.Dataset.from_generator(
    lambda: test_gen,     
    output_types=(tf.float32, tf.float32) ,
    output_shapes=(tf.TensorShape((batch_size* EPIS_LEN, ) + image_size), 
                   tf.TensorShape((batch_size* EPIS_LEN, ) + image_size)
                  )
)

In [None]:
_instance_scale=1.0
for data in train_ds:
    _instance_scale = float(data[0].numpy().max())
    break

In [None]:
_instance_scale = 1.0

In [None]:
import numpy as np
from collections.abc import Iterable
if isinstance(inputs_shape, Iterable):
    _outputs_shape = np.prod(inputs_shape)

# Model's Layers definition

In [None]:
enc_lays = [tf.keras.layers.Dense(units=intermediate_dim//2, activation='relu'),
            tf.keras.layers.Dense(units=intermediate_dim//2, activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(units=latent_dim)]

dec_lays = [tf.keras.layers.Dense(units=latent_dim, activation='relu'),
            tf.keras.layers.Dense(units=3*intermediate_dim//2, activation='relu'),
            tf.keras.layers.Dense(units=_outputs_shape),
            tf.keras.layers.Reshape(inputs_shape)]


# Model definition

In [None]:
model_name = dataset_name+'Dense_AE_ssmi'
#windows
#experiments_dir='..' + sep_local + '..' + sep_local +'..' + sep_local + '..' + sep_local + '..'+sep_local+'experiments'+sep_local + model_name

#linux 
experiments_dir=os.getcwd()+ sep_local  +'experiments'+sep_local + model_name

In [None]:
from training.autoencoding_basic.transformative.AE import autoencoder as AE

In [None]:
variables_params = \
[
    {
        'name': 'inference', 
        'inputs_shape':inputs_shape,
        'outputs_shape':latent_dim,
        'layers': enc_lays
    }

    ,
    
        {
        'name': 'generative', 
        'inputs_shape':latent_dim,
        'outputs_shape':inputs_shape,
        'layers':dec_lays
    }
]

In [None]:
from os.path import abspath
from utils.data_and_files.file_utils import create_if_not_exist
_restore = os.path.join(experiments_dir, 'var_save_dir')
create_if_not_exist(_restore)
absolute = abspath(_restore)
print("Restore_dir",absolute)
absolute = abspath(experiments_dir)
print("Recording_dir",absolute)
print("Current working dir",os.getcwd())

In [None]:
ae = AE( 
    name=model_name,
    inputs_shape=inputs_shape,
    outputs_shape=inputs_shape,
    latent_dim=latent_dim,
    batch_size=batch_size,
    variables_params=variables_params, 
    filepath=None#to restore trained model, set filepath=_restore
    )

In [None]:
# from evaluation.quantitive_metrics.psnr import psnr
from evaluation.quantitive_metrics.ssmi import ssmi


In [None]:
ae.compile(loss={'x_logits': ssmi})

# Callbacks

In [None]:
# added for linux warning suppression
import logging
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)


from training.callbacks.sample_generation import SampleGeneration
from training.callbacks.save_model import ModelSaver

In [None]:
es = tf.keras.callbacks.EarlyStopping(
    monitor='loss', 
    min_delta=1e-12, 
    patience=5, 
    verbose=1, 
    restore_best_weights=True
)

In [None]:
ms = ModelSaver(filepath=_restore,save_freq=5)

In [None]:
csv_dir = os.path.join(experiments_dir, 'csv_dir')
create_if_not_exist(csv_dir)
csv_dir = os.path.join(csv_dir, model_name+'.csv')
csv_log = tf.keras.callbacks.CSVLogger(csv_dir, append=True)
absolute = abspath(csv_dir)
print("Csv_dir",absolute)

In [None]:
image_gen_dir = os.path.join(experiments_dir, 'image_gen_dir')
create_if_not_exist(image_gen_dir)
absolute = abspath(image_gen_dir)
print("Image_gen_dir",absolute)

In [None]:
sg = SampleGeneration(latent_shape=latent_dim, filepath=image_gen_dir, gen_freq=5, save_img=True, gray_plot=False)

# Model Training

In [None]:
ae.fit(
    x=train_ds,
    input_kw=None,
    steps_per_epoch=1000,
    epochs=int(1e6), 
    verbose=2,
    callbacks=[ es, ms, csv_log, sg],
    workers=-1,
    use_multiprocessing=True,
    validation_data=test_ds,
    validation_steps=1000
)

# prediction preview

In [None]:
for batch in test_ds:
    break

In [None]:
import matplotlib.pyplot as plt

In [None]:
xt0 = batch[0][:10]
xt1 = batch[1][:10]

In [None]:
xt0 = tf.concat([x for x in xt0], axis=1)
xt1 = tf.concat([x for x in xt1], axis=1)

In [None]:
xt1_pred = ae.predict(batch[0][:10])
xt1_pred = tf.concat([x for x in xt1_pred], axis=1)

In [None]:
plt.figure(figsize=(20, 80), dpi=300)
plt.axis('off')
print('input')
plt.imshow(xt0)
#plt.save()
plt.show()


plt.figure(figsize=(20, 80), dpi=300)
plt.axis('off')
print('output')
plt.imshow(xt1)
#plt.save()
plt.show()

plt.figure(figsize=(20, 80), dpi=300)
plt.axis('off')
print('prediction')
plt.imshow(xt1_pred)
#plt.save()
plt.show()
