In [None]:
import os
import sys
import time
import random 

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions
import matplotlib.pyplot as plt

from normalizingflows.flow_catalog import RealNVP
from utils.train_utils import train_density_estimation, nll
from data import dataset_loader




In [None]:

train_data, val_data, test_data, _ = dataset_loader.load_and_preprocess_uci("miniboone", batch_size=256)

In [None]:
for sample in train_data.take(1):
    input_shape = sample.shape[1]
    break
print(input_shape)

permutation = tf.cast(np.concatenate((np.arange(input_shape/2,input_shape),np.arange(0,input_shape/2))), tf.int32)
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros(input_shape, tf.float32))

In [None]:
n_images = 1
dataset = "gas"
exp_number = 1
max_epochs = 10
layers = 4
shape = [256, 256]
base_lr = 1e-4
end_lr = 1e-5
uci_trainsizes = {"power": 1659917,
                  "gas": 852174,
                  "hepmass": 315123,
                  "miniboone": 29556,
                  "bsds300": 1000000}

In [None]:
bijectors = []

for i in range(layers):
    bijectors.append(tfb.BatchNormalization())
    bijectors.append(RealNVP(input_shape=input_shape, n_hidden=shape))
    bijectors.append(tfp.bijectors.Permute(permutation))

bijector = tfb.Chain(bijectors=list(reversed(bijectors)), name='chain_of_real_nvp')

flow = tfd.TransformedDistribution(
    distribution=base_dist,
    bijector=bijector
)

# number of trainable variables
n_trainable_variables = len(flow.trainable_variables)

In [None]:
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(base_lr, max_epochs, end_lr, power=0.5)

#checkpoint_directory = "{}/tmp_{}".format(dataset, str(hex(random.getrandbits(32))))

checkpoint_directory = "{}/tmp_{}_{}_{}_{}".format(dataset, layers, shape[0], shape[1], exp_number)
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

opt = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
checkpoint = tf.train.Checkpoint(optimizer=opt, model=flow)

In [None]:
global_step = []
train_losses = []
val_losses = []
min_val_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)  # high value to ensure that first loss < min_loss
min_train_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)
min_val_epoch = 0
min_train_epoch = 0
delta_stop = 10  # threshold for early stopping

t_start = time.time()  # start time

# start training
for i in range(max_epochs):

    train_data.shuffle(buffer_size=uci_trainsizes[dataset])
    batch_train_losses = []
    for batch in train_data:
        batch_loss = train_density_estimation(flow, opt, batch)
        batch_train_losses.append(batch_loss)

    train_loss = tf.reduce_mean(batch_train_losses)

    if i % int(1) == 0:
        batch_val_losses = []
        for batch in val_data:
            batch_loss = nll(flow, batch)
            batch_val_losses.append(batch_loss)
        #val_loss = nll(flow, val_data)
        val_loss = tf.reduce_mean(batch_val_losses)

        global_step.append(i)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"{i}, train_loss: {train_loss}, val_loss: {val_loss}")

        if train_loss < min_train_loss:
            min_train_loss = train_loss
            min_train_epoch = i

        if val_loss < min_val_loss:
            min_val_loss = val_loss
            min_val_epoch = i
            checkpoint.write(file_prefix=checkpoint_prefix)

        elif i - min_val_epoch > delta_stop:  # no decrease in min_val_loss for "delta_stop epochs"
            break

train_time = time.time() - t_start

# load best model with min validation loss
checkpoint.restore(checkpoint_prefix)

# perform on test dataset
t_start = time.time()

test_losses = []
for batch in test_data:
    batch_loss = nll(flow, batch)
    test_losses.append(batch_loss)

test_loss = tf.reduce_mean(test_losses)

test_time = time.time() - t_start


# remove checkpoint
filelist = [f for f in os.listdir(checkpoint_directory)]
for f in filelist:
    os.remove(os.path.join(checkpoint_directory, f))
os.removedirs(checkpoint_directory)


print(f'Test loss: {test_loss} at epoch: {i}')
print(f'Average test log likelihood: {-test_loss} at epoch: {i}')
print(f'Min val loss: {min_val_loss} at epoch: {min_val_epoch}')
print(f'Last val loss: {val_loss} at epoch: {i}')
print(f'Min train loss: {min_train_loss} at epoch: {min_train_epoch}')
print(f'Last train loss: {train_loss} at epoch: {i}')
print(f'Training time: {train_time}')
print(f'Test time: {test_time}')

results = {
    'test_loss': float(test_loss),
    'avg_test_logll': float(-test_loss),
    'min_val_loss': float(min_val_loss),
    'min_val_epoch': min_val_epoch,
    'val_loss': float(val_loss),
    'min_train_loss': float(min_train_loss),
    'min_train_epoch': min_train_epoch,
    'train_loss': float(train_loss),
    'train_time': train_time,
    'test_time': test_time,
    'trained_epochs': i,
    'trainable variables': n_trainable_variables,
    'exp_number': exp_number
}

In [None]:
plt.plot(val_losses)
plt.plot(train_losses)