In [5]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = 'retina'

import jax
import jax.numpy as jnp
import gym
import gymnax



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
from gymnax.environments.classic_control import CartPole, Pendulum
rng = jax.random.PRNGKey(0)

In [20]:
env = Pendulum()
obs, state = env.reset(rng)

In [23]:
env.step(rng, state, 0)

(DeviceArray([-0.21629913, -0.9763271 , -0.12499666], dtype=float32),
 {'theta': DeviceArray(-1.7888186, dtype=float32),
  'theta_dot': DeviceArray(-0.12499666, dtype=float32),
  'time': 1,
  'terminal': 0},
 DeviceArray(-3.214548, dtype=float32),
 0,
 {})

In [24]:
jit_step = jax.jit(env.step)#(rng, state, 0)
vmap_step = jax.vmap(jit_step, in_axes=(0, None, None))(jax.random.split(rng, 10), state, 0)
vmap_step

(DeviceArray([[-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667],
              [-0.21629913, -0.9763271 , -0.12499667]], dtype=float32),
 {'terminal': DeviceArray([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=int32),
  'theta': DeviceArray([-1.7888186, -1.7888186, -1.7888186, -1.7888186, -1.7888186,
               -1.7888186, -1.7888186, -1.7888186, -1.7888186, -1.7888186],            dtype=float32),
  'theta_dot': DeviceArray([-0.12499667, -0.12499667, -0.12499667, -0.12499667,
               -0.12499667, -0.12499667, -0.12499667, -0.12499667,
               -0.12499667,

In [None]:
env.name

In [None]:
env.update_env_params("gravity", 10)
env.env_params

# Pendulum-v0

In [None]:
# 2D State Space, 3D Obs Space, 1D Action Space [Continuous - Torque]
rng, reset, step, env_params = gymnax.make("Pendulum-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
env = gym.make("Pendulum-v0")
obs = env.reset()
state_gym_to_jax = {"theta": env.state[0], 
                    "theta_dot": env.state[1],
                    "time": 0,
                    "terminal": 0}

action = jnp.array([1])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state_gym_to_jax, action)
obs_jax, reward_jax, done_jax

# CartPole-v0

In [None]:
# 4D State = Obs Space, 1D Action Space [Discrete - Left, Right]
rng, reset, step, env_params = gymnax.make("CartPole-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
env = gym.make("CartPole-v0")
obs = env.reset()
state = {"x": obs[0],
         "x_dot": obs[1],
         "theta": obs[2],
         "theta_dot": obs[3],
         "time": 0,
         "terminal": 0}

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

# MountainCar-v0

In [None]:
# 2D State = Obs Space, 1D Action Space [Discrete - Acc. Left, No Acc., Acc. Right]
rng, reset, step, env_params = gymnax.make("MountainCar-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
env = gym.make("MountainCar-v0")
obs = env.reset()
state = {"position": obs[0],
         "velocity": obs[1],
         "time": 0,
         "terminal": 0}

action = 0
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

# MountainCarContinuous-v0

In [None]:
# 2D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("MountainCarContinuous-v0")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
env = gym.make("MountainCarContinuous-v0")
obs = env.reset()
state = {"position": obs[0],
         "velocity": obs[1],
         "time": 0,
         "terminal": 0}

action = jnp.array([0.5])
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax

# Acrobot-v1

In [None]:
# 6D State = Obs Space, 1D Action Space [Cont.: -1, 1]
rng, reset, step, env_params = gymnax.make("Acrobot-v1")
rng, key_reset, key_step = jax.random.split(rng, 3)
obs, state = reset(key_reset, env_params)

In [None]:
env = gym.make("Acrobot-v1")
obs = env.reset()
state = {"joint_angle1": env.state[0],
             "joint_angle2": env.state[1],
             "velocity_1": env.state[2],
             "velocity_2": env.state[3],
             "time": 0,
             "terminal": 0}


action = 2
obs_gym, reward_gym, done_gym, _ = env.step(action)
obs_gym, reward_gym, done_gym

In [None]:
obs_jax, state_jax, reward_jax, done_jax, _ = step(key_step, env_params,
                                                   state, action)
obs_jax, reward_jax, done_jax