Installation | Overview | Example Usage
Important
New reinforcement learning algorithms are frequently added to this project. However, for benchmarking purposes, please refer to the original implementations.
Note
The following README is an overview of what the library offers. Please refer to the documentation for more details.
Kitae aims to be a middle ground between 'clear RL' implementations and 'use-only' libraries (SB3, ...).
In Kitae, an Agent is entirely defined by a configuration and 4 factory functions:
train_state_factory
: creates the Agent's stateexplore_factory
: creates the function used to interact in the environmentprocess_experience_factory
: creates the function to process the data before updatingupdate_step_factory
: creates the function to update the Agent's state
These functions can be implemented very closely to 'clean RL' implementations, but are ultimately encapsulated into a single class which simplifies the use of multiple environments, saving and loading, etc...
Kitae offers a few tools to simplify writing agents. In particular, self-play in multi-agent settings and vectorized environments are automatically handled by the library.
This package requires Python 3.10 or later and a working JAX installation. To install JAX, refer to the instructions.
pip install --upgrade pip
pip install --upgrade git+https://github.com/Raffaelbdl/kitae
Kitae is designed as a polyvalent toolbox library for reinforcement learning. The goal is to simplify all steps of the process, from agent creation, to training and evaluating them.
One main feature of Kitae, is that it is designed to simplify working in vectorized settings with multiple instances of a environment.
from kitae.algos.collections import ppo
from kitae import config as cfg
from kitae.envs.make import make_vec_env
SEED = 0
ENV_ID = "CartPole-v1"
env = make_vec_env(ENV_ID, 16, capture_video=False, run_name=None)
env_cfg = cfg.EnvConfig(
ENV_ID,
env.single_observation_space,
env.single_action_space,
n_envs=16,
n_agents=1
)
agent = ppo.PPO(
"example-ppo",
cfg.AlgoConfig(
seed=SEED,
algo_params=ppo.PPOParams(
gamma=0.99,
_lambda=0.95,
clip_eps=0.2,
entropy_coef=0.01,
value_coef=0.5,
normalize=True,
),
update_cfg=cfg.UpdateConfig(
learning_rate=0.0003,
learning_rate_annealing=True,
max_grad_norm=0.5,
max_buffer_size=64,
batch_size=256,
n_epochs=1,
shared_encoder=True,
),
train_cfg=cfg.TrainConfig(n_env_steps=5*10**5, save_frequency=-1),
env_cfg=env_config,
),
tabulate=True,
)
algo.train(env, algo.config.train_cfg.n_env_steps)
The process of building a custom agent is detailed in this Google Colab.