-
Notifications
You must be signed in to change notification settings - Fork 4.4k
encapsulate the agent mapping operations #3481
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,7 +3,12 @@ | |
| import numpy as np | ||
|
|
||
| from gym import spaces | ||
| from gym_unity.envs import UnityEnv, UnityGymException | ||
| from gym_unity.envs import ( | ||
| UnityEnv, | ||
| UnityGymException, | ||
| AgentIdIndexMapper, | ||
| AgentIdIndexMapperSlow, | ||
| ) | ||
| from mlagents_envs.base_env import AgentGroupSpec, ActionType, BatchedStepResult | ||
|
|
||
|
|
||
|
|
@@ -183,3 +188,33 @@ def setup_mock_unityenvironment(mock_env, mock_spec, mock_result): | |
| mock_env.return_value.get_agent_groups.return_value = ["MockBrain"] | ||
| mock_env.return_value.get_agent_group_spec.return_value = mock_spec | ||
| mock_env.return_value.get_step_result.return_value = mock_result | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("mapper_cls", [AgentIdIndexMapper, AgentIdIndexMapperSlow]) | ||
| def test_agent_id_index_mapper(mapper_cls): | ||
| mapper = mapper_cls() | ||
| initial_agent_ids = [1001, 1002, 1003, 1004] | ||
| mapper.set_initial_agents(initial_agent_ids) | ||
|
|
||
| # Mark some agents as done with their last rewards. | ||
| mapper.mark_agent_done(1001, 42.0) | ||
| mapper.mark_agent_done(1004, 1337.0) | ||
|
|
||
| # Now add new agents, and get the rewards of the agent they replaced. | ||
| old_reward1 = mapper.register_new_agent_id(2001) | ||
| old_reward2 = mapper.register_new_agent_id(2002) | ||
|
|
||
| # Order of the rewards don't matter | ||
| assert {old_reward1, old_reward2} == {42.0, 1337.0} | ||
|
|
||
| new_agent_ids = [1002, 1003, 2001, 2002] | ||
| permutation = mapper.get_id_permutation(new_agent_ids) | ||
| # Make sure it's actually a permutation - needs to contain 0..N-1 with no repeats. | ||
| assert set(permutation) == set(range(0, 4)) | ||
|
|
||
| # For initial agents that were in the initial group, they need to be in the same slot. | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vincentpierre Is my claim here correct about the new agents? Order shouldn't matter as long as it's consistent, right?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, although anywhere might be confusing. They must replace previously done agents |
||
| # Agents that were added later can be anywhere. | ||
| permuted_ids = [new_agent_ids[i] for i in permutation] | ||
| for idx, agent_id in enumerate(initial_agent_ids): | ||
| if agent_id in permuted_ids: | ||
| assert permuted_ids[idx] == agent_id | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's pretty cool.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean
pytest.mark.parametrize, or passing class types as parameters?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
both