In [None]:
import os
import sys
import time
import random

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_probability as tfp
tfb = tfp.bijectors
tfd = tfp.distributions
import matplotlib.pyplot as plt

from normalizingflows.flow_catalog import RealNVP
from utils.train_utils import train_density_estimation, nll
from data import dataset_loader

tf.random.set_seed(1234)


In [None]:
batch_size = 64
dataset = "celeb_a"
layers = 8
base_lr = 1e-3
end_lr = 1e-4
max_epochs = int(100)
shape = [128, 128]
exp_number = 1
celeb_trainsize = 202599

In [None]:
# load celeb dataset
celeb_dataset = tfds.load(name="celeb_a", batch_size=batch_size, shuffle_files=True)
batched_train_data = celeb_dataset["train"]
batched_val_data = celeb_dataset["validation"]
batched_test_data = celeb_dataset["test"]

In [None]:
# assumes batch size first
sample_batch = next(iter(batched_train_data))
# get one random image of the batch and display it
plt.imshow(sample_batch["image"][int(np.random.rand()*batch_size)])
plt.savefig("celeb_a/gt_{}.png".format(3))
# get shapes
celeb_shape = sample_batch["image"].shape[1:]
input_shape = celeb_shape[0] * celeb_shape[1] * celeb_shape[2]
permutation = tf.cast(np.concatenate((np.arange(input_shape/2,input_shape),np.arange(0,input_shape/2))), tf.int32)
base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros(shape=input_shape, dtype=tf.float32))

In [None]:
bijectors = []
alpha = 1e-3

for i in range(layers):
    bijectors.append(tfb.BatchNormalization())
    bijectors.append(RealNVP(input_shape=input_shape, n_hidden=shape))
    bijectors.append(tfp.bijectors.Permute(permutation))
    
bijectors.append(tfb.Reshape(event_shape_out=(celeb_shape),
                                 event_shape_in=(input_shape,)))


bijector = tfb.Chain(bijectors=list(reversed(bijectors)), name='chain_of_real_nvp')

flow = tfd.TransformedDistribution(
    distribution=base_dist,
    bijector=bijector
)

# number of trainable variables
n_trainable_variables = len(flow.trainable_variables)

In [None]:
def show_image(flow, save_dir=None):
    plt.figure()
    data = flow.sample(1)
    data = tf.sigmoid(data)
    plt.imshow(data[0])
    if save_dir is not None:
        plt.savefig(save_dir + ".png", format="png")

In [None]:
learning_rate_fn = tf.keras.optimizers.schedules.PolynomialDecay(base_lr, max_epochs, end_lr, power=0.5)

checkpoint_directory = "{}/tmp_{}".format(dataset, str(hex(random.getrandbits(32))))
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

opt = tf.keras.optimizers.Adam(learning_rate=learning_rate_fn)
checkpoint = tf.train.Checkpoint(optimizer=opt, model=flow)

In [None]:
global_step = []
train_losses = []
val_losses = []
min_val_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)  # high value to ensure that first loss < min_loss
min_train_loss = tf.convert_to_tensor(np.inf, dtype=tf.float32)
min_val_epoch = 0
min_train_epoch = 0
delta_stop = 50  # threshold for early stopping

t_start = time.time()  # start time

# start training
for i in range(max_epochs):
    
    batched_train_data.shuffle(buffer_size=celeb_trainsize, reshuffle_each_iteration=True)
    batch_train_losses = []
    for batch in batched_train_data:
        batch_loss = train_density_estimation(flow, opt, dataset_loader.logit(tf.cast(batch["image"], tf.float32)))
        batch_train_losses.append(batch_loss)
        
    train_loss = tf.reduce_mean(batch_train_losses)

    if i % int(1) == 0:
        batch_val_losses = []
        for batch in batched_val_data:
            batch_loss = nll(flow, dataset_loader.logit(tf.cast(batch["image"], tf.float32)))
            batch_val_losses.append(batch_loss)
                
        val_loss = tf.reduce_mean(batch_val_losses)
        
        global_step.append(i)
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print(f"{i}, train_loss: {train_loss}, val_loss: {val_loss}")
        
        if train_loss < min_train_loss:
            min_train_loss = train_loss
            min_train_epoch = i
            
        if val_loss < min_val_loss:
            min_val_loss = val_loss
            min_val_epoch = i
            checkpoint.write(file_prefix=checkpoint_prefix)

        elif i - min_val_epoch > delta_stop:  # no decrease in min_val_loss for "delta_stop epochs"
            break

train_time = time.time() - t_start

# load best model with min validation loss
checkpoint.restore(checkpoint_prefix)

# perform on test dataset
t_start = time.time()

test_losses = []
for batch in batched_test_data:
    batch_loss = nll(flow, dataset_loader.logit(tf.cast(batch["image"], tf.float32)))
    test_losses.append(batch_loss)
    
test_loss = tf.reduce_mean(test_losses)

test_time = time.time() - t_start


In [None]:
plt.plot(global_step, train_losses, label="train loss")
plt.plot(global_step, val_losses, label="val loss")
plt.legend()

In [None]:
n_images = 5

save_dir = "{}/sampling_{}_{}_{}/".format(dataset, layers, shape[0], shape[1])

if not os.path.isdir(save_dir):
    os.mkdir(save_dir)
for j in range(n_images):
    plt.figure()
    data = flow.sample(1)
    data = tf.sigmoid(data)
    plt.imshow(data[0])
    plt.savefig("{}/{}_{}_i{}.png".format(save_dir, exp_number, min_val_epoch, j))
    plt.close()