diff --git a/ml-agents/mlagents/trainers/buffer.py b/ml-agents/mlagents/trainers/buffer.py index 03810a9716..b97723282a 100644 --- a/ml-agents/mlagents/trainers/buffer.py +++ b/ml-agents/mlagents/trainers/buffer.py @@ -253,7 +253,7 @@ def truncate(self, max_length: int, sequence_length: int = 1) -> None: max_length -= max_length % sequence_length if current_length > max_length: for _key in self.keys(): - self[_key] = self[_key][current_length - max_length :] + self[_key][:] = self[_key][current_length - max_length :] def resequence_and_append( self, diff --git a/ml-agents/mlagents/trainers/rl_trainer.py b/ml-agents/mlagents/trainers/rl_trainer.py index 12f0ab475c..c07ce22bbe 100644 --- a/ml-agents/mlagents/trainers/rl_trainer.py +++ b/ml-agents/mlagents/trainers/rl_trainer.py @@ -73,5 +73,5 @@ def advance(self) -> None: Steps the trainer, taking in trajectories and updates if ready """ super().advance() - if not self.is_training: + if not self.should_still_train: self.clear_update_buffer() diff --git a/ml-agents/mlagents/trainers/tests/test_buffer.py b/ml-agents/mlagents/trainers/tests/test_buffer.py index 55549a74e5..a4b2086cdb 100644 --- a/ml-agents/mlagents/trainers/tests/test_buffer.py +++ b/ml-agents/mlagents/trainers/tests/test_buffer.py @@ -152,3 +152,5 @@ def test_buffer_truncate(): # Test LSTM, truncate should be some multiple of sequence_length update_buffer.truncate(4, sequence_length=3) assert update_buffer.num_experiences == 3 + for buffer_field in update_buffer.values(): + assert isinstance(buffer_field, AgentBuffer.AgentBufferField) diff --git a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py index 986746ea74..0fe48771d2 100644 --- a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py +++ b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py @@ -3,6 +3,7 @@ import mlagents.trainers.tests.mock_brain as mb from mlagents.trainers.rl_trainer import RLTrainer from mlagents.trainers.tests.test_buffer import construct_fake_buffer +from mlagents.trainers.agent_processor import AgentManagerQueue def dummy_config(): @@ -10,6 +11,7 @@ def dummy_config(): """ summary_path: "test/" summary_freq: 1000 + max_steps: 100 reward_signals: extrinsic: strength: 1.0 @@ -75,3 +77,31 @@ def test_clear_update_buffer(): trainer.clear_update_buffer() for _, arr in trainer.update_buffer.items(): assert len(arr) == 0 + + +@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.clear_update_buffer") +def test_advance(mocked_clear_update_buffer): + trainer = create_rl_trainer() + trajectory_queue = AgentManagerQueue("testbrain") + trainer.subscribe_trajectory_queue(trajectory_queue) + time_horizon = 15 + trajectory = mb.make_fake_trajectory( + length=time_horizon, + max_step_complete=True, + vec_obs_size=1, + num_vis_obs=0, + action_space=[2], + ) + trajectory_queue.put(trajectory) + + trainer.advance() + # Check that get_step is correct + assert trainer.get_step == time_horizon + # Check that we can turn off the trainer and that the buffer is cleared + for _ in range(0, 10): + trajectory_queue.put(trajectory) + trainer.advance() + + # Check that the buffer has been cleared + assert not trainer.should_still_train + assert mocked_clear_update_buffer.call_count > 0