Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sagemaker-train/src/sagemaker/train/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,6 @@ def _is_nova_model_for_telemetry(self) -> bool:
return _is_nova_model(model_name) if model_name else False

@abstractmethod
def train(self, input_data_config: List[InputData], wait: bool = True, logs: bool = True):
def train(self, input_data_config: List[InputData], wait: bool = True, logs: bool = True, wait_timeout: Optional[int] = None):
"""Common training method that calls the specific implementation."""
pass
11 changes: 9 additions & 2 deletions sagemaker-train/src/sagemaker/train/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@ def _process_hyperparameters(self):
def train(self,
training_dataset: Optional[Union[str, DataSet]] = None,
validation_dataset: Optional[Union[str, DataSet]] = None,
wait: bool = True):
wait: bool = True,
wait_timeout: Optional[int] = None):
"""Execute the DPO training job.

Parameters:
Expand All @@ -192,6 +193,9 @@ def train(self,
Can be an S3 URI, dataset ARN, or DataSet object.
wait (bool):
Whether to wait for the training job to complete. Defaults to True.
wait_timeout (Optional[int]):
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
If None, uses the default timeout from the wait utility.

Returns:
TrainingJob: The SageMaker training job object.
Expand Down Expand Up @@ -276,7 +280,10 @@ def train(self,
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
try :
_wait(training_job)
wait_kwargs = {}
if wait_timeout is not None:
wait_kwargs['timeout'] = wait_timeout
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)

Expand Down
10 changes: 8 additions & 2 deletions sagemaker-train/src/sagemaker/train/rlaif_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _validate_reward_model_id(self, reward_model_id):


@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLAIFTrainer.train")
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None):
"""Execute the RLAIF training job.

Parameters:
Expand All @@ -209,6 +209,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
Can be an S3 URI, dataset ARN, or DataSet object.
wait (bool):
Whether to wait for the training job to complete. Defaults to True.
wait_timeout (Optional[int]):
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
If None, uses the default timeout from the wait utility.

Returns:
TrainingJob: The SageMaker training job object.
Expand Down Expand Up @@ -295,7 +298,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
try :
_wait(training_job)
wait_kwargs = {}
if wait_timeout is not None:
wait_kwargs['timeout'] = wait_timeout
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)

Expand Down
10 changes: 8 additions & 2 deletions sagemaker-train/src/sagemaker/train/rlvr_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def _process_hyperparameters(self):

@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="RLVRTrainer.train")
def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None):
"""Execute the RLVR training job.

Parameters:
Expand All @@ -195,6 +195,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
Can be an S3 URI, dataset ARN, or DataSet object.
wait (bool):
Whether to wait for the training job to complete. Defaults to True.
wait_timeout (Optional[int]):
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
If None, uses the default timeout from the wait utility.

Returns:
TrainingJob: The SageMaker training job object.
Expand Down Expand Up @@ -283,7 +286,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None,
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
try:
_wait(training_job)
wait_kwargs = {}
if wait_timeout is not None:
wait_kwargs['timeout'] = wait_timeout
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)

Expand Down
10 changes: 8 additions & 2 deletions sagemaker-train/src/sagemaker/train/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def _process_hyperparameters(self):
self.hyperparameters._specs.pop('validation_data_path', None)

@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="SFTTrainer.train")
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True):
def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validation_dataset: Optional[Union[str, DataSet]] = None, wait: bool = True, wait_timeout: Optional[int] = None):
"""Execute the SFT training job.

Parameters:
Expand All @@ -192,6 +192,9 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
Can be an S3 URI, dataset ARN, or DataSet object.
wait (bool):
Whether to wait for the training job to complete. Defaults to True.
wait_timeout (Optional[int]):
Maximum time in seconds to wait for the training job to complete. Only used when wait=True.
If None, uses the default timeout from the wait utility.

Returns:
TrainingJob: The SageMaker training job object.
Expand Down Expand Up @@ -277,7 +280,10 @@ def train(self, training_dataset: Optional[Union[str, DataSet]] = None, validati
from sagemaker.train.common_utils.trainer_wait import wait as _wait
from sagemaker.core.utils.exceptions import TimeoutExceededError
try :
_wait(training_job)
wait_kwargs = {}
if wait_timeout is not None:
wait_kwargs['timeout'] = wait_timeout
_wait(training_job, **wait_kwargs)
except TimeoutExceededError as e:
logger.error("Error: %s", e)

Expand Down
129 changes: 129 additions & 0 deletions sagemaker-train/tests/unit/train/test_dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,132 @@ def test_accepts_stopping_condition(self, mock_finetuning, mock_validate):

assert trainer.stopping_condition == stopping_condition
assert trainer.stopping_condition.max_runtime_in_seconds == 14400

@patch('sagemaker.train.common_utils.trainer_wait.wait')
@patch('sagemaker.train.dpo_trainer._resolve_model_and_name')
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role')
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session')
@patch('sagemaker.train.dpo_trainer._get_unique_name')
@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.dpo_trainer._create_input_data_config')
@patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels')
@patch('sagemaker.train.dpo_trainer._create_output_config')
@patch('sagemaker.train.dpo_trainer._create_serverless_config')
@patch('sagemaker.train.dpo_trainer._create_mlflow_config')
@patch('sagemaker.train.dpo_trainer._create_model_package_config')
@patch('sagemaker.core.resources.TrainingJob.create')
def test_train_passes_wait_timeout(self, mock_training_job_create, mock_model_package_config,
mock_mlflow_config, mock_serverless_config, mock_output_config,
mock_convert_channels, mock_input_config, mock_validate_group,
mock_unique_name, mock_get_sagemaker_session, mock_get_role,
mock_get_options, mock_resolve_model, mock_wait):
"""Test that wait_timeout is passed to _wait as timeout kwarg."""
mock_validate_group.return_value = "test-group"
mock_resolve_model.return_value = ("test-model", "test-model")
mock_get_sagemaker_session.return_value = Mock()
mock_fine_tuning_options = Mock()
mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"}
mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
mock_get_role.return_value = "test-role"
mock_unique_name.return_value = "test-job-name"
mock_input_config.return_value = [Mock()]
mock_convert_channels.return_value = [Mock()]
mock_output_config.return_value = Mock()
mock_serverless_config.return_value = Mock()
mock_mlflow_config.return_value = Mock()
mock_model_package_config.return_value = Mock()
mock_training_job = Mock()
mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
mock_training_job_create.return_value = mock_training_job

trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
trainer.train(wait=True, wait_timeout=600)

mock_wait.assert_called_once_with(mock_training_job, timeout=600)

@patch('sagemaker.train.common_utils.trainer_wait.wait')
@patch('sagemaker.train.dpo_trainer._resolve_model_and_name')
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role')
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session')
@patch('sagemaker.train.dpo_trainer._get_unique_name')
@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.dpo_trainer._create_input_data_config')
@patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels')
@patch('sagemaker.train.dpo_trainer._create_output_config')
@patch('sagemaker.train.dpo_trainer._create_serverless_config')
@patch('sagemaker.train.dpo_trainer._create_mlflow_config')
@patch('sagemaker.train.dpo_trainer._create_model_package_config')
@patch('sagemaker.core.resources.TrainingJob.create')
def test_train_without_wait_timeout_uses_default(self, mock_training_job_create, mock_model_package_config,
mock_mlflow_config, mock_serverless_config, mock_output_config,
mock_convert_channels, mock_input_config, mock_validate_group,
mock_unique_name, mock_get_sagemaker_session, mock_get_role,
mock_get_options, mock_resolve_model, mock_wait):
"""Test that _wait is called without timeout kwarg when wait_timeout is None."""
mock_validate_group.return_value = "test-group"
mock_resolve_model.return_value = ("test-model", "test-model")
mock_get_sagemaker_session.return_value = Mock()
mock_fine_tuning_options = Mock()
mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"}
mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
mock_get_role.return_value = "test-role"
mock_unique_name.return_value = "test-job-name"
mock_input_config.return_value = [Mock()]
mock_convert_channels.return_value = [Mock()]
mock_output_config.return_value = Mock()
mock_serverless_config.return_value = Mock()
mock_mlflow_config.return_value = Mock()
mock_model_package_config.return_value = Mock()
mock_training_job = Mock()
mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
mock_training_job_create.return_value = mock_training_job

trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
trainer.train(wait=True)

mock_wait.assert_called_once_with(mock_training_job)

@patch('sagemaker.train.common_utils.trainer_wait.wait')
@patch('sagemaker.train.dpo_trainer._resolve_model_and_name')
@patch('sagemaker.train.dpo_trainer._get_fine_tuning_options_and_model_arn')
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_role')
@patch('sagemaker.train.dpo_trainer.TrainDefaults.get_sagemaker_session')
@patch('sagemaker.train.dpo_trainer._get_unique_name')
@patch('sagemaker.train.dpo_trainer._validate_and_resolve_model_package_group')
@patch('sagemaker.train.dpo_trainer._create_input_data_config')
@patch('sagemaker.train.dpo_trainer._convert_input_data_to_channels')
@patch('sagemaker.train.dpo_trainer._create_output_config')
@patch('sagemaker.train.dpo_trainer._create_serverless_config')
@patch('sagemaker.train.dpo_trainer._create_mlflow_config')
@patch('sagemaker.train.dpo_trainer._create_model_package_config')
@patch('sagemaker.core.resources.TrainingJob.create')
def test_train_wait_false_skips_wait(self, mock_training_job_create, mock_model_package_config,
mock_mlflow_config, mock_serverless_config, mock_output_config,
mock_convert_channels, mock_input_config, mock_validate_group,
mock_unique_name, mock_get_sagemaker_session, mock_get_role,
mock_get_options, mock_resolve_model, mock_wait):
"""Test that _wait is not called when wait=False."""
mock_validate_group.return_value = "test-group"
mock_resolve_model.return_value = ("test-model", "test-model")
mock_get_sagemaker_session.return_value = Mock()
mock_fine_tuning_options = Mock()
mock_fine_tuning_options.to_dict.return_value = {"learning_rate": "0.001"}
mock_get_options.return_value = (mock_fine_tuning_options, "model-arn", False)
mock_get_role.return_value = "test-role"
mock_unique_name.return_value = "test-job-name"
mock_input_config.return_value = [Mock()]
mock_convert_channels.return_value = [Mock()]
mock_output_config.return_value = Mock()
mock_serverless_config.return_value = Mock()
mock_mlflow_config.return_value = Mock()
mock_model_package_config.return_value = Mock()
mock_training_job = Mock()
mock_training_job.arn = "arn:aws:sagemaker:us-east-1:123456789012:training-job/test-job"
mock_training_job_create.return_value = mock_training_job

trainer = DPOTrainer(model="test-model", model_package_group="test-group", training_dataset="s3://bucket/train")
trainer.train(wait=False, wait_timeout=600)

mock_wait.assert_not_called()
Loading
Loading