In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import gym
import gymnax



# Pendulum-v0

In [2]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [3]:
env = gym.make("Pendulum-v0")
obs = env.reset()
state_gym_to_jax = {"theta": env.state[0], 
                    "theta_dot": env.state[1],
                    "time": 0,
                    "terminal": 0}

action = jnp.array([1])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.93178695, -0.36300567, -0.5864714 ], dtype=float32),
 DeviceArray(-7.5325613, dtype=float32),
 False)

In [4]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state_gym_to_jax, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.93178695, -0.36300567, -0.58647144], dtype=float32),
 DeviceArray(-7.532561, dtype=float32),
 DeviceArray(False, dtype=bool))

# CartPole-v0

In [5]:
# 4D State = Obs Space, 1D Action Space [Discrete - Left, Right]
rng, reset, step, env_params = gymnax.make("CartPole-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [6]:
env = gym.make("CartPole-v0")
obs = env.reset()
state = {"x": obs[0],
         "x_dot": obs[1],
         "theta": obs[2],
         "theta_dot": obs[3],
         "time": 0,
         "terminal": 0}

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.00690449, -0.19804861, -0.02369296,  0.31639668]), 1.0, False)

In [7]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.00690449, -0.1980486 , -0.02369296,  0.31639668], dtype=float32),
 DeviceArray(1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCar-v0

In [8]:
# 2D State = Obs Space, 1D Action Space [Discrete - Acc. Left, No Acc., Acc. Right]
rng, reset, step, env_params = gymnax.make("MountainCar-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [9]:
env = gym.make("MountainCar-v0")
obs = env.reset()
state = {"position": obs[0],
         "velocity": obs[1],
         "time": 0,
         "terminal": 0}

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.42761389, -0.00172236]), -1.0, False)

In [10]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([-0.42761388, -0.00172236], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(False, dtype=bool))

# MountainCarContinuous-v0

In [11]:
# 2D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("MountainCarContinuous-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [12]:
env = gym.make("MountainCarContinuous-v0")
obs = env.reset()
state = {"position": obs[0],
         "velocity": obs[1],
         "time": 0,
         "terminal": 0}

action = jnp.array([0.5])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([-0.5517053 ,  0.00096778], dtype=float32), -0.025, False)

In [13]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([[-0.5517053 ],
              [ 0.00096778]], dtype=float32),
 DeviceArray([-0.025], dtype=float32),
 DeviceArray([False], dtype=bool))

# Acrobot-v1

In [14]:
# 6D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("Acrobot-v1")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [15]:
env = gym.make("Acrobot-v1")
obs = env.reset()
state = {"joint_angle1": env.state[0],
             "joint_angle2": env.state[1],
             "velocity_1": env.state[2],
             "velocity_2": env.state[3],
             "time": 0,
             "terminal": 0}


action = 2
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

(array([ 0.99989775,  0.01429985,  0.99975243, -0.02225053, -0.12747209,
         0.32437752]), -1.0, False)

In [16]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

(DeviceArray([ 0.9998978 ,  0.01429985,  0.9997524 , -0.02225053,
              -0.12747204,  0.32437748], dtype=float32),
 DeviceArray(-1., dtype=float32),
 DeviceArray(False, dtype=bool))