In [None]:
# 1st part: implement coupling layer
# - split input
# - implement y1 = x1, y2 = x2 + m(y1)   (forward)
# - implement backward


# one way to split could be
x1, x2 = tf.split(x, 2, axis=-1)
y1, y2 = whatever_the_coupling_layer_does(x1, x2)
y =  tf.concat([y1, y2], axis=-1)

# this is not so good, since it splits the data right through the middle.
# if the data has spatial relations, like images, this means the model
# conditions, say, the upper half of the image on the lower half, and vice versa.
# these long-range dependencies are hard to model.

# another option could be an even-odd split:
def split_even_odd(inp):
    even_inds = tf.range(0, inp.shape[1], 2)
    odd_inds = tf.range(1, inp.shape[1], 2)

    even = tf.gather(inp, even_inds, axis=1)
    odd = tf.gather(inp, odd_inds, axis=1)

    # process even and odd parts separately...
    ...
    
    # combine
    together = tf.stack([even_output, odd_output], axis=-1)
    return tf.reshape(together, tf.shape(inp))

In [None]:
import tensorflow_probability as tfp
tfd = tfp.distributions
import tensorflow as tf

In [None]:
# 2nd part: NICE model
# - stack a bunch of coupling layers
# - switch x1, x2 for each layer
# - forward: apply coupling layers
# - backward: apply backward coupling layers (in reverse)

# - rescaling:
# - create d-dimensional vector (model weight)
# - forward: multiply x * exp(vector) after applying coupling layers
# - backwards: x / vector (or x * tf.exp(-vector)) BEFORE applying reverse coupling layers


# - training: map x -> h using NICE model
# - compute log_p_simple(h)
# - add log determinant of jacobian: simply sum of scaling values
# use -log_likelihood as loss

# a note on simple distributions.
# you should use tfd = tfp.distributions...
input_dim = 12  # example
batch_size = 8
dummy_data = tf.random.normal((batch_size, input_dim))

simple_distribution = tfd.Normal(loc=tf.zeros((input_dim,)), scale=tf.ones((input_dim,)))
log_p_simple = simple_distribution.log_prob(dummy_data)
# ... this will return a batch x dim matrix. then SUM over axis 1 (data dimension)! average over batch axis!
log_p_simple = tf.reduce_sum(log_p_simple, axis=1)
print(log_p_simple)

# alternatively: use multivariate distribution.
# this returns one prob per entry.
simple_distribution = tfd.MultivariateNormalDiag(loc=tf.zeros((input_dim,)))
print(simple_distribution.log_prob(dummy_data))

print("Same results!")

# just don't mix it up!!! I had a bug where I was using multivariate, but summing over the last dimension.
# this resulted in summing my loss over the batch axis instead of averaging.
# that's bad, because your effective learning rate is MUCH higher than expected.

# the paper proses using tfd.Logistic instead. that does not have a multivariate version AFAIK,
# so you will have to use the first option.

In [None]:
# finally, here is a sample toy dataset.
# it's a 2D parabola kinda thing.
# you can use this to test your flow models.
# even a simple model like NICE should be able to fit this!
n_samples = 2048
x2_dist = tfd.Normal(loc=0., scale=0.5)
x2_samples = x2_dist.sample(n_samples)
x1 = tfd.Normal(loc=1. * tf.square(x2_samples),
                scale=0.1*tf.ones(n_samples, dtype=tf.float32))
x1_samples = x1.sample()
x_samples = tf.stack([x1_samples, x2_samples], axis=1)

as_np = x_samples.numpy()
plt.scatter(as_np[:, 0], as_np[:, 1])
a = plt.gca()
plt.show()


# x_samples it the dataset!
# NOTE that I say in the assignment that LayerNorm > BatchNorm for these models apparently.
# but for these simple models with very low data dimensionality, it appears layernorm causes issues sometimes.
# so maybe leave normalization out completely.