Skip to content

albcab/fmx

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 

Repository files navigation

Lightweight Conditional Flow Matching using JAX

Use the library by cloning the repository to your working folder:

git clone https://github.com/albcab/fmx.git

Supported algorithms

Learning the two moons 2D density as a running example, using haiku to build the neural network and optax to optimize:

import jax

from sklearn.datasets import make_moons
import haiku as hk
import optax

init_W = hk.initializers.VarianceScaling(0.1)
init_b = hk.initializers.RandomNormal(0.01)

def mlp_generator(out_dim, hidden_dims, non_linearity=jax.nn.relu):
    layers = []
    for out in hidden_dims:
        layers.append(hk.Linear(out, w_init=init_W, b_init=init_b))
        layers.append(non_linearity)
    layers.append(hk.Linear(out_dim, w_init=init_W, b_init=init_b))
    return hk.Sequential(layers)

def mlp_vector_field(time, sample):
    (out_dim,) = sample.shape
    mlp = mlp_generator(out_dim, hidden_dims=[256] * 3)
    input = jax.numpy.hstack([time, sample])
    return mlp(input)

vector_field = hk.without_apply_rng(hk.transform(mlp_vector_field))

data, _ = make_moons(4096, noise=0.05)
data = jax.numpy.array(data)

rng_key = jax.random.PRNGKey(0)
rng_key, key_init = jax.random.split(rng_key)
vector_field_params = vector_field.init(key_init, 0.0, data[0])

optim = optax.adamw(1e-3)
optim_state = optim.init(vector_field_params)

epochs, batch_size = 4096, 256 

Load and initialize the generating functions of the method:

from fmx.data import flow_matching

fmx = flow_matching(vector_field.apply, data)
"""Keyword arguments include:
    vector_field_apply: Callable that computes the vector field, with signature (parameters, time, sample) -> (vector field)
    samples: PyTree where each leaf has leading dimension (number_of_observations, ...)
    sigma: Standard deviation of conditional probability path
    odeint: Callable that computes the numerical approximation of probability path, with signature (vector field, initial condition) -> (flow)
"""

Learn the parameters of the vector field:

def step_fn(carry, key):
    vector_field_params, optim_state = carry
    fm_loss = fmx.get_loss(key, batch_size)
    loss_value, grads = jax.value_and_grad(fm_loss)(vector_field_params)
    updates, optim_state = optim.update(grads, optim_state, vector_field_params)
    vector_field_params = optax.apply_updates(vector_field_params, updates)
    return (vector_field_params, optim_state), loss_value

keys = jax.random.split(rng_key, epochs)
(vector_field_params, optim_state), loss_values = jax.lax.scan(
    step_fn, (vector_field_params, optim_state), keys
)

Generate new samples given the learned parameters:

(key_samples,) = jax.random.split(keys[0], 1)
num_samples = 4096
samples = fmx.sample(key_samples, vector_field_params, num_samples)

Evaluate the learned log density of the data:

logprob_fn = fmx.get_logprob(vector_field_params)
logprob_data = jax.vmap(logprob_fn)(data)

Only supports simplified conditional flow matching

Same as above but adding the argument reference_gn: Optional random generator of reference density:

from fmx.data import flow_matching

reference_gn = lambda key: jax.random.normal(key, (2,))
fmx = flow_matching(vector_field.apply, data, reference_gn=reference_gn)

About

Conditional Flow Matching using JAX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages