Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 23 additions & 16 deletions ml-agents/mlagents/trainers/agent_processor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import sys
from typing import List, Dict, Deque, TypeVar, Generic
from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set
from collections import defaultdict, Counter, deque

from mlagents_envs.base_env import BatchedStepResult
from mlagents_envs.base_env import BatchedStepResult, StepResult
from mlagents.trainers.trajectory import Trajectory, AgentExperience
from mlagents.trainers.tf_policy import TFPolicy
from mlagents.trainers.policy import Policy
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(
:param stats_category: The category under which to write the stats. Usually, this comes from the Trainer.
"""
self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list)
self.last_step_result: Dict[str, BatchedStepResult] = {}
self.last_step_result: Dict[str, Tuple[StepResult, int]] = {}
# last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while
# grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1).
self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {}
Expand Down Expand Up @@ -69,28 +69,27 @@ def add_experiences(
"Policy/Learning Rate", take_action_outputs["learning_rate"]
)

terminated_agents: List[str] = []
terminated_agents: Set[str] = set()
# Make unique agent_ids that are global across workers
action_global_agent_ids = [
get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids
]
for global_id in action_global_agent_ids:
self.last_take_action_outputs[global_id] = take_action_outputs
if global_id in self.last_step_result: # Don't store if agent just reset
self.last_take_action_outputs[global_id] = take_action_outputs

for _id in batched_step_result.agent_id: # Assume agent_id is 1-D
local_id = int(
_id
) # Needed for mypy to pass since ndarray has no content type
curr_agent_step = batched_step_result.get_agent_step_result(local_id)
global_id = get_global_agent_id(worker_id, local_id)
stored_step = self.last_step_result.get(global_id, None)
stored_agent_step, idx = self.last_step_result.get(global_id, (None, None))
stored_take_action_outputs = self.last_take_action_outputs.get(
global_id, None
)
if stored_step is not None and stored_take_action_outputs is not None:
if stored_agent_step is not None and stored_take_action_outputs is not None:
# We know the step is from the same worker, so use the local agent id.
stored_agent_step = stored_step.get_agent_step_result(local_id)
idx = stored_step.agent_id_to_index[local_id]
obs = stored_agent_step.obs
if not stored_agent_step.done:
if self.policy.use_recurrent:
Expand Down Expand Up @@ -155,29 +154,37 @@ def add_experiences(
"Environment/Episode Length",
self.episode_steps.get(global_id, 0),
)
terminated_agents += [global_id]
terminated_agents.add(global_id)
elif not curr_agent_step.done:
self.episode_steps[global_id] += 1

self.last_step_result[global_id] = batched_step_result

if "action" in take_action_outputs:
self.policy.save_previous_action(
previous_action.agent_ids, take_action_outputs["action"]
# Index is needed to grab from last_take_action_outputs
self.last_step_result[global_id] = (
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_gid ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one needs to be global_id, its in a different loop

curr_agent_step,
batched_step_result.agent_id_to_index[_id],
)

for terminated_id in terminated_agents:
self._clean_agent_data(terminated_id)

for _gid in action_global_agent_ids:
# If the ID doesn't have a last step result, the agent just reset,
# don't store the action.
if _gid in self.last_step_result:
if "action" in take_action_outputs:
self.policy.save_previous_action(
[_gid], take_action_outputs["action"]
)

def _clean_agent_data(self, global_id: str) -> None:
"""
Removes the data for an Agent.
"""
del self.experience_buffers[global_id]
del self.last_take_action_outputs[global_id]
del self.last_step_result[global_id]
del self.episode_steps[global_id]
del self.episode_rewards[global_id]
del self.last_step_result[global_id]
self.policy.remove_previous_action([global_id])
self.policy.remove_memories([global_id])

Expand Down
3 changes: 2 additions & 1 deletion ml-agents/mlagents/trainers/tests/mock_brain.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def create_mock_batchedstep(
num_vis_observations: int = 0,
action_shape: List[int] = None,
discrete: bool = False,
done: bool = False,
) -> BatchedStepResult:
"""
Creates a mock BatchedStepResult with observations. Imitates constant
Expand Down Expand Up @@ -68,7 +69,7 @@ def create_mock_batchedstep(
]

reward = np.array(num_agents * [1.0], dtype=np.float32)
done = np.array(num_agents * [False], dtype=np.bool)
done = np.array(num_agents * [done], dtype=np.bool)
max_step = np.array(num_agents * [False], dtype=np.bool)
agent_id = np.arange(num_agents, dtype=np.int32)

Expand Down
63 changes: 63 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_agent_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.trajectory import Trajectory
from mlagents.trainers.stats import StatsReporter
from mlagents.trainers.brain_conversion_utils import get_global_agent_id


def create_mock_brain():
Expand Down Expand Up @@ -91,6 +92,68 @@ def test_agentprocessor(num_vis_obs):
assert len(processor.experience_buffers[0]) == 0


def test_agent_deletion():
policy = create_mock_policy()
tqueue = mock.Mock()
name_behavior_id = "test_brain_name"
processor = AgentProcessor(
policy,
name_behavior_id,
max_trajectory_length=5,
stats_reporter=StatsReporter("testcat"),
)

fake_action_outputs = {
"action": [0.1],
"entropy": np.array([1.0], dtype=np.float32),
"learning_rate": 1.0,
"pre_action": [0.1],
"log_probs": [0.1],
}
mock_step = mb.create_mock_batchedstep(
num_agents=1,
num_vector_observations=8,
action_shape=[2],
num_vis_observations=0,
)
mock_done_step = mb.create_mock_batchedstep(
num_agents=1,
num_vector_observations=8,
action_shape=[2],
num_vis_observations=0,
done=True,
)
fake_action_info = ActionInfo(
action=[0.1],
value=[0.1],
outputs=fake_action_outputs,
agent_ids=mock_step.agent_id,
)

processor.publish_trajectory_queue(tqueue)
# This is like the initial state after the env reset
processor.add_experiences(mock_step, 0, ActionInfo.empty())

# Run 3 trajectories, with different workers (to simulate different agents)
add_calls = []
remove_calls = []
for _ep in range(3):
for _ in range(5):
processor.add_experiences(mock_step, _ep, fake_action_info)
add_calls.append(mock.call([get_global_agent_id(_ep, 0)], [0.1]))
processor.add_experiences(mock_done_step, _ep, fake_action_info)
# Make sure we don't add experiences from the prior agents after the done
remove_calls.append(mock.call([get_global_agent_id(_ep, 0)]))

policy.save_previous_action.assert_has_calls(add_calls)
policy.remove_previous_action.assert_has_calls(remove_calls)
# Check that there are no experiences left
assert len(processor.experience_buffers.keys()) == 0
assert len(processor.last_take_action_outputs.keys()) == 0
assert len(processor.episode_steps.keys()) == 0
assert len(processor.episode_rewards.keys()) == 0


def test_agent_manager():
policy = create_mock_policy()
name_behavior_id = "test_brain_name"
Expand Down