diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index 3ad4274a90..56d3c984b2 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -112,13 +112,13 @@ def pdf(self, value): ).squeeze(-1) def log_prob(self, value): - return torch.log(self.pdf(value)) + return torch.log(self.pdf(value) + EPSILON) def all_log_prob(self): - return torch.log(self.probs) + return torch.log(self.probs + EPSILON) def entropy(self): - return -torch.sum(self.probs * torch.log(self.probs), dim=-1) + return -torch.sum(self.probs * torch.log(self.probs + EPSILON), dim=-1) class GaussianDistribution(nn.Module): @@ -187,10 +187,13 @@ def _create_policy_branches(self, hidden_size: int) -> nn.ModuleList: return nn.ModuleList(branches) def _mask_branch(self, logits: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: - raw_probs = torch.nn.functional.softmax(logits, dim=-1) * mask - normalized_probs = raw_probs / torch.sum(raw_probs, dim=-1).unsqueeze(-1) - normalized_logits = torch.log(normalized_probs + EPSILON) - return normalized_logits + # Zero out masked logits, then subtract a large value. Technique mentionend here: + # https://arxiv.org/abs/2006.14171. Our implementation is ONNX and Barrcuda-friendly. + flipped_mask = 1.0 - mask + adj_logits = logits * mask - 1e8 * flipped_mask + probs = torch.nn.functional.softmax(adj_logits, dim=-1) + log_probs = torch.log(probs + EPSILON) + return log_probs def _split_masks(self, masks: torch.Tensor) -> List[torch.Tensor]: split_masks = []