Skip to content

feat: SDK Defaults - DebugHookConfig defaults in TrainingJob API #3947

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/sagemaker/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sagemaker.config.config_schema import ( # noqa: F401
KEY,
TRAINING_JOB,
ESTIMATOR_DEBUG_HOOK_CONFIG_PATH,
TRAINING_JOB_INTER_CONTAINER_ENCRYPTION_PATH,
TRAINING_JOB_ROLE_ARN_PATH,
TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
Expand Down Expand Up @@ -158,4 +159,6 @@
CONTAINERS,
PRIMARY_CONTAINER,
INFERENCE_SPECIFICATION,
ESTIMATOR,
DEBUG_HOOK_CONFIG,
)
31 changes: 31 additions & 0 deletions src/sagemaker/config/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@
INFERENCE_SPECIFICATION = "InferenceSpecification"
PROFILER_CONFIG = "ProfilerConfig"
DISABLE_PROFILER = "DisableProfiler"
ESTIMATOR = "Estimator"
DEBUG_HOOK_CONFIG = "DebugHookConfig"


def _simple_path(*args: str):
Expand Down Expand Up @@ -338,6 +340,9 @@ def _simple_path(*args: str):
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, SESSION, DEFAULT_S3_OBJECT_KEY_PREFIX
)
ESTIMATOR_DEBUG_HOOK_CONFIG_PATH = _simple_path(
SAGEMAKER, PYTHON_SDK, MODULES, ESTIMATOR, DEBUG_HOOK_CONFIG
)


SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA = {
Expand Down Expand Up @@ -645,6 +650,27 @@ def _simple_path(*args: str):
},
},
},
ESTIMATOR: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {
DEBUG_HOOK_CONFIG: {
TYPE: "boolean",
"description": (
"Sets a boolean for `debugger_hook_config` of"
"Estimator which will be then used for training job"
"API call. Today, the config_schema doesn't support"
"a dictionary as a valid value to be provided."
"In the future to add support for DebugHookConfig"
"as a dictionary, schema should be added under"
"the config path `SageMaker.TrainingJob` instead of"
"here, since the TrainingJob API supports"
"DebugHookConfig as a dictionary, we can add"
"a schema for it at API level."
),
},
},
},
REMOTE_FUNCTION: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
Expand Down Expand Up @@ -990,6 +1016,11 @@ def _simple_path(*args: str):
},
# Training Job
# https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CreateTrainingJob.html
# Please note that we currently support 'DebugHookConfig' as a boolean value
# which can be provided under [SageMaker.PythonSDK.Modules.Estimator] config path.
# As of today, config_schema does not support the dict as a valid value to be
# provided. In case, we decide to support it in the future, we can add a new schema
# for it under [SageMaker.TrainingJob] config path.
TRAINING_JOB: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
Expand Down
22 changes: 21 additions & 1 deletion src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from sagemaker import git_utils, image_uris, vpc_utils, s3
from sagemaker.analytics import TrainingJobAnalytics
from sagemaker.config import (
ESTIMATOR_DEBUG_HOOK_CONFIG_PATH,
TRAINING_JOB_VOLUME_KMS_KEY_ID_PATH,
TRAINING_JOB_SECURITY_GROUP_IDS_PATH,
TRAINING_JOB_SUBNETS_PATH,
Expand Down Expand Up @@ -675,7 +676,26 @@ def __init__(
self.checkpoint_local_path = checkpoint_local_path

self.rules = rules
self.debugger_hook_config = debugger_hook_config

# Today, we ONLY support debugger_hook_config to be provided as a boolean value
# from sagemaker_config. We resolve value for this parameter as per the order
# 1. value from direct_input which can be a boolean or a dictionary
# 2. value from sagemaker_config which can be a boolean
# In future, if we support debugger_hook_config to be provided as a dictionary
# from sagemaker_config [SageMaker.TrainingJob] then we will need to update the
# logic below to resolve the values as per the type of value received from
# direct_input and sagemaker_config
self.debugger_hook_config = resolve_value_from_config(
direct_input=debugger_hook_config,
config_path=ESTIMATOR_DEBUG_HOOK_CONFIG_PATH,
sagemaker_session=sagemaker_session,
)
# If customer passes True from either direct_input or sagemaker_config, we will
# create a default hook config as an empty dict which will later be populated
# with default s3_output_path from _prepare_debugger_for_training function
if self.debugger_hook_config is True:
self.debugger_hook_config = {}

self.tensorboard_output_config = tensorboard_output_config

self.debugger_rule_configs = None
Expand Down
2 changes: 2 additions & 0 deletions tests/data/config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ SageMaker:
Session:
DefaultS3Bucket: 'sagemaker-python-sdk-test-bucket'
DefaultS3ObjectKeyPrefix: 'test-prefix'
Estimator:
DebugHookConfig: false
RemoteFunction:
Dependencies: "./requirements.txt"
EnvironmentVariables:
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@
CONTAINERS,
PRIMARY_CONTAINER,
INFERENCE_SPECIFICATION,
ESTIMATOR,
DEBUG_HOOK_CONFIG,
)

DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data")
Expand Down Expand Up @@ -323,6 +325,13 @@
SAGEMAKER_CONFIG_TRAINING_JOB = {
SCHEMA_VERSION: "1.0",
SAGEMAKER: {
PYTHON_SDK: {
MODULES: {
ESTIMATOR: {
DEBUG_HOOK_CONFIG: False,
},
},
},
TRAINING_JOB: {
ENABLE_INTER_CONTAINER_TRAFFIC_ENCRYPTION: True,
ENABLE_NETWORK_ISOLATION: True,
Expand All @@ -337,6 +346,32 @@
},
}

SAGEMAKER_CONFIG_TRAINING_JOB_WITH_DEBUG_HOOK_CONFIG_AS_FALSE = {
SCHEMA_VERSION: "1.0",
SAGEMAKER: {
PYTHON_SDK: {
MODULES: {
ESTIMATOR: {
DEBUG_HOOK_CONFIG: False,
},
},
},
},
}

SAGEMAKER_CONFIG_TRAINING_JOB_WITH_DEBUG_HOOK_CONFIG_AS_TRUE = {
SCHEMA_VERSION: "1.0",
SAGEMAKER: {
PYTHON_SDK: {
MODULES: {
ESTIMATOR: {
DEBUG_HOOK_CONFIG: True,
},
},
},
},
}

SAGEMAKER_CONFIG_TRANSFORM_JOB = {
SCHEMA_VERSION: "1.0",
SAGEMAKER: {
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/sagemaker/config/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ def valid_session_config():
}


@pytest.fixture()
def valid_estimator_config():
return {
"DebugHookConfig": False,
}


@pytest.fixture()
def valid_environment_config():
return {
Expand Down Expand Up @@ -251,10 +258,12 @@ def valid_config_with_all_the_scopes(
valid_training_job_config,
valid_edge_packaging_config,
valid_remote_function_config,
valid_estimator_config,
):
return {
"PythonSDK": {
"Modules": {
"Estimator": valid_estimator_config,
"RemoteFunction": valid_remote_function_config,
"Session": valid_session_config,
}
Expand Down
19 changes: 19 additions & 0 deletions tests/unit/sagemaker/config/test_config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,25 @@ def test_valid_remote_function_schema(base_config_with_schema, valid_remote_func
)


def test_valid_estimator_schema(base_config_with_schema, valid_estimator_config):
_validate_config(
base_config_with_schema,
{"PythonSDK": {"Modules": {"Estimator": valid_estimator_config}}},
)


def test_invalid_estimator_schema(base_config_with_schema, valid_estimator_config):
invalid_estimator_config = {
"DebugHookConfig": {
"S3OutputPath": "s3://somepath",
}
}
config = base_config_with_schema
config["SageMaker"] = {"PythonSDK": {"Modules": {"Estimator": invalid_estimator_config}}}
with pytest.raises(exceptions.ValidationError):
validate(config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)


def test_tags_with_invalid_schema(base_config_with_schema, valid_edge_packaging_config):
edge_packaging_config = valid_edge_packaging_config.copy()
edge_packaging_config["Tags"] = [{"Key": "somekey"}]
Expand Down
Loading