From 74e092bded7e91ce251b57315fdd93e6f79257b7 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Tue, 4 Aug 2020 17:21:15 -0700 Subject: [PATCH 1/2] Don't save model twice, copy instead --- ml-agents/mlagents/model_serialization.py | 18 ++++++++++++++++++ .../mlagents/trainers/trainer/rl_trainer.py | 8 +++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/ml-agents/mlagents/model_serialization.py b/ml-agents/mlagents/model_serialization.py index edc7a5f6ee..879ec8d065 100644 --- a/ml-agents/mlagents/model_serialization.py +++ b/ml-agents/mlagents/model_serialization.py @@ -1,5 +1,6 @@ from distutils.util import strtobool import os +import shutil from typing import Any, List, Set, NamedTuple from distutils.version import LooseVersion @@ -227,3 +228,20 @@ def _enforce_onnx_conversion() -> bool: return strtobool(val) except Exception: return False + + +def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None: + """ + Copy the .nn file at the given source to the destination. + Also copies the corresponding .onnx file if it exists. + """ + shutil.copyfile(source_nn_path, destination_nn_path) + logger.info(f"Copied {source_nn_path} to {destination_nn_path}.") + # Copy the onnx file if it exists + source_onnx_path = os.path.splitext(source_nn_path)[0] + ".onnx" + destination_onnx_path = os.path.splitext(destination_nn_path)[0] + ".onnx" + try: + shutil.copyfile(source_onnx_path, destination_onnx_path) + logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.") + except Exception: + pass diff --git a/ml-agents/mlagents/trainers/trainer/rl_trainer.py b/ml-agents/mlagents/trainers/trainer/rl_trainer.py index 2ab44bcf1f..9765ab17be 100644 --- a/ml-agents/mlagents/trainers/trainer/rl_trainer.py +++ b/ml-agents/mlagents/trainers/trainer/rl_trainer.py @@ -5,7 +5,7 @@ import abc import time import attr -from mlagents.model_serialization import SerializationSettings +from mlagents.model_serialization import SerializationSettings, copy_model_files from mlagents.trainers.policy.checkpoint_manager import ( NNCheckpoint, NNCheckpointManager, @@ -131,12 +131,14 @@ def save_model(self) -> None: "Trainer has multiple policies, but default behavior only saves the first." ) policy = list(self.policies.values())[0] - settings = SerializationSettings(policy.model_path, self.brain_name) model_checkpoint = self._checkpoint() + + # Copy the checkpointed model files to the final output location + copy_model_files(model_checkpoint.file_path, f"{policy.model_path}.nn") + final_checkpoint = attr.evolve( model_checkpoint, file_path=f"{policy.model_path}.nn" ) - policy.save(policy.model_path, settings) NNCheckpointManager.track_final_checkpoint(self.brain_name, final_checkpoint) @abc.abstractmethod From 92db69a6b2f87b114fc02d9b6e07225ef7c295d9 Mon Sep 17 00:00:00 2001 From: Chris Elion Date: Wed, 5 Aug 2020 09:26:25 -0700 Subject: [PATCH 2/2] narrower exception --- ml-agents/mlagents/model_serialization.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml-agents/mlagents/model_serialization.py b/ml-agents/mlagents/model_serialization.py index 879ec8d065..11714c3ec2 100644 --- a/ml-agents/mlagents/model_serialization.py +++ b/ml-agents/mlagents/model_serialization.py @@ -243,5 +243,5 @@ def copy_model_files(source_nn_path: str, destination_nn_path: str) -> None: try: shutil.copyfile(source_onnx_path, destination_onnx_path) logger.info(f"Copied {source_onnx_path} to {destination_onnx_path}.") - except Exception: + except OSError: pass