Skip to content

Commit

Permalink
Merge pull request #55 from hill-a/doc-custom-policy
Browse files Browse the repository at this point in the history
Update doc + add tests for custom policies
  • Loading branch information
hill-a committed Oct 12, 2018
2 parents 5f3c4b6 + fff7119 commit 3ac278a
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 0 deletions.
9 changes: 9 additions & 0 deletions docs/guide/custom_policy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ However, you can also easily define a custom architecture for the policy (or val
# Train the agent
model.learn(total_timesteps=100000)
del model
# When loading a model with a custom policy
# you MUST pass explicitly the policy when loading the saved model
model = A2C.load(policy=CustomPolicy)
.. warning::

When loading a model with a custom policy, you must pass the custom policy explicitly when loading the model. (cf previous example)


You can also registered your policy, to help with code simplicity: you can refer to your custom policy using a string.

Expand Down
77 changes: 77 additions & 0 deletions tests/test_custom_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import os

import pytest

from stable_baselines import A2C, ACER, ACKTR, DQN, PPO1, PPO2, TRPO, DDPG
from stable_baselines.common.policies import FeedForwardPolicy
from stable_baselines.deepq.policies import FeedForwardPolicy as DQNPolicy
from stable_baselines.ddpg.policies import FeedForwardPolicy as DDPGPolicy

N_TRIALS = 100

class CustomCommonPolicy(FeedForwardPolicy):
def __init__(self, *args, **kwargs):
super(CustomCommonPolicy, self).__init__(*args, **kwargs,
layers=[8, 8],
feature_extraction="mlp")

class CustomDQNPolicy(DQNPolicy):
def __init__(self, *args, **kwargs):
super(CustomDQNPolicy, self).__init__(*args, **kwargs,
layers=[8, 8],
feature_extraction="mlp")

class CustomDDPGPolicy(DDPGPolicy):
def __init__(self, *args, **kwargs):
super(CustomDDPGPolicy, self).__init__(*args, **kwargs,
layers=[8, 8],
feature_extraction="mlp")


MODEL_DICT = {
'a2c': (A2C, CustomCommonPolicy),
'acer': (ACER, CustomCommonPolicy),
'acktr': (ACKTR, CustomCommonPolicy),
'dqn': (DQN, CustomDQNPolicy),
'ddpg': (DDPG, CustomDDPGPolicy),
'ppo1': (PPO1, CustomCommonPolicy),
'ppo2': (PPO2, CustomCommonPolicy),
'trpo': (TRPO, CustomCommonPolicy),
}


@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_custom_policy(model_name):
"""
Test if the algorithm (with a custom policy) can be loaded and saved without any issues.
:param model_class: (BaseRLModel) A RL model
"""

try:
model_class, policy = MODEL_DICT[model_name]
if model_name == 'ddpg':
env = 'MountainCarContinuous-v0'
else:
env = 'CartPole-v1'
# create and train
model = model_class(policy, env)
model.learn(total_timesteps=100, seed=0)

env = model.get_env()
# predict and measure the acc reward
obs = env.reset()
for _ in range(N_TRIALS):
action, _ = model.predict(obs)
# Test action probability method
if model_name != 'ddpg':
model.action_probability(obs)
obs, _, _, _ = env.step(action)
# saving
model.save("./test_model")
del model, env
# loading
model = model_class.load("./test_model", policy=policy)

finally:
if os.path.exists("./test_model"):
os.remove("./test_model")

0 comments on commit 3ac278a

Please sign in to comment.