In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import gym
import gymnax

In [2]:
# jit_step = jax.jit(env.step)
# vmap_step = jax.vmap(jit_step, in_axes=(0, None, None))(jax.random.split(rng, 10), state, 0)
# vmap_step
# env.name
# env.update_env_params("gravity", 10)
# env.params

In [9]:
env_gym.observation_space.low

array([ -1.      ,  -1.      ,  -1.      ,  -1.      , -12.566371,
       -28.274334], dtype=float32)

# Pendulum-v0
### 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]


In [3]:
# 1. Create gymnax based environment
rng, env = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = gym.make("Pendulum-v0")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "Pendulum-v0")

# 3. Transition in both environments
action = 1
obs_gym, reward_gym, done_gym, _ = env_gym.step(jnp.array([action]))
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "Pendulum-v0", state_jax)



# CartPole-v0
#### 4D State = Obs Space, 1D Action Space [Discrete - Left, Right]

In [4]:
# 1. Create gymnax based environment
rng, env = gymnax.make("CartPole-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = gym.make("CartPole-v0")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "CartPole-v0")

# 3. Transition in both environments
action = 1
obs_gym, reward_gym, done_gym, _ = env_gym.step(action)
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "CartPole-v0", state_jax)

# MountainCar-v0
#### 2D State = Obs Space, 1D Action Space [Discrete - Acc. Left, No Acc., Acc. Right]

In [5]:
# 1. Create gymnax based environment
rng, env = gymnax.make("MountainCar-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = gym.make("MountainCar-v0")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "MountainCar-v0")

# 3. Transition in both environments
action = 1
obs_gym, reward_gym, done_gym, _ = env_gym.step(action)
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "MountainCar-v0", state_jax)

# MountainCarContinuous-v0
#### 2D State = Obs Space, 1D Action Space [Cont.: -1, 1]

In [6]:
# 1. Create gymnax based environment
rng, env = gymnax.make("MountainCarContinuous-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = gym.make("MountainCarContinuous-v0")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "MountainCarContinuous-v0")

# 3. Transition in both environments
action = 1
obs_gym, reward_gym, done_gym, _ = env_gym.step(jnp.array([action]))
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "MountainCarContinuous-v0", state_jax)

# Acrobot-v1
#### 6D State = Obs Space, 1D Action Space [Cont.: -1, 1]

In [11]:
# 1. Create gymnax based environment
rng, env = gymnax.make("Acrobot-v1")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = env.reset(key_reset)

# 2. Setup gym-based environment, get state dict
env_gym = gym.make("Acrobot-v1")
_ = env_gym.reset()
state = gymnax.utils.np_state_to_jax(env_gym, "Acrobot-v1")

# 3. Transition in both environments
action = 1
obs_gym, reward_gym, done_gym, _ = env_gym.step(action)
obs_jax, state_jax, reward_jax, done_jax, _ = env.step(key_step, state, action)

# 4a. Check correctness of transition
gymnax.utils.assert_correct_transit(obs_gym, reward_gym, done_gym,
                                    obs_jax, reward_jax, done_jax)

# 4b. Check that post-transition states are equal
gymnax.utils.assert_correct_state(env_gym, "Acrobot-v1", state_jax)