In [1]:
import equinox as eqx
from jax import numpy as jnp, random as jr, vmap, nn, lax

import numpy as np
from matplotlib import pyplot as plt

from pets.control import plan
from pets.model import Ensemble
from pets.dataset import Normalizer
from pets.envs.atari import AtariEnv, reward_fn

In [2]:
_tile = lambda x: jnp.tile(x[:, None, ...], (1, ensemble_dim, 1))

def forward(model, normalizer, state, action, key):
    inputs = jnp.concatenate([state, action], axis=-1)
    inputs = normalizer.normalize(inputs)
    delta_mean, delta_logvar = vmap(model)(inputs)
    delta_std = jnp.sqrt(jnp.exp(delta_logvar))
    delta = delta_mean + delta_std * jr.normal(key, delta_mean.shape)
    return state + delta


@eqx.filter_jit
def rollout_fn(state, actions):
    state, actions = _tile(state), vmap(_tile)(actions)

    def scan_fn(carry, action):
        state, key = carry
        key, subkey = jr.split(key)
        next_state = forward(model, normalizer, state, action, subkey)
        return (next_state, key), next_state

    (final_state, _), states = lax.scan(scan_fn, (state, key), actions)
    rewards = reward_fn(states)
    return states, rewards.mean(-1)[..., None]


In [3]:
key = jr.PRNGKey(0)

env = AtariEnv("PongDeterministic-v4", render_mode="human")
state_dim = env.observation_space.shape[0]
ensemble_dim, hidden_dim, action_dim, num_steps = 5, 200, 3, 70

key, subkey = jr.split(key)
model = Ensemble(state_dim + 1, state_dim, hidden_dim, ensemble_dim, key=key)
model = eqx.tree_deserialise_leaves("../data/model.eqx", model)
normalizer = Normalizer.load("../data/normalizer.pkl")

(state, _), total_reward = env.reset(), 0.0
for _ in range(num_steps):
    key, subkey = jr.split(key)

    probs, _, _ = plan(state, rollout_fn, action_dim, subkey)
    action = probs.mean(1).argmax(-1)[0]

    next_state, reward, done, truncated, info = env.step(action)
    total_reward = total_reward + reward
print(f"reward: {total_reward}")

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


reward: 0.0


: 