diff --git a/doc/source/rllib/doc_code/catalog_guide.py b/doc/source/rllib/doc_code/catalog_guide.py index 4c212dbc77f306..afeecdca68bf67 100644 --- a/doc/source/rllib/doc_code/catalog_guide.py +++ b/doc/source/rllib/doc_code/catalog_guide.py @@ -30,9 +30,9 @@ import gymnasium as gym import torch -from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT +# ENCODER_OUT is a constant we use to enumerate Encoder I/O. +from ray.rllib.core.models.base import ENCODER_OUT from ray.rllib.core.models.catalog import Catalog -from ray.rllib.core.models.configs import MLPHeadConfig from ray.rllib.policy.sample_batch import SampleBatch env = gym.make("CartPole-v1") @@ -40,24 +40,17 @@ catalog = Catalog(env.observation_space, env.action_space, model_config_dict={}) # We expect a categorical distribution for CartPole. action_dist_class = catalog.get_action_dist_cls(framework="torch") -# Therefore, we need `env.action_space.n` action distribution inputs. -expected_action_dist_input_dims = (env.action_space.n,) + # Build an encoder that fits CartPole's observation space. encoder = catalog.build_encoder(framework="torch") # Build a suitable head model for the action distribution. -head_config = MLPHeadConfig( - input_dims=catalog.latent_dims, hidden_layer_dims=expected_action_dist_input_dims -) -head = head_config.build(framework="torch") +# We need `env.action_space.n` action distribution inputs. +head = torch.nn.Linear(catalog.latent_dims[0], env.action_space.n) # Now we are ready to interact with the environment obs, info = env.reset() # Encoders check for state and sequence lengths for recurrent models. # We don't need either in this case because default encoders are not recurrent. -input_batch = { - SampleBatch.OBS: torch.Tensor([obs]), - STATE_IN: None, - SampleBatch.SEQ_LENS: None, -} +input_batch = {SampleBatch.OBS: torch.Tensor([obs])} # Pass the batch through our models and the action distribution. encoding = encoder(input_batch)[ENCODER_OUT] action_dist_inputs = head(encoding) @@ -75,6 +68,8 @@ import torch from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog + +# STATE_IN, STATE_OUT and ENCODER_OUT are constants we use to enumerate Encoder I/O. from ray.rllib.core.models.base import STATE_IN, ENCODER_OUT, ACTOR from ray.rllib.policy.sample_batch import SampleBatch @@ -91,11 +86,7 @@ obs, info = env.reset() # Encoders check for state and sequence lengths for recurrent models. # We don't need either in this case because default encoders are not recurrent. -input_batch = { - SampleBatch.OBS: torch.Tensor([obs]), - STATE_IN: None, - SampleBatch.SEQ_LENS: None, -} +input_batch = {SampleBatch.OBS: torch.Tensor([obs])} # Pass the batch through our models and the action distribution. encoding = encoder(input_batch)[ENCODER_OUT][ACTOR] action_dist_inputs = policy_head(encoding) diff --git a/doc/source/rllib/doc_code/rlmodule_guide.py b/doc/source/rllib/doc_code/rlmodule_guide.py index 2cad7880581171..fa97168d12abf1 100644 --- a/doc/source/rllib/doc_code/rlmodule_guide.py +++ b/doc/source/rllib/doc_code/rlmodule_guide.py @@ -399,3 +399,50 @@ def setup(self): module = spec.build() # __pass-custom-marlmodule-shared-enc-end__ + + +# __checkpointing-begin__ +import gymnasium as gym +import shutil +import tempfile +from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec + +config = PPOConfig().environment("CartPole-v1") +env = gym.make("CartPole-v1") +# Create an RL Module that we would like to checkpoint +module_spec = SingleAgentRLModuleSpec( + module_class=PPOTorchRLModule, + observation_space=env.observation_space, + action_space=env.action_space, + model_config_dict={"fcnet_hiddens": [32]}, + catalog_class=PPOCatalog, +) +module = module_spec.build() + +# Create the checkpoint +module_ckpt_path = tempfile.mkdtemp() +module.save_to_checkpoint(module_ckpt_path) + +# Create a new RL Module from the checkpoint +module_to_load_spec = SingleAgentRLModuleSpec( + module_class=PPOTorchRLModule, + observation_space=env.observation_space, + action_space=env.action_space, + model_config_dict={"fcnet_hiddens": [32]}, + catalog_class=PPOCatalog, + load_state_path=module_ckpt_path, +) + +# Train with the checkpointed RL Module +config.rl_module( + rl_module_spec=module_to_load_spec, + _enable_rl_module_api=True, +) +algo = config.build() +algo.train() +# __checkpointing-end__ +algo.stop() +shutil.rmtree(module_ckpt_path) diff --git a/doc/source/rllib/images/rllib-concepts-rlmodules-sketch.png b/doc/source/rllib/images/rllib-concepts-rlmodules-sketch.png new file mode 100644 index 00000000000000..9e3155a2876b14 Binary files /dev/null and b/doc/source/rllib/images/rllib-concepts-rlmodules-sketch.png differ diff --git a/doc/source/rllib/key-concepts.rst b/doc/source/rllib/key-concepts.rst index b0f1259de273af..2e033399c85242 100644 --- a/doc/source/rllib/key-concepts.rst +++ b/doc/source/rllib/key-concepts.rst @@ -44,6 +44,7 @@ The model that tries to maximize the expected sum over all future rewards is cal The RL simulation feedback loop repeatedly collects data, for one (single-agent case) or multiple (multi-agent case) policies, trains the policies on these collected data, and makes sure the policies' weights are kept in sync. Thereby, the collected environment data contains observations, taken actions, received rewards and so-called **done** flags, indicating the boundaries of different episodes the agents play through in the simulation. The simulation iterations of action -> reward -> next state -> train -> repeat, until the end state, is called an **episode**, or in RLlib, a **rollout**. +The most common API to define environments is the `Farama-Foundation Gymnasium `__ API, which we also use in most of our examples. .. _algorithms: @@ -115,40 +116,32 @@ You can `configure the parallelism `__ Check out our `scaling guide `__ for more details here. -Policies --------- - -`Policies `__ are a core concept in RLlib. In a nutshell, policies are -Python classes that define how an agent acts in an environment. -`Rollout workers `__ query the policy to determine agent actions. -In a `Farama-Foundation Gymnasium `__ environment, there is a single agent and policy. -In `vector envs `__, policy inference is for multiple agents at once, -and in `multi-agent `__, there may be multiple policies, -each controlling one or more agents: +RL Modules +---------- -.. image:: images/multi-flat.svg +`RLModules `__ are framework-specific neural network containers. +In a nutshell, they carry the neural networks and define how to use them during three phases that occur in +reinforcement learning: Exploration, inference and training. +A minimal RL Module can contain a single neural network and define its exploration-, inference- and +training logic to only map observations to actions. Since RL Modules can map observations to actions, they naturally +implement reinforcement learning policies in RLlib and can therefore be found in the :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker`, +where their exploration and inference logic is used to sample from an environment. +The second place in RLlib where RL Modules commonly occur is the :py:class:`~ray.rllib.core.learner.learner.Learner`, +where their training logic is used in training the neural network. +RL Modules extend to the multi-agent case, where a single :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` +contains multiple RL Modules. The following figure is a rough sketch of how the above can look in practice: -Policies can be implemented using `any framework `__. -However, for TensorFlow and PyTorch, RLlib has -`build_tf_policy `__ and -`build_torch_policy `__ helper functions that let you -define a trainable policy with a functional-style API, for example: +.. image:: images/rllib-concepts-rlmodules-sketch.png -.. TODO: test this code snippet -.. code-block:: python - - def policy_gradient_loss(policy, model, dist_class, train_batch): - logits, _ = model.from_batch(train_batch) - action_dist = dist_class(logits, model) - return -tf.reduce_mean( - action_dist.logp(train_batch["actions"]) * train_batch["rewards"]) +.. note:: - # - MyTFPolicy = build_tf_policy( - name="MyTFPolicy", - loss_fn=policy_gradient_loss) + RL Modules are currently in alpha stage. They are wrapped in legacy :py:class:`~ray.rllib.policy.Policy` objects + to be used in :py:class:`~ray.rllib.evaluation.rollout_worker.RolloutWorker` for sampling. + This should be transparent to the user, but the following + `Policy Evaluation `__ section still refers to these legacy Policy objects. +.. policy-evaluation: Policy Evaluation ----------------- diff --git a/doc/source/rllib/package_ref/catalogs.rst b/doc/source/rllib/package_ref/catalogs.rst index 3b4bc2f1e607c5..155512194a5205 100644 --- a/doc/source/rllib/package_ref/catalogs.rst +++ b/doc/source/rllib/package_ref/catalogs.rst @@ -11,8 +11,7 @@ Basic usage ----------- Use the following basic API to get a default ``encoder`` or ``action distribution`` -out of Catalog. You can inherit from Catalog and modify the following methods to -directly inject custom components into a given RLModule. +out of Catalog. To change the catalog behavior, modify the following methods. Algorithm-specific implementations of Catalog have additional methods, for example, for building ``heads``. @@ -24,18 +23,18 @@ for example, for building ``heads``. Catalog Catalog.build_encoder Catalog.get_action_dist_cls - Catalog.get_preprocessor + Catalog.get_tokenizer_config Advanced usage -------------- -The following methods are used internally by the Catalog to build the default models. +The following methods and attributes are used internally by the Catalog to build the default models. Only override them when you need more granular control. .. autosummary:: :toctree: doc/ Catalog.latent_dims - Catalog.__post_init__ - Catalog.get_encoder_config - Catalog.get_tokenizer_config + Catalog._determine_components_hook + Catalog._get_encoder_config + Catalog._get_dist_cls_from_action_space diff --git a/doc/source/rllib/rllib-catalogs.rst b/doc/source/rllib/rllib-catalogs.rst index 5491f2fa20bbe1..6d6b0a75291ae9 100644 --- a/doc/source/rllib/rllib-catalogs.rst +++ b/doc/source/rllib/rllib-catalogs.rst @@ -4,30 +4,35 @@ .. include:: /_includes/rllib/rlmodules_rollout.rst -.. note:: Interacting with Catalogs mainly covers advanced use cases. Catalog (Alpha) =============== -Catalogs are where `RLModules `__ primarily get their models and action distributions from. -Each :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` has its own default + +Catalog is a utility abstraction that modularizes the construction of components for `RLModules `__. +It includes information such how input observation spaces should be encoded, +what action distributions should be used, and so on. :py:class:`~ray.rllib.core.models.catalog.Catalog`. For example, :py:class:`~ray.rllib.algorithms.ppo.ppo_torch_rl_module.PPOTorchRLModule` has the :py:class:`~ray.rllib.algorithms.ppo.ppo_catalog.PPOCatalog`. -You can override Catalogs’ methods to alter the behavior of existing RLModules. -This makes Catalogs a means of configuration for RLModules. -You interact with Catalogs when making deeper customization to what :py:class:`~ray.rllib.core.models.Model` and :py:class:`~ray.rllib.models.distributions.Distribution` RLlib creates by default. +To customize existing RLModules either change the RLModule directly by inheriting the class and changing the +:py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.setup` method or, alternatively, extend the Catalog class +attributed to that `RLModule`. Use Catalogs only if your customizations fits the abstractions provided by Catalog. + +.. note:: + Modifying Catalogs signifies advanced use cases so you should only consider this if modifying an RLModule or writing one does not cover your use case. + We recommend to modify Catalogs only when making deeper customizations to the decision trees that determine what :py:class:`~ray.rllib.core.models.base.Model` and :py:class:`~ray.rllib.models.distributions.Distribution` RLlib creates by default. .. note:: - If you simply want to modify a :py:class:`~ray.rllib.core.models.Model` by changing its default values, - have a look at the ``model config dict``: + If you simply want to modify a Model by changing its default values, + have a look at the model config dict: - .. dropdown:: **``MODEL_DEFAULTS`` dict** + .. dropdown:: ``MODEL_DEFAULTS`` :animate: fade-in-slide-down This dict (or an overriding sub-set) is part of :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig` and therefore also part of any algorithm-specific config. - You can override its values and pass it to an :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig` + To change the behavior RLlib's default models, override it and pass it to an AlgorithmConfig. to change the behavior RLlib's default models. .. literalinclude:: ../../../rllib/models/catalog.py @@ -35,22 +40,107 @@ You interact with Catalogs when making deeper customization to what :py:class:`~ :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ -While Catalogs have a base class :py:class:`~ray.rllib.core.models.catalog.Catalog`, you mostly interact with +While Catalogs have a base class Catalog, you mostly interact with Algorithm-specific Catalogs. Therefore, this doc also includes examples around PPO from which you can extrapolate to other algorithms. Prerequisites for this user guide is a rough understanding of `RLModules `__. This user guide covers the following topics: -- Basic usage - What are Catalogs +- Catalog design and ideas +- Catalog and AlgorithmConfig +- Basic usage - Inject your custom models into RLModules - Inject your custom action distributions into RLModules +- Write a Catalog from scratch + +What are Catalogs +~~~~~~~~~~~~~~~~~ + +Catalogs have two primary roles: Choosing the right :py:class:`~ray.rllib.core.models.base.Model` and choosing the right :py:class:`~ray.rllib.models.distributions.Distribution`. +By default, all catalogs implement decision trees that decide model architecture based on a combination of input configurations. +These mainly include the ``observation space`` and ``action space`` of the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`, the ``model config dict`` and the ``deep learning framework backend``. + +The following diagram shows the break down of the information flow towards ``models`` and ``distributions`` within an RLModule. +An RLModule creates an instance of the Catalog class they receive as part of their constructor. +It then create its internal ``models`` and ``distributions`` with the help of this Catalog. + +.. note:: + You can also modify Model or Distribution in an RLModule directly by overriding the RLModule's constructor! + +.. image:: images/catalog/catalog_and_rlm_diagram.svg + :align: center -.. - Extend RLlib’s selection of Models and Distributions with your own -.. - Write a Catalog from scratch +The following diagram shows a concrete case in more detail. + +.. dropdown:: **Example of catalog in a PPORLModule** + :animate: fade-in-slide-down + + The :py:class:`~ray.rllib.algorithms.ppo.ppo_catalog.PPOCatalog` is fed an ``observation space``, ``action space``, + a ``model config dict`` and the ``view requirements`` of the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`. + The ``model config dicts`` and the ``view requirements`` are only of interest in special cases, such as + recurrent networks or attention networks. A PPORLModule has four components that are created by the PPOCatalog: + ``Encoder``, ``value function head``, ``policy head``, and ``action distribution``. + + .. image:: images/catalog/ppo_catalog_and_rlm_diagram.svg + :align: center + + +Catalog design and ideas +~~~~~~~~~~~~~~~~~~~~~~~~ + +Since the main use cases for this component involve deep modifications of it, we explain the design and ideas behind Catalogs in this section. + +What problems do Catalogs solve? +-------------------------------- + +RL algorithms need neural network ``models`` and ``distributions``. +Within an algorithm, many different architectures for such sub-components are valid. +Moreover, models and distributions vary with environments. +However, most algorithms require models that have similarities. +The problem is finding sensible sub-components for a wide range of use cases while sharing this functionality +across algorithms. + +How do Catalogs solve this? +---------------------------- + +As states above, Catalogs implement decision-trees for sub-components of `RLModules`. +Models and distributions from a Catalog object are meant to fit together. +Since we mostly build RLModules out of :py:class:`~ray.rllib.core.models.base.Encoder` s, Heads and :py:class:`~ray.rllib.models.distributions.Distribution` s, Catalogs also generally reflect this. +For example, the PPOCatalog will output Encoders that output a latent vector and two Heads that take this latent vector as input. +(That's why Catalogs have a ``latent_dims`` attribute). Heads and distributions behave accordingly. +Whenever you create a Catalog, the decision tree is executed to find suitable configs for models and classes for distributions. +By default this happens in :py:meth:`~ray.rllib.core.models.catalog.Catalog.get_encoder_config` and :py:meth:`~ray.rllib.core.models.catalog.Catalog._get_dist_cls_from_action_space`. +Whenever you build a model, the config is turned into a model. +Distributions are instantiated per forward pass of an `RLModule` and are therefore not built. + +API philosophy +-------------- + +Catalogs attempt to encapsulate most complexity around models inside the :py:class:`~ray.rllib.core.models.base.Encoder`. +This means that recurrency, attention and other special cases are fully handles inside the Encoder and are transparent +to other components. +Encoders are the only components that the Catalog base class builds. +This is because many algorithms require custom heads and distributions but most of them can use the same encoders. +The Catalog API is designed such that interaction usually happens in two stages: + +- Instantiate a Catalog. This executes the decision tree. +- Generate arbitrary number of decided components through Catalog methods. + +The two default methods to access components on the base class are... + +- :py:meth:`~ray.rllib.core.models.catalog.Catalog.build_encoder` +- :py:meth:`~ray.rllib.core.models.catalog.Catalog.get_action_dist_cls` + +You can override these to quickly hack what models RLModules build. +Other methods are private and should only be overridden to make deep changes to the decision tree to enhance the capabilities of Catalogs. +Additionally, :py:meth:`~ray.rllib.core.models.catalog.Catalog.get_tokenizer_config` is a method that can be used when tokenization +is required. Tokenization means single-step-embedding. Encoding also means embedding but can span multiple timesteps. +In fact, RLlib's tokenizers used in its recurrent Encoders (e.g. :py:class:`~ray.rllib.core.models.torch.encoder.TorchLSTMEncoder`), +are instances of non-recurrent Encoder classes. Catalog and AlgorithmConfig -~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~~~~~~ Since Catalogs effectively control what ``models`` and ``distributions`` RLlib uses under the hood, they are also part of RLlib’s configurations. As the primary entry point for configuring RLlib, @@ -71,10 +161,14 @@ created by PPO. :start-after: __sphinx_doc_algo_configs_begin__ :end-before: __sphinx_doc_algo_configs_end__ + Basic usage ~~~~~~~~~~~ -The following three examples illustrate three basic usage patterns of Catalogs. +In the following three examples, we play with Catalogs to illustrate their API. + +High-level API +-------------- The first example showcases the general API for interacting with Catalogs. @@ -83,24 +177,12 @@ The first example showcases the general API for interacting with Catalogs. :start-after: __sphinx_doc_basic_interaction_begin__ :end-before: __sphinx_doc_basic_interaction_end__ -The second example showcases how to use the :py:class:`~ray.rllib.algorithms.ppo.ppo_catalog.PPOCatalog` -to create a ``model`` and an ``action distribution``. -This is more similar to what RLlib does internally. - -.. dropdown:: **Use catalog-generated models** - :animate: fade-in-slide-down - - .. literalinclude:: doc_code/catalog_guide.py - :language: python - :start-after: __sphinx_doc_ppo_models_begin__ - :end-before: __sphinx_doc_ppo_models_end__ +Creating models and distributions +--------------------------------- -The third example showcases how to use the base :py:class:`~ray.rllib.core.models.catalog.Catalog` -to create an ``encoder`` and an ``action distribution``. -Besides these, we create a ``head network`` that fits these two by hand to show how you can combine RLlib's -:py:class:`~ray.rllib.core.models.base.ModelConfig` API and Catalog. -Extending Catalog to also build this head is how :py:class:`~ray.rllib.core.models.catalog.Catalog` is meant to be -extended, which we cover later in this guide. +The second example showcases how to use the base :py:class:`~ray.rllib.core.models.catalog.Catalog` +to create an ``model`` and an ``action distribution``. +Besides these, we create a ``head network`` by hand that fits these two by hand. .. dropdown:: **Customize a policy head** :animate: fade-in-slide-down @@ -110,40 +192,28 @@ extended, which we cover later in this guide. :start-after: __sphinx_doc_modelsworkflow_begin__ :end-before: __sphinx_doc_modelsworkflow_end__ -What are Catalogs -~~~~~~~~~~~~~~~~~ - -Catalogs have two primary roles: Choosing the right :py:class:`~ray.rllib.core.models.Model` and choosing the right :py:class:`~ray.rllib.models.distributions.Distribution`. -By default, all catalogs implement decision trees that decide model architecture based on a combination of input configurations. -These mainly include the ``observation space`` and ``action space`` of the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`, the ``model config dict`` and the ``deep learning framework backend``. - -The following diagram shows the break down of the information flow towards ``models`` and ``distributions`` within an :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`. -An :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` creates an instance of the Catalog class they receive as part of their constructor. -It then create its internal ``models`` and ``distributions`` with the help of this Catalog. - -.. note:: - You can also modify :py:class:`~ray.rllib.core.models.Model` or :py:class:`~ray.rllib.models.distributions.Distribution` in an :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` directly by overriding the RLModule's constructor! - -.. image:: images/catalog/catalog_and_rlm_diagram.svg - :align: center +Creating models and distributions for PPO +----------------------------------------- -The following diagram shows a concrete case in more detail. +The third example showcases how to use the :py:class:`~ray.rllib.algorithms.ppo.ppo_catalog.PPOCatalog` +to create a ``encoder`` and an ``action distribution``. +This is more similar to what RLlib does internally. -.. dropdown:: **Example of catalog in a PPORLModule** +.. dropdown:: **Use catalog-generated models** :animate: fade-in-slide-down - The :py:class:`~ray.rllib.algorithms.ppo.ppo_catalog.PPOCatalog` is fed an ``observation space``, ``action space``, - a ``model config dict`` and the ``view requirements`` of the :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`. - The model config dicts and the view requirements are only of interest in special cases, such as - recurrent networks or attention networks. A PPORLModule has four components that are created by the - :py:class:`~ray.rllib.algorithms.ppo.ppo_catalog.PPOCatalog`: - ``Encoder``, ``value function head``, ``policy head``, and ``action distribution``. + .. literalinclude:: doc_code/catalog_guide.py + :language: python + :start-after: __sphinx_doc_ppo_models_begin__ + :end-before: __sphinx_doc_ppo_models_end__ - .. image:: images/catalog/ppo_catalog_and_rlm_diagram.svg - :align: center +Note that the above two examples illustrate in principle what it takes to implement a Catalog. +In this case, we see the difference between `Catalog` and `PPOCatalog`. +In most cases, we can reuse the capabilities of the base :py:class:`~ray.rllib.core.models.catalog.Catalog` base class +and only need to add methods to build head networks that we can then use in the appropriate `RLModule`. -Inject your custom model or action distributions into RLModules +Inject your custom model or action distributions into Catalogs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ You can make a :py:class:`~ray.rllib.core.models.catalog.Catalog` build custom ``models`` by overriding the Catalog’s methods used by RLModules to build ``models``. @@ -154,28 +224,68 @@ Have a look at these lines from the constructor of the :py:class:`~ray.rllib.alg :start-after: __sphinx_doc_begin__ :end-before: __sphinx_doc_end__ +Note that what happens inside the constructor of PPOTorchRLModule is similar to the earlier example `Creating models and distributions for PPO `__. + Consequently, in order to build a custom :py:class:`~ray.rllib.core.models.Model` compatible with a PPORLModule, you can override methods by inheriting from :py:class:`~ray.rllib.algorithms.ppo.ppo_catalog.PPOCatalog` or write a :py:class:`~ray.rllib.core.models.catalog.Catalog` that implements them from scratch. -The following showcases such modifications. +The following examples showcase such modifications: -This example shows two modifications: -- How to write a custom :py:class:`~ray.rllib.models.distributions.Distribution` -- How to inject a custom action distribution into a :py:class:`~ray.rllib.core.models.catalog.Catalog` +.. tab-set:: -.. literalinclude:: ../../../rllib/examples/catalog/custom_action_distribution.py - :language: python - :start-after: __sphinx_doc_begin__ - :end-before: __sphinx_doc_end__ + .. tab-item:: Adding a custom Encoder + + This example shows two modifications: + + - How to write a custom :py:class:`~ray.rllib.models.base.Encoder` + - How to inject the custom Encoder into a :py:class:`~ray.rllib.core.models.catalog.Catalog` + + Note that, if you only want to inject your Encoder into a single :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule`, the recommended workflow is to inherit + from an existing RL Module and place the Encoder there. + + .. literalinclude:: ../../../rllib/examples/catalog/mobilenet_v2_encoder.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + + + .. tab-item:: Adding a custom action distribution + + This example shows two modifications: + + - How to write a custom :py:class:`~ray.rllib.models.distributions.Distribution` + - How to inject the custom action distribution into a :py:class:`~ray.rllib.core.models.catalog.Catalog` + + .. literalinclude:: ../../../rllib/examples/catalog/custom_action_distribution.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + +These examples target PPO but the workflows apply to all RLlib algorithms. +Note that PPO adds the :py:class:`from ray.rllib.core.models.base.ActorCriticEncoder` and two heads (policy- and value-head) to the base class. +You can override these similarly to the above. +Other algorithms may add different sub-components or override default ones. + +Write a Catalog from scratch +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +You only need this when you want to write a new Algorithm under RLlib. +Note that writing an Algorithm does not strictly require writing a new Catalog but you can use Catalogs as a tool to create +the fitting default sub-components, such as models or distributions. +The following are typical requirements and steps for writing a new Catalog: +- Does the Algorithm need a special Encoder? Overwrite :py:meth:`~ray.rllib.core.models.catalog.Catalog._get_encoder_config`. +- Does the Algorithm need an additional network? Write a method to build it. You can use RLlib's model configurations to build models from dimensions. +- Does the Algorithm need a custom distribution? Overwrite :py:meth:`~ray.rllib.core.models.catalog.Catalog._get_dist_cls_from_action_space`. +- Does the Algorithm need a special tokenizer? Overwrite :py:meth:`~ray.rllib.core.models.catalog.Catalog.get_tokenizer_config`. +- Does the Algorithm not need an Encoder at all? Overwrite :py:meth:`~ray.rllib.core.models.catalog.Catalog._determine_components_hook`. +The following example shows our implementation of a Catalog for PPO that follows the above steps: -Notable TODOs -------------- +.. dropdown:: **Catalog for PPORLModules** -- Add cross references to Model and Distribution API docs -- Add example that shows how to inject own model -- Add more instructions on how to write a catalog from scratch -- Add section "Extend RLlib’s selection of Models and Distributions with your own" -- Add section "Write a Catalog from scratch" \ No newline at end of file + .. literalinclude:: ../../../rllib/algorithms/ppo/ppo_catalog.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ \ No newline at end of file diff --git a/doc/source/rllib/rllib-connector.rst b/doc/source/rllib/rllib-connector.rst index 3717a499af4131..0c609f985e8a6f 100644 --- a/doc/source/rllib/rllib-connector.rst +++ b/doc/source/rllib/rllib-connector.rst @@ -2,7 +2,7 @@ .. include:: /_includes/rllib/we_are_hiring.rst -Connectors (Alpha) +Connectors (Beta) ================== Connector are components that handle transformations on inputs and outputs of a given RL policy, with the goal of improving diff --git a/doc/source/rllib/rllib-rlmodule.rst b/doc/source/rllib/rllib-rlmodule.rst index 91473cc98452ad..f1eec46db53edd 100644 --- a/doc/source/rllib/rllib-rlmodule.rst +++ b/doc/source/rllib/rllib-rlmodule.rst @@ -302,7 +302,6 @@ In :py:class:`~ray.rllib.core.rl_module.rl_module.RLModule` you can enforce the To learn more, see the `SpecType` documentation. - Writing Custom Multi-Agent RL Modules (Advanced) ------------------------------------------------ @@ -313,14 +312,11 @@ The :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModule` offers The following example creates a custom multi-agent RL module with underlying modules. The modules share an encoder, which gets applied to the global part of the observations space. The local part passes through a separate encoder, specific to each policy. -.. tab-set:: - - .. tab-item:: Multi agent with shared encoder (Torch) - .. literalinclude:: doc_code/rlmodule_guide.py - :language: python - :start-after: __write-custom-marlmodule-shared-enc-begin__ - :end-before: __write-custom-marlmodule-shared-enc-end__ +.. literalinclude:: doc_code/rlmodule_guide.py + :language: python + :start-after: __write-custom-marlmodule-shared-enc-begin__ + :end-before: __write-custom-marlmodule-shared-enc-end__ To construct this custom multi-agent RL module, pass the class to the :py:class:`~ray.rllib.core.rl_module.marl_module.MultiAgentRLModuleSpec` constructor. Also, pass the :py:class:`~ray.rllib.core.rl_module.rl_module.SingleAgentRLModuleSpec` for each agent because RLlib requires the observation, action spaces, and model hyper-parameters for each agent. @@ -334,7 +330,11 @@ To construct this custom multi-agent RL module, pass the class to the :py:class: Extending Existing RLlib RL Modules ----------------------------------- -RLlib provides a number of RL Modules for different frameworks (e.g., PyTorch, TensorFlow, etc.). Extend these modules by inheriting from them and overriding the methods you need to customize. For example, extend :py:class:`~ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` and augment it with your own customization. Then pass the new customized class into the algorithm configuration. +RLlib provides a number of RL Modules for different frameworks (e.g., PyTorch, TensorFlow, etc.). +To customize existing RLModules you can change the RLModule directly by inheriting the class and changing the +:py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.setup` or other methods. +For example, extend :py:class:`~ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` and augment it with your own customization. +Then pass the new customized class into the appropriate :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig`. There are two possible ways to extend existing RL Modules: @@ -342,7 +342,10 @@ There are two possible ways to extend existing RL Modules: .. tab-item:: Inheriting existing RL Modules - One way to extend existing RL Modules is to inherit from them and override the methods you need to customize. For example, extend :py:class:`~ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule` and augment it with your own customization. Then pass the new customized class into the algorithm configuration to use the PPO algorithm to optimize your custom RL Module. + The default way to extend existing RL Modules is to inherit from them and override the methods you need to customize. + Then pass the new customized class into the :py:class:`~ray.rllib.algorithms.algorithm_config.AlgorithmConfig` to optimize your custom RL Module. + This is the preferred approach. With it, we can define our own models explicitly within a given RL Module + and don't need to interact with a Catalog, so you don't need to learn about Catalog. .. code-block:: python @@ -357,19 +360,39 @@ There are two possible ways to extend existing RL Modules: rl_module_spec=SingleAgentRLModuleSpec(module_class=MyPPORLModule) ) + A concrete example: If you want to replace the default encoder that RLlib builds for torch, PPO and a given observation space, + you can override :py:class:`~ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule`'s + :py:meth:`~ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module.PPOTorchRLModule.__init__` to create your custom + encoder instead of the default one. We do this in the following example. + + .. literalinclude:: ../../../rllib/examples/rl_module/mobilenet_rlm.py + :language: python + :start-after: __sphinx_doc_begin__ + :end-before: __sphinx_doc_end__ + .. tab-item:: Extending RL Module Catalog - Another way to customize your module is by extending its :py:class:`~ray.rllib.core.models.catalog.Catalog`. The :py:class:`~ray.rllib.core.models.catalog.Catalog` is a component that defines the default architecture and behavior of a model based on factors such as ``observation_space``, ``action_space``, etc. To modify sub-components of an existing RL Module, extend the corresponding Catalog class. + An advanced way to customize your module is by extending its :py:class:`~ray.rllib.core.models.catalog.Catalog`. + The Catalog is a component that defines the default models and other sub-components for RL Modules based on factors such as ``observation_space``, ``action_space``, etc. + For more information on the :py:class:`~ray.rllib.core.models.catalog.Catalog` class, refer to the `Catalog user guide `__. + By modifying the Catalog, you can alter what sub-components are being built for existing RL Modules. + This approach is useful mostly if you want your custom component to integrate with the decision trees that the Catalogs represent. + The following use cases are examples of what may require you to extend the Catalogs: - For instance, to adapt the existing ``PPORLModule`` for a custom graph observation space not supported by RLlib out-of-the-box, extend the :py:class:`~ray.rllib.core.models.catalog.Catalog` class used to create the ``PPORLModule`` and override the method responsible for returning the encoder component to ensure that your custom encoder replaces the default one initially provided by RLlib. For more information on the :py:class:`~ray.rllib.core.models.catalog.Catalog` class, refer to the `Catalog user guide `__. + - Choosing a custom model only for a certain observation space. + - Using a custom action distribution in multiple distinct Algorithms. + - Reusing your custom component in many distinct RL Modules. + For instance, to adapt existing ``PPORLModules`` for a custom graph observation space not supported by RLlib out-of-the-box, + extend the :py:class:`~ray.rllib.core.models.catalog.Catalog` class used to create the ``PPORLModule`` + and override the method responsible for returning the encoder component to ensure that your custom encoder replaces the default one initially provided by RLlib. .. code-block:: python class MyAwesomeCatalog(PPOCatalog): - def get_actor_critic_encoder_config(): + def build_actor_critic_encoder(): # create your awesome graph encoder here and return it pass @@ -380,6 +403,17 @@ There are two possible ways to extend existing RL Modules: ) +Checkpointing RL Modules +------------------------ + +RL Modules can be checkpointed with their two methods :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.save_to_checkpoint` and :py:meth:`~ray.rllib.core.rl_module.rl_module.RLModule.from_checkpoint`. +The following example shows how these methods can be used outside of, or in conjunction with, an RLlib Algorithm. + +.. literalinclude:: doc_code/rlmodule_guide.py + :language: python + :start-after: __checkpointing-begin__ + :end-before: __checkpointing-end__ + Migrating from Custom Policies and Models to RL Modules ------------------------------------------------------- @@ -541,12 +575,4 @@ See `Writing Custom Single Agent RL Modules`_ for more details on how to impleme ... def _forward_exploration(self, batch): - ... - - -Notable TODOs -------------- - -- [] Add support for RNNs. -- [] Checkpointing. -- [] End to end example for custom RL Modules extending PPORLModule (e.g. LLM) \ No newline at end of file + ... \ No newline at end of file diff --git a/rllib/BUILD b/rllib/BUILD index 8a31f265b32b8c..971c99031806ba 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -3143,6 +3143,32 @@ py_test( args = ["--run=PPO", "--as-test", "--framework=torch", "--stop-reward=28", "--num-cpus=4", "--use-prev-action", "--use-prev-reward"] ) +py_test( + name = "examples/catalog/custom_action_distribution", + main = "examples/catalog/custom_action_distribution.py", + tags = ["team:rllib", "examples", "no_main"], + size = "small", + srcs = ["examples/catalog/custom_action_distribution.py"], +) + + +py_test( + name = "examples/catalog/mobilenet_v2_encoder", + main = "examples/catalog/mobilenet_v2_encoder.py", + tags = ["team:rllib", "examples", "no_main"], + size = "small", + srcs = ["examples/catalog/mobilenet_v2_encoder.py"], +) + + +py_test( + name = "examples/rl_module/mobilenet_rlm", + main = "examples/rl_module/mobilenet_rlm.py", + tags = ["team:rllib", "examples", "no_main"], + size = "small", + srcs = ["examples/rl_module/mobilenet_rlm.py"], +) + py_test( name = "examples/centralized_critic_tf", main = "examples/centralized_critic.py", diff --git a/rllib/algorithms/ppo/ppo_catalog.py b/rllib/algorithms/ppo/ppo_catalog.py index 2a6df6710324cd..474491e3d4a30b 100644 --- a/rllib/algorithms/ppo/ppo_catalog.py +++ b/rllib/algorithms/ppo/ppo_catalog.py @@ -1,3 +1,4 @@ +# __sphinx_doc_begin__ import gymnasium as gym from ray.rllib.core.models.catalog import Catalog @@ -8,6 +9,7 @@ ) from ray.rllib.core.models.base import Encoder, ActorCriticEncoder, Model from ray.rllib.utils import override +from ray.rllib.utils.annotations import OverrideToImplementCustomLogic def _check_if_diag_gaussian(action_distribution_cls, framework): @@ -72,12 +74,14 @@ def __init__( # Replace EncoderConfig by ActorCriticEncoderConfig self.actor_critic_encoder_config = ActorCriticEncoderConfig( - base_encoder_config=self.encoder_config, - shared=self.model_config_dict["vf_share_layers"], + base_encoder_config=self._encoder_config, + shared=self._model_config_dict["vf_share_layers"], ) - self.pi_and_vf_head_hiddens = self.model_config_dict["post_fcnet_hiddens"] - self.pi_and_vf_head_activation = self.model_config_dict["post_fcnet_activation"] + self.pi_and_vf_head_hiddens = self._model_config_dict["post_fcnet_hiddens"] + self.pi_and_vf_head_activation = self._model_config_dict[ + "post_fcnet_activation" + ] # We don't have the exact (framework specific) action dist class yet and thus # cannot determine the exact number of output nodes (action space) required. @@ -92,6 +96,7 @@ def __init__( output_layer_dim=1, ) + @OverrideToImplementCustomLogic def build_actor_critic_encoder(self, framework: str) -> ActorCriticEncoder: """Builds the ActorCriticEncoder. @@ -114,9 +119,10 @@ def build_encoder(self, framework: str) -> Encoder: Since PPO uses an ActorCriticEncoder, this method should not be implemented. """ raise NotImplementedError( - "Use PPOCatalog.build_actor_critic_encoder() instead." + "Use PPOCatalog.build_actor_critic_encoder() instead for PPO." ) + @OverrideToImplementCustomLogic def build_pi_head(self, framework: str) -> Model: """Builds the policy head. @@ -132,18 +138,18 @@ def build_pi_head(self, framework: str) -> Model: """ # Get action_distribution_cls to find out about the output dimension for pi_head action_distribution_cls = self.get_action_dist_cls(framework=framework) - if self.model_config_dict["free_log_std"]: + if self._model_config_dict["free_log_std"]: _check_if_diag_gaussian( action_distribution_cls=action_distribution_cls, framework=framework ) required_output_dim = action_distribution_cls.required_input_dim( - space=self.action_space, model_config=self.model_config_dict + space=self.action_space, model_config=self._model_config_dict ) # Now that we have the action dist class and number of outputs, we can define # our pi-config and build the pi head. pi_head_config_class = ( FreeLogStdMLPHeadConfig - if self.model_config_dict["free_log_std"] + if self._model_config_dict["free_log_std"] else MLPHeadConfig ) self.pi_head_config = pi_head_config_class( @@ -156,6 +162,7 @@ def build_pi_head(self, framework: str) -> Model: return self.pi_head_config.build(framework=framework) + @OverrideToImplementCustomLogic def build_vf_head(self, framework: str) -> Model: """Builds the value function head. @@ -170,3 +177,6 @@ def build_vf_head(self, framework: str) -> Model: The value function head. """ return self.vf_head_config.build(framework=framework) + + +# __sphinx_doc_end__ diff --git a/rllib/core/models/catalog.py b/rllib/core/models/catalog.py index 48089700267d95..b956343babae75 100644 --- a/rllib/core/models/catalog.py +++ b/rllib/core/models/catalog.py @@ -24,16 +24,25 @@ from ray.rllib.utils.spaces.space_utils import flatten_space from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space from ray.rllib.utils.typing import ViewRequirementsDict +from ray.rllib.utils.annotations import ( + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) class Catalog: - """Describes the sub-modules architectures to be used in RLModules. + """Describes the sub-module-architectures to be used in RLModules. RLlib's native RLModules get their Models from a Catalog object. By default, that Catalog builds the configs it has as attributes. - You can modify a Catalog so that it builds different Models by subclassing and - overriding the build_* methods. Alternatively, you can customize the configs - inside RLlib's Catalogs to customize what is being built by RLlib. + This component was build to be hackable and extensible. You can inject custom + components into RL Modules by overriding the `build_xxx` methods of this class. + Note that it is recommended to write a custom RL Module for a single use-case. + Modifications to Catalogs mostly make sense if you want to reuse the same + Catalog for different RL Modules. For example if you have written a custom + encoder and want to inject it into different RL Modules (e.g. for PPO, DQN, etc.). + You can influence the decision tree that determines the sub-components by modifying + `Catalog._determine_components_hook`. Usage example: @@ -77,8 +86,6 @@ def __init__( self, observation_space: gym.Space, action_space: gym.Space, - # TODO (Artur): Turn model_config into model_config_dict to distinguish - # between ModelConfig and a model_config_dict dict library-wide. model_config_dict: dict, view_requirements: dict = None, ): @@ -97,62 +104,78 @@ def __init__( self.action_space = action_space # TODO (Artur): Make model defaults a dataclass - self.model_config_dict = {**MODEL_DEFAULTS, **model_config_dict} - self.view_requirements = view_requirements - + self._model_config_dict = {**MODEL_DEFAULTS, **model_config_dict} + self._view_requirements = view_requirements self._latent_dims = None - # Overwrite this post-init hook in subclasses - self.__post_init__() - - @property - def latent_dims(self): - """Returns the latent dimensions of the encoder. - - This establishes an agreement between encoder and heads about the latent - dimensions. Encoders can be built to output a latent tensor with - `latent_dims` dimensions, and heads can be built with tensors of - `latent_dims` dimensions as inputs. This can be safely ignored if this - agreement is not needed in case of modifications to the Catalog. + self._determine_components_hook() - Returns: - The latent dimensions of the encoder. - """ - return self._latent_dims + @OverrideToImplementCustomLogic_CallToSuperRecommended + def _determine_components_hook(self): + """Decision tree hook for subclasses to override. - @latent_dims.setter - def latent_dims(self, value): - self._latent_dims = value + By default, this method executes the decision tree that determines the + components that a Catalog builds. You can extend the components by overriding + this or by adding to the constructor of your subclass. - def __post_init__(self): - """Post-init hook for subclasses to override. + Override this method if you don't want to use the default components + determined here. If you want to use them but add additional components, you + should call `super()._determine_components()` at the beginning of your + implementation. This makes it so that subclasses are not forced to create an encoder config if the rest of their catalog is not dependent on it or if it breaks. - At the end of Catalog initialization, an attribute `Catalog.latent_dims` + At the end of this method, an attribute `Catalog.latent_dims` should be set so that heads can be built using that information. """ - self.encoder_config = self.get_encoder_config( + self._encoder_config = self._get_encoder_config( observation_space=self.observation_space, action_space=self.action_space, - model_config_dict=self.model_config_dict, - view_requirements=self.view_requirements, + model_config_dict=self._model_config_dict, + view_requirements=self._view_requirements, ) # Create a function that can be called when framework is known to retrieve the # class type for action distributions self._action_dist_class_fn = functools.partial( - self.get_dist_cls_from_action_space, action_space=self.action_space + self._get_dist_cls_from_action_space, action_space=self.action_space ) # The dimensions of the latent vector that is output by the encoder and fed # to the heads. - self.latent_dims = self.encoder_config.output_dims + self.latent_dims = self._encoder_config.output_dims + @property + def latent_dims(self): + """Returns the latent dimensions of the encoder. + + This establishes an agreement between encoder and heads about the latent + dimensions. Encoders can be built to output a latent tensor with + `latent_dims` dimensions, and heads can be built with tensors of + `latent_dims` dimensions as inputs. This can be safely ignored if this + agreement is not needed in case of modifications to the Catalog. + + Returns: + The latent dimensions of the encoder. + """ + return self._latent_dims + + @latent_dims.setter + def latent_dims(self, value): + self._latent_dims = value + + @OverrideToImplementCustomLogic def build_encoder(self, framework: str) -> Encoder: """Builds the encoder. - By default this method builds an encoder instance from Catalog.encoder_config. + By default, this method builds an encoder instance from Catalog._encoder_config. + + You should override this if you want to use RLlib's default RL Modules but + only want to change the encoder. For example, if you want to use a custom + encoder, but want to use RLlib's default heads, action distribution and how + tensors are routed between them. If you want to have full control over the + RL Module, we recommend writing your own RL Module by inheriting from one of + RLlib's RL Modules instead. Args: framework: The framework to use. Either "torch" or "tf2". @@ -160,20 +183,24 @@ def build_encoder(self, framework: str) -> Encoder: Returns: The encoder. """ - assert hasattr(self, "encoder_config"), ( - "You must define a `Catalog.encoder_config` attribute in your Catalog " + assert hasattr(self, "_encoder_config"), ( + "You must define a `Catalog._encoder_config` attribute in your Catalog " "subclass or override the `Catalog.build_encoder` method. By default, " "an encoder_config is created in the __post_init__ method." ) - return self.encoder_config.build(framework=framework) + return self._encoder_config.build(framework=framework) + @OverrideToImplementCustomLogic def get_action_dist_cls(self, framework: str): """Get the action distribution class. The default behavior is to get the action distribution from the - `Catalog.action_dist_class_fn`. This can be overridden to build a custom action - distribution as a means of configuring the behavior of a RLModule - implementation. + `Catalog._action_dist_class_fn`. + + You should override this to have RLlib build your custom action + distribution instead of the default one. For example, if you don't want to + use RLlib's default RLModules with their default models, but only want to + change the distribution that Catalog returns. Args: framework: The framework to use. Either "torch" or "tf2". @@ -190,7 +217,7 @@ def get_action_dist_cls(self, framework: str): return self._action_dist_class_fn(framework=framework) @classmethod - def get_encoder_config( + def _get_encoder_config( cls, observation_space: gym.Space, model_config_dict: dict, @@ -312,6 +339,7 @@ def get_encoder_config( return encoder_config @classmethod + @OverrideToImplementCustomLogic def get_tokenizer_config( cls, observation_space: gym.Space, @@ -324,13 +352,19 @@ def get_tokenizer_config( inputs. By default, RLlib uses the models supported by Catalog out of the box to tokenize. + You should override this method if you want to change the custom tokenizer + inside current encoders that Catalog returns without providing the recurrent + network as a whole. For example, if you want to define some custom CNN layers + as a tokenizer for a recurrent encoder that already includes the recurrent + layers and handles the state. + Args: observation_space: The observation space to use. model_config_dict: The model config to use. view_requirements: The view requirements to use if anything else than observation_space is to be encoded. This signifies an advanced use case. """ - return cls.get_encoder_config( + return cls._get_encoder_config( observation_space=observation_space, # Use model_config_dict without flags that would end up in complex models model_config_dict={ @@ -341,7 +375,7 @@ def get_tokenizer_config( ) @classmethod - def get_dist_cls_from_action_space( + def _get_dist_cls_from_action_space( cls, action_space: gym.Space, *, @@ -534,7 +568,7 @@ def _multi_action_dist_partial_helper( action_space_struct = get_base_struct_from_space(action_space) flat_action_space = flatten_space(action_space) child_distribution_cls_struct = tree.map_structure( - lambda s: catalog_cls.get_dist_cls_from_action_space( + lambda s: catalog_cls._get_dist_cls_from_action_space( action_space=s, framework=framework, ), diff --git a/rllib/core/models/tests/test_catalog.py b/rllib/core/models/tests/test_catalog.py index 3d38d7d5e064f9..60959bd5118daa 100644 --- a/rllib/core/models/tests/test_catalog.py +++ b/rllib/core/models/tests/test_catalog.py @@ -199,7 +199,7 @@ def test_get_encoder_config(self): view_requirements=None, ) - model_config = catalog.get_encoder_config( + model_config = catalog._get_encoder_config( observation_space=input_space, model_config_dict=model_config_dict ) self.assertEqual(type(model_config), model_config_type) @@ -328,7 +328,7 @@ def test_get_dist_cls_from_action_space(self): if framework == "tf2": framework = "tf2" - dist_cls = catalog.get_dist_cls_from_action_space( + dist_cls = catalog._get_dist_cls_from_action_space( action_space=action_space, framework=framework, ) @@ -450,9 +450,9 @@ def _forward(self, input_dict, **kwargs): } class MyCustomCatalog(PPOCatalog): - def __post_init__(self): + def _determine_components(self): self._action_dist_class_fn = functools.partial( - self.get_dist_cls_from_action_space, action_space=self.action_space + self._get_dist_cls_from_action_space, action_space=self.action_space ) self.latent_dims = (10,) self.encoder_config = MyCostumTorchEncoderConfig( diff --git a/rllib/examples/catalog/custom_action_distribution.py b/rllib/examples/catalog/custom_action_distribution.py index 42263f086a21ec..979bee581bd98c 100644 --- a/rllib/examples/catalog/custom_action_distribution.py +++ b/rllib/examples/catalog/custom_action_distribution.py @@ -19,23 +19,23 @@ class CustomTorchCategorical(Distribution): def __init__(self, logits): self.torch_dist = torch.distributions.categorical.Categorical(logits=logits) - def sample(self, sample_shape=torch.Size()): + def sample(self, sample_shape=torch.Size(), **kwargs): return self.torch_dist.sample(sample_shape) - def rsample(self, sample_shape=torch.Size()): + def rsample(self, sample_shape=torch.Size(), **kwargs): return self._dist.rsample(sample_shape) - def logp(self, value): + def logp(self, value, **kwargs): return self.torch_dist.log_prob(value) def entropy(self): return self.torch_dist.entropy() - def kl(self, other): + def kl(self, other, **kwargs): return torch.distributions.kl.kl_divergence(self.torch_dist, other.torch_dist) @staticmethod - def required_input_dim(space): + def required_input_dim(space, **kwargs): return int(space.n) @classmethod diff --git a/rllib/examples/catalog/mobilenet_v2_encoder.py b/rllib/examples/catalog/mobilenet_v2_encoder.py new file mode 100644 index 00000000000000..3d22fe92f06074 --- /dev/null +++ b/rllib/examples/catalog/mobilenet_v2_encoder.py @@ -0,0 +1,80 @@ +""" +This example shows two modifications: +- How to write a custom Encoder (using MobileNet v2) +- How to enhance Catalogs with this custom Encoder + +With the pattern shown in this example, we can enhance Catalogs such that they extend +to new observation- or action spaces while retaining their original functionality. +""" +# __sphinx_doc_begin__ +import gymnasium as gym +import numpy as np + +from ray.rllib.algorithms.ppo.ppo import PPOConfig +from ray.rllib.algorithms.ppo.ppo_catalog import PPOCatalog +from ray.rllib.examples.models.mobilenet_v2_encoder import ( + MobileNetV2EncoderConfig, + MOBILENET_INPUT_SHAPE, +) +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.env.random_env import RandomEnv + + +# Define a PPO Catalog that we can use to inject our MobileNetV2 Encoder into RLlib's +# decision tree of what model to choose +class MobileNetEnhancedPPOCatalog(PPOCatalog): + @classmethod + def _get_encoder_config( + cls, + observation_space: gym.Space, + **kwargs, + ): + if ( + isinstance(observation_space, gym.spaces.Box) + and observation_space.shape == MOBILENET_INPUT_SHAPE + ): + # Inject our custom encoder here, only if the observation space fits it + return MobileNetV2EncoderConfig() + else: + return super()._get_encoder_config(observation_space, **kwargs) + + +# Create a generic config with our enhanced Catalog +ppo_config = ( + PPOConfig() + .rl_module( + rl_module_spec=SingleAgentRLModuleSpec( + catalog_class=MobileNetEnhancedPPOCatalog + ) + ) + .rollouts(num_rollout_workers=0) + # The following training settings make it so that a training iteration is very + # quick. This is just for the sake of this example. PPO will not learn properly + # with these settings! + .training(train_batch_size=32, sgd_minibatch_size=16, num_sgd_iter=1) +) + +# CartPole's observation space is not compatible with our MobileNetV2 Encoder, so +# this will use the default behaviour of Catalogs +ppo_config.environment("CartPole-v1") +results = ppo_config.build().train() +print(results) + +# For this training, we use a RandomEnv with observations of shape +# MOBILENET_INPUT_SHAPE. This will use our custom Encoder. +ppo_config.environment( + RandomEnv, + env_config={ + "action_space": gym.spaces.Discrete(2), + # Test a simple Image observation space. + "observation_space": gym.spaces.Box( + 0.0, + 1.0, + shape=MOBILENET_INPUT_SHAPE, + dtype=np.float32, + ), + }, +) +results = ppo_config.build().train() +print(results) +# __sphinx_doc_end__ diff --git a/rllib/examples/models/mobilenet_v2_encoder.py b/rllib/examples/models/mobilenet_v2_encoder.py new file mode 100644 index 00000000000000..6a3482f547b0f8 --- /dev/null +++ b/rllib/examples/models/mobilenet_v2_encoder.py @@ -0,0 +1,47 @@ +""" +This file implements a MobileNet v2 Encoder. +It uses MobileNet v2 to encode images into a latent space of 1000 dimensions. + +Depending on the experiment, the MobileNet v2 encoder layers can be frozen or +unfrozen. This is controlled by the `freeze` parameter in the config. + +This is an example of how a pre-trained neural network can be used as an encoder +in RLlib. You can modify this example to accommodate your own encoder network or +other pre-trained networks. +""" + +from ray.rllib.core.models.base import Encoder, ENCODER_OUT +from ray.rllib.core.models.configs import ModelConfig +from ray.rllib.core.models.torch.base import TorchModel +from ray.rllib.utils.framework import try_import_torch + +torch, nn = try_import_torch() + +MOBILENET_INPUT_SHAPE = (3, 224, 224) + + +class MobileNetV2EncoderConfig(ModelConfig): + # MobileNet v2 has a flat output with a length of 1000. + output_dims = (1000,) + freeze = True + + def build(self, framework): + assert framework == "torch", "Unsupported framework `{}`!".format(framework) + return MobileNetV2Encoder(self) + + +class MobileNetV2Encoder(TorchModel, Encoder): + """A MobileNet v2 encoder for RLlib.""" + + def __init__(self, config): + super().__init__(config) + self.net = torch.hub.load( + "pytorch/vision:v0.6.0", "mobilenet_v2", pretrained=True + ) + if config.freeze: + # We don't want to train this encoder, so freeze its parameters! + for p in self.net.parameters(): + p.requires_grad = False + + def _forward(self, input_dict, **kwargs): + return {ENCODER_OUT: (self.net(input_dict["obs"]))} diff --git a/rllib/examples/rl_module/mobilenet_rlm.py b/rllib/examples/rl_module/mobilenet_rlm.py new file mode 100644 index 00000000000000..906b655024a815 --- /dev/null +++ b/rllib/examples/rl_module/mobilenet_rlm.py @@ -0,0 +1,82 @@ +""" +This example shows how to take full control over what models and action distribution +are being built inside an RL Module. With this pattern, we can bypass a Catalog and +explicitly define our own models within a given RL Module. +""" +# __sphinx_doc_begin__ +import gymnasium as gym +import numpy as np + +from ray.rllib.algorithms.ppo.ppo import PPOConfig +from ray.rllib.algorithms.ppo.torch.ppo_torch_rl_module import PPOTorchRLModule +from ray.rllib.core.models.configs import MLPHeadConfig +from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec +from ray.rllib.examples.env.random_env import RandomEnv +from ray.rllib.models.torch.torch_distributions import TorchCategorical +from ray.rllib.examples.models.mobilenet_v2_encoder import ( + MobileNetV2EncoderConfig, + MOBILENET_INPUT_SHAPE, +) +from ray.rllib.core.models.configs import ActorCriticEncoderConfig + + +class MobileNetTorchPPORLModule(PPOTorchRLModule): + """A PPORLModules with mobilenet v2 as an encoder. + + The idea behind this model is to demonstrate how we can bypass catalog to + take full control over what models and action distribution are being built. + In this example, we do this to modify an existing RLModule with a custom encoder. + """ + + def setup(self): + mobilenet_v2_config = MobileNetV2EncoderConfig() + # Since we want to use PPO, which is an actor-critic algorithm, we need to + # use an ActorCriticEncoderConfig to wrap the base encoder config. + actor_critic_encoder_config = ActorCriticEncoderConfig( + base_encoder_config=mobilenet_v2_config + ) + + self.encoder = actor_critic_encoder_config.build(framework="torch") + mobilenet_v2_output_dims = mobilenet_v2_config.output_dims + + pi_config = MLPHeadConfig( + input_dims=mobilenet_v2_output_dims, + output_layer_dim=2, + ) + + vf_config = MLPHeadConfig( + input_dims=mobilenet_v2_output_dims, output_layer_dim=1 + ) + + self.pi = pi_config.build(framework="torch") + self.vf = vf_config.build(framework="torch") + + self.action_dist_cls = TorchCategorical + + +config = ( + PPOConfig() + .rl_module( + rl_module_spec=SingleAgentRLModuleSpec(module_class=MobileNetTorchPPORLModule) + ) + .environment( + RandomEnv, + env_config={ + "action_space": gym.spaces.Discrete(2), + # Test a simple Image observation space. + "observation_space": gym.spaces.Box( + 0.0, + 1.0, + shape=MOBILENET_INPUT_SHAPE, + dtype=np.float32, + ), + }, + ) + # The following training settings make it so that a training iteration is very + # quick. This is just for the sake of this example. PPO will not learn properly + # with these settings! + .training(train_batch_size=32, sgd_minibatch_size=16, num_sgd_iter=1) +) + +config.build().train() +# __sphinx_doc_end__