From 0b9be17570d6fc30e82eb10e36be914e62299218 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Mon, 12 May 2025 13:35:01 -0700 Subject: [PATCH 01/17] Improve default handling --- src/sagemaker/modules/configs.py | 68 +++++++++++++++- src/sagemaker/modules/train/model_trainer.py | 83 ++++++++++++++++---- 2 files changed, 130 insertions(+), 21 deletions(-) diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index ac54e2ad0b..9abddce83c 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -30,7 +30,6 @@ from sagemaker_core.shapes import ( StoppingCondition, RetryStrategy, - OutputDataConfig, Channel, ShuffleConfig, DataSource, @@ -42,9 +41,7 @@ InfraCheckConfig, RemoteDebugConfig, SessionChainingConfig, - InstanceGroup, - TensorBoardOutputConfig, - CheckpointConfig, + InstanceGroup ) from sagemaker.modules.utils import convert_unassigned_to_none @@ -131,6 +128,8 @@ class Compute(shapes.ResourceConfig): subsequent training jobs. instance_groups (Optional[List[InstanceGroup]]): A list of instance groups for heterogeneous clusters to be used in the training job. + training_plan_arn (Optional[str]): + The Amazon Resource Name (ARN) of the training plan to use for this resource configuration. enable_managed_spot_training (Optional[bool]): To train models using managed spot training, choose True. Managed spot training provides a fully managed and scalable infrastructure for training machine learning @@ -224,3 +223,64 @@ class InputData(BaseConfig): channel_name: str = None data_source: Union[str, FileSystemDataSource, S3DataSource] = None + + +class OutputDataConfig(shapes.OutputDataConfig): + """OutputDataConfig. + + The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig`` + and allows the user to specify the output data configuration for the training job. + + Parameters: + s3_output_path (Optional[str]): + The S3 URI where the output data will be stored. This is the location where the + training job will save its output data, such as model artifacts and logs. + kms_key_id (Optional[str]): + The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that + SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side + encryption. + compression_type (Optional[str]): + The model output compression type. Select None to output an uncompressed model, + recommended for large model outputs. Defaults to gzip. + """ + + s3_output_path: Optional[str] = None + kms_key_id: Optional[str] = None + compression_type: Optional[str] = None + + +class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig): + """TensorBoardOutputConfig. + + The TensorBoardOutputConfig class is a subclass of ``sagemaker_core.shapes.TensorBoardOutputConfig`` + and allows the user to specify the storage locations for the Amazon SageMaker + Debugger TensorBoard. + + Parameters: + s3_output_path (Optional[str]): + Path to Amazon S3 storage location for TensorBoard output. If not specified, will default to the default artifact location for the training job. + ``s3://////`` + local_path (Optional[str]): + Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard. + """ + + s3_output_path: Optional[str] = None + local_path: Optional[str] = "/opt/ml/output/tensorboard" + + +class CheckpointConfig(shapes.CheckpointConfig): + """CheckpointConfig. + + The CheckpointConfig class is a subclass of ``sagemaker_core.shapes.CheckpointConfig`` + and allows the user to specify the checkpoint configuration for the training job. + + Parameters: + s3_uri (Optional[str]): + Path to Amazon S3 storage location for the Checkpoint data. If not specified, will default to the default artifact location for the training job. + ``s3://////`` + local_path (Optional[str]): + The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints. + """ + + s3_uri: Optional[str] = None + local_path: Optional[str] = "/opt/ml/checkpoints" diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 96078d1aeb..92709b84a1 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -181,7 +181,7 @@ class ModelTrainer(BaseModel): The output data configuration. This is used to specify the output data location for the training job. If not specified in the session, will default to - ``s3://///``. + s3://///. input_data_config (Optional[List[Union[Channel, InputData]]]): The input data config for the training job. Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI @@ -477,6 +477,20 @@ def model_post_init(self, __context: Any): ) logger.warning(f"Compute not provided. Using default:\n{self.compute}") + if self.compute.instance_type is None: + self.compute.instance_type = DEFAULT_INSTANCE_TYPE + logger.warning(f"Instance type not provided. Using default:\n{DEFAULT_INSTANCE_TYPE}") + if self.compute.instance_count is None: + self.compute.instance_count = 1 + logger.warning( + f"Instance count not provided. Using default:\n{self.compute.instance_count}" + ) + if self.compute.volume_size_in_gb is None: + self.compute.volume_size_in_gb = 30 + logger.warning( + f"Volume size not provided. Using default:\n{self.compute.volume_size_in_gb}" + ) + if self.stopping_condition is None: self.stopping_condition = StoppingCondition( max_runtime_in_seconds=3600, @@ -486,6 +500,12 @@ def model_post_init(self, __context: Any): logger.warning( f"StoppingCondition not provided. Using default:\n{self.stopping_condition}" ) + if self.stopping_condition.max_runtime_in_seconds is None: + self.stopping_condition.max_runtime_in_seconds = 3600 + logger.info( + "Max runtime not provided. Using default:\n" + f"{self.stopping_condition.max_runtime_in_seconds}" + ) if self.hyperparameters and isinstance(self.hyperparameters, str): if not os.path.exists(self.hyperparameters): @@ -511,23 +531,40 @@ def model_post_init(self, __context: Any): ) if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None: - session = self.sagemaker_session - base_job_name = self.base_job_name - self.output_data_config = OutputDataConfig( - s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}" - f"/{base_job_name}", - compression_type="GZIP", - kms_key_id=None, - ) - logger.warning( - f"OutputDataConfig not provided. Using default:\n{self.output_data_config}" - ) + if self.output_data_config is None: + session = self.sagemaker_session + base_job_name = self.base_job_name + self.output_data_config = OutputDataConfig( + s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}" + f"/{base_job_name}", + compression_type="GZIP", + kms_key_id=None, + ) + logger.warning( + f"OutputDataConfig not provided. Using default:\n{self.output_data_config}" + ) + if self.output_data_config.s3_output_path is None: + session = self.sagemaker_session + base_job_name = self.base_job_name + self.output_data_config.s3_output_path = ( + f"s3://{self._fetch_bucket_name_and_prefix(session)}/{base_job_name}" + ) + logger.warning( + f"OutputDataConfig s3_output_path not provided. Using default:\n" + f"{self.output_data_config.s3_output_path}" + ) + if self.output_data_config.compression_type is None: + self.output_data_config.compression_type = "GZIP" + logger.warning( + f"OutputDataConfig compression type not provided. Using default:\n" + f"{self.output_data_config.compression_type}" + ) - # TODO: Autodetect which image to use if source_code is provided if self.training_image: logger.info(f"Training image URI: {self.training_image}") - def _fetch_bucket_name_and_prefix(self, session: Session) -> str: + @staticmethod + def _fetch_bucket_name_and_prefix(session: Session) -> str: """Helper function to get the bucket name with the corresponding prefix if applicable""" if session.default_bucket_prefix is not None: return f"{session.default_bucket()}/{session.default_bucket_prefix}" @@ -558,16 +595,28 @@ def train( """ self._populate_intelligent_defaults() current_training_job_name = _get_unique_name(self.base_job_name) - input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input" - if input_data_config: + default_artifact_path = f"{self.base_job_name}/{current_training_job_name}" + input_data_key_prefix = f"{default_artifact_path}/input" + if input_data_config and self.input_data_config: self.input_data_config = input_data_config + # Add missing input data channels to the existing input_data_config + final_input_channel_names = {i.channel_name for i in input_data_config} + for input_data in self.input_data_config: + if input_data.channel_name not in final_input_channel_names: + input_data_config.append(input_data) + + self.input_data_config = input_data_config or self.input_data_config or [] - input_data_config = [] if self.input_data_config: input_data_config = self._get_input_data_config( self.input_data_config, input_data_key_prefix ) + if self.checkpoint_config and not self.checkpoint_config.s3_uri: + self.checkpoint_config.s3_uri = f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/{default_artifact_path}" + if self._tensorboard_output_config and not self._tensorboard_output_config.s3_uri: + self._tensorboard_output_config.s3_uri = f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/{default_artifact_path}" + string_hyper_parameters = {} if self.hyperparameters: for hyper_parameter, value in self.hyperparameters.items(): From e265d787f5f6a46c171303b86e388097e12cd4a3 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Mon, 12 May 2025 13:42:05 -0700 Subject: [PATCH 02/17] format --- src/sagemaker/modules/configs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index 9abddce83c..89403fbd8f 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -41,7 +41,7 @@ InfraCheckConfig, RemoteDebugConfig, SessionChainingConfig, - InstanceGroup + InstanceGroup, ) from sagemaker.modules.utils import convert_unassigned_to_none From 405dfd921b784b1aeaab10e402e9e7c0beca6bf7 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Mon, 12 May 2025 16:14:30 -0700 Subject: [PATCH 03/17] add tests & update docs --- pyproject.toml | 3 +- src/sagemaker/modules/configs.py | 10 +- src/sagemaker/modules/train/model_trainer.py | 139 +++++++++++++++--- .../modules/train/test_model_trainer.py | 34 +++++ 4 files changed, 161 insertions(+), 25 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c6508f54ad..17dfab3571 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,8 @@ dependencies = [ "tblib>=1.7.0,<4", "tqdm", "urllib3>=1.26.8,<3.0.0", - "uvicorn" + "uvicorn", + "graphene>=3,<4" ] [project.scripts] diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index 89403fbd8f..5af5b049f7 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -258,8 +258,9 @@ class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig): Parameters: s3_output_path (Optional[str]): - Path to Amazon S3 storage location for TensorBoard output. If not specified, will default to the default artifact location for the training job. - ``s3://////`` + Path to Amazon S3 storage location for TensorBoard output. If not specified, will + default to + ``s3://////tensorboard-output`` local_path (Optional[str]): Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard. """ @@ -276,8 +277,9 @@ class CheckpointConfig(shapes.CheckpointConfig): Parameters: s3_uri (Optional[str]): - Path to Amazon S3 storage location for the Checkpoint data. If not specified, will default to the default artifact location for the training job. - ``s3://////`` + Path to Amazon S3 storage location for the Checkpoint data. If not specified, will + default to + ``s3://////checkpoints`` local_path (Optional[str]): The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints. """ diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 92709b84a1..b84904e7f8 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -181,7 +181,7 @@ class ModelTrainer(BaseModel): The output data configuration. This is used to specify the output data location for the training job. If not specified in the session, will default to - s3://///. + ``s3://///``. input_data_config (Optional[List[Union[Channel, InputData]]]): The input data config for the training job. Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI @@ -595,8 +595,7 @@ def train( """ self._populate_intelligent_defaults() current_training_job_name = _get_unique_name(self.base_job_name) - default_artifact_path = f"{self.base_job_name}/{current_training_job_name}" - input_data_key_prefix = f"{default_artifact_path}/input" + input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input" if input_data_config and self.input_data_config: self.input_data_config = input_data_config # Add missing input data channels to the existing input_data_config @@ -613,9 +612,15 @@ def train( ) if self.checkpoint_config and not self.checkpoint_config.s3_uri: - self.checkpoint_config.s3_uri = f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/{default_artifact_path}" - if self._tensorboard_output_config and not self._tensorboard_output_config.s3_uri: - self._tensorboard_output_config.s3_uri = f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/{default_artifact_path}" + self.checkpoint_config.s3_uri = ( + f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/" + f"{self.base_job_name}/{current_training_job_name}/checkpoints" + ) + if self._tensorboard_output_config and not self._tensorboard_output_config.s3_output_path: + self._tensorboard_output_config.s3_output_path = ( + f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/" + f"{self.base_job_name}" + ) string_hyper_parameters = {} if self.hyperparameters: @@ -646,7 +651,7 @@ def train( data_source=self.source_code.source_dir, key_prefix=input_data_key_prefix, ) - input_data_config.append(source_code_channel) + self.input_data_config.append(source_code_channel) self._prepare_train_script( tmp_dir=tmp_dir, @@ -667,7 +672,7 @@ def train( data_source=tmp_dir.name, key_prefix=input_data_key_prefix, ) - input_data_config.append(sm_drivers_channel) + self.input_data_config.append(sm_drivers_channel) # If source_code is provided, we will always use # the default container entrypoint and arguments @@ -970,10 +975,43 @@ def from_recipe( ) -> "ModelTrainer": """Create a ModelTrainer from a training recipe. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import Compute + + recipe_overrides = { + "run": { + "results_dir": "/opt/ml/model", + }, + "model": { + "data": { + "use_synthetic_data": True + } + } + } + + compute = Compute( + instance_type="ml.p5.48xlarge", + keep_alive_period_in_seconds=3600 + ) + + model_trainer = ModelTrainer.from_recipe( + training_recipe="fine-tuning/deepseek/hf_deepseek_r1_distilled_llama_8b_seq8k_gpu_fine_tuning", + recipe_overrides=recipe_overrides, + compute=compute, + ) + + model_trainer.train(wait=False) + + Args: training_recipe (str): The training recipe to use for training the model. This must be the name of a sagemaker training recipe or a path to a local training recipe .yaml file. + For available training recipes, see: https://github.com/aws/sagemaker-hyperpod-recipes/ compute (Compute): The compute configuration. This is used to specify the compute resources for the training job. If not specified, will default to 1 instance of ml.m5.xlarge. @@ -1081,55 +1119,116 @@ def from_recipe( return model_trainer def with_tensorboard_output_config( - self, tensorboard_output_config: TensorBoardOutputConfig + self, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None ) -> "ModelTrainer": """Set the TensorBoard output configuration. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_tensorboard_output_config() + Args: tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig): The TensorBoard output configuration. """ - self._tensorboard_output_config = tensorboard_output_config + self._tensorboard_output_config = tensorboard_output_config or TensorBoardOutputConfig() return self def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": """Set the retry strategy for the training job. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + from sagemaker.modules.configs import RetryStrategy + + retry_strategy = RetryStrategy(maximum_retry_attempts=3) + + model_trainer = ModelTrainer( + ... + ).with_retry_strategy(retry_strategy) + Args: - retry_strategy (RetryStrategy): + retry_strategy (sagemaker.modules.configs.RetryStrategy): The retry strategy for the training job. """ self._retry_strategy = retry_strategy return self - def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "ModelTrainer": + def with_infra_check_config( + self, infra_check_config: Optional[InfraCheckConfig] = None + ) -> "ModelTrainer": """Set the infra check configuration for the training job. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_infra_check_config() + Args: - infra_check_config (InfraCheckConfig): + infra_check_config (sagemaker.modules.configs.InfraCheckConfig): The infra check configuration for the training job. """ - self._infra_check_config = infra_check_config + self._infra_check_config = infra_check_config or InfraCheckConfig(enable_infra_check=True) return self def with_session_chaining_config( - self, session_chaining_config: SessionChainingConfig + self, session_chaining_config: Optional[SessionChainingConfig] = None ) -> "ModelTrainer": """Set the session chaining configuration for the training job. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_session_chaining_config() + Args: - session_chaining_config (SessionChainingConfig): + session_chaining_config (sagemaker.modules.configs.SessionChainingConfig): The session chaining configuration for the training job. """ - self._session_chaining_config = session_chaining_config + self._session_chaining_config = session_chaining_config or SessionChainingConfig( + enable_session_tag_chaining=True + ) return self - def with_remote_debug_config(self, remote_debug_config: RemoteDebugConfig) -> "ModelTrainer": + def with_remote_debug_config( + self, remote_debug_config: Optional[RemoteDebugConfig] = None + ) -> "ModelTrainer": """Set the remote debug configuration for the training job. + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_remote_debug_config() + Args: - remote_debug_config (RemoteDebugConfig): + remote_debug_config (sagemaker.modules.configs.RemoteDebugConfig): The remote debug configuration for the training job. """ - self._remote_debug_config = remote_debug_config + self._remote_debug_config = remote_debug_config or RemoteDebugConfig( + enable_remote_debug=True + ) return self diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 6001c5db36..489d35a67e 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -1228,3 +1228,37 @@ def test_hyperparameters_invalid(mock_exists, modules_session): compute=DEFAULT_COMPUTE_CONFIG, hyperparameters="hyperparameters.yaml", ) + + +@patch("sagemaker.modules.train.model_trainer._get_unique_name") +@patch("sagemaker.modules.train.model_trainer.TrainingJob") +def test_model_trainer_default_paths(mock_training_job, mock_unique_name, modules_session): + def mock_upload_data(path, bucket, key_prefix): + return f"s3://{bucket}/{key_prefix}" + + unique_name = "base-job-0123456789" + base_name = "base-job" + + modules_session.upload_data.side_effect = mock_upload_data + mock_unique_name.return_value = unique_name + + model_trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + sagemaker_session=modules_session, + checkpoint_config=CheckpointConfig(), + base_job_name=base_name, + ).with_tensorboard_output_config(TensorBoardOutputConfig()) + model_trainer.train() + + _, kwargs = mock_training_job.create.call_args + + default_base_path = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{base_name}" + + assert kwargs["output_data_config"].s3_output_path == default_base_path + assert kwargs["output_data_config"].compression_type == "GZIP" + + assert kwargs["checkpoint_config"].s3_uri == f"{default_base_path}/{unique_name}/checkpoints" + assert kwargs["checkpoint_config"].local_path == "/opt/ml/checkpoints" + + assert kwargs["tensor_board_output_config"].s3_output_path == default_base_path + assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard" From a50e8681c90f2af6ebf0a6fd1f6ec34c67ede240 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Mon, 12 May 2025 16:33:23 -0700 Subject: [PATCH 04/17] fix docstyle --- src/sagemaker/modules/train/model_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index b84904e7f8..1e6cdc537f 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -972,7 +972,7 @@ def from_recipe( sagemaker_session: Optional[Session] = None, role: Optional[str] = None, base_job_name: Optional[str] = None, - ) -> "ModelTrainer": + ) -> "ModelTrainer": # noqa: D412 """Create a ModelTrainer from a training recipe. Example: @@ -1120,7 +1120,7 @@ def from_recipe( def with_tensorboard_output_config( self, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None - ) -> "ModelTrainer": + ) -> "ModelTrainer": # noqa: D412 """Set the TensorBoard output configuration. Example: @@ -1140,7 +1140,7 @@ def with_tensorboard_output_config( self._tensorboard_output_config = tensorboard_output_config or TensorBoardOutputConfig() return self - def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": + def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": # noqa: D412 """Set the retry strategy for the training job. Example: @@ -1165,7 +1165,7 @@ def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": def with_infra_check_config( self, infra_check_config: Optional[InfraCheckConfig] = None - ) -> "ModelTrainer": + ) -> "ModelTrainer": # noqa: D412 """Set the infra check configuration for the training job. Example: @@ -1187,7 +1187,7 @@ def with_infra_check_config( def with_session_chaining_config( self, session_chaining_config: Optional[SessionChainingConfig] = None - ) -> "ModelTrainer": + ) -> "ModelTrainer": # noqa: D412 """Set the session chaining configuration for the training job. Example: @@ -1211,7 +1211,7 @@ def with_session_chaining_config( def with_remote_debug_config( self, remote_debug_config: Optional[RemoteDebugConfig] = None - ) -> "ModelTrainer": + ) -> "ModelTrainer": # noqa: D412 """Set the remote debug configuration for the training job. Example: From 6c750be5186a462c30baeb20643b569c2ef1d625 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Mon, 12 May 2025 16:44:08 -0700 Subject: [PATCH 05/17] fix input_data_config --- src/sagemaker/modules/configs.py | 4 ++-- src/sagemaker/modules/train/model_trainer.py | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index 5af5b049f7..e2855707e8 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -240,8 +240,8 @@ class OutputDataConfig(shapes.OutputDataConfig): SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side encryption. compression_type (Optional[str]): - The model output compression type. Select None to output an uncompressed model, - recommended for large model outputs. Defaults to gzip. + The model output compression type. Select `NONE` to output an uncompressed model, + recommended for large model outputs. Defaults to `GZIP`. """ s3_output_path: Optional[str] = None diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 1e6cdc537f..b9545e1e8c 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -597,14 +597,16 @@ def train( current_training_job_name = _get_unique_name(self.base_job_name) input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input" if input_data_config and self.input_data_config: - self.input_data_config = input_data_config - # Add missing input data channels to the existing input_data_config - final_input_channel_names = {i.channel_name for i in input_data_config} - for input_data in self.input_data_config: - if input_data.channel_name not in final_input_channel_names: - input_data_config.append(input_data) - - self.input_data_config = input_data_config or self.input_data_config or [] + final_channels = { + input_data.channel_name: input_data for input_data in self.input_data_config + } + # Update with precedence on the input_data_config passed into the train method + final_channels.update( + {input_data.channel_name: input_data for input_data in input_data_config} + ) + self.input_data_config = list(final_channels.values()) + else: + self.input_data_config = input_data_config or self.input_data_config or [] if self.input_data_config: input_data_config = self._get_input_data_config( @@ -699,7 +701,7 @@ def train( training_job_name=current_training_job_name, algorithm_specification=algorithm_specification, hyper_parameters=string_hyper_parameters, - input_data_config=input_data_config, + input_data_config=self.input_data_config, resource_config=resource_config, vpc_config=vpc_config, # Public Instance Attributes From 152a50c08c4bd06e2a3c181d4046f89767405ebf Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Mon, 12 May 2025 17:13:56 -0700 Subject: [PATCH 06/17] fix use input_data_config parameter in train as authoritative source --- src/sagemaker/modules/train/model_trainer.py | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index b9545e1e8c..45f86a27cd 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -596,20 +596,11 @@ def train( self._populate_intelligent_defaults() current_training_job_name = _get_unique_name(self.base_job_name) input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input" - if input_data_config and self.input_data_config: - final_channels = { - input_data.channel_name: input_data for input_data in self.input_data_config - } - # Update with precedence on the input_data_config passed into the train method - final_channels.update( - {input_data.channel_name: input_data for input_data in input_data_config} - ) - self.input_data_config = list(final_channels.values()) - else: - self.input_data_config = input_data_config or self.input_data_config or [] + + self.input_data_config = input_data_config or self.input_data_config or [] if self.input_data_config: - input_data_config = self._get_input_data_config( + self.input_data_config = self._get_input_data_config( self.input_data_config, input_data_key_prefix ) From 0510d982b9d11c22d3568c5a71be1154610760c8 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 13 May 2025 10:28:12 -0700 Subject: [PATCH 07/17] fix tests --- src/sagemaker/modules/train/model_trainer.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 45f86a27cd..4252590d71 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -25,7 +25,12 @@ from sagemaker_core.main import resources from sagemaker_core.resources import TrainingJob -from sagemaker_core.shapes import AlgorithmSpecification +from sagemaker_core.shapes import ( + AlgorithmSpecification, + OutputDataConfig, + CheckpointConfig, + TensorBoardOutputConfig +) from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call @@ -52,7 +57,6 @@ Compute, StoppingCondition, RetryStrategy, - OutputDataConfig, SourceCode, TrainingImageConfig, Channel, @@ -64,8 +68,6 @@ InfraCheckConfig, RemoteDebugConfig, SessionChainingConfig, - TensorBoardOutputConfig, - CheckpointConfig, InputData, ) @@ -737,7 +739,7 @@ def train( sagemaker_session=self.sagemaker_session, container_entrypoint=algorithm_specification.container_entrypoint, container_arguments=algorithm_specification.container_arguments, - input_data_config=input_data_config, + input_data_config=self.input_data_config, hyper_parameters=string_hyper_parameters, environment=self.environment, ) From f5791cea95b363dbb87a3677e07d5c4c8cbe9b3c Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 13 May 2025 10:42:38 -0700 Subject: [PATCH 08/17] format --- src/sagemaker/modules/train/model_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 4252590d71..1caf924c7e 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -29,7 +29,7 @@ AlgorithmSpecification, OutputDataConfig, CheckpointConfig, - TensorBoardOutputConfig + TensorBoardOutputConfig, ) from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call From 088d94941b94c7138d95f9a73739c974ba180221 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 13 May 2025 11:15:18 -0700 Subject: [PATCH 09/17] update checkpoint config --- src/sagemaker/modules/train/model_trainer.py | 27 ++++++++++++++++++- .../modules/train/test_model_trainer.py | 16 ++++++----- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 1caf924c7e..209c0ba569 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -53,6 +53,7 @@ from sagemaker.utils import resolve_value_from_config from sagemaker.modules import Session, get_execution_role +from sagemaker.modules import configs from sagemaker.modules.configs import ( Compute, StoppingCondition, @@ -1132,7 +1133,9 @@ def with_tensorboard_output_config( tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig): The TensorBoard output configuration. """ - self._tensorboard_output_config = tensorboard_output_config or TensorBoardOutputConfig() + self._tensorboard_output_config = ( + tensorboard_output_config or configs.TensorBoardOutputConfig() + ) return self def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": # noqa: D412 @@ -1227,3 +1230,25 @@ def with_remote_debug_config( enable_remote_debug=True ) return self + + def with_checkpoint_config( + self, checkpoint_config: Optional[CheckpointConfig] = None + ) -> "ModelTrainer": + """Set the checkpoint configuration for the training job. + + Example: + + .. code:: python + + from sagemaker.modules.train import ModelTrainer + + model_trainer = ModelTrainer( + ... + ).with_checkpoint_config() + + Args: + checkpoint_config (sagemaker.modules.configs.CheckpointConfig): + The checkpoint configuration for the training job. + """ + self.checkpoint_config = checkpoint_config or configs.CheckpointConfig() + return self diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 489d35a67e..e06b30f3e1 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -1242,12 +1242,16 @@ def mock_upload_data(path, bucket, key_prefix): modules_session.upload_data.side_effect = mock_upload_data mock_unique_name.return_value = unique_name - model_trainer = ModelTrainer( - training_image=DEFAULT_IMAGE, - sagemaker_session=modules_session, - checkpoint_config=CheckpointConfig(), - base_job_name=base_name, - ).with_tensorboard_output_config(TensorBoardOutputConfig()) + model_trainer = ( + ModelTrainer( + training_image=DEFAULT_IMAGE, + sagemaker_session=modules_session, + base_job_name=base_name, + ) + .with_tensorboard_output_config() + .with_checkpoint_config() + ) + model_trainer.train() _, kwargs = mock_training_job.create.call_args From a9a68badc1f945e73c2c72dd38ddf427b5a36c77 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 13 May 2025 11:58:41 -0700 Subject: [PATCH 10/17] docstyle --- src/sagemaker/modules/train/model_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 209c0ba569..8b4e08e51b 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -1233,7 +1233,7 @@ def with_remote_debug_config( def with_checkpoint_config( self, checkpoint_config: Optional[CheckpointConfig] = None - ) -> "ModelTrainer": + ) -> "ModelTrainer": # noqa: D412 """Set the checkpoint configuration for the training job. Example: From 5f6dd1a25afa366635eceb2599aaa977340ecda8 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 13 May 2025 13:28:34 -0700 Subject: [PATCH 11/17] make config creation backwards compatible --- src/sagemaker/modules/train/model_trainer.py | 28 +++++++++----------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 8b4e08e51b..6b424b24e8 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -25,11 +25,9 @@ from sagemaker_core.main import resources from sagemaker_core.resources import TrainingJob +from sagemaker_core import shapes from sagemaker_core.shapes import ( - AlgorithmSpecification, - OutputDataConfig, - CheckpointConfig, - TensorBoardOutputConfig, + AlgorithmSpecification ) from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call @@ -224,9 +222,9 @@ class ModelTrainer(BaseModel): training_image: Optional[str] = None training_image_config: Optional[TrainingImageConfig] = None algorithm_name: Optional[str] = None - output_data_config: Optional[OutputDataConfig] = None + output_data_config: Optional[shapes.OutputDataConfig] = None input_data_config: Optional[List[Union[Channel, InputData]]] = None - checkpoint_config: Optional[CheckpointConfig] = None + checkpoint_config: Optional[shapes.CheckpointConfig] = None training_input_mode: Optional[str] = "File" environment: Optional[Dict[str, str]] = {} hyperparameters: Optional[Union[Dict[str, Any], str]] = {} @@ -237,7 +235,7 @@ class ModelTrainer(BaseModel): _latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None) # Private TrainingJob Parameters - _tensorboard_output_config: Optional[TensorBoardOutputConfig] = PrivateAttr(default=None) + _tensorboard_output_config: Optional[shapes.TensorBoardOutputConfig] = PrivateAttr(default=None) _retry_strategy: Optional[RetryStrategy] = PrivateAttr(default=None) _infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None) _session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None) @@ -268,8 +266,8 @@ class ModelTrainer(BaseModel): "networking": Networking, "stopping_condition": StoppingCondition, "training_image_config": TrainingImageConfig, - "output_data_config": OutputDataConfig, - "checkpoint_config": CheckpointConfig, + "output_data_config": configs.OutputDataConfig, + "checkpoint_config": configs.CheckpointConfig, } def _populate_intelligent_defaults(self): @@ -321,7 +319,7 @@ def _populate_intelligent_defaults_from_training_job_space(self): config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH ) if default_output_data_config: - self.output_data_config = OutputDataConfig( + self.output_data_config = configs.OutputDataConfig( **self._convert_keys_to_snake(default_output_data_config) ) @@ -537,7 +535,7 @@ def model_post_init(self, __context: Any): if self.output_data_config is None: session = self.sagemaker_session base_job_name = self.base_job_name - self.output_data_config = OutputDataConfig( + self.output_data_config = configs.OutputDataConfig( s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}" f"/{base_job_name}", compression_type="GZIP", @@ -959,9 +957,9 @@ def from_recipe( requirements: Optional[str] = None, training_image: Optional[str] = None, training_image_config: Optional[TrainingImageConfig] = None, - output_data_config: Optional[OutputDataConfig] = None, + output_data_config: Optional[shapes.OutputDataConfig] = None, input_data_config: Optional[List[Union[Channel, InputData]]] = None, - checkpoint_config: Optional[CheckpointConfig] = None, + checkpoint_config: Optional[shapes.CheckpointConfig] = None, training_input_mode: Optional[str] = "File", environment: Optional[Dict[str, str]] = None, tags: Optional[List[Tag]] = None, @@ -1115,7 +1113,7 @@ def from_recipe( return model_trainer def with_tensorboard_output_config( - self, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None + self, tensorboard_output_config: Optional[shapes.TensorBoardOutputConfig] = None ) -> "ModelTrainer": # noqa: D412 """Set the TensorBoard output configuration. @@ -1232,7 +1230,7 @@ def with_remote_debug_config( return self def with_checkpoint_config( - self, checkpoint_config: Optional[CheckpointConfig] = None + self, checkpoint_config: Optional[shapes.CheckpointConfig] = None ) -> "ModelTrainer": # noqa: D412 """Set the checkpoint configuration for the training job. From 716e71b486d2f9ebc5c69f5698ff1ef07885ef86 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Tue, 13 May 2025 13:37:32 -0700 Subject: [PATCH 12/17] format --- src/sagemaker/modules/train/model_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index 6b424b24e8..e7ac5cf57b 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -26,9 +26,7 @@ from sagemaker_core.main import resources from sagemaker_core.resources import TrainingJob from sagemaker_core import shapes -from sagemaker_core.shapes import ( - AlgorithmSpecification -) +from sagemaker_core.shapes import AlgorithmSpecification from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call From d301784f194a65db9d833a44d434d72faf8a2364 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 14 May 2025 11:28:08 -0700 Subject: [PATCH 13/17] fix condition --- src/sagemaker/modules/train/model_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/modules/train/model_trainer.py b/src/sagemaker/modules/train/model_trainer.py index e7ac5cf57b..58ae724074 100644 --- a/src/sagemaker/modules/train/model_trainer.py +++ b/src/sagemaker/modules/train/model_trainer.py @@ -529,7 +529,7 @@ def model_post_init(self, __context: Any): "Must be a valid JSON or YAML file." ) - if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB and self.output_data_config is None: + if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB: if self.output_data_config is None: session = self.sagemaker_session base_job_name = self.base_job_name From f32a18614bbdaad1b53a7ecc98d3ba67da19d5ad Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 14 May 2025 12:47:40 -0700 Subject: [PATCH 14/17] fix Compute and Networking config when attributes are None --- src/sagemaker/modules/configs.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index e2855707e8..808b220311 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -150,8 +150,12 @@ def _to_resource_config(self) -> shapes.ResourceConfig: compute_config_dict = self.model_dump() resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys()) filtered_dict = { - k: v for k, v in compute_config_dict.items() if k in resource_config_fields + k: v + for k, v in compute_config_dict.items() + if k in resource_config_fields and v is not None } + if not filtered_dict: + return None return shapes.ResourceConfig(**filtered_dict) @@ -193,10 +197,14 @@ def _model_validator(self) -> "Networking": def _to_vpc_config(self) -> shapes.VpcConfig: """Convert to a sagemaker_core.shapes.VpcConfig object.""" compute_config_dict = self.model_dump() - resource_config_fields = set(shapes.VpcConfig.__annotations__.keys()) + vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys()) filtered_dict = { - k: v for k, v in compute_config_dict.items() if k in resource_config_fields + k: v + for k, v in compute_config_dict.items() + if k in vpc_config_fields and v is not None } + if not filtered_dict: + return None return shapes.VpcConfig(**filtered_dict) From 862fecd26b4b5f3e5bbe57800584d962f6a10785 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 14 May 2025 12:48:33 -0700 Subject: [PATCH 15/17] format --- src/sagemaker/modules/configs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sagemaker/modules/configs.py b/src/sagemaker/modules/configs.py index 808b220311..3739c73c5d 100644 --- a/src/sagemaker/modules/configs.py +++ b/src/sagemaker/modules/configs.py @@ -199,9 +199,7 @@ def _to_vpc_config(self) -> shapes.VpcConfig: compute_config_dict = self.model_dump() vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys()) filtered_dict = { - k: v - for k, v in compute_config_dict.items() - if k in vpc_config_fields and v is not None + k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None } if not filtered_dict: return None From a3485ae5ff892cb5d3088c892dfe5792c064e015 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 14 May 2025 13:44:20 -0700 Subject: [PATCH 16/17] fix --- tests/unit/sagemaker/modules/train/test_model_trainer.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index e06b30f3e1..2efa5a4235 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -326,11 +326,7 @@ def test_train_with_intelligent_defaults_training_job_space( resource_config=ResourceConfig( volume_size_in_gb=30, instance_type="ml.m5.xlarge", - instance_count=1, - volume_kms_key_id=None, - keep_alive_period_in_seconds=None, - instance_groups=None, - training_plan_arn=None, + instance_count=1 ), vpc_config=None, session=ANY, @@ -870,8 +866,6 @@ def mock_upload_data(path, bucket, key_prefix): volume_size_in_gb=compute.volume_size_in_gb, volume_kms_key_id=compute.volume_kms_key_id, keep_alive_period_in_seconds=compute.keep_alive_period_in_seconds, - instance_groups=None, - training_plan_arn=None, ), vpc_config=VpcConfig( security_group_ids=networking.security_group_ids, From 0f1e7132c9a2730e2518ecdfad79809f1fa8aab5 Mon Sep 17 00:00:00 2001 From: Erick Benitez-Ramos Date: Wed, 14 May 2025 14:09:59 -0700 Subject: [PATCH 17/17] format --- tests/unit/sagemaker/modules/train/test_model_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/sagemaker/modules/train/test_model_trainer.py b/tests/unit/sagemaker/modules/train/test_model_trainer.py index 2efa5a4235..b1348b5ac9 100644 --- a/tests/unit/sagemaker/modules/train/test_model_trainer.py +++ b/tests/unit/sagemaker/modules/train/test_model_trainer.py @@ -324,9 +324,7 @@ def test_train_with_intelligent_defaults_training_job_space( hyper_parameters={}, input_data_config=[], resource_config=ResourceConfig( - volume_size_in_gb=30, - instance_type="ml.m5.xlarge", - instance_count=1 + volume_size_in_gb=30, instance_type="ml.m5.xlarge", instance_count=1 ), vpc_config=None, session=ANY,