In [1]:
import os
import sys
import time

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions
import matplotlib.pyplot as plt

from normalizingflows.flow_catalog import RealNVP
from utils.train_utils import train_density_estimation, nll
from data import dataset_loader




tensorflow:  2.0.0
tensorflow-probability:  0.8.0-rc0


In [3]:
category = 2
train_data, val_data, test_data, _ = dataset_loader.load_and_preprocess_mnist(logit_space=True, batch_size=128, shuffle=True, classes=category)

In [4]:
mnist_shape = (28, 28, 1)
size = 28
input_shape = size*size
permutation = tf.cast(np.concatenate((np.arange(input_shape/2,input_shape),np.arange(0,input_shape/2))), tf.int32)
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros(input_shape, tf.float32))

In [6]:
n_images = 1
dataset = "mnist"
exp_number = 1
max_epochs = 200
layers = 5
shape = [256, 256]
base_lr = 1e-4
end_lr = 1e-5

In [7]:
bijectors = []
alpha = 1e-3

for i in range(layers):
    bijectors.append(tfb.BatchNormalization())
    bijectors.append(RealNVP(input_shape=input_shape, n_hidden=shape))
    bijectors.append(tfp.bijectors.Permute(permutation))
    
bijectors.append(tfb.Reshape(event_shape_out=(size, size),
                                 event_shape_in=(size * size,)))


bijector = tfb.Chain(bijectors=list(reversed(bijectors)), name='chain_of_real_nvp')

flow = tfd.TransformedDistribution(
    distribution=base_dist,
    bijector=bijector
)

# number of trainable variables
n_trainable_variables = len(flow.trainable_variables)

In [8]:
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(base_lr, max_epochs, end_lr, power=0.5)

checkpoint_directory = "{}/tmp_{}_{}_{}_{}_{}".format(dataset, layers, shape[0], shape[1], exp_number, category)
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

opt = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
checkpoint = tf.train.Checkpoint(optimizer=opt, model=flow)

In [None]:
global_step = []
train_losses = []
val_losses = []
min_val_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)  # high value to ensure that first loss < min_loss
min_train_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)
min_val_epoch = 0
min_train_epoch = 0
delta_stop = 50  # threshold for early stopping

t_start = time.time()  # start time

# start training
for i in range(max_epochs):

    train_data.shuffle(buffer_size=1000)
    batch_train_losses = []
    for batch in train_data:
        batch_loss = train_density_estimation(flow, opt, batch)
        batch_train_losses.append(batch_loss)

    train_loss = tf.reduce_mean(batch_train_losses)

    if i % int(1) == 0:
        batch_val_losses = []
        for batch in val_data:
            batch_loss = nll(flow, batch)
            batch_val_losses.append(batch_loss)

        val_loss = tf.reduce_mean(batch_val_losses)

        global_step.append(i)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"{i}, train_loss: {train_loss}, val_loss: {val_loss}")

        if train_loss < min_train_loss:
            min_train_loss = train_loss
            min_train_epoch = i

        if val_loss < min_val_loss:
            min_val_loss = val_loss
            min_val_epoch = i
            checkpoint.write(file_prefix=checkpoint_prefix)

        elif i - min_val_epoch > delta_stop:  # no decrease in min_val_loss for "delta_stop epochs"
            break

train_time = time.time() - t_start

# load best model with min validation loss
checkpoint.restore(checkpoint_prefix)

# perform on test dataset
t_start = time.time()

test_losses = []
for batch in test_data:
    batch_loss = nll(flow, batch)
    test_losses.append(batch_loss)

test_loss = tf.reduce_mean(test_losses)

test_time = time.time() - t_start

save_dir = "{}/sampling_{}_{}_{}_{}/".format(dataset, layers, shape[0], shape[1], category)

if not os.path.isdir(save_dir):
    os.mkdir(save_dir)
for j in range(n_images):
    plt.figure()
    data = flow.sample(1)
    data = dataset_loader.inverse_logit(data)
    data = tf.reshape(data, (1, size, size))
    plt.imshow(data[0], cmap='gray')
    plt.savefig("{}/{}_{}_i{}.png".format(save_dir, exp_number, min_val_epoch, j))
    plt.close()

# remove checkpoint
filelist = [f for f in os.listdir(checkpoint_directory)]
for f in filelist:
    os.remove(os.path.join(checkpoint_directory, f))
os.removedirs(checkpoint_directory)

print(f'Test loss: {test_loss} at epoch: {i}')
print(f'Average test log likelihood: {-test_loss} at epoch: {i}')
print(f'Min val loss: {min_val_loss} at epoch: {min_val_epoch}')
print(f'Last val loss: {val_loss} at epoch: {i}')
print(f'Min train loss: {min_train_loss} at epoch: {min_train_epoch}')
print(f'Last train loss: {train_loss} at epoch: {i}')
print(f'Training time: {train_time}')
print(f'Test time: {test_time}')

results = {
    'test_loss': float(test_loss),
    'avg_test_logll': float(-test_loss),
    'min_val_loss': float(min_val_loss),
    'min_val_epoch': min_val_epoch,
    'val_loss': float(val_loss),
    'min_train_loss': float(min_train_loss),
    'min_train_epoch': min_train_epoch,
    'train_loss': float(train_loss),
    'train_time': train_time,
    'test_time': test_time,
    'trained_epochs': i,
    'trainable variables': n_trainable_variables,
    'exp_number': exp_number
}