Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add additional logic to avoid load being called on every advance #4934

Merged
merged 3 commits into from
Feb 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ and this project adheres to
while waiting for a connection, and raises a better error message if it crashes. (#4880)
- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is
no longer overwritten. (#4880)
- The `load_weights` function was being called unnecessarily often in the Ghost Trainer leading to training slowdowns. (#4934)


## [1.7.2-preview] - 2020-12-22
Expand Down
40 changes: 28 additions & 12 deletions ml-agents/mlagents/trainers/ghost/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,25 +247,19 @@ def advance(self) -> None:

next_learning_team = self.controller.get_learning_team

# CASE 1: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the
# new policy into the policy queue for the new learning agent if
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
# CASE 2: Current learning team is managed by a different GhostTrainer.
# If the learning team changes to a team managed by this GhostTrainer, this loop
# will push the current_snapshot into the correct queue. Otherwise,
# it will continue skipping and swap_snapshot will continue to handle
# pushing fixed snapshots
# Case 3: No team change. The if statement just continues to push the policy
# Case 1: No team change. The if statement just continues to push the policy
# into the correct queue (or not if not learning team).
for brain_name in self._internal_policy_queues:
internal_policy_queue = self._internal_policy_queues[brain_name]
try:
policy = internal_policy_queue.get_nowait()
self.current_policy_snapshot[brain_name] = policy.get_weights()
except AgentManagerQueue.Empty:
pass
if next_learning_team in self._team_to_name_to_policy_queue:
continue
if (
self._learning_team == next_learning_team
and next_learning_team in self._team_to_name_to_policy_queue
):
name_to_policy_queue = self._team_to_name_to_policy_queue[
next_learning_team
]
Expand All @@ -277,6 +271,28 @@ def advance(self) -> None:
policy.load_weights(self.current_policy_snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)

# CASE 2: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the
# new policy into the policy queue for the new learning agent if
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
# CASE 3: Current learning team is managed by a different GhostTrainer.
# If the learning team changes to a team managed by this GhostTrainer, this loop
# will push the current_snapshot into the correct queue. Otherwise,
# it will continue skipping and swap_snapshot will continue to handle
# pushing fixed snapshots
if (
self._learning_team != next_learning_team
and next_learning_team in self._team_to_name_to_policy_queue
):
name_to_policy_queue = self._team_to_name_to_policy_queue[
next_learning_team
]
for brain_name in name_to_policy_queue:
behavior_id = create_name_behavior_id(brain_name, next_learning_team)
policy = self.get_policy(behavior_id)
policy.load_weights(self.current_policy_snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)

# Note save and swap should be on different step counters.
# We don't want to save unless the policy is learning.
if self.get_step - self.last_save > self.steps_between_save:
Expand Down
9 changes: 1 addition & 8 deletions ml-agents/mlagents/trainers/tests/torch/test_ghost.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def dummy_config():
VECTOR_ACTION_SPACE = 1
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 513
BUFFER_INIT_SAMPLES = 10241
NUM_AGENTS = 12


Expand Down Expand Up @@ -193,13 +193,6 @@ def test_publish_queue(dummy_config):
# clear
policy_queue1.get_nowait()

mock_specs = mb.setup_test_behavior_specs(
False,
False,
vector_action_space=VECTOR_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE,
)

buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_specs)
# Mock out reward signal eval
copy_buffer_fields(
Expand Down