In [None]:
#!pip -q install colorlog
#!pip -q install keras_radam
#!pip -q install keras_tqdm
#!pip install tensorflow==2.1.0-rc1

# Settings

In [None]:
%env TF_KERAS = 1
import os
sep_local = os.path.sep

import sys
sys.path.append('..'+sep_local+'..')
print(sep_local)

In [None]:
import tensorflow as tf
print(tf.__version__)

In [None]:
os.chdir('..'+sep_local+'..'+sep_local+'..'+sep_local+'..')
print(os.getcwd())

In [None]:
#%cd /content/
#!git clone https://github.com/kkahloots/Generative_Models.git # this is for loading git with correct brach
#%cd /content/Generative_Models/

In [None]:
#os.chdir('/content/Generative_Models/')
print(os.getcwd())

# Dataset loading

In [None]:
dataset_name = 'dsprites'
import tensorflow_datasets as tfds

info = tfds.builder(dataset_name).info

print(info)

In [None]:
intermediate_dim = 20
inputs_shape=(64, 64, 1) # image shape
batch_size = 32
latents_dim = 6

In [None]:
TRAIN_BUF = 600
TEST_BUF = 100

In [None]:
import tensorflow_datasets as tfds
# Construct a tf.data.Dataset
train_ds = tfds.load(name=dataset_name, split=tfds.Split.TRAIN).shuffle(TRAIN_BUF).batch(batch_size)
try:
    test_ds = tfds.load(name=dataset_name, split=tfds.Split.TEST).shuffle(TEST_BUF).batch(batch_size)
except:
    test_ds = tfds.load(name=dataset_name, split=tfds.Split.TRAIN).shuffle(TEST_BUF).batch(batch_size)

In [None]:
from data.gt_load.datasets import load

In [None]:
#DATA_DOWN_PATH = '..'+sep_local+'..'+sep_local+'..'+sep_local+'data'
DATA_DOWN_PATH = os.getcwd() + sep_local+'data'
Script_dir = os.getcwd() + sep_local+'data'+sep_local+'download_gt_data.sh'
# Script call to download "dsprites_full" dataset_name 
!/bin/bash $Script_dir -f $DATA_DOWN_PATH -d $dataset_name

In [None]:
DATA_PATH =  os.getcwd() + sep_local+'data' +sep_local+'.gt_datasets'
#absolute = abspath(DATA_PATH)
print(DATA_PATH)

In [None]:
!ls /content/Generative_Models/data/.gt_datasets/dsprites

In [None]:
eval_dataset = load(dataset_name='dsprites_full',dataset_path=DATA_PATH)

In [None]:
_instance_scale=1.0
for data in train_ds:
    _instance_scale = float(data['image'].numpy().max())
    break

In [None]:
_instance_scale

In [None]:
import numpy as np
from collections.abc import Iterable

In [None]:
if isinstance(inputs_shape, Iterable):
    _outputs_shape = np.prod(inputs_shape)

In [None]:
_outputs_shape

# Model's Layers definition

In [None]:
menc_lays = [tf.keras.layers.Dense(units=intermediate_dim//2, activation='relu'),
            tf.keras.layers.Dense(units=intermediate_dim//2, activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(units=latents_dim)]

venc_lays = [tf.keras.layers.Dense(units=intermediate_dim//2, activation='relu'),
            tf.keras.layers.Dense(units=intermediate_dim//2, activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(units=latents_dim)]

dec_lays = [tf.keras.layers.Dense(units=latents_dim, activation='relu'),
            tf.keras.layers.Dense(units=intermediate_dim, activation='relu'),
            tf.keras.layers.Dense(units=_outputs_shape),
            tf.keras.layers.Reshape(inputs_shape)]

# Model definition

In [None]:
model_name = dataset_name+'GAVAE_Dense_reconst_ell'
experiments_dir='experiments'+sep_local+model_name

In [None]:
from training.adversarial_basic.generative_adversarial.autoencoders.VAAE import VAAE as AE

In [None]:
variables_params = \
[
    {
        'name': 'inference_mean', 
        'inputs_shape':inputs_shape,
        'outputs_shape':latents_dim,
        'layers': menc_lays
    },
    
    {
        'name': 'inference_logvariance', 
        'inputs_shape':inputs_shape,
        'outputs_shape':latents_dim,
        'layers': venc_lays
    },
        {
        'name': 'generative', 
        'inputs_shape':latents_dim,
        'outputs_shape':inputs_shape,
        'layers':dec_lays
    }
]

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist

In [None]:
_restore = os.path.join(experiments_dir, 'var_save_dir')

In [None]:
create_if_not_exist(_restore)
_restore

In [None]:
#to restore trained model, set filepath=_restore

In [None]:
ae = AE( 
    name=model_name,
    latents_dim=latents_dim,
    batch_size=batch_size,
    variables_params=variables_params, 
    filepath=None
    )

In [None]:
from statistical.basic_adversarial_losses import \
    create_generative_discriminator_real_losses, \
    create_generative_discriminator_fake_losses, \
    create_generative_generator_fake_losses
    

generative_discriminator_losses = {
    'generative_discriminator_real_outputs': create_generative_discriminator_real_losses,
    'generative_discriminator_fake_outputs': create_generative_discriminator_fake_losses,
    'generative_generator_fake_outputs': create_generative_generator_fake_losses,
}

In [None]:
#ae.compile(metrics=None)
discr2gen_rate = 0.001
gen2trad_rate = 0.5

ae.compile(
    adversarial_losses=generative_discriminator_losses,
    adversarial_weights={'generator_weight': gen2trad_rate, 'discriminator_weight': discr2gen_rate}
)

# Callbacks

In [None]:
from training.callbacks.sample_generation import SampleGeneration
from training.callbacks.save_model import ModelSaver
from training.callbacks.disentangle_supervied import DisentanglementSuperviedMetrics
from training.callbacks.disentangle_unsupervied import DisentanglementUnsuperviedMetrics

In [None]:
gts_csv = os.path.join(experiments_dir, 'csv_dir', 'gts_metrics')
gtu_csv = os.path.join(experiments_dir, 'csv_dir', 'gtu_metrics')

In [None]:
gts_mertics = DisentanglementSuperviedMetrics(            
    ground_truth_data=eval_dataset,
    representation_fn=lambda x: ae.encode(x),
    random_state=np.random.RandomState(0),
    file_Name=gts_csv,
    num_train=10000,
    num_test=200,
    batch_size=batch_size,
    continuous_factors=False,
    gt_freq=10
)
gtu_mertics = DisentanglementUnsuperviedMetrics(            
    ground_truth_data=eval_dataset,
    representation_fn=lambda x: ae.encode(x),
    random_state=np.random.RandomState(0),
    file_Name=gtu_csv,
    num_train=10000,
    num_test=200,
    batch_size=batch_size,
    gt_freq=10
)

In [None]:
es = tf.keras.callbacks.EarlyStopping(
    monitor='loss', 
    min_delta=1e-12, 
    patience=5, 
    verbose=1, 
    restore_best_weights=True
)

In [None]:
ms = ModelSaver(filepath=_restore,save_freq=5)

In [None]:
csv_dir = os.path.join(experiments_dir, 'csv_dir')
create_if_not_exist(csv_dir)
csv_dir = os.path.join(csv_dir, ae.name+'.csv')
csv_log = tf.keras.callbacks.CSVLogger(csv_dir, append=True)
csv_dir

In [None]:
image_gen_dir = os.path.join(experiments_dir, 'image_gen_dir')
create_if_not_exist(image_gen_dir)

In [None]:
sg = SampleGeneration(latents_shape=latents_dim, filepath=image_gen_dir, gen_freq=5, save_img=True, gray_plot=True)

In [None]:
import numpy as np

# Model Training

In [None]:
ae.fit(
    x=train_ds,
    input_kw='image',
    steps_per_epoch=1000,
    epochs=int(1e6), 
    verbose=2,
    callbacks=[ es, ms, csv_log, sg, gts_mertics, gtu_mertics],
    workers=-1,
    use_multiprocessing=True,
    validation_data=test_ds,
    validation_steps=1000
)

# Model Evaluation

## inception_score

In [None]:
from evaluation.generativity_metrics.inception_metrics import inception_score

In [None]:
is_mean, is_sigma = inception_score(ae, tolerance_threshold=1e-6, max_iteration=200)
print(f'inception_score mean: {is_mean}, sigma: {is_sigma}')

## Frechet_inception_distance

In [None]:
from evaluation.generativity_metrics.inception_metrics import frechet_inception_distance

In [None]:
fis_score = frechet_inception_distance(ae, training_generator, tolerance_threshold=1e-6, max_iteration=10, batch_size=32)
print(f'frechet inception distance: {fis_score}')

## perceptual_path_length_score

In [None]:
from evaluation.generativity_metrics.perceptual_path_length import perceptual_path_length_score

In [None]:
ppl_mean_score = perceptual_path_length_score(ae, training_generator, tolerance_threshold=1e-6, max_iteration=200, batch_size=32)
print(f'perceptual path length score: {ppl_mean_score}')

## precision score

In [None]:
from evaluation.generativity_metrics.precision_recall import precision_score

In [None]:
_precision_score = precision_score(ae, training_generator, tolerance_threshold=1e-6, max_iteration=200)
print(f'precision score: {_precision_score}')

## recall score

In [None]:
from evaluation.generativity_metrics.precision_recall import recall_score

In [None]:
_recall_score = recall_score(ae, training_generator, tolerance_threshold=1e-6, max_iteration=200)
print(f'recall score: {_recall_score}')

# Image Generation

## image reconstruction

### Training dataset

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from training.generators.image_generation_testing import reconstruct_from_a_batch

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'reconstruct_training_images_like_a_batch_dir')
create_if_not_exist(save_dir)

reconstruct_from_a_batch(ae, training_generator, save_dir)

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'reconstruct_testing_images_like_a_batch_dir')
create_if_not_exist(save_dir)

reconstruct_from_a_batch(ae, testing_generator, save_dir)

## with Randomness

In [None]:
from training.generators.image_generation_testing import generate_images_like_a_batch

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'generate_training_images_like_a_batch_dir')
create_if_not_exist(save_dir)

generate_images_like_a_batch(ae, training_generator, save_dir)

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'generate_testing_images_like_a_batch_dir')
create_if_not_exist(save_dir)

generate_images_like_a_batch(ae, testing_generator, save_dir)

### Complete Randomness

In [None]:
from training.generators.image_generation_testing import generate_images_randomly

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'random_synthetic_dir')
create_if_not_exist(save_dir)

generate_images_randomly(ae, save_dir)

In [None]:
from training.generators.image_generation_testing import interpolate_a_batch

In [None]:
from utils.data_and_files.file_utils import create_if_not_exist
save_dir = os.path.join(experiments_dir, 'interpolate_dir')
create_if_not_exist(save_dir)

interpolate_a_batch(ae, testing_generator, save_dir)