diff --git a/ml-agents/mlagents/trainers/policy/policy.py b/ml-agents/mlagents/trainers/policy/policy.py index ec6756c84b..0e23f10ec8 100644 --- a/ml-agents/mlagents/trainers/policy/policy.py +++ b/ml-agents/mlagents/trainers/policy/policy.py @@ -43,11 +43,6 @@ def __init__( self.vis_obs_size = sum( 1 for shape in behavior_spec.observation_shapes if len(shape) == 3 ) - self.vis_obs_shape = ( - [shape for shape in behavior_spec.observation_shapes if len(shape) == 3][0] - if self.vis_obs_size > 0 - else None - ) self.use_continuous_act = behavior_spec.is_action_continuous() self.num_branches = self.behavior_spec.action_size self.previous_action_dict: Dict[str, np.array] = {} diff --git a/ml-agents/mlagents/trainers/torch/model_serialization.py b/ml-agents/mlagents/trainers/torch/model_serialization.py index 5e1f16cdec..e664166495 100644 --- a/ml-agents/mlagents/trainers/torch/model_serialization.py +++ b/ml-agents/mlagents/trainers/torch/model_serialization.py @@ -13,15 +13,15 @@ def __init__(self, policy): self.policy = policy batch_dim = [1] dummy_vec_obs = [torch.zeros(batch_dim + [self.policy.vec_obs_size])] - dummy_vis_obs = ( - [torch.zeros(batch_dim + list(self.policy.vis_obs_shape))] - if self.policy.vis_obs_size > 0 - else [] - ) + dummy_vis_obs = [ + torch.zeros(batch_dim + list(shape)) + for shape in self.policy.behavior_spec.observation_shapes + if len(shape) == 3 + ] dummy_masks = torch.ones(batch_dim + [sum(self.policy.actor_critic.act_size)]) dummy_memories = torch.zeros(batch_dim + [1] + [self.policy.m_size]) - # Need to pass all posslible inputs since currently keyword arguments is not + # Need to pass all possible inputs since currently keyword arguments is not # supported by torch.nn.export() self.dummy_input = (dummy_vec_obs, dummy_vis_obs, dummy_masks, dummy_memories) @@ -32,9 +32,9 @@ def __init__(self, policy): if self.policy.use_vec_obs: self.input_names.append("vector_observation") self.dynamic_axes.update({"vector_observation": {0: "batch"}}) - if self.policy.use_vis_obs: - self.input_names.append("visual_observation") - self.dynamic_axes.update({"visual_observation": {0: "batch"}}) + for i in range(self.policy.vis_obs_size): + self.input_names.append(f"visual_observation_{i}") + self.dynamic_axes.update({f"visual_observation_{i}": {0: "batch"}}) if not self.policy.use_continuous_act: self.input_names.append("action_masks") self.dynamic_axes.update({"action_masks": {0: "batch"}})