Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ml-agents/mlagents/trainers/policy/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@


@attr.s(auto_attribs=True)
class NNCheckpoint:
class ModelCheckpoint:
steps: int
file_path: str
reward: Optional[float]
creation_time: float


class NNCheckpointManager:
class ModelCheckpointManager:
@staticmethod
def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]:
checkpoint_list = GlobalTrainingStatus.get_parameter_state(
Expand Down Expand Up @@ -60,12 +60,12 @@ def _cleanup_extra_checkpoints(
while len(checkpoints) > keep_checkpoints:
if keep_checkpoints <= 0 or len(checkpoints) == 0:
break
NNCheckpointManager.remove_checkpoint(checkpoints.pop(0))
ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0))
return checkpoints

@classmethod
def add_checkpoint(
cls, behavior_name: str, new_checkpoint: NNCheckpoint, keep_checkpoints: int
cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int
) -> None:
"""
Make room for new checkpoint if needed and insert new checkpoint information.
Expand All @@ -83,7 +83,7 @@ def add_checkpoint(

@classmethod
def track_final_checkpoint(
cls, behavior_name: str, final_checkpoint: NNCheckpoint
cls, behavior_name: str, final_checkpoint: ModelCheckpoint
) -> None:
"""
Ensures number of checkpoints stored is within the max number of checkpoints
Expand Down
4 changes: 2 additions & 2 deletions ml-agents/mlagents/trainers/sac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import os

import numpy as np
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint
from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint

from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(

self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer

def _checkpoint(self) -> NNCheckpoint:
def _checkpoint(self) -> ModelCheckpoint:
"""
Writes a checkpoint model to memory
Overrides the default to save the replay buffer.
Expand Down
8 changes: 5 additions & 3 deletions ml-agents/mlagents/trainers/tests/test_rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from unittest import mock
import pytest
import mlagents.trainers.tests.mock_brain as mb
from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint
from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint
from mlagents.trainers.trainer.rl_trainer import RLTrainer
from mlagents.trainers.tests.test_buffer import construct_fake_buffer
from mlagents.trainers.agent_processor import AgentManagerQueue
Expand Down Expand Up @@ -126,7 +126,9 @@ def test_advance(mocked_clear_update_buffer, mocked_save_model):
"framework", [FrameworkType.TENSORFLOW, FrameworkType.PYTORCH], ids=["tf", "torch"]
)
@mock.patch("mlagents.trainers.trainer.trainer.StatsReporter.write_stats")
@mock.patch("mlagents.trainers.trainer.rl_trainer.NNCheckpointManager.add_checkpoint")
@mock.patch(
"mlagents.trainers.trainer.rl_trainer.ModelCheckpointManager.add_checkpoint"
)
def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework):
trainer = create_rl_trainer(framework)
mock_policy = mock.Mock()
Expand Down Expand Up @@ -170,7 +172,7 @@ def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework):
add_checkpoint_calls = [
mock.call(
trainer.brain_name,
NNCheckpoint(
ModelCheckpoint(
step,
f"{trainer.model_saver.model_path}/{trainer.brain_name}-{step}.{export_ext}",
None,
Expand Down
24 changes: 13 additions & 11 deletions ml-agents/mlagents/trainers/tests/test_training_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
GlobalTrainingStatus,
)
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpointManager,
NNCheckpoint,
ModelCheckpointManager,
ModelCheckpoint,
)


Expand Down Expand Up @@ -78,25 +78,27 @@ def test_model_management(tmpdir):
brain_name, StatusType.CHECKPOINTS, test_checkpoint_list
)

new_checkpoint_4 = NNCheckpoint(
new_checkpoint_4 = ModelCheckpoint(
4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time()
)
NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

new_checkpoint_5 = NNCheckpoint(
new_checkpoint_5 = ModelCheckpoint(
5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time()
)
NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

final_model_path = f"{final_model_path}.nn"
final_model_time = time.time()
current_step = 6
final_model = NNCheckpoint(current_step, final_model_path, 3.294, final_model_time)
final_model = ModelCheckpoint(
current_step, final_model_path, 3.294, final_model_time
)

NNCheckpointManager.track_final_checkpoint(brain_name, final_model)
assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4
ModelCheckpointManager.track_final_checkpoint(brain_name, final_model)
assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4

check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][
StatusType.CHECKPOINTS.value
Expand Down
12 changes: 6 additions & 6 deletions ml-agents/mlagents/trainers/trainer/rl_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import time
import attr
from mlagents.trainers.policy.checkpoint_manager import (
NNCheckpoint,
NNCheckpointManager,
ModelCheckpoint,
ModelCheckpointManager,
)
from mlagents_envs.logging_util import get_logger
from mlagents_envs.timers import timed
Expand Down Expand Up @@ -176,7 +176,7 @@ def _policy_mean_reward(self) -> Optional[float]:
return sum(rewards) / len(rewards)

@timed
def _checkpoint(self) -> NNCheckpoint:
def _checkpoint(self) -> ModelCheckpoint:
"""
Checkpoints the policy associated with this trainer.
"""
Expand All @@ -187,13 +187,13 @@ def _checkpoint(self) -> NNCheckpoint:
)
checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self.step)
export_ext = "nn" if self.framework == FrameworkType.TENSORFLOW else "onnx"
new_checkpoint = NNCheckpoint(
new_checkpoint = ModelCheckpoint(
int(self.step),
f"{checkpoint_path}.{export_ext}",
self._policy_mean_reward(),
time.time(),
)
NNCheckpointManager.add_checkpoint(
ModelCheckpointManager.add_checkpoint(
self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints
)
return new_checkpoint
Expand All @@ -217,7 +217,7 @@ def save_model(self) -> None:
final_checkpoint = attr.evolve(
model_checkpoint, file_path=f"{self.model_saver.model_path}.{export_ext}"
)
NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)
ModelCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint)

@abc.abstractmethod
def _update_policy(self) -> bool:
Expand Down