From da9cc71f85ab9ceea98092c2fd867e62706bb691 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 10 Aug 2020 10:48:38 -0700 Subject: [PATCH 1/3] Fix test and replace range with arange --- .../mlagents/trainers/tests/torch/test_distributions.py | 4 ++-- ml-agents/mlagents/trainers/torch/distributions.py | 6 ++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_distributions.py b/ml-agents/mlagents/trainers/tests/torch/test_distributions.py index b2c6afaf0f..68f0eb272c 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_distributions.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_distributions.py @@ -125,13 +125,13 @@ def test_categorical_dist_instance(): torch.manual_seed(0) act_size = 4 test_prob = torch.tensor( - [1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1) + [[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)] ) # High prob for first action dist_instance = CategoricalDistInstance(test_prob) for _ in range(10): action = dist_instance.sample() - assert action.shape == (1,) + assert action.shape == (1, 1) assert action < act_size # Make sure the first action as higher probability than the others. diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index 4c8bdecc52..cfe25de1e1 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -100,8 +100,10 @@ def sample(self): return torch.multinomial(self.probs, 1) def pdf(self, value): - idx = torch.range(end=len(value)).unsqueeze(-1) - return torch.gather(self.probs.permute(1, 0)[value.flatten().long()], -1, idx).squeeze(-1) + idx = torch.arange(start=0, end=len(value)).unsqueeze(-1) + return torch.gather( + self.probs.permute(1, 0)[value.flatten().long()], -1, idx + ).squeeze(-1) def log_prob(self, value): return torch.log(self.pdf(value)) From 8807a706500cfed791f71ad2a2faf727ddc28dbd Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 10 Aug 2020 11:03:07 -0700 Subject: [PATCH 2/3] Added comment --- ml-agents/mlagents/trainers/torch/distributions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index cfe25de1e1..570460e36d 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -100,6 +100,8 @@ def sample(self): return torch.multinomial(self.probs, 1) def pdf(self, value): + # This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]), + # but torch.diag is not supported by ONNX export. idx = torch.arange(start=0, end=len(value)).unsqueeze(-1) return torch.gather( self.probs.permute(1, 0)[value.flatten().long()], -1, idx From 35b976749897b5566a7a37dd4588cba4726d5e91 Mon Sep 17 00:00:00 2001 From: Ervin Teng Date: Mon, 10 Aug 2020 12:05:01 -0700 Subject: [PATCH 3/3] Fix util test --- ml-agents/mlagents/trainers/tests/torch/test_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ml-agents/mlagents/trainers/tests/torch/test_utils.py b/ml-agents/mlagents/trainers/tests/torch/test_utils.py index cad04067fa..70306a89f3 100644 --- a/ml-agents/mlagents/trainers/tests/torch/test_utils.py +++ b/ml-agents/mlagents/trainers/tests/torch/test_utils.py @@ -187,14 +187,14 @@ def test_get_probs_and_entropy(): # Add two dists to the list. act_size = 2 test_prob = torch.tensor( - [1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1) + [[1.0 - 0.1 * (act_size - 1)] + [0.1] * (act_size - 1)] ) # High prob for first action dist_list = [CategoricalDistInstance(test_prob), CategoricalDistInstance(test_prob)] action_list = [torch.tensor([0]), torch.tensor([1])] log_probs, entropies, all_probs = ModelUtils.get_probs_and_entropy( action_list, dist_list ) - assert all_probs.shape == (len(dist_list * act_size),) - assert entropies.shape == (len(dist_list),) + assert all_probs.shape == (1, len(dist_list * act_size)) + assert entropies.shape == (1, len(dist_list)) # Make sure the first action has high probability than the others. assert log_probs.flatten()[0] > log_probs.flatten()[1]