Skip to content

Commit

Permalink
feature: support session tag chaining for training job (#4596)
Browse files Browse the repository at this point in the history
* feature: support session tag chaining for training job

* fix: resolve typo

* fix: resolve typo and build failure

* fix: resolve typo and unit test failure

---------

Co-authored-by: Jessica Zhu <jessicazhu3@106775307+jessicazhu3@users.noreply.github.com>
  • Loading branch information
jessicazhu3 and Jessica Zhu committed Apr 24, 2024
1 parent 30c9bf6 commit fe32d79
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 1 deletion.
22 changes: 21 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: bool = False,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -544,7 +545,9 @@ def __init__(
enable_infra_check (bool or PipelineVariable): Optional.
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
Specifies whether RemoteDebug is enabled for the training job.
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job.
"""
instance_count = renamed_kwargs(
"train_instance_count", "instance_count", instance_count, kwargs
Expand Down Expand Up @@ -785,6 +788,8 @@ def __init__(

self._enable_remote_debug = enable_remote_debug

self._enable_session_tag_chaining = enable_session_tag_chaining

@abstractmethod
def training_image_uri(self):
"""Return the Docker image to use for training.
Expand Down Expand Up @@ -2318,6 +2323,14 @@ def get_remote_debug_config(self):
else {"EnableRemoteDebug": self._enable_remote_debug}
)

def get_session_chaining_config(self):
"""dict: Return the configuration of SessionChaining"""
return (
None
if self._enable_session_tag_chaining is None
else {"EnableSessionTagChaining": self._enable_session_tag_chaining}
)

def enable_remote_debug(self):
"""Enable remote debug for a training job."""
self._update_remote_debug(True)
Expand Down Expand Up @@ -2574,6 +2587,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
if estimator.get_remote_debug_config() is not None:
train_args["remote_debug_config"] = estimator.get_remote_debug_config()

if estimator.get_session_chaining_config() is not None:
train_args["session_chaining_config"] = estimator.get_session_chaining_config()

return train_args

@classmethod
Expand Down Expand Up @@ -2766,6 +2782,7 @@ def __init__(
disable_output_compression: bool = False,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -3129,6 +3146,8 @@ def __init__(
Specifies whether it is running Sagemaker built-in infra check jobs.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job
"""
self.image_uri = image_uri
self._hyperparameters = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -3181,6 +3200,7 @@ def __init__(
container_arguments=container_arguments,
disable_output_compression=disable_output_compression,
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
container_arguments: Optional[List[str]] = None,
disable_output_compression: Optional[bool] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
):
"""Initializes a ``JumpStartEstimator``.
Expand Down Expand Up @@ -500,6 +501,8 @@ def __init__(
to Amazon S3 without compression after training finishes.
enable_remote_debug (bool or PipelineVariable): Optional.
Specifies whether RemoteDebug is enabled for the training job
enable_session_tag_chaining (bool or PipelineVariable): Optional.
Specifies whether SessionTagChaining is enabled for the training job
Raises:
ValueError: If the model ID is not recognized by JumpStart.
Expand Down Expand Up @@ -578,6 +581,7 @@ def _validate_model_id_and_get_type_hook():
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
)

self.model_id = estimator_init_kwargs.model_id
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def get_init_kwargs(
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
) -> JumpStartEstimatorInitKwargs:
"""Returns kwargs required to instantiate `sagemaker.estimator.Estimator` object."""

Expand Down Expand Up @@ -188,6 +189,7 @@ def get_init_kwargs(
disable_output_compression=disable_output_compression,
enable_infra_check=enable_infra_check,
enable_remote_debug=enable_remote_debug,
enable_session_tag_chaining=enable_session_tag_chaining,
)

estimator_init_kwargs = _add_model_version_to_kwargs(estimator_init_kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"disable_output_compression",
"enable_infra_check",
"enable_remote_debug",
"enable_session_tag_chaining",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -1818,6 +1819,7 @@ def __init__(
disable_output_compression: Optional[bool] = None,
enable_infra_check: Optional[Union[bool, PipelineVariable]] = None,
enable_remote_debug: Optional[Union[bool, PipelineVariable]] = None,
enable_session_tag_chaining: Optional[Union[bool, PipelineVariable]] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand Down Expand Up @@ -1877,6 +1879,7 @@ def __init__(
self.disable_output_compression = disable_output_compression
self.enable_infra_check = enable_infra_check
self.enable_remote_debug = enable_remote_debug
self.enable_session_tag_chaining = enable_session_tag_chaining


class JumpStartEstimatorFitKwargs(JumpStartKwargs):
Expand Down
24 changes: 24 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ def train( # noqa: C901
environment: Optional[Dict[str, str]] = None,
retry_strategy=None,
remote_debug_config=None,
session_chaining_config=None,
):
"""Create an Amazon SageMaker training job.
Expand Down Expand Up @@ -877,6 +878,15 @@ def train( # noqa: C901
remote_debug_config = {
"EnableRemoteDebug": True,
}
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
The dict can contain 'EnableSessionTagChaining'(bool).
For example,
.. code:: python
session_chaining_config = {
"EnableSessionTagChaining": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
Expand Down Expand Up @@ -970,6 +980,7 @@ def train( # noqa: C901
profiler_rule_configs=profiler_rule_configs,
profiler_config=inferred_profiler_config,
remote_debug_config=remote_debug_config,
session_chaining_config=session_chaining_config,
environment=environment,
retry_strategy=retry_strategy,
)
Expand Down Expand Up @@ -1013,6 +1024,7 @@ def _get_train_request( # noqa: C901
profiler_rule_configs=None,
profiler_config=None,
remote_debug_config=None,
session_chaining_config=None,
environment=None,
retry_strategy=None,
):
Expand Down Expand Up @@ -1133,6 +1145,15 @@ def _get_train_request( # noqa: C901
remote_debug_config = {
"EnableRemoteDebug": True,
}
session_chaining_config(dict): Configuration for SessionChaining. (default: ``None``)
The dict can contain 'EnableSessionTagChaining'(bool).
For example,
.. code:: python
session_chaining_config = {
"EnableSessionTagChaining": True,
}
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
retry_strategy(dict): Defines RetryStrategy for InternalServerFailures.
Expand Down Expand Up @@ -1239,6 +1260,9 @@ def _get_train_request( # noqa: C901
if remote_debug_config is not None:
train_request["RemoteDebugConfig"] = remote_debug_config

if session_chaining_config is not None:
train_request["SessionChainingConfig"] = session_chaining_config

if retry_strategy is not None:
train_request["RetryStrategy"] = retry_strategy

Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2089,6 +2089,41 @@ def test_framework_disable_remote_debug(sagemaker_session):
assert len(args) == 2


def test_framework_with_session_chaining_config(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_groups=[
InstanceGroup("group1", "ml.c4.xlarge", 1),
InstanceGroup("group2", "ml.m4.xlarge", 2),
],
enable_session_tag_chaining=True,
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args["session_chaining_config"]["EnableSessionTagChaining"]
assert f.get_session_chaining_config()["EnableSessionTagChaining"]


def test_framework_without_session_chaining_config(sagemaker_session):
f = DummyFramework(
entry_point=SCRIPT_PATH,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_groups=[
InstanceGroup("group1", "ml.c4.xlarge", 1),
InstanceGroup("group2", "ml.m4.xlarge", 2),
],
)
f.fit("s3://mydata")
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
assert args.get("SessionTagChaining") is None
assert f.get_remote_debug_config() is None


@patch("time.strftime", return_value=TIMESTAMP)
def test_custom_code_bucket(time, sagemaker_session):
code_bucket = "codebucket"
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
CONTAINER_ENTRY_POINT = ["bin/bash", "test.sh"]
CONTAINER_ARGUMENTS = ["--arg1", "value1", "--arg2", "value2"]
remote_debug_config = {"EnableRemoteDebug": True}
session_chaining_config = {"EnableSessionTagChaining": True}

sagemaker_session.train(
image_uri=IMAGE,
Expand All @@ -2222,6 +2223,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
container_entry_point=CONTAINER_ENTRY_POINT,
container_arguments=CONTAINER_ARGUMENTS,
remote_debug_config=remote_debug_config,
session_chaining_config=session_chaining_config,
)

_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
Expand All @@ -2245,6 +2247,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
)
assert actual_train_args["AlgorithmSpecification"]["ContainerArguments"] == CONTAINER_ARGUMENTS
assert actual_train_args["RemoteDebugConfig"]["EnableRemoteDebug"]
assert actual_train_args["SessionChainingConfig"]["EnableSessionTagChaining"]


def test_create_transform_job_with_sagemaker_config_injection(sagemaker_session):
Expand Down

0 comments on commit fe32d79

Please sign in to comment.