In [1]:
import jax
import jax.numpy as jnp
import gym
import gymnax



# Pendulum-v0

In [2]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [3]:
env = gym.make("Pendulum-v0")
obs = env.reset()
state = env.state[:]

action = jnp.array([1])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.78571081, -0.61859398,  0.21487095]), -6.205431164909779, False)

In [4]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.78571075, -0.61859405,  0.21487087], dtype=float32),
 DeviceArray(-6.205431, dtype=float32),
 DeviceArray(False, dtype=bool))

# CartPole-v0

In [5]:
# 4D State = Obs Space, 1D Action Space [Discrete - Left, Right]
rng, reset, step, env_params = gymnax.make("CartPole-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [6]:
env = gym.make("CartPole-v0")
obs = env.reset()
state = jnp.hstack([obs, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.01735631, -0.23905205,  0.02145337,  0.29879691]), 1.0, False)

In [7]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.01735631, -0.23905204,  0.02145337,  0.2987969 ], dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCar-v0

In [8]:
# 2D State = Obs Space, 1D Action Space [Discrete - Acc. Left, No Acc., Acc. Right]
rng, reset, step, env_params = gymnax.make("MountainCar-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [9]:
env = gym.make("MountainCar-v0")
obs = env.reset()
state = jnp.hstack([obs, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.42897494, -0.00171251]), -1.0, False)

In [10]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.42897493, -0.00171251], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCarContinuous-v0

In [11]:
# 2D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("MountainCarContinuous-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [12]:
env = gym.make("MountainCarContinuous-v0")
obs = env.reset()
state = jnp.hstack([obs, 0])

action = jnp.array([0.5])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.57712346,  0.00115829], dtype=float32), -0.025, False)

In [13]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.57712346,  0.00115829], dtype=float32),
 DeviceArray([-0.025], dtype=float32),
 DeviceArray([False], dtype=bool))

# Acrobot-v1

In [14]:
# 6D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("Acrobot-v1")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [15]:
env = gym.make("Acrobot-v1")
obs = env.reset()
state = env.state[:]

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([ 0.99984468,  0.01762455,  0.99666636, -0.08158527,  0.11572381,
        -0.38668822]), -1.0, False)

In [16]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([ 0.9998447 ,  0.01762455,  0.9966664 , -0.08158528,
               0.11572383, -0.3866882 ], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(False, dtype=bool))