In [None]:
%load_ext autoreload
%reload_ext autoreload
%env XLA_PYTHON_CLIENT_ALLOCATOR=platform

env: XLA_PYTHON_CLIENT_ALLOCATOR=platform


# Hamiltonian Neural Network
> Hamiltonian neural network for metalearning trajectory prediction.

In [None]:
# | default_exp hnn

In [None]:
# | hide
import nbdev
from fastcore.test import test_eq
from nbdev.showdoc import *

In [None]:
# | export
import jax
import jax.numpy as jnp
import equinox as eqx

In [None]:
# | export
from jaxDiversity.mlp import MultiActMLP, deterministic_init, init_linear_weight

In [None]:
# | export
def hamiltonian_factory(model, afuncs):
    """Returns a function that computes the Hamiltonian of a given model."""

    def hamiltonian(q, p):
        """Hamiltonian taking in q and p as 1D arrays."""
        q = q.reshape((1, -1))
        p = p.reshape((1, -1))
        x = jnp.concatenate([q, p], axis=None)
        return model(x, afuncs)[0].reshape(())

    return hamiltonian

In [None]:
# | export
@eqx.filter_value_and_grad()
def compute_loss(model, x, y, afuncs):
    """Computes hamilton's equations to get dqdp and then computes the loss"""
    hamiltonian = hamiltonian_factory(model, afuncs)
    q, p = jnp.split(x, 2, axis=1)
    dHdq = jax.vmap(jax.grad(hamiltonian, argnums=0))(q, p)
    dHdp = jax.vmap(jax.grad(hamiltonian, argnums=1))(q, p)
    dqdp = jnp.concatenate([dHdp, -dHdq], axis=1)  # pred_y
    loss = jnp.mean((dqdp - y) ** 2)

    return loss

In [None]:
# | test
# test compute_loss
key = jax.random.PRNGKey(0)
model_key, init_key = jax.random.split(key)
x = jnp.ones((5, 2))

model = MultiActMLP(2, 1, [18], model_key, bias=False)
model = init_linear_weight(model, deterministic_init, init_key)
y = jnp.ones((5, 2))

afuncs = [lambda x: 1, lambda x: 0]

loss, _ = compute_loss(model, x, y, afuncs)
test_eq(loss, 1.0)

In [None]:
# | hide
nbdev.nbdev_export()