# Density Estimation using Mask Autoregressive Flow (MAF) on UCI datasets

In [1]:
import numpy as np
import time
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
import os
import random

from data.dataset_loader import load_and_preprocess_uci
from normalizingflows.flow_catalog import Made, BatchNorm, get_trainable_variables
from utils.train_utils import train_density_estimation, nll

tfd = tfp.distributions
tfb = tfp.bijectors

tf.random.set_seed(1234)

tensorflow:  2.0.0
tensorflow-probability:  0.8.0-rc0


In [None]:
tf.test.is_gpu_available()

In [None]:
# parameters
batch_size = 512
dataset = "power"
layers = 4
base_lr = 1e-3
end_lr = 1e-4
max_epochs = int(500)
shape = [64, 64]
exp_number = 1
uci_trainsizes = {"power": 1659917,
                 "gas": 852174,
                 "hepmass": 315123,
                 "miniboone": 29556,
                 "bsds300": 1000000}
trainsize = uci_trainsizes[dataset]

### Data loading and preprocessing

In [4]:
batched_train_data, batched_val_data, batched_test_data = load_and_preprocess_uci(dataset, batch_size=batch_size)

In [5]:
sample_batch = next(iter(batched_train_data))
input_shape = sample_batch.shape[1]
print(input_shape)

permutation = tf.cast(np.concatenate((np.arange(input_shape/2,input_shape),np.arange(0,input_shape/2))), tf.int32)
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros(shape=input_shape, dtype=tf.float32))

6


### Create MAF flow

In [6]:
bijectors = []
event_shape = [input_shape]

# According to [Papamakarios et al. (2017)]:
# BatchNorm between the last autoregressive layer and the base distribution, and every two autoregressive layers

bijectors.append(BatchNorm(eps=10e-5, decay=0.95))

for i in range(0, layers):

    bijectors.append(tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn = Made(params=2, hidden_units=shape, activation="relu")))
    bijectors.append(tfb.Permute(permutation=permutation)) # Permutation improves denstiy estimation results
    
    # add BatchNorm every two layers
    if (i+1) % int(2) == 0:
        bijectors.append(BatchNorm(eps=10e-5, decay=0.95))
        

bijector = tfb.Chain(bijectors=list(reversed(bijectors)), name='chain_of_maf')


maf = tfd.TransformedDistribution(
    distribution=base_dist,
    bijector=bijector
    # event_shape=[event_shape]
)

# important: initialize with log_prob to initialize the moving average of the layer statistics in the batch norm layers
maf.log_prob(sample_batch)  # initialize
print("Successfully initialized!")

Successfully initialized!


In [7]:
n_trainable_variables = get_trainable_variables(maf)
print(n_trainable_variables)

21588


In [8]:
# learning rate scheduling
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(base_lr, max_epochs, end_lr, power=0.5)

In [9]:
# initialize checkpoints
checkpoint_directory = "{}/tmp_{}".format(dataset, str(hex(random.getrandbits(32))))
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

opt = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)  # optimizer
checkpoint = tf.train.Checkpoint(optimizer=opt, model=maf)

In [None]:
global_step = []
train_losses = []
val_losses = []
min_val_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)  # high value to ensure that first loss < min_loss
min_train_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)
min_val_epoch = 0
min_train_epoch = 0
delta_stop = 25  # threshold for early stopping

t_start = time.time()  # start time

# start training
for i in range(max_epochs):
    
    batched_train_data.shuffle(buffer_size=trainsize, reshuffle_each_iteration=True)
    batch_train_losses = []
    for batch in batched_train_data:
        batch_loss = train_density_estimation(maf, opt, batch)
        batch_train_losses.append(batch_loss)
        
    train_loss = tf.reduce_mean(batch_train_losses)

    if i % int(1) == 0:
        batch_val_losses = []
        for batch in batched_val_data:
            batch_loss = nll(maf, batch)
            batch_val_losses.append(batch_loss)
                
        val_loss = tf.reduce_mean(batch_val_losses)
        
        global_step.append(i)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"{i}, train_loss: {train_loss}, val_loss: {val_loss}")

        if train_loss < min_train_loss:
            min_train_loss = train_loss
            min_train_epoch = i
            
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            min_val_epoch = i
            checkpoint.write(file_prefix=checkpoint_prefix)

        elif i - min_val_epoch > delta_stop:  # no decrease in min_val_loss for "delta_stop epochs"
            break

train_time = time.time() - t_start

0, train_loss: 2.6511566638946533, val_loss: 1.676918387413025
1, train_loss: 1.291669487953186, val_loss: 1.0746381282806396
2, train_loss: 0.968338668346405, val_loss: 0.9203070998191833
3, train_loss: 0.8498308658599854, val_loss: 0.8211522698402405
4, train_loss: 0.7767125964164734, val_loss: 0.749708354473114
5, train_loss: 0.7166681289672852, val_loss: 0.703943133354187
6, train_loss: 0.6707810163497925, val_loss: 0.6692290902137756
7, train_loss: 0.6355478167533875, val_loss: 0.623394250869751
8, train_loss: 0.6012700200080872, val_loss: 0.5904493927955627
9, train_loss: 0.5746986865997314, val_loss: 0.5953004360198975
10, train_loss: 0.5566505789756775, val_loss: 0.5552653670310974
11, train_loss: 0.5321587324142456, val_loss: 0.5344226360321045
12, train_loss: 0.5176129937171936, val_loss: 0.5268900990486145
13, train_loss: 0.5031675100326538, val_loss: 0.497269868850708
14, train_loss: 0.4913642108440399, val_loss: 0.4924618601799011
15, train_loss: 0.47979089617729187, val_l

In [None]:
# load best model with min validation loss
checkpoint.restore(checkpoint_prefix)

# perform on test dataset
t_start = time.time()

test_losses = []
for batch in batched_test_data:
    batch_loss = nll(maf, batch)
    test_losses.append(batch_loss)
    
test_loss = tf.reduce_mean(test_losses)

test_time = time.time() - t_start

In [None]:
# plot train and validation loss curve
plt.plot(global_step, train_losses, label="train loss")
plt.plot(global_step, val_losses, label="val loss")
plt.legend()