From 1d894898c1772c779b1018c9ae98e8caae238747 Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Tue, 18 Aug 2020 15:52:21 -0700 Subject: [PATCH] Fix of the test for multi visual input --- ml-agents/mlagents/trainers/policy/policy.py | 5 ----- .../trainers/torch/model_serialization.py | 18 +++++++++--------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/ml-agents/mlagents/trainers/policy/policy.py b/ml-agents/mlagents/trainers/policy/policy.py index ec6756c84b..0e23f10ec8 100644 --- a/ml-agents/mlagents/trainers/policy/policy.py +++ b/ml-agents/mlagents/trainers/policy/policy.py @@ -43,11 +43,6 @@ def __init__( self.vis_obs_size = sum( 1 for shape in behavior_spec.observation_shapes if len(shape) == 3 ) - self.vis_obs_shape = ( - [shape for shape in behavior_spec.observation_shapes if len(shape) == 3][0] - if self.vis_obs_size > 0 - else None - ) self.use_continuous_act = behavior_spec.is_action_continuous() self.num_branches = self.behavior_spec.action_size self.previous_action_dict: Dict[str, np.array] = {} diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index 5e1f16cdec..e664166495 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -13,15 +13,15 @@ def __init__(self, policy): self.policy = policy batch_dim = [1] dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])] - dummy_vis_obs = ( - [torch.zeros(batch_dim + list(self.policy.vis_obs_shape))] - if self.policy.vis_obs_size > 0 - else [] - ) + dummy_vis_obs = [ + torch.zeros(batch_dim + list(shape)) + for shape in self.policy.behavior_spec.observation_shapes + if len(shape) == 3 + ] dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)]) dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size]) - # Need to pass all posslible inputs since currently keyword arguments is not + # Need to pass all possible inputs since currently keyword arguments is not # supported by torch.nn.export() self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories) @@ -32,9 +32,9 @@ def __init__(self, policy): if self.policy.use_vec_obs: self.input_names.append("vector_observation") self.dynamic_axes.update({"vector_observation": {0: "batch"}}) - if self.policy.use_vis_obs: - self.input_names.append("visual_observation") - self.dynamic_axes.update({"visual_observation": {0: "batch"}}) + for i in range(self.policy.vis_obs_size): + self.input_names.append(f"visual_observation_{i}") + self.dynamic_axes.update({f"visual_observation_{i}": {0: "batch"}}) if not self.policy.use_continuous_act: self.input_names.append("action_masks") self.dynamic_axes.update({"action_masks": {0: "batch"}})