In [1]:
import optax
import equinox as eqx

from jax import numpy as jnp, random as jr, nn

from logsumexp import Hopfield, SlotAttention, solve, energy


@eqx.filter_jit
def train_step(edges, nodes, opt_state, optim):
    nodes = solve(edges, nodes)
    grads = eqx.filter_grad(lambda edge: energy(edge, nodes))(edges)
    updates, opt_state = optim.update(grads, opt_state)
    edges = eqx.apply_updates(edges, updates)
    return edges, nodes, opt_state


key = jr.PRNGKey(0)
D, M, K, N = 4, 3, 10, 32

x = jnp.zeros((N, D))
z = jnp.zeros((K, D))
m = jnp.zeros((M, D))

nodes = {"x": x, "z": z, "m": m}
edges = [(Hopfield(), ["z", "m"]), (SlotAttention(D, key), ["x", "z"])]

optim = optax.adam(1e-3)
opt_state = optim.init(eqx.filter(edges, eqx.is_array))

edges, nodes, opt_state = train_step(edges, nodes, opt_state, optim)