In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np

from flow_matching.field.gaussian import u_ot
from flow_matching.dataset.toy import ToyDataset

In [None]:
def draw_field(u, x_range=(-1, 1), y_range=(-1, 1)):
    x, y = np.mgrid[x_range[0]:x_range[1]:20j, y_range[0]:y_range[1]:20j]
    p = np.vstack([x.ravel(), y.ravel()]).T
    dp = u(p)
    plt.quiver(p[:, 0], p[:, 1], dp[:, 0], dp[:, 1])

draw_field(lambda x: x)

In [None]:
ds = ToyDataset()
x1 = ds.sample(1)[0]

x0 = np.random.multivariate_normal([0, 0], np.eye(2), 1000)

plt.scatter(x0[:, 0], x0[:, 1], s=3, alpha=0.5)
plt.plot(x1[0], x1[1], 'ro')

t = 0.5

xt = (1 - t) * x0 + t * x1
plt.scatter(xt[:, 0], xt[:, 1], s=3, alpha=0.5, c='g')

# draw_field(lambda x: u_ot(x, t, x1), x_range=(-4, 4), y_range=(-4, 4))

In [None]:
# draw animation of the field as t changes from 0 to 1

from matplotlib.animation import FuncAnimation

fig, ax = plt.subplots()
x_range = (-4, 4)
y_range = (-4, 4)
x, y = np.mgrid[x_range[0]:x_range[1]:20j, y_range[0]:y_range[1]:20j]
p = np.vstack([x.ravel(), y.ravel()]).T

dp = u_ot(p, 0.0, x1)
q = ax.quiver(p[:, 0], p[:, 1], dp[:, 0], dp[:, 1])
ax.plot(x1[0], x1[1], 'ro')[0]

def update(t):
    dp = u_ot(p, t, x1)
    q.set_UVC(dp[:, 0], dp[:, 1])
    return (q,)

ani = FuncAnimation(fig, update, frames=np.linspace(0, 1, 100), blit=True)
# ani.save('ot_field.gif', writer='imagemagick', fps=10)
plt.show()


In [None]:
x1 = ds.sample(4000)
x0 = np.random.multivariate_normal([0, 0], np.eye(2), len(x1))
xt = (1 - t) * x0 + t * x1

t = 0.2
plt.scatter(x1[:, 0], x1[:, 1], s=3, alpha=0.3)
draw_field(lambda x: jax.vmap(u_ot, in_axes=(None, None, 0))(x, t, x1).mean(0), x_range=(-8, 8), y_range=(-4, 4))
