-
Notifications
You must be signed in to change notification settings - Fork 60
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support custom recurrent policies (#244)
* Make PPO2 support custom recurrent policies * BasePolicy: make initial_state a property * Use stateful class attribute rather than issubclass check * Introduce StatefulActorCriticPolicy to standardize masks_ph/states_ph * Dead import removal * Make placeholders and tensors publicly exposed properties * docs * DQN bugfix: don't set initial_state property * Make BasePolicy use abstractmethods * Update policies.py * Rename masks_ph -> dones_ph * Support >1D states * Bugfix: state shape * Unit test to verify LSTM policy training works in environment requiring memory * Linting * Use tuples for shapes * Merge conftest * Clean-up test cases * Docs improvements * Update docs/misc/changelog.rst Co-Authored-By: AdamGleave <adam@gleave.me> * Update tests/test_lstm_policy.py Co-Authored-By: AdamGleave <adam@gleave.me> * Rename stateful to recurrent * Doc improvements by araffin * Update tests/test_lstm_policy.py Co-Authored-By: AdamGleave <adam@gleave.me> * Avoid Codacy linting error * Add missing abstract method Trying to fix codacy warning
- Loading branch information
1 parent
0eac3f5
commit baa4aa5
Showing
14 changed files
with
294 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,22 @@ | ||
"""Configures pytest to ignore certain unit tests unless the appropriate flag is passed. | ||
--rungpu: tests that require GPU. | ||
--expensive: tests that take a long time to run (e.g. training an RL algorithm for many timestesps).""" | ||
|
||
import pytest | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption("--rungpu", action="store_true", default=False, help="run gpu tests") | ||
parser.addoption("--expensive", action="store_true", | ||
help="run expensive tests (which are otherwise skipped).") | ||
|
||
|
||
def pytest_collection_modifyitems(config, items): | ||
if config.getoption("--rungpu"): | ||
return | ||
skip_gpu = pytest.mark.skip(reason="need --rungpu option to run") | ||
flags = {'gpu': '--rungpu', 'expensive': '--expensive'} | ||
skips = {keyword: pytest.mark.skip(reason="need {} option to run".format(flag)) | ||
for keyword, flag in flags.items() if not config.getoption(flag)} | ||
for item in items: | ||
if "gpu" in item.keywords: | ||
item.add_marker(skip_gpu) | ||
for keyword, skip in skips.items(): | ||
if keyword in item.keywords: | ||
item.add_marker(skip) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.