Skip to content
/ jym Public

JAX implementation of RL algorithms and vectorized environments

License

Notifications You must be signed in to change notification settings

RPegoud/jym

Repository files navigation

JYM (Jax Gym)


JAX implementations of standard RL algorithms and vectorized environments

🚀 Stack

Python Python JAX JAX Haiku Haiku Optax Optax Haiku Chex

🌟 Key Features

Progress

🤖 Algorithms

Type Name Source
Bandits Simple Bandits (ε-Greedy policy) Sutton & Barto, 1998
Tabular Q-learning Watkins & Dayan, 1992
Tabular Expected SARSA Van Seijen et al., 2009
Tabular Double Q-learning Van Hasselt, 2010
Deep RL Deep Q-Network (DQN) Mnih et al., 2015

🌍 Environments

Type Name Source
Bandits Casino (K-armed Bandits) Sutton & Barto, 1998
Tabular GridWorld -
Tabular Cliff Walking -
Continuous Control CartPole Barto, Sutton, & Anderson, 1983
MinAtar Breakout Young et al., 2019

Coming Soon

🤖 Algorithms

Type Name
Bandits UCB (Upper Confidence Bound)
Tabular (model based) Dyna-Q, Dyna-Q+

🌍 Environments

Type Name
MinAtar Asterix, Freeway, Seaquest, SpaceInvaders

🧪 Experiments

  • K-armed Bandits Testbed

Reproduction of the 10-armed Testbed experiment presented in Reinforcement Learning: An Introduction (chapter 2.3, page 28-29).

This experiment showcases the difference in performance between different values of epsilon and therefore the long-term tradeoff between exploration and exploitation.

Description for first image

10-armed Testebed environment

K-armed Bandits Testbed distribution

K-armed Bandits JAX environment

Description for first image

Results obtained in Reinforcement Learning: An Introduction

K-armed Bandits Testbed

Replicated results using the K-armed Bandits JAX environment

  • Cliff Walking

Reproduction of the CliffWalking environment presented in Reinforcement Learning: An Introduction (chapter 6, page 132).

This experiment highlights the difference in behavior between TD algorithms, Q-learning being greedy (as the td target is the maximum Q-value over the next state) and Expected Sarsa being safer (td target: expected Q-value over the next state).

Described behaviour for the CliffWalking environment

Comparison of Expected Sarsa (top) and Q-learning (bottom) on CliffWalking

💾 Installation

To install and set up the project, follow these steps:

  1. Clone the repository to your local machine:

    git clone https://github.com/RPegoud/jax_rl.git
  2. Navigate to the project directory:

    cd jax_rl
  3. Install Poetry (if not already installed):

    python -m pip install poetry
  4. Install project dependencies using Poetry:

    poetry install
  5. Activate the virtual environment created by Poetry:

    poetry shell

📝 References


Official JAX Documenation

About

JAX implementation of RL algorithms and vectorized environments

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published