From bb8ddecf7c1b553808bf12e807a3709f1a3aac3f Mon Sep 17 00:00:00 2001 From: vincentpierre Date: Mon, 5 Oct 2020 10:09:07 -0700 Subject: [PATCH] Rename NNCheckpoint to ModelCheckpoint as Model can be NN or ONNX --- .../trainers/policy/checkpoint_manager.py | 10 ++++---- ml-agents/mlagents/trainers/sac/trainer.py | 4 ++-- .../trainers/tests/test_rl_trainer.py | 8 ++++--- .../trainers/tests/test_training_status.py | 24 ++++++++++--------- .../mlagents/trainers/trainer/rl_trainer.py | 12 +++++----- 5 files changed, 31 insertions(+), 27 deletions(-) diff --git a/ml-agents/mlagents/trainers/policy/checkpoint_manager.py b/ml-agents/mlagents/trainers/policy/checkpoint_manager.py index 93f21346e4..961d5290be 100644 --- a/ml-agents/mlagents/trainers/policy/checkpoint_manager.py +++ b/ml-agents/mlagents/trainers/policy/checkpoint_manager.py @@ -9,14 +9,14 @@ @attr.s(auto_attribs=True) -class NNCheckpoint: +class ModelCheckpoint: steps: int file_path: str reward: Optional[float] creation_time: float -class NNCheckpointManager: +class ModelCheckpointManager: @staticmethod def get_checkpoints(behavior_name: str) -> List[Dict[str, Any]]: checkpoint_list = GlobalTrainingStatus.get_parameter_state( @@ -60,12 +60,12 @@ def _cleanup_extra_checkpoints( while len(checkpoints) > keep_checkpoints: if keep_checkpoints <= 0 or len(checkpoints) == 0: break - NNCheckpointManager.remove_checkpoint(checkpoints.pop(0)) + ModelCheckpointManager.remove_checkpoint(checkpoints.pop(0)) return checkpoints @classmethod def add_checkpoint( - cls, behavior_name: str, new_checkpoint: NNCheckpoint, keep_checkpoints: int + cls, behavior_name: str, new_checkpoint: ModelCheckpoint, keep_checkpoints: int ) -> None: """ Make room for new checkpoint if needed and insert new checkpoint information. @@ -83,7 +83,7 @@ def add_checkpoint( @classmethod def track_final_checkpoint( - cls, behavior_name: str, final_checkpoint: NNCheckpoint + cls, behavior_name: str, final_checkpoint: ModelCheckpoint ) -> None: """ Ensures number of checkpoints stored is within the max number of checkpoints diff --git a/ml-agents/mlagents/trainers/sac/trainer.py b/ml-agents/mlagents/trainers/sac/trainer.py index 220f6205d6..241095349f 100644 --- a/ml-agents/mlagents/trainers/sac/trainer.py +++ b/ml-agents/mlagents/trainers/sac/trainer.py @@ -7,7 +7,7 @@ import os import numpy as np -from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint +from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint from mlagents_envs.logging_util import get_logger from mlagents_envs.timers import timed @@ -88,7 +88,7 @@ def __init__( self.checkpoint_replay_buffer = self.hyperparameters.save_replay_buffer - def _checkpoint(self) -> NNCheckpoint: + def _checkpoint(self) -> ModelCheckpoint: """ Writes a checkpoint model to memory Overrides the default to save the replay buffer. diff --git a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py index 30d570e67c..ef248224a4 100644 --- a/ml-agents/mlagents/trainers/tests/test_rl_trainer.py +++ b/ml-agents/mlagents/trainers/tests/test_rl_trainer.py @@ -2,7 +2,7 @@ from unittest import mock import pytest import mlagents.trainers.tests.mock_brain as mb -from mlagents.trainers.policy.checkpoint_manager import NNCheckpoint +from mlagents.trainers.policy.checkpoint_manager import ModelCheckpoint from mlagents.trainers.trainer.rl_trainer import RLTrainer from mlagents.trainers.tests.test_buffer import construct_fake_buffer from mlagents.trainers.agent_processor import AgentManagerQueue @@ -126,7 +126,9 @@ def test_advance(mocked_clear_update_buffer, mocked_save_model): "framework", [FrameworkType.TENSORFLOW, FrameworkType.PYTORCH], ids=["tf", "torch"] ) @mock.patch("mlagents.trainers.trainer.trainer.StatsReporter.write_stats") -@mock.patch("mlagents.trainers.trainer.rl_trainer.NNCheckpointManager.add_checkpoint") +@mock.patch( + "mlagents.trainers.trainer.rl_trainer.ModelCheckpointManager.add_checkpoint" +) def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework): trainer = create_rl_trainer(framework) mock_policy = mock.Mock() @@ -170,7 +172,7 @@ def test_summary_checkpoint(mock_add_checkpoint, mock_write_summary, framework): add_checkpoint_calls = [ mock.call( trainer.brain_name, - NNCheckpoint( + ModelCheckpoint( step, f"{trainer.model_saver.model_path}/{trainer.brain_name}-{step}.{export_ext}", None, diff --git a/ml-agents/mlagents/trainers/tests/test_training_status.py b/ml-agents/mlagents/trainers/tests/test_training_status.py index d1fae24a8f..db9992fe1c 100644 --- a/ml-agents/mlagents/trainers/tests/test_training_status.py +++ b/ml-agents/mlagents/trainers/tests/test_training_status.py @@ -9,8 +9,8 @@ GlobalTrainingStatus, ) from mlagents.trainers.policy.checkpoint_manager import ( - NNCheckpointManager, - NNCheckpoint, + ModelCheckpointManager, + ModelCheckpoint, ) @@ -78,25 +78,27 @@ def test_model_management(tmpdir): brain_name, StatusType.CHECKPOINTS, test_checkpoint_list ) - new_checkpoint_4 = NNCheckpoint( + new_checkpoint_4 = ModelCheckpoint( 4, os.path.join(final_model_path, f"{brain_name}-4.nn"), 2.678, time.time() ) - NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4) - assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4 + ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_4, 4) + assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 - new_checkpoint_5 = NNCheckpoint( + new_checkpoint_5 = ModelCheckpoint( 5, os.path.join(final_model_path, f"{brain_name}-5.nn"), 3.122, time.time() ) - NNCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4) - assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4 + ModelCheckpointManager.add_checkpoint(brain_name, new_checkpoint_5, 4) + assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 final_model_path = f"{final_model_path}.nn" final_model_time = time.time() current_step = 6 - final_model = NNCheckpoint(current_step, final_model_path, 3.294, final_model_time) + final_model = ModelCheckpoint( + current_step, final_model_path, 3.294, final_model_time + ) - NNCheckpointManager.track_final_checkpoint(brain_name, final_model) - assert len(NNCheckpointManager.get_checkpoints(brain_name)) == 4 + ModelCheckpointManager.track_final_checkpoint(brain_name, final_model) + assert len(ModelCheckpointManager.get_checkpoints(brain_name)) == 4 check_checkpoints = GlobalTrainingStatus.saved_state[brain_name][ StatusType.CHECKPOINTS.value diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py index da8c172615..903b3ae770 100644 --- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py +++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py @@ -5,8 +5,8 @@ import time import attr from mlagents.trainers.policy.checkpoint_manager import ( - NNCheckpoint, - NNCheckpointManager, + ModelCheckpoint, + ModelCheckpointManager, ) from mlagents_envs.logging_util import get_logger from mlagents_envs.timers import timed @@ -176,7 +176,7 @@ def _policy_mean_reward(self) -> Optional[float]: return sum(rewards) / len(rewards) @timed - def _checkpoint(self) -> NNCheckpoint: + def _checkpoint(self) -> ModelCheckpoint: """ Checkpoints the policy associated with this trainer. """ @@ -187,13 +187,13 @@ def _checkpoint(self) -> NNCheckpoint: ) checkpoint_path = self.model_saver.save_checkpoint(self.brain_name, self.step) export_ext = "nn" if self.framework == FrameworkType.TENSORFLOW else "onnx" - new_checkpoint = NNCheckpoint( + new_checkpoint = ModelCheckpoint( int(self.step), f"{checkpoint_path}.{export_ext}", self._policy_mean_reward(), time.time(), ) - NNCheckpointManager.add_checkpoint( + ModelCheckpointManager.add_checkpoint( self.brain_name, new_checkpoint, self.trainer_settings.keep_checkpoints ) return new_checkpoint @@ -217,7 +217,7 @@ def save_model(self) -> None: final_checkpoint = attr.evolve( model_checkpoint, file_path=f"{self.model_saver.model_path}.{export_ext}" ) - NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint) + ModelCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint) @abc.abstractmethod def _update_policy(self) -> bool: