diff --git a/ml-agents/mlagents/trainers/agent_processor.py b/ml-agents/mlagents/trainers/agent_processor.py index 0be92b7458..a811307a5a 100644 --- a/ml-agents/mlagents/trainers/agent_processor.py +++ b/ml-agents/mlagents/trainers/agent_processor.py @@ -1,8 +1,8 @@ import sys -from typing import List, Dict, Deque, TypeVar, Generic +from typing import List, Dict, Deque, TypeVar, Generic, Tuple, Set from collections import defaultdict, Counter, deque -from mlagents_envs.base_env import BatchedStepResult +from mlagents_envs.base_env import BatchedStepResult, StepResult from mlagents.trainers.trajectory import Trajectory, AgentExperience from mlagents.trainers.tf_policy import TFPolicy from mlagents.trainers.policy import Policy @@ -36,7 +36,7 @@ def __init__( :param stats_category: The category under which to write the stats. Usually, this comes from the Trainer. """ self.experience_buffers: Dict[str, List[AgentExperience]] = defaultdict(list) - self.last_step_result: Dict[str, BatchedStepResult] = {} + self.last_step_result: Dict[str, Tuple[StepResult, int]] = {} # last_take_action_outputs stores the action a_t taken before the current observation s_(t+1), while # grabbing previous_action from the policy grabs the action PRIOR to that, a_(t-1). self.last_take_action_outputs: Dict[str, ActionInfoOutputs] = {} @@ -69,13 +69,14 @@ def add_experiences( "Policy/Learning Rate", take_action_outputs["learning_rate"] ) - terminated_agents: List[str] = [] + terminated_agents: Set[str] = set() # Make unique agent_ids that are global across workers action_global_agent_ids = [ get_global_agent_id(worker_id, ag_id) for ag_id in previous_action.agent_ids ] for global_id in action_global_agent_ids: - self.last_take_action_outputs[global_id] = take_action_outputs + if global_id in self.last_step_result: # Don't store if agent just reset + self.last_take_action_outputs[global_id] = take_action_outputs for _id in batched_step_result.agent_id: # Assume agent_id is 1-D local_id = int( @@ -83,14 +84,12 @@ def add_experiences( ) # Needed for mypy to pass since ndarray has no content type curr_agent_step = batched_step_result.get_agent_step_result(local_id) global_id = get_global_agent_id(worker_id, local_id) - stored_step = self.last_step_result.get(global_id, None) + stored_agent_step, idx = self.last_step_result.get(global_id, (None, None)) stored_take_action_outputs = self.last_take_action_outputs.get( global_id, None ) - if stored_step is not None and stored_take_action_outputs is not None: + if stored_agent_step is not None and stored_take_action_outputs is not None: # We know the step is from the same worker, so use the local agent id. - stored_agent_step = stored_step.get_agent_step_result(local_id) - idx = stored_step.agent_id_to_index[local_id] obs = stored_agent_step.obs if not stored_agent_step.done: if self.policy.use_recurrent: @@ -155,29 +154,37 @@ def add_experiences( "Environment/Episode Length", self.episode_steps.get(global_id, 0), ) - terminated_agents += [global_id] + terminated_agents.add(global_id) elif not curr_agent_step.done: self.episode_steps[global_id] += 1 - self.last_step_result[global_id] = batched_step_result - - if "action" in take_action_outputs: - self.policy.save_previous_action( - previous_action.agent_ids, take_action_outputs["action"] + # Index is needed to grab from last_take_action_outputs + self.last_step_result[global_id] = ( + curr_agent_step, + batched_step_result.agent_id_to_index[_id], ) for terminated_id in terminated_agents: self._clean_agent_data(terminated_id) + for _gid in action_global_agent_ids: + # If the ID doesn't have a last step result, the agent just reset, + # don't store the action. + if _gid in self.last_step_result: + if "action" in take_action_outputs: + self.policy.save_previous_action( + [_gid], take_action_outputs["action"] + ) + def _clean_agent_data(self, global_id: str) -> None: """ Removes the data for an Agent. """ del self.experience_buffers[global_id] del self.last_take_action_outputs[global_id] + del self.last_step_result[global_id] del self.episode_steps[global_id] del self.episode_rewards[global_id] - del self.last_step_result[global_id] self.policy.remove_previous_action([global_id]) self.policy.remove_memories([global_id]) diff --git a/ml-agents/mlagents/trainers/tests/mock_brain.py b/ml-agents/mlagents/trainers/tests/mock_brain.py index af0ebbd7e8..4999bd30d2 100644 --- a/ml-agents/mlagents/trainers/tests/mock_brain.py +++ b/ml-agents/mlagents/trainers/tests/mock_brain.py @@ -39,6 +39,7 @@ def create_mock_batchedstep( num_vis_observations: int = 0, action_shape: List[int] = None, discrete: bool = False, + done: bool = False, ) -> BatchedStepResult: """ Creates a mock BatchedStepResult with observations. Imitates constant @@ -68,7 +69,7 @@ def create_mock_batchedstep( ] reward = np.array(num_agents * [1.0], dtype=np.float32) - done = np.array(num_agents * [False], dtype=np.bool) + done = np.array(num_agents * [done], dtype=np.bool) max_step = np.array(num_agents * [False], dtype=np.bool) agent_id = np.arange(num_agents, dtype=np.int32) diff --git a/ml-agents/mlagents/trainers/tests/test_agent_processor.py b/ml-agents/mlagents/trainers/tests/test_agent_processor.py index d8ad6b2862..24adb03c7c 100644 --- a/ml-agents/mlagents/trainers/tests/test_agent_processor.py +++ b/ml-agents/mlagents/trainers/tests/test_agent_processor.py @@ -10,6 +10,7 @@ from mlagents.trainers.action_info import ActionInfo from mlagents.trainers.trajectory import Trajectory from mlagents.trainers.stats import StatsReporter +from mlagents.trainers.brain_conversion_utils import get_global_agent_id def create_mock_brain(): @@ -91,6 +92,68 @@ def test_agentprocessor(num_vis_obs): assert len(processor.experience_buffers[0]) == 0 +def test_agent_deletion(): + policy = create_mock_policy() + tqueue = mock.Mock() + name_behavior_id = "test_brain_name" + processor = AgentProcessor( + policy, + name_behavior_id, + max_trajectory_length=5, + stats_reporter=StatsReporter("testcat"), + ) + + fake_action_outputs = { + "action": [0.1], + "entropy": np.array([1.0], dtype=np.float32), + "learning_rate": 1.0, + "pre_action": [0.1], + "log_probs": [0.1], + } + mock_step = mb.create_mock_batchedstep( + num_agents=1, + num_vector_observations=8, + action_shape=[2], + num_vis_observations=0, + ) + mock_done_step = mb.create_mock_batchedstep( + num_agents=1, + num_vector_observations=8, + action_shape=[2], + num_vis_observations=0, + done=True, + ) + fake_action_info = ActionInfo( + action=[0.1], + value=[0.1], + outputs=fake_action_outputs, + agent_ids=mock_step.agent_id, + ) + + processor.publish_trajectory_queue(tqueue) + # This is like the initial state after the env reset + processor.add_experiences(mock_step, 0, ActionInfo.empty()) + + # Run 3 trajectories, with different workers (to simulate different agents) + add_calls = [] + remove_calls = [] + for _ep in range(3): + for _ in range(5): + processor.add_experiences(mock_step, _ep, fake_action_info) + add_calls.append(mock.call([get_global_agent_id(_ep, 0)], [0.1])) + processor.add_experiences(mock_done_step, _ep, fake_action_info) + # Make sure we don't add experiences from the prior agents after the done + remove_calls.append(mock.call([get_global_agent_id(_ep, 0)])) + + policy.save_previous_action.assert_has_calls(add_calls) + policy.remove_previous_action.assert_has_calls(remove_calls) + # Check that there are no experiences left + assert len(processor.experience_buffers.keys()) == 0 + assert len(processor.last_take_action_outputs.keys()) == 0 + assert len(processor.episode_steps.keys()) == 0 + assert len(processor.episode_rewards.keys()) == 0 + + def test_agent_manager(): policy = create_mock_policy() name_behavior_id = "test_brain_name"