diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index c83ae4649e..0d7c9278e2 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -100,7 +100,8 @@ def sample(self): return torch.multinomial(self.probs, 1) def pdf(self, value): - return torch.diag(self.probs.T[value.flatten().long()]) + idx = torch.range(end=len(value)).unsqueeze(-1) + return torch.gather(self.probs.permute(1, 0)[value.flatten().long()], -1, idx).squeeze(-1) def log_prob(self, value): return torch.log(self.pdf(value))