Skip to content

Clean single-file implementation of offline RL algorithms in JAX

License

Notifications You must be signed in to change notification settings

nissymori/JAX-CORL

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

JAX-CORL

This repository aims JAX version of CORL, clean single-file implementations of offline RL algorithms with solid performance reports.

  • 🌬️ Persuing fast training: speed up via jax functions such as jit and vmap.
  • 🔪 As simple as possible: implement minimum requirements.
  • 💠 Focus on a few battle-tested algorithms: Refer here.
  • 📈 Solid performance report (README, Wiki).

JAX-CORL is complelenting single-file RL ecosystem by offering the combination of offline x JAX.

  • CleanRL: Online x PyTorch
  • purejaxrl: Online x JAX
  • CORL: Offline x PyTorch
  • JAX-CORL(ours): Offline x JAX

Algorithms

Algorithm implementation training time (ours) training time (CORL) wandb
AWAC algos/awac.py 665.5s(~11m) 16083s(~4.46h) link
IQL algos/iql.py 516.5s (~9m) 14775s(~4.08h) link
TD3+BC algos/td3_bc.py 524.4s (~9m) 8895s(~2.47h) link
CQL algos/cql.py 3304.1s (~56m) 41838s (~11.52h) link
DT 🚧 - -

Training time is for 1000_000 update steps with batch size 256 for halfcheetah-medium-expert v2 (little difference between different D4RL mujoco environment). It includes the compile time for jit. The computations were performed using four GeForce GTX 1080 Ti GPUs. Overall, ours are at least 10 times faster than CORL.

Reports for D4RL mujoco

Normalized Score

Here, we used D4RL mujoco control tasks as the benchmark. We reported the mean and standard deviation of the average normalized score of 5 episodes over 5 seeds. We plan to extend the verification to other D4RL banchmarks such as AntMaze. For those who would like to know about the source of hyperparameters and the validity of the performance, please refer to Wiki.

env AWAC IQL TD3+BC CQL
halfcheetah-medium-v2 $41.56\pm0.79$ $43.28\pm0.51$ $48.12\pm0.42$ $48.65\pm 0.49$
halfcheetah-medium-expert-v2 $76.61\pm 9.60$ $92.87\pm0.61$ $92.99\pm 0.11$ $53.76 \pm 14.53$
hopper-medium-v2 $51.45\pm 5.40$ $52.17\pm2.88$ $46.51\pm4.57$ $77.56\pm 7.12$
hopper-medium-expert-v2 $51.89\pm2.11$ $53.35\pm5.63$ $105.47\pm5.03$ $90.37 \pm 31.29$
walker2d-medium-v2 $68.12\pm12.08$ $75.33\pm5.2$ $72.73\pm4.66$ $80.16\pm 4.19$
walker2d-medium-expert-v2 $91.36\pm23.13$ $109.07\pm0.32$ $109.17\pm0.71$ $110.03 \pm 0.72$

How to use this codebase for your own research

This codebase can be used independently as a baseline for D4RL projects. It is also designed to be flexible, allowing users to develop new algorithms or adapt it for datasets other than D4RL.

For researchers interested in using this code for their projects, we provide a detailed explanation of the code's shared structure:

Data structure
Transition(NamedTuple):
    observations: jnp.ndarray
    actions: jnp.ndarray
    rewards: jnp.ndarray
    next_observations: jnp.ndarray
    dones: jnp.ndarray

def get_dataset(...) -> Transition:
    ...
    return dataset

The code includes a Transition class, defined as a NamedTuple, which includes fields for observations, actions, rewards, next observations, and done flags. The get_dataset function is expected to output data in the Transition format, making it adaptable to any dataset that conforms to this structure.

Trainer class
class Trainer(NamedTuple):
    actor: TrainState
    critic: TrainState
    # hyper parameter
    discount: float = 0.99
    ...
    def update_actor(agent, batch: Transition):
        ...
        return agent

    def update_critic(agent, batch: Transition):
        ...
        return agent

    @partial(jax.jit, static_argnames("n_jitted_updates")
    def update_n_times(agent, data, n_jitted_updates)
      for _ in range(n_updates):
        batch = data.sample()
        agent = update_actor(batch)
        agent = update_critic(batch)
      return agent

def create_trainer(...):
    # initialize models...
    return Trainer(
        acotor=actor,
        critic=critic,
    )

For all algorithms, we have Trainer class (e.g. TD3BCTrainer for TD3+BC) which encompasses all necessary components for the algorithm: models, hyperparameters, update logics. The Trainer class is versatile and can be used outside of the provided files if the create_trainer function is properly implemented to meet the necessary specifications for the Trainer class. Note: So far, we could not follow the policy for CQL due to technical issue. This will be handled in near future.

See also

Great Offline RL libraries

  • CORL: Comprehensive single-file implementations of offline RL algorithms in pytorch.

Implementations of offline RL algorithms in JAX

Single-file implementations

  • CleanRL: High-quality single-file implementations of online RL algorithms in PyTorch.
  • purejaxrl: High-quality single-file implementations of online RL algorithms in JAX.

Cite JAX-CORL

@article{nishimori2022jaxcorl,
  title={JAX-CORL: Clean Sigle-file Implementations of Offline RL Algorithms in JAX},
  author={Soichiro Nishimori},
  year={2024},
  url={https://github.com/nissymori/JAX-CORL}
}

Credits

  • This project is inspired by CORL, a clean single-file implementations of offline RL algorithm in pytorch.
  • I would like to thank @JohannesAck for his TD3-BC codebase and helpful advices.
  • The IQL implementation is based on implicit_q_learning.
  • AWAC implementation is based on jaxrl.
  • CQL implementation is based on JaxCQL.
  • DT implementation is based on min-decision-transformer.

About

Clean single-file implementation of offline RL algorithms in JAX

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published