In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

In [2]:
import jax
import jax.numpy as jnp
import gym
import gymnax



# Pendulum-v0

In [3]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [4]:
env = gym.make("Pendulum-v0")
obs = env.reset()
state_gym_to_jax = jnp.hstack([env.state[:], 0])

action = jnp.array([1])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.6645005 , -0.7472878 ,  0.13394395], dtype=float32),
 DeviceArray(-5.3401594, dtype=float32),
 False)

In [5]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state_gym_to_jax, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.6645005 , -0.74728787,  0.13394386], dtype=float32),
 DeviceArray(-5.340159, dtype=float32),
 DeviceArray(False, dtype=bool))

# CartPole-v0

In [6]:
# 4D State = Obs Space, 1D Action Space [Discrete - Left, Right]
rng, reset, step, env_params = gymnax.make("CartPole-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [7]:
env = gym.make("CartPole-v0")
obs = env.reset()
state = jnp.hstack([obs, 0, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.0153318 , -0.15174199, -0.01085852,  0.27575762]), 1.0, False)

In [8]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.0153318 , -0.15174198, -0.01085852,  0.2757576 ], dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCar-v0

In [9]:
# 2D State = Obs Space, 1D Action Space [Discrete - Acc. Left, No Acc., Acc. Right]
rng, reset, step, env_params = gymnax.make("MountainCar-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [10]:
env = gym.make("MountainCar-v0")
obs = env.reset()
state = jnp.hstack([obs, 0, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.53087833, -0.00095255]), -1.0, False)

In [11]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.5308783 , -0.00095255], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCarContinuous-v0

In [12]:
# 2D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("MountainCarContinuous-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [13]:
env = gym.make("MountainCarContinuous-v0")
obs = env.reset()
state = jnp.hstack([obs, 0, 0])

action = jnp.array([0.5])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.57035965,  0.00110778], dtype=float32), -0.025, False)

In [14]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.57035965,  0.00110778], dtype=float32),
 DeviceArray([-0.025], dtype=float32),
 DeviceArray([False], dtype=bool))

# Acrobot-v1

In [18]:
# 6D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("Acrobot-v1")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [25]:
env = gym.make("Acrobot-v1")
obs = env.reset()
state = jnp.hstack([env.state[:], 100])

action = 2
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([ 0.99969461,  0.02471213,  0.99990262,  0.01395561, -0.2104983 ,
         0.50916523]), -1.0, False)

In [26]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(Buffer([ 0.9996946 ,  0.02471212,  0.9999026 ,  0.01395561, -0.21049826,
          0.5091652 ], dtype=float32),
 Buffer(-1., dtype=float32),
 Buffer(False, dtype=bool))