Skip to content

Commit

Permalink
[RLlib] Early improvements to Catalogs and RL Modules docs + Catalogs…
Browse files Browse the repository at this point in the history
… improvements (ray-project#37245)

Signed-off-by: Artur Niederfahrenhorst <attaismyname@googlemail.com>
  • Loading branch information
ArturNiederfahrenhorst authored and Bhav00 committed Jul 22, 2023
1 parent b766284 commit 1213f20
Show file tree
Hide file tree
Showing 16 changed files with 660 additions and 215 deletions.
27 changes: 9 additions & 18 deletions doc/source/rllib/doc_code/catalog_guide.py
Expand Up @@ -30,34 +30,27 @@
import gymnasium as gym
import torch

from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT
# ENCODER_OUT is a constant we use to enumerate Encoder I/O.
from ray.rllib.core.models.base import ENCODER_OUT
from ray.rllib.core.models.catalog import Catalog
from ray.rllib.core.models.configs import MLPHeadConfig
from ray.rllib.policy.sample_batch import SampleBatch

env = gym.make("CartPole-v1")

catalog = Catalog(env.observation_space, env.action_space, model_config_dict={})
# We expect a categorical distribution for CartPole.
action_dist_class = catalog.get_action_dist_cls(framework="torch")
# Therefore, we need `env.action_space.n` action distribution inputs.
expected_action_dist_input_dims = (env.action_space.n,)

# Build an encoder that fits CartPole's observation space.
encoder = catalog.build_encoder(framework="torch")
# Build a suitable head model for the action distribution.
head_config = MLPHeadConfig(
input_dims=catalog.latent_dims, hidden_layer_dims=expected_action_dist_input_dims
)
head = head_config.build(framework="torch")
# We need `env.action_space.n` action distribution inputs.
head = torch.nn.Linear(catalog.latent_dims[0], env.action_space.n)
# Now we are ready to interact with the environment
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {
SampleBatch.OBS: torch.Tensor([obs]),
STATE_IN: None,
SampleBatch.SEQ_LENS: None,
}
input_batch = {SampleBatch.OBS: torch.Tensor([obs])}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT]
action_dist_inputs = head(encoding)
Expand All @@ -75,6 +68,8 @@
import torch

from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog

# STATE_IN, STATE_OUT and ENCODER_OUT are constants we use to enumerate Encoder I/O.
from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT, ACTOR
from ray.rllib.policy.sample_batch import SampleBatch

Expand All @@ -91,11 +86,7 @@
obs, info = env.reset()
# Encoders check for state and sequence lengths for recurrent models.
# We don't need either in this case because default encoders are not recurrent.
input_batch = {
SampleBatch.OBS: torch.Tensor([obs]),
STATE_IN: None,
SampleBatch.SEQ_LENS: None,
}
input_batch = {SampleBatch.OBS: torch.Tensor([obs])}
# Pass the batch through our models and the action distribution.
encoding = encoder(input_batch)[ENCODER_OUT][ACTOR]
action_dist_inputs = policy_head(encoding)
Expand Down
47 changes: 47 additions & 0 deletions doc/source/rllib/doc_code/rlmodule_guide.py
Expand Up @@ -399,3 +399,50 @@ def setup(self):

module = spec.build()
# __pass-custom-marlmodule-shared-enc-end__


# __checkpointing-begin__
import gymnasium as gym
import shutil
import tempfile
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog
from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec

config = PPOConfig().environment("CartPole-v1")
env = gym.make("CartPole-v1")
# Create an RL Module that we would like to checkpoint
module_spec = SingleAgentRLModuleSpec(
module_class=PPOTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [32]},
catalog_class=PPOCatalog,
)
module = module_spec.build()

# Create the checkpoint
module_ckpt_path = tempfile.mkdtemp()
module.save_to_checkpoint(module_ckpt_path)

# Create a new RL Module from the checkpoint
module_to_load_spec = SingleAgentRLModuleSpec(
module_class=PPOTorchRLModule,
observation_space=env.observation_space,
action_space=env.action_space,
model_config_dict={"fcnet_hiddens": [32]},
catalog_class=PPOCatalog,
load_state_path=module_ckpt_path,
)

# Train with the checkpointed RL Module
config.rl_module(
rl_module_spec=module_to_load_spec,
_enable_rl_module_api=True,
)
algo = config.build()
algo.train()
# __checkpointing-end__
algo.stop()
shutil.rmtree(module_ckpt_path)
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
49 changes: 21 additions & 28 deletions doc/source/rllib/key-concepts.rst
Expand Up @@ -44,6 +44,7 @@ The model that tries to maximize the expected sum over all future rewards is cal
The RL simulation feedback loop repeatedly collects data, for one (single-agent case) or multiple (multi-agent case) policies, trains the policies on these collected data, and makes sure the policies' weights are kept in sync. Thereby, the collected environment data contains observations, taken actions, received rewards and so-called **done** flags, indicating the boundaries of different episodes the agents play through in the simulation.

The simulation iterations of action -> reward -> next state -> train -> repeat, until the end state, is called an **episode**, or in RLlib, a **rollout**.
The most common API to define environments is the `Farama-Foundation Gymnasium <rllib-env.html#gymnasium>`__ API, which we also use in most of our examples.

.. _algorithms:

Expand Down Expand Up @@ -115,40 +116,32 @@ You can `configure the parallelism <rllib-training.html#specifying-resources>`__
Check out our `scaling guide <rllib-training.html#scaling-guide>`__ for more details here.


Policies
--------

`Policies <rllib-concepts.html#policies>`__ are a core concept in RLlib. In a nutshell, policies are
Python classes that define how an agent acts in an environment.
`Rollout workers <rllib-concepts.html#policy-evaluation>`__ query the policy to determine agent actions.
In a `Farama-Foundation Gymnasium <rllib-env.html#gymnasium>`__ environment, there is a single agent and policy.
In `vector envs <rllib-env.html#vectorized>`__, policy inference is for multiple agents at once,
and in `multi-agent <rllib-env.html#multi-agent-and-hierarchical>`__, there may be multiple policies,
each controlling one or more agents:
RL Modules
----------

.. image:: images/multi-flat.svg
`RLModules <rllib-rlmodule.html>`__ are framework-specific neural network containers.
In a nutshell, they carry the neural networks and define how to use them during three phases that occur in
reinforcement learning: Exploration, inference and training.
A minimal RL Module can contain a single neural network and define its exploration-, inference- and
training logic to only map observations to actions. Since RL Modules can map observations to actions, they naturally
implement reinforcement learning policies in RLlib and can therefore be found in the :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`,
where their exploration and inference logic is used to sample from an environment.
The second place in RLlib where RL Modules commonly occur is the :py:class:`~ray.rllib.core.learner.learner.Learner`,
where their training logic is used in training the neural network.
RL Modules extend to the multi-agent case, where a single :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule`
contains multiple RL Modules. The following figure is a rough sketch of how the above can look in practice:

Policies can be implemented using `any framework <https://github.com/ray-project/ray/blob/master/rllib/policy/policy.py>`__.
However, for TensorFlow and PyTorch, RLlib has
`build_tf_policy <rllib-concepts.html#building-policies-in-tensorflow>`__ and
`build_torch_policy <rllib-concepts.html#building-policies-in-pytorch>`__ helper functions that let you
define a trainable policy with a functional-style API, for example:
.. image:: images/rllib-concepts-rlmodules-sketch.png

.. TODO: test this code snippet

.. code-block:: python
def policy_gradient_loss(policy, model, dist_class, train_batch):
logits, _ = model.from_batch(train_batch)
action_dist = dist_class(logits, model)
return -tf.reduce_mean(
action_dist.logp(train_batch["actions"]) * train_batch["rewards"])
.. note::

# <class 'ray.rllib.policy.tf_policy_template.MyTFPolicy'>
MyTFPolicy = build_tf_policy(
name="MyTFPolicy",
loss_fn=policy_gradient_loss)
RL Modules are currently in alpha stage. They are wrapped in legacy :py:class:`~ray.rllib.policy.Policy` objects
to be used in :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` for sampling.
This should be transparent to the user, but the following
`Policy Evaluation <key-concepts.html#policy-evaluation>`__ section still refers to these legacy Policy objects.

.. policy-evaluation:
Policy Evaluation
-----------------
Expand Down
13 changes: 6 additions & 7 deletions doc/source/rllib/package_ref/catalogs.rst
Expand Up @@ -11,8 +11,7 @@ Basic usage
-----------

Use the following basic API to get a default ``encoder`` or ``action distribution``
out of Catalog. You can inherit from Catalog and modify the following methods to
directly inject custom components into a given RLModule.
out of Catalog. To change the catalog behavior, modify the following methods.
Algorithm-specific implementations of Catalog have additional methods,
for example, for building ``heads``.

Expand All @@ -24,18 +23,18 @@ for example, for building ``heads``.
Catalog
Catalog.build_encoder
Catalog.get_action_dist_cls
Catalog.get_preprocessor
Catalog.get_tokenizer_config


Advanced usage
--------------

The following methods are used internally by the Catalog to build the default models.
The following methods and attributes are used internally by the Catalog to build the default models. Only override them when you need more granular control.

.. autosummary::
:toctree: doc/

Catalog.latent_dims
Catalog.__post_init__
Catalog.get_encoder_config
Catalog.get_tokenizer_config
Catalog._determine_components_hook
Catalog._get_encoder_config
Catalog._get_dist_cls_from_action_space

0 comments on commit 1213f20

Please sign in to comment.