Skip to content

Commit

Permalink
Fix partial minibatch computation in GAIL dataset. (#724)
Browse files Browse the repository at this point in the history
* Fix partial minibatch computation in GAIL dataset.

* Updated changelog.

* Added name to bottom of changelog.
  • Loading branch information
richardwu committed Mar 6, 2020
1 parent a4efff0 commit ac46c37
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
11 changes: 6 additions & 5 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,16 +22,16 @@ Breaking Changes:

- Algorithms no longer import from each other, and ``common`` does not import from algorithms.
- ``a2c/utils.py`` removed and split into other files:
- common/tf_util.py: ``sample``, ``calc_entropy``, ``mse``, ``avg_norm``, ``total_episode_reward_logger``,
``q_explained_variance``, ``gradient_add``, ``avg_norm``, ``check_shape``,

- common/tf_util.py: ``sample``, ``calc_entropy``, ``mse``, ``avg_norm``, ``total_episode_reward_logger``,
``q_explained_variance``, ``gradient_add``, ``avg_norm``, ``check_shape``,
``seq_to_batch``, ``batch_to_seq``.
- common/tf_layers.py: ``conv``, ``linear``, ``lstm``, ``_ln``, ``lnlstm``, ``conv_to_fc``, ``ortho_init``.
- a2c/a2c.py: ``discount_with_dones``.
- acer/acer_simple.py: ``get_by_index``, ``EpisodeStats``.
- common/schedules.py: ``constant``, ``linear_schedule``, ``middle_drop``, ``double_linear_con``, ``double_middle_drop``,
``SCHEDULES``, ``Scheduler``.

- ``trpo_mpi/utils.py`` functions moved (``traj_segment_generator`` moved to ``common/runners.py``, ``flatten_lists`` to ``common/misc_util.py``).
- ``ppo2/ppo2.py`` functions moved (``safe_mean`` to ``common/math_util.py``, ``constfn`` and ``get_schedule_fn`` to ``common/schedules.py``).
- ``sac/policies.py`` function ``mlp`` moved to ``common/tf_layers.py``.
Expand Down Expand Up @@ -69,6 +69,7 @@ Bug Fixes:
- Fixed a bug in ``BaseRLModel`` when seeding vectorized environments. (@NeoExtended)
- Fixed ``num_timesteps`` computation to be consistent between algorithms (updated after ``env.step()``)
Only ``TRPO`` and ``PPO1`` update it differently (after synchronization) because they rely on MPI
- Fixed partial minibatch computation in ExpertDataset (@richardwu)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -652,4 +653,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta
@flodorner @KuKuXia @NeoExtended @solliet @mmcenta @richardwu
2 changes: 1 addition & 1 deletion stable_baselines/gail/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def __init__(self, indices, observations, actions, batch_size, n_workers=1,
self.n_minibatches = len(indices) // batch_size
# Add a partial minibatch, for instance
# when there is not enough samples
if partial_minibatch and len(indices) / batch_size > 0:
if partial_minibatch and len(indices) % batch_size > 0:
self.n_minibatches += 1
self.batch_size = batch_size
self.observations = observations
Expand Down

0 comments on commit ac46c37

Please sign in to comment.