Skip to content

Commit

Permalink
[ci skip] documented replay buffer classes
Browse files Browse the repository at this point in the history
  • Loading branch information
hill-a committed Sep 18, 2018
1 parent 53feaae commit e954001
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 2 deletions.
26 changes: 25 additions & 1 deletion docs/modules/her.rst
Expand Up @@ -29,6 +29,7 @@ Train a HER agent on `MountainCarContinuous-v0`.
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.her.reward_class import ProximalReward
from stable_baselines.her.utils import stack_obs_goal
from stable_baselines.her.replay_buffer import FutureHERBuffer
from stable_baselines.ddpg.noise import NormalActionNoise
from stable_baselines import DDPG, HER
Expand All @@ -38,7 +39,8 @@ Train a HER agent on `MountainCarContinuous-v0`.
n_actions = env.action_space.shape[-1]
action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=float(0.2) * np.ones(n_actions))
model = HER(DDPG, 'MlpPolicy', env, ProximalReward(eps=0.1), action_noise=action_noise) # define the reward function for HER
# define the reward function, buffer_class and model for HER (+ model parameters)
model = HER(DDPG, 'MlpPolicy', env, ProximalReward(eps=0.1), buffer_class=FutureHERBuffer, action_noise=action_noise)
model.learn(total_timesteps=25000)
model.save("her_dqn_mountaincar")
Expand Down Expand Up @@ -75,6 +77,28 @@ Reward function
:inherited-members:


HER Replay Buffer
-----------------

.. autoclass::HERBuffer
:members:
.. autoclass::EpisodeHERBuffer
:members:
:inherited-members:
.. autoclass::RandomHERBuffer
:members:
:inherited-members:
.. autoclass::FutureHERBuffer
:members:
:inherited-members:
Utility functions
-----------------

Expand Down
1 change: 1 addition & 0 deletions stable_baselines/her/__init__.py
@@ -1,2 +1,3 @@
from stable_baselines.her.her import HER
from stable_baselines.her.reward_class import HERRewardFunctions, ProximalReward
from stable_baselines.her.replay_buffer import HERBuffer, EpisodeHERBuffer, RandomHERBuffer, FutureHERBuffer
25 changes: 24 additions & 1 deletion stable_baselines/her/replay_buffer.py
Expand Up @@ -6,9 +6,11 @@

class HERBuffer(ReplayBuffer):
"""
Replay Buffer for HER that implements the Episode strategy
Base class for the Replay Buffer for HER.
:param size: (int) The size of the buffer
:param reward_func: (HERRewardFunctions) the reward function
:param num_sample_goals: (int) the ratio of sampled HER to normal experience replay samples
"""

def __init__(self, size, reward_func, num_sample_goals):
Expand Down Expand Up @@ -59,6 +61,13 @@ def add(self, obs_t, action, reward, obs_tp1, done):


class FutureHERBuffer(HERBuffer):
"""
HER Replay buffer that implements the Future strategy
:param size: (int) The size of the buffer
:param reward_func: (HERRewardFunctions) the reward function
:param num_sample_goals: (int) the ratio of sampled HER to normal experience replay samples
"""
def sample(self, batch_size, **kwargs):
idxes = [np.random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
future_proba = 1 - (1. / (1 + self.num_sample_goals))
Expand All @@ -80,6 +89,13 @@ def sample(self, batch_size, **kwargs):


class EpisodeHERBuffer(HERBuffer):
"""
HER Replay buffer that implements the Episode strategy
:param size: (int) The size of the buffer
:param reward_func: (HERRewardFunctions) the reward function
:param num_sample_goals: (int) the ratio of sampled HER to normal experience replay samples
"""
def sample(self, batch_size, **kwargs):
idxes = [np.random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
episode_proba = 1 - (1. / (1 + self.num_sample_goals))
Expand All @@ -101,6 +117,13 @@ def sample(self, batch_size, **kwargs):


class RandomHERBuffer(HERBuffer):
"""
HER Replay buffer that implements the random strategy
:param size: (int) The size of the buffer
:param reward_func: (HERRewardFunctions) the reward function
:param num_sample_goals: (int) the ratio of sampled HER to normal experience replay samples
"""
def sample(self, batch_size, **kwargs):
idxes = [np.random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
random_proba = 1 - (1. / (1 + self.num_sample_goals))
Expand Down

0 comments on commit e954001

Please sign in to comment.