From 19a3c1d40a4a9541043b84ff54771369cfce22c7 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 15:34:32 -0400 Subject: [PATCH 1/2] fix: Add additional dependencies for ModelTrainer (5668) --- .../src/sagemaker/core/training/configs.py | 5 + .../src/sagemaker/train/constants.py | 3 + .../src/sagemaker/train/model_trainer.py | 33 ++ .../src/sagemaker/train/templates.py | 15 + .../train/test_model_trainer_dependencies.py | 289 ++++++++++++++++++ 5 files changed, 345 insertions(+) create mode 100644 sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py diff --git a/sagemaker-core/src/sagemaker/core/training/configs.py b/sagemaker-core/src/sagemaker/core/training/configs.py index a308ed40ee..36cd4a5fbd 100644 --- a/sagemaker-core/src/sagemaker/core/training/configs.py +++ b/sagemaker-core/src/sagemaker/core/training/configs.py @@ -109,6 +109,10 @@ class SourceCode(BaseConfig): ignore_patterns: (Optional[List[str]]) : The ignore patterns to ignore specific files/folders when uploading to S3. If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints']. + dependencies (Optional[List[str]]): + A list of paths to local directories (absolute or relative) containing additional + libraries that will be copied into the training container and added to PYTHONPATH. + Each path must be a valid local directory or file. """ source_dir: Optional[StrPipeVar] = None @@ -123,6 +127,7 @@ class SourceCode(BaseConfig): ".cache", ".ipynb_checkpoints", ] + dependencies: Optional[List[str]] = None class OutputDataConfig(shapes.OutputDataConfig): """OutputDataConfig. diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 68b0f6c474..740e524155 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -63,6 +63,9 @@ "amazon.nova-pro-v1:0": ["us-east-1"] } +SM_DEPENDENCIES = "sm_dependencies" +SM_DEPENDENCIES_CONTAINER_PATH = "/opt/ml/input/data/sm_dependencies" + SM_RECIPE = "recipe" SM_RECIPE_YAML = "recipe.yaml" SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}" \ No newline at end of file diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index d07edeb025..1f4736b0ca 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -83,6 +83,8 @@ SM_CODE_CONTAINER_PATH, SM_DRIVERS, SM_DRIVERS_LOCAL_PATH, + SM_DEPENDENCIES, + SM_DEPENDENCIES_CONTAINER_PATH, SM_RECIPE, SM_RECIPE_YAML, SM_RECIPE_CONTAINER_PATH, @@ -99,6 +101,7 @@ EXECUTE_BASIC_SCRIPT_DRIVER, INSTALL_AUTO_REQUIREMENTS, INSTALL_REQUIREMENTS, + INSTALL_DEPENDENCIES, ) from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter from sagemaker.core.telemetry.constants import Feature @@ -484,6 +487,13 @@ def _validate_source_code(self, source_code: Optional[SourceCode]): f"Invalid 'entry_script': {entry_script}. " "Must be a valid file within the 'source_dir'.", ) + if source_code.dependencies: + for dep_path in source_code.dependencies: + if not _is_valid_path(dep_path): + raise ValueError( + f"Invalid dependency path: {dep_path}. " + "Each dependency must be a valid local directory or file path." + ) @staticmethod def _validate_and_fetch_hyperparameters_file(hyperparameters_file: str): @@ -654,6 +664,24 @@ def _create_training_job_args( ) final_input_data_config.append(source_code_channel) + # If dependencies are provided, create a channel for the dependencies + # The dependencies will be mounted at /opt/ml/input/data/sm_dependencies + if self.source_code.dependencies: + deps_tmp_dir = TemporaryDirectory() + for dep_path in self.source_code.dependencies: + dep_basename = os.path.basename(os.path.normpath(dep_path)) + dest_path = os.path.join(deps_tmp_dir.name, dep_basename) + if os.path.isdir(dep_path): + shutil.copytree(dep_path, dest_path, dirs_exist_ok=True) + else: + shutil.copy2(dep_path, dest_path) + dependencies_channel = self.create_input_data_channel( + channel_name=SM_DEPENDENCIES, + data_source=deps_tmp_dir.name, + key_prefix=input_data_key_prefix, + ) + final_input_data_config.append(dependencies_channel) + self._prepare_train_script( tmp_dir=self._temp_code_dir, source_code=self.source_code, @@ -1010,6 +1038,10 @@ def _prepare_train_script( base_command = source_code.command.split() base_command = " ".join(base_command) + install_dependencies = "" + if source_code.dependencies: + install_dependencies = INSTALL_DEPENDENCIES + install_requirements = "" if source_code.requirements: if self._jumpstart_config and source_code.requirements == "auto": @@ -1049,6 +1081,7 @@ def _prepare_train_script( train_script = TRAIN_SCRIPT_TEMPLATE.format( working_dir=working_dir, + install_dependencies=install_dependencies, install_requirements=install_requirements, execute_driver=execute_driver, ) diff --git a/sagemaker-train/src/sagemaker/train/templates.py b/sagemaker-train/src/sagemaker/train/templates.py index c943769618..e5aa04df9b 100644 --- a/sagemaker-train/src/sagemaker/train/templates.py +++ b/sagemaker-train/src/sagemaker/train/templates.py @@ -39,6 +39,20 @@ $SM_PIP_CMD install -r {requirements_file} """ +INSTALL_DEPENDENCIES = """ +echo "Setting up additional dependencies" +if [ -d /opt/ml/input/data/sm_dependencies ]; then + for dep_dir in /opt/ml/input/data/sm_dependencies/*/; do + if [ -d "$dep_dir" ]; then + echo "Adding $dep_dir to PYTHONPATH" + export PYTHONPATH="$dep_dir:$PYTHONPATH" + fi + done + # Also add the root dependencies dir in case of single files + export PYTHONPATH="/opt/ml/input/data/sm_dependencies:$PYTHONPATH" +fi +""" + EXEUCTE_DISTRIBUTED_DRIVER = """ echo "Running {driver_name} Driver" $SM_PYTHON_CMD /opt/ml/input/data/sm_drivers/distributed_drivers/{driver_script} @@ -95,6 +109,7 @@ set -x {working_dir} +{install_dependencies} {install_requirements} {execute_driver} diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py b/sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py new file mode 100644 index 0000000000..c383c96355 --- /dev/null +++ b/sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py @@ -0,0 +1,289 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tests for ModelTrainer dependencies feature.""" +from __future__ import absolute_import + +import os +import shutil +import tempfile +import json +import pytest +from unittest.mock import patch, MagicMock, ANY + +from sagemaker.core.helper.session_helper import Session +from sagemaker.train.model_trainer import ModelTrainer, Mode +from sagemaker.train.configs import ( + Compute, + StoppingCondition, + OutputDataConfig, + SourceCode, + InputData, +) +from sagemaker.train.constants import SM_DEPENDENCIES +from sagemaker.train.templates import INSTALL_DEPENDENCIES +from tests.unit import DATA_DIR + +DEFAULT_IMAGE = "000000000000.dkr.ecr.us-west-2.amazonaws.com/dummy-image:latest" +DEFAULT_BUCKET = "sagemaker-us-west-2-000000000000" +DEFAULT_ROLE = "arn:aws:iam::000000000000:role/test-role" +DEFAULT_BUCKET_PREFIX = "sample-prefix" +DEFAULT_REGION = "us-west-2" +DEFAULT_SOURCE_DIR = f"{DATA_DIR}/script_mode" +DEFAULT_INSTANCE_TYPE = "ml.m5.xlarge" +DEFAULT_COMPUTE_CONFIG = Compute(instance_type=DEFAULT_INSTANCE_TYPE, instance_count=1) +DEFAULT_OUTPUT_DATA_CONFIG = OutputDataConfig( + s3_output_path=f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/dummy-image-job", + compression_type="GZIP", + kms_key_id=None, +) +DEFAULT_STOPPING_CONDITION = StoppingCondition( + max_runtime_in_seconds=3600, + max_pending_time_in_seconds=None, + max_wait_time_in_seconds=None, +) + + +@pytest.fixture(scope="module", autouse=True) +def modules_session(): + with patch("sagemaker.train.Session", spec=Session) as session_mock: + session_instance = session_mock.return_value + session_instance.default_bucket.return_value = DEFAULT_BUCKET + session_instance.get_caller_identity_arn.return_value = DEFAULT_ROLE + session_instance.default_bucket_prefix = DEFAULT_BUCKET_PREFIX + session_instance.boto_session = MagicMock(spec="boto3.session.Session") + session_instance.boto_region_name = DEFAULT_REGION + yield session_instance + + +def test_source_code_with_dependencies_default_is_none(): + """Verify SourceCode.dependencies defaults to None.""" + source_code = SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + ) + assert source_code.dependencies is None + + +def test_source_code_with_dependencies_field_accepts_list_of_paths(): + """Verify SourceCode accepts dependencies as a list of valid directory paths.""" + # Create temporary directories to use as dependencies + dep_dir1 = tempfile.mkdtemp() + dep_dir2 = tempfile.mkdtemp() + try: + source_code = SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + dependencies=[dep_dir1, dep_dir2], + ) + assert source_code.dependencies == [dep_dir1, dep_dir2] + finally: + shutil.rmtree(dep_dir1) + shutil.rmtree(dep_dir2) + + +def test_validate_source_code_with_invalid_dependency_path_raises_value_error(modules_session): + """Verify _validate_source_code raises ValueError when a dependency path does not exist.""" + with pytest.raises(ValueError, match="Invalid dependency path"): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + source_code=SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + dependencies=["/nonexistent/path/to/dep"], + ), + ) + + +def test_source_code_dependencies_validation_with_valid_dirs(modules_session): + """Create SourceCode with dependencies pointing to multiple valid directories.""" + dep_dir1 = tempfile.mkdtemp() + dep_dir2 = tempfile.mkdtemp() + try: + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + source_code=SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + dependencies=[dep_dir1, dep_dir2], + ), + ) + assert trainer is not None + assert trainer.source_code.dependencies == [dep_dir1, dep_dir2] + finally: + shutil.rmtree(dep_dir1) + shutil.rmtree(dep_dir2) + + +@patch("sagemaker.train.model_trainer.TrainingJob") +def test_train_with_dependencies_creates_sm_dependencies_channel( + mock_training_job, modules_session +): + """Verify that training with dependencies creates an sm_dependencies channel.""" + dep_dir = tempfile.mkdtemp() + # Create a dummy file in the dependency directory + with open(os.path.join(dep_dir, "my_lib.py"), "w") as f: + f.write("# dummy library") + + modules_session.upload_data.return_value = ( + f"s3://{DEFAULT_BUCKET}/prefix/sm_dependencies" + ) + + try: + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + output_data_config=DEFAULT_OUTPUT_DATA_CONFIG, + stopping_condition=DEFAULT_STOPPING_CONDITION, + source_code=SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + dependencies=[dep_dir], + ), + ) + trainer.train() + + mock_training_job.create.assert_called_once() + input_data_config = mock_training_job.create.call_args.kwargs["input_data_config"] + channel_names = [ch.channel_name for ch in input_data_config] + assert SM_DEPENDENCIES in channel_names + finally: + shutil.rmtree(dep_dir) + + +@patch("sagemaker.train.model_trainer.TrainingJob") +def test_train_without_dependencies_does_not_create_dependencies_channel( + mock_training_job, modules_session +): + """Verify that training without dependencies does not create sm_dependencies channel.""" + modules_session.upload_data.return_value = f"s3://{DEFAULT_BUCKET}/prefix/code" + + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + output_data_config=DEFAULT_OUTPUT_DATA_CONFIG, + stopping_condition=DEFAULT_STOPPING_CONDITION, + source_code=SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + ), + ) + trainer.train() + + mock_training_job.create.assert_called_once() + input_data_config = mock_training_job.create.call_args.kwargs["input_data_config"] + channel_names = [ch.channel_name for ch in input_data_config] + assert SM_DEPENDENCIES not in channel_names + + +@patch("sagemaker.train.model_trainer.TrainingJob") +def test_train_with_empty_dependencies_list_does_not_create_channel( + mock_training_job, modules_session +): + """Verify that an empty dependencies list behaves the same as None.""" + modules_session.upload_data.return_value = f"s3://{DEFAULT_BUCKET}/prefix/code" + + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + output_data_config=DEFAULT_OUTPUT_DATA_CONFIG, + stopping_condition=DEFAULT_STOPPING_CONDITION, + source_code=SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + dependencies=[], + ), + ) + trainer.train() + + mock_training_job.create.assert_called_once() + input_data_config = mock_training_job.create.call_args.kwargs["input_data_config"] + channel_names = [ch.channel_name for ch in input_data_config] + assert SM_DEPENDENCIES not in channel_names + + +@patch("sagemaker.train.model_trainer.TrainingJob") +def test_train_with_dependencies_generates_pythonpath_setup_in_train_script( + mock_training_job, modules_session +): + """Verify that the generated train script contains PYTHONPATH setup for dependencies.""" + dep_dir = tempfile.mkdtemp() + with open(os.path.join(dep_dir, "my_lib.py"), "w") as f: + f.write("# dummy library") + + modules_session.upload_data.return_value = ( + f"s3://{DEFAULT_BUCKET}/prefix/sm_dependencies" + ) + + try: + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + output_data_config=DEFAULT_OUTPUT_DATA_CONFIG, + stopping_condition=DEFAULT_STOPPING_CONDITION, + source_code=SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + dependencies=[dep_dir], + ), + ) + trainer.train() + + # Check the generated train script in the temp directory + assert trainer._temp_code_dir is not None or True # temp dir may be cleaned up + # The key assertion is that the template was used - check via the training job call + mock_training_job.create.assert_called_once() + finally: + shutil.rmtree(dep_dir) + + +def test_dependencies_copied_to_temp_dir_preserving_basenames(modules_session): + """Verify that each dependency directory's basename is preserved when copied.""" + dep_dir = tempfile.mkdtemp(suffix="_mylib") + sub_file = os.path.join(dep_dir, "module.py") + with open(sub_file, "w") as f: + f.write("# test module") + + try: + trainer = ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + output_data_config=DEFAULT_OUTPUT_DATA_CONFIG, + stopping_condition=DEFAULT_STOPPING_CONDITION, + source_code=SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + dependencies=[dep_dir], + ), + ) + # Verify the source_code has the dependencies set + assert trainer.source_code.dependencies == [dep_dir] + dep_basename = os.path.basename(os.path.normpath(dep_dir)) + assert dep_basename.endswith("_mylib") + finally: + shutil.rmtree(dep_dir) From 4deb5ef9da1e82c22b01e20704720ec8079e6786 Mon Sep 17 00:00:00 2001 From: aviruthen <91846056+aviruthen@users.noreply.github.com> Date: Tue, 7 Apr 2026 18:57:24 -0400 Subject: [PATCH 2/2] fix: address review comments (iteration #1) --- .../src/sagemaker/core/training/configs.py | 1 + .../src/sagemaker/train/constants.py | 2 +- .../src/sagemaker/train/model_trainer.py | 21 +++--- .../src/sagemaker/train/templates.py | 21 ++++-- .../train/test_model_trainer_dependencies.py | 74 ++++++++++++++++--- 5 files changed, 91 insertions(+), 28 deletions(-) diff --git a/sagemaker-core/src/sagemaker/core/training/configs.py b/sagemaker-core/src/sagemaker/core/training/configs.py index 36cd4a5fbd..a470d9016e 100644 --- a/sagemaker-core/src/sagemaker/core/training/configs.py +++ b/sagemaker-core/src/sagemaker/core/training/configs.py @@ -129,6 +129,7 @@ class SourceCode(BaseConfig): ] dependencies: Optional[List[str]] = None + class OutputDataConfig(shapes.OutputDataConfig): """OutputDataConfig. diff --git a/sagemaker-train/src/sagemaker/train/constants.py b/sagemaker-train/src/sagemaker/train/constants.py index 740e524155..b92f6c92dc 100644 --- a/sagemaker-train/src/sagemaker/train/constants.py +++ b/sagemaker-train/src/sagemaker/train/constants.py @@ -68,4 +68,4 @@ SM_RECIPE = "recipe" SM_RECIPE_YAML = "recipe.yaml" -SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}" \ No newline at end of file +SM_RECIPE_CONTAINER_PATH = f"/opt/ml/input/data/recipe/{SM_RECIPE_YAML}" diff --git a/sagemaker-train/src/sagemaker/train/model_trainer.py b/sagemaker-train/src/sagemaker/train/model_trainer.py index 1f4736b0ca..09dcb58978 100644 --- a/sagemaker-train/src/sagemaker/train/model_trainer.py +++ b/sagemaker-train/src/sagemaker/train/model_trainer.py @@ -272,6 +272,7 @@ class ModelTrainer(BaseModel): # Private Attributes for AWS_Batch _temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) + _temp_deps_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None) CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [ "role", @@ -411,6 +412,8 @@ def __del__(self): self._temp_recipe_train_dir.cleanup() if self._temp_code_dir is not None: self._temp_code_dir.cleanup() + if self._temp_deps_dir is not None: + self._temp_deps_dir.cleanup() def _validate_training_image_and_algorithm_name( self, training_image: Optional[str], algorithm_name: Optional[str] @@ -489,10 +492,10 @@ def _validate_source_code(self, source_code: Optional[SourceCode]): ) if source_code.dependencies: for dep_path in source_code.dependencies: - if not _is_valid_path(dep_path): + if not _is_valid_path(dep_path, path_type="Directory"): raise ValueError( f"Invalid dependency path: {dep_path}. " - "Each dependency must be a valid local directory or file path." + "Each dependency must be a valid local directory path." ) @staticmethod @@ -667,17 +670,14 @@ def _create_training_job_args( # If dependencies are provided, create a channel for the dependencies # The dependencies will be mounted at /opt/ml/input/data/sm_dependencies if self.source_code.dependencies: - deps_tmp_dir = TemporaryDirectory() + self._temp_deps_dir = TemporaryDirectory() for dep_path in self.source_code.dependencies: dep_basename = os.path.basename(os.path.normpath(dep_path)) - dest_path = os.path.join(deps_tmp_dir.name, dep_basename) - if os.path.isdir(dep_path): - shutil.copytree(dep_path, dest_path, dirs_exist_ok=True) - else: - shutil.copy2(dep_path, dest_path) + dest_path = os.path.join(self._temp_deps_dir.name, dep_basename) + shutil.copytree(dep_path, dest_path, dirs_exist_ok=True) dependencies_channel = self.create_input_data_channel( channel_name=SM_DEPENDENCIES, - data_source=deps_tmp_dir.name, + data_source=self._temp_deps_dir.name, key_prefix=input_data_key_prefix, ) final_input_data_config.append(dependencies_channel) @@ -841,6 +841,9 @@ def train( local_container.train(wait) if self._temp_code_dir is not None: self._temp_code_dir.cleanup() + if self._temp_deps_dir is not None: + self._temp_deps_dir.cleanup() + self._temp_deps_dir = None def create_input_data_channel( diff --git a/sagemaker-train/src/sagemaker/train/templates.py b/sagemaker-train/src/sagemaker/train/templates.py index e5aa04df9b..0cb623def4 100644 --- a/sagemaker-train/src/sagemaker/train/templates.py +++ b/sagemaker-train/src/sagemaker/train/templates.py @@ -42,14 +42,23 @@ INSTALL_DEPENDENCIES = """ echo "Setting up additional dependencies" if [ -d /opt/ml/input/data/sm_dependencies ]; then - for dep_dir in /opt/ml/input/data/sm_dependencies/*/; do - if [ -d "$dep_dir" ]; then - echo "Adding $dep_dir to PYTHONPATH" - export PYTHONPATH="$dep_dir:$PYTHONPATH" + for dep in /opt/ml/input/data/sm_dependencies/*; do + if [ -d "$dep" ]; then + echo "Adding directory $dep to PYTHONPATH" + export PYTHONPATH="$dep:$PYTHONPATH" + elif [ -f "$dep" ]; then + case "$dep" in + *.whl|*.tar.gz) + echo "Installing package $dep via pip" + $SM_PIP_CMD install "$dep" + ;; + *) + echo "Adding parent directory of $dep to PYTHONPATH" + export PYTHONPATH="/opt/ml/input/data/sm_dependencies:$PYTHONPATH" + ;; + esac fi done - # Also add the root dependencies dir in case of single files - export PYTHONPATH="/opt/ml/input/data/sm_dependencies:$PYTHONPATH" fi """ diff --git a/sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py b/sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py index c383c96355..ebbc973ec5 100644 --- a/sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py +++ b/sagemaker-train/tests/unit/train/test_model_trainer_dependencies.py @@ -11,7 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Tests for ModelTrainer dependencies feature.""" -from __future__ import absolute_import +from __future__ import annotations import os import shutil @@ -53,7 +53,7 @@ ) -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(autouse=True) def modules_session(): with patch("sagemaker.train.Session", spec=Session) as session_mock: session_instance = session_mock.return_value @@ -107,6 +107,27 @@ def test_validate_source_code_with_invalid_dependency_path_raises_value_error(mo ) +def test_validate_source_code_with_file_dependency_raises_value_error(modules_session): + """Verify _validate_source_code raises ValueError when a dependency is a file, not a directory.""" + dep_file = tempfile.NamedTemporaryFile(suffix=".py", delete=False) + dep_file.close() + try: + with pytest.raises(ValueError, match="Invalid dependency path"): + ModelTrainer( + training_image=DEFAULT_IMAGE, + role=DEFAULT_ROLE, + sagemaker_session=modules_session, + compute=DEFAULT_COMPUTE_CONFIG, + source_code=SourceCode( + source_dir=DEFAULT_SOURCE_DIR, + entry_script="custom_script.py", + dependencies=[dep_file.name], + ), + ) + finally: + os.unlink(dep_file.name) + + def test_source_code_dependencies_validation_with_valid_dirs(modules_session): """Create SourceCode with dependencies pointing to multiple valid directories.""" dep_dir1 = tempfile.mkdtemp() @@ -250,23 +271,39 @@ def test_train_with_dependencies_generates_pythonpath_setup_in_train_script( dependencies=[dep_dir], ), ) - trainer.train() - - # Check the generated train script in the temp directory - assert trainer._temp_code_dir is not None or True # temp dir may be cleaned up - # The key assertion is that the template was used - check via the training job call - mock_training_job.create.assert_called_once() + # Call _create_training_job_args to generate the train script without cleanup + trainer._create_training_job_args() + + # Read the generated train script and verify it contains PYTHONPATH setup + assert trainer._temp_code_dir is not None + train_script_path = os.path.join(trainer._temp_code_dir.name, "sm_train.sh") + with open(train_script_path) as f: + script_content = f.read() + assert "sm_dependencies" in script_content + assert "PYTHONPATH" in script_content + assert "Setting up additional dependencies" in script_content finally: + if trainer._temp_code_dir is not None: + trainer._temp_code_dir.cleanup() + if trainer._temp_deps_dir is not None: + trainer._temp_deps_dir.cleanup() shutil.rmtree(dep_dir) -def test_dependencies_copied_to_temp_dir_preserving_basenames(modules_session): +@patch("sagemaker.train.model_trainer.TrainingJob") +def test_dependencies_copied_to_temp_dir_preserving_basenames( + mock_training_job, modules_session +): """Verify that each dependency directory's basename is preserved when copied.""" dep_dir = tempfile.mkdtemp(suffix="_mylib") sub_file = os.path.join(dep_dir, "module.py") with open(sub_file, "w") as f: f.write("# test module") + modules_session.upload_data.return_value = ( + f"s3://{DEFAULT_BUCKET}/prefix/sm_dependencies" + ) + try: trainer = ModelTrainer( training_image=DEFAULT_IMAGE, @@ -281,9 +318,22 @@ def test_dependencies_copied_to_temp_dir_preserving_basenames(modules_session): dependencies=[dep_dir], ), ) - # Verify the source_code has the dependencies set - assert trainer.source_code.dependencies == [dep_dir] + # Call _create_training_job_args to trigger the copy + trainer._create_training_job_args() + + # Verify the dependencies were copied preserving basenames + assert trainer._temp_deps_dir is not None dep_basename = os.path.basename(os.path.normpath(dep_dir)) - assert dep_basename.endswith("_mylib") + copied_dep_path = os.path.join(trainer._temp_deps_dir.name, dep_basename) + assert os.path.isdir(copied_dep_path), ( + f"Expected dependency directory {copied_dep_path} to exist" + ) + assert os.path.isfile(os.path.join(copied_dep_path, "module.py")), ( + "Expected module.py to be copied into the dependency directory" + ) finally: + if trainer._temp_code_dir is not None: + trainer._temp_code_dir.cleanup() + if trainer._temp_deps_dir is not None: + trainer._temp_deps_dir.cleanup() shutil.rmtree(dep_dir)