Skip to content

Commit

Permalink
implemented future, episode and random strategy for HER sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
hill-a committed Sep 18, 2018
1 parent 20f3a33 commit 53feaae
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 53 deletions.
17 changes: 11 additions & 6 deletions stable_baselines/her/her.py
@@ -1,5 +1,5 @@
from stable_baselines.common import OffPolicyRLModel
from stable_baselines.her.replay_buffer import make_her_buffer
from stable_baselines.her.replay_buffer import make_her_buffer, FutureHERBuffer
from stable_baselines.her.env_wrapper import HERWrapper


Expand All @@ -17,8 +17,8 @@ class HER(OffPolicyRLModel):
:param *args: positional arguments for the model
:param **kwargs: keyword arguments for the model
"""
def __init__(self, model, policy, env, reward_function, num_sample_goals=4, verbose=0, _init_setup_model=True,
*args, **kwargs):
def __init__(self, model, policy, env, reward_function, num_sample_goals=4, buffer_class=FutureHERBuffer,
verbose=0, _init_setup_model=True, *args, **kwargs):
super(HER, self).__init__(policy=None, env=env, replay_buffer=None,
verbose=verbose, policy_base=None, requires_vec_env=True)

Expand All @@ -27,10 +27,12 @@ def __init__(self, model, policy, env, reward_function, num_sample_goals=4, verb

self.reward_function = reward_function
self.model_class = model
self.buffer_class = buffer_class
self.num_sample_goals = num_sample_goals

if self.env is not None:
env = HERWrapper(self.env, reward_function)
replay_buffer = make_her_buffer(reward_function, num_sample_goals)
replay_buffer = make_her_buffer(buffer_class, reward_function, num_sample_goals)
self.model = model(policy=policy, env=env, verbose=verbose, replay_buffer=replay_buffer,
_init_setup_model=_init_setup_model, *args, **kwargs)

Expand Down Expand Up @@ -66,6 +68,8 @@ def get_save_data(self):
"model_class": self.model_class,
"reward_function": self.reward_function,
"observation_space": self.observation_space,
"num_sample_goals": self.num_sample_goals,
"buffer_class": self.buffer_class,
}

def save(self, save_path):
Expand All @@ -80,8 +84,9 @@ def save(self, save_path):
def load(cls, load_path, env=None, **kwargs):
(data, model_data), params = cls._load_from_file(load_path)

her_model = cls(model=data["model_class"], policy=model_data["policy"],
reward_function=data["reward_function"], env=None, _init_setup_model=False)
her_model = cls(model=data["model_class"], policy=model_data["policy"], reward_function=data["reward_function"],
buffer_class=data["buffer_class"], num_sample_goals=data["num_sample_goals"], env=None,
_init_setup_model=False)
her_model.__dict__.update(data)
her_model.model.__dict__.update(model_data)
her_model.model.__dict__.update(kwargs)
Expand Down
170 changes: 123 additions & 47 deletions stable_baselines/her/replay_buffer.py
Expand Up @@ -4,58 +4,134 @@
from stable_baselines.her.utils import unstack_goal, stack_obs_goal


def make_her_buffer(reward_func, num_sample_goals):
class HERBuffer(ReplayBuffer):
"""
Replay Buffer for HER that implements the Episode strategy
:param size: (int) The size of the buffer
"""

def __init__(self, size, reward_func, num_sample_goals):
super(HERBuffer, self).__init__(size=size)
self.reward_func = reward_func
self.num_sample_goals = num_sample_goals
self.eps_count = 0 # the episode counter
self.eps_idx = 0 # the frame in the episode
self.eps_goal = [] # the goals in the current episode
self.eps_goals = {0: self.eps_goal} # the lookup table for the goals

def _encode_sample(self, idxes):
obses_t, actions, rewards, obses_tp1, dones, eps_ids, eps_idx = [], [], [], [], [], [], []
for i in idxes:
data = self._storage[i]
obs_t, action, reward, obs_tp1, done, eps_id, eps_pos = data
obses_t.append(np.array(obs_t, copy=False))
actions.append(np.array(action, copy=False))
rewards.append(reward)
obses_tp1.append(np.array(obs_tp1, copy=False))
dones.append(done)
eps_ids.append(eps_id)
eps_idx.append(eps_pos)
return np.array(obses_t), np.array(actions), np.array(rewards), np.array(obses_tp1), np.array(dones), \
np.array(eps_ids), np.array(eps_idx)

def sample(self, batch_size, **kwargs):
raise NotImplementedError()

def add(self, obs_t, action, reward, obs_tp1, done):
if done:
# clean up unused goals
self.eps_goals = {k: v for k, v in self.eps_goals.items()
if len(self._storage) <= self._next_idx or k >= self._storage[self._next_idx][5]}
self.eps_count += 1
self.eps_idx = 0
self.eps_goal = []
self.eps_goals[self.eps_count] = self.eps_goal
data = (obs_t, action, reward, obs_tp1, done, self.eps_count, self.eps_idx)
self.eps_idx += 1
self.eps_goal.append(self._next_idx)

if self._next_idx >= len(self._storage):
self._storage.append(data)
else:
self._storage[self._next_idx] = data
self._next_idx = (self._next_idx + 1) % self._maxsize


class FutureHERBuffer(HERBuffer):
def sample(self, batch_size, **kwargs):
idxes = [np.random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
future_proba = 1 - (1. / (1 + self.num_sample_goals))
future_idx = np.random.uniform(size=batch_size) < future_proba
obs_t, actions, rewards, obs_tp1, dones, eps_ids, eps_idx = self._encode_sample(idxes)

for idx in np.where(future_idx)[0]:
future_goal_idx = np.array(self.eps_goals[eps_ids[idx]])[eps_idx[idx]:]
if future_goal_idx.shape[0] == 0:
continue
goal_idx = future_goal_idx[np.random.randint(future_goal_idx.shape[0])]
goal_obs, _, _, _, _, _, _ = self._storage[goal_idx]
goal_obs = unstack_goal(goal_obs)
obs_t[idx] = stack_obs_goal(unstack_goal(obs_t[idx]), goal_obs)
obs_tp1[idx] = stack_obs_goal(unstack_goal(obs_tp1[idx]), goal_obs)
rewards[idx] = self.reward_func.get_reward(unstack_goal(obs_t[idx]), actions[idx], goal_obs)

return obs_t, actions, rewards, obs_tp1, dones


class EpisodeHERBuffer(HERBuffer):
def sample(self, batch_size, **kwargs):
idxes = [np.random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
episode_proba = 1 - (1. / (1 + self.num_sample_goals))
episode_idx = np.random.uniform(size=batch_size) < episode_proba
obs_t, actions, rewards, obs_tp1, dones, eps_ids, eps_idx = self._encode_sample(idxes)

for idx in np.where(episode_idx)[0]:
episode_goal_idx = np.array(self.eps_goals[eps_ids[idx]])
if episode_goal_idx.shape[0] == 0:
continue
goal_idx = episode_goal_idx[np.random.randint(episode_goal_idx.shape[0])]
goal_obs, _, _, _, _, _, _ = self._storage[goal_idx]
goal_obs = unstack_goal(goal_obs)
obs_t[idx] = stack_obs_goal(unstack_goal(obs_t[idx]), goal_obs)
obs_tp1[idx] = stack_obs_goal(unstack_goal(obs_tp1[idx]), goal_obs)
rewards[idx] = self.reward_func.get_reward(unstack_goal(obs_t[idx]), actions[idx], goal_obs)

return obs_t, actions, rewards, obs_tp1, dones


class RandomHERBuffer(HERBuffer):
def sample(self, batch_size, **kwargs):
idxes = [np.random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
random_proba = 1 - (1. / (1 + self.num_sample_goals))
random_idx = np.random.uniform(size=batch_size) < random_proba
obs_t, actions, rewards, obs_tp1, dones, eps_ids, eps_idx = self._encode_sample(idxes)

for idx in np.where(random_idx)[0]:
if len(self._storage) == 0:
break
goal_idx = np.random.randint(len(self._storage))
goal_obs, _, _, _, _, _, _ = self._storage[goal_idx]
goal_obs = unstack_goal(goal_obs)
obs_t[idx] = stack_obs_goal(unstack_goal(obs_t[idx]), goal_obs)
obs_tp1[idx] = stack_obs_goal(unstack_goal(obs_tp1[idx]), goal_obs)
rewards[idx] = self.reward_func.get_reward(unstack_goal(obs_t[idx]), actions[idx], goal_obs)

return obs_t, actions, rewards, obs_tp1, dones


def make_her_buffer(buffer_class, reward_func, num_sample_goals):
"""
Creates the Hindsight Experience Replay Buffer for HER
:param buffer_class: (HERBuffer) the buffer type you wish to use
:param reward_func: (HERRewardFunctions) The reward function to apply to the buffer
:param num_sample_goals: (int) the number of goals to sample for every step
"""
class HindsightExperienceReplayBuffer(ReplayBuffer):
"""
Replay Buffer for HER
assert issubclass(buffer_class, HERBuffer), "Error: the buffer type, must be of type HERBuffer."

:param size: (int) The size of the buffer
"""
class _Buffer(buffer_class):
def __init__(self, size):
super(HindsightExperienceReplayBuffer, self).__init__(size=size)
self.reward_func = reward_func
self.num_sample_goals = num_sample_goals

def add(self, obs_t, action, reward, obs_tp1, done):
super().add(obs_t, action, reward, obs_tp1, done)

start = None
length = 0
for i in range(1, self._maxsize + 1):
# walk backwards to know the range of the episode
step = (self._next_idx - i) % (self._maxsize - 1)

if step == 0 and len(self._storage) < (self._maxsize - 1):
start = 0
length = self._next_idx
break
elif self._storage[step][4]: # if end of episode
start = step
length = (start + self._next_idx) % (self._maxsize - 1)

if start is None:
start = 0
length = self._maxsize

# sample goals randomly withing the current episode
goals = [self._storage
[(np.random.randint(0, length) + start) % (self._maxsize - 1)]
[0]
[:obs_t.shape[-1] // 2]
for _ in range(num_sample_goals)]

# stack the new goals to the current obs, obs+1 and recalculate the reward
for goal in goals:
obs_t_, obs_tp1_, reward_ = \
(stack_obs_goal(unstack_goal(obs_t), goal),
stack_obs_goal(unstack_goal(obs_tp1), goal),
reward_func.get_reward(unstack_goal(obs_t), action, goal))
super().add(obs_t_, action, reward_, obs_tp1_, done)

return HindsightExperienceReplayBuffer
super(_Buffer, self).__init__(size, reward_func, num_sample_goals)

return _Buffer

0 comments on commit 53feaae

Please sign in to comment.