From a16999b2b5b850da2472c9fe62daebf945171adc Mon Sep 17 00:00:00 2001 From: Ruo-Ping Dong Date: Wed, 16 Dec 2020 20:36:14 -0500 Subject: [PATCH] Fix model inference issue with Barracuda v1.2.1 (#4766) Co-authored-by: Ervin T. --- ml-agents/mlagents/trainers/torch/distributions.py | 8 +++++--- ml-agents/mlagents/trainers/torch/networks.py | 14 ++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/ml-agents/mlagents/trainers/torch/distributions.py b/ml-agents/mlagents/trainers/torch/distributions.py index e5b44e8550..b3e7b34f3c 100644 --- a/ml-agents/mlagents/trainers/torch/distributions.py +++ b/ml-agents/mlagents/trainers/torch/distributions.py @@ -173,9 +173,11 @@ def forward(self, inputs: torch.Tensor) -> List[DistInstance]: log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2) else: # Expand so that entropy matches batch size. Note that we're using - # torch.cat here instead of torch.expand() becuase it is not supported in the - # verified version of Barracuda (1.0.2). - log_sigma = torch.cat([self.log_sigma] * inputs.shape[0], axis=0) + # mu*0 here to get the batch size implicitly since Barracuda 1.2.1 + # throws error on runtime broadcasting due to unknown reason. We + # use this to replace torch.expand() becuase it is not supported in + # the verified version of Barracuda (1.0.X). + log_sigma = mu * 0 + self.log_sigma if self.tanh_squash: return TanhGaussianDistInstance(mu, torch.exp(log_sigma)) else: diff --git a/ml-agents/mlagents/trainers/torch/networks.py b/ml-agents/mlagents/trainers/torch/networks.py index 2308b91d7a..ed82c1f2ca 100644 --- a/ml-agents/mlagents/trainers/torch/networks.py +++ b/ml-agents/mlagents/trainers/torch/networks.py @@ -258,9 +258,11 @@ def __init__( ): super().__init__() self.action_spec = action_spec - self.version_number = torch.nn.Parameter(torch.Tensor([2.0])) + self.version_number = torch.nn.Parameter( + torch.Tensor([2.0]), requires_grad=False + ) self.is_continuous_int_deprecated = torch.nn.Parameter( - torch.Tensor([int(self.action_spec.is_continuous())]) + torch.Tensor([int(self.action_spec.is_continuous())]), requires_grad=False ) self.continuous_act_size_vector = torch.nn.Parameter( torch.Tensor([int(self.action_spec.continuous_size)]), requires_grad=False @@ -283,6 +285,9 @@ def __init__( self.encoding_size = network_settings.memory.memory_size // 2 else: self.encoding_size = network_settings.hidden_units + self.memory_size_vector = torch.nn.Parameter( + torch.Tensor([int(self.network_body.memory_size)]), requires_grad=False + ) self.action_model = ActionModel( self.encoding_size, @@ -335,10 +340,7 @@ def forward( disc_action_out, action_out_deprecated, ) = self.action_model.get_action_out(encoding, masks) - export_out = [ - self.version_number, - torch.Tensor([self.network_body.memory_size]), - ] + export_out = [self.version_number, self.memory_size_vector] if self.action_spec.continuous_size > 0: export_out += [cont_action_out, self.continuous_act_size_vector] if self.action_spec.discrete_size > 0: