In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import gym
import gymnax



# Pendulum-v0

In [2]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [3]:
env = gym.make("Pendulum-v0")
obs = env.reset()
state_gym_to_jax = jnp.hstack([env.state[:], 0])

action = jnp.array([1])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.87099123,  0.49129856,  0.57785404], dtype=float32),
 DeviceArray(-6.7565894, dtype=float32),
 False)

In [4]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state_gym_to_jax, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.87099123,  0.49129853,  0.5778539 ], dtype=float32),
 DeviceArray(-6.75659, dtype=float32),
 DeviceArray(False, dtype=bool))

# CartPole-v0

In [5]:
# 4D State = Obs Space, 1D Action Space [Discrete - Left, Right]
rng, reset, step, env_params = gymnax.make("CartPole-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [6]:
env = gym.make("CartPole-v0")
obs = env.reset()
state = jnp.hstack([obs, 0, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([ 0.01443894, -0.15467995,  0.04936708,  0.35307481]), 1.0, False)

In [7]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([ 0.01443894, -0.15467994,  0.04936708,  0.3530748 ], dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCar-v0

In [8]:
# 2D State = Obs Space, 1D Action Space [Discrete - Acc. Left, No Acc., Acc. Right]
rng, reset, step, env_params = gymnax.make("MountainCar-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [9]:
env = gym.make("MountainCar-v0")
obs = env.reset()
state = jnp.hstack([obs, 0, 0])

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.45243659, -0.00154096]), -1.0, False)

In [10]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.4524366 , -0.00154096], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCarContinuous-v0

In [11]:
# 2D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("MountainCarContinuous-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [12]:
env = gym.make("MountainCarContinuous-v0")
obs = env.reset()
state = jnp.hstack([obs, 0, 0])

action = jnp.array([0.5])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.5645939 ,  0.00106461], dtype=float32), -0.025, False)

In [13]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.5645939 ,  0.00106461], dtype=float32),
 DeviceArray([-0.025], dtype=float32),
 DeviceArray([False], dtype=bool))

# Acrobot-v1

In [14]:
# 6D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("Acrobot-v1")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [15]:
env = gym.make("Acrobot-v1")
obs = env.reset()
state = jnp.hstack([env.state[:], 100])

action = 2
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([ 0.9986493 ,  0.05195746,  0.99746759,  0.07112244, -0.25985733,
         0.4668123 ]), -1.0, False)

In [16]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([ 0.9986493 ,  0.05195746,  0.9974676 ,  0.07112244,
              -0.2598573 ,  0.46681225], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(False, dtype=bool))