Skip to content

Commit

Permalink
Add additional logic to avoid load being called on every advance (#4934)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewcoh committed Feb 10, 2021
1 parent aeedd0b commit c56c617
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 20 deletions.
1 change: 1 addition & 0 deletions com.unity.ml-agents/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ and this project adheres to
while waiting for a connection, and raises a better error message if it crashes. (#4880)
- Passing a `-logfile` option in the `--env-args` option to `mlagents-learn` is
no longer overwritten. (#4880)
- The `load_weights` function was being called unnecessarily often in the Ghost Trainer leading to training slowdowns. (#4934)


## [1.7.2-preview] - 2020-12-22
Expand Down
40 changes: 28 additions & 12 deletions ml-agents/mlagents/trainers/ghost/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,25 +247,19 @@ def advance(self) -> None:

next_learning_team = self.controller.get_learning_team

# CASE 1: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the
# new policy into the policy queue for the new learning agent if
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
# CASE 2: Current learning team is managed by a different GhostTrainer.
# If the learning team changes to a team managed by this GhostTrainer, this loop
# will push the current_snapshot into the correct queue. Otherwise,
# it will continue skipping and swap_snapshot will continue to handle
# pushing fixed snapshots
# Case 3: No team change. The if statement just continues to push the policy
# Case 1: No team change. The if statement just continues to push the policy
# into the correct queue (or not if not learning team).
for brain_name in self._internal_policy_queues:
internal_policy_queue = self._internal_policy_queues[brain_name]
try:
policy = internal_policy_queue.get_nowait()
self.current_policy_snapshot[brain_name] = policy.get_weights()
except AgentManagerQueue.Empty:
pass
if next_learning_team in self._team_to_name_to_policy_queue:
continue
if (
self._learning_team == next_learning_team
and next_learning_team in self._team_to_name_to_policy_queue
):
name_to_policy_queue = self._team_to_name_to_policy_queue[
next_learning_team
]
Expand All @@ -277,6 +271,28 @@ def advance(self) -> None:
policy.load_weights(self.current_policy_snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)

# CASE 2: Current learning team is managed by this GhostTrainer.
# If the learning team changes, the following loop over queues will push the
# new policy into the policy queue for the new learning agent if
# that policy is managed by this GhostTrainer. Otherwise, it will save the current snapshot.
# CASE 3: Current learning team is managed by a different GhostTrainer.
# If the learning team changes to a team managed by this GhostTrainer, this loop
# will push the current_snapshot into the correct queue. Otherwise,
# it will continue skipping and swap_snapshot will continue to handle
# pushing fixed snapshots
if (
self._learning_team != next_learning_team
and next_learning_team in self._team_to_name_to_policy_queue
):
name_to_policy_queue = self._team_to_name_to_policy_queue[
next_learning_team
]
for brain_name in name_to_policy_queue:
behavior_id = create_name_behavior_id(brain_name, next_learning_team)
policy = self.get_policy(behavior_id)
policy.load_weights(self.current_policy_snapshot[brain_name])
name_to_policy_queue[brain_name].put(policy)

# Note save and swap should be on different step counters.
# We don't want to save unless the policy is learning.
if self.get_step - self.last_save > self.steps_between_save:
Expand Down
9 changes: 1 addition & 8 deletions ml-agents/mlagents/trainers/tests/torch/test_ghost.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def dummy_config():
VECTOR_ACTION_SPACE = 1
VECTOR_OBS_SPACE = 8
DISCRETE_ACTION_SPACE = [3, 3, 3, 2]
BUFFER_INIT_SAMPLES = 513
BUFFER_INIT_SAMPLES = 10241
NUM_AGENTS = 12


Expand Down Expand Up @@ -193,13 +193,6 @@ def test_publish_queue(dummy_config):
# clear
policy_queue1.get_nowait()

mock_specs = mb.setup_test_behavior_specs(
False,
False,
vector_action_space=VECTOR_ACTION_SPACE,
vector_obs_space=VECTOR_OBS_SPACE,
)

buffer = mb.simulate_rollout(BUFFER_INIT_SAMPLES, mock_specs)
# Mock out reward signal eval
copy_buffer_fields(
Expand Down

0 comments on commit c56c617

Please sign in to comment.