In [1]:
import equinox as eqx
from jax import numpy as jnp, random as jr, vmap, nn, lax

import numpy as np
from matplotlib import pyplot as plt

from pets.control import plan
from pets.model import Ensemble
from pets.dataset import Dataset
from pets.envs.atari import AtariEnv, reward_fn

In [2]:
def rollout_fn(state, action):
    states = model.rollout(state, action, stats, key)
    rewards = reward_fn(states)
    return states, rewards.mean(-1)[..., None]


def act(state):
    probs, _, _ = plan(state, rollout_fn, action_dim, key)
    return probs.mean(1).argmax(-1)[0]

In [3]:
""" experiment parameters """
key = jr.PRNGKey(1)

num_steps = 80
ensemble_dim, hidden_dim = 5, 200

""" environment """
env = AtariEnv("PongDeterministic-v4", render_mode="human")
state_dim, action_dim = env.observation_space.shape[0], 3

""" model """
model = Ensemble(state_dim + 1, state_dim, hidden_dim, ensemble_dim, key=key)
model = eqx.tree_deserialise_leaves("../data/model.eqx", model)

""" dataset """
dataset = Dataset.load("../data/dataset.pkl")
stats = dataset.stats()

""" test model """
(state, _), total_reward = env.reset(), 0.0
for _ in range(num_steps):
    action = act(state)
    state, reward, done, truncated, info = env.step(action)
    total_reward = total_reward + reward

env.close()
print(f"reward: {total_reward}")

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


[109.  22.   0.  60. 109.  22.   0.  60.]
[109.  16.   0.   0.   0.  -6.   0. -60.]
[109.   8.   0.   0.   0.  -8.   0.   0.]
[109.   2.   0.   0.   0.  -6.   0.   0.]
[109.   0.   0.   0.   0.  -2.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109.   0.   0.   0.   0.   0.   0.   0.]
[109. 130. 126. 130.   0. 130. 126. 130.]
[109. 134. 122. 134.   0.   4.  -4.   4.]
[109. 138. 118. 138.   0.   4.  -4.   4.]
[109. 142. 114. 142.   0.   4.  -4.   4.]
[109. 146. 110. 146.   0.   4.  -4.   4.]
[109. 150. 106. 150.   0.   4.  -4.   4.]
[109. 154. 102. 154.   0.   4.  -4.   4.]
[109. 158.  98. 158.   0.   4.  -4.   4.]
[109. 162.  94. 162.   0.   4.  -4

: 