In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline
%config InlineBackend.figure_format='retina'

from math import ceil

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpy as np
from diffrax import diffeqsolve, ODETerm, Dopri5, SaveAt

from flow_matching.field.gaussian import u, p
from flow_matching.dataset.toy import ToyDataset

In [None]:
def draw_field(u, x_range=(-1, 1), y_range=(-1, 1), ax=None):
    x_, y_ = np.mgrid[x_range[0] : x_range[1] : 30j, y_range[0] : y_range[1] : 30j]
    x = np.vstack([x_.ravel(), y_.ravel()]).T
    dx = u(x)
    dx = dx / np.linalg.norm(dx, axis=1, keepdims=True)
    if ax is None:
        plt.quiver(x[:, 0], x[:, 1], dx[:, 0], dx[:, 1])
    else:
        ax.quiver(x[:, 0], x[:, 1], dx[:, 0], dx[:, 1])


draw_field(lambda x: x)

In [None]:
ds = ToyDataset()
x1 = ds.sample(1)[0]

x0 = np.random.multivariate_normal([0, 0], np.eye(2), 1000)

plt.scatter(x0[:, 0], x0[:, 1], s=3, alpha=0.5)
plt.plot(x1[0], x1[1], "ro")

t = 0.5

xt = (1 - t) * x0 + t * x1
plt.scatter(xt[:, 0], xt[:, 1], s=3, alpha=0.5, c="g")

# draw_field(lambda x: u_ot(x, t, x1), x_range=(-4, 4), y_range=(-4, 4))

In [None]:
x1 = ds.sample(4000)
# x0 = np.random.multivariate_normal([0, 0], np.eye(2), len(x1))
# # xt = (1 - t) * x0 + t * x1


def u_t(x, t):
    """u_t(x) ∝ E_{x_1 ~ q} u_t(x | x_1) p_t(x | x_1)"""
    u_t_given_x1 = jax.vmap(u, in_axes=(None, None, 0))(x, t, x1)
    p_t_given_x1 = jax.vmap(p, in_axes=(None, None, 0))(x, t, x1)
    p_t = p_t_given_x1.mean(0)
    return (u_t_given_x1 * p_t_given_x1[..., None]).mean(axis=0) / p_t[..., None]
    ans = jnp.exp(logu_t_given_x1 + logp_t_given_x1).mean(0) / jnp.exp(logp_t)
    return ans


ts = np.linspace(0.001, 0.999, 10)
fig, axs = plt.subplots(3, ceil(len(ts) / 3), figsize=(16, 10))
for t, ax in zip(ts, axs.flatten()):
    ax.scatter(x1[:, 0], x1[:, 1], s=3, alpha=0.3)
    draw_field(
        lambda x: u_t(x, t),
        x_range=(-8, 8),
        y_range=(-6, 6),
        ax=ax,
    )
    ax.set_title(f"t = {t:.1f}")

# t = 0.1
# plt.scatter(x1[:, 0], x1[:, 1], s=3, alpha=0.3)
# draw_field(
#     lambda x: u_t(x, t),
#     x_range=(-8, 8),
#     y_range=(-4, 4),
# )

In [None]:
x1 = ds.sample(2000)
x0 = np.random.multivariate_normal([0, 0], np.eye(2), 500)
# # xt = (1 - t) * x0 + t * x1


# def u_t(x, t):
#     """u_t(x) ∝ E_{x_1 ~ q} u_t(x | x_1) p_t(x | x_1)"""
#     u_t_given_x1 = jax.vmap(u, in_axes=(None, None, 0))(x, t, x1)
#     p_t_given_x1 = jax.vmap(p, in_axes=(None, None, 0))(x, t, x1)
#     return (u_t_given_x1 * p_t_given_x1).mean(0)


def f(t, x, args):
    u_t_given_x1 = jax.vmap(u, in_axes=(None, None, 0))(x, t, x1)
    p_t_given_x1 = jax.vmap(p, in_axes=(None, None, 0))(x, t, x1)
    p_t = p_t_given_x1.mean(0)
    return (u_t_given_x1 * p_t_given_x1[..., None]).mean(axis=0) / p_t[..., None]
    # return (u_t_given_x1 * p_t_given_x1).mean(0) / p_t


term = ODETerm(f)
solver = Dopri5()
saveat = SaveAt(ts=jnp.linspace(0, 0.995, 51))
solution = diffeqsolve(term, solver, t0=0, t1=0.995, dt0=0.05, y0=x0, saveat=saveat)

In [None]:
T, N, _ = solution.ys.shape  # (T N 2)

from matplotlib.collections import LineCollection

cm = plt.colormaps["coolwarm"]

lines = []
colors = []
for i in range(T - 1):
    for j in range(N):
        lines.append(solution.ys[i : i + 2, j])
        colors.append(cm(i / T))

lc = LineCollection(lines, colors=colors, alpha=0.15)
plt.gca().add_collection(lc)
plt.scatter(solution.ys[0, :, 0], solution.ys[0, :, 1], s=2, alpha=0.3, color=cm(0.0))
plt.scatter(solution.ys[-1, :, 0], solution.ys[-1, :, 1], s=2, alpha=0.3, color=cm(1.0))