In [1]:
import optax
import equinox as eqx
from jax import Array, numpy as jnp, random as jr, tree_util as jtu, vmap, lax

from matplotlib import pyplot as plt

from rssm import Model

In [2]:
def init_dataset(key: jr.PRNGKey, T: int, B: int) -> tuple[Array, Array]:
    t = jnp.linspace(0, 2 * jnp.pi, T)
    phases = jr.uniform(key, shape=(B,), minval=0, maxval=2 * jnp.pi)

    def per_batch(i):
        phase_i = phases[i]
        x = jnp.sin(t + phase_i)
        y = jnp.cos(t + phase_i)
        return jnp.stack([x, y], axis=-1)

    obs_seq = vmap(per_batch)(jnp.arange(B))
    action_seq = jnp.zeros((B, T, 1))
    return obs_seq, action_seq

In [None]:
@eqx.filter_jit
def train_step(params, obs_seq, action_seq, optimizer, opt_state, key):
    def loss_fn(params):
        B, T, D = obs_seq.shape
        out_seq, post_logits, prior_logits, _, _ = vmap(
            lambda o, a, k: forward(params, o, a, k)
        )(obs_seq, action_seq, jr.split(key, B))
        return mse_loss(obs_seq, out_seq) + kl_loss(post_logits, prior_logits)

    loss, grads = eqx.filter_value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = eqx.apply_updates(params, updates)
    return params, opt_state, loss


def forward(params, obs_seq, action_seq, key):
    rssm, encoder, decoder = params

    embed_seq = vmap(encoder)(obs_seq)
    init_state = rssm.init_state()
    outputs = rssm.rollout(init_state, embed_seq, action_seq, key)
    post_logits, prior_logits, deter, stoch = outputs
    out_feats = jnp.concatenate([deter, stoch], axis=-1)
    out_seq = vmap(decoder)(out_feats)

    return out_seq, post_logits, prior_logits, deter, stoch


def kl_loss(
    prior_logits: Array, post_logits: Array, free_nats: float = 1.0, alpha: float = 0.5
) -> Array:
    kl_lhs = optax.losses.kl_divergence_with_log_targets(
        lax.stop_gradient(post_logits), prior_logits
    ).sum(axis=-1)
    kl_rhs = optax.losses.kl_divergence_with_log_targets(
        post_logits, lax.stop_gradient(prior_logits)
    ).sum(axis=-1)

    kl_lhs, kl_rhs = jnp.mean(kl_lhs), jnp.mean(kl_rhs)
    if free_nats > 0.0:
        kl_lhs = jnp.maximum(kl_lhs, free_nats)
        kl_rhs = jnp.maximum(kl_rhs, free_nats)
    return (alpha * kl_lhs) + ((1 - alpha) * kl_rhs)


def mse_loss(out_seq: Array, obs_seq: Array) -> Array:
    return jnp.mean(jnp.sum((out_seq - obs_seq) ** 2, axis=-1))

In [4]:
num_epochs = 500

key = jr.PRNGKey(0)
keys = jr.split(key, 5)

rssm = Model(
    embed_dim=64,
    action_dim=1,
    deter_dim=200,
    num_discrete=16,
    num_classes=16,
    hidden_dim=200,
    key=keys[0],
)

encoder = eqx.nn.MLP(in_size=2, out_size=64, width_size=200, depth=2, key=keys[1])
decoder = eqx.nn.MLP(
    in_size=200 + (16 * 16), out_size=2, width_size=200, depth=2, key=keys[2]
)
params = (rssm, encoder, decoder)

optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adam(1e-3))
opt_state = optimizer.init(eqx.filter(params, eqx.is_array))

obs_seq, action_seq = init_dataset(keys[4], T=50, B=50)

key = keys[5]
for epoch in range(num_epochs):
    key, train_key = jr.split(key)
    params, opt_state, loss = train_step(
        params, obs_seq, action_seq, optimizer, opt_state, train_key
    )
    if epoch % 100 == 0:
        print(f"Epoch {epoch}: {loss:.4f}")

Epoch 0: 83.5334
Epoch 100: 1.9986
Epoch 200: 1.9962
Epoch 300: 2.0045
Epoch 400: 1.8640


In [None]:
def rollout(params, obs_seq, action_seq, key):
    keys = jr.split(key, 2)
    rssm, encoder, decoder = params

    outputs = forward(params, obs_seq, action_seq, keys[0])
    out_seq, _, _, deter, stoch = outputs

    outputs = rssm.rollout_prior((stoch[-1], deter[-1]), action_seq, keys[1])
    prior_logits, deter, stoch = outputs

    out_feats = jnp.concatenate([deter, stoch], axis=-1)
    rollout_seq = vmap(decoder)(out_feats)
    return out_seq, rollout_seq


post_seq, rollout_seq = rollout(params, obs_seq[0], action_seq[0], key)

plt.figure(figsize=(6, 6))
plt.plot(
    obs_seq[0][:, 0],
    obs_seq[0][:, 1],
    label="GT Traj",
    marker="o",
    color="orange",
)
plt.plot(
    post_seq[:, 0],
    post_seq[:, 1],
    label="Post Traj",
    marker="x",
    color="red",
)
plt.plot(
    rollout_seq[:, 0],
    rollout_seq[:, 1],
    label="Rollout Traj",
    marker="x",
    color="green",
)
plt.legend()
plt.xlabel("x (sine)")
plt.ylabel("y (cosine)")
plt.axis("equal")
plt.grid(True)
plt.tight_layout()
plt.show()

NameError: name 'rollout' is not defined