Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def truncate(self, max_length: int, sequence_length: int = 1) -> None:
max_length -= max_length % sequence_length
if current_length > max_length:
for _key in self.keys():
self[_key] = self[_key][current_length - max_length :]
self[_key][:] = self[_key][current_length - max_length :]

def resequence_and_append(
self,
Expand Down
2 changes: 1 addition & 1 deletion ml-agents/mlagents/trainers/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,5 +73,5 @@ def advance(self) -> None:
Steps the trainer, taking in trajectories and updates if ready
"""
super().advance()
if not self.is_training:
if not self.should_still_train:
self.clear_update_buffer()
2 changes: 2 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,5 @@ def test_buffer_truncate():
# Test LSTM, truncate should be some multiple of sequence_length
update_buffer.truncate(4, sequence_length=3)
assert update_buffer.num_experiences == 3
for buffer_field in update_buffer.values():
assert isinstance(buffer_field, AgentBuffer.AgentBufferField)
30 changes: 30 additions & 0 deletions ml-agents/mlagents/trainers/tests/test_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
from mlagents.trainers.agent_processor import AgentManagerQueue


def dummy_config():
return yaml.safe_load(
"""
summary_path: "test/"
summary_freq: 1000
max_steps: 100
reward_signals:
extrinsic:
strength: 1.0
Expand Down Expand Up @@ -75,3 +77,31 @@ def test_clear_update_buffer():
trainer.clear_update_buffer()
for _, arr in trainer.update_buffer.items():
assert len(arr) == 0


@mock.patch("mlagents.trainers.rl_trainer.RLTrainer.clear_update_buffer")
def test_advance(mocked_clear_update_buffer):
trainer = create_rl_trainer()
trajectory_queue = AgentManagerQueue("testbrain")
trainer.subscribe_trajectory_queue(trajectory_queue)
time_horizon = 15
trajectory = mb.make_fake_trajectory(
length=time_horizon,
max_step_complete=True,
vec_obs_size=1,
num_vis_obs=0,
action_space=[2],
)
trajectory_queue.put(trajectory)

trainer.advance()
# Check that get_step is correct
assert trainer.get_step == time_horizon
# Check that we can turn off the trainer and that the buffer is cleared
for _ in range(0, 10):
trajectory_queue.put(trajectory)
trainer.advance()

# Check that the buffer has been cleared
assert not trainer.should_still_train
assert mocked_clear_update_buffer.call_count > 0