Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 100 additions & 28 deletions gym-unity/gym_unity/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,7 @@ def __init__(
self.visual_obs = None
self._n_agents = -1

self._gym_id_order: List[int] = []
self._done_agents_index_to_last_reward: Dict[int, float] = {}
self.agent_mapper = AgentIdIndexMapper()

# Save the step result from the last time all Agents requested decisions.
self._previous_step_result: BatchedStepResult = None
Expand Down Expand Up @@ -124,7 +123,7 @@ def __init__(
step_result = self._env.get_step_result(self.brain_name)
self._check_agents(step_result.n_agents())
self._previous_step_result = step_result
self._gym_id_order = list(self._previous_step_result.agent_id)
self.agent_mapper.set_initial_agents(list(self._previous_step_result.agent_id))

# Set observation and action spaces
if self.group_spec.is_action_discrete():
Expand Down Expand Up @@ -379,35 +378,25 @@ def _sanitize_info(self, step_result: BatchedStepResult) -> BatchedStepResult:

for index, agent_id in enumerate(step_result.agent_id):
if step_result.done[index]:
gym_index = self._gym_id_order.index(agent_id)
self._done_agents_index_to_last_reward[gym_index] = step_result.reward[
index
]
self._gym_id_order[gym_index] = -1 # no agent at that index
self.agent_mapper.mark_agent_done(agent_id, step_result.reward[index])

# Set the new AgentDone flags to True
# Note that the corresponding agent_id that gets marked done will be different
# than the original agent that was done, but this is OK since the gym interface
# only cares about the ordering.
for index, agent_id in enumerate(step_result.agent_id):
if not self._previous_step_result.contains_agent(agent_id):
# insert the id in the current id list:
original_index = self._gym_id_order.index(-1)
self._gym_id_order[original_index] = agent_id
# This is a new agent
# Register this agent, and get the reward of the previous agent that
# was in its index, so that we can return it to the gym.
last_reward = self.agent_mapper.register_new_agent_id(agent_id)
step_result.done[index] = True
# The index of the agent among not-done agents is
step_result.reward[index] = self._done_agents_index_to_last_reward[
original_index
]
self._done_agents_index_to_last_reward = {}
step_result.reward[index] = last_reward

self._previous_step_result = step_result # store the new original

new_id_order = []
agent_id_list = list(step_result.agent_id)
for agent_id in self._gym_id_order:
new_id_order.append(agent_id_list.index(agent_id))
# Get a permutation of the agent IDs so that a given ID stays in the same
# index as where it was first seen.
new_id_order = self.agent_mapper.get_id_permutation(list(step_result.agent_id))

_mask: Optional[List[np.array]] = None
if step_result.action_mask is not None:
Expand All @@ -430,12 +419,9 @@ def _sanitize_action(self, action: np.array) -> np.array:
sanitized_action = np.zeros(
(self._previous_step_result.n_agents(), self.group_spec.action_size)
)
agent_id_to_gym_index = {
agent_id: gym_index for gym_index, agent_id in enumerate(self._gym_id_order)
}
for index, agent_id in enumerate(self._previous_step_result.agent_id):
if not self._previous_step_result.done[index]:
array_index = agent_id_to_gym_index[agent_id]
array_index = self.agent_mapper.get_gym_index(agent_id)
sanitized_action[index, :] = action[array_index, :]
return sanitized_action

Expand All @@ -456,9 +442,7 @@ def _step(self, needs_reset: bool = False) -> BatchedStepResult:
+ "Some agents did not request decisions at the same time."
)
for agent_id, reward in zip(info.agent_id, info.reward):
gym_index = self._gym_id_order.index(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
self._gym_id_order[gym_index] = -1 # no agent at that index
self.agent_mapper.mark_agent_done(agent_id, reward)

self._env.step()
info = self._env.get_step_result(self.brain_name)
Expand Down Expand Up @@ -526,3 +510,91 @@ def lookup_action(self, action):
:return: The List containing the branched actions.
"""
return self.action_lookup[action]


class AgentIdIndexMapper:
def __init__(self) -> None:
self._agent_id_to_gym_index: Dict[int, int] = {}
self._done_agents_index_to_last_reward: Dict[int, float] = {}

def set_initial_agents(self, agent_ids: List[int]) -> None:
"""
Provide the initial list of agent ids for the mapper
"""
for idx, agent_id in enumerate(agent_ids):
self._agent_id_to_gym_index[agent_id] = idx

def mark_agent_done(self, agent_id: int, reward: float) -> None:
"""
Declare the agent done with the corresponding final reward.
"""
gym_index = self._agent_id_to_gym_index.pop(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward

def register_new_agent_id(self, agent_id: int) -> float:
"""
Adds the new agent ID and returns the reward to use for the previous agent in this index
"""
# Any free index is OK here.
free_index, last_reward = self._done_agents_index_to_last_reward.popitem()
self._agent_id_to_gym_index[agent_id] = free_index
return last_reward

def get_id_permutation(self, agent_ids: List[int]) -> List[int]:
"""
Get the permutation from new agent ids to the order that preserves the positions of previous agents.
The result is a list with each integer from 0 to len(agent_ids)-1 appearing exactly once.
"""
# Map the new agent ids to the their index
new_agent_ids_to_index = {
agent_id: idx for idx, agent_id in enumerate(agent_ids)
}

# Make the output list. We don't write to it sequentially, so start with dummy values.
new_permutation = [-1] * len(agent_ids)

# For each agent ID, find the new index of the agent, and write it in the original index.
for agent_id, original_index in self._agent_id_to_gym_index.items():
new_permutation[original_index] = new_agent_ids_to_index[agent_id]
return new_permutation

def get_gym_index(self, agent_id: int) -> int:
"""
Get the gym index for the current agent.
"""
return self._agent_id_to_gym_index[agent_id]


class AgentIdIndexMapperSlow:
"""
Reference implementation of AgentIdIndexMapper.
The operations are O(N^2) so it shouldn't be used for large numbers of agents.
See AgentIdIndexMapper for method descriptions
"""

def __init__(self) -> None:
self._gym_id_order: List[int] = []
self._done_agents_index_to_last_reward: Dict[int, float] = {}

def set_initial_agents(self, agent_ids: List[int]) -> None:
self._gym_id_order = list(agent_ids)

def mark_agent_done(self, agent_id: int, reward: float) -> None:
gym_index = self._gym_id_order.index(agent_id)
self._done_agents_index_to_last_reward[gym_index] = reward
self._gym_id_order[gym_index] = -1

def register_new_agent_id(self, agent_id: int) -> float:
original_index = self._gym_id_order.index(-1)
self._gym_id_order[original_index] = agent_id
reward = self._done_agents_index_to_last_reward.pop(original_index)
return reward

def get_id_permutation(self, agent_ids):
new_id_order = []
for agent_id in self._gym_id_order:
new_id_order.append(agent_ids.index(agent_id))
return new_id_order

def get_gym_index(self, agent_id: int) -> int:
return self._gym_id_order.index(agent_id)
37 changes: 36 additions & 1 deletion gym-unity/gym_unity/tests/test_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import numpy as np

from gym import spaces
from gym_unity.envs import UnityEnv, UnityGymException
from gym_unity.envs import (
UnityEnv,
UnityGymException,
AgentIdIndexMapper,
AgentIdIndexMapperSlow,
)
from mlagents_envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult


Expand Down Expand Up @@ -183,3 +188,33 @@ def setup_mock_unityenvironment(mock_env, mock_spec, mock_result):
mock_env.return_value.get_agent_groups.return_value = ["MockBrain"]
mock_env.return_value.get_agent_group_spec.return_value = mock_spec
mock_env.return_value.get_step_result.return_value = mock_result


@pytest.mark.parametrize("mapper_cls", [AgentIdIndexMapper, AgentIdIndexMapperSlow])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's pretty cool.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean pytest.mark.parametrize, or passing class types as parameters?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both

def test_agent_id_index_mapper(mapper_cls):
mapper = mapper_cls()
initial_agent_ids = [1001, 1002, 1003, 1004]
mapper.set_initial_agents(initial_agent_ids)

# Mark some agents as done with their last rewards.
mapper.mark_agent_done(1001, 42.0)
mapper.mark_agent_done(1004, 1337.0)

# Now add new agents, and get the rewards of the agent they replaced.
old_reward1 = mapper.register_new_agent_id(2001)
old_reward2 = mapper.register_new_agent_id(2002)

# Order of the rewards don't matter
assert {old_reward1, old_reward2} == {42.0, 1337.0}

new_agent_ids = [1002, 1003, 2001, 2002]
permutation = mapper.get_id_permutation(new_agent_ids)
# Make sure it's actually a permutation - needs to contain 0..N-1 with no repeats.
assert set(permutation) == set(range(0, 4))

# For initial agents that were in the initial group, they need to be in the same slot.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentpierre Is my claim here correct about the new agents? Order shouldn't matter as long as it's consistent, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, although anywhere might be confusing. They must replace previously done agents

# Agents that were added later can be anywhere.
permuted_ids = [new_agent_ids[i] for i in permutation]
for idx, agent_id in enumerate(initial_agent_ids):
if agent_id in permuted_ids:
assert permuted_ids[idx] == agent_id