In [104]:
import jax
import jumanji
from jumanji.wrappers import AutoResetWrapper
import numpy as np

In [20]:
env = jumanji.make("Maze-v0")  # Create a Snake environment
env = AutoResetWrapper(env)     # Automatically reset the environment when an episode terminates

In [270]:
batch_size = 1
rollout_length = 1000
num_actions = env.action_spec.num_values

In [271]:
random_key = jax.random.PRNGKey(0)
key1, key2 = jax.random.split(random_key)

In [272]:
def step_fn(state, key):
  action = jax.random.randint(key=key, minval=0, maxval=num_actions, shape=())
  new_state, timestep = env.step(state, action)
  return new_state, {"state": state.agent_position, "action": action, "next_state": new_state.agent_position, "reward": timestep.reward, "env_info": timestep}

def run_n_steps(state, key, n):
  random_keys = jax.random.split(key, n)
  state, rollout = jax.lax.scan(step_fn, state, random_keys)
  return rollout

# Instantiate a batch of environment states
keys = jax.random.split(key1, batch_size)
state, timestep = jax.vmap(env.reset)(keys)

# Collect a batch of rollouts
keys = jax.random.split(key2, batch_size)
rollout = jax.vmap(run_n_steps, in_axes=(0, 0, None))(state, keys, rollout_length)

In [273]:
rollout["reward"]

Array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [245]:
j = 0
states = [np.array([rollout["timestep"].observation.agent_position.row[j][i], rollout["timestep"].observation.agent_position.col[j][i]]) for i in range(0, rollout_length-1)]
next_states = [np.array([rollout["timestep"].observation.agent_position.row[j][i], rollout["timestep"].observation.agent_position.col[j][i]]) for i in range(1, rollout_length)]
rewards = [np.array(rollout["timestep"].reward[j][i]) for i in range(1, rollout_length)]
actions = [rollout["actions"][j][i] for i in range(1, rollout_length)]
target_position = [np.array([rollout["timestep"].observation.target_position.row[j][i], rollout["timestep"].observation.target_position.col[j][i]]) for i in range(0, rollout_length-1)]

In [None]:
for k in range(len(states)):
    print(states[k], actions[k], next_states[k], rollout["timestep"].observation.action_mask[j][k], target_position[k], rewards[k], (target_position[k] == next_states[k]).all())

[3 0] 2 [4 0] [ True False  True False] [9 8] 0.0 False
[4 0] 0 [3 0] [ True  True  True False] [9 8] 0.0 False
[3 0] 3 [3 0] [ True False  True False] [9 8] 0.0 False
[3 0] 2 [4 0] [ True False  True False] [9 8] 0.0 False
[4 0] 2 [5 0] [ True  True  True False] [9 8] 0.0 False
[5 0] 1 [5 0] [ True False  True False] [9 8] 0.0 False
[5 0] 2 [6 0] [ True False  True False] [9 8] 0.0 False
[6 0] 1 [6 0] [ True False  True False] [9 8] 0.0 False
[6 0] 1 [6 0] [ True False  True False] [9 8] 0.0 False
[6 0] 2 [7 0] [ True False  True False] [9 8] 0.0 False
[7 0] 0 [6 0] [ True False  True False] [9 8] 0.0 False
[6 0] 2 [7 0] [ True False  True False] [9 8] 0.0 False
[7 0] 3 [7 0] [ True False  True False] [9 8] 0.0 False
[7 0] 2 [8 0] [ True False  True False] [9 8] 0.0 False
[8 0] 3 [8 0] [ True  True False False] [9 8] 0.0 False
[8 0] 3 [8 0] [ True  True False False] [9 8] 0.0 False
[8 0] 1 [8 1] [ True  True False False] [9 8] 0.0 False
[8 1] 3 [8 0] [False  True False  True] [9 8] 0.

array([[False,  True],
       [False,  True],
       [False,  True],
       ...,
       [False, False],
       [False, False],
       [False, False]])

In [170]:
rewards = [rollout["timestep"].reward[j][i] for i in range(rollout_length)]

In [171]:
rewards

[Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 Array(0., dtype=float32),
 