In [None]:
from typing import *
import jax
import jax.numpy as jnp
import numpy as np
import shapely
import matplotlib.pyplot as plt
import diffrax
import matplotlib.animation as animation
import matplotlib
import flax.linen as nn
import flax
import chex
import optax
import warnings

import sys

sys.path.append("../..")
from action_angle_networks import models

matplotlib.rc("animation", html="jshtml")

In [None]:
def fq(q):
    return jnp.concatenate(
        [q[0, None] + 2 * q[1, None], q[0, None] + 5 * q[1, None]], axis=-1
    )


positions = jnp.array([1, 2]).astype(jnp.float32)
momentums = jnp.array([3, 4]).astype(jnp.float32)
positions_grad1, momentums_grad1_fn = jax.vjp(fq, positions)
(momentums_grad1,) = momentums_grad1_fn(momentums)
momentums_grad1 = -momentums_grad1
jac = jax.jacobian(fq)(positions).squeeze()
print(jac.shape, momentums.shape)
expected_momentums_grad1 = -jac.T @ momentums
print("jac", jac)
print("positions_grad1", positions_grad1.shape)
print("expected_momentums_grad1", expected_momentums_grad1)
print("computed_momentums_grad", momentums_grad1)

In [None]:
# rhombus
q_init = jnp.linspace(0, 1, 100)
p_init = 1 - q_init
qs = jnp.concatenate([q_init[::-1], -q_init, -q_init[::-1], q_init])
ps = jnp.concatenate([p_init[::-1], p_init, -p_init[::-1], -p_init])

# astroid
ts = jnp.linspace(0, 2 * jnp.pi, 100)
qs = jnp.cos(ts) ** 3
ps = jnp.sin(ts) ** 3

qs, ps = qs[:, None], ps[:, None]
plt.plot(qs, ps)
plt.xlim(-5, 5)
plt.ylim(-5, 5)
plt.show();

In [None]:
def compute_area(qs, ps):
    pgon = shapely.geometry.Polygon(zip(qs, ps))
    return pgon.area

In [None]:
class ContinuousSymplecticFlow(flax.struct.PyTreeNode):

    position_matched_encoder: nn.Module
    momentum_matched_encoder: nn.Module
    position_cross_encoder: nn.Module
    momentum_cross_encoder: nn.Module
    concatenate_time: bool = True

    def init(
        self, init_rng: chex.PRNGKey, positions: chex.Array, momentums: chex.Array
    ) -> optax.Params:
        init_rngs = jax.random.split(init_rng, num=4)

        if self.concatenate_time:
            positions = jnp.concatenate([positions, jnp.ones_like(positions)], axis=1)
            momentums = jnp.concatenate([momentums, jnp.ones_like(momentums)], axis=1)

        return flax.core.FrozenDict(
            {
                "position_matched_encoder": self.position_matched_encoder.init(
                    init_rngs[0], positions
                ),
                "momentum_matched_encoder": self.momentum_matched_encoder.init(
                    init_rngs[1], momentums
                ),
                "position_cross_encoder": self.position_cross_encoder.init(
                    init_rngs[2], positions
                ),
                "momentum_cross_encoder": self.momentum_cross_encoder.init(
                    init_rngs[3], momentums
                ),
            }
        )

    def make_compute_derivatives_fn(self, params: optax.Params, direction: str):
        """Returns a function that computes the derivatives of the symplectic flow."""
        assert direction in ["forward", "backward"]

        def compute_derivatives(t, y, unused_args):
            del unused_args

            positions, momentums = y
            num_dims = positions.shape[-1]

            # Concatenate time to the positions and momentums.
            if self.concatenate_time:
                positions_with_t = jnp.concatenate(
                    [positions, jnp.ones_like(positions) * t], axis=1
                )
                momentums_with_t = jnp.concatenate(
                    [momentums, jnp.ones_like(momentums) * t], axis=1
                )

            # Symplectic flow of type 1.
            positions_grad1, momentums_grad1_fn = jax.vjp(
                lambda q: self.position_matched_encoder.apply(
                    params["position_matched_encoder"], q
                ),
                positions_with_t,
            )
            (momentums_grad1,) = momentums_grad1_fn(momentums)
            momentums_grad1 = -momentums_grad1[:, :num_dims]

            # Symplectic flow of type 2.
            momentums_grad2, positions_grad2_fn = jax.vjp(
                lambda p: self.momentum_matched_encoder.apply(
                    params["momentum_matched_encoder"], p
                ),
                momentums_with_t,
            )
            (positions_grad2,) = positions_grad2_fn(positions)
            positions_grad2 = -positions_grad2[:, :num_dims]

            # Symplectic flow of type 3.
            positions_grad3 = jax.vmap(
                jax.grad(
                    lambda p: self.momentum_cross_encoder.apply(
                        params["momentum_cross_encoder"], p
                    ).squeeze()
                )
            )(momentums_with_t)
            positions_grad3 = positions_grad3[:, :num_dims]
            momentums_grad3 = jax.vmap(
                jax.grad(
                    lambda q: self.position_cross_encoder.apply(
                        params["position_cross_encoder"], q
                    ).squeeze()
                )
            )(positions_with_t)
            momentums_grad3 = momentums_grad3[:, :num_dims]

            # Add up contributions from each of the symplectic flows.
            positions_grad = positions_grad1 + positions_grad2 + positions_grad3
            momentums_grad = momentums_grad1 + momentums_grad2 + momentums_grad3

            return positions_grad, momentums_grad

        # If we are integrating backwards in time, we need to invert time as if we had started at t1.
        # Note, t0 = 0 and t1 = 1 for this inversion to be valid.
        # We also need to flip the sign of the gradients.
        if direction == "backward":
            return lambda t, y, args: tuple(
                -grad for grad in compute_derivatives(1 - t, y, args)
            )

        return compute_derivatives

    def _apply(
        self,
        params: optax.Params,
        positions: chex.Array,
        momentums: chex.Array,
        save_intermediates: bool,
        direction: str,
    ) -> Tuple[chex.Array, chex.Array]:
        """Helper function to integrate both forwards and backwards in time."""

        if direction not in ["forward", "backward"]:
            raise ValueError(
                f"direction must be either 'forward' or 'backward', got {direction}"
            )

        term = diffrax.ODETerm(self.make_compute_derivatives_fn(params, direction))
        solver = diffrax.Dopri5()
        if save_intermediates:
            saveat = diffrax.SaveAt(ts=jnp.linspace(0, 1, 100))
        else:
            saveat = diffrax.SaveAt(t1=True)

        solution = diffrax.diffeqsolve(
            term,
            solver,
            t0=0,
            t1=1,
            dt0=0.1,
            y0=(positions, momentums),
            saveat=saveat,
        )
        positions, momentums = solution.ys

        if not save_intermediates:
            positions = jnp.squeeze(positions, axis=0)
            momentums = jnp.squeeze(momentums, axis=0)

        return positions, momentums

    def apply(
        self,
        params: optax.Params,
        positions: chex.Array,
        momentums: chex.Array,
        save_intermediates: bool = False,
    ) -> Tuple[chex.Array, chex.Array]:
        """Performs the forward computation of the symplectic flow."""
        return self._apply(
            params,
            positions,
            momentums,
            save_intermediates=save_intermediates,
            direction="forward",
        )

    def inverse_apply(
        self,
        params: optax.Params,
        positions: chex.Array,
        momentums: chex.Array,
        save_intermediates: bool = False,
    ) -> Tuple[chex.Array, chex.Array]:
        """Performs the inverse computation of the symplectic flow."""
        return self._apply(
            params,
            positions,
            momentums,
            save_intermediates=save_intermediates,
            direction="backward",
        )


# q_tau, p_tau = R(omega * tau) * (q_0, p_0)
# q_tau, p_tau = ((cos(omega * tau), -sin(omega * tau)), (sin(omega * tau), cos(omega * tau))) * (q_0, p_0)
# q_tau, p_tau = ((q_0cos(omega * tau) - p_0sin(omega * tau)), (q_0sin(omega * tau) + p_0cos(omega * tau)))
# d(q_tau, p_tau)/dtau = (-q_0*omega*sin(omega * tau) - p_0*omega*cos(omega * tau), q_0*omega*cos(omega * tau) - p_0*omega*sin(omega * tau))
# d(q_tau, p_tau)/dtau = omega * (-p_tau, q_tau)


positions, momentums = jnp.concatenate((qs, qs), axis=-1), jnp.concatenate(
    (ps, ps), axis=-1
)
target_positions, target_momentums = (positions + 1) / 1.5, (momentums + 5) * 1.5


def make_rotation_matrix(theta):
    return jnp.array(
        [[jnp.cos(theta), -jnp.sin(theta)], [jnp.sin(theta), jnp.cos(theta)]]
    )


theta1 = -1.0
R1 = make_rotation_matrix(theta1)
target_coords1 = jnp.concatenate(
    [target_positions[:, 0, None], target_momentums[:, 0, None]], axis=-1
)
target_coords1 = jnp.matmul(target_coords1, R1.T)
target_positions1, target_momentums1 = (
    target_coords1[:, 0, None],
    target_coords1[:, 1, None],
)

theta2 = 1.0
R2 = make_rotation_matrix(theta2)
target_coords2 = jnp.concatenate(
    [target_positions[:, 1, None], target_momentums[:, 1, None]], axis=-1
)
target_coords2 = jnp.matmul(target_coords1, R2.T)
target_positions2, target_momentums2 = (
    target_coords2[:, 0, None],
    target_coords2[:, 1, None],
)

target_positions = jnp.concatenate([target_positions1, target_positions2], axis=-1)
target_momentums = jnp.concatenate([target_momentums1, target_momentums2], axis=-1)

# model = ContinuousSymplecticFlow(
#     position_encoder=models.MLP(latent_sizes=[10, 10, 1], activation=jax.nn.softplus),
#     momentum_encoder=models.MLP(latent_sizes=[10, 10, 1], activation=jax.nn.softplus),
# )
num_dims = positions.shape[-1]
model = ContinuousSymplecticFlow(
    position_matched_encoder=models.MLP(
        latent_sizes=[10, num_dims], activation=jax.nn.softplus
    ),
    momentum_matched_encoder=models.MLP(
        latent_sizes=[10, num_dims], activation=jax.nn.softplus
    ),
    position_cross_encoder=models.MLP(
        latent_sizes=[10, 10, 1], activation=jax.nn.softplus
    ),
    momentum_cross_encoder=models.MLP(
        latent_sizes=[10, 10, 1], activation=jax.nn.softplus
    ),
)
init_rng, rng = jax.random.split(jax.random.PRNGKey(0))
params = model.init(init_rng, positions, momentums)

tx = optax.sgd(1e-3, momentum=0.99)
opt_state = tx.init(params)


@jax.jit
def train_step(step, params, opt_state):
    def loss_fn(params):
        predicted_positions, predicted_momentums = model.apply(
            params, positions, momentums
        )
        return (
            optax.l2_loss(predicted_positions, target_positions)
            + optax.l2_loss(predicted_momentums, target_momentums)
        ).mean()

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(params)
    # jax.debug.print("grad={grad}", grad=grad)
    updates, opt_state = tx.update(grad, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, loss


num_optimization_steps = 10000
all_predicted_positions, all_predicted_momentums = None, None
for step in range(num_optimization_steps):
    params, loss = train_step(step, params, opt_state)

    with warnings.catch_warnings():
        warnings.simplefilter(action="ignore", category=FutureWarning)
        if step % 100 == 0:
            print(f"Step {step}: loss = {loss}")

    if loss < 1e-3:
        break

In [None]:
all_predicted_positions, all_predicted_momentums = model.apply(
    params, positions, momentums, save_intermediates=True
)

plt.plot(all_predicted_positions[-1], all_predicted_momentums[-1])
plt.plot(target_positions, target_momentums)
# plt.xlim(-5, 5)
# plt.ylim(-5, 5)
plt.show();

In [None]:
def make_animation(
    predicted_positions,
    predicted_momentums,
    target_positions,
    target_momentums,
    direction: str,
):
    fig, ax = plt.subplots()
    (ln1,) = ax.plot([], [], "o", markersize=2, label="Predicted Shape 1", color="C0")
    (ln3,) = ax.plot(
        target_positions[:, 0],
        target_momentums[:, 0],
        "o",
        markersize=2,
        label="Target Shape 1",
        color="C2",
    )
    (ln2,) = ax.plot([], [], "o", markersize=2, label="Predicted Shape 2", color="C1")
    (ln4,) = ax.plot(
        target_positions[:, 1],
        target_momentums[:, 1],
        "o",
        markersize=2,
        label="Target Shape 2",
        color="C3",
    )
    ax.set_xlim(
        -1 + min(predicted_positions.min(), target_positions.min()),
        1 + max(predicted_positions.max(), target_positions.max()),
    )
    ax.set_ylim(
        -1 + min(predicted_momentums.min(), target_momentums.min()),
        1 + max(predicted_momentums.max(), target_momentums.max()),
    )

    ax.legend(loc="upper right")
    ax.set_xlabel("q")
    ax.set_ylabel("p")
    ax.grid()

    num_dims = predicted_positions.shape[-1]

    def init():
        return ln1, ln2

    def update(frame):
        ln1.set_data(predicted_positions[frame, :, 0], predicted_momentums[frame, :, 0])
        ln2.set_data(predicted_positions[frame, :, 1], predicted_momentums[frame, :, 1])

        sum_areas = 0
        for index in range(num_dims):
            sum_areas += compute_area(
                predicted_positions[frame, :, index],
                predicted_momentums[frame, :, index],
            )

        if direction == "forward":
            frame_time = frame / num_frames
        elif direction == "backward":
            frame_time = 1 - frame / num_frames

        ax.set_title(
            "Pseudo-Time = {:0.2f} \n Sum of Predicted Shape Areas= {:0.3f}".format(
                frame_time, sum_areas
            )
        )
        return ln1, ln2

    num_frames = predicted_positions.shape[0]
    ani = animation.FuncAnimation(
        fig, update, frames=num_frames, init_func=init, blit=True
    )
    return ani

In [None]:
all_predicted_positions, all_predicted_momentums = model.apply(
    params, positions, momentums, save_intermediates=True
)
ani = make_animation(
    all_predicted_positions,
    all_predicted_momentums,
    target_positions,
    target_momentums,
    direction="forward",
)
ani.save("notebook_outputs/continuous_time_flow_targeted_2dims.mp4")
ani.save("notebook_outputs/continuous_time_flow_targeted_2dims.gif")
ani

In [None]:
all_inverse_predicted_positions, all_inverse_predicted_momentums = model.inverse_apply(
    params, target_positions, target_momentums, save_intermediates=True
)
ani = make_animation(
    all_inverse_predicted_positions,
    all_inverse_predicted_momentums,
    positions,
    momentums,
    direction="backward",
)
ani.save("notebook_outputs/continuous_time_flow_targeted_inverse_2dims.mp4")
ani.save("notebook_outputs/continuous_time_flow_targeted_inverse_2dims.gif")
ani

In [None]:
def fq(q):
    q = q[:, None]
    return (jnp.array([[1, 1], [0, 2]], dtype=np.float32) @ q).squeeze()


def compute_derivatives(t, y, args):
    del t, args
    positions, momentums = y
    positions_grad, momentums_grad_fn = jax.vjp(fq, positions)

    (momentums_grad,) = momentums_grad_fn(momentums)
    momentums_grad = -momentums_grad
    expected_momentums_grad = -jax.jacobian(fq)(positions).squeeze().T @ momentums
    # jax.debug.print("momentums_grad={momentums_grad}", momentums_grad=momentums_grad)
    # jax.debug.print("expected_momentums_grad={expected_momentums_grad}", expected_momentums_grad=expected_momentums_grad)
    return positions_grad, expected_momentums_grad


positions, momentums = jnp.concatenate((qs, qs), axis=-1), jnp.concatenate(
    (ps, ps), axis=-1
)
# positions, momentums = positions[:2], momentums[:2]
term = diffrax.ODETerm(jax.vmap(compute_derivatives, in_axes=(None, 0, None)))
solver = diffrax.Dopri5()
solution = diffrax.diffeqsolve(
    term,
    solver,
    t0=0,
    t1=1,
    dt0=0.1,
    y0=(positions, momentums),
    saveat=diffrax.SaveAt(ts=jnp.linspace(0, 1, 100)),
)
new_positions, new_momentums = solution.ys
plt.plot(new_positions[-1], new_momentums[-1])
# plt.xlim(-5, 5)
# plt.ylim(-5, 5)
plt.show();

In [None]:
for new_position_frame, new_momentums_frame in zip(new_positions, new_momentums):
    areas = 0
    for index in range(num_dims):
        areas += compute_area(
            new_position_frame[:, index], new_momentums_frame[:, index]
        )
    print(areas)