From ee0b4050f0be9dc887d4fa7093e9bedcd1ba5471 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 12 Feb 2020 10:09:18 -0800 Subject: [PATCH 1/2] rename and comments --- gym-unity/gym_unity/envs/__init__.py | 77 +++++++++++++++------------- 1 file changed, 41 insertions(+), 36 deletions(-) diff --git a/gym-unity/gym_unity/envs/__init__.py b/gym-unity/gym_unity/envs/__init__.py index 7dac2960c2..df9a1fd4db 100644 --- a/gym-unity/gym_unity/envs/__init__.py +++ b/gym-unity/gym_unity/envs/__init__.py @@ -75,7 +75,8 @@ def __init__( self.visual_obs = None self._n_agents = -1 self._done_agents: Set[int] = set() - self._stored_info: BatchedStepResult = None + # Save the step result from the last time all Agents requested decisions. + self._previous_step_result: BatchedStepResult = None self._multiagent = multiagent self._flattener = None # Hidden flag used by Atari environments to determine if the game is over @@ -119,7 +120,7 @@ def __init__( self._env.reset() step_result = self._env.get_step_result(self.brain_name) self._check_agents(step_result.n_agents()) - self._stored_info = step_result + self._previous_step_result = step_result # Set observation and action spaces if self.group_spec.is_action_discrete(): @@ -162,15 +163,15 @@ def reset(self) -> Union[List[np.ndarray], np.ndarray]: Returns: observation (object/list): the initial observation of the space. """ - info = self._step(True) - n_agents = info.n_agents() + step_result = self._step(True) + n_agents = step_result.n_agents() self._check_agents(n_agents) self.game_over = False if not self._multiagent: - res: GymStepResult = self._single_step(info) + res: GymStepResult = self._single_step(step_result) else: - res = self._multi_step(info) + res = self._multi_step(step_result) return res[0] def step(self, action: List[Any]) -> GymStepResult: @@ -215,17 +216,17 @@ def step(self, action: List[Any]) -> GymStepResult: action = self._sanitize_action(action) self._env.set_actions(self.brain_name, action) - info = self._step() + step_result = self._step() - n_agents = info.n_agents() + n_agents = step_result.n_agents() self._check_agents(n_agents) if not self._multiagent: - single_res = self._single_step(info) + single_res = self._single_step(step_result) self.game_over = single_res[2] return single_res else: - multi_res = self._multi_step(info) + multi_res = self._multi_step(step_result) self.game_over = all(multi_res[2]) return multi_res @@ -358,8 +359,8 @@ def _check_agents(self, n_agents: int) -> None: "initialization. This is not supported." ) - def _sanitize_info(self, info: BatchedStepResult) -> BatchedStepResult: - n_extra_agents = info.n_agents() - self._n_agents + def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult: + n_extra_agents = step_result.n_agents() - self._n_agents if n_extra_agents < 0 or n_extra_agents > self._n_agents: # In this case, some Agents did not request a decision when expected # or too many requested a decision @@ -369,46 +370,48 @@ def _sanitize_info(self, info: BatchedStepResult) -> BatchedStepResult: # remove the done Agents indices_to_keep: List[int] = [] - for index in range(len(info.agent_id)): - if not info.done[index]: + for index in range(len(step_result.agent_id)): + if not step_result.done[index]: indices_to_keep.append(index) - # set the new AgentDone flags to True - for index in range(len(info.agent_id)): - agent_id = info.agent_id[index] - if not self._stored_info.contains_agent(agent_id): - info.done[index] = True + # Set the new AgentDone flags to True + # Note that the corresponding agent_id that gets marked done will be different + # than the original agent that was done, but this is OK since the gym interface + # only cares about the ordering. + for index, agent_id in enumerate(step_result.agent_id): + if not self._previous_step_result.contains_agent(agent_id): + step_result.done[index] = True if agent_id in self._done_agents: - info.done[index] = True + step_result.done[index] = True self._done_agents = set() - self._stored_info = info # store the new original + self._previous_step_result = step_result # store the new original _mask: Optional[List[np.array]] = None - if info.action_mask is not None: + if step_result.action_mask is not None: _mask = [] - for mask_index in range(len(info.action_mask)): - _mask.append(info.action_mask[mask_index][indices_to_keep]) + for mask_index in range(len(step_result.action_mask)): + _mask.append(step_result.action_mask[mask_index][indices_to_keep]) new_obs: List[np.array] = [] - for obs_index in range(len(info.obs)): - new_obs.append(info.obs[obs_index][indices_to_keep]) + for obs_index in range(len(step_result.obs)): + new_obs.append(step_result.obs[obs_index][indices_to_keep]) return BatchedStepResult( obs=new_obs, - reward=info.reward[indices_to_keep], - done=info.done[indices_to_keep], - max_step=info.max_step[indices_to_keep], - agent_id=info.agent_id[indices_to_keep], + reward=step_result.reward[indices_to_keep], + done=step_result.done[indices_to_keep], + max_step=step_result.max_step[indices_to_keep], + agent_id=step_result.agent_id[indices_to_keep], action_mask=_mask, ) def _sanitize_action(self, action: np.array) -> np.array: - if self._stored_info.n_agents() == self._n_agents: + if self._previous_step_result.n_agents() == self._n_agents: return action sanitized_action = np.zeros( - (self._stored_info.n_agents(), self.group_spec.action_size) + (self._previous_step_result.n_agents(), self.group_spec.action_size) ) input_index = 0 - for index in range(self._stored_info.n_agents()): - if not self._stored_info.done[index]: + for index in range(self._previous_step_result.n_agents()): + if not self._previous_step_result.done[index]: sanitized_action[index, :] = action[input_index, :] input_index = input_index + 1 return sanitized_action @@ -419,8 +422,10 @@ def _step(self, needs_reset: bool = False) -> BatchedStepResult: else: self._env.step() info = self._env.get_step_result(self.brain_name) - # In case some Agents raised a Done flag between steps, we re-request - # decisions until all agents request a real decision. + # Two possible cases here: + # 1) all agents requested decisions (some of which might be done) + # 2) some Agents were marked Done in between steps. + # In case 2, we re-request decisions until all agents request a real decision. while info.n_agents() - sum(info.done) < self._n_agents: if not info.done.all(): raise UnityGymException( From e477b1a57df1f916edff103f0125fb952ba315fa Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 12 Feb 2020 10:11:36 -0800 Subject: [PATCH 2/2] enumerate --- gym-unity/gym_unity/envs/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gym-unity/gym_unity/envs/__init__.py b/gym-unity/gym_unity/envs/__init__.py index df9a1fd4db..f9cf21eab3 100644 --- a/gym-unity/gym_unity/envs/__init__.py +++ b/gym-unity/gym_unity/envs/__init__.py @@ -370,8 +370,8 @@ def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult: # remove the done Agents indices_to_keep: List[int] = [] - for index in range(len(step_result.agent_id)): - if not step_result.done[index]: + for index, is_done in enumerate(step_result.done): + if not is_done: indices_to_keep.append(index) # Set the new AgentDone flags to True