In [1]:
import jax
import jax.numpy as jnp
import gym
import gymnax

# Pendulum-v0

In [2]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)



In [3]:
env = gym.make("Pendulum-v0")
obs = env.reset()
state = env.state[:]

action = jnp.array([1])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.87901473,  0.47679461,  1.49043396]), -6.69361665024985, False)

In [4]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.8790148,  0.4767945,  1.4904337], dtype=float32),
 DeviceArray(-6.693618, dtype=float32),
 DeviceArray(False, dtype=bool))

# CartPole-v0

In [5]:
# 4D State = Obs Space, 1D Action Space [Discrete - Left, Right]
rng, reset, step, env_params = gymnax.make("CartPole-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [6]:
env = gym.make("CartPole-v0")
obs = env.reset()
state = jnp.hstack([obs, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([ 0.01243853, -0.23243149, -0.02471123,  0.27305714]), 1.0, False)

In [7]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([ 0.01243853, -0.23243147, -0.02471123,  0.2730571 ], dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCar-v0

In [8]:
# 2D State = Obs Space, 1D Action Space [Discrete - Acc. Left, No Acc., Acc. Right]
rng, reset, step, env_params = gymnax.make("MountainCar-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [9]:
env = gym.make("MountainCar-v0")
obs = env.reset()
state = jnp.hstack([obs, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.535379  , -0.00091855]), -1.0, False)

In [10]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.535379  , -0.00091855], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCarContinuous-v0

In [11]:
# 2D State = Obs Space, 1D Action Space [Discrete - Acc. Left, No Acc., Acc. Right]
rng, reset, step, env_params = gymnax.make("MountainCarContinuous-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [12]:
env = gym.make("MountainCarContinuous-v0")
obs = env.reset()
state = jnp.hstack([obs, 0])

action = jnp.array([0.5])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.5142275 ,  0.00068486], dtype=float32), -0.025, False)

In [13]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.5142275 ,  0.00068486], dtype=float32),
 DeviceArray([-0.025], dtype=float32),
 DeviceArray([False], dtype=bool))

# Acrobot-v1