In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
import gymnax
from bsuite.environments import catch, deep_sea, discounting_chain, memory_chain, umbrella_chain, mnist, bandit

# Catch

In [19]:
# # 1. Create gymnax based environment
# rng, env = gymnax.make("Catch-bsuite")
# rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
# obs, state = env.reset(key_reset)

# # 2. Setup bsuite-based environment, get state dict
# env_gym = catch.Catch()
# _ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "Catch-bsuite")

# 3. Transition in both environments
action = 0
timestep = env_gym.step(jnp.array(action))
obs_gym, reward_gym, done_gym = timestep.observation, timestep.reward, 1-timestep.discount
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "Catch-bsuite", state_jax)

In [18]:
_ = env_gym.reset()
_

TimeStep(step_type=<StepType.FIRST: 0>, reward=None, discount=None, observation=array([[1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.]], dtype=float32))

In [20]:
obs_gym

array([[0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.]], dtype=float32)

In [21]:
obs_jax

DeviceArray([[0., 0., 0., 0., 0.],
             [1., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0.],
             [0., 1., 0., 0., 0.]], dtype=float32)

# DeepSea

In [4]:
# 1. Create gymnax based environment
rng, env = gymnax.make("DeepSea-bsuite")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup bsuite-based environment, get state dict
env_gym = deep_sea.DeepSea(size=8, randomize_actions=False)
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "DeepSea-bsuite")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
timestep = env_gym.step(jnp.array(action))
obs_gym, reward_gym, done_gym = timestep.observation, timestep.reward, 1-timestep.discount
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "DeepSea-bsuite", state_jax)



# Discounting Chain

In [5]:
# 1. Create gymnax based environment
rng, env = gymnax.make("DiscountingChain-bsuite")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup bsuite-based environment, get state dict
env_gym = discounting_chain.DiscountingChain(mapping_seed=0)
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "DiscountingChain-bsuite")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
timestep = env_gym.step(jnp.array(action))
obs_gym, reward_gym, done_gym = timestep.observation, timestep.reward, 1-timestep.discount
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "DiscountingChain-bsuite", state_jax)

# Memory Chain

In [6]:
# 1. Create gymnax based environment
rng, env = gymnax.make("MemoryChain-bsuite")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup bsuite-based environment, get state dict
env_gym = memory_chain.MemoryChain(memory_length=5, num_bits=1)
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "MemoryChain-bsuite")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
timestep = env_gym.step(jnp.array(action))
obs_gym, reward_gym, done_gym = timestep.observation, timestep.reward, 1-timestep.discount
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "MemoryChain-bsuite", state_jax)

AssertionError: 

# Umbrella Chain

In [7]:
# 1. Create gymnax based environment
rng, env = gymnax.make("UmbrellaChain-bsuite")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup bsuite-based environment, get state dict
env_gym = umbrella_chain.UmbrellaChain(chain_length=10, n_distractor=0)
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "UmbrellaChain-bsuite")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
timestep = env_gym.step(jnp.array(action))
obs_gym, reward_gym, done_gym = timestep.observation, timestep.reward, 1-timestep.discount
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "UmbrellaChain-bsuite", state_jax)

# MNIST Bandit

In [8]:
# 1. Create gymnax based environment
rng, env = gymnax.make("MNISTBandit-bsuite")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup bsuite-based environment, get state dict
env_gym = mnist.MNISTBandit()
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "MNISTBandit-bsuite")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
timestep = env_gym.step(jnp.array(action))
obs_gym, reward_gym, done_gym = timestep.observation, timestep.reward, 1-timestep.discount
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "MNISTBandit-bsuite", state_jax)

# Simple Bandit

In [11]:
# 1. Create gymnax based environment
rng, env = gymnax.make("SimpleBandit-bsuite")
rng, key_reset, key_step, key_action = jax.random.split(rng, 4)
obs, state = env.reset(key_reset)

# 2. Setup bsuite-based environment, get state dict
env_gym = bandit.SimpleBandit()
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "SimpleBandit-bsuite")

# 3. Transition in both environments
action = env.action_space.sample(key_action)
timestep = env_gym.step(jnp.array(action))
obs_gym, reward_gym, done_gym = timestep.observation, timestep.reward, 1-timestep.discount
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "SimpleBandit-bsuite", state_jax)