In [12]:
import os
import sys
import numpy as np
module_path = os.path.abspath(os.path.join('../..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from data.plot_samples import plot_samples_2d
from data.visu_density import plot_heatmap_2d
from data.dataset_loader import load_and_preprocess_uci
from normalizingflows.flow_catalog import NeuralSplineFlow
from utils.train_utils import train_density_no_tf, train_density_estimation
import tensorflow as tf
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions
import time
from utils.train_utils import sanity_check


### Fetch the data and the related spesifications

In [13]:
dataset = "miniboone"
layers = 8
shape = [64, 64]
delta_count = 15
uci_trainsizes = {"power": 1659917,
                 "gas": 852174,
                 "hepmass": 315123,
                 "miniboone": 29556,
                 "bsds300": 1000000}
trainsize = uci_trainsizes[dataset]
batched_train_data, batched_val_data, batched_test_data, intervals = load_and_preprocess_uci(dataset, batch_size=32) 
sample_batch = next(iter(batched_train_data))
input_shape = int(sample_batch.shape[1])

### create the flow

In [14]:
permutation = tf.cast(np.concatenate((np.arange(input_shape/2,input_shape),np.arange(0,input_shape/2))), tf.int32)
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros(input_shape, tf.float32)) 

bijectors = []

bijector_chain = []
for i in range(layers):
    bijector_chain.append(NeuralSplineFlow(input_dim=input_shape, d_dim=int(input_shape/2)+1, number_of_bins=8, b_interval=intervals))
    bijector_chain.append(tfp.bijectors.Permute(permutation))


bijector = tfb.Chain(bijectors=list(reversed(bijector_chain)), name='chain_of_real_nvp')

flow = tfd.TransformedDistribution(
    distribution=base_dist,
    bijector=bijector
)

# number of trainable variables
n_trainable_variables = flow.trainable_variables

### train

In [None]:
checkpoint_directory = "{}/tmp_{}_{}_{}".format(dataset, layers, shape[0], base_lr)
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

optimizer = tf.keras.optimizers.Adam(learning_rate=base_lr)  # optimizer
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=flow)

global_step = []
train_losses = []
val_losses = []
min_val_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)  # high value to ensure that first loss < min_loss
min_train_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)
min_val_epoch = 0
min_train_epoch = 0
delta = 0  # threshold for early stopping

t_start = time.time()  # start time

# start training
for i in range(max_epochs):
    for batch in batched_train_data:
        train_loss = train_density_no_tf(flow, optimizer, batch)

    if i % int(100) == 0:
        batch_val_losses = []
        for batch in batched_val_data:
            batch_loss = -tf.reduce_mean(flow.log_prob(batch))
            batch_val_losses.append(batch_loss)
        val_loss = tf.reduce_mean(batch_val_losses)
        global_step.append(i)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"{i}, train_loss: {train_loss}, val_loss: {val_loss}")

        if train_loss < min_train_loss:
            min_train_loss = train_loss
            min_train_epoch = i

        if val_loss < min_val_loss:
            min_val_loss = val_loss
            min_val_epoch = i

            checkpoint.write(file_prefix=checkpoint_prefix)

        elif i - min_val_epoch > delta_count:
            break

train_time = time.time() - t_start

0, train_loss: 54.0181884765625, val_loss: 53.489410400390625


### evaluate

In [None]:
checkpoint.restore(checkpoint_prefix)

# perform on test dataset
t_start = time.time()
test_losses = []
for batch in batched_test_data:
    batch_loss = -tf.reduce_mean(flow.log_prob(batch))
    test_losses.append(batch_loss)

test_loss = tf.reduce_mean(test_losses)

        
test_time = time.time() - t_start

# save density estimation of best model
save_dir = "{}/{}_density_{}_{}_{}_{}_{}".format(dataset, dataset, batch_size, layers, shape, base_lr, min_val_epoch)
plot_heatmap_2d(flow, -4.0, 4.0, -4.0, 4.0, name=save_dir)

save_dir = "{}/{}_sampling_{}_{}_{}_{}_{}".format(dataset, dataset, batch_size, layers, shape, base_lr, min_val_epoch)
plot_samples_2d(flow.sample(1000), name=save_dir)

# remove checkpoint
filelist = [f for f in os.listdir(checkpoint_directory)]
for f in filelist:
    os.remove(os.path.join(checkpoint_directory, f))
os.removedirs(checkpoint_directory)

print(f'Test loss: {test_loss} at epoch: {i}')
print(f'Min val loss: {min_val_loss} at epoch: {min_val_epoch}')
print(f'Last val loss: {val_loss} at epoch: {i}')
print(f'Min train loss: {min_train_loss} at epoch: {min_train_epoch}')
print(f'Last val loss: {train_loss} at epoch: {i}')
print(f'Training time: {train_time}')
print(f'Test time: {test_time}')

results = {
    'test_loss': float(test_loss),
    'min_val_loss': float(min_val_loss),
    'min_val_epoch': min_val_epoch,
    'val_loss': float(val_loss),
    'min_train_loss': float(min_train_loss),
    'min_train_epoch': min_train_epoch,
    'train_loss': float(train_loss),
    'train_time': train_time,
    'test_time': test_time,
    'trained_epochs': i,
    'trainable variables': n_trainable_variables,
}