forked from ray-project/ray
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[RLlib contrib] TD3. (ray-project#36726)
- Loading branch information
Showing
8 changed files
with
341 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# TD3 (Twin Delayed DDPG) | ||
|
||
[TD3](https://arxiv.org/pdf/1802.09477) While DDPG can achieve great performance sometimes, it is frequently brittle with respect to hyperparameters and other kinds of tuning. A common failure mode for DDPG is that the learned Q-function begins to dramatically overestimate Q-values, which then leads to the policy breaking, because it exploits the errors in the Q-function. Twin Delayed DDPG (TD3) is an algorithm that addresses this issue by introducing three critical tricks: | ||
|
||
Trick One: Clipped Double-Q Learning. TD3 learns two Q-functions instead of one (hence “twin”), and uses the smaller of the two Q-values to form the targets in the Bellman error loss functions. | ||
|
||
Trick Two: “Delayed” Policy Updates. TD3 updates the policy (and target networks) less frequently than the Q-function. The paper recommends one policy update for every two Q-function updates. | ||
|
||
Trick Three: Target Policy Smoothing. TD3 adds noise to the target action, to make it harder for the policy to exploit Q-function errors by smoothing out Q along changes in action. | ||
|
||
Together, these three tricks result in substantially improved performance over baseline DDPG. | ||
|
||
|
||
## Installation | ||
|
||
``` | ||
conda create -n rllib-td3 python=3.10 | ||
conda activate rllib-td3 | ||
pip install -r requirements.txt | ||
pip install -e '.[development]' | ||
``` | ||
|
||
## Usage | ||
|
||
[TD3 Example]() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import argparse | ||
|
||
from rllib_td3.td3 import TD3, TD3Config | ||
|
||
import ray | ||
from ray import air, tune | ||
from ray.rllib.utils.test_utils import check_learning_achieved | ||
|
||
|
||
def get_cli_args(): | ||
"""Create CLI parser and return parsed arguments""" | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--run-as-test", action="store_true", default=False) | ||
args = parser.parse_args() | ||
print(f"Running with following CLI args: {args}") | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = get_cli_args() | ||
|
||
ray.init() | ||
|
||
config = ( | ||
TD3Config() | ||
.framework("torch") | ||
.environment("Pendulum-v1") | ||
.training( | ||
actor_hiddens=[64, 64], | ||
critic_hiddens=[64, 64], | ||
replay_buffer_config={"type": "MultiAgentReplayBuffer"}, | ||
num_steps_sampled_before_learning_starts=5000, | ||
) | ||
.exploration(exploration_config={"random_timesteps": 5000}) | ||
) | ||
|
||
stop_reward = -900 | ||
|
||
tuner = tune.Tuner( | ||
TD3, | ||
param_space=config.to_dict(), | ||
run_config=air.RunConfig( | ||
stop={ | ||
"sampler_results/episode_reward_mean": stop_reward, | ||
"timesteps_total": 100000, | ||
}, | ||
failure_config=air.FailureConfig(fail_fast="raise"), | ||
), | ||
) | ||
results = tuner.fit() | ||
|
||
if args.run_as_test: | ||
check_learning_achieved(results, stop_reward) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
[build-system] | ||
requires = ["setuptools>=61.0"] | ||
build-backend = "setuptools.build_meta" | ||
|
||
[tool.setuptools.packages.find] | ||
where = ["src"] | ||
|
||
[project] | ||
name = "rllib-td3" | ||
authors = [{name = "Anyscale Inc."}] | ||
version = "0.1.0" | ||
description = "" | ||
readme = "README.md" | ||
requires-python = ">=3.7, <3.11" | ||
dependencies = ["gymnasium==0.26.3", "ray[rllib]==2.5.1"] | ||
|
||
[project.optional-dependencies] | ||
development = ["pytest>=7.2.2", "pre-commit==2.21.0", "tensorflow==2.11.0", "torch==1.12.0"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tensorflow==2.11.0 | ||
torch==1.12.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from rllib_td3.td3.td3 import TD3, TD3Config | ||
|
||
from ray.tune.registry import register_trainable | ||
|
||
__all__ = ["TD3Config", "TD3"] | ||
|
||
register_trainable("rllib-contrib-td3", TD3) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
"""A more stable successor to TD3. | ||
By default, this uses a near-identical configuration to that reported in the | ||
TD3 paper. | ||
""" | ||
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig | ||
from ray.rllib.algorithms.ddpg.ddpg import DDPG, DDPGConfig | ||
from ray.rllib.utils.annotations import override | ||
from ray.rllib.utils.deprecation import DEPRECATED_VALUE | ||
|
||
|
||
class TD3Config(DDPGConfig): | ||
"""Defines a configuration class from which a TD3 Algorithm can be built. | ||
Example: | ||
>>> from ray.rllib.algorithms.td3 import TD3Config | ||
>>> config = TD3Config().training(lr=0.01).resources(num_gpus=1) | ||
>>> print(config.to_dict()) # doctest: +SKIP | ||
>>> # Build a Algorithm object from the config and run one training iteration. | ||
>>> algo = config.build(env="Pendulum-v1") # doctest: +SKIP | ||
>>> algo.train() # doctest: +SKIP | ||
Example: | ||
>>> from ray.rllib.algorithms.td3 import TD3Config | ||
>>> from ray import air | ||
>>> from ray import tune | ||
>>> config = TD3Config() | ||
>>> # Print out some default values. | ||
>>> print(config.lr) # doctest: +SKIP | ||
>>> # Update the config object. | ||
>>> config = config.training(lr=tune.grid_search( # doctest: +SKIP | ||
... [0.001, 0.0001])) # doctest: +SKIP | ||
>>> # Set the config object's env. | ||
>>> config.environment(env="Pendulum-v1") # doctest: +SKIP | ||
>>> # Use to_dict() to get the old-style python config dict | ||
>>> # when running with tune. | ||
>>> tune.Tuner( # doctest: +SKIP | ||
... "TD3", | ||
... run_config=air.RunConfig(stop={"episode_reward_mean": 200}), | ||
... param_space=config.to_dict(), | ||
... ).fit() | ||
""" | ||
|
||
def __init__(self, algo_class=None): | ||
"""Initializes a TD3Config instance.""" | ||
super().__init__(algo_class=algo_class or TD3) | ||
|
||
# fmt: off | ||
# __sphinx_doc_begin__ | ||
|
||
# Override some of DDPG/SimpleQ/Algorithm's default values with TD3-specific | ||
# values. | ||
|
||
# .training() | ||
|
||
# largest changes: twin Q functions, delayed policy updates, target | ||
# smoothing, no l2-regularization. | ||
self.twin_q = True | ||
self.policy_delay = 2 | ||
self.smooth_target_policy = True, | ||
self.l2_reg = 0.0 | ||
# Different tau (affecting target network update). | ||
self.tau = 5e-3 | ||
# Different batch size. | ||
self.train_batch_size = 100 | ||
# No prioritized replay by default (we may want to change this at some | ||
# point). | ||
self.replay_buffer_config = { | ||
"type": "MultiAgentReplayBuffer", | ||
# Specify prioritized replay by supplying a buffer type that supports | ||
# prioritization, for example: MultiAgentPrioritizedReplayBuffer. | ||
"prioritized_replay": DEPRECATED_VALUE, | ||
"capacity": 1000000, | ||
"worker_side_prioritization": False, | ||
} | ||
# Number of timesteps to collect from rollout workers before we start | ||
# sampling from replay buffers for learning. Whether we count this in agent | ||
# steps or environment steps depends on config.multi_agent(count_steps_by=..). | ||
self.num_steps_sampled_before_learning_starts = 10000 | ||
|
||
# .exploration() | ||
# TD3 uses Gaussian Noise by default. | ||
self.exploration_config = { | ||
# TD3 uses simple Gaussian noise on top of deterministic NN-output | ||
# actions (after a possible pure random phase of n timesteps). | ||
"type": "GaussianNoise", | ||
# For how many timesteps should we return completely random | ||
# actions, before we start adding (scaled) noise? | ||
"random_timesteps": 10000, | ||
# Gaussian stddev of action noise for exploration. | ||
"stddev": 0.1, | ||
# Scaling settings by which the Gaussian noise is scaled before | ||
# being added to the actions. NOTE: The scale timesteps start only | ||
# after(!) any random steps have been finished. | ||
# By default, do not anneal over time (fixed 1.0). | ||
"initial_scale": 1.0, | ||
"final_scale": 1.0, | ||
"scale_timesteps": 1, | ||
} | ||
# __sphinx_doc_end__ | ||
# fmt: on | ||
|
||
|
||
class TD3(DDPG): | ||
@classmethod | ||
@override(DDPG) | ||
def get_default_config(cls) -> AlgorithmConfig: | ||
return TD3Config() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import unittest | ||
|
||
import numpy as np | ||
import rllib_td3.td3.td3 as td3 | ||
|
||
import ray | ||
from ray.rllib.utils.framework import try_import_tf | ||
from ray.rllib.utils.numpy import convert_to_numpy | ||
from ray.rllib.utils.test_utils import ( | ||
check, | ||
check_compute_single_action, | ||
check_train_results, | ||
framework_iterator, | ||
) | ||
|
||
tf1, tf, tfv = try_import_tf() | ||
|
||
|
||
class TestTD3(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls) -> None: | ||
ray.init() | ||
|
||
@classmethod | ||
def tearDownClass(cls) -> None: | ||
ray.shutdown() | ||
|
||
def test_td3_compilation(self): | ||
"""Test whether TD3 can be built with both frameworks.""" | ||
config = td3.TD3Config() | ||
|
||
# Test against all frameworks. | ||
for _ in framework_iterator(config, with_eager_tracing=True): | ||
algo = config.build(env="Pendulum-v1") | ||
num_iterations = 1 | ||
for i in range(num_iterations): | ||
results = algo.train() | ||
check_train_results(results) | ||
print(results) | ||
check_compute_single_action(algo) | ||
algo.stop() | ||
|
||
def test_td3_exploration_and_with_random_prerun(self): | ||
"""Tests TD3's Exploration (w/ random actions for n timesteps).""" | ||
config = td3.TD3Config().environment(env="Pendulum-v1") | ||
no_random_init = config.exploration_config.copy() | ||
random_init = { | ||
# Act randomly at beginning ... | ||
"random_timesteps": 30, | ||
# Then act very closely to deterministic actions thereafter. | ||
"stddev": 0.001, | ||
"initial_scale": 0.001, | ||
"final_scale": 0.001, | ||
} | ||
obs = np.array([0.0, 0.1, -0.1]) | ||
|
||
# Test against all frameworks. | ||
for _ in framework_iterator(config, with_eager_tracing=True): | ||
config.exploration(exploration_config=no_random_init) | ||
# Default GaussianNoise setup. | ||
algo = config.build() | ||
# Setting explore=False should always return the same action. | ||
a_ = algo.compute_single_action(obs, explore=False) | ||
check(convert_to_numpy(algo.get_policy().global_timestep), 1) | ||
for i in range(50): | ||
a = algo.compute_single_action(obs, explore=False) | ||
check(convert_to_numpy(algo.get_policy().global_timestep), i + 2) | ||
check(a, a_) | ||
# explore=None (default: explore) should return different actions. | ||
actions = [] | ||
for i in range(50): | ||
actions.append(algo.compute_single_action(obs)) | ||
check(convert_to_numpy(algo.get_policy().global_timestep), i + 52) | ||
check(np.std(actions), 0.0, false=True) | ||
algo.stop() | ||
|
||
# Check randomness at beginning. | ||
config.exploration(exploration_config=random_init) | ||
algo = config.build() | ||
# ts=0 (get a deterministic action as per explore=False). | ||
deterministic_action = algo.compute_single_action(obs, explore=False) | ||
check(convert_to_numpy(algo.get_policy().global_timestep), 1) | ||
# ts=1-29 (in random window). | ||
random_a = [] | ||
for i in range(1, 30): | ||
random_a.append(algo.compute_single_action(obs, explore=True)) | ||
check(convert_to_numpy(algo.get_policy().global_timestep), i + 1) | ||
check(random_a[-1], deterministic_action, false=True) | ||
self.assertTrue(np.std(random_a) > 0.3) | ||
|
||
# ts > 30 (a=deterministic_action + scale * N[0,1]) | ||
for i in range(50): | ||
a = algo.compute_single_action(obs, explore=True) | ||
check(convert_to_numpy(algo.get_policy().global_timestep), i + 31) | ||
check(a, deterministic_action, rtol=0.1) | ||
|
||
# ts >> 30 (BUT: explore=False -> expect deterministic action). | ||
for i in range(50): | ||
a = algo.compute_single_action(obs, explore=False) | ||
check(convert_to_numpy(algo.get_policy().global_timestep), i + 81) | ||
check(a, deterministic_action) | ||
algo.stop() | ||
|
||
|
||
if __name__ == "__main__": | ||
import sys | ||
|
||
import pytest | ||
|
||
sys.exit(pytest.main(["-v", __file__])) |