In [None]:
"""
In this notebook, we consider a simple example of normalizing flows, based on the free-form formulation in, 

Draxler, F., Sorrenson, P., Zimmermann, L., Rousselot, A., & Köthe, U. (2024, April). 
Free-form flows: Make any architecture a normalizing flow. 
In International Conference on Artificial Intelligence and Statistics (pp. 2197-2205). PMLR.

As documented in the reference, the advantage of this formulation is that it avoids the need to explicitly use
invertible neural networks. Rather, this is enforced through a penalty that promotes reconstruction.

In the following code, we will use only basic jax code as a learning exersize, akin to

https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html

"""

In [222]:
import time
from typing import Callable
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np

In [223]:
jax.devices()
key = jax.random.key(0)

In [224]:
# settings
NUMBER_OF_EPOCHS = 1200
BATCH_SIZE = 256
DATA_SIZE = 10000
LEARNING_RATE = 1e-3
BETA = 20

In [225]:
# dataset

def generate_dataset(n: int = 1000, r_inner = 0.5, r_outer = 1.5):
    #theta = np.linspace(0, 2 * np.pi, n)
    theta = np.random.rand(n) * 2 * np.pi
    r = r_inner + (r_outer - r_inner) / 2 * (1 + np.sin(2 * theta))
    noise = 0.1 * (2.0 * np.random.rand(n) - 1.0)
    x = r * np.cos(theta) * (1 + noise)
    y = r * np.sin(theta) * (1 + noise)

    return np.stack([x, y], axis=1)


train_data = jnp.array( generate_dataset(n=DATA_SIZE) )
valid_data = jnp.array( generate_dataset(n=100) )

In [None]:
plt.scatter(train_data[:,0], train_data[:,1])

In [227]:
# def random_layer_params(
#         key: jax.random.key,
#         in_dim: int,
#         out_dim: int,
#         b_scale: float = 0.0,
# ) -> tuple[jnp.array]:
    
#     w_key, b_key = jax.random.split(key, 2)
#     return  (
#       1/jnp.sqrt(in_dim) * jax.random.uniform(
#            w_key,
#            shape=(in_dim, out_dim),
#            minval=-1.0,
#            maxval=1.0,
#         ),
#       b_scale * jax.random.uniform(
#            b_key,
#            shape=(out_dim,),
#            minval=-1.0,
#            maxval=1.0,
#         ),
#     )

def random_layer_params(
        key: jax.random.key,
        in_dim: int,
        out_dim: int,
        b_scale: float = 0.0,
) -> tuple[jnp.array]:
    
    w_key, b_key = jax.random.split(key, 2)
    return  (
      jnp.sqrt(2/(in_dim)) * jax.random.normal(
           w_key,
           shape=(in_dim, out_dim),
        ),
      b_scale * jax.random.normal(
           b_key,
           shape=(out_dim,),
        ),
    )

# Initialize all layers for a fully-connected neural network with sizes "sizes"
def init_network_params(key: jax.random.key, layer_sizes: list[int], b_scale: int = 0.0) -> list[int]:
    keys = jax.random.split(key, len(layer_sizes))
    return [
        random_layer_params(k, in_dim, out_dim, b_scale) 
        for in_dim, out_dim, k in zip(layer_sizes[:-1], layer_sizes[1:], keys)
    ]

layer_sizes = [2, 32, 64, 64, 32, 2]
b_scale = 1e-3

key, subkey1, subkey2 = jax.random.split(key, 3)
encoder_params = init_network_params(subkey1, layer_sizes, b_scale)
decoder_params = init_network_params(subkey2, layer_sizes, b_scale)

In [228]:

# model prediction
# empirically, including layer-wise residual connections (e.g., h = h + activation(h)) results in the loss
# blowing up, leading to nans. The trainability (with just SGD) of this loss is quite sensitive to the network characteristics
# and the network/loss hyperparameters.
def predict_single(params: jnp.array, x: jnp.array) -> jnp.array:
    def relu(z: jnp.array):
        return jnp.maximum(0, z)
    def tanh(z: jnp.array):
        return jnp.tanh(z)
    def sigmoid(z: jnp.array):
        return 1/(1 + jnp.exp(-z))
    def selu(z: jnp.array):
        return z * sigmoid(z)
    
    h = x
    for (w,b) in params[:-1]:
        h = jnp.dot(h,w) + b
        # h = relu(h)  
        h = tanh(h)
    
    final_w, final_b = params[-1]
    return x + jnp.dot(h, final_w) + final_b

In [229]:
# loss

def fff_components_single(
    key:  jax.random.key,
    encoder_params: list[jnp.array],
    decoder_params: list[jnp.array],
    x: jnp.array,
):
    """
    Free-form flow (fff) loss function, as defined in Algorithm 1 of the
    primary reference and derived in Appendix A.2. 
    """

    # note, x is assumed of dimension (d,) 
    # also, jacobian calculations are wrt inputs (not parameters)
    
    def encoder_fn(x1: jnp.array):
       return predict_single(encoder_params, x1)
    
    def decoder_fn(z1: jnp.array):
        return predict_single(decoder_params, z1)
    
    # Hutchinson trace approximation works better with test vectors v on \sqrt{dim}-sphere.
    # https://www.ethanepperly.com/index.php/2024/01/28/dont-use-gaussians-in-stochastic-trace-estimation/
    #
    # following the main reference, we will just use a single test vector

    
    v = jax.random.normal(key, shape=x.shape)
    v *= jnp.sqrt(v.shape[-1]) / jnp.sqrt(jnp.square(v).sum(axis=-1, keepdims=True))
    
    z, func_vjp = jax.vjp(encoder_fn, x)
    v1 = func_vjp(v)[0]
    xr, v2 = jax.jvp(decoder_fn, [z,], [v,])
 
    
    # v, z, v1, xr, v2 are all lists containing a (B, 1, d) array
    log_jac_det = jax.lax.stop_gradient(v2) * v1
    nll = 0.5 * jnp.square(z).sum(axis=-1) - log_jac_det.sum(axis=-1)
    L_reconstr = jnp.square(xr - x).sum(axis=-1)

    return nll, L_reconstr


# define vmap over rng keys and batches
batch_fff_components = jax.vmap(fff_components_single, in_axes=(0, None, None, 0))

def batch_loss(
    key:  jax.random.key,
    encoder_params: list[jnp.array],
    decoder_params: list[jnp.array],
    x: jnp.array,
):
    keys = jax.random.split(key, x.shape[0])
    nll, L_reconstr = batch_fff_components(keys, encoder_params, decoder_params, x)
    return (nll + BETA * L_reconstr).mean()

In [230]:

# class AdamW():
#     def __init__(self)


@jax.jit
def update(
    key:  jax.random.key,
    encoder_params: list[jnp.array],
    decoder_params: list[jnp.array],
    x: jnp.array,
    lr: float,
)->tuple[list]:
    grad_fn = jax.grad(batch_loss, argnums=(1, 2))
    encoder_grads, decoder_grads = grad_fn(key, encoder_params, decoder_params, x)

    return (
        [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(encoder_params, encoder_grads)],
        [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(decoder_params, decoder_grads)],
    )

          

In [None]:
# training loop

def batch_iterate(key: jax.random.key, batch_size: int, x: jnp.array):
    perm = jnp.array(jax.random.permutation(key, x.shape[0]))
    for s in range(0, x.shape[0], batch_size):
        ids = perm[s : s + batch_size]
        yield x[ids]

losses = []
lr = LEARNING_RATE
scale = 0.99
for e in range(NUMBER_OF_EPOCHS):
    key, subkey1 = jax.random.split(key, 2)
    start_time = time.time()
    for xb in batch_iterate(subkey1, BATCH_SIZE, train_data):
        key, subkey2 = jax.random.split(key, 2)
        encoder_params, decoder_params = update(
            subkey2,
            encoder_params,
            decoder_params,
            xb,
            lr,
        )

    lr *= scale
    subkeys = jax.random.split(key, train_data.shape[0]+1)
    key = subkeys[0]
    train_nll, train_reconstr = batch_fff_components(
        subkeys[1:],
        encoder_params,
        decoder_params,
        train_data,
    )
    train_nll = train_nll.mean()
    train_reconstr = train_reconstr.mean()
    train_loss = (train_nll + BETA * train_reconstr).mean()

    subkeys = jax.random.split(key, valid_data.shape[0]+1)
    key = subkeys[0]
    valid_nll, valid_reconstr = batch_fff_components(
        subkeys[1:],
        encoder_params,
        decoder_params,
        valid_data,
    )
    valid_nll = valid_nll.mean()
    valid_reconstr = valid_reconstr.mean()
    valid_loss = (valid_nll + BETA * valid_reconstr).mean()
        
        
    epoch_time = time.time() - start_time
    print(f"Epoch {e} | lr: {lr:.2e} | Train nll, reconstr, loss: {train_nll.item():.3f}, {train_reconstr.item():.3f}, {train_loss.item():.3f} | Valid loss {valid_loss.item():.3f} | Epoch time {epoch_time:.4f}s")

In [232]:
batch_predict = jax.vmap(predict_single, in_axes=(None, 0))

In [None]:
# generation plot
num_samples = 1000
key, subkey = jax.random.split(key, 2)
z = jax.random.normal(subkey, shape=(num_samples, 2))
x_gen = batch_predict(decoder_params, z)

plt.scatter(x_gen[:, 0], x_gen[:, 1], c='blue', marker='.', label='decoded/generated')
plt.scatter(z[:, 0], z[:, 1], marker='o', facecolors='none', edgecolors='black', alpha=0.2, label='latents')
plt.scatter(train_data[:, 0], train_data[:, 1], c='red', marker='.', alpha = 1e-2, label='data')
leg = plt.legend()
for lhandle in leg.legend_handles: 
    lhandle.set_alpha(1)