From c0cb54de62864db87beaf77fbd236fbb4d9b67f0 Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Thu, 6 Aug 2020 18:11:26 -0700 Subject: [PATCH 1/2] fix discrete export --- ml-agents/mlagents/trainers/torch/distributions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index c83ae4649e..6b57453430 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -100,7 +100,8 @@ def sample(self): return torch.multinomial(self.probs, 1) def pdf(self, value): - return torch.diag(self.probs.T[value.flatten().long()]) + idx = torch.tensor(range(len(value))).unsqueeze(-1) + return torch.gather(self.probs.permute(1, 0)[value.flatten().long()], -1, idx).squeeze(-1) def log_prob(self, value): return torch.log(self.pdf(value)) From c66b6f11675d43df423a8fdda850aa4734062c94 Mon Sep 17 00:00:00 2001 From: "Ruo-Ping (Rachel) Dong" Date: Fri, 7 Aug 2020 16:31:06 -0700 Subject: [PATCH 2/2] Update ml-agents/mlagents/trainers/torch/distributions.py Co-authored-by: Ervin T. --- ml-agents/mlagents/trainers/torch/distributions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index 6b57453430..0d7c9278e2 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -100,7 +100,7 @@ def sample(self): return torch.multinomial(self.probs, 1) def pdf(self, value): - idx = torch.tensor(range(len(value))).unsqueeze(-1) + idx = torch.range(end=len(value)).unsqueeze(-1) return torch.gather(self.probs.permute(1, 0)[value.flatten().long()], -1, idx).squeeze(-1) def log_prob(self, value):