In [1]:
import jax.numpy as jnp
import jax
import numpy as np
from flax.training import train_state, checkpoints
import optax
from flax import linen as nn
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation



In [2]:
def annulus_data(n_samples, n_t, key):
    dt = 5e-2
    def V(p):
        x, y = p
        r = (x-0.5)**2 + (y-0.5)**2
        return (r - 0.1)**2


    def u(p):
        x, y = p
        return -jax.grad(V)(p) + 0.5 * jnp.array([y - 0.5, -(x - 0.5)])


    def time_step(x, key):
        x += dt * jax.vmap(u)(x) + 0.05 * \
            jax.random.normal(key, x.shape) * jnp.sqrt(dt)
        return x


    @jax.jit
    def integrate(x, keys):
        def step(x, key):
            x = time_step(x, key)
            return x, x
        _, x_traj = jax.lax.scan(step, x, keys)
        return x_traj
    
    x0 = jax.random.normal(key, (n_samples, 2)) * 0.1 + jnp.array([0.5, 0.5])
    return integrate(x0, jax.random.split(key, n_t))

In [17]:
###
# Parameters
###
key = jax.random.PRNGKey(0)

d = 2           # spatial dimension
Nt = 256        # number of time steps
Nx = 10_000     # number of samples
Nμ = 1          # number of parameter samples

x_data = annulus_data(Nx, Nt, key)
t_data = jnp.linspace(0, 1, Nt)
μ_data = jnp.zeros(Nμ)

nt = 256
nx = 1024
nμ = 1

num_iterations = 20_000
initial_learning_rate = 5e-4
final_learning_rate   = 1e-6

In [18]:
class MLP(nn.Module):
  num_hid : int
  num_out : int
  num_layers : int

  def setup(self):
    self.layers = [nn.Dense(features=self.num_hid) for _ in range(self.num_layers)]
    self.out = nn.Dense(features=self.num_out)

  def __call__(self, x, t, mu):
    h = jnp.hstack([x,t,mu])
    for layer in self.layers:
        h = nn.swish(layer(h))
    h = self.out(h)
    return h

###
# Neural network representing s = s(x, t, μ)
###
net_width = 64
net_depth = 3
model = MLP(num_hid=net_width, 
            num_out=1, 
            num_layers=net_depth) # NN representing s

In [19]:
###
# Initialize the model and optimizer
###
key, p_key = jax.random.split(key)
params = model.init(p_key, jnp.zeros(d), jnp.zeros(1), jnp.zeros(1))
learning_rate_schedule = optax.cosine_decay_schedule(
    init_value=initial_learning_rate, 
    decay_steps=num_iterations,
    alpha=final_learning_rate)
optimizer = optax.adam(learning_rate=learning_rate_schedule)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)

def s(params, x, t, mu):
    return state.apply_fn(params, x, t, mu).sum()

In [20]:
### Utility functions

def get_random_subset(key, arr, bs):
  N = len(arr)
  indices = jax.random.choice(key, N, shape=(bs,), replace=False)
  subset = arr[indices]
  return subset

def get_random_timegrid(key, t_data, bs_t):
    n_t = t_data.shape[0]
    t_q = get_random_subset(key, t_data, bs_t)
    t_q = jnp.sort(t_q)
    t_q = t_q.at[0].set(t_data[0])
    t_q = t_q.at[-1].set(t_data[-1])
    w_q = 0.5 * jnp.concatenate([jnp.array([t_q[1] - t_q[0]]), (t_q[2:] - t_q[:-2]), jnp.array([t_q[-1] - t_q[-2]])]) 
    return t_q, w_q
        
def get_expected_value_fct(x_data, t_data, mu_data, bs_n):
    n_t = t_data.shape[0]
    def expected_value(f, tau, t, i_mu, key):
        #rho is evaluated at tau and s is evaluated at t
        # get a random subset of x at time t1
        i_t = jnp.int32(tau / (t_data[-1] - t_data[0]) * (n_t - 1))
        x = get_random_subset(key, x_data[:, i_t, i_mu, :], bs_n)
        mu = mu_data[i_mu]
        # evaluate f at the time t at all x_tau and average
        # signature of f should be (x, t, mu) -> scalar
        return jnp.mean(jax.vmap(lambda _x: f(_x, t, mu))(x))
    return expected_value

def get_s_derivatives(s):
    def grad_s(x, t, mu):
        return jax.grad(s)(x, t, mu)
    def grad_s_squared(x, t, mu):
        return jnp.sum(grad_s(x, t, mu)**2)
    def partial_t_s(x, t, mu):
        return jax.grad(s,1)(x, t, mu)
    def laplace_s(x,t,mu):
        return jnp.trace(jax.hessian(s)(x, t, mu))
    return grad_s, grad_s_squared, partial_t_s, laplace_s

### DICE

def get_dice_loss(_s, x_data, t_data, mu_data, bs_n, bs_t, bs_mu):
    
    n_t = t_data.shape[0]
    n_mu = mu_data.shape[0]
    expected_value = get_expected_value_fct(x_data, t_data, mu_data, bs_n)
    
    def dice_loss_mu(state, params, key, i_mu):
        
        mu = mu_data[i_mu]
        # closures for s
        s = lambda x, t, mu: _s(params, x, t, mu)
        grad_s, grad_s_squared, partial_t_s, laplace_s = get_s_derivatives(s)

        key, x_key, t_key = jax.random.split(key, 3)
        x_keys = jax.random.split(x_key, bs_t)
        t_q, w_q = get_random_timegrid(t_key, t_data, bs_t)
        
        E_s = lambda tau, t, key: expected_value(s, tau, t, i_mu, key)
        E_s_v = jax.vmap(E_s)

        sum_En_snplus1 =      jnp.sum(E_s_v(t_q[:-1], t_q[1:],  x_keys[:-1]))
        sum_Enplus1_sn =      jnp.sum(E_s_v(t_q[1:],  t_q[:-1], x_keys[1:]))
        sum_En_sn =           jnp.sum(E_s_v(t_q[:-1], t_q[:-1], x_keys[:-1]))
        sum_Enplus1_snplus1 = jnp.sum(E_s_v(t_q[1:],  t_q[1:],  x_keys[1:]))
        loss = (+ 0.5 * sum_En_snplus1 
                - 0.5 * sum_Enplus1_sn 
                + 0.5 * sum_En_sn 
                - 0.5 * sum_Enplus1_snplus1)
        
        E_grad_s_squared = lambda tau, t, key: expected_value(grad_s_squared, tau, t, i_mu, key)
        loss += 0.5 * jnp.sum( w_q * jax.vmap(E_grad_s_squared)(t_q, t_q, x_keys))

        return loss
    
    def dice_loss(state, params, key):
        key, mu_key = jax.random.split(key)
        i_mus = jax.random.choice(mu_key, n_mu, shape=(bs_mu,), replace=False)
        loss = jnp.mean(jax.vmap(lambda i_mu: dice_loss_mu(state, params, key, i_mu))(i_mus))
        return loss
    
    return dice_loss

In [21]:
###
# Our loss function expects a tensor of shape (Nx, Nt, Nμ, d). 
# We have no μ here, so we add a dummy dimension.

x_data = jnp.reshape(x_data, (Nt, Nx, 1, d))
x_data = jnp.transpose(x_data, (1, 0, 2, 3))  # Shape is now (Nx, Nt, Nμ, d)

In [22]:
loss_fn = get_dice_loss(s, x_data, t_data, μ_data, nx, nt, nμ)

@jax.jit
def train_step(state, key):
  grad_fn = jax.value_and_grad(loss_fn, argnums=1)
  loss, grads = grad_fn(state, state.params, key)
  state = state.apply_gradients(grads=grads)
  return state, loss

key, loc_key = jax.random.split(key)
state, loss = train_step(state, loc_key)  # compilation

In [None]:
###
# Optimization loop
###

from tqdm import tqdm
loss_plot = [ ]
states =    [ ]
key, loop_key = jax.random.split(key)

with tqdm(range(num_iterations)) as pbar:
    for iter in pbar:
        key, _ = jax.random.split(key)
        state, loss = train_step(state, key)
        loss_plot.append(loss)
        if iter % 10 == 0:
            states.append(state) # we save the state every 10 iterations for diagnostics
        pbar.set_postfix({"loss": loss})
loss_plot_np = np.array(loss_plot)
s_opt_state = states[-1]

###
# Alternatively: Load pre-trained model
# state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=optimizer)
# s_opt_state = state
# s_opt_state = checkpoints.restore_checkpoint(ckpt_dir="...", target=s_opt_state)
### (Path has to be absolute)

 33%|██████████████████████████████████████████████████████████▌                                                                                                                      | 6619/20000 [00:51<01:43, 129.17it/s, loss=-0.04482279]

In [24]:
### Generate samples with explicit euler
def euler(x, f, t, dt):
    return x + dt * f(x, t)

def generate_sample_s(x0, v, times): 
    def step(carry, t_next):
        x, t_prev = carry
        dt = t_next - t_prev
        x_new = euler(x, v, t_prev, dt)
        new_carry = (x_new, t_next)
        return new_carry, x_new

    init = (x0, times[0])
    carry, xs = jax.lax.scan(step, init, times[1:])
    xs = jnp.vstack([x0[None, ...], xs])
    return xs

def grad_s(params, x, t, mu):
    return jax.grad(s,1)(params, x, t, mu)

def generate_batch(params, nx, times, mu, key):
    key, genkey = jax.random.split(key)
    x0 = jax.random.normal(key, (nx, 2)) * 0.1 + jnp.array([0.5, 0.5])
    v = lambda x, t: grad_s(params, x, t, mu)
    x = jax.vmap(lambda x0: generate_sample_s(x0, v, times))(x0)
    return x

In [None]:
# Evaluate the model
key, eval_key = jax.random.split(key)


x_gen = jax.vmap(lambda mu, key: generate_batch(s_opt_state.params, Nx, t_data, mu, key), out_axes=2)(μ_data, jax.random.split(key, len(μ_data)))

In [26]:
def save_gif(x_data, name):
    fig, ax = plt.subplots()
    fig.subplots_adjust(left=0, right=1, bottom=0, top=1)
    ax.set_facecolor('white')
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')

    # Define the colors for the quadrants
    COLOR_TOP_RIGHT = 'purple'
    COLOR_BOTTOM_RIGHT = 'black'
    COLOR_BOTTOM_LEFT = 'teal'
    COLOR_TOP_LEFT = 'grey'

    # Get the initial positions (x and y) from the first frame
    initial_positions = x_data[:, 0, 0, :]  # Shape (Nx, d)
    initial_x = initial_positions[:, 0]
    initial_y = initial_positions[:, 1]

    # Create boolean masks for each quadrant based on initial positions
    mask_top_right = (initial_x >= 0.5) & (initial_y >= 0.5)
    mask_bottom_right = (initial_x >= 0.5) & (initial_y < 0.5)
    mask_bottom_left = (initial_x < 0.5) & (initial_y < 0.5)
    mask_top_left = (initial_x < 0.5) & (initial_y >= 0.5)

    # Create an array to hold the color for each particle
    # Initialize with a default color (e.g., black)
    particle_colors = np.full(Nx, COLOR_BOTTOM_RIGHT, dtype=object)

    # Use the masks to set the colors for each quadrant
    particle_colors[mask_top_right] = COLOR_TOP_RIGHT
    particle_colors[mask_bottom_left] = COLOR_BOTTOM_LEFT
    particle_colors[mask_top_left] = COLOR_TOP_LEFT

    # Create the initial scatter plot object, passing the array of colors to `c`
    scatter = ax.scatter(
        initial_x, 
        initial_y, 
        c=particle_colors,
        s=2
    )

    def update(frame_num):
        """Updates the scatter plot positions for a given frame."""
        new_positions = x_data[:, frame_num, 0, :]
        scatter.set_offsets(new_positions)
        return scatter,

    anim = FuncAnimation(fig, update, frames=Nt, interval=50, blit=True)

    print("Saving animation...")
    anim.save(name, writer='pillow', dpi=150)
    print(f"Animation saved as {name}")

    plt.close(fig)

In [None]:
save_gif(x_data, 'true.gif')
save_gif(x_gen, 'generated.gif')

Saving animation...
Animation saved as generated.gif
