In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [3]:
import jax
import jax.numpy as jnp
import gymnax
from bsuite.environments import catch, deep_sea

# Catch

In [4]:
rng, reset, step, env_params = gymnax.make("Catch-bsuite")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)
obs.shape

(10, 5)

In [5]:
env = catch.Catch()
timestep = env.reset()

In [6]:
state = jnp.array([env._ball_x, env._ball_y,
                   env._paddle_x, env._paddle_y, 0,
                   env._reset_next_step]).copy()
action = 1
timestep = env.step(action)
print(timestep.observation)
print(timestep.reward)
print(timestep.discount)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 1. 0. 0.]]
0.0
1.0


In [7]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, 1-done_jax

(DeviceArray([[0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 1.],
              [0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0.],
              [0., 0., 1., 0., 0.]], dtype=float32),
 DeviceArray(0, dtype=int32),
 DeviceArray(1, dtype=int32))

# DeepSea

In [8]:
rng, reset, step, env_params = gymnax.make("DeepSea-bsuite")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)
state

{'bad_episode': DeviceArray(False, dtype=bool),
 'column': DeviceArray(0, dtype=int32),
 'denoised_return': DeviceArray(0, dtype=int32),
 'optimal_no_cost': DeviceArray(1., dtype=float32),
 'optimal_return': DeviceArray(0.99, dtype=float32),
 'row': DeviceArray(0, dtype=int32),
 'terminal': DeviceArray(False, dtype=bool),
 'total_bad_episodes': DeviceArray(0, dtype=int32)}

In [9]:
env = deep_sea.DeepSea(size=8)
obs = env.reset()

In [10]:
action = 1
timestep = env.step(action)
print(timestep.observation)
print(timestep.reward)
print(timestep.discount)

[[0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0. 0. 0. 0.]]
-0.00125
1.0


In [11]:
env_params

FrozenDict({
    size: 8,
    deterministic: True,
    unscaled_move_cost: 0.01,
    randomize_actions: True,
})

In [12]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, 1-done_jax

(DeviceArray([[0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.],
              [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32),
 DeviceArray([-0.00125], dtype=float32),
 DeviceArray(1, dtype=int32))