In [1]:
import optax
import equinox as eqx
from jax import Array, numpy as jnp, random as jr

from rssm import Model, State, rssm_loss

In [2]:
@eqx.filter_jit
def train_step(model, obs_seq, action_seq, optimizer, opt_state, key):
    def loss_fn(model):
        o_loss, kl_loss = rssm_loss(model, obs_seq, action_seq, key)
        return o_loss + kl_loss

    loss, grads = eqx.filter_value_and_grad(loss_fn)(model)
    updates, opt_state = optimizer.update(grads, opt_state, model)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss


def dataset(key: jr.PRNGKey, T: int, B: int, obs_dim: int, action_dim: int):
    key_obs, key_act = jr.split(key, 2)
    obs_seq = jr.normal(key_obs, (B, T, obs_dim))
    action_seq = jr.normal(key_act, (B, T, action_dim))
    return obs_seq, action_seq

In [None]:
epochs = 2000
T, B = 20, 16
obs_size, action_size = 8, 1
stoch_size, deter_size, embed_size, mlp_hidden_size = 30, 200, 200, 200

key = jr.PRNGKey(0)

key, subkey = jr.split(key)
obs_seq, action_seq = dataset(subkey, T, B, obs_size, action_size)

key, subkey = jr.split(key)
model = Model(
    obs_size=obs_size,
    action_size=action_size,
    stoch_size=stoch_size,
    deter_size=deter_size,
    embed_size=embed_size,
    mlp_hidden_size=mlp_hidden_size,
    key=subkey,
)

optimizer = optax.adam(1e-3)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

for epoch in range(epochs):
    key, subkey = jr.split(key)
    model, opt_state, loss = train_step(
        model, obs_seq, action_seq, optimizer, opt_state, subkey
    )
    if epoch % 100 == 0:
        print(f"{epoch}: {loss:.5f}")

0: 9.42690
100: 7.44335
200: 6.57504
300: 5.15820
400: 3.34989
500: 2.03363
600: 1.33689
700: 0.99205
800: 0.86378
900: 0.69207
1000: 0.63396
1100: 0.56418
1200: 0.52629
1300: 0.48196
1400: 0.54422
1500: 0.43645
1600: 0.48892
1700: 0.38476
1800: 0.41599
1900: 0.38261
