diff --git a/gym-unity/gym_unity/envs/__init__.py b/gym-unity/gym_unity/envs/__init__.py index 3f6edd2ae1..c7da3af9b9 100644 --- a/gym-unity/gym_unity/envs/__init__.py +++ b/gym-unity/gym_unity/envs/__init__.py @@ -75,8 +75,7 @@ def __init__( self.visual_obs = None self._n_agents = -1 - self._gym_id_order: List[int] = [] - self._done_agents_index_to_last_reward: Dict[int, float] = {} + self.agent_mapper = AgentIdIndexMapper() # Save the step result from the last time all Agents requested decisions. self._previous_step_result: BatchedStepResult = None @@ -124,7 +123,7 @@ def __init__( step_result = self._env.get_step_result(self.brain_name) self._check_agents(step_result.n_agents()) self._previous_step_result = step_result - self._gym_id_order = list(self._previous_step_result.agent_id) + self.agent_mapper.set_initial_agents(list(self._previous_step_result.agent_id)) # Set observation and action spaces if self.group_spec.is_action_discrete(): @@ -379,11 +378,7 @@ def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult: for index, agent_id in enumerate(step_result.agent_id): if step_result.done[index]: - gym_index = self._gym_id_order.index(agent_id) - self._done_agents_index_to_last_reward[gym_index] = step_result.reward[ - index - ] - self._gym_id_order[gym_index] = -1 # no agent at that index + self.agent_mapper.mark_agent_done(agent_id, step_result.reward[index]) # Set the new AgentDone flags to True # Note that the corresponding agent_id that gets marked done will be different @@ -391,23 +386,17 @@ def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult: # only cares about the ordering. for index, agent_id in enumerate(step_result.agent_id): if not self._previous_step_result.contains_agent(agent_id): - # insert the id in the current id list: - original_index = self._gym_id_order.index(-1) - self._gym_id_order[original_index] = agent_id - # This is a new agent + # Register this agent, and get the reward of the previous agent that + # was in its index, so that we can return it to the gym. + last_reward = self.agent_mapper.register_new_agent_id(agent_id) step_result.done[index] = True - # The index of the agent among not-done agents is - step_result.reward[index] = self._done_agents_index_to_last_reward[ - original_index - ] - self._done_agents_index_to_last_reward = {} + step_result.reward[index] = last_reward self._previous_step_result = step_result # store the new original - new_id_order = [] - agent_id_list = list(step_result.agent_id) - for agent_id in self._gym_id_order: - new_id_order.append(agent_id_list.index(agent_id)) + # Get a permutation of the agent IDs so that a given ID stays in the same + # index as where it was first seen. + new_id_order = self.agent_mapper.get_id_permutation(list(step_result.agent_id)) _mask: Optional[List[np.array]] = None if step_result.action_mask is not None: @@ -430,12 +419,9 @@ def _sanitize_action(self, action: np.array) -> np.array: sanitized_action = np.zeros( (self._previous_step_result.n_agents(), self.group_spec.action_size) ) - agent_id_to_gym_index = { - agent_id: gym_index for gym_index, agent_id in enumerate(self._gym_id_order) - } for index, agent_id in enumerate(self._previous_step_result.agent_id): if not self._previous_step_result.done[index]: - array_index = agent_id_to_gym_index[agent_id] + array_index = self.agent_mapper.get_gym_index(agent_id) sanitized_action[index, :] = action[array_index, :] return sanitized_action @@ -456,9 +442,7 @@ def _step(self, needs_reset: bool = False) -> BatchedStepResult: + "Some agents did not request decisions at the same time." ) for agent_id, reward in zip(info.agent_id, info.reward): - gym_index = self._gym_id_order.index(agent_id) - self._done_agents_index_to_last_reward[gym_index] = reward - self._gym_id_order[gym_index] = -1 # no agent at that index + self.agent_mapper.mark_agent_done(agent_id, reward) self._env.step() info = self._env.get_step_result(self.brain_name) @@ -526,3 +510,91 @@ def lookup_action(self, action): :return: The List containing the branched actions. """ return self.action_lookup[action] + + +class AgentIdIndexMapper: + def __init__(self) -> None: + self._agent_id_to_gym_index: Dict[int, int] = {} + self._done_agents_index_to_last_reward: Dict[int, float] = {} + + def set_initial_agents(self, agent_ids: List[int]) -> None: + """ + Provide the initial list of agent ids for the mapper + """ + for idx, agent_id in enumerate(agent_ids): + self._agent_id_to_gym_index[agent_id] = idx + + def mark_agent_done(self, agent_id: int, reward: float) -> None: + """ + Declare the agent done with the corresponding final reward. + """ + gym_index = self._agent_id_to_gym_index.pop(agent_id) + self._done_agents_index_to_last_reward[gym_index] = reward + + def register_new_agent_id(self, agent_id: int) -> float: + """ + Adds the new agent ID and returns the reward to use for the previous agent in this index + """ + # Any free index is OK here. + free_index, last_reward = self._done_agents_index_to_last_reward.popitem() + self._agent_id_to_gym_index[agent_id] = free_index + return last_reward + + def get_id_permutation(self, agent_ids: List[int]) -> List[int]: + """ + Get the permutation from new agent ids to the order that preserves the positions of previous agents. + The result is a list with each integer from 0 to len(agent_ids)-1 appearing exactly once. + """ + # Map the new agent ids to the their index + new_agent_ids_to_index = { + agent_id: idx for idx, agent_id in enumerate(agent_ids) + } + + # Make the output list. We don't write to it sequentially, so start with dummy values. + new_permutation = [-1] * len(agent_ids) + + # For each agent ID, find the new index of the agent, and write it in the original index. + for agent_id, original_index in self._agent_id_to_gym_index.items(): + new_permutation[original_index] = new_agent_ids_to_index[agent_id] + return new_permutation + + def get_gym_index(self, agent_id: int) -> int: + """ + Get the gym index for the current agent. + """ + return self._agent_id_to_gym_index[agent_id] + + +class AgentIdIndexMapperSlow: + """ + Reference implementation of AgentIdIndexMapper. + The operations are O(N^2) so it shouldn't be used for large numbers of agents. + See AgentIdIndexMapper for method descriptions + """ + + def __init__(self) -> None: + self._gym_id_order: List[int] = [] + self._done_agents_index_to_last_reward: Dict[int, float] = {} + + def set_initial_agents(self, agent_ids: List[int]) -> None: + self._gym_id_order = list(agent_ids) + + def mark_agent_done(self, agent_id: int, reward: float) -> None: + gym_index = self._gym_id_order.index(agent_id) + self._done_agents_index_to_last_reward[gym_index] = reward + self._gym_id_order[gym_index] = -1 + + def register_new_agent_id(self, agent_id: int) -> float: + original_index = self._gym_id_order.index(-1) + self._gym_id_order[original_index] = agent_id + reward = self._done_agents_index_to_last_reward.pop(original_index) + return reward + + def get_id_permutation(self, agent_ids): + new_id_order = [] + for agent_id in self._gym_id_order: + new_id_order.append(agent_ids.index(agent_id)) + return new_id_order + + def get_gym_index(self, agent_id: int) -> int: + return self._gym_id_order.index(agent_id) diff --git a/gym-unity/gym_unity/tests/test_gym.py b/gym-unity/gym_unity/tests/test_gym.py index 29f6182222..e2015dc64c 100644 --- a/gym-unity/gym_unity/tests/test_gym.py +++ b/gym-unity/gym_unity/tests/test_gym.py @@ -3,7 +3,12 @@ import numpy as np from gym import spaces -from gym_unity.envs import UnityEnv, UnityGymException +from gym_unity.envs import ( + UnityEnv, + UnityGymException, + AgentIdIndexMapper, + AgentIdIndexMapperSlow, +) from mlagents_envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult @@ -183,3 +188,33 @@ def setup_mock_unityenvironment(mock_env, mock_spec, mock_result): mock_env.return_value.get_agent_groups.return_value = ["MockBrain"] mock_env.return_value.get_agent_group_spec.return_value = mock_spec mock_env.return_value.get_step_result.return_value = mock_result + + +@pytest.mark.parametrize("mapper_cls", [AgentIdIndexMapper, AgentIdIndexMapperSlow]) +def test_agent_id_index_mapper(mapper_cls): + mapper = mapper_cls() + initial_agent_ids = [1001, 1002, 1003, 1004] + mapper.set_initial_agents(initial_agent_ids) + + # Mark some agents as done with their last rewards. + mapper.mark_agent_done(1001, 42.0) + mapper.mark_agent_done(1004, 1337.0) + + # Now add new agents, and get the rewards of the agent they replaced. + old_reward1 = mapper.register_new_agent_id(2001) + old_reward2 = mapper.register_new_agent_id(2002) + + # Order of the rewards don't matter + assert {old_reward1, old_reward2} == {42.0, 1337.0} + + new_agent_ids = [1002, 1003, 2001, 2002] + permutation = mapper.get_id_permutation(new_agent_ids) + # Make sure it's actually a permutation - needs to contain 0..N-1 with no repeats. + assert set(permutation) == set(range(0, 4)) + + # For initial agents that were in the initial group, they need to be in the same slot. + # Agents that were added later can be anywhere. + permuted_ids = [new_agent_ids[i] for i in permutation] + for idx, agent_id in enumerate(initial_agent_ids): + if agent_id in permuted_ids: + assert permuted_ids[idx] == agent_id