-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Review of code (A2C, PPO and refactoring) (#35)
* Split torch module code into torch_layers file * Updated reference to CNN * Change 'CxWxH' to 'CxHxW', as per common notion * Fix missing import in policies.py * Move PPOPolicy to OnlineActorCriticPolicy * Create OnPolicyRLModel from PPO, and make A2C and PPO inherit * Update A2C optimizer comment * Clean weight init scales for clarity * Fix A2C log_interval default parameter * Rename 'progress' to 'progress_remaining * Rename 'Models' to 'Algorithms' * Rename 'OnlineActorCriticPolicy' to 'ActorCriticPolicy' * Move static functions out from BaseAlgorithm * Move on/off_policy base algorithms to their own files * Add files for A2C/PPO * Fix docs * Fix pytype * Update documentation on OnPolicyAlgorithm * Add proper doctstring for on_policy rollout gathering * Add bit clarification on the mlppolicy/cnnpolicy naming * Move static function is_vectorized_policies to utils.py * Checking docstrings, pep8 fixes * Update changelog * Clean changelog * Remove policy warnings for sac/td3 * Add monitor_wrapper for OnPolicyAlgorithm. Clean tb logging variables. Add parameter keywords to OffPolicyAlgorithm super init Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
- Loading branch information
Showing
26 changed files
with
1,511 additions
and
1,279 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from stable_baselines3.a2c.a2c import A2C | ||
from stable_baselines3.ppo.policies import MlpPolicy, CnnPolicy | ||
from stable_baselines3.a2c.policies import MlpPolicy, CnnPolicy |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# This file is here just to define MlpPolicy/CnnPolicy | ||
# that work for A2C | ||
from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, register_policy | ||
|
||
MlpPolicy = ActorCriticPolicy | ||
CnnPolicy = ActorCriticCnnPolicy | ||
|
||
register_policy("MlpPolicy", ActorCriticPolicy) | ||
register_policy("CnnPolicy", ActorCriticCnnPolicy) |
Oops, something went wrong.