In [1]:
import jax
import jax.numpy as jnp
import gym
import gymnax

# Pendulum-v0

In [2]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)



In [3]:
env = gym.make("Pendulum-v0")
obs = env.reset()
state = env.state[:]

action = jnp.array([1])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([ 0.84300198, -0.53791046, -0.57475187]), -0.30329336376075317, False)

In [4]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([ 0.84300196, -0.53791046, -0.57475185], dtype=float32),
 DeviceArray(-0.3032934, dtype=float32),
 DeviceArray(False, dtype=bool))

# CartPole-v0

In [5]:
# 4D State = Obs Space, 1D Action Space [Discrete - Left, Right]
rng, reset, step, env_params = gymnax.make("CartPole-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [6]:
env = gym.make("CartPole-v0")
obs = env.reset()
state = jnp.hstack([obs, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.02128864, -0.22881371, -0.04331165,  0.23288206]), 1.0, False)

In [7]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.02128864, -0.22881371, -0.04331165,  0.23288204], dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(False, dtype=bool))