Skip to content

Commit

Permalink
HER now passes info dictionary to compute_reward (#802)
Browse files Browse the repository at this point in the history
* her now passes info dict to compute_reward. updated sac to match.

* updated dqn, td3 and ddpg to match her changes. fixed bug introduced by previous commit.

* moved her check to offpolicy base class in order to avoid duplicating code

* removed leftover import

* Fix circular import

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
Tirafesi and araffin committed May 30, 2020
1 parent 257ab9c commit c20df90
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 11 deletions.
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ New Features:
^^^^^^^^^^^^^
- Added momentum parameter to A2C for the embedded RMSPropOptimizer (@kantneel)
- ActionNoise is now an abstract base class and implements ``__call__``, ``NormalActionNoise`` and ``OrnsteinUhlenbeckActionNoise`` have return types (@solliet)
- HER now passes info dictionary to compute_reward, allowing for the computation of rewards that are independent of the goal (@tirafesi)

Bug Fixes:
^^^^^^^^^^
Expand Down
25 changes: 25 additions & 0 deletions stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -979,6 +979,31 @@ def __init__(self, policy, env, replay_buffer=None, _init_setup_model=False, ver

self.replay_buffer = replay_buffer

def is_using_her(self) -> bool:
"""
Check if is using HER
:return: (bool) Whether is using HER or not
"""
# Avoid circular import
from stable_baselines.her.replay_buffer import HindsightExperienceReplayWrapper
return isinstance(self.replay_buffer, HindsightExperienceReplayWrapper)

def replay_buffer_add(self, obs_t, action, reward, obs_tp1, done, info):
"""
Add a new transition to the replay buffer
:param obs_t: (np.ndarray) the last observation
:param action: ([float]) the action
:param reward: (float) the reward of the transition
:param obs_tp1: (np.ndarray) the new observation
:param done: (bool) is the episode done
:param info: (dict) extra values used to compute the reward when using HER
"""
# Pass info dict when using HER, as it can be used to compute the reward
kwargs = dict(info=info) if self.is_using_her() else {}
self.replay_buffer.add(obs_t, action, reward, obs_tp1, float(done), **kwargs)

@abstractmethod
def setup_model(self):
pass
Expand Down
7 changes: 4 additions & 3 deletions stable_baselines/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,7 @@ def _policy(self, obs, apply_noise=True, compute_q=True):
action = np.clip(action, -1, 1)
return action, q_value

def _store_transition(self, obs, action, reward, next_obs, done):
def _store_transition(self, obs, action, reward, next_obs, done, info):
"""
Store a transition in the replay buffer
Expand All @@ -634,9 +634,10 @@ def _store_transition(self, obs, action, reward, next_obs, done):
:param reward: (float] the reward
:param next_obs: ([float] or [int]) the current observation
:param done: (bool) Whether the episode is over
:param info: (dict) extra values used to compute reward when using HER
"""
reward *= self.reward_scale
self.replay_buffer.add(obs, action, reward, next_obs, float(done))
self.replay_buffer_add(obs, action, reward, next_obs, done, info)
if self.normalize_observations:
self.obs_rms.update(np.array([obs]))

Expand Down Expand Up @@ -915,7 +916,7 @@ def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="D
# Avoid changing the original ones
obs_, new_obs_, reward_ = obs, new_obs, reward

self._store_transition(obs_, action, reward_, new_obs_, done)
self._store_transition(obs_, action, reward_, new_obs_, done, info)
obs = new_obs
# Save the unnormalized observation
if self._vec_normalize_env is not None:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/deepq/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def learn(self, total_timesteps, callback=None, log_interval=100, tb_log_name="D
# Avoid changing the original ones
obs_, new_obs_, reward_ = obs, new_obs, rew
# Store transition in the replay buffer.
self.replay_buffer.add(obs_, action, reward_, new_obs_, float(done))
self.replay_buffer_add(obs_, action, reward_, new_obs_, done, info)
obs = new_obs
# Save the unnormalized observation
if self._vec_normalize_env is not None:
Expand Down
11 changes: 6 additions & 5 deletions stable_baselines/her/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(self, replay_buffer, n_sampled_goal, goal_selection_strategy, wrapp
self.episode_transitions = []
self.replay_buffer = replay_buffer

def add(self, obs_t, action, reward, obs_tp1, done):
def add(self, obs_t, action, reward, obs_tp1, done, info):
"""
add a new transition to the buffer
Expand All @@ -69,10 +69,11 @@ def add(self, obs_t, action, reward, obs_tp1, done):
:param reward: (float) the reward of the transition
:param obs_tp1: (np.ndarray) the new observation
:param done: (bool) is the episode done
:param info: (dict) extra values used to compute reward
"""
assert self.replay_buffer is not None
# Update current episode buffer
self.episode_transitions.append((obs_t, action, reward, obs_tp1, done))
self.episode_transitions.append((obs_t, action, reward, obs_tp1, done, info))
if done:
# Add transitions (and imagined ones) to buffer only when an episode is over
self._store_episode()
Expand Down Expand Up @@ -146,7 +147,7 @@ def _store_episode(self):
# create a set of artificial transitions
for transition_idx, transition in enumerate(self.episode_transitions):

obs_t, action, reward, obs_tp1, done = transition
obs_t, action, reward, obs_tp1, done, info = transition

# Add to the replay buffer
self.replay_buffer.add(obs_t, action, reward, obs_tp1, done)
Expand All @@ -162,7 +163,7 @@ def _store_episode(self):
# For each sampled goals, store a new transition
for goal in sampled_goals:
# Copy transition to avoid modifying the original one
obs, action, reward, next_obs, done = copy.deepcopy(transition)
obs, action, reward, next_obs, done, info = copy.deepcopy(transition)

# Convert concatenated obs to dict, so we can update the goals
obs_dict, next_obs_dict = map(self.env.convert_obs_to_dict, (obs, next_obs))
Expand All @@ -172,7 +173,7 @@ def _store_episode(self):
next_obs_dict['desired_goal'] = goal

# Update the reward according to the new desired goal
reward = self.env.compute_reward(next_obs_dict['achieved_goal'], goal, None)
reward = self.env.compute_reward(next_obs_dict['achieved_goal'], goal, info)
# Can we use achieved_goal == desired_goal?
done = False

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ def learn(self, total_timesteps, callback=None,
obs_, new_obs_, reward_ = obs, new_obs, reward

# Store transition in the replay buffer.
self.replay_buffer.add(obs_, action, reward_, new_obs_, float(done))
self.replay_buffer_add(obs_, action, reward_, new_obs_, done, info)
obs = new_obs
# Save the unnormalized observation
if self._vec_normalize_env is not None:
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/td3/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def learn(self, total_timesteps, callback=None,
obs_, new_obs_, reward_ = obs, new_obs, reward

# Store transition in the replay buffer.
self.replay_buffer.add(obs_, action, reward_, new_obs_, float(done))
self.replay_buffer_add(obs_, action, reward_, new_obs_, done, info)
obs = new_obs
# Save the unnormalized observation
if self._vec_normalize_env is not None:
Expand Down

0 comments on commit c20df90

Please sign in to comment.