Skip to content

Commit

Permalink
Rename the observations variable in the evaluation util to avoid shad…
Browse files Browse the repository at this point in the history
…owing (#1288)

* Rename the observations variable in the evaluation util to avoid shadowing

This enables a callback in evaluate_policy to have access to the
observation vector that is fed to the environment step function,
which is currently shadowed by the output observation.

* Update changelog

* Add test

* Move assignment outside of the loop

---------

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
Co-authored-by: Antonin Raffin <antonin.raffin@dlr.de>
  • Loading branch information
3 people committed Apr 11, 2023
1 parent 84f5511 commit 4232f9d
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
30 changes: 29 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,34 @@
Changelog
==========

Release 1.8.1a0 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit)

New Features:
^^^^^^^^^^^^^

`SB3-Contrib`_
^^^^^^^^^^^^^^

`RL Zoo`_
^^^^^^^^^

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^


Release 1.8.0 (2023-04-07)
--------------------------
Expand Down Expand Up @@ -1271,4 +1299,4 @@ And all the contributors:
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit
4 changes: 3 additions & 1 deletion stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def evaluate_policy(
episode_starts = np.ones((env.num_envs,), dtype=bool)
while (episode_counts < episode_count_targets).any():
actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic)
observations, rewards, dones, infos = env.step(actions)
new_observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
for i in range(n_envs):
Expand Down Expand Up @@ -120,6 +120,8 @@ def evaluate_policy(
current_rewards[i] = 0
current_lengths[i] = 0

observations = new_observations

if render:
env.render()

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.8.0
1.8.1a0
4 changes: 4 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def test_evaluate_policy(direct_policy: bool):

def dummy_callback(locals_, _globals):
locals_["model"].n_callback_calls += 1
assert "observations" in locals_
assert "new_observations" in locals_
assert locals_["new_observations"] is not locals_["observations"]
assert not np.allclose(locals_["new_observations"], locals_["observations"])

assert model.policy is not None
policy = model.policy if direct_policy else model
Expand Down

0 comments on commit 4232f9d

Please sign in to comment.