Skip to content

Commit

Permalink
Fix - sample type inconsistency in (Multi)Categorical Probability Dis…
Browse files Browse the repository at this point in the history
…tribution (#588)

* Fix - sample type inconsistency in CategoricalProbabilityDistribution

* Adding info on fix to changelog.

* Fix - sample type inconsistency (change sample type of CategoricalProbabilityDistribution, MultiCategoricalProbabilityDistribution to tf.int64)

* Change dtype of actions to int64 of ACER

* Update changelog.rst
  • Loading branch information
seheevic authored and araffin committed Dec 2, 2019
1 parent 04c35e1 commit 6039b89
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 8 deletions.
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Breaking Changes:
- `allow_early_resets` of the `Monitor` wrapper now default to `True`
- `make_atari_env` now returns a `DummyVecEnv` by default (instead of a `SubprocVecEnv`)
this usually improves performance.
- Fix inconsistency of sample type, so that mode/sample function returns tensor of tf.int64 in CategoricalProbabilityDistribution/MultiCategoricalProbabilityDistribution (@seheevic)

New Features:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -546,4 +547,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic
2 changes: 1 addition & 1 deletion stable_baselines/a2c/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ def get_by_index(input_tensor, idx):
"""
assert len(input_tensor.get_shape()) == 2
assert len(idx.get_shape()) == 1
idx_flattened = tf.range(0, input_tensor.shape[0]) * input_tensor.shape[1] + idx
idx_flattened = tf.range(0, input_tensor.shape[0], dtype=idx.dtype) * input_tensor.shape[1] + idx
offset_tensor = tf.gather(tf.reshape(input_tensor, [-1]), # flatten input
idx_flattened) # use flattened indices
return offset_tensor
Expand Down
4 changes: 2 additions & 2 deletions stable_baselines/acer/acer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def run(self):
"""
Run a step leaning of the model
:return: ([float], [float], [float], [float], [float], [bool], [float])
:return: ([float], [float], [int64], [float], [float], [bool], [float])
encoded observation, observations, actions, rewards, mus, dones, masks
"""
enc_obs = [self.obs]
Expand Down Expand Up @@ -666,7 +666,7 @@ def run(self):

enc_obs = np.asarray(enc_obs, dtype=self.obs_dtype).swapaxes(1, 0)
mb_obs = np.asarray(mb_obs, dtype=self.obs_dtype).swapaxes(1, 0)
mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0)
mb_actions = np.asarray(mb_actions, dtype=np.int64).swapaxes(1, 0)
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines/common/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def sample_shape(self):
return []

def sample_dtype(self):
return tf.int32
return tf.int64


class MultiCategoricalProbabilityDistributionType(ProbabilityDistributionType):
Expand Down Expand Up @@ -211,7 +211,7 @@ def sample_shape(self):
return [len(self.n_vec)]

def sample_dtype(self):
return tf.int32
return tf.int64


class DiagGaussianProbabilityDistributionType(ProbabilityDistributionType):
Expand Down Expand Up @@ -353,7 +353,7 @@ def flatparam(self):
return self.flat

def mode(self):
return tf.cast(tf.stack([p.mode() for p in self.categoricals], axis=-1), tf.int32)
return tf.stack([p.mode() for p in self.categoricals], axis=-1)

def neglogp(self, x):
return tf.add_n([p.neglogp(px) for p, px in zip(self.categoricals, tf.unstack(x, axis=-1))])
Expand All @@ -365,7 +365,7 @@ def entropy(self):
return tf.add_n([p.entropy() for p in self.categoricals])

def sample(self):
return tf.cast(tf.stack([p.sample() for p in self.categoricals], axis=-1), tf.int32)
return tf.stack([p.sample() for p in self.categoricals], axis=-1)

@classmethod
def fromflat(cls, flat):
Expand Down

0 comments on commit 6039b89

Please sign in to comment.