This repository aims JAX version of CORL, clean single-file implementations of offline RL algorithms with solid performance reports.
- 🌬️ Persuing fast training: speed up via jax functions such as
jit
andvmap
. - 🔪 As simple as possible: implement minimum requirements.
- 💠 Focus on a few battle-tested algorithms: Refer here.
- 📈 Solid performance report (README, Wiki).
JAX-CORL is complelenting single-file RL ecosystem by offering the combination of offline x JAX.
- CleanRL: Online x PyTorch
- purejaxrl: Online x JAX
- CORL: Offline x PyTorch
- JAX-CORL(ours): Offline x JAX
Algorithm | implementation | training time (ours) | training time (CORL) | wandb |
---|---|---|---|---|
AWAC | algos/awac.py | 665.5s(~11m) | 16083s(~4.46h) | link |
IQL | algos/iql.py | 516.5s (~9m) | 14775s(~4.08h) | link |
TD3+BC | algos/td3_bc.py | 524.4s (~9m) | 8895s(~2.47h) | link |
CQL | algos/cql.py | 3304.1s (~56m) | 41838s (~11.52h) | link |
DT | 🚧 | - | - |
Training time is for 1000_000
update steps with batch size 256
for halfcheetah-medium-expert v2 (little difference between different D4RL mujoco environment). It includes the compile time for jit
. The computations were performed using four GeForce GTX 1080 Ti GPUs. Overall, ours are at least 10 times faster than CORL.
Here, we used D4RL mujoco control tasks as the benchmark. We reported the mean and standard deviation of the average normalized score of 5 episodes over 5 seeds. We plan to extend the verification to other D4RL banchmarks such as AntMaze. For those who would like to know about the source of hyperparameters and the validity of the performance, please refer to Wiki.
env | AWAC | IQL | TD3+BC | CQL |
---|---|---|---|---|
halfcheetah-medium-v2 | ||||
halfcheetah-medium-expert-v2 | ||||
hopper-medium-v2 | ||||
hopper-medium-expert-v2 | ||||
walker2d-medium-v2 | ||||
walker2d-medium-expert-v2 |
This codebase can be used independently as a baseline for D4RL projects. It is also designed to be flexible, allowing users to develop new algorithms or adapt it for datasets other than D4RL.
For researchers interested in using this code for their projects, we provide a detailed explanation of the code's shared structure:
Transition(NamedTuple):
observations: jnp.ndarray
actions: jnp.ndarray
rewards: jnp.ndarray
next_observations: jnp.ndarray
dones: jnp.ndarray
def get_dataset(...) -> Transition:
...
return dataset
The code includes a Transition
class, defined as a NamedTuple
, which includes fields for observations, actions, rewards, next observations, and done flags. The get_dataset function is expected to output data in the Transition format, making it adaptable to any dataset that conforms to this structure.
class Trainer(NamedTuple):
actor: TrainState
critic: TrainState
# hyper parameter
discount: float = 0.99
...
def update_actor(agent, batch: Transition):
...
return agent
def update_critic(agent, batch: Transition):
...
return agent
@partial(jax.jit, static_argnames("n_jitted_updates")
def update_n_times(agent, data, n_jitted_updates)
for _ in range(n_updates):
batch = data.sample()
agent = update_actor(batch)
agent = update_critic(batch)
return agent
def create_trainer(...):
# initialize models...
return Trainer(
acotor=actor,
critic=critic,
)
For all algorithms, we have Trainer
class (e.g. TD3BCTrainer
for TD3+BC) which encompasses all necessary components for the algorithm: models, hyperparameters, update logics. The Trainer class is versatile and can be used outside of the provided files if the create_trainer function is properly implemented to meet the necessary specifications for the Trainer class.
Note: So far, we could not follow the policy for CQL due to technical issue. This will be handled in near future.
Great Offline RL libraries
- CORL: Comprehensive single-file implementations of offline RL algorithms in pytorch.
Implementations of offline RL algorithms in JAX
- jaxrl: Includes implementatin of AWAC.
- JaxCQL: Clean implementation of CQL.
- implicit_q_learning: Official implementation of IQL.
- decision-transformer-jax: Jax implementation of Decision Transformer with Haiku.
- td3-bc-jax: Direct port of original implementation with Haiku.
Single-file implementations
- CleanRL: High-quality single-file implementations of online RL algorithms in PyTorch.
- purejaxrl: High-quality single-file implementations of online RL algorithms in JAX.
@article{nishimori2022jaxcorl,
title={JAX-CORL: Clean Sigle-file Implementations of Offline RL Algorithms in JAX},
author={Soichiro Nishimori},
year={2024},
url={https://github.com/nissymori/JAX-CORL}
}
- This project is inspired by CORL, a clean single-file implementations of offline RL algorithm in pytorch.
- I would like to thank @JohannesAck for his TD3-BC codebase and helpful advices.
- The IQL implementation is based on implicit_q_learning.
- AWAC implementation is based on jaxrl.
- CQL implementation is based on JaxCQL.
- DT implementation is based on min-decision-transformer.