In [None]:
import jax
import jax.numpy as jnp
from jax import random, jit, vmap
from jax.experimental import ode
from jax.scipy.stats import norm
from functools import partial

# Defining items

## Flow matching

In [None]:
def make_cond_flow(sigma_min):

    def mu(x1, t):
        return t*x1

    def sigma(x1, t):
        return 1 - (1 - sigma_min)*t

    @partial(vmap, in_axes=(0, 0, 0), out_axes=0)
    def cond_flow(x0, x1, t):
        return sigma(x1, t) * x0 + mu(x1, t)

    def cond_vec_field(x, x1, t):
        return (x1 - (1 - sigma_min)*x)/(1 - (1 - sigma_min)*t)

    return cond_flow, cond_vec_field

## NeuralODE 

In [None]:
def NeuralODE(vec_field_net, batch_size, dim):
    
    def batched_sampler(rng, params):
        '''
        flow from x0 to x1
        '''
        x0 = random.normal(rng, (batch_size, dim))

        def _ode(x, t):
            return vmap(vec_field_net, (None, 0, None), 0)(params, x, t)

        xt = ode.odeint(_ode, 
                        x0, 
                        jnp.array([0.0, 1.0]),
                        rtol=1e-10, atol=1e-10
                        )
        return xt[-1]

    @partial(vmap, in_axes=(None, 0), out_axes=0)
    def logp(params, x):
        '''
        likelihood of given samples
        '''
        def base_logp(x):
            return norm.logpdf(x).sum(-1)

        def _ode(state, t):
            x = state[0]
            return -vec_field_net(params, x, t), divergence_fwd(vec_field_net)(params, x, t)

        logp = 0.0
        xt, logpt = ode.odeint(_ode, 
                            [x, logp], 
                            jnp.array([0.0, 1.0]),
                            rtol=1e-10, atol=1e-10
                            )
        return -logpt[-1] + base_logp(xt[-1])

    def divergence_fwd(f):
        def _div_f(params, x, t):
            jac = jax.jacfwd(lambda x: f(params, x, t))
            return jnp.trace(jac(x))
        return _div_f
    
    
    return batched_sampler, logp

## Loss

In [None]:
def make_loss(vec_field_net, cond_vec_field):

    @partial(vmap, in_axes=(None, 0, 0, 0), out_axes=0)
    def loss(params, x, x1, t):
        return jnp.linalg.norm(vec_field_net(params, x, t) - cond_vec_field(x, x1, t))

    return lambda params, x, x1, t: loss(params, x, x1, t).mean()

# Loading datasets

In [None]:
import matplotlib.pyplot as plt
from sklearn import datasets, preprocessing
from IPython.display import clear_output

%matplotlib inline

In [None]:
n_samples = 100000
plot_range = [(-2, 2), (-2, 2)]
n_bins = 100

scaler = preprocessing.StandardScaler()
X, _ = datasets.make_moons(n_samples=n_samples, noise=.05)
X = scaler.fit_transform(X)
plt.hist2d(X[:, 0], X[:, 1], bins=n_bins, range=plot_range)[-1]

# Building networks

In [None]:
from jax.example_libraries.stax import serial, Dense, Relu, Softplus

rng = random.PRNGKey(42)

dim = X.shape[1]
sigma_min = 0.01
num_epochs, batch_size = 400, 1000

In [None]:
def net_in_ode(rng):
    net_init, net_apply = serial(Dense(8), Softplus, Dense(dim))
    in_shape = (-1, dim+1)
    _, net_params = net_init(rng, in_shape)

    def net_apply_with_t(params, x, t):
        return net_apply(params, jnp.concatenate((x,t.reshape(1))))
    
    return net_params, net_apply_with_t

In [None]:
init_rng, rng = random.split(rng)
params, vec_field_net = net_in_ode(init_rng)
batched_cond_flow, cond_vec_field = make_cond_flow(sigma_min)
batched_sampler, batched_logp = NeuralODE(vec_field_net, batch_size, dim)

loss = make_loss(vec_field_net, cond_vec_field)
value_and_grad = jax.value_and_grad(loss)

from jax.example_libraries import optimizers

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)
opt_state = opt_init(params)

# Step

In [None]:
def step(rng, i, opt_state, inputs):
    params = get_params(opt_state)

    n_rng, u_rng = random.split(rng)
    t = random.uniform(u_rng, (batch_size,))
    x = batched_cond_flow(random.normal(n_rng, (batch_size, dim)), inputs, t)

    value, grad = value_and_grad(params, x, inputs, t)
    return opt_update(i, grad, opt_state), value

In [None]:
import itertools

itercount = itertools.count()

LOSS = []
for epoch in range(num_epochs):
    clear_output(wait=True) # Clear plotting

    permute_rng, step_rng, rng = random.split(rng, 3)
    X = random.permutation(permute_rng, X)
    for batch_index in range(0, len(X), batch_size):
        opt_state, value = step(step_rng, next(itercount), opt_state, X[batch_index:batch_index+batch_size])
        LOSS.append(value)
     
params = get_params(opt_state)
sample_rng, rng = random.split(rng)

X_syn = batched_sampler(sample_rng, params)

fig = plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.hist2d(X_syn[:, 0], X_syn[:, 1], bins=n_bins, range=plot_range)

plt.subplot(1, 2, 2)
plt.plot(LOSS)

plt.show()