Skip to content

Commit

Permalink
[RLlib contrib] TD3. (ray-project#36726)
Browse files Browse the repository at this point in the history
  • Loading branch information
avnishn authored and Andrew Xue committed Oct 10, 2023
1 parent 9c5a1d2 commit b737bcb
Show file tree
Hide file tree
Showing 8 changed files with 341 additions and 3 deletions.
21 changes: 18 additions & 3 deletions .buildkite/pipeline.ml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/alpha_star && pip install -r requirements.txt && pip install -e .)
- ./ci/env/env_info.sh
- pytest rllib_contrib/alpha_star/tests/
Expand Down Expand Up @@ -523,14 +526,14 @@
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT

# Install mujoco necessary for the testing environments
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- sudo apt install libosmesa6-dev libgl1-mesa-glx libglfw3 patchelf -y
- [ ! -d "/root/.mujoco" ] && mkdir -p /root/.mujoco && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz \
&& mv mujoco210-linux-x86_64.tar.gz /root/.mujoco/. && \
- mkdir -p /root/.mujoco && \
wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz && \
mv mujoco210-linux-x86_64.tar.gz /root/.mujoco/. && \
(cd /root/.mujoco && tar -xf /root/.mujoco/mujoco210-linux-x86_64.tar.gz)
- export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mujoco210/bin
- (cd rllib_contrib/maml && pip install -r requirements.txt && pip install -e ".[development"])
Expand All @@ -548,3 +551,15 @@
- ./ci/env/env_info.sh
- pytest rllib_contrib/r2d2/tests/
- python rllib_contrib/r2d2/examples/r2d2_stateless_cartpole.py --run-as-test

- label: ":exploding_death_star: RLlib Contrib: TD3 Tests"
conditions: ["NO_WHEELS_REQUIRED", "RAY_CI_RLLIB_CONTRIB_AFFECTED"]
commands:
- cleanup() { if [ "${BUILDKITE_PULL_REQUEST}" = "false" ]; then ./ci/build/upload_build_info.sh; fi }; trap cleanup EXIT
- conda deactivate
- conda create -n rllib_contrib python=3.8 -y
- conda activate rllib_contrib
- (cd rllib_contrib/td3 && pip install -r requirements.txt && pip install -e .)
- ./ci/env/env_info.sh
- pytest rllib_contrib/td3/tests/
- python rllib_contrib/td3/examples/td3_pendulum_v1.py --run-as-test
25 changes: 25 additions & 0 deletions rllib_contrib/td3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# TD3 (Twin Delayed DDPG)

[TD3](https://arxiv.org/pdf/1802.09477) While DDPG can achieve great performance sometimes, it is frequently brittle with respect to hyperparameters and other kinds of tuning. A common failure mode for DDPG is that the learned Q-function begins to dramatically overestimate Q-values, which then leads to the policy breaking, because it exploits the errors in the Q-function. Twin Delayed DDPG (TD3) is an algorithm that addresses this issue by introducing three critical tricks:

Trick One: Clipped Double-Q Learning. TD3 learns two Q-functions instead of one (hence “twin”), and uses the smaller of the two Q-values to form the targets in the Bellman error loss functions.

Trick Two: “Delayed” Policy Updates. TD3 updates the policy (and target networks) less frequently than the Q-function. The paper recommends one policy update for every two Q-function updates.

Trick Three: Target Policy Smoothing. TD3 adds noise to the target action, to make it harder for the policy to exploit Q-function errors by smoothing out Q along changes in action.

Together, these three tricks result in substantially improved performance over baseline DDPG.


## Installation

```
conda create -n rllib-td3 python=3.10
conda activate rllib-td3
pip install -r requirements.txt
pip install -e '.[development]'
```

## Usage

[TD3 Example]()
53 changes: 53 additions & 0 deletions rllib_contrib/td3/examples/td3_pendulum_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import argparse

from rllib_td3.td3 import TD3, TD3Config

import ray
from ray import air, tune
from ray.rllib.utils.test_utils import check_learning_achieved


def get_cli_args():
"""Create CLI parser and return parsed arguments"""
parser = argparse.ArgumentParser()
parser.add_argument("--run-as-test", action="store_true", default=False)
args = parser.parse_args()
print(f"Running with following CLI args: {args}")
return args


if __name__ == "__main__":
args = get_cli_args()

ray.init()

config = (
TD3Config()
.framework("torch")
.environment("Pendulum-v1")
.training(
actor_hiddens=[64, 64],
critic_hiddens=[64, 64],
replay_buffer_config={"type": "MultiAgentReplayBuffer"},
num_steps_sampled_before_learning_starts=5000,
)
.exploration(exploration_config={"random_timesteps": 5000})
)

stop_reward = -900

tuner = tune.Tuner(
TD3,
param_space=config.to_dict(),
run_config=air.RunConfig(
stop={
"sampler_results/episode_reward_mean": stop_reward,
"timesteps_total": 100000,
},
failure_config=air.FailureConfig(fail_fast="raise"),
),
)
results = tuner.fit()

if args.run_as_test:
check_learning_achieved(results, stop_reward)
18 changes: 18 additions & 0 deletions rllib_contrib/td3/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[build-system]
requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"

[tool.setuptools.packages.find]
where = ["src"]

[project]
name = "rllib-td3"
authors = [{name = "Anyscale Inc."}]
version = "0.1.0"
description = ""
readme = "README.md"
requires-python = ">=3.7, <3.11"
dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.1"]

[project.optional-dependencies]
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"]
2 changes: 2 additions & 0 deletions rllib_contrib/td3/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.11.0
torch==1.12.0
7 changes: 7 additions & 0 deletions rllib_contrib/td3/src/rllib_td3/td3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from rllib_td3.td3.td3 import TD3, TD3Config

from ray.tune.registry import register_trainable

__all__ = ["TD3Config", "TD3"]

register_trainable("rllib-contrib-td3", TD3)
108 changes: 108 additions & 0 deletions rllib_contrib/td3/src/rllib_td3/td3/td3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""A more stable successor to TD3.
By default, this uses a near-identical configuration to that reported in the
TD3 paper.
"""
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig
from ray.rllib.utils.annotations import override
from ray.rllib.utils.deprecation import DEPRECATED_VALUE


class TD3Config(DDPGConfig):
"""Defines a configuration class from which a TD3 Algorithm can be built.
Example:
>>> from ray.rllib.algorithms.td3 import TD3Config
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1)
>>> print(config.to_dict()) # doctest: +SKIP
>>> # Build a Algorithm object from the config and run one training iteration.
>>> algo = config.build(env="Pendulum-v1") # doctest: +SKIP
>>> algo.train() # doctest: +SKIP
Example:
>>> from ray.rllib.algorithms.td3 import TD3Config
>>> from ray import air
>>> from ray import tune
>>> config = TD3Config()
>>> # Print out some default values.
>>> print(config.lr) # doctest: +SKIP
>>> # Update the config object.
>>> config = config.training(lr=tune.grid_search( # doctest: +SKIP
... [0.001, 0.0001])) # doctest: +SKIP
>>> # Set the config object's env.
>>> config.environment(env="Pendulum-v1") # doctest: +SKIP
>>> # Use to_dict() to get the old-style python config dict
>>> # when running with tune.
>>> tune.Tuner( # doctest: +SKIP
... "TD3",
... run_config=air.RunConfig(stop={"episode_reward_mean": 200}),
... param_space=config.to_dict(),
... ).fit()
"""

def __init__(self, algo_class=None):
"""Initializes a TD3Config instance."""
super().__init__(algo_class=algo_class or TD3)

# fmt: off
# __sphinx_doc_begin__

# Override some of DDPG/SimpleQ/Algorithm's default values with TD3-specific
# values.

# .training()

# largest changes: twin Q functions, delayed policy updates, target
# smoothing, no l2-regularization.
self.twin_q = True
self.policy_delay = 2
self.smooth_target_policy = True,
self.l2_reg = 0.0
# Different tau (affecting target network update).
self.tau = 5e-3
# Different batch size.
self.train_batch_size = 100
# No prioritized replay by default (we may want to change this at some
# point).
self.replay_buffer_config = {
"type": "MultiAgentReplayBuffer",
# Specify prioritized replay by supplying a buffer type that supports
# prioritization, for example: MultiAgentPrioritizedReplayBuffer.
"prioritized_replay": DEPRECATED_VALUE,
"capacity": 1000000,
"worker_side_prioritization": False,
}
# Number of timesteps to collect from rollout workers before we start
# sampling from replay buffers for learning. Whether we count this in agent
# steps or environment steps depends on config.multi_agent(count_steps_by=..).
self.num_steps_sampled_before_learning_starts = 10000

# .exploration()
# TD3 uses Gaussian Noise by default.
self.exploration_config = {
# TD3 uses simple Gaussian noise on top of deterministic NN-output
# actions (after a possible pure random phase of n timesteps).
"type": "GaussianNoise",
# For how many timesteps should we return completely random
# actions, before we start adding (scaled) noise?
"random_timesteps": 10000,
# Gaussian stddev of action noise for exploration.
"stddev": 0.1,
# Scaling settings by which the Gaussian noise is scaled before
# being added to the actions. NOTE: The scale timesteps start only
# after(!) any random steps have been finished.
# By default, do not anneal over time (fixed 1.0).
"initial_scale": 1.0,
"final_scale": 1.0,
"scale_timesteps": 1,
}
# __sphinx_doc_end__
# fmt: on


class TD3(DDPG):
@classmethod
@override(DDPG)
def get_default_config(cls) -> AlgorithmConfig:
return TD3Config()
110 changes: 110 additions & 0 deletions rllib_contrib/td3/tests/test_td3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import unittest

import numpy as np
import rllib_td3.td3.td3 as td3

import ray
from ray.rllib.utils.framework import try_import_tf
from ray.rllib.utils.numpy import convert_to_numpy
from ray.rllib.utils.test_utils import (
check,
check_compute_single_action,
check_train_results,
framework_iterator,
)

tf1, tf, tfv = try_import_tf()


class TestTD3(unittest.TestCase):
@classmethod
def setUpClass(cls) -> None:
ray.init()

@classmethod
def tearDownClass(cls) -> None:
ray.shutdown()

def test_td3_compilation(self):
"""Test whether TD3 can be built with both frameworks."""
config = td3.TD3Config()

# Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
algo = config.build(env="Pendulum-v1")
num_iterations = 1
for i in range(num_iterations):
results = algo.train()
check_train_results(results)
print(results)
check_compute_single_action(algo)
algo.stop()

def test_td3_exploration_and_with_random_prerun(self):
"""Tests TD3's Exploration (w/ random actions for n timesteps)."""
config = td3.TD3Config().environment(env="Pendulum-v1")
no_random_init = config.exploration_config.copy()
random_init = {
# Act randomly at beginning ...
"random_timesteps": 30,
# Then act very closely to deterministic actions thereafter.
"stddev": 0.001,
"initial_scale": 0.001,
"final_scale": 0.001,
}
obs = np.array([0.0, 0.1, -0.1])

# Test against all frameworks.
for _ in framework_iterator(config, with_eager_tracing=True):
config.exploration(exploration_config=no_random_init)
# Default GaussianNoise setup.
algo = config.build()
# Setting explore=False should always return the same action.
a_ = algo.compute_single_action(obs, explore=False)
check(convert_to_numpy(algo.get_policy().global_timestep), 1)
for i in range(50):
a = algo.compute_single_action(obs, explore=False)
check(convert_to_numpy(algo.get_policy().global_timestep), i + 2)
check(a, a_)
# explore=None (default: explore) should return different actions.
actions = []
for i in range(50):
actions.append(algo.compute_single_action(obs))
check(convert_to_numpy(algo.get_policy().global_timestep), i + 52)
check(np.std(actions), 0.0, false=True)
algo.stop()

# Check randomness at beginning.
config.exploration(exploration_config=random_init)
algo = config.build()
# ts=0 (get a deterministic action as per explore=False).
deterministic_action = algo.compute_single_action(obs, explore=False)
check(convert_to_numpy(algo.get_policy().global_timestep), 1)
# ts=1-29 (in random window).
random_a = []
for i in range(1, 30):
random_a.append(algo.compute_single_action(obs, explore=True))
check(convert_to_numpy(algo.get_policy().global_timestep), i + 1)
check(random_a[-1], deterministic_action, false=True)
self.assertTrue(np.std(random_a) > 0.3)

# ts > 30 (a=deterministic_action + scale * N[0,1])
for i in range(50):
a = algo.compute_single_action(obs, explore=True)
check(convert_to_numpy(algo.get_policy().global_timestep), i + 31)
check(a, deterministic_action, rtol=0.1)

# ts >> 30 (BUT: explore=False -> expect deterministic action).
for i in range(50):
a = algo.compute_single_action(obs, explore=False)
check(convert_to_numpy(algo.get_policy().global_timestep), i + 81)
check(a, deterministic_action)
algo.stop()


if __name__ == "__main__":
import sys

import pytest

sys.exit(pytest.main(["-v", __file__]))

0 comments on commit b737bcb

Please sign in to comment.