diff --git a/.gitignore b/.gitignore
index 9829ed9781..cae8f890ea 100644
--- a/.gitignore
+++ b/.gitignore
@@ -30,5 +30,6 @@ env/
.vscode/
**/tmp
.python-version
-**/_repack_model.py
-**/_repack_script_launcher.sh
\ No newline at end of file
+**/_repack_script_launcher.sh
+tests/data/**/_repack_model.py
+tests/data/experiment/sagemaker-dev-1.0.tar.gz
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 95e4a7b9cf..de20a8a0df 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,85 @@
# Changelog
+## v2.122.0 (2022-12-14)
+
+### Features
+
+ * Feature Store dataset builder, delete_record, get_record, list_feature_group
+ * Add OSU region to frameworks for DLC
+
+### Bug Fixes and Other Changes
+
+ * the Hyperband support fix for the HPO
+ * unpin packaging version
+ * Remove content type image/jpg from analysis configuration schema
+
+## v2.121.2 (2022-12-12)
+
+### Bug Fixes and Other Changes
+
+ * Update for Tensorflow Serving 2.11 inference DLCs
+ * Revert "fix: type hint of PySparkProcessor __init__"
+ * Skip Bad Transform Test
+
+## v2.121.1 (2022-12-09)
+
+### Bug Fixes and Other Changes
+
+ * Pop out ModelPackageName from pipeline definition
+ * Fix failing jumpstart cache unit tests
+
+## v2.121.0 (2022-12-08)
+
+### Features
+
+ * Algorithms Region Expansion OSU/DXB
+
+### Bug Fixes and Other Changes
+
+ * FrameworkProcessor S3 uploads
+ * Add constraints file for apache-airflow
+
+## v2.120.0 (2022-12-07)
+
+### Features
+
+ * Add Neo image uri config for Pytorch 1.12
+ * Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12
+ * Update registries with new region account number mappings.
+ * Add DXB region to frameworks by DLC
+
+### Bug Fixes and Other Changes
+
+ * support idempotency for framework and spark processors
+
+## v2.119.0 (2022-12-03)
+
+### Features
+
+ * Add Code Owners file
+ * Added transform with monitoring pipeline step in transformer
+ * Update TF 2.9 and TF 2.10 inference DLCs
+ * make estimator accept json file as modelparallel config
+ * SageMaker Training Compiler does not support p4de instances
+ * Add support for SparkML v3.3
+
+### Bug Fixes and Other Changes
+
+ * Fix bug forcing uploaded tar to be named sourcedir
+ * Update local_requirements.txt PyYAML version
+ * refactoring : using with statement
+ * Allow Py 3.7 for MMS Test Docker env
+ * fix PySparkProcessor __init__ params type
+ * type hint of PySparkProcessor __init__
+ * Return ARM XGB/SKLearn tags if `image_scope` is `inference_graviton`
+ * Update scipy to 1.7.3 to support M1 development envs
+ * Fixing type hints for Spark processor that has instance type/count params in reverse order
+ * Add DeepAR ap-northeast-3 repository.
+ * Fix AsyncInferenceConfig documentation typo
+ * fix ml_inf to ml_inf1 in Neo multi-version support
+ * Fix type annotations
+ * add neo mvp region accounts
+
## v2.118.0 (2022-12-01)
### Features
diff --git a/CODEOWNERS b/CODEOWNERS
new file mode 100644
index 0000000000..7f7ac28644
--- /dev/null
+++ b/CODEOWNERS
@@ -0,0 +1 @@
+* @aws/sagemaker-ml-frameworks
diff --git a/VERSION b/VERSION
index 34d47b7f52..6d7f044fa2 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-2.118.1.dev0
+2.122.1.dev0
diff --git a/doc/experiments/index.rst b/doc/experiments/index.rst
new file mode 100644
index 0000000000..8c12f30edc
--- /dev/null
+++ b/doc/experiments/index.rst
@@ -0,0 +1,10 @@
+############################
+Amazon SageMaker Experiments
+############################
+
+The SageMaker Python SDK supports to track and organize your machine learning workflow across SageMaker with jobs, such as Processing, Training and Transform, or locally.
+
+.. toctree::
+ :maxdepth: 2
+
+ sagemaker.experiments
diff --git a/doc/experiments/sagemaker.experiments.rst b/doc/experiments/sagemaker.experiments.rst
new file mode 100644
index 0000000000..f0776ec43b
--- /dev/null
+++ b/doc/experiments/sagemaker.experiments.rst
@@ -0,0 +1,20 @@
+Experiments
+============
+
+Run
+-------------
+
+.. autoclass:: sagemaker.experiments.Run
+ :members:
+
+.. automethod:: sagemaker.experiments.load_run
+
+.. automethod:: sagemaker.experiments.list_runs
+
+.. autoclass:: sagemaker.experiments.SortByType
+ :members:
+ :undoc-members:
+
+.. autoclass:: sagemaker.experiments.SortOrderType
+ :members:
+ :undoc-members:
diff --git a/doc/index.rst b/doc/index.rst
index 2d4ebe32c1..69038056b0 100644
--- a/doc/index.rst
+++ b/doc/index.rst
@@ -60,6 +60,16 @@ Orchestrate your SageMaker training and inference workflows with Airflow and Kub
workflows/index
+****************************
+Amazon SageMaker Experiments
+****************************
+You can use Amazon SageMaker Experiments to track machine learning experiments.
+
+.. toctree::
+ :maxdepth: 2
+
+ experiments/index
+
*************************
Amazon SageMaker Debugger
*************************
diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt
index b52f394bd0..494b6dca11 100644
--- a/requirements/extras/test_requirements.txt
+++ b/requirements/extras/test_requirements.txt
@@ -11,6 +11,7 @@ contextlib2==21.6.0
awslogs==0.14.0
black==22.3.0
stopit==1.1.2
+# Update tox.ini to have correct version of airflow constraints file
apache-airflow==2.4.1
apache-airflow-providers-amazon==4.0.0
attrs==22.1.0
@@ -19,3 +20,4 @@ requests==2.27.1
sagemaker-experiments==0.1.35
Jinja2==3.0.3
pandas>=1.3.5,<1.5
+scikit-learn==1.0.2
diff --git a/setup.py b/setup.py
index 4327045760..e2adb6b433 100644
--- a/setup.py
+++ b/setup.py
@@ -48,7 +48,7 @@ def read_requirements(filename):
# Declare minimal set for installation
required_packages = [
"attrs>=20.3.0,<23",
- "boto3>=1.26.20,<2.0",
+ "boto3>=1.26.28,<2.0",
"google-pasta",
"numpy>=1.9.0,<2.0",
"protobuf>=3.1,<4.0",
diff --git a/src/sagemaker/amazon/amazon_estimator.py b/src/sagemaker/amazon/amazon_estimator.py
index b156f2e65f..1abea5e48c 100644
--- a/src/sagemaker/amazon/amazon_estimator.py
+++ b/src/sagemaker/amazon/amazon_estimator.py
@@ -27,7 +27,7 @@
from sagemaker.deprecations import renamed_warning
from sagemaker.estimator import EstimatorBase, _TrainingJob
from sagemaker.inputs import FileSystemInput, TrainingInput
-from sagemaker.utils import sagemaker_timestamp
+from sagemaker.utils import sagemaker_timestamp, check_and_get_run_experiment_config
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.pipeline_context import runnable_by_pipeline
from sagemaker.workflow import is_pipeline_variable
@@ -242,8 +242,8 @@ def fit(
generates a default job name, based on the training image name
and current timestamp.
experiment_config (dict[str, str]): Experiment management configuration.
- Optionally, the dict can contain three keys:
- 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
+ Optionally, the dict can contain four keys:
+ 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
@@ -255,6 +255,7 @@ def fit(
"""
self._prepare_for_training(records, job_name=job_name, mini_batch_size=mini_batch_size)
+ experiment_config = check_and_get_run_experiment_config(experiment_config)
self.latest_training_job = _TrainingJob.start_new(
self, records, experiment_config=experiment_config
)
diff --git a/src/sagemaker/apiutils/_base_types.py b/src/sagemaker/apiutils/_base_types.py
index e920797b18..9a7359e12b 100644
--- a/src/sagemaker/apiutils/_base_types.py
+++ b/src/sagemaker/apiutils/_base_types.py
@@ -173,8 +173,10 @@ def _search(
search_items = search_method_response.get("Results", [])
next_token = search_method_response.get(boto_next_token_name)
for item in search_items:
- if cls.__name__ in item:
- yield search_item_factory(item[cls.__name__])
+ # _TrialComponent class in experiments module is not public currently
+ class_name = cls.__name__.lstrip("_")
+ if class_name in item:
+ yield search_item_factory(item[class_name])
if not next_token:
break
except StopIteration:
diff --git a/src/sagemaker/apiutils/_boto_functions.py b/src/sagemaker/apiutils/_boto_functions.py
index 1e29f2ebea..a227d30ca8 100644
--- a/src/sagemaker/apiutils/_boto_functions.py
+++ b/src/sagemaker/apiutils/_boto_functions.py
@@ -68,7 +68,9 @@ def from_boto(boto_dict, boto_name_to_member_name, member_name_to_type):
api_type, is_collection = member_name_to_type[member_name]
if is_collection:
if isinstance(boto_value, dict):
- member_value = api_type.from_boto(boto_value)
+ member_value = {
+ key: api_type.from_boto(value) for key, value in boto_value.items()
+ }
else:
member_value = [api_type.from_boto(item) for item in boto_value]
else:
diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py
index 4765630ce8..f082679401 100644
--- a/src/sagemaker/clarify.py
+++ b/src/sagemaker/clarify.py
@@ -282,7 +282,6 @@
"text/csv",
"application/jsonlines",
"image/jpeg",
- "image/jpg",
"image/png",
"application/x-npy",
),
diff --git a/src/sagemaker/dataset_definition/inputs.py b/src/sagemaker/dataset_definition/inputs.py
index 90a272c4d7..468be22ac3 100644
--- a/src/sagemaker/dataset_definition/inputs.py
+++ b/src/sagemaker/dataset_definition/inputs.py
@@ -124,8 +124,10 @@ class DatasetDefinition(ApiObject):
"""DatasetDefinition input."""
_custom_boto_types = {
- "redshift_dataset_definition": (RedshiftDatasetDefinition, True),
- "athena_dataset_definition": (AthenaDatasetDefinition, True),
+ # RedshiftDatasetDefinition and AthenaDatasetDefinition are not collection
+ # Instead they are singleton objects. Thus, set the is_collection flag to False.
+ "redshift_dataset_definition": (RedshiftDatasetDefinition, False),
+ "athena_dataset_definition": (AthenaDatasetDefinition, False),
}
def __init__(
diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py
index 6f729267de..a46e106a1d 100644
--- a/src/sagemaker/estimator.py
+++ b/src/sagemaker/estimator.py
@@ -79,6 +79,7 @@
get_config_value,
name_from_base,
to_string,
+ check_and_get_run_experiment_config,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
@@ -884,8 +885,6 @@ def _check_debugger_rule(self, rule):
def _prepare_debugger_for_training(self):
"""Prepare debugger rules and debugger configs for training."""
- if self.debugger_rules and self.debugger_hook_config is None:
- self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
# If debugger_hook_config was provided without an S3 URI, default it for the customer.
if self.debugger_hook_config and not self.debugger_hook_config.s3_output_path:
self.debugger_hook_config.s3_output_path = self.output_path
@@ -899,10 +898,7 @@ def _validate_and_set_debugger_configs(self):
self.sagemaker_session.boto_region_name
)
- if region_supports_debugger:
- if self.debugger_hook_config in [None, {}]:
- self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
- else:
+ if not region_supports_debugger:
if self.debugger_hook_config is not False and self.debugger_hook_config:
# when user set debugger config in a unsupported region
raise ValueError(
@@ -1103,8 +1099,8 @@ def fit(
job_name (str): Training job name. If not specified, the estimator generates
a default job name based on the training image name and current timestamp.
experiment_config (dict[str, str]): Experiment management configuration.
- Optionally, the dict can contain three keys:
- 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
+ Optionally, the dict can contain four keys:
+ 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'..
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
@@ -1122,6 +1118,7 @@ def fit(
"""
self._prepare_for_training(job_name=job_name)
+ experiment_config = check_and_get_run_experiment_config(experiment_config)
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
self.jobs.append(self.latest_training_job)
if wait:
@@ -2023,8 +2020,8 @@ def start_new(cls, estimator, inputs, experiment_config):
inputs (str): Parameters used when called
:meth:`~sagemaker.estimator.EstimatorBase.fit`.
experiment_config (dict[str, str]): Experiment management configuration.
- Optionally, the dict can contain three keys:
- 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
+ Optionally, the dict can contain four keys:
+ 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
@@ -2033,6 +2030,7 @@ def start_new(cls, estimator, inputs, experiment_config):
* If both `ExperimentName` and `TrialName` are not supplied the trial component
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
+ * `RunName` is used to record an experiment run.
Returns:
sagemaker.estimator._TrainingJob: Constructed object that captures
all information about the started training job.
@@ -2053,8 +2051,8 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
inputs (str): Parameters used when called
:meth:`~sagemaker.estimator.EstimatorBase.fit`.
experiment_config (dict[str, str]): Experiment management configuration.
- Optionally, the dict can contain three keys:
- 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
+ Optionally, the dict can contain four keys:
+ 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
@@ -2063,6 +2061,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
* If both `ExperimentName` and `TrialName` are not supplied the trial component
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
+ * `RunName` is used to record an experiment run.
Returns:
Dict: dict for `sagemaker.session.Session.train` method
diff --git a/src/sagemaker/experiments/__init__.py b/src/sagemaker/experiments/__init__.py
new file mode 100644
index 0000000000..b87656b1ab
--- /dev/null
+++ b/src/sagemaker/experiments/__init__.py
@@ -0,0 +1,20 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Sagemaker Experiment Module"""
+from __future__ import absolute_import
+
+from sagemaker.experiments.run import Run # noqa: F401
+from sagemaker.experiments.run import load_run # noqa: F401
+from sagemaker.experiments.run import list_runs # noqa: F401
+from sagemaker.experiments.run import SortOrderType # noqa: F401
+from sagemaker.experiments.run import SortByType # noqa: F401
diff --git a/src/sagemaker/experiments/_api_types.py b/src/sagemaker/experiments/_api_types.py
new file mode 100644
index 0000000000..78f82565aa
--- /dev/null
+++ b/src/sagemaker/experiments/_api_types.py
@@ -0,0 +1,251 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains API objects for SageMaker experiments."""
+from __future__ import absolute_import
+
+import enum
+import numbers
+
+from sagemaker.apiutils import _base_types
+
+
+class TrialComponentMetricSummary(_base_types.ApiObject):
+ """Summary model of a trial component.
+
+ Attributes:
+ metric_name (str): The name of the metric.
+ source_arn (str): The ARN of the source.
+ time_stamp (datetime): Metric last updated value.
+ max (float): The max value of the metric.
+ min (float): The min value of the metric.
+ last (float): The last value of the metric.
+ count (float): The number of samples used to generate the metric.
+ avg (float): The average value of the metric.
+ std_dev (float): The standard deviation of the metric.
+ """
+
+ metric_name = None
+ source_arn = None
+ time_stamp = None
+ max = None
+ min = None
+ last = None
+ count = None
+ avg = None
+ std_dev = None
+
+ def __init__(self, metric_name=None, source_arn=None, **kwargs):
+ super(TrialComponentMetricSummary, self).__init__(
+ metric_name=metric_name, source_arn=source_arn, **kwargs
+ )
+
+
+class TrialComponentParameters(_base_types.ApiObject):
+ """A dictionary of TrialComponentParameterValues"""
+
+ @classmethod
+ def from_boto(cls, boto_dict, **kwargs):
+ """Converts a boto dict to a dictionary of TrialComponentParameterValues
+
+ Args:
+ boto_dict (dict): boto response dictionary.
+ **kwargs: Arbitrary keyword arguments.
+
+ Returns:
+ dict: Dictionary of parameter values.
+ """
+ return_map = {}
+ for key, value in boto_dict.items():
+ return_map[key] = value.get("NumberValue", value.get("StringValue", None))
+ return return_map
+
+ @classmethod
+ def to_boto(cls, parameters):
+ """Converts TrialComponentParameters to dict.
+
+ Args:
+ parameters (TrialComponentParameters): Dictionary to convert.
+
+ Returns:
+ dict: Dictionary of trial component parameters in boto format.
+ """
+ boto_map = {}
+ for key, value in parameters.items():
+ if isinstance(value, numbers.Number):
+ boto_map[key] = {"NumberValue": value}
+ else:
+ boto_map[key] = {"StringValue": str(value)}
+ return boto_map
+
+
+class TrialComponentArtifact(_base_types.ApiObject):
+ """Trial component artifact.
+
+ Attributes:
+ value (str): The artifact value.
+ media_type (str): The media type.
+ """
+
+ value = None
+ media_type = None
+
+ def __init__(self, value=None, media_type=None, **kwargs):
+ super(TrialComponentArtifact, self).__init__(value=value, media_type=media_type, **kwargs)
+
+
+class _TrialComponentStatusType(enum.Enum):
+ """The type of trial component status"""
+
+ InProgress = "InProgress"
+ Completed = "Completed"
+ Failed = "Failed"
+
+
+class TrialComponentStatus(_base_types.ApiObject):
+ """Status of the trial component.
+
+ Attributes:
+ primary_status (str): The status of a trial component.
+ message (str): Status message.
+ """
+
+ primary_status = None
+ message = None
+
+ def __init__(self, primary_status=None, message=None, **kwargs):
+ super(TrialComponentStatus, self).__init__(
+ primary_status=primary_status, message=message, **kwargs
+ )
+
+
+class TrialComponentSummary(_base_types.ApiObject):
+ """Summary model of a trial component.
+
+ Attributes:
+ trial_component_name (str): Name of trial component.
+ trial_component_arn (str): ARN of the trial component.
+ display_name (str): Friendly display name in UI.
+ source_arn (str): ARN of the trial component source.
+ status (str): Status.
+ start_time (datetime): Start time.
+ end_time (datetime): End time.
+ creation_time (datetime): Creation time.
+ created_by (str): Created by.
+ last_modified_time (datetime): Date last modified.
+ last_modified_by (datetime): User last modified.
+ """
+
+ _custom_boto_types = {
+ "status": (TrialComponentStatus, False),
+ }
+ trial_component_name = None
+ trial_component_arn = None
+ display_name = None
+ source_arn = None
+ status = None
+ start_time = None
+ end_time = None
+ creation_time = None
+ created_by = None
+ last_modified_time = None
+ last_modified_by = None
+
+
+class TrialComponentSource(_base_types.ApiObject):
+ """Trial Component Source
+
+ Attributes:
+ source_arn (str): The ARN of the source.
+ """
+
+ source_arn = None
+
+ def __init__(self, source_arn=None, **kwargs):
+ super(TrialComponentSource, self).__init__(source_arn=source_arn, **kwargs)
+
+
+class Parent(_base_types.ApiObject):
+ """The trial/experiment/run that a trial component is associated with.
+
+ Attributes:
+ trial_name (str): Name of the trial.
+ experiment_name (str): Name of the experiment.
+ run_name (str): Name of the run.
+ """
+
+ trial_name = None
+ experiment_name = None
+ run_name = None
+
+
+class TrialComponentSearchResult(_base_types.ApiObject):
+ """Summary model of an Trial Component search result.
+
+ Attributes:
+ trial_component_arn (str): ARN of the trial component.
+ trial_component_name (str): Name of the trial component.
+ display_name (str): Display name of the trial component for UI display.
+ source (dict): The source of the trial component.
+ status (dict): The status of the trial component.
+ start_time (datetime): Start time.
+ end_time (datetime): End time.
+ creation_time (datetime): Creation time.
+ created_by (str): Created by.
+ last_modified_time (datetime): Date last modified.
+ last_modified_by (datetime): User last modified.
+ parameters (dict): The hyperparameters of the component.
+ input_artifacts (dict): The input artifacts of the component.
+ output_artifacts (dict): The output artifacts of the component.
+ metrics (list): The metrics for the component.
+ source_detail (dict): The source of the trial component.
+ tags (list): The list of tags that are associated with the trial component.
+ parents (list[Parent]): The parent of trial component.
+ """
+
+ _custom_boto_types = {
+ "parents": (Parent, True), # parents is a collection (list) of Parent objects
+ }
+ trial_component_arn = None
+ trial_component_name = None
+ display_name = None
+ source = None
+ status = None
+ start_time = None
+ end_time = None
+ creation_time = None
+ created_by = None
+ last_modified_time = None
+ last_modified_by = None
+ parameters = None
+ input_artifacts = None
+ output_artifacts = None
+ metrics = None
+ source_detail = None
+ tags = None
+ parents = None
+
+
+class TrialSummary(_base_types.ApiObject):
+ """Summary model of a trial.
+
+ Attributes:
+ trial_arn (str): The ARN of the trial.
+ trial_name (str): The name of the trial.
+ creation_time (datetime): When the trial was created.
+ last_modified_time (datetime): When the trial was last modified.
+ """
+
+ trial_arn = None
+ trial_name = None
+ creation_time = None
+ last_modified_time = None
diff --git a/src/sagemaker/experiments/_environment.py b/src/sagemaker/experiments/_environment.py
new file mode 100644
index 0000000000..441661ae5a
--- /dev/null
+++ b/src/sagemaker/experiments/_environment.py
@@ -0,0 +1,132 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains the _RunEnvironment class."""
+from __future__ import absolute_import
+
+import enum
+import json
+import logging
+import os
+
+from sagemaker.experiments import trial_component
+from sagemaker.utils import retry_with_backoff
+
+TRAINING_JOB_ARN_ENV = "TRAINING_JOB_ARN"
+PROCESSING_JOB_CONFIG_PATH = "/opt/ml/config/processingjobconfig.json"
+TRANSFORM_JOB_ENV_BATCH_VAR = "SAGEMAKER_BATCH"
+MAX_RETRY_ATTEMPTS = 7
+
+logger = logging.getLogger(__name__)
+
+
+class _EnvironmentType(enum.Enum):
+ """SageMaker jobs which data can be pulled from the environment."""
+
+ SageMakerTrainingJob = 1
+ SageMakerProcessingJob = 2
+ SageMakerTransformJob = 3
+
+
+class _RunEnvironment(object):
+ """Retrieves job specific data from the environment."""
+
+ def __init__(self, environment_type, source_arn):
+ """Init for _RunEnvironment.
+
+ Args:
+ environment_type (_EnvironmentType): The environment type.
+ source_arn (str): The ARN of the current job.
+ """
+ self.environment_type = environment_type
+ self.source_arn = source_arn
+
+ @classmethod
+ def load(
+ cls,
+ training_job_arn_env=TRAINING_JOB_ARN_ENV,
+ processing_job_config_path=PROCESSING_JOB_CONFIG_PATH,
+ transform_job_batch_var=TRANSFORM_JOB_ENV_BATCH_VAR,
+ ):
+ """Loads source arn of current job from environment.
+
+ Args:
+ training_job_arn_env (str): The environment key for training job ARN
+ (default: `TRAINING_JOB_ARN`).
+ processing_job_config_path (str): The processing job config path
+ (default: `/opt/ml/config/processingjobconfig.json`).
+ transform_job_batch_var (str): The environment variable indicating if
+ it is a transform job (default: `SAGEMAKER_BATCH`).
+
+ Returns:
+ _RunEnvironment: Job data loaded from the environment. None if config does not exist.
+ """
+ if training_job_arn_env in os.environ:
+ environment_type = _EnvironmentType.SageMakerTrainingJob
+ source_arn = os.environ.get(training_job_arn_env)
+ return _RunEnvironment(environment_type, source_arn)
+ if os.path.exists(processing_job_config_path):
+ environment_type = _EnvironmentType.SageMakerProcessingJob
+ source_arn = json.loads(open(processing_job_config_path).read())["ProcessingJobArn"]
+ return _RunEnvironment(environment_type, source_arn)
+ if transform_job_batch_var in os.environ and os.environ[transform_job_batch_var] == "true":
+ environment_type = _EnvironmentType.SageMakerTransformJob
+ # TODO: need to figure out how to get source_arn from job env
+ # with Transform team's help.
+ source_arn = ""
+ return _RunEnvironment(environment_type, source_arn)
+
+ return None
+
+ def get_trial_component(self, sagemaker_session):
+ """Retrieves the trial component from the job in the environment.
+
+ Args:
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ _TrialComponent: The trial component created from the job. None if not found.
+ """
+ # TODO: Remove this condition check once we have a way to retrieve source ARN
+ # from transform job env
+ if self.environment_type == _EnvironmentType.SageMakerTransformJob:
+ logger.error(
+ "Currently getting the job trial component from the transform job environment "
+ "is not supported. Returning None."
+ )
+ return None
+
+ def _get_trial_component():
+ summaries = list(
+ trial_component._TrialComponent.list(
+ source_arn=self.source_arn.lower(), sagemaker_session=sagemaker_session
+ )
+ )
+ if summaries:
+ summary = summaries[0]
+ return trial_component._TrialComponent.load(
+ trial_component_name=summary.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ return None
+
+ job_tc = None
+ try:
+ job_tc = retry_with_backoff(_get_trial_component, MAX_RETRY_ATTEMPTS)
+ except Exception as ex: # pylint: disable=broad-except
+ logger.error(
+ "Failed to get trail component in the current environment due to %s", str(ex)
+ )
+ return job_tc
diff --git a/src/sagemaker/experiments/_helper.py b/src/sagemaker/experiments/_helper.py
new file mode 100644
index 0000000000..0c689b1125
--- /dev/null
+++ b/src/sagemaker/experiments/_helper.py
@@ -0,0 +1,266 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains the helper classes for SageMaker Experiment."""
+from __future__ import absolute_import
+
+import json
+import logging
+import os
+
+import botocore
+
+from sagemaker.experiments._utils import is_already_exist_error
+
+logger = logging.getLogger(__name__)
+
+
+_DEFAULT_ARTIFACT_PREFIX = "trial-component-artifacts"
+_DEFAULT_ARTIFACT_TYPE = "Tracker"
+
+
+class _ArtifactUploader(object):
+ """Artifact uploader"""
+
+ def __init__(
+ self,
+ trial_component_name,
+ sagemaker_session,
+ artifact_bucket=None,
+ artifact_prefix=_DEFAULT_ARTIFACT_PREFIX,
+ ):
+ """Initialize a `_ArtifactUploader` instance.
+
+ Args:
+ trial_component_name (str): The name of the trial component,
+ which is used to generate the S3 path to upload the artifact to.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed.
+ artifact_bucket (str): The S3 bucket to upload the artifact to.
+ If not specified, the default bucket defined in `sagemaker_session`
+ will be used.
+ artifact_prefix (str): The S3 key prefix used to generate the S3 path
+ to upload the artifact to (default: "trial-component-artifacts").
+ """
+ self.sagemaker_session = sagemaker_session
+ self.trial_component_name = trial_component_name
+ self.artifact_bucket = artifact_bucket
+ self.artifact_prefix = artifact_prefix
+ self._s3_client = self.sagemaker_session.boto_session.client("s3")
+
+ def upload_artifact(self, file_path):
+ """Upload an artifact file to S3.
+
+ Args:
+ file_path (str): the file path of the artifact
+
+ Returns:
+ (str, str): The s3 URI of the uploaded file and the etag of the file.
+
+ Raises:
+ ValueError: If file does not exist.
+ """
+ file_path = os.path.expanduser(file_path)
+ if not os.path.isfile(file_path):
+ raise ValueError(
+ "{} does not exist or is not a file. Please supply a file path.".format(file_path)
+ )
+ if not self.artifact_bucket:
+ self.artifact_bucket = self.sagemaker_session.default_bucket()
+ artifact_name = os.path.basename(file_path)
+ artifact_s3_key = "{}/{}/{}".format(
+ self.artifact_prefix, self.trial_component_name, artifact_name
+ )
+ self._s3_client.upload_file(file_path, self.artifact_bucket, artifact_s3_key)
+ etag = self._try_get_etag(artifact_s3_key)
+ return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag
+
+ def upload_object_artifact(self, artifact_name, artifact_object, file_extension=None):
+ """Upload an artifact object to S3.
+
+ Args:
+ artifact_name (str): the name of the artifact.
+ artifact_object (obj): the object of the artifact
+ file_extension (str): Optional file extension.
+
+ Returns:
+ str: The s3 URI of the uploaded file and the version of the file.
+ """
+ if not self.artifact_bucket:
+ self.artifact_bucket = self.sagemaker_session.default_bucket()
+ if file_extension:
+ artifact_name = (
+ artifact_name + ("" if file_extension.startswith(".") else ".") + file_extension
+ )
+ artifact_s3_key = "{}/{}/{}".format(
+ self.artifact_prefix, self.trial_component_name, artifact_name
+ )
+ self._s3_client.put_object(
+ Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key
+ )
+ etag = self._try_get_etag(artifact_s3_key)
+ return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag
+
+ def _try_get_etag(self, key):
+ """Get ETag of given key and return None if not allowed
+
+ Args:
+ key (str): The S3 object key.
+
+ Returns:
+ str: The S3 object ETag if it allows, otherwise return None.
+ """
+ try:
+ response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key)
+ return response["ETag"]
+ except botocore.exceptions.ClientError as error:
+ # requires read permissions
+ logger.warning("Failed to get ETag of %s due to %s", key, error)
+ return None
+
+
+class _LineageArtifactManager(object):
+ """A helper class to manage Lineage Artifacts"""
+
+ def __init__(
+ self,
+ name,
+ source_uri,
+ etag,
+ source_arn=None,
+ dest_arn=None,
+ artifact_type=_DEFAULT_ARTIFACT_TYPE,
+ ):
+ """Initialize a `_LineageArtifactManager` instance.
+
+ Args:
+ name (str): The name of the Lineage artifact to be created.
+ source_uri (str): The source URI used to create the Lineage artifact.
+ etag (str): The S3 Etag used to create the Lineage artifact.
+ source_arn (str): The source ARN of a trail component to associate
+ this Lineage artifact with (default: None).
+ dest_arn (str): The destination ARN of a trial component to associate
+ this Lineage artifact with (default: None).
+ artifact_type (str): The type of the Lineage artifact (default: "Tracker").
+ """
+ self.name = name
+ self.source_uri = source_uri
+ self.etag = etag
+ self.source_arn = source_arn
+ self.dest_arn = dest_arn
+ self.artifact_arn = None
+ self.artifact_type = artifact_type
+
+ def create_artifact(self, sagemaker_session):
+ """Create the artifact by calling `CreateArtifact` API
+
+ Args:
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed.
+ """
+ source_ids = []
+ if self.etag:
+ source_ids.append({"SourceIdType": "S3ETag", "Value": self.etag})
+
+ try:
+ response = sagemaker_session.sagemaker_client.create_artifact(
+ ArtifactName=self.name,
+ ArtifactType=self.artifact_type,
+ Source={"SourceUri": self.source_uri, "SourceTypes": source_ids},
+ )
+ self.artifact_arn = response["ArtifactArn"]
+ except botocore.exceptions.ClientError as err:
+ err_info = err.response["Error"]
+ if not is_already_exist_error(err_info):
+ raise
+ logger.warning(
+ "Skip creating the artifact since it already exists: %s", err_info["Message"]
+ )
+
+ def add_association(self, sagemaker_session):
+ """Associate the artifact with a source/destination ARN (e.g. trial component arn)
+
+ Args:
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed.
+ """
+ source_arn = self.source_arn if self.source_arn else self.artifact_arn
+ dest_arn = self.dest_arn if self.dest_arn else self.artifact_arn
+ # if the trial component (job) is the source then it produced the artifact,
+ # otherwise the artifact contributed to the trial component (job)
+ association_edge_type = "Produced" if self.source_arn else "ContributedTo"
+ try:
+ sagemaker_session.sagemaker_client.add_association(
+ SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_edge_type
+ )
+ except botocore.exceptions.ClientError as err:
+ err_info = err.response["Error"]
+ if not is_already_exist_error(err_info):
+ raise
+ logger.warning(
+ "Skip associating since the association already exists: %s", err_info["Message"]
+ )
+
+
+class _LineageArtifactTracker(object):
+ """Lineage Artifact Tracker"""
+
+ def __init__(self, trial_component_arn, sagemaker_session):
+ """Initialize a `_LineageArtifactTracker` instance.
+
+ Args:
+ trial_component_arn (str): The ARN of the trial component to be
+ associated with the input/output artifacts.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed.
+ """
+ self.trial_component_arn = trial_component_arn
+ self.sagemaker_session = sagemaker_session
+ self.artifacts = []
+
+ def add_input_artifact(self, name, source_uri, etag, artifact_type):
+ """Add a Lineage input artifact locally
+
+ Args:
+ name (str): The name of the Lineage input artifact to be added.
+ source_uri (str): The source URI used to create the Lineage input artifact.
+ etag (str): The S3 Etag used to create the Lineage input artifact.
+ artifact_type (str): The type of the Lineage input artifact.
+ """
+ artifact = _LineageArtifactManager(
+ name, source_uri, etag, dest_arn=self.trial_component_arn, artifact_type=artifact_type
+ )
+ self.artifacts.append(artifact)
+
+ def add_output_artifact(self, name, source_uri, etag, artifact_type):
+ """Add a Lineage output artifact locally
+
+ Args:
+ name (str): The name of the Lineage output artifact to be added.
+ source_uri (str): The source URI used to create the Lineage output artifact.
+ etag (str): The S3 Etag used to create the Lineage output artifact.
+ artifact_type (str): The type of the Lineage output artifact.
+ """
+ artifact = _LineageArtifactManager(
+ name, source_uri, etag, source_arn=self.trial_component_arn, artifact_type=artifact_type
+ )
+ self.artifacts.append(artifact)
+
+ def save(self):
+ """Persist any artifact data saved locally"""
+ for artifact in self.artifacts:
+ artifact.create_artifact(self.sagemaker_session)
+ artifact.add_association(self.sagemaker_session)
diff --git a/src/sagemaker/experiments/_metrics.py b/src/sagemaker/experiments/_metrics.py
new file mode 100644
index 0000000000..f80c43f337
--- /dev/null
+++ b/src/sagemaker/experiments/_metrics.py
@@ -0,0 +1,413 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains classes to manage metrics for Sagemaker Experiment"""
+from __future__ import absolute_import
+
+import datetime
+import json
+import logging
+import os
+import time
+import threading
+import queue
+
+import dateutil.tz
+
+from sagemaker.session import Session
+
+METRICS_DIR = os.environ.get("SAGEMAKER_METRICS_DIRECTORY", ".")
+METRIC_TS_LOWER_BOUND_TO_NOW = 1209600 # on seconds
+METRIC_TS_UPPER_BOUND_FROM_NOW = 7200 # on seconds
+
+BATCH_SIZE = 10
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+# TODO: remove this _SageMakerFileMetricsWriter class
+# when _MetricsManager is fully ready
+class _SageMakerFileMetricsWriter(object):
+ """Write metric data to file."""
+
+ def __init__(self, metrics_file_path=None):
+ """Construct a `_SageMakerFileMetricsWriter` object"""
+ self._metrics_file_path = metrics_file_path
+ self._file = None
+ self._closed = False
+
+ def log_metric(self, metric_name, value, timestamp=None, step=None):
+ """Write a metric to file.
+
+ Args:
+ metric_name (str): The name of the metric.
+ value (float): The value of the metric.
+ timestamp (datetime.datetime): Timestamp of the metric.
+ If not specified, the current UTC time will be used.
+ step (int): Iteration number of the metric (default: None).
+
+ Raises:
+ SageMakerMetricsWriterException: If the metrics file is closed.
+ AttributeError: If file has been initialized and the writer hasn't been closed.
+ """
+ raw_metric_data = _RawMetricData(
+ metric_name=metric_name, value=value, timestamp=timestamp, step=step
+ )
+ try:
+ logger.debug("Writing metric: %s", raw_metric_data)
+ self._file.write(json.dumps(raw_metric_data.to_record()))
+ self._file.write("\n")
+ except AttributeError as attr_err:
+ if self._closed:
+ raise SageMakerMetricsWriterException("log_metric called on a closed writer")
+ if not self._file:
+ self._file = open(self._get_metrics_file_path(), "a", buffering=1)
+ self._file.write(json.dumps(raw_metric_data.to_record()))
+ self._file.write("\n")
+ else:
+ raise attr_err
+
+ def close(self):
+ """Closes the metric file."""
+ if not self._closed and self._file:
+ self._file.close()
+ self._file = None # invalidate reference, causing subsequent log_metric to fail.
+ self._closed = True
+
+ def __enter__(self):
+ """Return self"""
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ """Execute self.close()"""
+ self.close()
+
+ def __del__(self):
+ """Execute self.close()"""
+ self.close()
+
+ def _get_metrics_file_path(self):
+ """Get file path to store metrics"""
+ pid_filename = "{}.json".format(str(os.getpid()))
+ metrics_file_path = self._metrics_file_path or os.path.join(METRICS_DIR, pid_filename)
+ logger.debug("metrics_file_path = %s", metrics_file_path)
+ return metrics_file_path
+
+
+class SageMakerMetricsWriterException(Exception):
+ """SageMakerMetricsWriterException"""
+
+ def __init__(self, message, errors=None):
+ """Construct a `SageMakerMetricsWriterException` instance"""
+ super().__init__(message)
+ if errors:
+ self.errors = errors
+
+
+class _RawMetricData(object):
+ """A Raw Metric Data Object"""
+
+ MetricName = None
+ Value = None
+ Timestamp = None
+ Step = None
+
+ def __init__(self, metric_name, value, timestamp=None, step=None):
+ """Construct a `_RawMetricData` instance.
+
+ Args:
+ metric_name (str): The name of the metric.
+ value (float): The value of the metric.
+ timestamp (datetime.datetime or float or str): Timestamp of the metric.
+ If not specified, the current UTC time will be used.
+ step (int): Iteration number of the metric (default: None).
+ """
+ if timestamp is None:
+ timestamp = time.time()
+ elif isinstance(timestamp, datetime.datetime):
+ # If the input is a datetime then convert it to UTC time.
+ # Assume a naive datetime is in local timezone
+ if not timestamp.tzinfo:
+ timestamp = timestamp.replace(tzinfo=dateutil.tz.tzlocal())
+ timestamp = (timestamp - timestamp.utcoffset()).replace(tzinfo=datetime.timezone.utc)
+ timestamp = timestamp.timestamp()
+ else:
+ timestamp = float(timestamp)
+
+ if timestamp < (time.time() - METRIC_TS_LOWER_BOUND_TO_NOW) or timestamp > (
+ time.time() + METRIC_TS_UPPER_BOUND_FROM_NOW
+ ):
+ raise ValueError(
+ "Supplied timestamp %f is invalid."
+ " Timestamps must be between two weeks before and two hours from now." % timestamp
+ )
+ value = float(value)
+
+ self.MetricName = metric_name
+ self.Value = float(value)
+ self.Timestamp = timestamp
+ if step is not None:
+ if not isinstance(step, int):
+ raise ValueError("step must be int.")
+ self.Step = step
+
+ def to_record(self):
+ """Convert the `_RawMetricData` object to dict"""
+ return self.__dict__
+
+ def to_raw_metric_data(self):
+ """Converts the metric data to a BatchPutMetrics RawMetricData item"""
+ # Convert timestamp from float to timestamp str.
+ # Otherwise will get ParamValidationError
+ raw_metric_data = {
+ "MetricName": self.MetricName,
+ "Value": self.Value,
+ "Timestamp": str(int(self.Timestamp)),
+ }
+ if self.Step is not None:
+ raw_metric_data["Step"] = int(self.Step)
+ return raw_metric_data
+
+ def __str__(self):
+ """String representation of the `_RawMetricData` object."""
+ return repr(self)
+
+ def __repr__(self):
+ """Return a string representation of this _RawMetricData` object."""
+ return "{}({})".format(
+ type(self).__name__,
+ ",".join(["{}={}".format(k, repr(v)) for k, v in vars(self).items()]),
+ )
+
+
+class _MetricsManager(object):
+ """Collects metrics and sends them directly to SageMaker Metrics data plane APIs."""
+
+ def __init__(self, trial_component_name: str, sagemaker_session: Session, sink=None) -> None:
+ """Initialize a `_MetricsManager` instance
+
+ Args:
+ trial_component_name (str): The Name of the Trial Component to log metrics to
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+ sink (object): The metrics sink to use.
+ """
+ if sink is None:
+ self.sink = _SyncMetricsSink(
+ trial_component_name, sagemaker_session.sagemaker_metrics_client
+ )
+ else:
+ self.sink = sink
+
+ def log_metric(self, metric_name, value, timestamp=None, step=None):
+ """Sends a metric to metrics service."""
+
+ metric_data = _RawMetricData(metric_name, value, timestamp, step)
+ self.sink.log_metric(metric_data)
+
+ def __enter__(self):
+ """Return self"""
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ """Execute self.close()"""
+ self.sink.close()
+
+ def close(self):
+ """Close the metrics object."""
+ self.sink.close()
+
+
+class _SyncMetricsSink(object):
+ """Collects metrics and sends them directly to metrics service."""
+
+ def __init__(self, trial_component_name, metrics_client) -> None:
+ """Initialize a `_SyncMetricsSink` instance
+
+ Args:
+ trial_component_name (str): The Name of the Trial Component to log metrics.
+ metrics_client (boto3.client): boto client for metrics service
+ """
+ self._trial_component_name = trial_component_name
+ self._metrics_client = metrics_client
+ self._buffer = []
+
+ def log_metric(self, metric_data):
+ """Sends a metric to metrics service."""
+
+ # this is a simplistic solution which calls BatchPutMetrics
+ # on the same thread as the client code
+ self._buffer.append(metric_data)
+ self._drain()
+
+ def _drain(self, close=False):
+ """Pops off all metrics in the buffer and starts sending them to metrics service."""
+
+ if not self._buffer:
+ return
+
+ if len(self._buffer) < BATCH_SIZE and not close:
+ return
+
+ # pop all the available metrics
+ available_metrics, self._buffer = self._buffer, []
+
+ self._send_metrics(available_metrics)
+
+ def _send_metrics(self, metrics):
+ """Calls BatchPutMetrics directly on the metrics service."""
+ while metrics:
+ batch, metrics = (
+ metrics[:BATCH_SIZE],
+ metrics[BATCH_SIZE:],
+ )
+ request = self._construct_batch_put_metrics_request(batch)
+ response = self._metrics_client.batch_put_metrics(**request)
+ errors = response["Errors"] if "Errors" in response else None
+ if errors:
+ message = errors[0]["Message"]
+ raise Exception(f'{len(errors)} errors with message "{message}"')
+
+ def _construct_batch_put_metrics_request(self, batch):
+ """Creates dictionary object used as request to metrics service."""
+ return {
+ "TrialComponentName": self._trial_component_name.lower(),
+ "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
+ }
+
+ def close(self):
+ """Drains any remaining metrics."""
+ self._drain(close=True)
+
+
+class _MetricQueue(object):
+ """A thread safe queue for sending metrics to SageMaker.
+
+ Args:
+ trial_component_name (str): the ARN of the resource
+ metric_name (str): the name of the metric
+ metrics_client (boto_client): the boto client for SageMaker Metrics service
+ """
+
+ _CONSUMER_SLEEP_SECONDS = 5
+
+ def __init__(self, trial_component_name, metric_name, metrics_client):
+ # infinite queue size
+ self._queue = queue.Queue()
+ self._buffer = []
+ self._thread = threading.Thread(target=self._run)
+ self._started = False
+ self._finished = False
+ self._trial_component_name = trial_component_name
+ self._metrics_client = metrics_client
+ self._metric_name = metric_name
+ self._logged_metrics = 0
+
+ def log_metric(self, metric_data):
+ """Adds a metric data point to the queue"""
+ self._buffer.append(metric_data)
+
+ if len(self._buffer) < BATCH_SIZE:
+ return
+
+ self._enqueue_all()
+
+ if not self._started:
+ self._thread.start()
+ self._started = True
+
+ def _run(self):
+ """Starts the metric thread which sends metrics to SageMaker in batches"""
+
+ while not self._queue.empty() or not self._finished:
+ if self._queue.empty():
+ time.sleep(self._CONSUMER_SLEEP_SECONDS)
+ else:
+ batch = self._queue.get()
+ self._send_metrics(batch)
+
+ def _send_metrics(self, metrics_batch):
+ """Calls BatchPutMetrics directly on the metrics service."""
+ request = self._construct_batch_put_metrics_request(metrics_batch)
+ self._logged_metrics += len(metrics_batch)
+ self._metrics_client.batch_put_metrics(**request)
+
+ def _construct_batch_put_metrics_request(self, batch):
+ """Creates dictionary object used as request to metrics service."""
+
+ return {
+ "TrialComponentName": self._trial_component_name,
+ "MetricData": list(map(lambda x: x.to_raw_metric_data(), batch)),
+ }
+
+ def _enqueue_all(self):
+ """Enqueue all buffered metrics to be sent to SageMaker"""
+
+ available_metrics, self._buffer = self._buffer, []
+ if available_metrics:
+ self._queue.put(available_metrics)
+
+ def close(self):
+ """Flushes any buffered metrics"""
+
+ self._enqueue_all()
+ self._finished = True
+
+ def is_active(self):
+ """Is the thread active (still draining metrics to SageMaker)"""
+
+ return self._thread.is_alive()
+
+
+class _AsyncMetricsSink(object):
+ """Collects metrics and sends them directly to metrics service."""
+
+ _COMPLETE_SLEEP_SECONDS = 1.0
+
+ def __init__(self, trial_component_name, metrics_client) -> None:
+ """Initialize a `_AsyncMetricsSink` instance
+
+ Args:
+ trial_component_name (str): The Name of the Trial Component to log metrics to.
+ metrics_client (boto3.client): boto client for metrics service
+ """
+ self._trial_component_name = trial_component_name
+ self._metrics_client = metrics_client
+ self._buffer = []
+ self._is_draining = False
+ self._metric_queues = {}
+
+ def log_metric(self, metric_data):
+ """Sends a metric to metrics service."""
+
+ if metric_data.MetricName in self._metric_queues:
+ self._metric_queues[metric_data.MetricName].log_metric(metric_data)
+ else:
+ cur_metric_queue = _MetricQueue(
+ self._trial_component_name, metric_data.MetricName, self._metrics_client
+ )
+ self._metric_queues[metric_data.MetricName] = cur_metric_queue
+ cur_metric_queue.log_metric(metric_data)
+
+ def close(self):
+ """Closes the metric file."""
+ logging.debug("Closing")
+ for q in self._metric_queues.values():
+ q.close()
+
+ # TODO should probably use join
+ while any(map(lambda x: x.is_active(), self._metric_queues.values())):
+ time.sleep(self._COMPLETE_SLEEP_SECONDS)
+ logging.debug("Closed")
diff --git a/src/sagemaker/experiments/_run_context.py b/src/sagemaker/experiments/_run_context.py
new file mode 100644
index 0000000000..9a7dada5f4
--- /dev/null
+++ b/src/sagemaker/experiments/_run_context.py
@@ -0,0 +1,58 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains the SageMaker Experiment _RunContext class."""
+from __future__ import absolute_import
+
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+ from sagemaker.experiments import Run
+
+
+class _RunContext:
+ """A static context variable to keep track of the current Run object"""
+
+ _context_run = None
+
+ @classmethod
+ def add_run_object(cls, run: "Run"):
+ """Keep track of the current executing Run object
+
+ by adding it to a class static variable.
+
+ Args:
+ run (Run): The current Run object to be tracked.
+ """
+ cls._context_run = run
+
+ @classmethod
+ def drop_current_run(cls) -> "Run":
+ """Drop the Run object tracked in the global static variable
+
+ as its execution finishes (its "with" block ends).
+
+ Return:
+ Run: the dropped Run object.
+ """
+ current_run = cls._context_run
+ cls._context_run = None
+ return current_run
+
+ @classmethod
+ def get_current_run(cls) -> "Run":
+ """Return the current Run object without dropping it.
+
+ Return:
+ Run: the current Run object to be returned.
+ """
+ return cls._context_run
diff --git a/src/sagemaker/experiments/_utils.py b/src/sagemaker/experiments/_utils.py
new file mode 100644
index 0000000000..5ef5d99dad
--- /dev/null
+++ b/src/sagemaker/experiments/_utils.py
@@ -0,0 +1,218 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains the SageMaker Experiment utility methods."""
+from __future__ import absolute_import
+
+import logging
+import os
+
+import mimetypes
+import urllib
+from functools import wraps
+from typing import Optional
+
+from sagemaker import Session
+from sagemaker.apiutils import _utils
+from sagemaker.experiments._environment import _RunEnvironment, _EnvironmentType
+from sagemaker.experiments.trial_component import _TrialComponent
+from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression
+from sagemaker.utils import retry_with_backoff
+
+
+def resolve_artifact_name(file_path):
+ """Resolve artifact name from given file path.
+
+ If not specified, will auto create one.
+
+ Args:
+ file_path (str): Path to the file.
+
+ Returns:
+ str: The resolved artifact name.
+ """
+ _, filename = os.path.split(file_path)
+ if filename:
+ return filename
+
+ return _utils.name("artifact")
+
+
+def guess_media_type(file_path):
+ """Infer the media type of a file based on its file name.
+
+ Args:
+ file_path (str): Path to the file.
+
+ Returns:
+ str: The guessed media type.
+ """
+ file_url = urllib.parse.urljoin("file:", urllib.request.pathname2url(file_path))
+ guessed_media_type, _ = mimetypes.guess_type(file_url, strict=False)
+ return guessed_media_type
+
+
+def verify_length_of_true_and_predicted(true_labels, predicted_attrs, predicted_attrs_name):
+ """Verify if lengths match between lists of true labels and predicted attributes.
+
+ Args:
+ true_labels (list or array): The list of the true labels.
+ predicted_attrs (list or array): The list of the predicted labels/probabilities/scores.
+ predicted_attrs_name (str): The name of the predicted attributes.
+
+ Raises:
+ ValueError: If lengths mismatch between true labels and predicted attributes.
+ """
+ if len(true_labels) != len(predicted_attrs):
+ raise ValueError(
+ "Lengths mismatch between true labels and {}: "
+ "({} vs {}).".format(predicted_attrs_name, len(true_labels), len(predicted_attrs))
+ )
+
+
+def validate_invoked_inside_run_context(func):
+ """A Decorator to force the decorated method called under Run context."""
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ self_instance = args[0]
+ if not self_instance._inside_load_context and not self_instance._inside_init_context:
+ raise RuntimeError("This method should be called inside context of 'with' statement.")
+ return func(*args, **kwargs)
+
+ return wrapper
+
+
+def is_already_exist_error(error):
+ """Check if the error indicates resource already exists
+
+ Args:
+ error (dict): The "Error" field in the response of the
+ `botocore.exceptions.ClientError`
+ """
+ return error["Code"] == "ValidationException" and "already exists" in error["Message"]
+
+
+def get_tc_and_exp_config_from_job_env(
+ environment: _RunEnvironment,
+ sagemaker_session: Session,
+) -> dict:
+ """Retrieve an experiment config from the job environment.
+
+ Args:
+ environment (_RunEnvironment): The run environment object with job specific data.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+ """
+ job_name = environment.source_arn.split("/")[-1]
+ if environment.environment_type == _EnvironmentType.SageMakerTrainingJob:
+ job_response = retry_with_backoff(
+ callable_func=lambda: sagemaker_session.describe_training_job(job_name),
+ num_attempts=4,
+ )
+ elif environment.environment_type == _EnvironmentType.SageMakerProcessingJob:
+ job_response = retry_with_backoff(
+ callable_func=lambda: sagemaker_session.describe_processing_job(job_name),
+ num_attempts=4,
+ )
+ else: # environment.environment_type == _EnvironmentType.SageMakerTransformJob
+ raise RuntimeError(
+ "Failed to load the Run as loading experiment config "
+ "from transform job environment is not currently supported. "
+ "As a workaround, please explicitly pass in "
+ "the experiment_name and run_name in load_run."
+ )
+
+ job_exp_config = job_response.get("ExperimentConfig", dict())
+ from sagemaker.experiments.run import RUN_NAME
+
+ if job_exp_config.get(RUN_NAME, None):
+ return job_exp_config
+ raise RuntimeError(
+ "Not able to fetch RunName in ExperimentConfig of the sagemaker job. "
+ "Please make sure the ExperimentConfig is correctly set."
+ )
+
+
+def verify_load_input_names(
+ run_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+):
+ """Verify the run_name and the experiment_name inputs in load_run.
+
+ Args:
+ run_name (str): The run_name supplied by the user (default: None).
+ experiment_name (str): The experiment_name supplied by the user
+ (default: None).
+
+ Raises:
+ ValueError: If run_name is supplied while experiment_name is not.
+ """
+ if not run_name and experiment_name:
+ logging.warning(
+ "No run_name is supplied. Ignoring the provided experiment_name "
+ "since it only takes effect along with run_name. "
+ "Will load the Run object from the job environment or current Run context."
+ )
+ if run_name and not experiment_name:
+ raise ValueError(
+ "Invalid input: experiment_name is missing when run_name is supplied. "
+ "Please supply a valid experiment_name when the run_name is not None."
+ )
+
+
+def is_run_trial_component(trial_component_name: str, sagemaker_session: Session) -> bool:
+ """Check if a trial component is generated by `sagemaker.experiments.Run`
+
+ Args:
+ trial_component_name (str): The name of the trial component.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ bool: Indicate whether the trial component is created by
+ `sagemaker.experiments.Run` or not.
+ """
+ search_filter = Filter(
+ name="TrialComponentName",
+ operator=Operator.EQUALS,
+ value=trial_component_name,
+ )
+ search_expression = SearchExpression(filters=[search_filter])
+
+ def search():
+ return list(
+ _TrialComponent.search(
+ search_expression=search_expression,
+ max_results=1, # TrialComponentName is unique in an account
+ sagemaker_session=sagemaker_session,
+ )
+ )[0]
+
+ try:
+ tc_search_res = retry_with_backoff(search, 4)
+ from sagemaker.experiments.run import RUN_TC_TAG
+
+ if not tc_search_res.tags or RUN_TC_TAG not in tc_search_res.tags:
+ return False
+ return True
+ except Exception as ex: # pylint: disable=broad-except
+ logging.warning(
+ "Failed to inspect the type of the trial component (%s), due to (%s)",
+ trial_component_name,
+ str(ex),
+ )
+ return False
diff --git a/src/sagemaker/experiments/experiment.py b/src/sagemaker/experiments/experiment.py
new file mode 100644
index 0000000000..8f59ff36b3
--- /dev/null
+++ b/src/sagemaker/experiments/experiment.py
@@ -0,0 +1,237 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains the SageMaker Experiment class."""
+from __future__ import absolute_import
+
+import time
+
+from sagemaker.apiutils import _base_types
+from sagemaker.experiments.trial import _Trial
+from sagemaker.experiments.trial_component import _TrialComponent
+
+
+class _Experiment(_base_types.Record):
+ """An Amazon SageMaker experiment, which is a collection of related trials.
+
+ New experiments are created by calling `experiments.experiment._Experiment.create`.
+ Existing experiments can be reloaded by calling `experiments.experiment._Experiment.load`.
+
+ Attributes:
+ experiment_name (str): The name of the experiment. The name must be unique
+ within an account.
+ display_name (str): Name of the experiment that will appear in UI,
+ such as SageMaker Studio.
+ description (str): A description of the experiment.
+ tags (List[Dict[str, str]]): A list of tags to associate with the experiment.
+ """
+
+ experiment_name = None
+ display_name = None
+ description = None
+ tags = None
+
+ _boto_create_method = "create_experiment"
+ _boto_load_method = "describe_experiment"
+ _boto_update_method = "update_experiment"
+ _boto_delete_method = "delete_experiment"
+
+ _boto_update_members = ["experiment_name", "description", "display_name"]
+ _boto_delete_members = ["experiment_name"]
+
+ _MAX_DELETE_ALL_ATTEMPTS = 3
+
+ def save(self):
+ """Save the state of this Experiment to SageMaker.
+
+ Returns:
+ dict: Update experiment API response.
+ """
+ return self._invoke_api(self._boto_update_method, self._boto_update_members)
+
+ def delete(self):
+ """Delete this Experiment from SageMaker.
+
+ Deleting an Experiment does not delete associated Trials and their Trial Components.
+ It requires that each Trial in the Experiment is first deleted.
+
+ Returns:
+ dict: Delete experiment API response.
+ """
+ return self._invoke_api(self._boto_delete_method, self._boto_delete_members)
+
+ @classmethod
+ def load(cls, experiment_name, sagemaker_session=None):
+ """Load an existing experiment and return an `_Experiment` object representing it.
+
+ Args:
+ experiment_name: (str): Name of the experiment
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ experiments.experiment._Experiment: A SageMaker `_Experiment` object
+ """
+ return cls._construct(
+ cls._boto_load_method,
+ experiment_name=experiment_name,
+ sagemaker_session=sagemaker_session,
+ )
+
+ @classmethod
+ def create(
+ cls,
+ experiment_name,
+ display_name=None,
+ description=None,
+ tags=None,
+ sagemaker_session=None,
+ ):
+ """Create a new experiment in SageMaker and return an `_Experiment` object.
+
+ Args:
+ experiment_name: (str): Name of the experiment. Must be unique. Required.
+ display_name: (str): Name of the experiment that will appear in UI,
+ such as SageMaker Studio (default: None).
+ description: (str): Description of the experiment (default: None).
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+ tags (List[Dict[str, str]]): A list of tags to associate with the experiment
+ (default: None).
+
+ Returns:
+ experiments.experiment._Experiment: A SageMaker `_Experiment` object
+ """
+ return cls._construct(
+ cls._boto_create_method,
+ experiment_name=experiment_name,
+ display_name=display_name,
+ description=description,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+
+ @classmethod
+ def _load_or_create(
+ cls,
+ experiment_name,
+ display_name=None,
+ description=None,
+ tags=None,
+ sagemaker_session=None,
+ ):
+ """Load an experiment by name and create a new one if it does not exist.
+
+ Args:
+ experiment_name: (str): Name of the experiment. Must be unique. Required.
+ display_name: (str): Name of the experiment that will appear in UI,
+ such as SageMaker Studio (default: None). This is used only when the
+ given `experiment_name` does not exist and a new experiment has to be created.
+ description: (str): Description of the experiment (default: None).
+ This is used only when the given `experiment_name` does not exist and
+ a new experiment has to be created.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+ tags (List[Dict[str, str]]): A list of tags to associate with the experiment
+ (default: None). This is used only when the given `experiment_name` does not
+ exist and a new experiment has to be created.
+
+ Returns:
+ experiments.experiment._Experiment: A SageMaker `_Experiment` object
+ """
+ sagemaker_client = sagemaker_session.sagemaker_client
+ try:
+ experiment = _Experiment.load(experiment_name, sagemaker_session)
+ except sagemaker_client.exceptions.ResourceNotFound:
+ experiment = _Experiment.create(
+ experiment_name=experiment_name,
+ display_name=display_name,
+ description=description,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+ return experiment
+
+ def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):
+ """List trials in this experiment matching the specified criteria.
+
+ Args:
+ created_before (datetime.datetime): Return trials created before this instant
+ (default: None).
+ created_after (datetime.datetime): Return trials created after this instant
+ (default: None).
+ sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime'
+ (default: None).
+ sort_order (str): One of 'Ascending', or 'Descending' (default: None).
+
+ Returns:
+ collections.Iterator[experiments._api_types.TrialSummary] :
+ An iterator over trials matching the criteria.
+ """
+ return _Trial.list(
+ experiment_name=self.experiment_name,
+ created_before=created_before,
+ created_after=created_after,
+ sort_by=sort_by,
+ sort_order=sort_order,
+ sagemaker_session=self.sagemaker_session,
+ )
+
+ def _delete_all(self, action):
+ """Force to delete the experiment and associated trials, trial components.
+
+ Args:
+ action (str): The string '--force' is required to pass in to confirm recursively
+ delete the experiments, and all its trials and trial components.
+ """
+ if action != "--force":
+ raise ValueError(
+ "Must confirm with string '--force' in order to delete the experiment and "
+ "associated trials, trial components."
+ )
+
+ delete_attempt_count = 0
+ last_exception = None
+ while True:
+ if delete_attempt_count == self._MAX_DELETE_ALL_ATTEMPTS:
+ raise Exception("Failed to delete, please try again.") from last_exception
+ try:
+ for trial_summary in self.list_trials():
+ trial = _Trial.load(
+ sagemaker_session=self.sagemaker_session,
+ trial_name=trial_summary.trial_name,
+ )
+ for (
+ trial_component_summary
+ ) in trial.list_trial_components(): # pylint: disable=no-member
+ tc = _TrialComponent.load(
+ sagemaker_session=self.sagemaker_session,
+ trial_component_name=trial_component_summary.trial_component_name,
+ )
+ tc.delete(force_disassociate=True)
+ # to prevent throttling
+ time.sleep(1.2)
+ trial.delete() # pylint: disable=no-member
+ # to prevent throttling
+ time.sleep(1.2)
+ self.delete()
+ break
+ except Exception as ex: # pylint: disable=broad-except
+ last_exception = ex
+ finally:
+ delete_attempt_count = delete_attempt_count + 1
diff --git a/src/sagemaker/experiments/run.py b/src/sagemaker/experiments/run.py
new file mode 100644
index 0000000000..1492b6bafa
--- /dev/null
+++ b/src/sagemaker/experiments/run.py
@@ -0,0 +1,882 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains the SageMaker Experiment Run class."""
+from __future__ import absolute_import
+
+import datetime
+import logging
+from enum import Enum
+from math import isnan, isinf
+from numbers import Number
+from typing import Optional, List, Dict, TYPE_CHECKING, Union
+
+import dateutil
+from numpy import array
+
+from sagemaker.apiutils import _utils
+from sagemaker.experiments import _api_types
+from sagemaker.experiments._api_types import TrialComponentArtifact, _TrialComponentStatusType
+from sagemaker.experiments._helper import (
+ _ArtifactUploader,
+ _LineageArtifactTracker,
+)
+from sagemaker.experiments._environment import _RunEnvironment
+from sagemaker.experiments._run_context import _RunContext
+from sagemaker.experiments.experiment import _Experiment
+from sagemaker.experiments._metrics import _MetricsManager
+from sagemaker.experiments.trial import _Trial
+from sagemaker.experiments.trial_component import _TrialComponent
+
+from sagemaker.utils import (
+ get_module,
+ unique_name_from_base,
+)
+
+from sagemaker.experiments._utils import (
+ guess_media_type,
+ resolve_artifact_name,
+ verify_length_of_true_and_predicted,
+ validate_invoked_inside_run_context,
+ get_tc_and_exp_config_from_job_env,
+ verify_load_input_names,
+ is_run_trial_component,
+)
+
+if TYPE_CHECKING:
+ from sagemaker import Session
+
+logger = logging.getLogger(__name__)
+
+RUN_NAME_BASE = "Sagemaker-Run".lower()
+TRIAL_NAME_TEMPLATE = "Default-Run-Group-{}"
+MAX_RUN_TC_ARTIFACTS_LEN = 30
+MAX_NAME_LEN_IN_BACKEND = 120
+EXPERIMENT_NAME = "ExperimentName"
+TRIAL_NAME = "TrialName"
+RUN_NAME = "RunName"
+DELIMITER = "-"
+RUN_TC_TAG_KEY = "sagemaker:trial-component-source"
+RUN_TC_TAG_VALUE = "run"
+RUN_TC_TAG = {"Key": RUN_TC_TAG_KEY, "Value": RUN_TC_TAG_VALUE}
+
+
+class SortByType(Enum):
+ """The type of property by which to sort the `list_runs` results."""
+
+ CREATION_TIME = "CreationTime"
+ NAME = "Name"
+
+
+class SortOrderType(Enum):
+ """The type of order to sort the list or search results."""
+
+ ASCENDING = "Ascending"
+ DESCENDING = "Descending"
+
+
+class Run(object):
+ """A collection of parameters, metrics, and artifacts to create a ML model."""
+
+ def __init__(
+ self,
+ experiment_name: str,
+ run_name: Optional[str] = None,
+ experiment_display_name: Optional[str] = None,
+ run_display_name: Optional[str] = None,
+ tags: Optional[List[Dict[str, str]]] = None,
+ sagemaker_session: Optional["Session"] = None,
+ ):
+ """Construct a `Run` instance.
+
+ SageMaker Experiments automatically tracks the inputs, parameters, configurations,
+ and results of your iterations as runs.
+ You can assign, group, and organize these runs into experiments.
+ You can also create, compare, and evaluate runs.
+
+ The code sample below shows how to initialize a run, log parameters to the Run object
+ and invoke a training job under the context of this Run object, which automatically
+ passes the run's ``experiment_config`` (including the experiment name, run name etc.)
+ to the training job.
+
+ Note:
+ All log methods (e.g. ``log_parameter``, ``log_metric``, etc.) have to be called within
+ the run context (i.e. the ``with`` statement). Otherwise, a ``RuntimeError`` is thrown.
+
+ .. code:: python
+
+ with Run(experiment_name="my-exp", run_name="my-run", ...) as run:
+ run.log_parameter(...)
+ ...
+ estimator.fit(job_name="my-job") # Create a training job
+
+ In order to reuse an existing run to log extra data, ``load_run`` is recommended.
+ The code snippet below displays how to load the run initialized above
+ in a custom training job script, where no ``run_name`` or ``experiment_name``
+ is presented as they are automatically retrieved from the experiment config
+ in the job environment.
+
+ Note:
+ Instead of the ``Run`` constructor, the ``load_run`` is recommended to use
+ in a job script to load the existing run created before the job launch.
+ Otherwise, a new run may be created each time you launch a job.
+
+ .. code:: python
+
+ with load_run() as run:
+ run.log_metric(...)
+ ...
+
+ Args:
+ experiment_name (str): The name of the experiment. The name must be unique
+ within an account.
+ run_name (str): The name of the run. If it is not specified, one is auto generated.
+ experiment_display_name (str): Name of the experiment that will appear in UI,
+ such as SageMaker Studio. (default: None). This display name is used in
+ a create experiment call. If an experiment with the specified name already exists,
+ this display name won't take effect.
+ run_display_name (str): The display name of the run used in UI (default: None).
+ This display name is used in a create run call. If a run with the
+ specified name already exists, this display name won't take effect.
+ tags (List[Dict[str, str]]): A list of tags to be used for all create calls,
+ e.g. to create an experiment, a run group, etc. (default: None).
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+ """
+ # TODO: we should revert the lower casting once backend fix reaches prod
+ self.experiment_name = experiment_name.lower()
+ sagemaker_session = sagemaker_session or _utils.default_session()
+ self.run_name = run_name or unique_name_from_base(RUN_NAME_BASE)
+
+ # avoid confusion due to mis-match in casing between run name and TC name
+ self.run_name = self.run_name.lower()
+
+ trial_component_name = Run._generate_trial_component_name(
+ run_name=self.run_name, experiment_name=self.experiment_name
+ )
+ self.run_group_name = Run._generate_trial_name(self.experiment_name)
+
+ self._experiment = _Experiment._load_or_create(
+ experiment_name=self.experiment_name,
+ display_name=experiment_display_name,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+
+ self._trial = _Trial._load_or_create(
+ experiment_name=self.experiment_name,
+ trial_name=self.run_group_name,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+
+ self._trial_component, is_existed = _TrialComponent._load_or_create(
+ trial_component_name=trial_component_name,
+ display_name=run_display_name,
+ tags=Run._append_run_tc_label_to_tags(tags),
+ sagemaker_session=sagemaker_session,
+ )
+ if is_existed:
+ logger.info(
+ "The run (%s) under experiment (%s) already exists. Loading it. "
+ "Note: sagemaker.experiments.load_run is recommended to use when "
+ "the desired run already exists.",
+ self.run_name,
+ self.experiment_name,
+ )
+ self._trial.add_trial_component(self._trial_component)
+
+ self._artifact_uploader = _ArtifactUploader(
+ trial_component_name=self._trial_component.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ self._lineage_artifact_tracker = _LineageArtifactTracker(
+ trial_component_arn=self._trial_component.trial_component_arn,
+ sagemaker_session=sagemaker_session,
+ )
+ self._metrics_manager = _MetricsManager(
+ trial_component_name=self._trial_component.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ self._inside_init_context = False
+ self._inside_load_context = False
+ self._in_load = False
+
+ @property
+ def experiment_config(self) -> dict:
+ """Get experiment config from run attributes."""
+ return {
+ EXPERIMENT_NAME: self.experiment_name,
+ TRIAL_NAME: self.run_group_name,
+ RUN_NAME: self._trial_component.trial_component_name,
+ }
+
+ @validate_invoked_inside_run_context
+ def log_parameter(self, name: str, value: Union[str, int, float]):
+ """Record a single parameter value for this run.
+
+ Overwrites any previous value recorded for the specified parameter name.
+
+ Args:
+ name (str): The name of the parameter.
+ value (str or int or float): The value of the parameter.
+ """
+ if self._is_input_valid("parameter", name, value):
+ self._trial_component.parameters[name] = value
+
+ @validate_invoked_inside_run_context
+ def log_parameters(self, parameters: Dict[str, Union[str, int, float]]):
+ """Record a collection of parameter values for this run.
+
+ Args:
+ parameters (dict[str, str or int or float]): The parameters to record.
+ """
+ filtered_parameters = {
+ key: value
+ for (key, value) in parameters.items()
+ if self._is_input_valid("parameter", key, value)
+ }
+ self._trial_component.parameters.update(filtered_parameters)
+
+ @validate_invoked_inside_run_context
+ def log_metric(
+ self,
+ name: str,
+ value: float,
+ timestamp: Optional[datetime.datetime] = None,
+ step: Optional[int] = None,
+ ):
+ """Record a custom scalar metric value for this run.
+
+ Note:
+ This method is for manual custom metrics, for automatic metrics see the
+ ``enable_sagemaker_metrics`` parameter on the ``estimator`` class.
+
+ Args:
+ name (str): The name of the metric.
+ value (float): The value of the metric.
+ timestamp (datetime.datetime): The timestamp of the metric.
+ If not specified, the current UTC time will be used.
+ step (int): The integer iteration number of the metric value (default: None).
+ """
+ if self._is_input_valid("metric", name, value):
+ self._metrics_manager.log_metric(
+ metric_name=name, value=value, timestamp=timestamp, step=step
+ )
+
+ @validate_invoked_inside_run_context
+ def log_precision_recall(
+ self,
+ y_true: Union[list, array],
+ predicted_probabilities: Union[list, array],
+ positive_label: Optional[Union[str, int]] = None,
+ title: Optional[str] = None,
+ is_output: bool = True,
+ no_skill: Optional[int] = None,
+ ):
+ """Create and log a precision recall graph artifact for Studio UI to render.
+
+ The artifact is stored in S3 and represented as a lineage artifact
+ with an association with the run.
+
+ You can view the artifact in the UI.
+ If your job is created by a pipeline execution you can view the artifact
+ by selecting the corresponding step in the pipelines UI.
+ See also `SageMaker Pipelines `_
+
+ This method requires sklearn library.
+
+ Args:
+ y_true (list or array): True labels. If labels are not binary
+ then positive_label should be given.
+ predicted_probabilities (list or array): Estimated/predicted probabilities.
+ positive_label (str or int): Label of the positive class (default: None).
+ title (str): Title of the graph (default: None).
+ is_output (bool): Determines direction of association to the
+ run. Defaults to True (output artifact).
+ If set to False then represented as input association.
+ no_skill (int): The precision threshold under which the classifier cannot discriminate
+ between the classes and would predict a random class or a constant class in
+ all cases (default: None).
+ """
+
+ verify_length_of_true_and_predicted(
+ true_labels=y_true,
+ predicted_attrs=predicted_probabilities,
+ predicted_attrs_name="predicted probabilities",
+ )
+
+ get_module("sklearn")
+ from sklearn.metrics import precision_recall_curve, average_precision_score
+
+ kwargs = {}
+ if positive_label is not None:
+ kwargs["pos_label"] = positive_label
+
+ precision, recall, _ = precision_recall_curve(y_true, predicted_probabilities, **kwargs)
+
+ kwargs["average"] = "micro"
+ ap = average_precision_score(y_true, predicted_probabilities, **kwargs)
+
+ data = {
+ "type": "PrecisionRecallCurve",
+ "version": 0,
+ "title": title,
+ "precision": precision.tolist(),
+ "recall": recall.tolist(),
+ "averagePrecisionScore": ap,
+ "noSkill": no_skill,
+ }
+ self._log_graph_artifact(
+ artifact_name=title, data=data, graph_type="PrecisionRecallCurve", is_output=is_output
+ )
+
+ @validate_invoked_inside_run_context
+ def log_roc_curve(
+ self,
+ y_true: Union[list, array],
+ y_score: Union[list, array],
+ title: Optional[str] = None,
+ is_output: bool = True,
+ ):
+ """Create and log a receiver operating characteristic (ROC curve) artifact.
+
+ The artifact is stored in S3 and represented as a lineage artifact
+ with an association with the run.
+
+ You can view the artifact in the UI.
+ If your job is created by a pipeline execution you can view the artifact
+ by selecting the corresponding step in the pipelines UI.
+ See also `SageMaker Pipelines `_
+
+ This method requires sklearn library.
+
+ Args:
+ y_true (list or array): True labels. If labels are not binary
+ then positive_label should be given.
+ y_score (list or array): Estimated/predicted probabilities.
+ title (str): Title of the graph (default: None).
+ is_output (bool): Determines direction of association to the
+ run. Defaults to True (output artifact).
+ If set to False then represented as input association.
+ """
+ verify_length_of_true_and_predicted(
+ true_labels=y_true, predicted_attrs=y_score, predicted_attrs_name="predicted scores"
+ )
+
+ get_module("sklearn")
+ from sklearn.metrics import roc_curve, auc
+
+ fpr, tpr, _ = roc_curve(y_true, y_score)
+
+ auc = auc(fpr, tpr)
+
+ data = {
+ "type": "ROCCurve",
+ "version": 0,
+ "title": title,
+ "falsePositiveRate": fpr.tolist(),
+ "truePositiveRate": tpr.tolist(),
+ "areaUnderCurve": auc,
+ }
+ self._log_graph_artifact(
+ artifact_name=title, data=data, graph_type="ROCCurve", is_output=is_output
+ )
+
+ @validate_invoked_inside_run_context
+ def log_confusion_matrix(
+ self,
+ y_true: Union[list, array],
+ y_pred: Union[list, array],
+ title: Optional[str] = None,
+ is_output: bool = True,
+ ):
+ """Create and log a confusion matrix artifact.
+
+ The artifact is stored in S3 and represented as a lineage artifact
+ with an association with the run.
+
+ You can view the artifact in the UI.
+ If your job is created by a pipeline execution you can view the
+ artifact by selecting the corresponding step in the pipelines UI.
+ See also `SageMaker Pipelines `_
+ This method requires sklearn library.
+
+ Args:
+ y_true (list or array): True labels. If labels are not binary
+ then positive_label should be given.
+ y_pred (list or array): Predicted labels.
+ title (str): Title of the graph (default: None).
+ is_output (bool): Determines direction of association to the
+ run. Defaults to True (output artifact).
+ If set to False then represented as input association.
+ """
+ verify_length_of_true_and_predicted(
+ true_labels=y_true, predicted_attrs=y_pred, predicted_attrs_name="predicted labels"
+ )
+
+ get_module("sklearn")
+ from sklearn.metrics import confusion_matrix
+
+ matrix = confusion_matrix(y_true, y_pred)
+
+ data = {
+ "type": "ConfusionMatrix",
+ "version": 0,
+ "title": title,
+ "confusionMatrix": matrix.tolist(),
+ }
+ self._log_graph_artifact(
+ artifact_name=title, data=data, graph_type="ConfusionMatrix", is_output=is_output
+ )
+
+ @validate_invoked_inside_run_context
+ def log_artifact(
+ self, name: str, value: str, media_type: Optional[str] = None, is_output: bool = True
+ ):
+ """Record a single artifact for this run.
+
+ Overwrites any previous value recorded for the specified name.
+
+ Args:
+ name (str): The name of the artifact.
+ value (str): The value.
+ media_type (str): The MediaType (MIME type) of the value (default: None).
+ is_output (bool): Determines direction of association to the
+ run. Defaults to True (output artifact).
+ If set to False then represented as input association.
+ """
+ self._verify_trial_component_artifacts_length(is_output=is_output)
+ if is_output:
+ self._trial_component.output_artifacts[name] = TrialComponentArtifact(
+ value, media_type=media_type
+ )
+ else:
+ self._trial_component.input_artifacts[name] = TrialComponentArtifact(
+ value, media_type=media_type
+ )
+
+ @validate_invoked_inside_run_context
+ def log_file(
+ self,
+ file_path: str,
+ name: Optional[str] = None,
+ media_type: Optional[str] = None,
+ is_output: bool = True,
+ ):
+ """Upload a file to s3 and store it as an input/output artifact in this run.
+
+ Args:
+ file_path (str): The path of the local file to upload.
+ name (str): The name of the artifact (default: None).
+ media_type (str): The MediaType (MIME type) of the file.
+ If not specified, this library will attempt to infer the media type
+ from the file extension of ``file_path``.
+ is_output (bool): Determines direction of association to the
+ run. Defaults to True (output artifact).
+ If set to False then represented as input association.
+ """
+ self._verify_trial_component_artifacts_length(is_output)
+ media_type = media_type or guess_media_type(file_path)
+ name = name or resolve_artifact_name(file_path)
+ s3_uri, _ = self._artifact_uploader.upload_artifact(file_path)
+ if is_output:
+ self._trial_component.output_artifacts[name] = TrialComponentArtifact(
+ value=s3_uri, media_type=media_type
+ )
+ else:
+ self._trial_component.input_artifacts[name] = TrialComponentArtifact(
+ value=s3_uri, media_type=media_type
+ )
+
+ def close(self):
+ """Persist any data saved locally."""
+ try:
+ # Update the trial component with additions from the Run object
+ self._trial_component.save()
+ # Create Lineage entities for the artifacts
+ self._lineage_artifact_tracker.save()
+ finally:
+ if self._metrics_manager:
+ self._metrics_manager.close()
+
+ @staticmethod
+ def _generate_trial_name(base_name) -> str:
+ """Generate the reserved trial name based on experiment name
+
+ Args:
+ base_name (str): The ``experiment_name`` of this ``Run`` object.
+ """
+ available_length = MAX_NAME_LEN_IN_BACKEND - len(TRIAL_NAME_TEMPLATE)
+ return TRIAL_NAME_TEMPLATE.format(base_name[:available_length])
+
+ @staticmethod
+ def _is_input_valid(input_type, field_name, field_value) -> bool:
+ """Check if the input is valid or not
+
+ Args:
+ input_type (str): The type of the input, one of ``parameter``, ``metric``.
+ field_name (str): The name of the field to be checked.
+ field_value (str or int or float): The value of the field to be checked.
+ """
+ if isinstance(field_value, Number) and (isnan(field_value) or isinf(field_value)):
+ logger.warning(
+ "Failed to log %s %s. Received invalid value: %s.",
+ input_type,
+ field_name,
+ field_value,
+ )
+ return False
+ return True
+
+ def _log_graph_artifact(self, data, graph_type, is_output, artifact_name=None):
+ """Log an artifact.
+
+ Logs an artifact by uploading data to S3, creating an artifact, and associating that
+ artifact with the run trial component.
+
+ Args:
+ data (dict): Artifacts data that will be saved to S3.
+ graph_type (str): The type of the artifact.
+ is_output (bool): Determines direction of association to the
+ trial component. Defaults to True (output artifact).
+ If set to False then represented as input association.
+ artifact_name (str): Name of the artifact (default: None).
+ """
+ # generate an artifact name
+ if not artifact_name:
+ unique_name_from_base(graph_type)
+
+ # create a json file in S3
+ s3_uri, etag = self._artifact_uploader.upload_object_artifact(
+ artifact_name, data, file_extension="json"
+ )
+
+ # create an artifact and association for the table
+ if is_output:
+ self._lineage_artifact_tracker.add_output_artifact(
+ name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type
+ )
+ else:
+ self._lineage_artifact_tracker.add_input_artifact(
+ name=artifact_name, source_uri=s3_uri, etag=etag, artifact_type=graph_type
+ )
+
+ def _verify_trial_component_artifacts_length(self, is_output):
+ """Verify the length of trial component artifacts
+
+ Args:
+ is_output (bool): Determines direction of association to the
+ trial component.
+
+ Raises:
+ ValueError: If the length of trial component artifacts exceeds the limit.
+ """
+ err_msg_template = "Cannot add more than {} {}_artifacts under run"
+ if is_output:
+ if len(self._trial_component.output_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
+ raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "output"))
+ else:
+ if len(self._trial_component.input_artifacts) >= MAX_RUN_TC_ARTIFACTS_LEN:
+ raise ValueError(err_msg_template.format(MAX_RUN_TC_ARTIFACTS_LEN, "input"))
+
+ @staticmethod
+ def _generate_trial_component_name(run_name: str, experiment_name: str) -> str:
+ """Generate the TrialComponentName based on run_name and experiment_name
+
+ Args:
+ run_name (str): The run_name supplied by the user.
+ experiment_name (str): The experiment_name supplied by the user,
+ which is prepended to the run_name to generate the TrialComponentName.
+
+ Returns:
+ str: The TrialComponentName used to create a trial component
+ which is unique in an account.
+
+ Raises:
+ ValueError: If either the run_name or the experiment_name exceeds
+ the length limit.
+ """
+ buffer = 1 # leave length buffers for delimiters
+ max_len = int(MAX_NAME_LEN_IN_BACKEND / 2) - buffer
+ err_msg_template = "The {} (length: {}) must have length less than or equal to {}"
+ if len(run_name) > max_len:
+ raise ValueError(err_msg_template.format("run_name", len(run_name), max_len))
+ if len(experiment_name) > max_len:
+ raise ValueError(
+ err_msg_template.format("experiment_name", len(experiment_name), max_len)
+ )
+ trial_component_name = "{}{}{}".format(experiment_name, DELIMITER, run_name)
+ # due to mixed-case concerns on the backend
+ trial_component_name = trial_component_name.lower()
+ return trial_component_name
+
+ @staticmethod
+ def _extract_run_name_from_tc_name(trial_component_name: str, experiment_name: str) -> str:
+ """Extract the user supplied run name from a trial component name.
+
+ Args:
+ trial_component_name (str): The name of a run trial component.
+ experiment_name (str): The experiment_name supplied by the user,
+ which was prepended to the run_name to generate the trial_component_name.
+
+ Returns:
+ str: The name of the Run object supplied by a user.
+ """
+ return trial_component_name.replace("{}{}".format(experiment_name, DELIMITER), "", 1)
+
+ @staticmethod
+ def _append_run_tc_label_to_tags(tags: Optional[List[Dict[str, str]]] = None) -> list:
+ """Append the run trial component label to tags used to create a trial component.
+
+ Args:
+ tags (List[Dict[str, str]]): The tags supplied by users to initialize a Run object.
+
+ Returns:
+ list: The updated tags with the appended run trial component label.
+ """
+ if not tags:
+ tags = []
+ tags.append(RUN_TC_TAG)
+ return tags
+
+ def __enter__(self):
+ """Updates the start time of the run.
+
+ Returns:
+ object: self.
+ """
+ nested_with_err_msg_template = (
+ "It is not allowed to use nested 'with' statements on the {}."
+ )
+ if self._in_load:
+ if self._inside_load_context:
+ raise RuntimeError(nested_with_err_msg_template.format("load_run"))
+ self._inside_load_context = True
+ else:
+ if _RunContext.get_current_run():
+ raise RuntimeError(nested_with_err_msg_template.format("Run"))
+ self._inside_init_context = True
+ _RunContext.add_run_object(self)
+
+ if not self._trial_component.start_time:
+ start_time = datetime.datetime.now(dateutil.tz.tzlocal())
+ self._trial_component.start_time = start_time
+ self._trial_component.status = _api_types.TrialComponentStatus(
+ primary_status=_TrialComponentStatusType.InProgress.value,
+ message="Within a run context",
+ )
+ # Save the start_time and status changes to backend
+ self._trial_component.save()
+ return self
+
+ def __exit__(self, exc_type, exc_value, exc_traceback):
+ """Updates the end time of the run.
+
+ Args:
+ exc_type (str): The exception type.
+ exc_value (str): The exception value.
+ exc_traceback (str): The stack trace of the exception.
+ """
+ if self._in_load:
+ self._inside_load_context = False
+ self._in_load = False
+ else:
+ self._inside_init_context = False
+ _RunContext.drop_current_run()
+
+ end_time = datetime.datetime.now(dateutil.tz.tzlocal())
+ self._trial_component.end_time = end_time
+ if exc_value:
+ self._trial_component.status = _api_types.TrialComponentStatus(
+ primary_status=_TrialComponentStatusType.Failed.value, message=str(exc_value)
+ )
+ else:
+ self._trial_component.status = _api_types.TrialComponentStatus(
+ primary_status=_TrialComponentStatusType.Completed.value
+ )
+
+ self.close()
+
+
+def load_run(
+ run_name: Optional[str] = None,
+ experiment_name: Optional[str] = None,
+ sagemaker_session: Optional["Session"] = None,
+) -> Run:
+ """Load an existing run.
+
+ In order to reuse an existing run to log extra data, ``load_run`` is recommended.
+ It can be used in several ways:
+
+ 1. Use ``load_run`` by explicitly passing in ``run_name`` and ``experiment_name``.
+
+ If ``run_name`` and ``experiment_name`` are passed in, they are honored over
+ the default experiment config in the job environment or the run context
+ (i.e. within the ``with`` block).
+
+ Note:
+ Both ``run_name`` and ``experiment_name`` should be supplied to make this usage work.
+ Otherwise, you may get a ``ValueError``.
+
+ .. code:: python
+
+ with load_run(experiment_name="my-exp", run_name="my-run") as run:
+ run.log_metric(...)
+ ...
+
+ 2. Use the ``load_run`` in a job script without supplying ``run_name`` and ``experiment_name``.
+
+ In this case, the default experiment config (specified when creating the job) is fetched
+ from the job environment to load the run.
+
+ .. code:: python
+
+ # In a job script
+ with load_run() as run:
+ run.log_metric(...)
+ ...
+
+ 3. Use the ``load_run`` in a notebook within a run context (i.e. the ``with`` block)
+ but without supplying ``run_name`` and ``experiment_name``.
+
+ Every time we call ``with Run(...) as run1:``, the initialized ``run1`` is tracked
+ in the run context. Then when we call ``load_run()`` under this with statement, the ``run1``
+ in the context is loaded by default.
+
+ .. code:: python
+
+ # In a notebook
+ with Run(experiment_name="my-exp", run_name="my-run", ...) as run1:
+ run1.log_parameter(...)
+
+ with load_run() as run2: # run2 is the same object as run1
+ run2.log_metric(...)
+ ...
+
+ Args:
+ run_name (str): The name of the run to be loaded (default: None).
+ If it is None, the ``RunName`` in the ``ExperimentConfig`` of the job will be
+ fetched to load the run.
+ experiment_name (str): The name of the Experiment that the to be loaded run
+ is associated with (default: None).
+ Note: the experiment_name must be supplied along with a valid run_name.
+ Otherwise, it will be ignored.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ Run: The loaded Run object.
+ """
+ sagemaker_session = sagemaker_session or _utils.default_session()
+ environment = _RunEnvironment.load()
+
+ verify_load_input_names(run_name=run_name, experiment_name=experiment_name)
+
+ if run_name or environment:
+ if run_name:
+ logger.warning(
+ "run_name is explicitly supplied in load_run, "
+ "which will be prioritized to load the Run object. "
+ "In other words, the run name in the experiment config, fetched from the "
+ "job environment or the current run context, will be ignored."
+ )
+ else:
+ exp_config = get_tc_and_exp_config_from_job_env(
+ environment=environment, sagemaker_session=sagemaker_session
+ )
+ run_name = Run._extract_run_name_from_tc_name(
+ trial_component_name=exp_config[RUN_NAME],
+ experiment_name=exp_config[EXPERIMENT_NAME],
+ )
+ experiment_name = exp_config[EXPERIMENT_NAME]
+
+ run_instance = Run(
+ experiment_name=experiment_name,
+ run_name=run_name,
+ sagemaker_session=sagemaker_session,
+ )
+ elif _RunContext.get_current_run():
+ run_instance = _RunContext.get_current_run()
+ else:
+ raise RuntimeError(
+ "Failed to load a Run object. "
+ "Please make sure a Run object has been initialized already."
+ )
+
+ run_instance._in_load = True
+ return run_instance
+
+
+def list_runs(
+ experiment_name: str,
+ created_before: Optional[datetime.datetime] = None,
+ created_after: Optional[datetime.datetime] = None,
+ sagemaker_session: Optional["Session"] = None,
+ max_results: Optional[int] = None,
+ next_token: Optional[str] = None,
+ sort_by: SortByType = SortByType.CREATION_TIME,
+ sort_order: SortOrderType = SortOrderType.DESCENDING,
+) -> list:
+ """Return a list of ``Run`` objects matching the given criteria.
+
+ Args:
+ experiment_name (str): Only Run objects related to the specified experiment
+ are returned.
+ created_before (datetime.datetime): Return Run objects created before this instant
+ (default: None).
+ created_after (datetime.datetime): Return Run objects created after this instant
+ (default: None).
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+ max_results (int): Maximum number of Run objects to retrieve (default: None).
+ next_token (str): Token for next page of results (default: None).
+ sort_by (SortByType): The property to sort results by. One of NAME, CREATION_TIME
+ (default: CREATION_TIME).
+ sort_order (SortOrderType): One of ASCENDING, or DESCENDING (default: DESCENDING).
+
+ Returns:
+ list: A list of ``Run`` objects.
+ """
+ tc_summaries = _TrialComponent.list(
+ experiment_name=experiment_name,
+ created_before=created_before,
+ created_after=created_after,
+ sort_by=sort_by.value,
+ sort_order=sort_order.value,
+ sagemaker_session=sagemaker_session,
+ max_results=max_results,
+ next_token=next_token,
+ )
+ run_list = []
+ for tc_summary in tc_summaries:
+ if not is_run_trial_component(
+ trial_component_name=tc_summary.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ ):
+ continue
+ run_instance = Run(
+ experiment_name=experiment_name,
+ run_name=Run._extract_run_name_from_tc_name(
+ trial_component_name=tc_summary.trial_component_name,
+ experiment_name=experiment_name,
+ ),
+ sagemaker_session=sagemaker_session,
+ )
+ run_list.append(run_instance)
+ return run_list
diff --git a/src/sagemaker/experiments/trial.py b/src/sagemaker/experiments/trial.py
new file mode 100644
index 0000000000..146b24f18b
--- /dev/null
+++ b/src/sagemaker/experiments/trial.py
@@ -0,0 +1,289 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains the Trial class."""
+from __future__ import absolute_import
+
+from sagemaker.apiutils import _base_types
+from sagemaker.experiments import _api_types
+from sagemaker.experiments.trial_component import _TrialComponent
+
+
+class _Trial(_base_types.Record):
+ """An execution of a data-science workflow with an experiment.
+
+ Consists of a list of trial component objects, which document individual
+ activities within the workflow.
+
+ Attributes:
+ trial_name (str): The name of the trial.
+ experiment_name (str): The name of the trial's experiment.
+ display_name (str): The name of the trial that will appear in UI,
+ such as SageMaker Studio.
+ tags (List[Dict[str, str]]): A list of tags to associate with the trial.
+ """
+
+ trial_name = None
+ experiment_name = None
+ display_name = None
+ tags = None
+
+ _boto_create_method = "create_trial"
+ _boto_load_method = "describe_trial"
+ _boto_delete_method = "delete_trial"
+ _boto_update_method = "update_trial"
+
+ _boto_update_members = ["trial_name", "display_name"]
+ _boto_delete_members = ["trial_name"]
+
+ @classmethod
+ def _boto_ignore(cls):
+ """Response fields to ignore by default."""
+ return super(_Trial, cls)._boto_ignore() + ["CreatedBy"]
+
+ def save(self):
+ """Save the state of this Trial to SageMaker.
+
+ Returns:
+ dict: Update trial response.
+ """
+ return self._invoke_api(self._boto_update_method, self._boto_update_members)
+
+ def delete(self):
+ """Delete this Trial from SageMaker.
+
+ Does not delete associated Trial Components.
+
+ Returns:
+ dict: Delete trial response.
+ """
+ return self._invoke_api(self._boto_delete_method, self._boto_delete_members)
+
+ @classmethod
+ def load(cls, trial_name, sagemaker_session=None):
+ """Load an existing trial and return a `_Trial` object.
+
+ Args:
+ trial_name: (str): Name of the Trial.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ experiments.trial._Trial: A SageMaker `_Trial` object
+ """
+ return super(_Trial, cls)._construct(
+ cls._boto_load_method,
+ trial_name=trial_name,
+ sagemaker_session=sagemaker_session,
+ )
+
+ @classmethod
+ def create(
+ cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None
+ ):
+ """Create a new trial and return a `_Trial` object.
+
+ Args:
+ experiment_name: (str): Name of the experiment to create this trial in.
+ trial_name: (str): Name of the Trial.
+ display_name (str): Name of the trial that will appear in UI,
+ such as SageMaker Studio (default: None).
+ tags (List[dict]): A list of tags to associate with the trial (default: None).
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ experiments.trial._Trial: A SageMaker `_Trial` object
+ """
+ trial = super(_Trial, cls)._construct(
+ cls._boto_create_method,
+ trial_name=trial_name,
+ experiment_name=experiment_name,
+ display_name=display_name,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+ return trial
+
+ @classmethod
+ def list(
+ cls,
+ experiment_name=None,
+ trial_component_name=None,
+ created_before=None,
+ created_after=None,
+ sort_by=None,
+ sort_order=None,
+ sagemaker_session=None,
+ ):
+ """List all trials matching the specified criteria.
+
+ Args:
+ experiment_name (str): Name of the experiment. If specified, only trials in
+ the experiment will be returned (default: None).
+ trial_component_name (str): Name of the trial component. If specified, only
+ trials with this trial component name will be returned (default: None).
+ created_before (datetime.datetime): Return trials created before this instant
+ (default: None).
+ created_after (datetime.datetime): Return trials created after this instant
+ (default: None).
+ sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime'
+ (default: None).
+ sort_order (str): One of 'Ascending', or 'Descending' (default: None).
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+ Returns:
+ collections.Iterator[experiments._api_types.TrialSummary]: An iterator over trials
+ matching the specified criteria.
+ """
+ return super(_Trial, cls)._list(
+ "list_trials",
+ _api_types.TrialSummary.from_boto,
+ "TrialSummaries",
+ experiment_name=experiment_name,
+ trial_component_name=trial_component_name,
+ created_before=created_before,
+ created_after=created_after,
+ sort_by=sort_by,
+ sort_order=sort_order,
+ sagemaker_session=sagemaker_session,
+ )
+
+ def add_trial_component(self, trial_component):
+ """Add the specified trial component to this trial.
+
+ A trial component may belong to many trials and a trial may have many trial components.
+
+ Args:
+ trial_component (str or _TrialComponent): The trial component to add.
+ Can be one of a _TrialComponent instance, or a string containing
+ the name of the trial component to add.
+ """
+ if isinstance(trial_component, _TrialComponent):
+ trial_component_name = trial_component.trial_component_name
+ elif isinstance(trial_component, str):
+ trial_component_name = trial_component
+ else:
+ raise TypeError(
+ "Unsupported type of trail component {}. "
+ "It has to be one type of _TrialComponent or str".format(trial_component)
+ )
+ self.sagemaker_session.sagemaker_client.associate_trial_component(
+ TrialName=self.trial_name, TrialComponentName=trial_component_name
+ )
+
+ def remove_trial_component(self, trial_component):
+ """Remove the specified trial component from this trial.
+
+ Args:
+ trial_component (str or _TrialComponent): The trial component to add.
+ Can be one of a _TrialComponent instance, or a string containing
+ the name of the trial component to add.
+ """
+ if isinstance(trial_component, _TrialComponent):
+ trial_component_name = trial_component.trial_component_name
+ elif isinstance(trial_component, str):
+ trial_component_name = trial_component
+ else:
+ raise TypeError(
+ "Unsupported type of trail component {}. "
+ "It has to be one type of _TrialComponent or str".format(trial_component)
+ )
+ self.sagemaker_session.sagemaker_client.disassociate_trial_component(
+ TrialName=self.trial_name, TrialComponentName=trial_component_name
+ )
+
+ def list_trial_components(
+ self,
+ created_before=None,
+ created_after=None,
+ sort_by=None,
+ sort_order=None,
+ max_results=None,
+ next_token=None,
+ ):
+ """List trial components in this trial matching the specified criteria.
+
+ Args:
+ created_before (datetime.datetime): Return trials created before this instant
+ (default: None).
+ created_after (datetime.datetime): Return trials created after this instant
+ (default: None).
+ sort_by (str): Which property to sort results by. One of 'Name',
+ 'CreationTime' (default: None).
+ sort_order (str): One of 'Ascending', or 'Descending' (default: None).
+ max_results (int): maximum number of trial components to retrieve (default: None).
+ next_token (str): token for next page of results (default: None).
+
+ Returns:
+ collections.Iterator[experiments._api_types.TrialComponentSummary] : An iterator over
+ trials matching the criteria.
+ """
+ return _TrialComponent.list(
+ trial_name=self.trial_name,
+ created_before=created_before,
+ created_after=created_after,
+ sort_by=sort_by,
+ sort_order=sort_order,
+ max_results=max_results,
+ next_token=next_token,
+ sagemaker_session=self.sagemaker_session,
+ )
+
+ @classmethod
+ def _load_or_create(
+ cls, experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None
+ ):
+ """Load a trial by name and create a new one if it does not exist.
+
+ Args:
+ experiment_name: (str): Name of the experiment to create this trial in.
+ trial_name: (str): Name of the Trial.
+ display_name (str): Name of the trial that will appear in UI,
+ such as SageMaker Studio (default: None). This is used only when the given
+ `trial_name` does not exist and a new trial has to be created.
+ tags (List[dict]): A list of tags to associate with the trial (default: None).
+ This is used only when the given `trial_name` does not exist and
+ a new trial has to be created.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ experiments.trial._Trial: A SageMaker `_Trial` object
+ """
+ sagemaker_client = sagemaker_session.sagemaker_client
+ try:
+ trial = _Trial.load(trial_name, sagemaker_session)
+ if trial.experiment_name != experiment_name: # pylint: disable=no-member
+ raise ValueError(
+ "The given experiment_name {} ".format(experiment_name)
+ + "does not match that in the loaded trial {}".format(
+ trial.experiment_name # pylint: disable=no-member
+ )
+ )
+ except sagemaker_client.exceptions.ResourceNotFound:
+ trial = _Trial.create(
+ experiment_name=experiment_name,
+ trial_name=trial_name,
+ display_name=display_name,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+ return trial
diff --git a/src/sagemaker/experiments/trial_component.py b/src/sagemaker/experiments/trial_component.py
new file mode 100644
index 0000000000..e5701b2119
--- /dev/null
+++ b/src/sagemaker/experiments/trial_component.py
@@ -0,0 +1,341 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Contains the TrialComponent class."""
+from __future__ import absolute_import
+
+import time
+
+from sagemaker.apiutils import _base_types
+from sagemaker.experiments import _api_types
+from sagemaker.experiments._api_types import TrialComponentSearchResult
+
+
+class _TrialComponent(_base_types.Record):
+ """This class represents a SageMaker trial component object.
+
+ A trial component is a stage in a trial.
+ Trial components are created automatically within the SageMaker runtime and
+ may not be created directly. To automatically associate trial components with
+ a trial and experiment, supply an experiment config when creating a job.
+ For example: https://docs.aws.amazon.com/sagemaker/latest/dg/API_CreateTrainingJob.html
+
+ Attributes:
+ trial_component_name (str): The name of the trial component. Generated by SageMaker
+ from the name of the source job with a suffix specific to the type of source job.
+ trial_component_arn (str): The ARN of the trial component.
+ display_name (str): The name of the trial component that will appear in UI,
+ such as SageMaker Studio.
+ source (TrialComponentSource): A TrialComponentSource object with a source_arn attribute.
+ status (str): Status of the source job.
+ start_time (datetime): When the source job started.
+ end_time (datetime): When the source job ended.
+ creation_time (datetime): When the source job was created.
+ created_by (obj): Contextual info on which account created the trial component.
+ last_modified_time (datetime): When the trial component was last modified.
+ last_modified_by (obj): Contextual info on which account last modified the trial component.
+ parameters (dict): Dictionary of parameters to the source job.
+ input_artifacts (dict): Dictionary of input artifacts.
+ output_artifacts (dict): Dictionary of output artifacts.
+ metrics (obj): Aggregated metrics for the job.
+ parameters_to_remove (list): The hyperparameters to remove from the component.
+ input_artifacts_to_remove (list): The input artifacts to remove from the component.
+ output_artifacts_to_remove (list): The output artifacts to remove from the component.
+ tags (List[Dict[str, str]]): A list of tags to associate with the trial component.
+ """
+
+ trial_component_name = None
+ trial_component_arn = None
+ display_name = None
+ source = None
+ status = None
+ start_time = None
+ end_time = None
+ creation_time = None
+ created_by = None
+ last_modified_time = None
+ last_modified_by = None
+ parameters = None
+ input_artifacts = None
+ output_artifacts = None
+ metrics = None
+ parameters_to_remove = None
+ input_artifacts_to_remove = None
+ output_artifacts_to_remove = None
+ tags = None
+
+ _boto_load_method = "describe_trial_component"
+ _boto_create_method = "create_trial_component"
+ _boto_update_method = "update_trial_component"
+ _boto_delete_method = "delete_trial_component"
+
+ _custom_boto_types = {
+ "source": (_api_types.TrialComponentSource, False),
+ "status": (_api_types.TrialComponentStatus, False),
+ "parameters": (_api_types.TrialComponentParameters, False),
+ "input_artifacts": (_api_types.TrialComponentArtifact, True),
+ "output_artifacts": (_api_types.TrialComponentArtifact, True),
+ "metrics": (_api_types.TrialComponentMetricSummary, True),
+ }
+
+ _boto_update_members = [
+ "trial_component_name",
+ "display_name",
+ "status",
+ "start_time",
+ "end_time",
+ "parameters",
+ "input_artifacts",
+ "output_artifacts",
+ "parameters_to_remove",
+ "input_artifacts_to_remove",
+ "output_artifacts_to_remove",
+ ]
+ _boto_delete_members = ["trial_component_name"]
+
+ def __init__(self, sagemaker_session=None, **kwargs):
+ """Init for _TrialComponent"""
+ super().__init__(sagemaker_session, **kwargs)
+ self.parameters = self.parameters or {}
+ self.input_artifacts = self.input_artifacts or {}
+ self.output_artifacts = self.output_artifacts or {}
+
+ @classmethod
+ def _boto_ignore(cls):
+ """Response fields to ignore by default."""
+ return super(_TrialComponent, cls)._boto_ignore() + ["CreatedBy"]
+
+ def save(self):
+ """Save the state of this TrialComponent to SageMaker."""
+ return self._invoke_api(self._boto_update_method, self._boto_update_members)
+
+ def delete(self, force_disassociate=False):
+ """Delete this TrialComponent from SageMaker.
+
+ Args:
+ force_disassociate (boolean): Indicates whether to force disassociate the
+ trial component with the trials before deletion (default: False).
+ If set to true, force disassociate the trial component with associated trials
+ first, then delete the trial component.
+ If it's not set or set to false, it will delete the trial component directory
+ without disassociation.
+
+ Returns:
+ dict: Delete trial component response.
+ """
+ if force_disassociate:
+ next_token = None
+
+ while True:
+ if next_token:
+ list_trials_response = self.sagemaker_session.sagemaker_client.list_trials(
+ TrialComponentName=self.trial_component_name, NextToken=next_token
+ )
+ else:
+ list_trials_response = self.sagemaker_session.sagemaker_client.list_trials(
+ TrialComponentName=self.trial_component_name
+ )
+
+ # Disassociate the trials and trial components
+ for per_trial in list_trials_response["TrialSummaries"]:
+ # to prevent DisassociateTrialComponent throttling
+ time.sleep(1.2)
+ self.sagemaker_session.sagemaker_client.disassociate_trial_component(
+ TrialName=per_trial["TrialName"],
+ TrialComponentName=self.trial_component_name,
+ )
+
+ if "NextToken" in list_trials_response:
+ next_token = list_trials_response["NextToken"]
+ else:
+ break
+
+ return self._invoke_api(self._boto_delete_method, self._boto_delete_members)
+
+ @classmethod
+ def load(cls, trial_component_name, sagemaker_session=None):
+ """Load an existing trial component and return an `_TrialComponent` object representing it.
+
+ Args:
+ trial_component_name (str): Name of the trial component
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object
+ """
+ trial_component = cls._construct(
+ cls._boto_load_method,
+ trial_component_name=trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ return trial_component
+
+ @classmethod
+ def create(cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None):
+ """Create a trial component and return a `_TrialComponent` object representing it.
+
+ Args:
+ trial_component_name (str): The name of the trial component.
+ display_name (str): Display name of the trial component used by Studio (default: None).
+ tags (List[Dict[str, str]]): Tags to add to the trial component (default: None).
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
+ """
+ return super(_TrialComponent, cls)._construct(
+ cls._boto_create_method,
+ trial_component_name=trial_component_name,
+ display_name=display_name,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+
+ @classmethod
+ def list(
+ cls,
+ source_arn=None,
+ created_before=None,
+ created_after=None,
+ sort_by=None,
+ sort_order=None,
+ sagemaker_session=None,
+ trial_name=None,
+ experiment_name=None,
+ max_results=None,
+ next_token=None,
+ ):
+ """Return a list of trial component summaries.
+
+ Args:
+ source_arn (str): A SageMaker Training or Processing Job ARN (default: None).
+ created_before (datetime.datetime): Return trial components created before this instant
+ (default: None).
+ created_after (datetime.datetime): Return trial components created after this instant
+ (default: None).
+ sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime'
+ (default: None).
+ sort_order (str): One of 'Ascending', or 'Descending' (default: None).
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+ trial_name (str): If provided only trial components related to the trial are returned
+ (default: None).
+ experiment_name (str): If provided only trial components related to the experiment are
+ returned (default: None).
+ max_results (int): maximum number of trial components to retrieve (default: None).
+ next_token (str): token for next page of results (default: None).
+ Returns:
+ collections.Iterator[experiments._api_types.TrialComponentSummary]: An iterator
+ over `TrialComponentSummary` objects.
+ """
+ return super(_TrialComponent, cls)._list(
+ "list_trial_components",
+ _api_types.TrialComponentSummary.from_boto,
+ "TrialComponentSummaries",
+ source_arn=source_arn,
+ created_before=created_before,
+ created_after=created_after,
+ sort_by=sort_by,
+ sort_order=sort_order,
+ sagemaker_session=sagemaker_session,
+ trial_name=trial_name,
+ experiment_name=experiment_name,
+ max_results=max_results,
+ next_token=next_token,
+ )
+
+ @classmethod
+ def search(
+ cls,
+ search_expression=None,
+ sort_by=None,
+ sort_order=None,
+ max_results=None,
+ sagemaker_session=None,
+ ):
+ """Search Experiment Trail Component.
+
+ Returns SearchResults in the account matching the search criteria.
+
+ Args:
+ search_expression: (SearchExpression): A Boolean conditional statement (default: None).
+ Resource objects must satisfy this condition to be included in search results.
+ You must provide at least one subexpression, filter, or nested filter.
+ sort_by (str): The name of the resource property used to sort the SearchResults
+ (default: None).
+ sort_order (str): How SearchResults are ordered. Valid values are Ascending or
+ Descending (default: None).
+ max_results (int): The maximum number of results to return in a SearchResponse
+ (default: None).
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ collections.Iterator[SearchResult] : An iterator over search results matching the
+ search criteria.
+ """
+ return super(_TrialComponent, cls)._search(
+ search_resource="ExperimentTrialComponent",
+ search_item_factory=TrialComponentSearchResult.from_boto,
+ search_expression=None if search_expression is None else search_expression.to_boto(),
+ sort_by=sort_by,
+ sort_order=sort_order,
+ max_results=max_results,
+ sagemaker_session=sagemaker_session,
+ )
+
+ @classmethod
+ def _load_or_create(
+ cls, trial_component_name, display_name=None, tags=None, sagemaker_session=None
+ ):
+ """Load a trial component by name and create a new one if it does not exist.
+
+ Args:
+ trial_component_name (str): The name of the trial component.
+ display_name (str): Display name of the trial component used by Studio (default: None).
+ This is used only when the given `trial_component_name` does not
+ exist and a new trial component has to be created.
+ tags (List[Dict[str, str]]): Tags to add to the trial component (default: None).
+ This is used only when the given `trial_component_name` does not
+ exist and a new trial component has to be created.
+ sagemaker_session (sagemaker.session.Session): Session object which
+ manages interactions with Amazon SageMaker APIs and any other
+ AWS services needed. If not specified, one is created using the
+ default AWS configuration chain.
+
+ Returns:
+ experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
+ bool: A boolean variable indicating whether the trail component already exists
+ """
+ sagemaker_client = sagemaker_session.sagemaker_client
+ is_existed = False
+ try:
+ run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
+ is_existed = True
+ except sagemaker_client.exceptions.ResourceNotFound:
+ run_tc = _TrialComponent.create(
+ trial_component_name=trial_component_name,
+ display_name=display_name,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+ return run_tc, is_existed
diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py
new file mode 100644
index 0000000000..fc82997379
--- /dev/null
+++ b/src/sagemaker/feature_store/dataset_builder.py
@@ -0,0 +1,990 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Dataset Builder
+
+A Dataset Builder is a builder class for generating a dataset by providing conditions.
+"""
+from __future__ import absolute_import
+
+import datetime
+from enum import Enum
+import os
+from typing import Any, Dict, List, Tuple, Union
+
+import attr
+import pandas as pd
+
+from sagemaker import Session, s3, utils
+from sagemaker.feature_store.feature_group import FeatureDefinition, FeatureGroup, FeatureTypeEnum
+
+
+_DEFAULT_CATALOG = "AwsDataCatalog"
+_DEFAULT_DATABASE = "sagemaker_featurestore"
+
+
+@attr.s
+class TableType(Enum):
+ """Enum of Table types.
+
+ The data type of a table can be FeatureGroup or DataFrame.
+ """
+
+ FEATURE_GROUP = "FeatureGroup"
+ DATA_FRAME = "DataFrame"
+
+
+@attr.s
+class FeatureGroupToBeMerged:
+ """FeatureGroup metadata which will be used for SQL join.
+
+ This class instantiates a FeatureGroupToBeMerged object that comprises a list of feature names,
+ a list of feature names which will be included in SQL query, a database, an Athena table name,
+ a feature name of record identifier, a feature name of event time identifier and a feature name
+ of base which is the target join key.
+
+ Attributes:
+ features (List[str]): A list of strings representing feature names of this FeatureGroup.
+ included_feature_names (List[str]): A list of strings representing features to be
+ included in the sql join.
+ projected_feature_names (List[str]): A list of strings representing features to be
+ included for final projection in output.
+ catalog (str): A string representing the catalog.
+ database (str): A string representing the database.
+ table_name (str): A string representing the Athena table name of this FeatureGroup.
+ record_dentifier_feature_name (str): A string representing the record identifier feature.
+ event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the
+ event time identifier feature.
+ target_feature_name_in_base (str): A string representing the feature name in base which will
+ be used as target join key (default: None).
+ table_type (TableType): A TableType representing the type of table if it is Feature Group or
+ Panda Data Frame (default: None).
+ """
+
+ features: List[str] = attr.ib()
+ included_feature_names: List[str] = attr.ib()
+ projected_feature_names: List[str] = attr.ib()
+ catalog: str = attr.ib()
+ database: str = attr.ib()
+ table_name: str = attr.ib()
+ record_identifier_feature_name: str = attr.ib()
+ event_time_identifier_feature: FeatureDefinition = attr.ib()
+ target_feature_name_in_base: str = attr.ib(default=None)
+ table_type: TableType = attr.ib(default=None)
+
+
+def construct_feature_group_to_be_merged(
+ feature_group: FeatureGroup,
+ included_feature_names: List[str],
+ target_feature_name_in_base: str = None,
+) -> FeatureGroupToBeMerged:
+ """Construct a FeatureGroupToBeMerged object by provided parameters.
+
+ Args:
+ feature_group (FeatureGroup): A FeatureGroup object.
+ included_feature_names (List[str]): A list of strings representing features to be
+ included in the output.
+ target_feature_name_in_base (str): A string representing the feature name in base which
+ will be used as target join key (default: None).
+ Returns:
+ A FeatureGroupToBeMerged object.
+
+ Raises:
+ ValueError: Invalid feature name(s) in included_feature_names.
+ """
+ feature_group_metadata = feature_group.describe()
+ data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get(
+ "DataCatalogConfig", None
+ )
+ if not data_catalog_config:
+ raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.")
+
+ record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None)
+ feature_definitions = feature_group_metadata.get("FeatureDefinitions", [])
+ event_time_identifier_feature_name = feature_group_metadata.get("EventTimeFeatureName", None)
+ event_time_identifier_feature_type = FeatureTypeEnum(
+ next(
+ filter(
+ lambda f: f.get("FeatureName", None) == event_time_identifier_feature_name,
+ feature_definitions,
+ ),
+ {},
+ ).get("FeatureType", None)
+ )
+ table_name = data_catalog_config.get("TableName", None)
+ database = data_catalog_config.get("Database", None)
+ disable_glue = feature_group_metadata.get("DisableGlueTableCreation", False)
+ catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG
+ features = [feature.get("FeatureName", None) for feature in feature_definitions]
+
+ for included_feature in included_feature_names or []:
+ if included_feature not in features:
+ raise ValueError(
+ f"Feature {included_feature} not found in FeatureGroup {feature_group.name}"
+ )
+ if not included_feature_names:
+ included_feature_names = features
+ projected_feature_names = features.copy()
+ else:
+ projected_feature_names = included_feature_names.copy()
+ if record_identifier_feature_name not in included_feature_names:
+ included_feature_names.append(record_identifier_feature_name)
+ if event_time_identifier_feature_name not in included_feature_names:
+ included_feature_names.append(event_time_identifier_feature_name)
+ return FeatureGroupToBeMerged(
+ features,
+ included_feature_names,
+ projected_feature_names,
+ catalog,
+ database,
+ table_name,
+ record_identifier_feature_name,
+ FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type),
+ target_feature_name_in_base,
+ TableType.FEATURE_GROUP,
+ )
+
+
+@attr.s
+class DatasetBuilder:
+ """DatasetBuilder definition.
+
+ This class instantiates a DatasetBuilder object that comprises a base, a list of feature names,
+ an output path and a KMS key ID.
+
+ Attributes:
+ _sagemaker_session (Session): Session instance to perform boto calls.
+ _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
+ pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset.
+ _output_path (str): An S3 URI which stores the output .csv file.
+ _record_identifier_feature_name (str): A string representing the record identifier feature
+ if base is a DataFrame (default: None).
+ _event_time_identifier_feature_name (str): A string representing the event time identifier
+ feature if base is a DataFrame (default: None).
+ _included_feature_names (List[str]): A list of strings representing features to be
+ included in the output (default: None).
+ _kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file
+ (default: None).
+ _point_in_time_accurate_join (bool): A boolean representing whether using point in time join
+ or not (default: False).
+ _include_duplicated_records (bool): A boolean representing whether including duplicated
+ records or not (default: False).
+ _include_deleted_records (bool): A boolean representing whether including deleted records or
+ not (default: False).
+ _number_of_recent_records (int): An int that how many records will be returned for each
+ record identifier (default: 1).
+ _number_of_records (int): An int that how many records will be returned (default: None).
+ _write_time_ending_timestamp (datetime.datetime): A datetime that all records' write time in
+ dataset will be before it (default: None).
+ _event_time_starting_timestamp (datetime.datetime): A datetime that all records' event time
+ in dataset will be after it (default: None).
+ _event_time_ending_timestamp (datetime.datetime): A datetime that all records' event time in
+ dataset will be before it (default: None).
+ _feature_groups_to_be_merged (List[FeatureGroupToBeMerged]): A list of
+ FeatureGroupToBeMerged which will be joined to base (default: []).
+ _event_time_identifier_feature_type (FeatureTypeEnum): A FeatureTypeEnum representing the
+ type of event time identifier feature (default: None).
+ """
+
+ _sagemaker_session: Session = attr.ib()
+ _base: Union[FeatureGroup, pd.DataFrame] = attr.ib()
+ _output_path: str = attr.ib()
+ _record_identifier_feature_name: str = attr.ib(default=None)
+ _event_time_identifier_feature_name: str = attr.ib(default=None)
+ _included_feature_names: List[str] = attr.ib(default=None)
+ _kms_key_id: str = attr.ib(default=None)
+
+ _point_in_time_accurate_join: bool = attr.ib(init=False, default=False)
+ _include_duplicated_records: bool = attr.ib(init=False, default=False)
+ _include_deleted_records: bool = attr.ib(init=False, default=False)
+ _number_of_recent_records: int = attr.ib(init=False, default=None)
+ _number_of_records: int = attr.ib(init=False, default=None)
+ _write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None)
+ _event_time_starting_timestamp: datetime.datetime = attr.ib(init=False, default=None)
+ _event_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None)
+ _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = attr.ib(init=False, factory=list)
+ _event_time_identifier_feature_type: FeatureTypeEnum = attr.ib(default=None)
+
+ _DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP = {
+ "object": "STRING",
+ "int64": "INT",
+ "float64": "DOUBLE",
+ "bool": "BOOLEAN",
+ "datetime64[ns]": "TIMESTAMP",
+ }
+
+ def with_feature_group(
+ self,
+ feature_group: FeatureGroup,
+ target_feature_name_in_base: str = None,
+ included_feature_names: List[str] = None,
+ ):
+ """Join FeatureGroup with base.
+
+ Args:
+ feature_group (FeatureGroup): A FeatureGroup which will be joined to base.
+ target_feature_name_in_base (str): A string representing the feature name in base which
+ will be used as target join key (default: None).
+ included_feature_names (List[str]): A list of strings representing features to be
+ included in the output (default: None).
+ Returns:
+ This DatasetBuilder object.
+ """
+ self._feature_groups_to_be_merged.append(
+ construct_feature_group_to_be_merged(
+ feature_group, included_feature_names, target_feature_name_in_base
+ )
+ )
+ return self
+
+ def point_in_time_accurate_join(self):
+ """Set join type as point in time accurate join.
+
+ Returns:
+ This DatasetBuilder object.
+ """
+ self._point_in_time_accurate_join = True
+ return self
+
+ def include_duplicated_records(self):
+ """Include duplicated records in dataset.
+
+ Returns:
+ This DatasetBuilder object.
+ """
+ self._include_duplicated_records = True
+ return self
+
+ def include_deleted_records(self):
+ """Include deleted records in dataset.
+
+ Returns:
+ This DatasetBuilder object.
+ """
+ self._include_deleted_records = True
+ return self
+
+ def with_number_of_recent_records_by_record_identifier(self, number_of_recent_records: int):
+ """Set number_of_recent_records field with provided input.
+
+ Args:
+ number_of_recent_records (int): An int that how many recent records will be returned for
+ each record identifier.
+ Returns:
+ This DatasetBuilder object.
+ """
+ self._number_of_recent_records = number_of_recent_records
+ return self
+
+ def with_number_of_records_from_query_results(self, number_of_records: int):
+ """Set number_of_records field with provided input.
+
+ Args:
+ number_of_records (int): An int that how many records will be returned.
+ Returns:
+ This DatasetBuilder object.
+ """
+ self._number_of_records = number_of_records
+ return self
+
+ def as_of(self, timestamp: datetime.datetime):
+ """Set write_time_ending_timestamp field with provided input.
+
+ Args:
+ timestamp (datetime.datetime): A datetime that all records' write time in dataset will
+ be before it.
+ Returns:
+ This DatasetBuilder object.
+ """
+ self._write_time_ending_timestamp = timestamp
+ return self
+
+ def with_event_time_range(
+ self,
+ starting_timestamp: datetime.datetime = None,
+ ending_timestamp: datetime.datetime = None,
+ ):
+ """Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs.
+
+ Args:
+ starting_timestamp (datetime.datetime): A datetime that all records' event time in
+ dataset will be after it (default: None).
+ ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset
+ will be before it (default: None).
+ Returns:
+ This DatasetBuilder object.
+ """
+ self._event_time_starting_timestamp = starting_timestamp
+ self._event_time_ending_timestamp = ending_timestamp
+ return self
+
+ def to_csv_file(self) -> Tuple[str, str]:
+ """Get query string and result in .csv format file
+
+ Returns:
+ The S3 path of the .csv file.
+ The query string executed.
+ """
+ if isinstance(self._base, pd.DataFrame):
+ temp_id = utils.unique_name_from_base("dataframe-base")
+ local_file_name = f"{temp_id}.csv"
+ desired_s3_folder = f"{self._output_path}/{temp_id}"
+ self._base.to_csv(local_file_name, index=False, header=False)
+ s3.S3Uploader.upload(
+ local_path=local_file_name,
+ desired_s3_uri=desired_s3_folder,
+ sagemaker_session=self._sagemaker_session,
+ kms_key=self._kms_key_id,
+ )
+ os.remove(local_file_name)
+ temp_table_name = f'dataframe_{temp_id.replace("-", "_")}'
+ self._create_temp_table(temp_table_name, desired_s3_folder)
+ base_features = list(self._base.columns)
+ event_time_identifier_feature_dtype = self._base[
+ self._event_time_identifier_feature_name
+ ].dtypes
+ self._event_time_identifier_feature_type = (
+ FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
+ str(event_time_identifier_feature_dtype), None
+ )
+ )
+ query_string = self._construct_query_string(
+ FeatureGroupToBeMerged(
+ base_features,
+ self._included_feature_names if self._included_feature_names else base_features,
+ self._included_feature_names if self._included_feature_names else base_features,
+ _DEFAULT_CATALOG,
+ _DEFAULT_DATABASE,
+ temp_table_name,
+ self._record_identifier_feature_name,
+ FeatureDefinition(
+ self._event_time_identifier_feature_name,
+ self._event_time_identifier_feature_type,
+ ),
+ None,
+ TableType.DATA_FRAME,
+ )
+ )
+ query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
+ # TODO: cleanup temp table, need more clarification, keep it for now
+ return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
+ "OutputLocation", None
+ ), query_result.get("QueryExecution", {}).get("Query", None)
+ if isinstance(self._base, FeatureGroup):
+ base_feature_group = construct_feature_group_to_be_merged(
+ self._base, self._included_feature_names
+ )
+ self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name
+ self._event_time_identifier_feature_name = (
+ base_feature_group.event_time_identifier_feature.feature_name
+ )
+ self._event_time_identifier_feature_type = (
+ base_feature_group.event_time_identifier_feature.feature_type
+ )
+ query_string = self._construct_query_string(base_feature_group)
+ query_result = self._run_query(
+ query_string,
+ base_feature_group.catalog,
+ base_feature_group.database,
+ )
+ return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get(
+ "OutputLocation", None
+ ), query_result.get("QueryExecution", {}).get("Query", None)
+ raise ValueError("Base must be either a FeatureGroup or a DataFrame.")
+
+ def to_dataframe(self) -> Tuple[pd.DataFrame, str]:
+ """Get query string and result in pandas.Dataframe
+
+ Returns:
+ The pandas.DataFrame object.
+ The query string executed.
+ """
+ csv_file, query_string = self.to_csv_file()
+ s3.S3Downloader.download(
+ s3_uri=csv_file,
+ local_path="./",
+ kms_key=self._kms_key_id,
+ sagemaker_session=self._sagemaker_session,
+ )
+ local_file_name = csv_file.split("/")[-1]
+ df = pd.read_csv(local_file_name)
+ os.remove(local_file_name)
+
+ local_metadata_file_name = local_file_name + ".metadata"
+ if os.path.exists(local_metadata_file_name):
+ os.remove(local_file_name + ".metadata")
+
+ if "row_recent" in df:
+ df = df.drop("row_recent", axis="columns")
+ return df, query_string
+
+ def _construct_event_time_conditions(
+ self,
+ table_name: str,
+ event_time_identifier_feature: FeatureDefinition,
+ ) -> List[str]:
+ """Internal method for constructing event time range sql range as string.
+
+ Args:
+ table_name (str): name of the table.
+ event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the
+ event time identifier feature.
+ Returns:
+ The list of query strings.
+ """
+ event_time_conditions = []
+ timestamp_cast_function_name = "from_unixtime"
+ if event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING:
+ timestamp_cast_function_name = "from_iso8601_timestamp"
+ if self._event_time_starting_timestamp:
+ event_time_conditions.append(
+ f"{timestamp_cast_function_name}({table_name}."
+ + f'"{event_time_identifier_feature.feature_name}") >= '
+ + f"from_unixtime({self._event_time_starting_timestamp.timestamp()})"
+ )
+ if self._event_time_ending_timestamp:
+ event_time_conditions.append(
+ f"{timestamp_cast_function_name}({table_name}."
+ + f'"{event_time_identifier_feature.feature_name}") <= '
+ + f"from_unixtime({self._event_time_ending_timestamp.timestamp()})"
+ )
+ return event_time_conditions
+
+ def _construct_write_time_condition(
+ self,
+ table_name: str,
+ ) -> str:
+ """Internal method for constructing write time condition.
+
+ Args:
+ table_name (str): name of the table.
+ Returns:
+ string of write time condition.
+ """
+ write_time_condition = (
+ f'{table_name}."write_time" <= '
+ f"to_timestamp('{self._write_time_ending_timestamp.replace(microsecond=0)}', "
+ f"'yyyy-mm-dd hh24:mi:ss')"
+ )
+ return write_time_condition
+
+ def _construct_where_query_string(
+ self,
+ suffix: str,
+ event_time_identifier_feature: FeatureDefinition,
+ where_conditions: List[str],
+ ) -> str:
+ """Internal method for constructing SQL WHERE query string by parameters.
+
+ Args:
+ suffix (str): A temp identifier of the FeatureGroup.
+ event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the
+ event time identifier feature.
+ where_conditions (List[str]): A list of strings representing existing where clauses.
+ Returns:
+ The WHERE query string.
+
+ Raises:
+ ValueError: FeatureGroup not provided while using as_of(). Only found pandas.DataFrame.
+ """
+ if self._number_of_recent_records:
+ if self._number_of_recent_records < 0:
+ raise ValueError(
+ "Please provide non-negative integer for number_of_recent_records."
+ )
+ if self._number_of_records:
+ if self._number_of_records < 0:
+ raise ValueError("Please provide non-negative integer for number_of_records.")
+ if self._include_deleted_records:
+ if isinstance(self._base, pd.DataFrame):
+ if len(self._feature_groups_to_be_merged) == 0:
+ raise ValueError(
+ "include_deleted_records() only works for FeatureGroup,"
+ " if there is no join operation."
+ )
+ if self._include_duplicated_records:
+ if isinstance(self._base, pd.DataFrame):
+ if len(self._feature_groups_to_be_merged) == 0:
+ raise ValueError(
+ "include_duplicated_records() only works for FeatureGroup,"
+ " if there is no join operation."
+ )
+ if self._point_in_time_accurate_join:
+ if len(self._feature_groups_to_be_merged) == 0:
+ raise ValueError(
+ "point_in_time_accurate_join() this operation only works when there is "
+ "more than one feature group to join."
+ )
+ if self._write_time_ending_timestamp:
+ if isinstance(self._base, pd.DataFrame):
+ if len(self._feature_groups_to_be_merged) == 0:
+ raise ValueError(
+ "as_of() only works for FeatureGroup," " if there is no join operation."
+ )
+ if isinstance(self._base, FeatureGroup):
+ if self._write_time_ending_timestamp:
+ where_conditions.append(self._construct_write_time_condition(f"table_{suffix}"))
+
+ event_time_conditions = self._construct_event_time_conditions(
+ f"table_{suffix}", event_time_identifier_feature
+ )
+ where_conditions.extend(event_time_conditions)
+
+ if len(where_conditions) == 0:
+ return ""
+ return "WHERE " + "\nAND ".join(where_conditions)
+
+ def _construct_dedup_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str:
+ """Internal method for constructing removing duplicate records SQL query string.
+
+ Args:
+ feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the
+ FeatureGroup metadata.
+ suffix (str): A temp identifier of the FeatureGroup.
+ Returns:
+ The SQL query string.
+ """
+ record_feature_name = feature_group.record_identifier_feature_name
+ event_time_identifier_feature = feature_group.event_time_identifier_feature
+ event_time_feature_name = feature_group.event_time_identifier_feature.feature_name
+ rank_query_string = ""
+ where_conditions = []
+ where_conditions_str = ""
+ is_dedup_enabled = False
+
+ if feature_group.table_type is TableType.FEATURE_GROUP:
+ is_dedup_enabled = True
+ rank_query_string = (
+ f'ORDER BY origin_{suffix}."api_invocation_time" DESC, '
+ + f'origin_{suffix}."write_time" DESC\n'
+ )
+
+ if self._write_time_ending_timestamp:
+ where_conditions.append(self._construct_write_time_condition(f"origin_{suffix}"))
+
+ event_time_conditions = self._construct_event_time_conditions(
+ f"origin_{suffix}", event_time_identifier_feature
+ )
+ where_conditions.extend(event_time_conditions)
+
+ if len(where_conditions) != 0:
+ where_conditions_str = "WHERE " + "\nAND ".join(where_conditions) + "\n"
+
+ dedup_where_clause = f"WHERE dedup_row_{suffix} = 1\n" if is_dedup_enabled else ""
+ return (
+ f"table_{suffix} AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + f'PARTITION BY origin_{suffix}."{record_feature_name}", '
+ + f'origin_{suffix}."{event_time_feature_name}"\n'
+ + rank_query_string
+ + f") AS dedup_row_{suffix}\n"
+ + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n'
+ + where_conditions_str
+ + ")\n"
+ + dedup_where_clause
+ + ")"
+ )
+
+ def _construct_deleted_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str:
+ """Internal method for constructing removing deleted records SQL query string.
+
+ Args:
+ feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the
+ FeatureGroup metadata.
+ suffix (str): A temp identifier of the FeatureGroup.
+ Returns:
+ The SQL query string.
+ """
+ record_feature_name = feature_group.record_identifier_feature_name
+ event_time_identifier_feature = feature_group.event_time_identifier_feature
+ event_time_feature_name = feature_group.event_time_identifier_feature.feature_name
+ rank_query_string = f'ORDER BY origin_{suffix}."{event_time_feature_name}" DESC'
+ write_time_condition = "\n"
+ event_time_starting_condition = ""
+ event_time_ending_condition = ""
+
+ if feature_group.table_type is TableType.FEATURE_GROUP:
+ rank_query_string += (
+ f', origin_{suffix}."api_invocation_time" DESC, '
+ + f'origin_{suffix}."write_time" DESC\n'
+ )
+
+ if self._write_time_ending_timestamp:
+ write_time_condition += " AND "
+ write_time_condition += self._construct_write_time_condition(f"origin_{suffix}")
+ write_time_condition += "\n"
+
+ if self._event_time_starting_timestamp and self._event_time_ending_timestamp:
+ event_time_conditions = self._construct_event_time_conditions(
+ f"origin_{suffix}", event_time_identifier_feature
+ )
+ event_time_starting_condition = "AND " + event_time_conditions[0] + "\n"
+ event_time_ending_condition = "AND " + event_time_conditions[1] + "\n"
+
+ return (
+ f"deleted_{suffix} AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + f'PARTITION BY origin_{suffix}."{record_feature_name}"\n'
+ + rank_query_string
+ + f") AS deleted_row_{suffix}\n"
+ + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n'
+ + "WHERE is_deleted"
+ + write_time_condition
+ + event_time_starting_condition
+ + event_time_ending_condition
+ + ")\n"
+ + f"WHERE deleted_row_{suffix} = 1\n"
+ + ")"
+ )
+
+ def _construct_table_included_features(
+ self, feature_group: FeatureGroupToBeMerged, suffix: str
+ ) -> str:
+ """Internal method for constructing included features string of table.
+
+ Args:
+ feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object
+ which has the metadata.
+ suffix (str): A temp identifier of the table.
+ Returns:
+ The string that includes all feature to be included of table.
+ """
+
+ included_features = ", ".join(
+ [
+ f'table_{suffix}."{include_feature_name}"'
+ for include_feature_name in feature_group.included_feature_names
+ ]
+ )
+ return included_features
+
+ def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str:
+ """Internal method for constructing SQL query string by parameters.
+
+ Args:
+ feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the
+ FeatureGroup metadata.
+ suffix (str): A temp identifier of the FeatureGroup.
+ Returns:
+ The query string.
+ """
+ included_features = self._construct_table_included_features(feature_group, suffix)
+
+ # If base is a FeatureGroup then included_features_write_time will have a write_time column
+ # Or included_features_write_time is same as included_features
+ included_features_write_time = included_features
+
+ if feature_group.table_type is TableType.FEATURE_GROUP:
+ included_features_write_time += f', table_{suffix}."write_time"'
+ record_feature_name = feature_group.record_identifier_feature_name
+ event_time_feature_name = feature_group.event_time_identifier_feature.feature_name
+ if self._include_duplicated_records and self._include_deleted_records:
+ return (
+ f"SELECT {included_features}\n"
+ + f'FROM "{feature_group.database}"."{feature_group.table_name}" table_{suffix}\n'
+ + self._construct_where_query_string(
+ suffix, feature_group.event_time_identifier_feature, ["NOT is_deleted"]
+ )
+ )
+ if feature_group.table_type is TableType.FEATURE_GROUP and self._include_deleted_records:
+ rank_query_string = ""
+ if feature_group.table_type is TableType.FEATURE_GROUP:
+ rank_query_string = (
+ f'ORDER BY origin_{suffix}."api_invocation_time" DESC, '
+ + f'origin_{suffix}."write_time" DESC\n'
+ )
+ return (
+ f"SELECT {included_features}\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + f'PARTITION BY origin_{suffix}."{record_feature_name}", '
+ + f'origin_{suffix}."{event_time_feature_name}"\n'
+ + rank_query_string
+ + f") AS row_{suffix}\n"
+ + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n'
+ + "WHERE NOT is_deleted"
+ + f") AS table_{suffix}\n"
+ + self._construct_where_query_string(
+ suffix,
+ feature_group.event_time_identifier_feature,
+ [f"row_{suffix} = 1"],
+ )
+ )
+ rank_query_string = ""
+ if feature_group.table_type is TableType.FEATURE_GROUP:
+ rank_query_string = (
+ f'OR (table_{suffix}."{event_time_feature_name}" = '
+ + f'deleted_{suffix}."{event_time_feature_name}" '
+ + f'AND table_{suffix}."api_invocation_time" > '
+ + f'deleted_{suffix}."api_invocation_time")\n'
+ + f'OR (table_{suffix}."{event_time_feature_name}" = '
+ + f'deleted_{suffix}."{event_time_feature_name}" '
+ + f'AND table_{suffix}."api_invocation_time" = '
+ + f'deleted_{suffix}."api_invocation_time" '
+ + f'AND table_{suffix}."write_time" > deleted_{suffix}."write_time")\n'
+ )
+
+ final_query_string = ""
+ if feature_group.table_type is TableType.FEATURE_GROUP:
+ if self._include_duplicated_records:
+ final_query_string = (
+ f"WITH {self._construct_deleted_query(feature_group, suffix)}\n"
+ + f"SELECT {included_features}\n"
+ + "FROM (\n"
+ + f"SELECT {included_features_write_time}\n"
+ + f'FROM "{feature_group.database}"."{feature_group.table_name}"'
+ + f" table_{suffix}\n"
+ + f"LEFT JOIN deleted_{suffix}\n"
+ + f'ON table_{suffix}."{record_feature_name}" = '
+ + f'deleted_{suffix}."{record_feature_name}"\n'
+ + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n'
+ + "UNION ALL\n"
+ + f"SELECT {included_features_write_time}\n"
+ + f"FROM deleted_{suffix}\n"
+ + f'JOIN "{feature_group.database}"."{feature_group.table_name}"'
+ + f" table_{suffix}\n"
+ + f'ON table_{suffix}."{record_feature_name}" = '
+ + f'deleted_{suffix}."{record_feature_name}"\n'
+ + "AND (\n"
+ + f'table_{suffix}."{event_time_feature_name}" > '
+ + f'deleted_{suffix}."{event_time_feature_name}"\n'
+ + rank_query_string
+ + ")\n"
+ + f") AS table_{suffix}\n"
+ + self._construct_where_query_string(
+ suffix, feature_group.event_time_identifier_feature, []
+ )
+ )
+ else:
+ final_query_string = (
+ f"WITH {self._construct_dedup_query(feature_group, suffix)},\n"
+ + f"{self._construct_deleted_query(feature_group, suffix)}\n"
+ + f"SELECT {included_features}\n"
+ + "FROM (\n"
+ + f"SELECT {included_features_write_time}\n"
+ + f"FROM table_{suffix}\n"
+ + f"LEFT JOIN deleted_{suffix}\n"
+ + f'ON table_{suffix}."{record_feature_name}" = '
+ + f'deleted_{suffix}."{record_feature_name}"\n'
+ + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n'
+ + "UNION ALL\n"
+ + f"SELECT {included_features_write_time}\n"
+ + f"FROM deleted_{suffix}\n"
+ + f"JOIN table_{suffix}\n"
+ + f'ON table_{suffix}."{record_feature_name}" = '
+ + f'deleted_{suffix}."{record_feature_name}"\n'
+ + "AND (\n"
+ + f'table_{suffix}."{event_time_feature_name}" > '
+ + f'deleted_{suffix}."{event_time_feature_name}"\n'
+ + rank_query_string
+ + ")\n"
+ + f") AS table_{suffix}\n"
+ + self._construct_where_query_string(
+ suffix, feature_group.event_time_identifier_feature, []
+ )
+ )
+ else:
+ final_query_string = (
+ f"WITH {self._construct_dedup_query(feature_group, suffix)}\n"
+ + f"SELECT {included_features}\n"
+ + "FROM (\n"
+ + f"SELECT {included_features_write_time}\n"
+ + f"FROM table_{suffix}\n"
+ + f") AS table_{suffix}\n"
+ + self._construct_where_query_string(
+ suffix, feature_group.event_time_identifier_feature, []
+ )
+ )
+ return final_query_string
+
+ def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str:
+ """Internal method for constructing SQL query string by parameters.
+
+ Args:
+ base (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the metadata.
+ Returns:
+ The query string.
+
+ Raises:
+ ValueError: target_feature_name_in_base is an invalid feature name.
+ """
+ base_table_query_string = self._construct_table_query(base, "base")
+ query_string = f"WITH fg_base AS ({base_table_query_string})"
+ if len(self._feature_groups_to_be_merged) > 0:
+ with_subquery_string = "".join(
+ [
+ f",\nfg_{i} AS ({self._construct_table_query(feature_group, str(i))})"
+ for i, feature_group in enumerate(self._feature_groups_to_be_merged)
+ ]
+ )
+ query_string += with_subquery_string
+
+ selected_features = ""
+ selected_features += ", ".join(map("fg_base.{0}".format, base.projected_feature_names))
+ if len(self._feature_groups_to_be_merged) > 0:
+ for i, feature_group in enumerate(self._feature_groups_to_be_merged):
+ selected_features += ", "
+ selected_features += ", ".join(
+ [
+ f'fg_{i}."{feature_name}" as "{feature_name}.{(i+1)}"'
+ for feature_name in feature_group.projected_feature_names
+ ]
+ )
+
+ selected_features_final = ""
+ selected_features_final += ", ".join(base.projected_feature_names)
+ if len(self._feature_groups_to_be_merged) > 0:
+ for i, feature_group in enumerate(self._feature_groups_to_be_merged):
+ selected_features_final += ", "
+ selected_features_final += ", ".join(
+ [
+ '"{0}.{1}"'.format(feature_name, (i + 1))
+ for feature_name in feature_group.projected_feature_names
+ ]
+ )
+
+ query_string += (
+ f"\nSELECT {selected_features_final}\n"
+ + "FROM (\n"
+ + f"SELECT {selected_features}, row_number() OVER (\n"
+ + f'PARTITION BY fg_base."{base.record_identifier_feature_name}"\n'
+ + f'ORDER BY fg_base."{base.event_time_identifier_feature.feature_name}" DESC'
+ )
+
+ recent_record_where_clause = ""
+ if self._number_of_recent_records is not None and self._number_of_recent_records >= 0:
+ recent_record_where_clause = f"WHERE row_recent <= {self._number_of_recent_records}"
+
+ join_subquery_strings = []
+ if len(self._feature_groups_to_be_merged) > 0:
+ for i, feature_group in enumerate(self._feature_groups_to_be_merged):
+ if not feature_group.target_feature_name_in_base:
+ feature_group.target_feature_name_in_base = self._record_identifier_feature_name
+ else:
+ if feature_group.target_feature_name_in_base not in base.features:
+ raise ValueError(
+ f"Feature {feature_group.target_feature_name_in_base} not found in base"
+ )
+ query_string += (
+ f', fg_{i}."{feature_group.event_time_identifier_feature.feature_name}" DESC'
+ )
+ join_subquery_strings.append(self._construct_join_condition(feature_group, str(i)))
+
+ query_string += (
+ "\n) AS row_recent\n"
+ + "FROM fg_base"
+ + "".join(join_subquery_strings)
+ + "\n)\n"
+ + f"{recent_record_where_clause}"
+ )
+
+ if self._number_of_records is not None and self._number_of_records >= 0:
+ query_string += f"\nLIMIT {self._number_of_records}"
+ return query_string
+
+ def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str:
+ """Internal method for constructing SQL JOIN query string by parameters.
+
+ Args:
+ feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the
+ FeatureGroup metadata.
+ suffix (str): A temp identifier of the FeatureGroup.
+ Returns:
+ The JOIN query string.
+ """
+ join_condition_string = (
+ f"\nJOIN fg_{suffix}\n"
+ + f'ON fg_base."{feature_group.target_feature_name_in_base}" = '
+ + f'fg_{suffix}."{feature_group.record_identifier_feature_name}"'
+ )
+ base_timestamp_cast_function_name = "from_unixtime"
+ if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING:
+ base_timestamp_cast_function_name = "from_iso8601_timestamp"
+ timestamp_cast_function_name = "from_unixtime"
+ if feature_group.event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING:
+ timestamp_cast_function_name = "from_iso8601_timestamp"
+ if self._point_in_time_accurate_join:
+ join_condition_string += (
+ f"\nAND {base_timestamp_cast_function_name}(fg_base."
+ + f'"{self._event_time_identifier_feature_name}") >= '
+ + f"{timestamp_cast_function_name}(fg_{suffix}."
+ + f'"{feature_group.event_time_identifier_feature.feature_name}")'
+ )
+ return join_condition_string
+
+ def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str):
+ """Internal method for creating a temp Athena table for the base pandas.Dataframe.
+
+ Args:
+ temp_table_name (str): The Athena table name of base pandas.DataFrame.
+ desired_s3_folder (str): The S3 URI of the folder of the data.
+ """
+ columns_string = ", ".join(
+ [self._construct_athena_table_column_string(column) for column in self._base.columns]
+ )
+ serde_properties = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"'
+ query_string = (
+ f"CREATE EXTERNAL TABLE {temp_table_name} ({columns_string}) "
+ + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' "
+ + f"WITH SERDEPROPERTIES ({serde_properties}) "
+ + f"LOCATION '{desired_s3_folder}';"
+ )
+ self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE)
+
+ def _construct_athena_table_column_string(self, column: str) -> str:
+ """Internal method for constructing string of Athena column.
+
+ Args:
+ column (str): The column name from pandas.Dataframe.
+ Returns:
+ The Athena column string.
+
+ Raises:
+ RuntimeError: The type of pandas.Dataframe column is not support yet.
+ """
+ dataframe_type = self._base[column].dtypes
+ if str(dataframe_type) not in self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.keys():
+ raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.")
+ return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}"
+
+ def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]:
+ """Internal method for execute Athena query, wait for query finish and get query result.
+
+ Args:
+ query_string (str): The SQL query statements to be executed.
+ catalog (str): The name of the data catalog used in the query execution.
+ database (str): The name of the database used in the query execution.
+ Returns:
+ The query result.
+
+ Raises:
+ RuntimeError: Athena query failed.
+ """
+ query = self._sagemaker_session.start_query_execution(
+ catalog=catalog,
+ database=database,
+ query_string=query_string,
+ output_location=self._output_path,
+ kms_key=self._kms_key_id,
+ )
+ query_id = query.get("QueryExecutionId", None)
+ self._sagemaker_session.wait_for_athena_query(query_execution_id=query_id)
+ query_result = self._sagemaker_session.get_query_execution(query_execution_id=query_id)
+ query_state = query_result.get("QueryExecution", {}).get("Status", {}).get("State", None)
+
+ if query_state != "SUCCEEDED":
+ raise RuntimeError(f"Failed to execute query {query_id}.")
+ return query_result
diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py
index d486ab8a01..855e11488f 100644
--- a/src/sagemaker/feature_store/feature_group.py
+++ b/src/sagemaker/feature_store/feature_group.py
@@ -435,13 +435,14 @@ class FeatureGroup:
"uint64",
]
_FLOAT_TYPES = ["float_", "float16", "float32", "float64"]
- _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = {
+ DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = {
type: FeatureTypeEnum.INTEGRAL for type in _INTEGER_TYPES
}
- _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update(
+ DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update(
{type: FeatureTypeEnum.FRACTIONAL for type in _FLOAT_TYPES}
)
- _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING
+ DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING
+ DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["object"] = FeatureTypeEnum.STRING
_FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = {
FeatureTypeEnum.INTEGRAL.value: "INT",
@@ -629,7 +630,7 @@ def load_feature_definitions(
"""
feature_definitions = []
for column in data_frame:
- feature_type = self._DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
+ feature_type = self.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get(
str(data_frame[column].dtype), None
)
if feature_type:
@@ -644,6 +645,23 @@ def load_feature_definitions(
self.feature_definitions = feature_definitions
return self.feature_definitions
+ def get_record(
+ self, record_identifier_value_as_string: str, feature_names: Sequence[str] = None
+ ) -> Sequence[Dict[str, str]]:
+ """Get a single record in a FeatureGroup
+
+ Args:
+ record_identifier_value_as_string (String):
+ a String representing the value of the record identifier.
+ feature_names (Sequence[String]):
+ a list of Strings representing feature names.
+ """
+ return self.sagemaker_session.get_record(
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ feature_group_name=self.name,
+ feature_names=feature_names,
+ ).get("Record")
+
def put_record(self, record: Sequence[FeatureValue]):
"""Put a single record in the FeatureGroup.
@@ -654,6 +672,25 @@ def put_record(self, record: Sequence[FeatureValue]):
feature_group_name=self.name, record=[value.to_dict() for value in record]
)
+ def delete_record(
+ self,
+ record_identifier_value_as_string: str,
+ event_time: str,
+ ):
+ """Delete a single record from a FeatureGroup.
+
+ Args:
+ record_identifier_value_as_string (String):
+ a String representing the value of the record identifier.
+ event_time (String):
+ a timestamp format String indicating when the deletion event occurred.
+ """
+ return self.sagemaker_session.delete_record(
+ feature_group_name=self.name,
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ event_time=event_time,
+ )
+
def ingest(
self,
data_frame: DataFrame,
diff --git a/src/sagemaker/feature_store/feature_store.py b/src/sagemaker/feature_store/feature_store.py
new file mode 100644
index 0000000000..def8b2b2da
--- /dev/null
+++ b/src/sagemaker/feature_store/feature_store.py
@@ -0,0 +1,130 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Feature Store.
+
+Amazon SageMaker Feature Store is a fully managed, purpose-built repository to store, share, and
+manage features for machine learning (ML) models.
+"""
+from __future__ import absolute_import
+
+import datetime
+from typing import Any, Dict, Sequence, Union
+
+import attr
+import pandas as pd
+
+from sagemaker import Session
+from sagemaker.feature_store.dataset_builder import DatasetBuilder
+from sagemaker.feature_store.feature_group import FeatureGroup
+
+
+@attr.s
+class FeatureStore:
+ """FeatureStore definition.
+
+ This class instantiates a FeatureStore object that comprises a SageMaker session instance.
+
+ Attributes:
+ sagemaker_session (Session): session instance to perform boto calls.
+ """
+
+ sagemaker_session: Session = attr.ib(default=Session)
+
+ def create_dataset(
+ self,
+ base: Union[FeatureGroup, pd.DataFrame],
+ output_path: str,
+ record_identifier_feature_name: str = None,
+ event_time_identifier_feature_name: str = None,
+ included_feature_names: Sequence[str] = None,
+ kms_key_id: str = None,
+ ) -> DatasetBuilder:
+ """Create a Dataset Builder for generating a Dataset.
+
+ Args:
+ base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a
+ pandas.DataFrame and will be used to merge other FeatureGroups and generate a
+ Dataset.
+ output_path (str): An S3 URI which stores the output .csv file.
+ record_identifier_feature_name (str): A string representing the record identifier
+ feature if base is a DataFrame (default: None).
+ event_time_identifier_feature_name (str): A string representing the event time
+ identifier feature if base is a DataFrame (default: None).
+ included_feature_names (List[str]): A list of features to be included in the output
+ (default: None).
+ kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file
+ (default: None).
+
+ Raises:
+ ValueError: Base is a Pandas DataFrame but no record identifier feature name nor event
+ time identifier feature name is provided.
+ """
+ if isinstance(base, pd.DataFrame):
+ if record_identifier_feature_name is None or event_time_identifier_feature_name is None:
+ raise ValueError(
+ "You must provide a record identifier feature name and an event time "
+ + "identifier feature name if specify DataFrame as base."
+ )
+ return DatasetBuilder(
+ self.sagemaker_session,
+ base,
+ output_path,
+ record_identifier_feature_name,
+ event_time_identifier_feature_name,
+ included_feature_names,
+ kms_key_id,
+ )
+
+ def list_feature_groups(
+ self,
+ name_contains: str = None,
+ feature_group_status_equals: str = None,
+ offline_store_status_equals: str = None,
+ creation_time_after: datetime.datetime = None,
+ creation_time_before: datetime.datetime = None,
+ sort_order: str = None,
+ sort_by: str = None,
+ max_results: int = None,
+ next_token: str = None,
+ ) -> Dict[str, Any]:
+ """List all FeatureGroups satisfying given filters.
+
+ Args:
+ name_contains (str): A string that partially matches one or more FeatureGroups' names.
+ Filters FeatureGroups by name.
+ feature_group_status_equals (str): A FeatureGroup status.
+ Filters FeatureGroups by FeatureGroup status.
+ offline_store_status_equals (str): An OfflineStore status.
+ Filters FeatureGroups by OfflineStore status.
+ creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups
+ created after a specific date and time.
+ creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups
+ created before a specific date and time.
+ sort_order (str): The order in which FeatureGroups are listed.
+ sort_by (str): The value on which the FeatureGroup list is sorted.
+ max_results (int): The maximum number of results returned by ListFeatureGroups.
+ next_token (str): A token to resume pagination of ListFeatureGroups results.
+ Returns:
+ Response dict from service.
+ """
+ return self.sagemaker_session.list_feature_groups(
+ name_contains=name_contains,
+ feature_group_status_equals=feature_group_status_equals,
+ offline_store_status_equals=offline_store_status_equals,
+ creation_time_after=creation_time_after,
+ creation_time_before=creation_time_before,
+ sort_order=sort_order,
+ sort_by=sort_by,
+ max_results=max_results,
+ next_token=next_token,
+ )
diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py
index d82d3596ac..5efe530396 100644
--- a/src/sagemaker/fw_utils.py
+++ b/src/sagemaker/fw_utils.py
@@ -493,7 +493,7 @@ def framework_name_from_image(image_uri):
# We must support both the legacy and current image name format.
name_pattern = re.compile(
r"""^(?:sagemaker(?:-rl)?-)?
- (tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost
+ (tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost
|huggingface-tensorflow|huggingface-pytorch
|huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)?
(scriptmode|training)?
diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py
index 80bd62d5be..c424753286 100644
--- a/src/sagemaker/git_utils.py
+++ b/src/sagemaker/git_utils.py
@@ -279,9 +279,8 @@ def _run_clone_command(repo_url, dest_dir):
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
elif repo_url.startswith("git@"):
with tempfile.NamedTemporaryFile() as sshnoprompt:
- write_pipe = open(sshnoprompt.name, "w")
- write_pipe.write("ssh -oBatchMode=yes $@")
- write_pipe.close()
+ with open(sshnoprompt.name, "w") as write_pipe:
+ write_pipe.write("ssh -oBatchMode=yes $@")
os.chmod(sshnoprompt.name, 0o511)
my_env["GIT_SSH"] = sshnoprompt.name
subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env)
diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json
index 3cc488c55d..590b6e5f82 100644
--- a/src/sagemaker/image_uri_config/autogluon.json
+++ b/src/sagemaker/image_uri_config/autogluon.json
@@ -26,9 +26,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -56,9 +58,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -86,9 +90,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -116,9 +122,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -146,9 +154,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -176,9 +186,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -204,6 +216,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -211,15 +224,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -237,6 +254,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -244,15 +262,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -270,6 +292,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -277,15 +300,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -303,6 +330,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -310,15 +338,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -336,6 +368,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -343,15 +376,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -369,6 +406,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -376,15 +414,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
diff --git a/src/sagemaker/image_uri_config/blazingtext.json b/src/sagemaker/image_uri_config/blazingtext.json
index c588d65c73..ae4295c59a 100644
--- a/src/sagemaker/image_uri_config/blazingtext.json
+++ b/src/sagemaker/image_uri_config/blazingtext.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032"
diff --git a/src/sagemaker/image_uri_config/factorization-machines.json b/src/sagemaker/image_uri_config/factorization-machines.json
index 0f9930357f..8fb1895707 100644
--- a/src/sagemaker/image_uri_config/factorization-machines.json
+++ b/src/sagemaker/image_uri_config/factorization-machines.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/forecasting-deepar.json b/src/sagemaker/image_uri_config/forecasting-deepar.json
index 1acc96ed3e..e9beb7acb6 100644
--- a/src/sagemaker/image_uri_config/forecasting-deepar.json
+++ b/src/sagemaker/image_uri_config/forecasting-deepar.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "522234722520",
"us-east-2": "566113047672",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "156387875391"
diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json
index 1e2246cb11..980dceed17 100644
--- a/src/sagemaker/image_uri_config/huggingface-neuron.json
+++ b/src/sagemaker/image_uri_config/huggingface-neuron.json
@@ -15,21 +15,25 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ca-central-1": "763104351884",
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json
index e771e2a548..482264b773 100644
--- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json
+++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json
@@ -60,6 +60,7 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -89,6 +90,7 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
@@ -123,6 +125,7 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json
index 317c17030a..a0caa59a55 100644
--- a/src/sagemaker/image_uri_config/huggingface.json
+++ b/src/sagemaker/image_uri_config/huggingface.json
@@ -38,9 +38,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -70,9 +72,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -108,9 +112,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -140,9 +146,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -180,9 +188,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -213,9 +223,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -246,9 +258,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -279,9 +293,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -320,9 +336,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -353,9 +371,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -386,9 +406,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -419,9 +441,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -458,9 +482,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -491,9 +517,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -530,9 +558,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -563,9 +593,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -602,9 +634,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -635,9 +669,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -674,6 +710,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -681,15 +718,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -707,6 +748,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -714,15 +756,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -740,6 +786,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -747,15 +794,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -781,6 +832,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -788,15 +840,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -814,6 +870,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -821,15 +878,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -847,6 +908,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -854,15 +916,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -880,6 +946,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -887,15 +954,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -919,6 +990,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -926,15 +998,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -952,6 +1028,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -959,15 +1036,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -991,6 +1072,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -998,15 +1080,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1024,6 +1110,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1031,15 +1118,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1063,6 +1154,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1070,15 +1162,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1096,6 +1192,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1103,15 +1200,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
diff --git a/src/sagemaker/image_uri_config/image-classification.json b/src/sagemaker/image_uri_config/image-classification.json
index 44ccb3f08d..61e037da08 100644
--- a/src/sagemaker/image_uri_config/image-classification.json
+++ b/src/sagemaker/image_uri_config/image-classification.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032"
diff --git a/src/sagemaker/image_uri_config/ipinsights.json b/src/sagemaker/image_uri_config/ipinsights.json
index 4e56c149dc..cf3c70194f 100644
--- a/src/sagemaker/image_uri_config/ipinsights.json
+++ b/src/sagemaker/image_uri_config/ipinsights.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/kmeans.json b/src/sagemaker/image_uri_config/kmeans.json
index 952724ce11..e8e947f094 100644
--- a/src/sagemaker/image_uri_config/kmeans.json
+++ b/src/sagemaker/image_uri_config/kmeans.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/knn.json b/src/sagemaker/image_uri_config/knn.json
index 79b239966d..89e8ef6224 100644
--- a/src/sagemaker/image_uri_config/knn.json
+++ b/src/sagemaker/image_uri_config/knn.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/linear-learner.json b/src/sagemaker/image_uri_config/linear-learner.json
index bb027284ab..606edd3791 100644
--- a/src/sagemaker/image_uri_config/linear-learner.json
+++ b/src/sagemaker/image_uri_config/linear-learner.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json
index 12bc40fccf..588a03a76e 100644
--- a/src/sagemaker/image_uri_config/mxnet.json
+++ b/src/sagemaker/image_uri_config/mxnet.json
@@ -245,9 +245,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -277,9 +279,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -309,9 +313,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -341,9 +347,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -373,9 +381,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -619,6 +629,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -626,15 +637,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -651,6 +666,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -658,15 +674,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -683,6 +703,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -690,15 +711,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -715,6 +740,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -722,15 +748,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -747,6 +777,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -754,15 +785,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -852,6 +887,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -859,15 +895,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -884,6 +924,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -891,15 +932,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -916,6 +961,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -923,15 +969,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
diff --git a/src/sagemaker/image_uri_config/neo-pytorch.json b/src/sagemaker/image_uri_config/neo-pytorch.json
index bd15a6450e..c46dd3de5d 100644
--- a/src/sagemaker/image_uri_config/neo-pytorch.json
+++ b/src/sagemaker/image_uri_config/neo-pytorch.json
@@ -11,7 +11,9 @@
"1.7.0": "1.7",
"1.7.1": "1.7",
"1.8.0": "1.8",
- "1.8.1": "1.8"
+ "1.8.1": "1.8",
+ "1.12.0": "1.12",
+ "1.12.1": "1.12"
},
"versions": {
"1.4": {
@@ -173,6 +175,38 @@
"us-west-2": "301217895009"
},
"repository": "sagemaker-inference-pytorch"
+ },
+ "1.12": {
+ "py_versions": ["py3"],
+ "registries": {
+ "af-south-1": "774647643957",
+ "ap-east-1": "110948597952",
+ "ap-northeast-1": "941853720454",
+ "ap-northeast-2": "151534178276",
+ "ap-northeast-3": "925152966179",
+ "ap-south-1": "763008648453",
+ "ap-southeast-1": "324986816169",
+ "ap-southeast-2": "355873309152",
+ "ca-central-1": "464438896020",
+ "cn-north-1": "472730292857",
+ "cn-northwest-1": "474822919863",
+ "eu-central-1": "746233611703",
+ "eu-north-1": "601324751636",
+ "eu-south-1": "966458181534",
+ "eu-west-1": "802834080501",
+ "eu-west-2": "205493899709",
+ "eu-west-3": "254080097072",
+ "me-south-1": "836785723513",
+ "sa-east-1": "756306329178",
+ "us-east-1": "785573368785",
+ "us-east-2": "007439368137",
+ "us-gov-west-1": "263933020539",
+ "us-iso-east-1": "167761179201",
+ "us-isob-east-1": "406031935815",
+ "us-west-1": "710691900526",
+ "us-west-2": "301217895009"
+ },
+ "repository": "sagemaker-inference-pytorch"
}
}
}
diff --git a/src/sagemaker/image_uri_config/ntm.json b/src/sagemaker/image_uri_config/ntm.json
index 115264b346..16f9565405 100644
--- a/src/sagemaker/image_uri_config/ntm.json
+++ b/src/sagemaker/image_uri_config/ntm.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/object-detection.json b/src/sagemaker/image_uri_config/object-detection.json
index 6a7ba03695..67b60fe587 100644
--- a/src/sagemaker/image_uri_config/object-detection.json
+++ b/src/sagemaker/image_uri_config/object-detection.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032"
diff --git a/src/sagemaker/image_uri_config/object2vec.json b/src/sagemaker/image_uri_config/object2vec.json
index 39614d1273..b166cc96ff 100644
--- a/src/sagemaker/image_uri_config/object2vec.json
+++ b/src/sagemaker/image_uri_config/object2vec.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/pca.json b/src/sagemaker/image_uri_config/pca.json
index 5f87d8528c..11982e2197 100644
--- a/src/sagemaker/image_uri_config/pca.json
+++ b/src/sagemaker/image_uri_config/pca.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/pytorch-neuron.json b/src/sagemaker/image_uri_config/pytorch-neuron.json
index b116a8a36b..5b29406955 100644
--- a/src/sagemaker/image_uri_config/pytorch-neuron.json
+++ b/src/sagemaker/image_uri_config/pytorch-neuron.json
@@ -28,6 +28,7 @@
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
diff --git a/src/sagemaker/image_uri_config/pytorch-training-compiler.json b/src/sagemaker/image_uri_config/pytorch-training-compiler.json
new file mode 100644
index 0000000000..fd7df875a3
--- /dev/null
+++ b/src/sagemaker/image_uri_config/pytorch-training-compiler.json
@@ -0,0 +1,41 @@
+{
+ "training": {
+ "processors": [
+ "gpu"
+ ],
+ "version_aliases": {
+ "1.12": "1.12.0"
+ },
+ "versions": {
+ "1.12.0": {
+ "py_versions": [
+ "py38"
+ ],
+ "registries": {
+ "af-south-1": "626614931356",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ca-central-1": "763104351884",
+ "eu-central-1": "763104351884",
+ "eu-north-1": "763104351884",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "eu-south-1": "692866216735",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884"
+ },
+ "repository": "pytorch-trcomp-training"
+ }
+ }
+ }
+}
diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json
index 3bf8016ba8..85681a3423 100644
--- a/src/sagemaker/image_uri_config/pytorch.json
+++ b/src/sagemaker/image_uri_config/pytorch.json
@@ -17,6 +17,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -25,7 +26,9 @@
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
"eu-north-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-west-1": "763104351884",
+ "eu-south-2": "503227376785",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-west-2": "763104351884"
@@ -39,8 +42,11 @@
"registries": {
"ap-northeast-1": "763104351884",
"ap-northeast-2": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-3": "907027046896",
+ "eu-central-2": "380420809688",
"eu-west-1": "763104351884",
+ "eu-south-2": "503227376785",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
"us-west-2": "763104351884"
@@ -182,6 +188,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -189,15 +196,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -217,6 +228,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -224,15 +236,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -251,6 +267,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -258,15 +275,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -285,6 +306,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -292,15 +314,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -320,6 +346,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -327,15 +354,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -355,6 +386,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -362,15 +394,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -390,6 +426,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -397,15 +434,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -425,6 +466,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -432,15 +474,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -459,6 +505,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -466,15 +513,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -493,6 +544,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -500,15 +552,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -527,6 +583,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -534,15 +591,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -561,6 +622,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -568,15 +630,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -595,6 +661,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -602,15 +669,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -629,6 +700,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -636,15 +708,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -663,6 +739,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -670,15 +747,18 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -707,6 +787,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -714,15 +795,18 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -879,9 +963,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -914,9 +1000,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -949,9 +1037,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -983,9 +1073,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1018,9 +1110,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1053,9 +1147,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1088,9 +1184,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1123,9 +1221,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1157,9 +1257,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1191,9 +1293,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1225,9 +1329,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1259,9 +1365,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1293,9 +1401,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1327,9 +1437,11 @@
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1364,6 +1476,7 @@
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
diff --git a/src/sagemaker/image_uri_config/randomcutforest.json b/src/sagemaker/image_uri_config/randomcutforest.json
index ae7a3574be..15dc84dfc5 100644
--- a/src/sagemaker/image_uri_config/randomcutforest.json
+++ b/src/sagemaker/image_uri_config/randomcutforest.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107"
diff --git a/src/sagemaker/image_uri_config/semantic-segmentation.json b/src/sagemaker/image_uri_config/semantic-segmentation.json
index 866dd606b4..f49bc43109 100644
--- a/src/sagemaker/image_uri_config/semantic-segmentation.json
+++ b/src/sagemaker/image_uri_config/semantic-segmentation.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032"
diff --git a/src/sagemaker/image_uri_config/seq2seq.json b/src/sagemaker/image_uri_config/seq2seq.json
index bb3daf93b6..87810ad09d 100644
--- a/src/sagemaker/image_uri_config/seq2seq.json
+++ b/src/sagemaker/image_uri_config/seq2seq.json
@@ -22,10 +22,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032"
diff --git a/src/sagemaker/image_uri_config/sklearn.json b/src/sagemaker/image_uri_config/sklearn.json
index 7961fde282..4d093f5f62 100644
--- a/src/sagemaker/image_uri_config/sklearn.json
+++ b/src/sagemaker/image_uri_config/sklearn.json
@@ -24,10 +24,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -57,10 +59,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -90,10 +94,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -127,10 +133,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -160,10 +168,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -193,10 +203,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -230,10 +242,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json
index 6a01c3e3e6..a900aa4fe5 100644
--- a/src/sagemaker/image_uri_config/tensorflow.json
+++ b/src/sagemaker/image_uri_config/tensorflow.json
@@ -141,6 +141,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -148,15 +149,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "eu-south-2": "503227376785",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -172,6 +177,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -179,15 +185,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -203,6 +213,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -210,15 +221,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -234,6 +249,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -241,15 +257,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -285,7 +305,10 @@
"2.5": "2.5.1",
"2.6": "2.6.3",
"2.7": "2.7.0",
- "2.8": "2.8.0"
+ "2.8": "2.8.0",
+ "2.9": "2.9.2",
+ "2.10": "2.10.0",
+ "2.11": "2.11.0"
},
"versions": {
"1.10.0": {
@@ -386,6 +409,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -393,15 +417,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -417,6 +445,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -424,15 +453,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -448,6 +481,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -455,15 +489,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -479,6 +517,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -486,15 +525,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -510,6 +553,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -517,15 +561,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -541,6 +589,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -548,15 +597,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -572,6 +625,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -579,15 +633,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -795,6 +853,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -802,15 +861,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -826,6 +889,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -833,15 +897,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -857,6 +925,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -864,15 +933,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -888,6 +961,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -895,15 +969,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -919,6 +997,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -926,15 +1005,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -950,6 +1033,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -957,15 +1041,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -981,6 +1069,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -988,15 +1077,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1012,6 +1105,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1019,15 +1113,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1043,6 +1141,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1050,15 +1149,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1074,6 +1177,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1081,15 +1185,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1105,6 +1213,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1112,15 +1221,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1136,6 +1249,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1143,15 +1257,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1167,6 +1285,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1174,15 +1293,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1198,6 +1321,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1205,15 +1329,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1229,6 +1357,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1236,15 +1365,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1260,6 +1393,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1267,15 +1401,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1291,6 +1429,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1298,15 +1437,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1322,6 +1465,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1329,15 +1473,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1353,6 +1501,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1360,15 +1509,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1384,6 +1537,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1391,15 +1545,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1415,6 +1573,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1422,15 +1581,19 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1446,6 +1609,113 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "me-south-1": "217643126080",
+ "me-central-1": "914824155844",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884"
+ },
+ "repository": "tensorflow-inference"
+ },
+ "2.9.2": {
+ "registries": {
+ "af-south-1": "626614931356",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884"
+ },
+ "repository": "tensorflow-inference"
+ },
+ "2.10.0": {
+ "registries": {
+ "af-south-1": "626614931356",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
+ "ap-southeast-1": "763104351884",
+ "ap-southeast-2": "763104351884",
+ "ap-southeast-3": "907027046896",
+ "ca-central-1": "763104351884",
+ "cn-north-1": "727897471807",
+ "cn-northwest-1": "727897471807",
+ "eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
+ "eu-north-1": "763104351884",
+ "eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
+ "eu-west-1": "763104351884",
+ "eu-west-2": "763104351884",
+ "eu-west-3": "763104351884",
+ "me-south-1": "217643126080",
+ "sa-east-1": "763104351884",
+ "us-east-1": "763104351884",
+ "us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
+ "us-gov-west-1": "442386744353",
+ "us-iso-east-1": "886529160074",
+ "us-west-1": "763104351884",
+ "us-west-2": "763104351884"
+ },
+ "repository": "tensorflow-inference"
+ },
+ "2.11.0": {
+ "registries": {
+ "af-south-1": "626614931356",
+ "ap-east-1": "871362719292",
+ "ap-northeast-1": "763104351884",
+ "ap-northeast-2": "763104351884",
+ "ap-northeast-3": "364406365360",
+ "ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1453,8 +1723,10 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
@@ -1462,6 +1734,7 @@
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1490,6 +1763,7 @@
"ap-northeast-2": "763104351884",
"ap-northeast-3": "364406365360",
"ap-south-1": "763104351884",
+ "ap-south-2": "772153158452",
"ap-southeast-1": "763104351884",
"ap-southeast-2": "763104351884",
"ap-southeast-3": "907027046896",
@@ -1497,15 +1771,18 @@
"cn-north-1": "727897471807",
"cn-northwest-1": "727897471807",
"eu-central-1": "763104351884",
+ "eu-central-2": "380420809688",
"eu-north-1": "763104351884",
"eu-west-1": "763104351884",
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"eu-south-1": "692866216735",
+ "eu-south-2": "503227376785",
"me-south-1": "217643126080",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1696,9 +1973,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1732,9 +2011,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1767,9 +2048,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1803,9 +2086,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1839,9 +2124,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1875,9 +2162,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -1911,9 +2200,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2138,9 +2429,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2173,9 +2466,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2208,9 +2503,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2242,9 +2539,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2276,9 +2575,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2311,9 +2612,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2346,9 +2649,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2380,9 +2685,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2414,9 +2721,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2448,9 +2757,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2482,9 +2793,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2516,9 +2829,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2550,9 +2865,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2584,9 +2901,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2618,9 +2937,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2652,9 +2973,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2686,9 +3009,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2720,9 +3045,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2754,9 +3081,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2788,9 +3117,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2822,9 +3153,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2856,9 +3189,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2890,9 +3225,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2924,9 +3261,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2958,9 +3297,11 @@
"eu-west-2": "763104351884",
"eu-west-3": "763104351884",
"me-south-1": "217643126080",
+ "me-central-1": "914824155844",
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
@@ -2995,6 +3336,7 @@
"sa-east-1": "763104351884",
"us-east-1": "763104351884",
"us-east-2": "763104351884",
+ "us-gov-east-1": "446045086412",
"us-gov-west-1": "442386744353",
"us-iso-east-1": "886529160074",
"us-west-1": "763104351884",
diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json
index a809083c4a..946e78ecc4 100644
--- a/src/sagemaker/image_uri_config/xgboost.json
+++ b/src/sagemaker/image_uri_config/xgboost.json
@@ -25,10 +25,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032"
@@ -58,10 +60,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -91,10 +95,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -124,10 +130,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -155,10 +163,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -186,10 +196,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -217,10 +229,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -248,10 +262,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -286,10 +302,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032"
@@ -319,10 +337,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -352,10 +372,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -385,10 +407,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -416,10 +440,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -447,10 +473,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -478,10 +506,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -509,10 +539,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -544,10 +576,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
@@ -575,10 +609,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249"
diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py
index 7d1d3bd835..c42ce02188 100644
--- a/src/sagemaker/image_uris.py
+++ b/src/sagemaker/image_uris.py
@@ -146,7 +146,7 @@ def retrieve(
tolerate_deprecated_model,
)
- if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK):
+ if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]):
final_image_scope = image_scope
config = _config_for_framework_and_scope(
framework + "-training-compiler", final_image_scope, accelerator_type
diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py
index 202edff9ad..db607770a7 100644
--- a/src/sagemaker/jumpstart/cache.py
+++ b/src/sagemaker/jumpstart/cache.py
@@ -20,7 +20,7 @@
import boto3
import botocore
from packaging.version import Version
-from packaging.specifiers import SpecifierSet
+from packaging.specifiers import SpecifierSet, InvalidSpecifier
from sagemaker.jumpstart.constants import (
ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE,
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
@@ -371,7 +371,10 @@ def _select_version(
return None
return str(max(available_versions))
- spec = SpecifierSet(f"=={semantic_version_str}")
+ try:
+ spec = SpecifierSet(f"=={semantic_version_str}")
+ except InvalidSpecifier:
+ raise KeyError(f"Bad semantic version: {semantic_version_str}")
available_versions_filtered = list(spec.filter(available_versions))
return (
str(max(available_versions_filtered)) if available_versions_filtered != [] else None
diff --git a/src/sagemaker/lineage/_utils.py b/src/sagemaker/lineage/_utils.py
index 28732b0174..7c833a468e 100644
--- a/src/sagemaker/lineage/_utils.py
+++ b/src/sagemaker/lineage/_utils.py
@@ -12,7 +12,6 @@
# language governing permissions and limitations under the License.
"""SageMaker lineage utility methods."""
from __future__ import absolute_import
-from importlib import import_module
from sagemaker.lineage import association
@@ -38,22 +37,6 @@ def _disassociate(source_arn=None, destination_arn=None, sagemaker_session=None)
curr_association.delete()
-def get_module(module_name):
- """Import a module.
-
- Args:
- module_name (str): name of the module to import.
-
- Returns:
- [obj]: The imported module.
- Raises exceptions when the module name is not found
- """
- try:
- return import_module(module_name)
- except ImportError:
- raise Exception("Cannot import module {}, please try again.".format(module_name))
-
-
def get_resource_name_from_arn(arn):
"""Extract the resource name from an ARN string.
diff --git a/src/sagemaker/lineage/artifact.py b/src/sagemaker/lineage/artifact.py
index 3921562beb..718344095a 100644
--- a/src/sagemaker/lineage/artifact.py
+++ b/src/sagemaker/lineage/artifact.py
@@ -29,8 +29,9 @@
LineageEntityEnum,
LineageQueryDirectionEnum,
)
-from sagemaker.lineage._utils import get_module, _disassociate, get_resource_name_from_arn
+from sagemaker.lineage._utils import _disassociate, get_resource_name_from_arn
from sagemaker.lineage.association import Association
+from sagemaker.utils import get_module
LOGGER = logging.getLogger("sagemaker")
diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py
index db6ce2badd..af52da6288 100644
--- a/src/sagemaker/processing.py
+++ b/src/sagemaker/processing.py
@@ -23,6 +23,7 @@
import logging
from textwrap import dedent
from typing import Dict, List, Optional, Union
+from copy import copy
import attr
@@ -32,7 +33,12 @@
from sagemaker.job import _Job
from sagemaker.local import LocalSession
from sagemaker.network import NetworkConfig
-from sagemaker.utils import base_name_from_image, get_config_value, name_from_base
+from sagemaker.utils import (
+ base_name_from_image,
+ get_config_value,
+ name_from_base,
+ check_and_get_run_experiment_config,
+)
from sagemaker.session import Session
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.functions import Join
@@ -202,6 +208,7 @@ def run(
outputs=outputs,
)
+ experiment_config = check_and_get_run_experiment_config(experiment_config)
self.latest_job = ProcessingJob.start_new(
processor=self,
inputs=normalized_inputs,
@@ -604,6 +611,7 @@ def run(
kms_key=kms_key,
)
+ experiment_config = check_and_get_run_experiment_config(experiment_config)
self.latest_job = ProcessingJob.start_new(
processor=self,
inputs=normalized_inputs,
@@ -1587,13 +1595,13 @@ def run( # type: ignore[override]
framework script to run.Path (absolute or relative) to the local
Python source file which should be executed as the entry point
to training. When `code` is an S3 URI, ignore `source_dir`,
- `dependencies, and `git_config`. If ``source_dir`` is specified,
+ `dependencies`, and `git_config`. If ``source_dir`` is specified,
then ``code`` must point to a file located at the root of ``source_dir``.
source_dir (str): Path (absolute, relative or an S3 URI) to a directory
with any other processing source code dependencies aside from the entry
point file (default: None). If ``source_dir`` is an S3 URI, it must
- point to a tar.gz file. Structure within this directory are preserved
- when processing on Amazon SageMaker (default: None).
+ point to a file named `sourcedir.tar.gz`. Structure within this directory
+ are preserved when processing on Amazon SageMaker (default: None).
dependencies (list[str]): A list of paths to directories (absolute
or relative) with any additional libraries that will be exported
to the container (default: []). The library folders will be
@@ -1730,20 +1738,17 @@ def _pack_and_upload_code(
"sagemaker_session unspecified when creating your Processor to have one set up "
"automatically."
)
+ if "/sourcedir.tar.gz" in estimator.uploaded_code.s3_prefix:
+ # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.
+ entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace(
+ "sourcedir.tar.gz",
+ "runproc.sh",
+ )
+ else:
+ raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`")
- # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh.
- entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace(
- "sourcedir.tar.gz",
- "runproc.sh",
- )
script = estimator.uploaded_code.script_name
- s3_runproc_sh = S3Uploader.upload_string_as_file_body(
- self._generate_framework_script(script),
- desired_s3_uri=entrypoint_s3_uri,
- kms_key=kms_key,
- sagemaker_session=self.sagemaker_session,
- )
- logger.info("runproc.sh uploaded to %s", s3_runproc_sh)
+ s3_runproc_sh = self._create_and_upload_runproc(script, kms_key, entrypoint_s3_uri)
return s3_runproc_sh, inputs, job_name
@@ -1827,14 +1832,17 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput
# a7399455f5386d83ddc5cb15c0db00c04bd518ec/src/sagemaker/processing.py#L425-L426
if inputs is None:
inputs = []
- inputs.append(
+
+ # make a shallow copy of user inputs
+ patched_inputs = copy(inputs)
+ patched_inputs.append(
ProcessingInput(
input_name="code",
source=s3_payload,
destination="/opt/ml/processing/input/code/",
)
)
- return inputs
+ return patched_inputs
def _set_entrypoint(self, command, user_script_name):
"""Framework processor override for setting processing job entrypoint.
@@ -1850,3 +1858,42 @@ def _set_entrypoint(self, command, user_script_name):
)
)
self.entrypoint = self.framework_entrypoint_command + [user_script_location]
+
+ def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri):
+ """Create runproc shell script and upload to S3 bucket.
+
+ If leveraging a pipeline session with optimized S3 artifact paths,
+ the runproc.sh file is hashed and uploaded to a separate S3 location.
+
+
+ Args:
+ user_script (str): Relative path to ```code``` in the source bundle
+ - e.g. 'process.py'.
+ kms_key (str): THe kms key used for encryption.
+ entrypoint_s3_uri (str): The S3 upload path for the runproc script.
+ """
+ from sagemaker.workflow.utilities import _pipeline_config, hash_object
+
+ if _pipeline_config and _pipeline_config.pipeline_name:
+ runproc_file_str = self._generate_framework_script(user_script)
+ runproc_file_hash = hash_object(runproc_file_str)
+ s3_uri = (
+ f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/"
+ f"code/{runproc_file_hash}/runproc.sh"
+ )
+ s3_runproc_sh = S3Uploader.upload_string_as_file_body(
+ runproc_file_str,
+ desired_s3_uri=s3_uri,
+ kms_key=kms_key,
+ sagemaker_session=self.sagemaker_session,
+ )
+ else:
+ s3_runproc_sh = S3Uploader.upload_string_as_file_body(
+ self._generate_framework_script(user_script),
+ desired_s3_uri=entrypoint_s3_uri,
+ kms_key=kms_key,
+ sagemaker_session=self.sagemaker_session,
+ )
+ logger.info("runproc.sh uploaded to %s", s3_runproc_sh)
+
+ return s3_runproc_sh
diff --git a/src/sagemaker/pytorch/__init__.py b/src/sagemaker/pytorch/__init__.py
index cac5f94b9a..e2d14f4163 100644
--- a/src/sagemaker/pytorch/__init__.py
+++ b/src/sagemaker/pytorch/__init__.py
@@ -16,3 +16,5 @@
from sagemaker.pytorch.estimator import PyTorch # noqa: F401
from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401
from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401
+
+from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig # noqa: F401
diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py
index 686de4a78c..29e254662f 100644
--- a/src/sagemaker/pytorch/estimator.py
+++ b/src/sagemaker/pytorch/estimator.py
@@ -28,6 +28,7 @@
)
from sagemaker.pytorch import defaults
from sagemaker.pytorch.model import PyTorchModel
+from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
from sagemaker.workflow.entities import PipelineVariable
@@ -51,7 +52,8 @@ def __init__(
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
distribution: Optional[Dict] = None,
- **kwargs
+ compiler_config: Optional[TrainingCompilerConfig] = None,
+ **kwargs,
):
"""This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment.
@@ -208,6 +210,31 @@ def __init__(
To learn more, see `Training with parameter servers
`_.
+ **To enable distributed training with
+ `SageMaker Training Compiler `_
+ for PyTorch:**
+
+ .. code:: python
+
+ {
+ "pytorchxla": {
+ "enabled": True
+ }
+ }
+
+ To learn more, see `SageMaker Training Compiler
+ `_
+ in the *Amazon SageMaker Developer Guide*.
+
+ .. note::
+
+ When you use this PyTorch XLA option for distributed training strategy,
+ you must add the ``compiler_config`` parameter and activate SageMaker
+ Training Compiler.
+
+ compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`):
+ Configures SageMaker Training Compiler to accelerate training.
+
**kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework`
constructor.
@@ -250,6 +277,25 @@ def __init__(
self.distribution = distribution or {}
+ if compiler_config is not None:
+ if not isinstance(compiler_config, TrainingCompilerConfig):
+ error_string = (
+ f"Expected instance of type {TrainingCompilerConfig}"
+ f"for argument compiler_config. "
+ f"Instead got {type(compiler_config)}"
+ )
+ raise ValueError(error_string)
+ if compiler_config:
+ compiler_config.validate(self)
+ elif distribution is not None and "pytorchxla" in distribution:
+ raise ValueError(
+ "Distributed training through PyTorch XLA is currently only supported "
+ "when SageMaker Training Compiler is enabled. To learn more, "
+ "see Enable SageMaker Training Compiler at "
+ "https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html."
+ )
+ self.compiler_config = compiler_config
+
def _pytorch_distribution_configuration(self, distribution):
"""Returns a dict of distribution config for PyTorch training
@@ -289,6 +335,12 @@ def hyperparameters(self):
hyperparameters.update(
EstimatorBase._json_encode_hyperparameters(additional_hyperparameters)
)
+ if self.compiler_config:
+ training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict()
+ hyperparameters.update(
+ EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters)
+ )
+
return hyperparameters
def create_model(
@@ -299,7 +351,7 @@ def create_model(
entry_point=None,
source_dir=None,
dependencies=None,
- **kwargs
+ **kwargs,
):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
@@ -350,7 +402,7 @@ def create_model(
sagemaker_session=self.sagemaker_session,
vpc_config=self.get_vpc_config(vpc_config_override),
dependencies=(dependencies or self.dependencies),
- **kwargs
+ **kwargs,
)
@classmethod
@@ -371,6 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
)
image_uri = init_params.pop("image_uri")
framework, py_version, tag, _ = framework_name_from_image(image_uri)
+ if framework:
+ framework = framework.split("-")[0]
if tag is None:
framework_version = None
diff --git a/src/sagemaker/pytorch/training_compiler/__init__.py b/src/sagemaker/pytorch/training_compiler/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/sagemaker/pytorch/training_compiler/config.py b/src/sagemaker/pytorch/training_compiler/config.py
new file mode 100644
index 0000000000..7faf8acbbd
--- /dev/null
+++ b/src/sagemaker/pytorch/training_compiler/config.py
@@ -0,0 +1,151 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Configuration for the SageMaker Training Compiler."""
+from __future__ import absolute_import
+import logging
+from typing import Union
+from packaging.specifiers import SpecifierSet
+from packaging.version import Version
+
+from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig
+from sagemaker.workflow.entities import PipelineVariable
+
+logger = logging.getLogger(__name__)
+
+
+class TrainingCompilerConfig(BaseConfig):
+ """The SageMaker Training Compiler configuration class."""
+
+ SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"]
+ SUPPORTED_INSTANCE_TYPES_WITH_EFA = [
+ "ml.g4dn.8xlarge",
+ "ml.g4dn.12xlarge",
+ "ml.g5.48xlarge",
+ "ml.p3dn.24xlarge",
+ "ml.p4d.24xlarge",
+ ]
+
+ def __init__(
+ self,
+ enabled: Union[bool, PipelineVariable] = True,
+ debug: Union[bool, PipelineVariable] = False,
+ ):
+ """This class initializes a ``TrainingCompilerConfig`` instance.
+
+ `Amazon SageMaker Training Compiler
+ `_
+ is a feature of SageMaker Training
+ and speeds up training jobs by optimizing model execution graphs.
+
+ You can compile PyTorch models
+ by passing the object of this configuration class to the ``compiler_config``
+ parameter of the :class:`~sagemaker.pytorch.PyTorch`
+ estimator.
+
+ Args:
+ enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker
+ Training Compiler. The default is ``True``.
+ debug (bool or PipelineVariable): Optional. Whether to dump detailed logs
+ for debugging. This comes with a potential performance slowdown.
+ The default is ``False``.
+
+ **Example**: The following code shows the basic usage of the
+ :class:`sagemaker.pytorch.TrainingCompilerConfig()` class
+ to run a PyTorch training job with the compiler.
+
+ .. code-block:: python
+
+ from sagemaker.pytorch import PyTorch, TrainingCompilerConfig
+
+ pytorch_estimator=PyTorch(
+ ...
+ compiler_config=TrainingCompilerConfig()
+ )
+
+ .. seealso::
+
+ For more information about how to enable SageMaker Training Compiler
+ for various training settings such as distributed training,
+ see `Enable SageMaker Training Compiler
+ `_
+ in the `Amazon SageMaker Training Compiler developer guide
+ `_.
+
+ """
+
+ super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug)
+
+ @classmethod
+ def validate(
+ cls,
+ estimator,
+ ):
+ """Checks if SageMaker Training Compiler is configured correctly.
+
+ Args:
+ estimator (:class:`sagemaker.pytorch.PyTorch`): An estimator object.
+ If SageMaker Training Compiler is enabled, it will validate whether
+ the estimator is configured to be compatible with Training Compiler.
+
+ Raises:
+ ValueError: Raised if the requested configuration is not compatible
+ with SageMaker Training Compiler.
+ """
+
+ super(TrainingCompilerConfig, cls).validate(estimator)
+
+ if estimator.image_uri:
+ error_helper_string = (
+ "Overriding the image URI is currently not supported "
+ "for SageMaker Training Compiler."
+ "Specify the following parameters to run the PyTorch training job "
+ "with SageMaker Training Compiler enabled: "
+ "framework_version, and compiler_config."
+ )
+ raise ValueError(error_helper_string)
+
+ if estimator.distribution:
+ pt_xla_present = "pytorchxla" in estimator.distribution
+ pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False)
+ if pt_xla_enabled:
+ if estimator.framework_version:
+ if Version(estimator.framework_version) in SpecifierSet("< 1.12"):
+ error_helper_string = (
+ "Distribution mechanism 'pytorchxla' is currently only supported for "
+ "PyTorch >= 1.12 when SageMaker Training Compiler is enabled."
+ " Received framework_version={} which is unsupported."
+ )
+ raise ValueError(error_helper_string.format(estimator.framework_version))
+ if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA:
+ logger.warning(
+ "Consider using instances with EFA support when "
+ "training with PyTorch >= 1.12 and SageMaker Training Compiler "
+ "enabled. SageMaker Training Compiler leverages EFA to provide better "
+ "performance for distributed training."
+ )
+ if not pt_xla_present:
+ if estimator.framework_version:
+ if Version(estimator.framework_version) in SpecifierSet(">= 1.12"):
+ error_helper_string = (
+ "'pytorchxla' is the only distribution mechanism currently supported "
+ "for PyTorch >= 1.12 when SageMaker Training Compiler is enabled."
+ " Received distribution={} which is unsupported."
+ )
+ raise ValueError(error_helper_string.format(estimator.distribution))
+ elif estimator.instance_count and estimator.instance_count > 1:
+ if estimator.framework_version:
+ if Version(estimator.framework_version) in SpecifierSet(">= 1.12"):
+ logger.warning(
+ "Consider setting 'distribution' to 'pytorchxla' for distributed "
+ "training with PyTorch >= 1.12 and SageMaker Training Compiler enabled."
+ )
diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py
index 00797c9ea0..ce6a3b99cd 100644
--- a/src/sagemaker/session.py
+++ b/src/sagemaker/session.py
@@ -89,6 +89,7 @@ def __init__(
sagemaker_featurestore_runtime_client=None,
default_bucket=None,
settings=SessionSettings(),
+ sagemaker_metrics_client=None,
):
"""Initialize a SageMaker ``Session``.
@@ -116,6 +117,10 @@ def __init__(
Example: "sagemaker-my-custom-bucket".
settings (sagemaker.session_settings.SessionSettings): Optional. Set of optional
parameters to apply to the session.
+ sagemaker_metrics_client (boto3.SageMakerMetrics.Client):
+ Client which makes SageMaker Metrics related calls to Amazon SageMaker
+ (default: None). If not provided, one will be created using
+ this instance's ``boto_session``.
"""
self._default_bucket = None
self._default_bucket_name_override = default_bucket
@@ -130,6 +135,7 @@ def __init__(
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=sagemaker_runtime_client,
sagemaker_featurestore_runtime_client=sagemaker_featurestore_runtime_client,
+ sagemaker_metrics_client=sagemaker_metrics_client,
)
def _initialize(
@@ -138,6 +144,7 @@ def _initialize(
sagemaker_client,
sagemaker_runtime_client,
sagemaker_featurestore_runtime_client,
+ sagemaker_metrics_client,
):
"""Initialize this SageMaker Session.
@@ -172,6 +179,12 @@ def _initialize(
"sagemaker-featurestore-runtime"
)
+ if sagemaker_metrics_client:
+ self.sagemaker_metrics_client = sagemaker_metrics_client
+ else:
+ self.sagemaker_metrics_client = self.boto_session.client("sagemaker-metrics")
+ prepend_user_agent(self.sagemaker_metrics_client)
+
self.local_mode = False
@property
@@ -312,7 +325,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
# For each object key, create the directory on the local machine if needed, and then
# download the file.
for key in keys:
- tail_s3_uri_path = os.path.basename(key_prefix)
+ tail_s3_uri_path = os.path.basename(key)
if not os.path.splitext(key_prefix)[1]:
tail_s3_uri_path = os.path.relpath(key, key_prefix)
destination_path = os.path.join(path, tail_s3_uri_path)
@@ -548,8 +561,8 @@ def train( # noqa: C901
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
experiment_config (dict[str, str]): Experiment management configuration.
- Optionally, the dict can contain three keys:
- 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
+ Optionally, the dict can contain four keys:
+ 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
@@ -558,6 +571,7 @@ def train( # noqa: C901
* If both `ExperimentName` and `TrialName` are not supplied the trial component
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
+ * `RunName` is used to record an experiment run.
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
@@ -703,8 +717,8 @@ def _get_train_request( # noqa: C901
checkpoints will be provided under `/opt/ml/checkpoints/`.
(default: ``None``).
experiment_config (dict[str, str]): Experiment management configuration.
- Optionally, the dict can contain three keys:
- 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
+ Optionally, the dict can contain four keys:
+ 'ExperimentName', 'TrialName', 'TrialComponentDisplayName' and 'RunName'.
The behavior of setting these keys is as follows:
* If `ExperimentName` is supplied but `TrialName` is not a Trial will be
automatically created and the job's Trial Component associated with the Trial.
@@ -713,6 +727,7 @@ def _get_train_request( # noqa: C901
* If both `ExperimentName` and `TrialName` are not supplied the trial component
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
+ * `RunName` is used to record an experiment run.
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
@@ -2121,6 +2136,7 @@ def tune( # noqa: C901
stop_condition,
tags,
warm_start_config,
+ strategy_config=None,
enable_network_isolation=False,
image_uri=None,
algorithm_arn=None,
@@ -2136,6 +2152,8 @@ def tune( # noqa: C901
Args:
job_name (str): Name of the tuning job being created.
strategy (str): Strategy to be used for hyperparameter estimations.
+ strategy_config (dict): A configuration for the hyperparameter tuning
+ job optimisation strategy.
objective_type (str): The type of the objective metric for evaluating training jobs.
This value can be either 'Minimize' or 'Maximize'.
objective_metric_name (str): Name of the metric for evaluating training jobs.
@@ -2220,6 +2238,7 @@ def tune( # noqa: C901
objective_metric_name=objective_metric_name,
parameter_ranges=parameter_ranges,
early_stopping_type=early_stopping_type,
+ strategy_config=strategy_config,
),
"TrainingJobDefinition": self._map_training_config(
static_hyperparameters=static_hyperparameters,
@@ -2375,6 +2394,7 @@ def _map_tuning_config(
objective_type=None,
objective_metric_name=None,
parameter_ranges=None,
+ strategy_config=None,
):
"""Construct tuning job configuration dictionary.
@@ -2392,6 +2412,8 @@ def _map_tuning_config(
objective_metric_name (str): Name of the metric for evaluating training jobs.
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
be one of three types: Continuous, Integer, or Categorical.
+ strategy_config (dict): A configuration for the hyperparameter tuning job optimisation
+ strategy.
Returns:
A dictionary of tuning job configuration. For format details, please refer to
@@ -2415,6 +2437,8 @@ def _map_tuning_config(
if parameter_ranges is not None:
tuning_config["ParameterRanges"] = parameter_ranges
+ if strategy_config is not None:
+ tuning_config["StrategyConfig"] = strategy_config
return tuning_config
@classmethod
@@ -4332,6 +4356,56 @@ def update_feature_group(
FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions
)
+ def list_feature_groups(
+ self,
+ name_contains,
+ feature_group_status_equals,
+ offline_store_status_equals,
+ creation_time_after,
+ creation_time_before,
+ sort_order,
+ sort_by,
+ max_results,
+ next_token,
+ ) -> Dict[str, Any]:
+ """List all FeatureGroups satisfying given filters.
+
+ Args:
+ name_contains (str): A string that partially matches one or more FeatureGroups' names.
+ Filters FeatureGroups by name.
+ feature_group_status_equals (str): A FeatureGroup status.
+ Filters FeatureGroups by FeatureGroup status.
+ offline_store_status_equals (str): An OfflineStore status.
+ Filters FeatureGroups by OfflineStore status.
+ creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups
+ created after a specific date and time.
+ creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups
+ created before a specific date and time.
+ sort_order (str): The order in which FeatureGroups are listed.
+ sort_by (str): The value on which the FeatureGroup list is sorted.
+ max_results (int): The maximum number of results returned by ListFeatureGroups.
+ next_token (str): A token to resume pagination of ListFeatureGroups results.
+ Returns:
+ Response dict from service.
+ """
+ list_feature_groups_args = {}
+
+ def check_object(key, value):
+ if value is not None:
+ list_feature_groups_args[key] = value
+
+ check_object("NameContains", name_contains)
+ check_object("FeatureGroupStatusEquals", feature_group_status_equals)
+ check_object("OfflineStoreStatusEquals", offline_store_status_equals)
+ check_object("CreationTimeAfter", creation_time_after)
+ check_object("CreationTimeBefore", creation_time_before)
+ check_object("SortOrder", sort_order)
+ check_object("SortBy", sort_by)
+ check_object("MaxResults", max_results)
+ check_object("NextToken", next_token)
+
+ return self.sagemaker_client.list_feature_groups(**list_feature_groups_args)
+
def update_feature_metadata(
self,
feature_group_name: str,
@@ -4399,6 +4473,48 @@ def put_record(
Record=record,
)
+ def delete_record(
+ self,
+ feature_group_name: str,
+ record_identifier_value_as_string: str,
+ event_time: str,
+ ):
+ """Deletes a single record from the FeatureGroup.
+
+ Args:
+ feature_group_name (str): name of the FeatureGroup.
+ record_identifier_value_as_string (str): name of the record identifier.
+ event_time (str): a timestamp indicating when the deletion event occurred.
+ """
+ return self.sagemaker_featurestore_runtime_client.delete_record(
+ FeatureGroupName=feature_group_name,
+ RecordIdentifierValueAsString=record_identifier_value_as_string,
+ EventTime=event_time,
+ )
+
+ def get_record(
+ self,
+ record_identifier_value_as_string: str,
+ feature_group_name: str,
+ feature_names: Sequence[str],
+ ) -> Dict[str, Sequence[Dict[str, str]]]:
+ """Gets a single record in the FeatureGroup.
+
+ Args:
+ record_identifier_value_as_string (str): name of the record identifier.
+ feature_group_name (str): name of the FeatureGroup.
+ feature_names (Sequence[str]): list of feature names.
+ """
+ get_record_args = {
+ "FeatureGroupName": feature_group_name,
+ "RecordIdentifierValueAsString": record_identifier_value_as_string,
+ }
+
+ if feature_names:
+ get_record_args["FeatureNames"] = feature_names
+
+ return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args)
+
def start_query_execution(
self,
catalog: str,
diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py
index dc3d26a355..912bc90d80 100644
--- a/src/sagemaker/spark/processing.py
+++ b/src/sagemaker/spark/processing.py
@@ -30,6 +30,7 @@
from enum import Enum
from io import BytesIO
from urllib.parse import urlparse
+from copy import copy
from typing import Union, List, Dict, Optional
@@ -279,6 +280,10 @@ def run(
def _extend_processing_args(self, inputs, outputs, **kwargs):
"""Extends processing job args such as inputs."""
+ # make a shallow copy of user outputs
+ outputs = outputs or []
+ extended_outputs = copy(outputs)
+
if kwargs.get("spark_event_logs_s3_uri"):
spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri")
self._validate_s3_uri(spark_event_logs_s3_uri)
@@ -297,16 +302,21 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
s3_upload_mode="Continuous",
)
- outputs = outputs or []
- outputs.append(output)
+ extended_outputs.append(output)
+
+ # make a shallow copy of user inputs
+ inputs = inputs or []
+ extended_inputs = copy(inputs)
if kwargs.get("configuration"):
configuration = kwargs.get("configuration")
self._validate_configuration(configuration)
- inputs = inputs or []
- inputs.append(self._stage_configuration(configuration))
+ extended_inputs.append(self._stage_configuration(configuration))
- return inputs, outputs
+ return (
+ extended_inputs if extended_inputs else None,
+ extended_outputs if extended_outputs else None,
+ )
def start_history_server(self, spark_event_logs_s3_uri=None):
"""Starts a Spark history server.
@@ -940,9 +950,16 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
outputs: Processing outputs.
kwargs: Additional keyword arguments passed to `super()`.
"""
+
+ if inputs is None:
+ inputs = []
+
+ # make a shallow copy of user inputs
+ extended_inputs = copy(inputs)
+
self.command = [_SparkProcessorBase._default_command]
extended_inputs = self._handle_script_dependencies(
- inputs, kwargs.get("submit_py_files"), FileType.PYTHON
+ extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON
)
extended_inputs = self._handle_script_dependencies(
extended_inputs, kwargs.get("submit_jars"), FileType.JAR
@@ -1199,8 +1216,14 @@ def _extend_processing_args(self, inputs, outputs, **kwargs):
else:
raise ValueError("submit_class is required")
+ if inputs is None:
+ inputs = []
+
+ # make a shallow copy of user inputs
+ extended_inputs = copy(inputs)
+
extended_inputs = self._handle_script_dependencies(
- inputs, kwargs.get("submit_jars"), FileType.JAR
+ extended_inputs, kwargs.get("submit_jars"), FileType.JAR
)
extended_inputs = self._handle_script_dependencies(
extended_inputs, kwargs.get("submit_files"), FileType.FILE
diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py
index cfcc637b99..40ed143ebc 100644
--- a/src/sagemaker/transformer.py
+++ b/src/sagemaker/transformer.py
@@ -14,17 +14,24 @@
from __future__ import absolute_import
from typing import Union, Optional, List, Dict
-from botocore import exceptions
+import logging
+import copy
+import time
+from botocore import exceptions
from sagemaker.job import _Job
-from sagemaker.session import Session
+from sagemaker.session import Session, get_execution_role
from sagemaker.inputs import BatchDataCaptureConfig
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.functions import Join
-from sagemaker.workflow.pipeline_context import runnable_by_pipeline
+from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.execution_variables import ExecutionVariables
-from sagemaker.utils import base_name_from_image, name_from_base
+from sagemaker.utils import (
+ base_name_from_image,
+ name_from_base,
+ check_and_get_run_experiment_config,
+)
class Transformer(object):
@@ -248,6 +255,7 @@ def transform(
)
self._reset_output_path = True
+ experiment_config = check_and_get_run_experiment_config(experiment_config)
self.latest_transform_job = _TransformJob.start_new(
self,
data,
@@ -266,6 +274,155 @@ def transform(
if wait:
self.latest_transform_job.wait(logs=logs)
+ def transform_with_monitoring(
+ self,
+ monitoring_config,
+ monitoring_resource_config,
+ data: str,
+ data_type: str = "S3Prefix",
+ content_type: str = None,
+ compression_type: str = None,
+ split_type: str = None,
+ input_filter: str = None,
+ output_filter: str = None,
+ join_source: str = None,
+ model_client_config: Dict[str, str] = None,
+ batch_data_capture_config: BatchDataCaptureConfig = None,
+ monitor_before_transform: bool = False,
+ supplied_baseline_statistics: str = None,
+ supplied_baseline_constraints: str = None,
+ wait: bool = True,
+ pipeline_name: str = None,
+ role: str = None,
+ ):
+ """Runs a transform job with monitoring job.
+
+ Note that this function will not start a transform job immediately,
+ instead, it will create a SageMaker Pipeline and execute it.
+ If you provide an existing pipeline_name, no new pipeline will be created, otherwise,
+ each transform_with_monitoring call will create a new pipeline and execute.
+
+ Args:
+ monitoring_config (Union[
+ `sagemaker.workflow.quality_check_step.QualityCheckConfig`,
+ `sagemaker.workflow.quality_check_step.ClarifyCheckConfig`
+ ]): the monitoring configuration used for run model monitoring.
+ monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`):
+ the check job (processing job) cluster resource configuration.
+ transform_step_args (_JobStepArguments): the transform step transform arguments.
+ data (str): Input data location in S3 for the transform job
+ data_type (str): What the S3 location defines (default: 'S3Prefix').
+ Valid values:
+ * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix
+ will be used as inputs for the transform job.
+ * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3
+ object to use as an input for the transform job.
+ content_type (str): MIME type of the input data (default: None).
+ compression_type (str): Compression type of the input data, if
+ compressed (default: None). Valid values: 'Gzip', None.
+ split_type (str): The record delimiter for the input object
+ (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and
+ 'TFRecord'.
+ input_filter (str): A JSONPath to select a portion of the input to
+ pass to the algorithm container for inference. If you omit the
+ field, it gets the value '$', representing the entire input.
+ For CSV data, each row is taken as a JSON array,
+ so only index-based JSONPaths can be applied, e.g. $[0], $[1:].
+ CSV data should follow the `RFC format `_.
+ See `Supported JSONPath Operators
+ `_
+ for a table of supported JSONPath operators.
+ For more information, see the SageMaker API documentation for
+ `CreateTransformJob
+ `_.
+ Some examples: "$[1:]", "$.features" (default: None).
+ output_filter (str): A JSONPath to select a portion of the
+ joined/original output to return as the output.
+ For more information, see the SageMaker API documentation for
+ `CreateTransformJob
+ `_.
+ Some examples: "$[1:]", "$.prediction" (default: None).
+ join_source (str): The source of data to be joined to the transform
+ output. It can be set to 'Input' meaning the entire input record
+ will be joined to the inference result. You can use OutputFilter
+ to select the useful portion before uploading to S3. (default:
+ None). Valid values: Input, None.
+ model_client_config (dict[str, str]): Model configuration.
+ Dictionary contains two optional keys,
+ 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'.
+ (default: ``None``).
+ batch_data_capture_config (BatchDataCaptureConfig): Configuration object which
+ specifies the configurations related to the batch data capture for the transform job
+ (default: ``None``).
+ monitor_before_transform (bgool): If to run data quality
+ or model explainability monitoring type,
+ a true value of this flag indicates running the check step before the transform job.
+ fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the
+ check step when a violation is detected.
+ supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path
+ to the supplied statistics object representing the statistics JSON file
+ which will be used for drift to check (default: None).
+ supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path
+ to the supplied constraints object representing the constraints JSON file
+ which will be used for drift to check (default: None).
+ wait (bool): To determine if needed to wait for the pipeline execution to complete
+ pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step
+ role (str): Execution role
+ """
+
+ transformer = self
+ if not isinstance(self.sagemaker_session, PipelineSession):
+ sagemaker_session = self.sagemaker_session
+ self.sagemaker_session = None
+ transformer = copy.deepcopy(self)
+ transformer.sagemaker_session = PipelineSession()
+ self.sagemaker_session = sagemaker_session
+
+ transform_step_args = transformer.transform(
+ data=data,
+ data_type=data_type,
+ content_type=content_type,
+ compression_type=compression_type,
+ split_type=split_type,
+ input_filter=input_filter,
+ output_filter=output_filter,
+ batch_data_capture_config=batch_data_capture_config,
+ join_source=join_source,
+ model_client_config=model_client_config,
+ )
+
+ from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep
+
+ monitoring_batch_step = MonitorBatchTransformStep(
+ name="MonitorBatchTransformStep",
+ display_name="MonitorBatchTransformStep",
+ description="",
+ transform_step_args=transform_step_args,
+ monitor_configuration=monitoring_config,
+ check_job_configuration=monitoring_resource_config,
+ monitor_before_transform=monitor_before_transform,
+ supplied_baseline_constraints=supplied_baseline_constraints,
+ supplied_baseline_statistics=supplied_baseline_statistics,
+ )
+
+ pipeline_name = (
+ pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}"
+ )
+ # if pipeline exists, just start the execution
+ from sagemaker.workflow.pipeline import Pipeline
+
+ pipeline = Pipeline(
+ name=pipeline_name,
+ steps=[monitoring_batch_step],
+ sagemaker_session=transformer.sagemaker_session,
+ )
+ pipeline.upsert(role_arn=role if role else get_execution_role())
+ execution = pipeline.start()
+ if wait:
+ logging.info("Waiting for transform with monitoring to execute ...")
+ execution.wait()
+ return execution
+
def delete_model(self):
"""Delete the corresponding SageMaker model for this Transformer."""
self.sagemaker_session.delete_model(self.model_name)
diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py
index 52b9d81d0d..9a694cbec9 100644
--- a/src/sagemaker/tuner.py
+++ b/src/sagemaker/tuner.py
@@ -282,8 +282,8 @@ def from_job_desc(cls, hyperband_strategy_config):
Returns:
sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of
- HyperbandStrategyConfig containing the max_resource and min_resource provided as part of
- ``hyperband_strategy_config``.
+ ``HyperbandStrategyConfig`` containing the max_resource
+ and min_resource provided as part of ``hyperband_strategy_config``.
"""
return cls(
min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE],
@@ -306,7 +306,7 @@ def to_input_req(self):
Returns:
dict: Containing the "MaxResource" and
- "MinResource" as the first class fields.
+ "MinResource" as the first class fields.
"""
return {
HYPERBAND_MIN_RESOURCE: self.min_resource,
@@ -330,7 +330,7 @@ def __init__(
Args:
hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration
- for the object that specifies the Hyperband strategy.
+ for the object that specifies the Hyperband strategy.
This parameter is only supported for the Hyperband selection for Strategy within
the HyperParameterTuningJobConfig.
"""
@@ -461,7 +461,7 @@ def __init__(
``WarmStartConfig`` object that has been initialized with the
configuration defining the nature of warm start tuning job.
strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter"
- tuning job optimisation strategy.
+ tuning job optimisation strategy.
early_stopping_type (str or PipelineVariable): Specifies whether early stopping is
enabled for the job. Can be either 'Auto' or 'Off' (default:
'Off'). If set to 'Off', early stopping will not be attempted.
@@ -1569,7 +1569,7 @@ def create(
strategy (str): Strategy to be used for hyperparameter estimations
(default: 'Bayesian').
strategy_config (dict): The configuration for a training job launched by a
- hyperparameter tuning job.
+ hyperparameter tuning job.
objective_type (str): The type of the objective metric for evaluating training jobs.
This value can be either 'Minimize' or 'Maximize' (default: 'Maximize').
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter
@@ -1776,7 +1776,7 @@ def _get_tuner_args(cls, tuner, inputs):
}
if tuner.strategy_config is not None:
- tuning_config["strategy_config"] = tuner.strategy_config
+ tuning_config["strategy_config"] = tuner.strategy_config.to_input_req()
if tuner.objective_metric_name is not None:
tuning_config["objective_type"] = tuner.objective_type
diff --git a/src/sagemaker/utilities/search_expression.py b/src/sagemaker/utilities/search_expression.py
new file mode 100644
index 0000000000..5b2aaf3226
--- /dev/null
+++ b/src/sagemaker/utilities/search_expression.py
@@ -0,0 +1,133 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""Simplify Search Expression by provide a simplified DSL"""
+from __future__ import absolute_import
+
+from enum import Enum, unique
+
+from sagemaker.apiutils._base_types import ApiObject
+
+
+# TODO: we should update the lineage to use search expressions
+# defined here in a separate change
+@unique
+class Operator(Enum):
+ """Search operators"""
+
+ EQUALS = "Equals"
+ NOT_EQUALS = "NotEquals"
+ GREATER_THAN = "GreaterThan"
+ GREATER_THAN_OR_EQUAL = "GreaterThanOrEqualTo"
+ LESS_THAN = "LessThan"
+ LESS_THAN_OR_EQUAL = "LessThanOrEqualTo"
+ CONTAINS = "Contains"
+ EXISTS = "Exists"
+ NOT_EXISTS = "NotExists"
+
+
+@unique
+class BooleanOperator(Enum):
+ """Boolean search operation enum"""
+
+ AND = "And"
+ OR = "Or"
+
+
+class SearchObject(ApiObject):
+ """Search Object"""
+
+ def to_boto(self):
+ """Convert a search object to boto"""
+ return ApiObject.to_boto(self)
+
+
+class Filter(SearchObject):
+ """A Python class represent a Search Filter object."""
+
+ name = None
+ operator = None
+ value = None
+
+ def __init__(self, name, operator=None, value=None, **kwargs):
+ """Construct a Filter object
+
+ Args:
+ name (str): filter field name
+ operator (Operator): one of Operator enum
+ value (str): value of the field
+ """
+ super().__init__(**kwargs)
+ self.name = name
+ self.operator = None if operator is None else operator.value
+ self.value = value
+
+
+class NestedFilter(SearchObject):
+ """A Python class represent a Nested Filter object."""
+
+ nested_property_name = None
+ filters = None
+
+ def __init__(self, property_name, filters, **kwargs):
+ """Construct a Nested Filter object
+
+ Args:
+ property_name (str): nested property name
+ filters (List[Filter]): list of Filter objects
+ """
+ super().__init__(**kwargs)
+ self.nested_property_name = property_name
+ self.filters = list(map(lambda x: x.to_boto(), filters))
+
+
+class SearchExpression(SearchObject):
+ """A Python class representation of a Search Expression object.
+
+ A sample search expression defined in here:
+ https://boto3.amazonaws.com/v1/documentation/api/1.12.8/reference/services/sagemaker.html#SageMaker.Client.search
+ """
+
+ filters = None
+ nested_filters = None
+ operator = None
+ sub_expressions = None
+
+ def __init__(
+ self,
+ filters=None,
+ nested_filters=None,
+ sub_expressions=None,
+ boolean_operator=BooleanOperator.AND,
+ **kwargs
+ ):
+ """Construct a Search Expression object
+
+ Args:
+ filters (List[Filter]): list of Filter objects
+ nested_filters (List[NestedFilter]): list of Nested Filters objects
+ sub_expressions (List[SearchExpression]): list of Search Expression objects
+ boolean_operator (BooleanOperator): one of the boolean operator enums
+ """
+ super().__init__(**kwargs)
+ if filters is None and nested_filters is None and sub_expressions is None:
+ raise ValueError(
+ "You must specify at least one subexpression, filter, or nested filter"
+ )
+ self.filters = None if filters is None else list(map(lambda x: x.to_boto(), filters))
+ self.nested_filters = (
+ None if nested_filters is None else list(map(lambda x: x.to_boto(), nested_filters))
+ )
+ self.sub_expressions = (
+ None if sub_expressions is None else list(map(lambda x: x.to_boto(), sub_expressions))
+ )
+ self.operator = boolean_operator.value
diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py
index e668b2a8ed..9d28e3bf4e 100644
--- a/src/sagemaker/utils.py
+++ b/src/sagemaker/utils.py
@@ -29,6 +29,7 @@
from datetime import datetime
from typing import Optional
+from importlib import import_module
import botocore
from six.moves.urllib import parse
@@ -590,6 +591,27 @@ def retries(
)
+def retry_with_backoff(callable_func, num_attempts=8):
+ """Retry with backoff until maximum attempts are reached
+
+ Args:
+ callable_func (callable): The callable function to retry.
+ num_attempts (int): The maximum number of attempts to retry.
+ """
+ if num_attempts < 1:
+ raise ValueError(
+ "The num_attempts must be >= 1, but the given value is {}.".format(num_attempts)
+ )
+ for i in range(num_attempts):
+ try:
+ return callable_func()
+ except Exception as ex: # pylint: disable=broad-except
+ if i == num_attempts - 1:
+ raise ex
+ logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
+ time.sleep(2**i)
+
+
def _botocore_resolver():
"""Get the DNS suffix for the given region.
@@ -874,3 +896,47 @@ def _start_waiting(waiting_time: int):
print(progress, end="\r")
time.sleep(interval)
print(len(progress) * " ", end="\r")
+
+
+def get_module(module_name):
+ """Import a module.
+
+ Args:
+ module_name (str): name of the module to import.
+
+ Returns:
+ object: The imported module.
+
+ Raises:
+ Exception: when the module name is not found
+ """
+ try:
+ return import_module(module_name)
+ except ImportError:
+ raise Exception("Cannot import module {}, please try again.".format(module_name))
+
+
+def check_and_get_run_experiment_config(experiment_config: Optional[dict] = None) -> dict:
+ """Check user input experiment_config or get it from the current Run object if exists.
+
+ Args:
+ experiment_config (dict): The experiment_config supplied by the user.
+
+ Returns:
+ dict: Return the user supplied experiment_config if it is not None.
+ Otherwise fetch the experiment_config from the current Run object if exists.
+ """
+ from sagemaker.experiments._run_context import _RunContext
+
+ run_obj = _RunContext.get_current_run()
+ if experiment_config:
+ if run_obj:
+ logger.warning(
+ "The function is invoked within an Experiment Run context "
+ "but another experiment_config (%s) was supplied, so "
+ "ignoring the experiment_config fetched from the Run object.",
+ experiment_config,
+ )
+ return experiment_config
+
+ return run_obj.experiment_config if run_obj else None
diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py
index 8ba65f1eee..cdef9537c1 100644
--- a/src/sagemaker/workflow/_utils.py
+++ b/src/sagemaker/workflow/_utils.py
@@ -13,6 +13,7 @@
"""Scrapper utilities to support repacking of models."""
from __future__ import absolute_import
+import logging
import os
import shutil
import tarfile
@@ -37,6 +38,8 @@
if TYPE_CHECKING:
from sagemaker.workflow.step_collections import StepCollection
+logger = logging.getLogger(__name__)
+
FRAMEWORK_VERSION = "0.23-1"
INSTANCE_TYPE = "ml.m5.large"
REPACK_SCRIPT = "_repack_model.py"
@@ -479,10 +482,19 @@ def arguments(self) -> RequestType:
request_dict = get_create_model_package_request(**model_package_args)
# these are not available in the workflow service and will cause rejection
+ warn_msg_template = (
+ "Popping out '%s' from the pipeline definition "
+ "since it will be overridden in pipeline execution time."
+ )
if "CertifyForMarketplace" in request_dict:
request_dict.pop("CertifyForMarketplace")
+ logger.warning(warn_msg_template, "CertifyForMarketplace")
if "Description" in request_dict:
request_dict.pop("Description")
+ logger.warning(warn_msg_template, "Description")
+ if "ModelPackageName" in request_dict:
+ request_dict.pop("ModelPackageName")
+ logger.warning(warn_msg_template, "ModelPackageName")
return request_dict
diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py
index 89d7c5dfd9..08c170d424 100644
--- a/src/sagemaker/workflow/utilities.py
+++ b/src/sagemaker/workflow/utilities.py
@@ -114,11 +114,12 @@ def get_code_hash(step: Entity) -> str:
if isinstance(step, ProcessingStep) and step.step_args:
kwargs = step.step_args.func_kwargs
source_dir = kwargs.get("source_dir")
+ submit_class = kwargs.get("submit_class")
dependencies = get_processing_dependencies(
[
kwargs.get("dependencies"),
kwargs.get("submit_py_files"),
- kwargs.get("submit_class"),
+ [submit_class] if submit_class else None,
kwargs.get("submit_jars"),
kwargs.get("submit_files"),
]
@@ -168,7 +169,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str]
str: A hash string representing the unique code artifact(s) for the step
"""
- # FrameworkProcessor
+ # If FrameworkProcessor contains source_dir
if source_dir:
source_dir_url = urlparse(source_dir)
if source_dir_url.scheme == "" or source_dir_url.scheme == "file":
@@ -400,5 +401,5 @@ def execute_job_functions(step_args: _StepArguments):
"""
chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs)
- if chained_args:
+ if isinstance(chained_args, _StepArguments):
execute_job_functions(chained_args)
diff --git a/tests/conftest.py b/tests/conftest.py
index e92d98112b..f6682ebb8c 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -73,6 +73,7 @@
"neo_pytorch",
"neo_tensorflow",
"pytorch",
+ "pytorch_training_compiler",
"ray_pytorch",
"ray_tensorflow",
"sklearn",
diff --git a/tests/data/experiment/inference.py b/tests/data/experiment/inference.py
new file mode 100644
index 0000000000..cdb9a7b8c6
--- /dev/null
+++ b/tests/data/experiment/inference.py
@@ -0,0 +1,85 @@
+# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License").
+# You may not use this file except in compliance with the License.
+# A copy of the License is located at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# or in the "license" file accompanying this file. This file is distributed
+# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
+# express or implied. See the License for the specific language governing
+# permissions and limitations under the License.
+import logging
+import os
+import pickle as pkl
+
+import boto3
+import numpy as np
+import sagemaker_xgboost_container.encoder as xgb_encoders
+
+sdk_name = "sagemaker-dev-1.0.tar.gz"
+code_dir = "/opt/ml/code"
+
+sdk_file = f"{code_dir}/{sdk_name}"
+os.system(f"pip install {sdk_file}")
+
+from sagemaker.session import Session
+from sagemaker.experiments import load_run
+
+boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
+sagemaker_session = Session(boto_session=boto_session)
+
+
+def model_fn(model_dir):
+ """
+ Deserialize and return fitted model.
+ """
+ with load_run(
+ experiment_name=os.environ["EXPERIMENT_NAME"],
+ run_name=os.environ["RUN_NAME"],
+ sagemaker_session=sagemaker_session,
+ ) as run:
+ logging.info(f"Run name: {run.run_name}")
+ logging.info(f"Experiment name: {run.experiment_name}")
+ logging.info(f"Trial component name: {run._trial_component.trial_component_name}")
+ run.log_parameters({"p3": 3.0, "p4": 4.0})
+ run.log_metric("test-job-load-log-metric", 0.1)
+
+ model_file = "xgboost-model"
+ booster = pkl.load(open(os.path.join(model_dir, model_file), "rb"))
+ return booster
+
+
+def input_fn(request_body, request_content_type):
+ """
+ The SageMaker XGBoost model server receives the request data body and the content type,
+ and invokes the `input_fn`.
+ Return a DMatrix (an object that can be passed to predict_fn).
+ """
+ if request_content_type == "text/libsvm":
+ return xgb_encoders.libsvm_to_dmatrix(request_body)
+ else:
+ raise ValueError("Content type {} is not supported.".format(request_content_type))
+
+
+def predict_fn(input_data, model):
+ """
+ SageMaker XGBoost model server invokes `predict_fn` on the return value of `input_fn`.
+ Return a two-dimensional NumPy array where the first columns are predictions
+ and the remaining columns are the feature contributions (SHAP values) for that prediction.
+ """
+ prediction = model.predict(input_data)
+ feature_contribs = model.predict(input_data, pred_contribs=True, validate_features=False)
+ output = np.hstack((prediction[:, np.newaxis], feature_contribs))
+ return output
+
+
+def output_fn(predictions, content_type):
+ """
+ After invoking predict_fn, the model server invokes `output_fn`.
+ """
+ if content_type == "text/csv" or content_type == "application/json":
+ return ",".join(str(x) for x in predictions[0])
+ else:
+ raise ValueError("Content type {} is not supported.".format(content_type))
diff --git a/tests/data/experiment/process_job_script_for_run_clz.py b/tests/data/experiment/process_job_script_for_run_clz.py
new file mode 100644
index 0000000000..32fd0ab4f6
--- /dev/null
+++ b/tests/data/experiment/process_job_script_for_run_clz.py
@@ -0,0 +1,37 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""This script file runs on SageMaker processing job"""
+from __future__ import absolute_import
+
+import logging
+import os
+import boto3
+
+sdk_file = "sagemaker-dev-1.0.tar.gz"
+os.system(f"pip install {sdk_file}")
+
+
+from sagemaker import Session
+from sagemaker.experiments import load_run
+
+
+boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
+sagemaker_session = Session(boto_session=boto_session)
+
+
+with load_run(sagemaker_session=sagemaker_session) as run:
+ logging.info(f"Run name: {run.run_name}")
+ logging.info(f"Experiment name: {run.experiment_name}")
+ logging.info(f"Trial component name: {run._trial_component.trial_component_name}")
+ run.log_parameters({"p3": 3.0, "p4": 4.0})
+ run.log_metric("test-job-load-log-metric", 0.1)
diff --git a/tests/data/experiment/train_job_script_for_run_clz.py b/tests/data/experiment/train_job_script_for_run_clz.py
new file mode 100644
index 0000000000..34c86e0993
--- /dev/null
+++ b/tests/data/experiment/train_job_script_for_run_clz.py
@@ -0,0 +1,71 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+"""This script file runs on SageMaker training job"""
+from __future__ import absolute_import
+
+import logging
+import time
+import os
+import boto3
+
+sdk_file = "sagemaker-dev-1.0.tar.gz"
+os.system(f"pip install {sdk_file}")
+
+from sagemaker import Session
+from sagemaker.experiments import load_run, Run
+
+boto_session = boto3.Session(region_name=os.environ["AWS_REGION"])
+sagemaker_session = Session(boto_session=boto_session)
+
+if os.environ["RUN_OPERATION"] == "init":
+ logging.info("Initializing a Run")
+ with Run(
+ experiment_name=os.environ["EXPERIMENT_NAME"],
+ run_name=os.environ["RUN_NAME"],
+ sagemaker_session=sagemaker_session,
+ ) as run:
+ logging.info(f"Run name: {run.run_name}")
+ logging.info(f"Experiment name: {run.experiment_name}")
+ logging.info(f"Trial component name: {run._trial_component.trial_component_name}")
+ run.log_parameter("p1", 1.0)
+ run.log_parameter("p2", 2)
+
+ for i in range(2):
+ run.log_metric("A", i)
+ for i in range(2):
+ run.log_metric("B", i)
+ for i in range(2):
+ run.log_metric("C", i)
+ for i in range(2):
+ time.sleep(0.003)
+ run.log_metric("D", i)
+ for i in range(2):
+ time.sleep(0.003)
+ run.log_metric("E", i)
+ time.sleep(15)
+
+else:
+ logging.info("Loading a Run")
+ logging.info("Invoking load_run with name arguments")
+ with load_run(
+ experiment_name=os.environ["EXPERIMENT_NAME"],
+ run_name=os.environ["RUN_NAME"],
+ sagemaker_session=sagemaker_session,
+ ) as run:
+ run.log_parameters({"p3": 3.0, "p4": 4})
+ run.log_metric("test-job-load-log-metric", 0.1)
+
+ if os.environ.get("CALL_RUN_LOAD_WITH_NO_NAME_ARGS", None) == "True":
+ logging.info("Invoking load_run without name arguments")
+ with load_run(sagemaker_session=sagemaker_session) as run:
+ run.log_parameters({"p5": 5.0, "p6": 6})
diff --git a/tests/data/experiment/transform_job_materials/data.csv b/tests/data/experiment/transform_job_materials/data.csv
new file mode 100644
index 0000000000..9f1b6c0bb0
--- /dev/null
+++ b/tests/data/experiment/transform_job_materials/data.csv
@@ -0,0 +1 @@
+-99 1:3 2:0.37 3:0.29 4:0.095 5:0.249 6:0.1045 7:0.058 8:0.067
\ No newline at end of file
diff --git a/tests/data/experiment/transform_job_materials/xgb_model.tar.gz b/tests/data/experiment/transform_job_materials/xgb_model.tar.gz
new file mode 100644
index 0000000000..3969bede9e
Binary files /dev/null and b/tests/data/experiment/transform_job_materials/xgb_model.tar.gz differ
diff --git a/tests/data/huggingface_byoc/requirements.txt b/tests/data/huggingface_byoc/requirements.txt
new file mode 100644
index 0000000000..462542f1c1
--- /dev/null
+++ b/tests/data/huggingface_byoc/requirements.txt
@@ -0,0 +1,2 @@
+transformers
+datasets
diff --git a/tests/data/huggingface_byoc/run_glue.py b/tests/data/huggingface_byoc/run_glue.py
new file mode 100644
index 0000000000..1060398fa4
--- /dev/null
+++ b/tests/data/huggingface_byoc/run_glue.py
@@ -0,0 +1,568 @@
+#!/usr/bin/env python
+# coding=utf-8
+# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+""" Finetuning the library models for sequence classification on GLUE."""
+# You can also adapt this script on your own text classification task. Pointers for this are left as comments.
+
+import logging
+import os
+import random
+import sys
+from dataclasses import dataclass, field
+from typing import Optional
+
+import numpy as np
+from datasets import load_dataset, load_metric
+
+import transformers
+from transformers import (
+ AutoConfig,
+ AutoModelForSequenceClassification,
+ AutoTokenizer,
+ DataCollatorWithPadding,
+ EvalPrediction,
+ HfArgumentParser,
+ PretrainedConfig,
+ Trainer,
+ TrainingArguments,
+ default_data_collator,
+ set_seed,
+)
+from transformers.trainer_utils import get_last_checkpoint, is_main_process
+from transformers.utils import check_min_version
+
+
+# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
+check_min_version("4.5.0")
+
+task_to_keys = {
+ "cola": ("sentence", None),
+ "mnli": ("premise", "hypothesis"),
+ "mrpc": ("sentence1", "sentence2"),
+ "qnli": ("question", "sentence"),
+ "qqp": ("question1", "question2"),
+ "rte": ("sentence1", "sentence2"),
+ "sst2": ("sentence", None),
+ "stsb": ("sentence1", "sentence2"),
+ "wnli": ("sentence1", "sentence2"),
+}
+
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class DataTrainingArguments:
+ """
+ Arguments pertaining to what data we are going to input our model for training and eval.
+
+ Using `HfArgumentParser` we can turn this class
+ into argparse arguments to be able to specify them on
+ the command line.
+ """
+
+ task_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
+ )
+ max_seq_length: int = field(
+ default=128,
+ metadata={
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
+ "than this will be truncated, sequences shorter will be padded."
+ },
+ )
+ overwrite_cache: bool = field(
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
+ )
+ pad_to_max_length: bool = field(
+ default=True,
+ metadata={
+ "help": "Whether to pad all samples to `max_seq_length`. "
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
+ },
+ )
+ max_train_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
+ "value if set."
+ },
+ )
+ max_val_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of validation examples to this "
+ "value if set."
+ },
+ )
+ max_test_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "For debugging purposes or quicker training, truncate the number of test examples to this "
+ "value if set."
+ },
+ )
+ train_file: Optional[str] = field(
+ default=None, metadata={"help": "A csv or a json file containing the training data."}
+ )
+ validation_file: Optional[str] = field(
+ default=None, metadata={"help": "A csv or a json file containing the validation data."}
+ )
+ test_file: Optional[str] = field(
+ default=None, metadata={"help": "A csv or a json file containing the test data."}
+ )
+
+ def __post_init__(self):
+ if self.task_name is not None:
+ self.task_name = self.task_name.lower()
+ if self.task_name not in task_to_keys.keys():
+ raise ValueError(
+ "Unknown task, you should pick one in " + ",".join(task_to_keys.keys())
+ )
+ elif self.train_file is None or self.validation_file is None:
+ raise ValueError("Need either a GLUE task or a training/validation file.")
+ else:
+ train_extension = self.train_file.split(".")[-1]
+ assert train_extension in [
+ "csv",
+ "json",
+ ], "`train_file` should be a csv or a json file."
+ validation_extension = self.validation_file.split(".")[-1]
+ assert (
+ validation_extension == train_extension
+ ), "`validation_file` should have the same extension (csv or json) as `train_file`."
+
+
+@dataclass
+class ModelArguments:
+ """
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
+ """
+
+ model_name_or_path: str = field(
+ metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
+ )
+ config_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "Pretrained config name or path if not the same as model_name"},
+ )
+ tokenizer_name: Optional[str] = field(
+ default=None,
+ metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
+ )
+ cache_dir: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "Where do you want to store the pretrained models downloaded from huggingface.co"
+ },
+ )
+ use_fast_tokenizer: bool = field(
+ default=True,
+ metadata={
+ "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
+ },
+ )
+ model_revision: str = field(
+ default="main",
+ metadata={
+ "help": "The specific model version to use (can be a branch name, tag name or commit id)."
+ },
+ )
+ use_auth_token: bool = field(
+ default=False,
+ metadata={
+ "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
+ "with private models)."
+ },
+ )
+
+
+def main():
+ # See all possible arguments in src/transformers/training_args.py
+ # or by passing the --help flag to this script.
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
+
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
+ # If we pass only one argument to the script and it's the path to a json file,
+ # let's parse it to get our arguments.
+ model_args, data_args, training_args = parser.parse_json_file(
+ json_file=os.path.abspath(sys.argv[1])
+ )
+ else:
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
+
+ # Detecting last checkpoint.
+ last_checkpoint = None
+ if (
+ os.path.isdir(training_args.output_dir)
+ and training_args.do_train
+ and not training_args.overwrite_output_dir
+ ):
+ last_checkpoint = get_last_checkpoint(training_args.output_dir)
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
+ raise ValueError(
+ f"Output directory ({training_args.output_dir}) already exists and is not empty. "
+ "Use --overwrite_output_dir to overcome."
+ )
+ elif last_checkpoint is not None:
+ logger.info(
+ f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
+ "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
+ )
+
+ # Setup logging
+ logging.basicConfig(
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+ datefmt="%m/%d/%Y %H:%M:%S",
+ handlers=[logging.StreamHandler(sys.stdout)],
+ )
+ logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
+
+ # Log on each process the small summary:
+ logger.warning(
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
+ )
+ # Set the verbosity to info of the Transformers logger (on main process only):
+ if is_main_process(training_args.local_rank):
+ transformers.utils.logging.set_verbosity_info()
+ transformers.utils.logging.enable_default_handler()
+ transformers.utils.logging.enable_explicit_format()
+ logger.info(f"Training/evaluation parameters {training_args}")
+
+ # Set seed before initializing model.
+ set_seed(training_args.seed)
+
+ # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
+ # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
+ #
+ # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
+ # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
+ # label if at least two columns are provided.
+ #
+ # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
+ # single column. You can easily tweak this behavior (see below)
+ #
+ # In distributed training, the load_dataset function guarantee that only one local process can concurrently
+ # download the dataset.
+ if data_args.task_name is not None:
+ # Downloading and loading a dataset from the hub.
+ datasets = load_dataset("glue", data_args.task_name)
+ else:
+ # Loading a dataset from your local files.
+ # CSV/JSON training and evaluation files are needed.
+ data_files = {"train": data_args.train_file, "validation": data_args.validation_file}
+
+ # Get the test dataset: you can provide your own CSV/JSON test file (see below)
+ # when you use `do_predict` without specifying a GLUE benchmark task.
+ if training_args.do_predict:
+ if data_args.test_file is not None:
+ train_extension = data_args.train_file.split(".")[-1]
+ test_extension = data_args.test_file.split(".")[-1]
+ assert (
+ test_extension == train_extension
+ ), "`test_file` should have the same extension (csv or json) as `train_file`."
+ data_files["test"] = data_args.test_file
+ else:
+ raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
+
+ for key in data_files.keys():
+ logger.info(f"load a local file for {key}: {data_files[key]}")
+
+ if data_args.train_file.endswith(".csv"):
+ # Loading a dataset from local csv files
+ datasets = load_dataset("csv", data_files=data_files)
+ else:
+ # Loading a dataset from local json files
+ datasets = load_dataset("json", data_files=data_files)
+ # See more about loading any type of standard or custom dataset at
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
+
+ # Labels
+ if data_args.task_name is not None:
+ is_regression = data_args.task_name == "stsb"
+ if not is_regression:
+ label_list = datasets["train"].features["label"].names
+ num_labels = len(label_list)
+ else:
+ num_labels = 1
+ else:
+ # Trying to have good defaults here, don't hesitate to tweak to your needs.
+ is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"]
+ if is_regression:
+ num_labels = 1
+ else:
+ # A useful fast method:
+ # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique
+ label_list = datasets["train"].unique("label")
+ label_list.sort() # Let's sort it for determinism
+ num_labels = len(label_list)
+
+ # Load pretrained model and tokenizer
+ #
+ # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
+ # download model & vocab.
+ config = AutoConfig.from_pretrained(
+ model_args.config_name if model_args.config_name else model_args.model_name_or_path,
+ num_labels=num_labels,
+ finetuning_task=data_args.task_name,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
+ cache_dir=model_args.cache_dir,
+ use_fast=model_args.use_fast_tokenizer,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+ model = AutoModelForSequenceClassification.from_pretrained(
+ model_args.model_name_or_path,
+ from_tf=bool(".ckpt" in model_args.model_name_or_path),
+ config=config,
+ cache_dir=model_args.cache_dir,
+ revision=model_args.model_revision,
+ use_auth_token=True if model_args.use_auth_token else None,
+ )
+
+ # Preprocessing the datasets
+ if data_args.task_name is not None:
+ sentence1_key, sentence2_key = task_to_keys[data_args.task_name]
+ else:
+ # Again, we try to have some nice defaults but don't hesitate to tweak to your use case.
+ non_label_column_names = [
+ name for name in datasets["train"].column_names if name != "label"
+ ]
+ if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names:
+ sentence1_key, sentence2_key = "sentence1", "sentence2"
+ else:
+ if len(non_label_column_names) >= 2:
+ sentence1_key, sentence2_key = non_label_column_names[:2]
+ else:
+ sentence1_key, sentence2_key = non_label_column_names[0], None
+
+ # Padding strategy
+ if data_args.pad_to_max_length:
+ padding = "max_length"
+ else:
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
+ padding = False
+
+ # Some models have set the order of the labels to use, so let's make sure we do use it.
+ label_to_id = None
+ if (
+ model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id
+ and data_args.task_name is not None
+ and not is_regression
+ ):
+ # Some have all caps in their config, some don't.
+ label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
+ if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
+ label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
+ else:
+ logger.warn(
+ "Your model seems to have been trained with labels, but they don't match the dataset: ",
+ f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
+ "\nIgnoring the model labels as a result.",
+ )
+ elif data_args.task_name is None and not is_regression:
+ label_to_id = {v: i for i, v in enumerate(label_list)}
+
+ if data_args.max_seq_length > tokenizer.model_max_length:
+ logger.warn(
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
+ )
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
+
+ def preprocess_function(examples):
+ # Tokenize the texts
+ args = (
+ (examples[sentence1_key],)
+ if sentence2_key is None
+ else (examples[sentence1_key], examples[sentence2_key])
+ )
+ result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True)
+
+ # Map labels to IDs (not necessary for GLUE tasks)
+ if label_to_id is not None and "label" in examples:
+ result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
+ return result
+
+ datasets = datasets.map(
+ preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache
+ )
+ if training_args.do_train:
+ if "train" not in datasets:
+ raise ValueError("--do_train requires a train dataset")
+ train_dataset = datasets["train"]
+ if data_args.max_train_samples is not None:
+ train_dataset = train_dataset.select(range(data_args.max_train_samples))
+
+ if training_args.do_eval:
+ if "validation" not in datasets and "validation_matched" not in datasets:
+ raise ValueError("--do_eval requires a validation dataset")
+ eval_dataset = datasets[
+ "validation_matched" if data_args.task_name == "mnli" else "validation"
+ ]
+ if data_args.max_val_samples is not None:
+ eval_dataset = eval_dataset.select(range(data_args.max_val_samples))
+
+ if (
+ training_args.do_predict
+ or data_args.task_name is not None
+ or data_args.test_file is not None
+ ):
+ if "test" not in datasets and "test_matched" not in datasets:
+ raise ValueError("--do_predict requires a test dataset")
+ test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"]
+ if data_args.max_test_samples is not None:
+ test_dataset = test_dataset.select(range(data_args.max_test_samples))
+
+ # Log a few random samples from the training set:
+ if training_args.do_train:
+ for index in random.sample(range(len(train_dataset)), 3):
+ logger.info(f"Sample {index} of the training set: {train_dataset[index]}.")
+
+ # Get the metric function
+ if data_args.task_name is not None:
+ metric = load_metric("glue", data_args.task_name)
+ # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from
+ # compute_metrics
+
+ # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
+ # predictions and label_ids field) and has to return a dictionary string to float.
+ def compute_metrics(p: EvalPrediction):
+ preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
+ preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)
+ if data_args.task_name is not None:
+ result = metric.compute(predictions=preds, references=p.label_ids)
+ if len(result) > 1:
+ result["combined_score"] = np.mean(list(result.values())).item()
+ return result
+ elif is_regression:
+ return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
+ else:
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
+
+ # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding.
+ if data_args.pad_to_max_length:
+ data_collator = default_data_collator
+ elif training_args.fp16:
+ data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
+ else:
+ data_collator = None
+
+ # Initialize our Trainer
+ trainer = Trainer(
+ model=model,
+ args=training_args,
+ train_dataset=train_dataset if training_args.do_train else None,
+ eval_dataset=eval_dataset if training_args.do_eval else None,
+ compute_metrics=compute_metrics,
+ tokenizer=tokenizer,
+ data_collator=data_collator,
+ )
+
+ # Training
+ if training_args.do_train:
+ checkpoint = None
+ if last_checkpoint is not None:
+ checkpoint = last_checkpoint
+ elif os.path.isdir(model_args.model_name_or_path):
+ # Check the config from that potential checkpoint has the right number of labels before using it as a
+ # checkpoint.
+ if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels:
+ checkpoint = model_args.model_name_or_path
+
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
+ metrics = train_result.metrics
+ max_train_samples = (
+ data_args.max_train_samples
+ if data_args.max_train_samples is not None
+ else len(train_dataset)
+ )
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
+
+ trainer.save_model() # Saves the tokenizer too for easy upload
+
+ trainer.log_metrics("train", metrics)
+ trainer.save_metrics("train", metrics)
+ trainer.save_state()
+
+ # Evaluation
+ if training_args.do_eval:
+ logger.info("*** Evaluate ***")
+
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
+ tasks = [data_args.task_name]
+ eval_datasets = [eval_dataset]
+ if data_args.task_name == "mnli":
+ tasks.append("mnli-mm")
+ eval_datasets.append(datasets["validation_mismatched"])
+
+ for eval_dataset, task in zip(eval_datasets, tasks):
+ metrics = trainer.evaluate(eval_dataset=eval_dataset)
+
+ max_val_samples = (
+ data_args.max_val_samples
+ if data_args.max_val_samples is not None
+ else len(eval_dataset)
+ )
+ metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
+
+ trainer.log_metrics("eval", metrics)
+ trainer.save_metrics("eval", metrics)
+
+ if training_args.do_predict:
+ logger.info("*** Test ***")
+
+ # Loop to handle MNLI double evaluation (matched, mis-matched)
+ tasks = [data_args.task_name]
+ test_datasets = [test_dataset]
+ if data_args.task_name == "mnli":
+ tasks.append("mnli-mm")
+ test_datasets.append(datasets["test_mismatched"])
+
+ for test_dataset, task in zip(test_datasets, tasks):
+ # Removing the `label` columns because it contains -1 and Trainer won't like that.
+ test_dataset.remove_columns_("label")
+ predictions = trainer.predict(test_dataset=test_dataset).predictions
+ predictions = (
+ np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1)
+ )
+
+ output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt")
+ if trainer.is_world_process_zero():
+ with open(output_test_file, "w") as writer:
+ logger.info(f"***** Test results {task} *****")
+ writer.write("index\tprediction\n")
+ for index, item in enumerate(predictions):
+ if is_regression:
+ writer.write(f"{index}\t{item:3.3f}\n")
+ else:
+ item = label_list[item]
+ writer.write(f"{index}\t{item}\n")
+
+
+def _mp_fn(index):
+ # For xla_spawn (TPUs)
+ main()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/data/huggingface_byoc/train/dummy.csv b/tests/data/huggingface_byoc/train/dummy.csv
new file mode 100644
index 0000000000..fb1539d552
--- /dev/null
+++ b/tests/data/huggingface_byoc/train/dummy.csv
@@ -0,0 +1 @@
+# dummy data
\ No newline at end of file
diff --git a/tests/data/multimodel/container/Dockerfile b/tests/data/multimodel/container/Dockerfile
index 4792a429c1..71c38a6605 100644
--- a/tests/data/multimodel/container/Dockerfile
+++ b/tests/data/multimodel/container/Dockerfile
@@ -1,4 +1,5 @@
-FROM public.ecr.aws/ubuntu/ubuntu:18.04
+# added latest image from https://gallery.ecr.aws/lts/ubuntu
+FROM public.ecr.aws/ubuntu/ubuntu:22.04
# Set a docker label to advertise multi-model support on the container
LABEL com.amazonaws.sagemaker.capabilities.multi-models=true
@@ -15,7 +16,7 @@ RUN apt-get update && \
curl \
vim \
&& rm -rf /var/lib/apt/lists/* \
- && curl -O https://bootstrap.pypa.io/pip/3.6/get-pip.py \
+ && curl -O https://bootstrap.pypa.io/pip/get-pip.py \
&& python3 get-pip.py
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1
diff --git a/tests/data/pipeline/test_source_dir/script_1.py b/tests/data/pipeline/test_source_dir/script_1.py
new file mode 100644
index 0000000000..4a427b1898
--- /dev/null
+++ b/tests/data/pipeline/test_source_dir/script_1.py
@@ -0,0 +1,11 @@
+"""
+Integ test file script_1.py
+"""
+import pathlib
+
+if __name__ == "__main__":
+
+ print("writing file to /opt/ml/processing/test/test.py...")
+ pathlib.Path("/opt/ml/processing/test").mkdir(parents=True, exist_ok=True)
+ with open("/opt/ml/processing/test/test.py", "w") as f:
+ f.write('print("test...")')
diff --git a/tests/data/pipeline/test_source_dir/script_2.py b/tests/data/pipeline/test_source_dir/script_2.py
new file mode 100644
index 0000000000..6245dac987
--- /dev/null
+++ b/tests/data/pipeline/test_source_dir/script_2.py
@@ -0,0 +1,9 @@
+"""
+Integ test file script_2.py
+"""
+
+if __name__ == "__main__":
+
+ print("reading file: /opt/ml/procesing/test/test.py")
+ with open("/opt/ml/processing/test/test.py", "r") as f:
+ print(f.read())
diff --git a/tests/data/pipeline/test_source_dir_2/script_2.py b/tests/data/pipeline/test_source_dir_2/script_2.py
new file mode 100644
index 0000000000..6245dac987
--- /dev/null
+++ b/tests/data/pipeline/test_source_dir_2/script_2.py
@@ -0,0 +1,9 @@
+"""
+Integ test file script_2.py
+"""
+
+if __name__ == "__main__":
+
+ print("reading file: /opt/ml/procesing/test/test.py")
+ with open("/opt/ml/processing/test/test.py", "r") as f:
+ print(f.read())
diff --git a/tests/data/pytorch_neo/code/inference.py b/tests/data/pytorch_neo/code/inference.py
index 5b89c2bebc..79fe66d716 100644
--- a/tests/data/pytorch_neo/code/inference.py
+++ b/tests/data/pytorch_neo/code/inference.py
@@ -71,8 +71,8 @@ def model_fn(model_dir):
logger.info("model_fn")
neopytorch.config(model_dir=model_dir, neo_runtime=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
- # The compiled model is saved as "model.pth"
- model = torch.jit.load(os.path.join(model_dir, "model.pth"), map_location=device)
+ # The compiled model is saved as "model.pth" or "model.pt"
+ model = torch.jit.load(os.path.join(model_dir, "model.pt"), map_location=device)
# It is recommended to run warm-up inference during model load
sample_input_path = os.path.join(model_dir, "sample_input.pkl")
diff --git a/tests/data/spark/code/java/TestJarFile.jar b/tests/data/spark/code/java/TestJarFile.jar
new file mode 100644
index 0000000000..d528331d55
Binary files /dev/null and b/tests/data/spark/code/java/TestJarFile.jar differ
diff --git a/tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar b/tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar
new file mode 100644
index 0000000000..056675146d
Binary files /dev/null and b/tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar differ
diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py
index 00ed09577b..9133fc8904 100644
--- a/tests/integ/__init__.py
+++ b/tests/integ/__init__.py
@@ -158,7 +158,7 @@
"ap-northeast-1",
"eu-central-1",
]
-# TODO: SM Training Compiler team to add all supported regions.
+
TRAINING_COMPILER_SUPPORTED_REGIONS = [
"af-south-1",
"ap-east-1",
diff --git a/tests/integ/sagemaker/experiments/__init__.py b/tests/integ/sagemaker/experiments/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integ/sagemaker/experiments/conftest.py b/tests/integ/sagemaker/experiments/conftest.py
new file mode 100644
index 0000000000..ca40a3ba6d
--- /dev/null
+++ b/tests/integ/sagemaker/experiments/conftest.py
@@ -0,0 +1,177 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import glob
+import logging
+import os
+import shutil
+import tempfile
+import time
+import uuid
+
+import boto3
+import pytest
+
+from sagemaker.experiments import Run
+from tests.integ import DATA_DIR
+
+from sagemaker.experiments import trial_component, trial, experiment
+from sagemaker.utils import retry_with_backoff, unique_name_from_base
+from tests.integ.sagemaker.experiments.helpers import name, names
+
+TAGS = [{"Key": "some-key", "Value": "some-value"}]
+EXP_NAME_BASE_IN_LOCAL = "Job-Exp-in-Local"
+RUN_NAME_IN_LOCAL = "job-run-in-local"
+
+
+@pytest.fixture(scope="module")
+def run_obj(sagemaker_session):
+ run = Run(
+ experiment_name=unique_name_from_base(EXP_NAME_BASE_IN_LOCAL),
+ run_name=RUN_NAME_IN_LOCAL,
+ sagemaker_session=sagemaker_session,
+ )
+ try:
+ yield run
+ time.sleep(0.5)
+ finally:
+ exp = experiment._Experiment.load(
+ experiment_name=run.experiment_name, sagemaker_session=sagemaker_session
+ )
+ exp._delete_all(action="--force")
+
+
+@pytest.fixture(scope="module")
+def trial_component_obj(sagemaker_session):
+ trial_component_obj = trial_component._TrialComponent.create(
+ trial_component_name=name(),
+ sagemaker_session=sagemaker_session,
+ tags=TAGS,
+ )
+ yield trial_component_obj
+ time.sleep(0.5)
+ _delete_associations(trial_component_obj.trial_component_arn, sagemaker_session)
+ retry_with_backoff(trial_component_obj.delete)
+
+
+@pytest.fixture(scope="module")
+def experiment_obj(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ description = "{}-{}".format("description", str(uuid.uuid4()))
+ boto3.set_stream_logger("", logging.INFO)
+ experiment_name = name()
+ experiment_obj = experiment._Experiment.create(
+ experiment_name=experiment_name,
+ description=description,
+ sagemaker_session=sagemaker_session,
+ tags=TAGS,
+ )
+ yield experiment_obj
+ time.sleep(0.5)
+ experiment_obj.delete()
+ with pytest.raises(client.exceptions.ResourceNotFound):
+ client.describe_experiment(ExperimentName=experiment_name)
+
+
+@pytest.fixture(scope="module")
+def trial_obj(sagemaker_session, experiment_obj):
+ trial_obj = trial._Trial.create(
+ trial_name=name(),
+ experiment_name=experiment_obj.experiment_name,
+ tags=TAGS,
+ sagemaker_session=sagemaker_session,
+ )
+ yield trial_obj
+ time.sleep(0.5)
+ trial_obj.delete()
+
+
+@pytest.fixture(scope="module")
+def trials(experiment_obj, sagemaker_session):
+ trial_objs = []
+ for trial_name in names():
+ next_trial = trial._Trial.create(
+ trial_name=trial_name,
+ experiment_name=experiment_obj.experiment_name,
+ sagemaker_session=sagemaker_session,
+ )
+ trial_objs.append(next_trial)
+ time.sleep(0.5)
+ yield trial_objs
+ for trial_obj in trial_objs:
+ trial_obj.delete()
+
+
+@pytest.fixture(scope="module")
+def trial_component_with_force_disassociation_obj(trials, sagemaker_session):
+ trial_component_obj = trial_component._TrialComponent.create(
+ trial_component_name=name(), sagemaker_session=sagemaker_session
+ )
+ for trial_obj in trials:
+ sagemaker_session.sagemaker_client.associate_trial_component(
+ TrialName=trial_obj.trial_name,
+ TrialComponentName=trial_component_obj.trial_component_name,
+ )
+ yield trial_component_obj
+ time.sleep(0.5)
+ trial_component_obj.delete(force_disassociate=True)
+
+
+@pytest.fixture(scope="module")
+def trial_components(sagemaker_session):
+ trial_component_objs = [
+ trial_component._TrialComponent.create(
+ trial_component_name=trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ for trial_component_name in names()
+ ]
+ yield trial_component_objs
+ for trial_component_obj in trial_component_objs:
+ trial_component_obj.delete()
+
+
+@pytest.fixture(scope="module")
+def tempdir():
+ temp_dir = tempfile.mkdtemp()
+ yield temp_dir
+ shutil.rmtree(temp_dir)
+
+
+_EXP_PLUS_SDK_TAR = "sagemaker-dev-1.0.tar.gz"
+
+
+@pytest.fixture(scope="module")
+def dev_sdk_tar():
+ resource_dir = os.path.join(DATA_DIR, "experiment")
+ os.system("python setup.py sdist")
+ sdist_path = max(glob.glob("dist/sagemaker-*"), key=os.path.getctime)
+ sdk_file = os.path.join(resource_dir, _EXP_PLUS_SDK_TAR)
+ shutil.copy(sdist_path, sdk_file)
+ return sdk_file
+
+
+def _delete_associations(arn, sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ outgoing_associations = client.list_associations(SourceArn=arn)["AssociationSummaries"]
+ incoming_associations = client.list_associations(DestinationArn=arn)["AssociationSummaries"]
+ associations = []
+ if outgoing_associations:
+ associations.extend(outgoing_associations)
+ if incoming_associations:
+ associations.extend(incoming_associations)
+ for association in associations:
+ source_arn = association["SourceArn"]
+ destination_arn = association["DestinationArn"]
+ client.delete_association(SourceArn=source_arn, DestinationArn=destination_arn)
diff --git a/tests/integ/sagemaker/experiments/helpers.py b/tests/integ/sagemaker/experiments/helpers.py
new file mode 100644
index 0000000000..b5e8064b08
--- /dev/null
+++ b/tests/integ/sagemaker/experiments/helpers.py
@@ -0,0 +1,42 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from contextlib import contextmanager
+
+from sagemaker import utils
+from sagemaker.experiments.experiment import _Experiment
+
+EXP_INTEG_TEST_NAME_PREFIX = "experiments-integ"
+
+
+def name():
+ return utils.unique_name_from_base(EXP_INTEG_TEST_NAME_PREFIX)
+
+
+def names():
+ return [utils.unique_name_from_base(EXP_INTEG_TEST_NAME_PREFIX) for i in range(3)]
+
+
+def to_seconds(dt):
+ return int(dt.timestamp())
+
+
+@contextmanager
+def cleanup_exp_resources(exp_names, sagemaker_session):
+ try:
+ yield
+ finally:
+ for exp_name in exp_names:
+ exp = _Experiment.load(experiment_name=exp_name, sagemaker_session=sagemaker_session)
+ exp._delete_all(action="--force")
diff --git a/tests/integ/sagemaker/experiments/test_experiment.py b/tests/integ/sagemaker/experiments/test_experiment.py
new file mode 100644
index 0000000000..ff7d5fac37
--- /dev/null
+++ b/tests/integ/sagemaker/experiments/test_experiment.py
@@ -0,0 +1,56 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from sagemaker.experiments import experiment
+from tests.integ.sagemaker.experiments.helpers import name
+
+
+def test_create_delete(experiment_obj):
+ # The fixture creates deletes, just ensure fixture is used at least once
+ assert experiment_obj.experiment_name
+
+
+def test_create_tags(experiment_obj, sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ while True:
+ actual_tags = client.list_tags(ResourceArn=experiment_obj.experiment_arn)["Tags"]
+ if actual_tags:
+ break
+ for tag in actual_tags:
+ if "aws:tag" in tag.get("Key"):
+ actual_tags.remove(tag)
+ assert actual_tags == experiment_obj.tags
+
+
+def test_save(experiment_obj):
+ description = name()
+ experiment_obj.description = description
+ experiment_obj.save()
+
+
+def test_save_load(experiment_obj, sagemaker_session):
+ experiment_obj_two = experiment._Experiment.load(
+ experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session
+ )
+ assert experiment_obj.experiment_name == experiment_obj_two.experiment_name
+ assert experiment_obj.description == experiment_obj_two.description
+
+ experiment_obj.description = name()
+ experiment_obj.display_name = name()
+ experiment_obj.save()
+ experiment_obj_three = experiment._Experiment.load(
+ experiment_name=experiment_obj.experiment_name, sagemaker_session=sagemaker_session
+ )
+ assert experiment_obj.description == experiment_obj_three.description
+ assert experiment_obj.display_name == experiment_obj_three.display_name
diff --git a/tests/integ/sagemaker/experiments/test_metrics.py b/tests/integ/sagemaker/experiments/test_metrics.py
new file mode 100644
index 0000000000..15c0c2f9dc
--- /dev/null
+++ b/tests/integ/sagemaker/experiments/test_metrics.py
@@ -0,0 +1,39 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+import random
+from sagemaker.experiments._metrics import _MetricsManager
+from sagemaker.experiments.trial_component import _TrialComponent
+from sagemaker.utils import retry_with_backoff
+
+
+def test_end_to_end(trial_component_obj, sagemaker_session):
+ # The fixture creates deletes, just ensure fixture is used at least once
+ with _MetricsManager(trial_component_obj.trial_component_name, sagemaker_session) as mm:
+ for i in range(100):
+ mm.log_metric("test-x-step", random.random(), step=i)
+ mm.log_metric("test-x-timestamp", random.random())
+
+ def verify_metrics():
+ updated_tc = _TrialComponent.load(
+ trial_component_name=trial_component_obj.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ metrics = updated_tc.metrics
+ # TODO: revert to len(metrics) == 2 once backend fix reaches prod
+ assert len(metrics) > 0
+ assert list(filter(lambda x: x.metric_name == "test-x-step", metrics))
+ assert list(filter(lambda x: x.metric_name == "test-x-timestamp", metrics))
+
+ # metrics -> eureka propagation
+ retry_with_backoff(verify_metrics)
diff --git a/tests/integ/sagemaker/experiments/test_run.py b/tests/integ/sagemaker/experiments/test_run.py
new file mode 100644
index 0000000000..713a6a3792
--- /dev/null
+++ b/tests/integ/sagemaker/experiments/test_run.py
@@ -0,0 +1,662 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import datetime
+import os
+
+import pytest
+
+from tests.integ.sagemaker.experiments.conftest import TAGS
+from sagemaker.experiments._api_types import _TrialComponentStatusType
+from sagemaker.experiments._utils import is_run_trial_component
+from sagemaker.processing import FrameworkProcessor
+from sagemaker.pytorch import PyTorch
+from sagemaker.s3 import S3Uploader
+from sagemaker.xgboost import XGBoostModel
+from tests.integ import DATA_DIR
+from sagemaker.experiments._metrics import BATCH_SIZE
+from sagemaker.experiments.trial_component import _TrialComponent
+from sagemaker.sklearn import SKLearn
+from sagemaker.utils import retry_with_backoff, unique_name_from_base
+from tests.integ.sagemaker.experiments.helpers import name, cleanup_exp_resources
+from sagemaker.experiments.run import (
+ RUN_NAME_BASE,
+ DELIMITER,
+)
+from sagemaker.experiments import Run, load_run, list_runs
+from sagemaker.experiments._helper import _DEFAULT_ARTIFACT_PREFIX
+
+
+# when running integration tests locally modify this to your test account's execution role
+EXECUTION_ROLE = "SageMakerRole"
+
+
+@pytest.fixture
+def artifact_file_path(tempdir):
+ file_contents = "test artifact file"
+ file_path = os.path.join(tempdir, "artifact_file.txt")
+ with open(file_path, "w") as foo_file:
+ foo_file.write(file_contents)
+ return file_path
+
+
+artifact_name = unique_name_from_base("Test-Artifact")
+file_artifact_name = f"File-Artifact-{name()}"
+metric_name = "Test-Local-Init-Log-Metric"
+
+
+def test_local_run_with_load(sagemaker_session, artifact_file_path):
+ exp_name = f"My-Local-Exp-{name()}"
+ with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
+ # Run name is not provided, will create a new TC
+ with Run(experiment_name=exp_name, sagemaker_session=sagemaker_session) as run1:
+ run1_name = run1.run_name
+ assert RUN_NAME_BASE in run1_name
+ _local_run_log_behaviors(
+ artifact_file_path=artifact_file_path,
+ sagemaker_session=sagemaker_session,
+ )
+
+ def verify_load_run():
+ with load_run(
+ experiment_name=exp_name,
+ run_name=run1_name,
+ sagemaker_session=sagemaker_session,
+ ) as run2:
+ assert run2.run_name == run1_name
+ assert (
+ run2._trial_component.trial_component_name
+ == f"{run2.experiment_name}{DELIMITER}{run1_name}"
+ )
+ _check_run_from_local_end_result(
+ sagemaker_session=sagemaker_session, tc=run2._trial_component
+ )
+
+ # Add retry to make sure metrics -> eureka propagation is consistent
+ retry_with_backoff(verify_load_run, 4)
+
+
+def test_two_local_run_init_with_same_run_name_and_different_exp_names(sagemaker_session):
+ exp_name1 = f"my-two-local-exp1-{name()}"
+ exp_name2 = f"my-two-local-exp2-{name()}"
+ run_name = "test-run"
+ with cleanup_exp_resources(
+ exp_names=[exp_name1, exp_name2], sagemaker_session=sagemaker_session
+ ):
+ # Run name is not provided, will create a new TC
+ with Run(
+ experiment_name=exp_name1, run_name=run_name, sagemaker_session=sagemaker_session
+ ) as run1:
+ pass
+ with Run(
+ experiment_name=exp_name2, run_name=run_name, sagemaker_session=sagemaker_session
+ ) as run2:
+ pass
+
+ assert run1.experiment_name != run2.experiment_name
+ assert run1.run_name == run2.run_name
+ assert (
+ run1._trial_component.trial_component_name != run2._trial_component.trial_component_name
+ )
+ assert run1._trial_component.trial_component_name == f"{exp_name1}{DELIMITER}{run_name}"
+ assert run2._trial_component.trial_component_name == f"{exp_name2}{DELIMITER}{run_name}"
+
+
+@pytest.mark.parametrize(
+ "input_names",
+ [
+ (f"my-local-exp-{name()}", "test-run", None), # both have delimiter -
+ ("my-test-1", "my-test-1", None), # exp_name equals run_name
+ ("my-test-3", "my-test-3-run", None), # is subset of run_name
+ ("x" * 59, "test-run", None), # long exp_name
+ ("test-exp", "y" * 59, None), # long run_name
+ ("e" * 59, "y" * 59, None), # long exp_name and run_name
+ ("my-test4", "test-run", "run-display-name-test"), # with supplied display name
+ ],
+)
+def test_run_name_vs_trial_component_name_edge_cases(sagemaker_session, input_names):
+ exp_name, run_name, run_display_name = input_names
+ with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
+ with Run(
+ experiment_name=exp_name,
+ sagemaker_session=sagemaker_session,
+ run_name=run_name,
+ run_display_name=run_display_name,
+ ) as run1:
+ assert not run1._experiment.tags
+ assert not run1._trial.tags
+ is_run_tc = is_run_trial_component(
+ trial_component_name=run1._trial_component.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ assert is_run_tc
+
+ with load_run(
+ experiment_name=exp_name,
+ run_name=run_name,
+ sagemaker_session=sagemaker_session,
+ ) as run2:
+ assert run2.experiment_name == exp_name
+ assert run2.run_name == run_name
+ assert run2._trial_component.trial_component_name == f"{exp_name}{DELIMITER}{run_name}"
+ assert run2._trial_component.display_name in (
+ run_display_name,
+ run2._trial_component.trial_component_name,
+ )
+
+
+_EXP_NAME_BASE_IN_SCRIPT = "job-exp-in-script"
+_RUN_NAME_IN_SCRIPT = "job-run-in-script"
+
+_EXP_DIR = os.path.join(DATA_DIR, "experiment")
+_ENTRY_POINT_PATH = os.path.join(_EXP_DIR, "train_job_script_for_run_clz.py")
+_PYTHON_PROCESS_SCRIPT = "process_job_script_for_run_clz.py"
+_TRANSFORM_MATERIALS = os.path.join(_EXP_DIR, "transform_job_materials")
+
+_RUN_INIT = "init"
+_RUN_LOAD = "load"
+
+
+def test_run_from_local_and_train_job_and_all_exp_cfg_match(sagemaker_session, dev_sdk_tar):
+ # Notes:
+ # 1. The 1st Run TC created locally and its exp config was auto passed to the job
+ # 2. In training job, the same exp and run names are given in the Run constructor
+ # which will load the 1st Run TC in training job and log parameters
+ # and metrics there
+ # 3. In a different training job, load the same Run TC and log more parameters there.
+ exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT)
+ estimator = _generate_estimator(
+ sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name
+ )
+ tc_name = Run._generate_trial_component_name(
+ experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT
+ )
+
+ with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
+ with Run(
+ experiment_name=exp_name,
+ run_name=_RUN_NAME_IN_SCRIPT,
+ sagemaker_session=sagemaker_session,
+ ) as run:
+ init_start_time = _check_tc_status_when_entering(run._trial_component)
+ _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session)
+ # experiment_config is auto passed in by _RunContext
+ estimator.fit(
+ job_name=f"train-job-{name()}",
+ wait=True, # wait the training job to finish
+ logs="None", # set to "All" to display logs fetched from the training job
+ )
+ old_end_time = _check_tc_status_when_exiting(
+ trial_component_name=run._trial_component.trial_component_name,
+ init_start_time=init_start_time,
+ sagemaker_session=sagemaker_session,
+ )
+
+ _check_tc_status_when_exiting(
+ trial_component_name=run._trial_component.trial_component_name,
+ init_start_time=init_start_time,
+ old_end_time=old_end_time,
+ sagemaker_session=sagemaker_session,
+ )
+ assert run.experiment_name == exp_name
+ assert run.run_name == _RUN_NAME_IN_SCRIPT
+ _check_run_from_local_end_result(
+ tc=run._trial_component,
+ sagemaker_session=sagemaker_session,
+ is_complete_log=False,
+ )
+ _check_run_from_job_result(
+ tc_name=tc_name,
+ sagemaker_session=sagemaker_session,
+ )
+
+ with run:
+ estimator.environment["RUN_OPERATION"] = _RUN_LOAD
+ estimator.environment["CALL_RUN_LOAD_WITH_NO_NAME_ARGS"] = "True"
+ estimator.fit(
+ job_name=f"train-job-{name()}",
+ wait=True, # wait the training job to finish
+ logs="None", # set to "All" to display logs fetched from the training job
+ )
+
+ old_end_time = _check_tc_status_when_exiting(
+ trial_component_name=run._trial_component.trial_component_name,
+ init_start_time=init_start_time,
+ old_end_time=old_end_time,
+ sagemaker_session=sagemaker_session,
+ )
+
+ _check_tc_status_when_exiting(
+ trial_component_name=run._trial_component.trial_component_name,
+ init_start_time=init_start_time,
+ old_end_time=old_end_time,
+ sagemaker_session=sagemaker_session,
+ )
+ _check_run_from_job_result(
+ tc_name=tc_name,
+ sagemaker_session=sagemaker_session,
+ is_init=False,
+ has_extra_load=True,
+ )
+
+
+def test_run_from_local_and_train_job_and_exp_cfg_not_match(sagemaker_session, dev_sdk_tar):
+ # Notes:
+ # 1. The 1st Run TC created locally and its exp config was auto passed to the job
+ # 2. In training job, different exp and run names (i.e. 2nd Run TC) are given
+ # in the Run constructor which will create a Run TC according to the run_name
+ # passed in there and ignore the exp config in the job
+ # 3. Both metrics and parameters are logged in the Run TC created in job
+ # 4. In a different training job, load the 2nd Run TC and log more parameters there.
+ exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT)
+ exp_name2 = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT)
+ estimator = _generate_estimator(
+ sdk_tar=dev_sdk_tar, sagemaker_session=sagemaker_session, exp_name=exp_name
+ )
+ tc_name = Run._generate_trial_component_name(
+ experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT
+ )
+
+ with cleanup_exp_resources(
+ exp_names=[exp_name, exp_name2], sagemaker_session=sagemaker_session
+ ):
+ with Run(
+ experiment_name=exp_name2,
+ run_name=f"{_RUN_NAME_IN_SCRIPT}2",
+ sagemaker_session=sagemaker_session,
+ ) as run:
+ init_start_time = _check_tc_status_when_entering(run._trial_component)
+ # experiment_config is auto passed in by _RunContext
+ estimator.fit(
+ job_name=f"train-job-{name()}",
+ wait=True, # wait the training job to finish
+ logs="None", # set to "All" to display logs fetched from the training job
+ )
+ _check_tc_status_intermediate(
+ trial_component=run._trial_component,
+ sagemaker_session=sagemaker_session,
+ init_start_time=init_start_time,
+ )
+
+ old_end_time = _check_tc_status_when_exiting(
+ trial_component_name=run._trial_component.trial_component_name,
+ init_start_time=init_start_time,
+ sagemaker_session=sagemaker_session,
+ )
+ assert run.experiment_name != exp_name
+ assert run.run_name != _RUN_NAME_IN_SCRIPT
+ _check_run_from_job_result(
+ tc_name=tc_name,
+ sagemaker_session=sagemaker_session,
+ )
+
+ with run:
+ estimator.environment["RUN_OPERATION"] = _RUN_LOAD
+ estimator.fit(
+ job_name=f"train-job-{name()}",
+ wait=True, # wait the training job to finish
+ logs="None", # set to "All" to display logs fetched from the training job
+ )
+ _check_tc_status_intermediate(
+ trial_component=run._trial_component,
+ sagemaker_session=sagemaker_session,
+ init_start_time=init_start_time,
+ old_end_time=old_end_time,
+ )
+
+ _check_tc_status_when_exiting(
+ trial_component_name=run._trial_component.trial_component_name,
+ init_start_time=init_start_time,
+ old_end_time=old_end_time,
+ sagemaker_session=sagemaker_session,
+ )
+ _check_run_from_job_result(
+ tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False
+ )
+
+
+def test_run_from_train_job_only(sagemaker_session, dev_sdk_tar):
+ # Notes:
+ # 1. No Run TC created locally or specified in experiment config
+ # 2. In training job, Run is initialized
+ # which will create a Run TC according to the run_name passed in there
+ # 3. Both metrics and parameters are logged in the Run TC created in job
+ # 4. In a different training job, load the same Run TC and log more parameters there.
+ exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT)
+ estimator = _generate_estimator(
+ sdk_tar=dev_sdk_tar,
+ sagemaker_session=sagemaker_session,
+ exp_name=exp_name,
+ )
+ tc_name = Run._generate_trial_component_name(
+ experiment_name=exp_name, run_name=_RUN_NAME_IN_SCRIPT
+ )
+
+ with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
+ estimator.fit(
+ job_name=f"train-job-{name()}",
+ wait=True, # wait the training job to finish
+ logs="None", # set to "All" to display logs fetched from the training job
+ )
+ _check_run_from_job_result(
+ tc_name=tc_name,
+ sagemaker_session=sagemaker_session,
+ )
+
+ estimator.environment["RUN_OPERATION"] = _RUN_LOAD
+ estimator.fit(
+ job_name=f"train-job-{name()}",
+ wait=True, # wait the training job to finish
+ logs="None", # set to "All" to display logs fetched from the training job
+ )
+ _check_run_from_job_result(
+ tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False
+ )
+
+
+# dev_sdk_tar is required to trigger generating the dev SDK tar
+def test_run_from_processing_job_and_override_default_exp_config(
+ sagemaker_session, dev_sdk_tar, run_obj
+):
+ # Notes:
+ # 1. The 1st Run TC (run) created locally
+ # 2. Within the 2nd Run TC (run_obj)'s context, invoke processor.run
+ # but override the default experiment config in context of 2nd Run TC
+ # with the experiment config of the 1st Run TC
+ # 3. In the processing job script, load the 1st Run TC via the experiment config
+ # fetched from the job env
+ # 4. All data are logged in the Run TC either locally or in the processing job
+ exp_name = unique_name_from_base(_EXP_NAME_BASE_IN_SCRIPT)
+ processor = FrameworkProcessor(
+ estimator_cls=PyTorch,
+ framework_version="1.10",
+ py_version="py38",
+ instance_count=1,
+ instance_type="ml.m5.xlarge",
+ role=EXECUTION_ROLE,
+ sagemaker_session=sagemaker_session,
+ )
+
+ with cleanup_exp_resources(exp_names=[exp_name], sagemaker_session=sagemaker_session):
+ with Run(
+ experiment_name=exp_name,
+ run_name=_RUN_NAME_IN_SCRIPT,
+ sagemaker_session=sagemaker_session,
+ ) as run:
+ _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session)
+
+ with run_obj:
+ # Override the default experiment_config in _RunContext of run_obj
+ # with the experiment_config of run
+ processor.run(
+ code=_PYTHON_PROCESS_SCRIPT,
+ source_dir=_EXP_DIR,
+ job_name=f"process-job-{name()}",
+ wait=True, # wait the job to finish
+ logs=False,
+ experiment_config=run.experiment_config,
+ )
+
+ assert run_obj.experiment_name != run.experiment_name
+ assert run_obj.run_name != run.run_name
+ _check_run_from_local_end_result(
+ tc=run._trial_component,
+ sagemaker_session=sagemaker_session,
+ is_complete_log=False,
+ )
+ tc_name = Run._generate_trial_component_name(
+ experiment_name=run.experiment_name, run_name=run.run_name
+ )
+ _check_run_from_job_result(
+ tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False
+ )
+
+ with run_obj:
+ # Not to override the exp config and use the default one in the context
+ processor.run(
+ code=_PYTHON_PROCESS_SCRIPT,
+ source_dir=_EXP_DIR,
+ job_name=f"process-job-{name()}",
+ wait=True, # wait the job to finish
+ logs=False,
+ )
+
+ tc_name = Run._generate_trial_component_name(
+ experiment_name=run_obj.experiment_name, run_name=run_obj.run_name
+ )
+ _check_run_from_job_result(
+ tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False
+ )
+
+
+# dev_sdk_tar is required to trigger generating the dev SDK tar
+def test_run_from_transform_job(sagemaker_session, dev_sdk_tar, run_obj, xgboost_latest_version):
+ # Notes:
+ # 1. The 1st Run TC (run) created locally
+ # 2. In the inference script running in a transform job, load the 1st Run TC
+ # via explicitly passing the experiment_name and run_name of the 1st Run TC
+ # TODO: once we're able to retrieve exp config from the transform job env,
+ # we should expand this test and add the load_run() without explicitly supplying the names
+ # 3. All data are logged in the Run TC either locally or in the transform job
+ xgb_model_data_s3 = sagemaker_session.upload_data(
+ path=os.path.join(_TRANSFORM_MATERIALS, "xgb_model.tar.gz"),
+ key_prefix="integ-test-data/xgboost/model",
+ )
+ xgboost_model = XGBoostModel(
+ sagemaker_session=sagemaker_session,
+ model_data=xgb_model_data_s3,
+ role=EXECUTION_ROLE,
+ entry_point="inference.py",
+ source_dir=_EXP_DIR,
+ framework_version=xgboost_latest_version,
+ env={
+ "EXPERIMENT_NAME": run_obj.experiment_name,
+ "RUN_NAME": run_obj.run_name,
+ },
+ )
+ transformer = xgboost_model.transformer(
+ instance_count=1,
+ instance_type="ml.m5.4xlarge",
+ max_concurrent_transforms=5,
+ max_payload=1,
+ strategy="MultiRecord",
+ )
+ uri = "s3://{}/{}/input/data/{}".format(
+ sagemaker_session.default_bucket(),
+ "transform-test",
+ unique_name_from_base("json-data"),
+ )
+ input_data = S3Uploader.upload(
+ os.path.join(_TRANSFORM_MATERIALS, "data.csv"), uri, sagemaker_session=sagemaker_session
+ )
+
+ with run_obj:
+ _local_run_log_behaviors(is_complete_log=False, sagemaker_session=sagemaker_session)
+ transformer.transform(
+ data=input_data,
+ content_type="text/libsvm",
+ split_type="Line",
+ wait=True,
+ job_name=f"transform-job-{name()}",
+ )
+
+ _check_run_from_local_end_result(
+ tc=run_obj._trial_component,
+ sagemaker_session=sagemaker_session,
+ is_complete_log=False,
+ )
+ tc_name = Run._generate_trial_component_name(
+ experiment_name=run_obj.experiment_name, run_name=run_obj.run_name
+ )
+ _check_run_from_job_result(tc_name=tc_name, sagemaker_session=sagemaker_session, is_init=False)
+
+
+def test_list(run_obj, sagemaker_session):
+ tc1 = _TrialComponent.create(
+ trial_component_name=f"non-run-tc1-{name()}",
+ sagemaker_session=sagemaker_session,
+ )
+ tc2 = _TrialComponent.create(
+ trial_component_name=f"non-run-tc2-{name()}",
+ sagemaker_session=sagemaker_session,
+ tags=TAGS,
+ )
+ run_obj._trial.add_trial_component(tc1)
+ run_obj._trial.add_trial_component(tc2)
+
+ run_tcs = list_runs(
+ experiment_name=run_obj.experiment_name, sagemaker_session=sagemaker_session
+ )
+ assert len(run_tcs) == 1
+ assert run_tcs[0].run_name == run_obj.run_name
+ assert run_tcs[0].experiment_name == run_obj.experiment_name
+ assert run_tcs[0].experiment_config == run_obj.experiment_config
+
+
+def _generate_estimator(exp_name, sdk_tar, sagemaker_session):
+ return SKLearn(
+ framework_version="0.23-1",
+ entry_point=_ENTRY_POINT_PATH,
+ dependencies=[sdk_tar],
+ role=EXECUTION_ROLE,
+ instance_type="ml.m5.large",
+ instance_count=1,
+ volume_size=10,
+ max_run=900,
+ enable_sagemaker_metrics=True,
+ environment={
+ "EXPERIMENT_NAME": exp_name,
+ "RUN_NAME": _RUN_NAME_IN_SCRIPT,
+ "RUN_OPERATION": _RUN_INIT,
+ },
+ sagemaker_session=sagemaker_session,
+ )
+
+
+def _local_run_log_behaviors(
+ sagemaker_session,
+ artifact_file_path=None,
+ is_complete_log=True,
+):
+ with load_run(sagemaker_session=sagemaker_session) as run:
+ run.log_parameter("pa", 1.0)
+ run.log_parameter("pb", "p2-value")
+ run.log_parameters({"pc": 2.0, "pd": "p4-value"})
+
+ if is_complete_log:
+ run.log_file(file_path=artifact_file_path, name=file_artifact_name)
+ run.log_artifact(name=artifact_name, value="s3://Output")
+ run.log_artifact(name=artifact_name, value="s3://Input", is_output=False)
+
+ for i in range(BATCH_SIZE):
+ run.log_metric(name=metric_name, value=i, step=i)
+
+
+def _check_run_from_local_end_result(sagemaker_session, tc, is_complete_log=True):
+ assert tc.parameters == {"pa": 1.0, "pb": "p2-value", "pc": 2.0, "pd": "p4-value"}
+
+ if not is_complete_log:
+ return
+
+ s3_prefix = f"s3://{sagemaker_session.default_bucket()}/{_DEFAULT_ARTIFACT_PREFIX}"
+ assert s3_prefix in tc.output_artifacts[file_artifact_name].value
+ assert "text/plain" == tc.output_artifacts[file_artifact_name].media_type
+ assert "s3://Output" == tc.output_artifacts[artifact_name].value
+ assert not tc.output_artifacts[artifact_name].media_type
+ assert "s3://Input" == tc.input_artifacts[artifact_name].value
+ assert not tc.input_artifacts[artifact_name].media_type
+
+ # TODO: revert to len(tc.metrics) == 1 once backend fix reaches prod
+ assert len(tc.metrics) > 0
+ metric_summary = tc.metrics[0]
+ assert metric_summary.metric_name == metric_name
+ assert metric_summary.max == 9.0
+ assert metric_summary.min == 0.0
+
+
+def _check_run_from_job_result(sagemaker_session, tc_name=None, is_init=True, has_extra_load=False):
+ def validate_tc_updated_in_init():
+ assert tc.start_time
+ assert tc.end_time
+ assert tc.status.primary_status == _TrialComponentStatusType.Completed.value
+ assert tc.parameters["p1"] == 1.0
+ assert tc.parameters["p2"] == 2.0
+ # TODO: revert to assert len(tc.metrics) == 5 once
+ # backend fix hits prod
+ assert len(tc.metrics) > 0
+ for metric_summary in tc.metrics:
+ # metrics deletion is not supported at this point
+ # so its count would accumulate
+ assert metric_summary.count > 0
+ assert metric_summary.min == 0.0
+ assert metric_summary.max == 1.0
+
+ def validate_tc_updated_in_load():
+ assert tc.parameters["p3"] == 3.0
+ assert tc.parameters["p4"] == 4.0
+ assert len(tc.metrics) > 0
+ for metric_summary in tc.metrics:
+ if metric_summary.metric_name != "test-job-load-log-metric":
+ continue
+ assert metric_summary.last == 0.1
+ assert metric_summary.max == 0.1
+ assert metric_summary.min == 0.1
+ if has_extra_load:
+ assert tc.parameters["p5"] == 5.0
+ assert tc.parameters["p6"] == 6.0
+
+ tc = _TrialComponent.load(trial_component_name=tc_name, sagemaker_session=sagemaker_session)
+ if is_init:
+ # Add retry since the load behavior is inconsistent sometimes
+ retry_with_backoff(validate_tc_updated_in_init, 4)
+ else:
+ retry_with_backoff(validate_tc_updated_in_load, 4)
+
+
+def _check_tc_status_when_entering(trial_component):
+ assert isinstance(trial_component.start_time, datetime.datetime)
+ assert not trial_component.end_time
+ assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value
+ return trial_component.start_time
+
+
+def _check_tc_status_when_exiting(
+ trial_component_name, sagemaker_session, init_start_time, old_end_time=None
+):
+ tc = _TrialComponent.load(
+ trial_component_name=trial_component_name, sagemaker_session=sagemaker_session
+ )
+ # There will be deviation (< 1s) caused by different TS precisions used in Backend and SDK
+ assert abs(tc.start_time.timestamp() - init_start_time.timestamp()) < 1
+ assert tc.status.primary_status == _TrialComponentStatusType.Completed.value
+ assert isinstance(tc.end_time, datetime.datetime)
+ if old_end_time:
+ assert tc.end_time > old_end_time
+ return tc.end_time
+
+
+def _check_tc_status_intermediate(
+ trial_component, sagemaker_session, init_start_time, old_end_time=None
+):
+ tc_load = _TrialComponent.load(
+ trial_component_name=trial_component.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ assert abs(tc_load.start_time.timestamp() - init_start_time.timestamp()) < 1
+ assert tc_load.status.primary_status == _TrialComponentStatusType.InProgress.value
+ if not old_end_time:
+ assert not trial_component.end_time
+ return
+ assert isinstance(tc_load.end_time, datetime.datetime)
+ assert tc_load.end_time == old_end_time
diff --git a/tests/integ/sagemaker/experiments/test_trial.py b/tests/integ/sagemaker/experiments/test_trial.py
new file mode 100644
index 0000000000..08f646c086
--- /dev/null
+++ b/tests/integ/sagemaker/experiments/test_trial.py
@@ -0,0 +1,75 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import logging
+
+from sagemaker.experiments import trial
+from src.sagemaker.utils import retry_with_backoff
+
+
+def test_create_delete(trial_obj):
+ # Fixture creates / deletes, just ensure used at least once.
+ assert trial_obj.trial_name
+
+
+def test_create_tags(trial_obj, sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ while True:
+ actual_tags = client.list_tags(ResourceArn=trial_obj.trial_arn)["Tags"]
+ if actual_tags:
+ break
+ for tag in actual_tags:
+ if "aws:tag" in tag.get("Key"):
+ actual_tags.remove(tag)
+ assert actual_tags == trial_obj.tags
+
+
+def test_save_load(trial_obj, sagemaker_session):
+ trial_obj.display_name = "foo"
+ trial_obj.save()
+ assert (
+ "foo"
+ == trial._Trial.load(
+ trial_name=trial_obj.trial_name,
+ sagemaker_session=sagemaker_session,
+ ).display_name
+ )
+
+
+def test_add_remove_trial_component(trial_obj, trial_component_obj):
+ trial_obj.add_trial_component(trial_component_obj)
+ logging.info(
+ f"Added trial component {trial_component_obj.trial_component_name} to trial {trial_obj.trial_name}"
+ )
+
+ def validate_add():
+ trial_components = list(trial_obj.list_trial_components())
+ assert 1 == len(
+ trial_components
+ ), "Expected trial component to be included in trials list of TC"
+
+ retry_with_backoff(validate_add)
+
+ trial_obj.remove_trial_component(trial_component_obj)
+ logging.info(
+ f"Removed trial component {trial_component_obj.trial_component_name} from trial {trial_obj.trial_name}"
+ )
+
+ def validate_remove():
+ trial_components = list(trial_obj.list_trial_components())
+ assert 0 == len(
+ trial_components
+ ), "Expected trial component to be removed from trials list of TC"
+
+ retry_with_backoff(validate_remove)
diff --git a/tests/integ/sagemaker/experiments/test_trial_component.py b/tests/integ/sagemaker/experiments/test_trial_component.py
new file mode 100644
index 0000000000..3d79e41cc4
--- /dev/null
+++ b/tests/integ/sagemaker/experiments/test_trial_component.py
@@ -0,0 +1,144 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import datetime
+import uuid
+
+from sagemaker.experiments._api_types import _TrialComponentStatusType
+from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX
+from sagemaker.experiments import _api_types, trial_component
+from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression
+
+
+def test_create_delete(trial_component_obj):
+ # Fixture does create / delete, just need to ensure called at least once
+ assert trial_component_obj.trial_component_name
+ assert trial_component_obj.input_artifacts == {}
+ assert trial_component_obj.parameters == {}
+ assert trial_component_obj.output_artifacts == {}
+
+
+def test_create_tags(trial_component_obj, sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ while True:
+ actual_tags = client.list_tags(ResourceArn=trial_component_obj.trial_component_arn)["Tags"]
+ if actual_tags:
+ break
+ for tag in actual_tags:
+ if "aws:tag" in tag.get("Key"):
+ actual_tags.remove(tag)
+ assert actual_tags == trial_component_obj.tags
+
+
+def test_delete_with_force_disassociate(
+ trial_component_with_force_disassociation_obj, sagemaker_session
+):
+ assert trial_component_with_force_disassociation_obj.trial_component_name
+ trials = sagemaker_session.sagemaker_client.list_trials(
+ TrialComponentName=trial_component_with_force_disassociation_obj.trial_component_name
+ )["TrialSummaries"]
+ assert len(trials) == 3
+
+
+def test_save(trial_component_obj, sagemaker_session):
+ trial_component_obj.display_name = str(uuid.uuid4())
+ trial_component_obj.status = _api_types.TrialComponentStatus(
+ primary_status=_TrialComponentStatusType.InProgress.value, message="Message"
+ )
+ trial_component_obj.start_time = datetime.datetime.now(
+ datetime.timezone.utc
+ ) - datetime.timedelta(days=1)
+ trial_component_obj.end_time = datetime.datetime.now(datetime.timezone.utc)
+ trial_component_obj.parameters = {"foo": "bar", "whizz": 100.1}
+ trial_component_obj.input_artifacts = {
+ "snizz": _api_types.TrialComponentArtifact(value="s3:/foo/bar", media_type="text/plain"),
+ "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2"),
+ }
+ trial_component_obj.output_artifacts = {
+ "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow"),
+ "fly2": _api_types.TrialComponentArtifact(
+ value="s3:/sky/far2", media_type="away/tomorrow2"
+ ),
+ }
+ trial_component_obj.parameters_to_remove = ["foo"]
+ trial_component_obj.input_artifacts_to_remove = ["snizz"]
+ trial_component_obj.output_artifacts_to_remove = ["fly2"]
+
+ trial_component_obj.save()
+
+ loaded = trial_component._TrialComponent.load(
+ trial_component_name=trial_component_obj.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+
+ assert trial_component_obj.trial_component_name == loaded.trial_component_name
+ assert trial_component_obj.status == loaded.status
+
+ assert trial_component_obj.start_time - loaded.start_time < datetime.timedelta(seconds=1)
+ assert trial_component_obj.end_time - loaded.end_time < datetime.timedelta(seconds=1)
+
+ assert loaded.parameters == {"whizz": 100.1}
+ assert loaded.input_artifacts == {
+ "snizz1": _api_types.TrialComponentArtifact(value="s3:/foo/bar2", media_type="text/plain2")
+ }
+ assert loaded.output_artifacts == {
+ "fly": _api_types.TrialComponentArtifact(value="s3:/sky/far", media_type="away/tomorrow")
+ }
+
+
+def test_load(trial_component_obj, sagemaker_session):
+ loaded = trial_component._TrialComponent.load(
+ trial_component_name=trial_component_obj.trial_component_name,
+ sagemaker_session=sagemaker_session,
+ )
+ assert trial_component_obj.trial_component_arn == loaded.trial_component_arn
+
+
+def test_list_sort(trial_components, sagemaker_session):
+ slack = datetime.timedelta(minutes=1)
+ now = datetime.datetime.now(datetime.timezone.utc)
+ trial_component_names = [tc.trial_component_name for tc in trial_components]
+
+ for sort_order in ["Ascending", "Descending"]:
+ trial_component_names_listed = [
+ s.trial_component_name
+ for s in trial_component._TrialComponent.list(
+ created_after=now - slack,
+ created_before=now + slack,
+ sort_by="CreationTime",
+ sort_order=sort_order,
+ sagemaker_session=sagemaker_session,
+ )
+ if s.trial_component_name in trial_component_names
+ ]
+
+ if sort_order == "Descending":
+ trial_component_names_listed = trial_component_names_listed[::-1]
+ assert trial_component_names == trial_component_names_listed
+ assert trial_component_names # sanity test
+
+
+def test_search(sagemaker_session):
+ trial_component_names_searched = []
+ search_filter = Filter(
+ name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX
+ )
+ search_expression = SearchExpression(filters=[search_filter])
+ for s in trial_component._TrialComponent.search(
+ search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session
+ ):
+ trial_component_names_searched.append(s.trial_component_name)
+
+ assert len(trial_component_names_searched) > 0
+ assert trial_component_names_searched # sanity test
diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py
index 3c416ffd36..abfe6f6d0d 100644
--- a/tests/integ/sagemaker/lineage/conftest.py
+++ b/tests/integ/sagemaker/lineage/conftest.py
@@ -26,6 +26,7 @@
artifact,
)
from sagemaker.model import ModelPackage
+from sagemaker.utils import retry_with_backoff
from tests.integ.sagemaker.workflow.test_workflow import (
test_end_to_end_pipeline_successful_execution,
)
@@ -43,7 +44,7 @@
)
from sagemaker.lineage.lineage_trial_component import LineageTrialComponent
-from tests.integ.sagemaker.lineage.helpers import name, names, retry
+from tests.integ.sagemaker.lineage.helpers import name, names
SLEEP_TIME_SECONDS = 1
SLEEP_TIME_TWO_SECONDS = 2
@@ -400,7 +401,7 @@ def model_obj(sagemaker_session):
yield model
time.sleep(SLEEP_TIME_SECONDS)
- retry(lambda: model.delete(disassociate=True), num_attempts=4)
+ retry_with_backoff(lambda: model.delete(disassociate=True), num_attempts=4)
@pytest.fixture
diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py
index fb71d1d88c..5548c63cff 100644
--- a/tests/integ/sagemaker/lineage/helpers.py
+++ b/tests/integ/sagemaker/lineage/helpers.py
@@ -15,7 +15,6 @@
import uuid
from datetime import datetime
-import time
def name():
@@ -33,19 +32,6 @@ def names():
]
-def retry(callable, num_attempts=8):
- assert num_attempts >= 1
- for i in range(num_attempts):
- try:
- return callable()
- except Exception as ex:
- if i == num_attempts - 1:
- raise ex
- print("Retrying", ex)
- time.sleep(2**i)
- assert False, "logic error in retry"
-
-
def traverse_graph_back(start_arn, sagemaker_session):
def visit(arn, visited: set):
visited.add(arn)
diff --git a/tests/integ/sagemaker/lineage/test_artifact.py b/tests/integ/sagemaker/lineage/test_artifact.py
index c629fcdc30..1980b51da2 100644
--- a/tests/integ/sagemaker/lineage/test_artifact.py
+++ b/tests/integ/sagemaker/lineage/test_artifact.py
@@ -20,7 +20,7 @@
import pytest
from sagemaker.lineage import artifact
-from tests.integ.sagemaker.lineage.helpers import retry
+from sagemaker.utils import retry_with_backoff
def test_create_delete(artifact_obj):
@@ -125,7 +125,7 @@ def validate():
assert len(trials) == 1
assert trial_obj.trial_name in trials
- retry(validate, num_attempts=3)
+ retry_with_backoff(validate, num_attempts=3)
def test_downstream_trials_v2(trial_associated_artifact, trial_obj, sagemaker_session):
diff --git a/tests/integ/sagemaker/utilities/__init__.py b/tests/integ/sagemaker/utilities/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/integ/sagemaker/utilities/test_search_expression.py b/tests/integ/sagemaker/utilities/test_search_expression.py
new file mode 100644
index 0000000000..ea7f4476bf
--- /dev/null
+++ b/tests/integ/sagemaker/utilities/test_search_expression.py
@@ -0,0 +1,67 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import pytest
+
+from tests.integ.sagemaker.experiments.helpers import EXP_INTEG_TEST_NAME_PREFIX
+from sagemaker.experiments.trial_component import _TrialComponent
+from sagemaker.utilities.search_expression import Filter, Operator, SearchExpression, NestedFilter
+
+
+def test_search(sagemaker_session):
+ tc_names_searched = []
+ search_filter = Filter(
+ name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX
+ )
+ search_expression = SearchExpression(filters=[search_filter])
+ for tc in _TrialComponent.search(
+ search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session
+ ):
+ tc_names_searched.append(tc.trial_component_name)
+
+ assert len(tc_names_searched) > 0
+ assert tc_names_searched
+
+
+@pytest.mark.skip(reason="failed validation, need to wait for NestedFilter bug to be fixed")
+def test_nested_search(sagemaker_session):
+ tc_names_searched = []
+ search_filter = Filter(
+ name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX
+ )
+ nested_filter = NestedFilter(property_name="TrialComponentName", filters=[search_filter])
+ search_expression = SearchExpression(nested_filters=[nested_filter])
+ for tc in _TrialComponent.search(
+ search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session
+ ):
+ tc_names_searched.append(tc.trial_component_name)
+
+ assert len(tc_names_searched) > 0
+ assert tc_names_searched
+
+
+def test_sub_expression(sagemaker_session):
+ tc_names_searched = []
+ search_filter = Filter(
+ name="TrialComponentName", operator=Operator.CONTAINS, value=EXP_INTEG_TEST_NAME_PREFIX
+ )
+ sub_expression = SearchExpression(filters=[search_filter])
+ search_expression = SearchExpression(sub_expressions=[sub_expression])
+ for tc in _TrialComponent.search(
+ search_expression=search_expression, max_results=10, sagemaker_session=sagemaker_session
+ ):
+ tc_names_searched.append(tc.trial_component_name)
+
+ assert len(tc_names_searched) > 0
+ assert tc_names_searched
diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py
index 31c518b100..f25723c440 100644
--- a/tests/integ/sagemaker/workflow/test_model_steps.py
+++ b/tests/integ/sagemaker/workflow/test_model_steps.py
@@ -112,6 +112,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen
inference_instances=["ml.m5.xlarge"],
transform_instances=["ml.m5.xlarge"],
description="test-description",
+ model_package_name="model-pkg-name-will-be-popped-out",
)
step_model_regis = ModelStep(
name="pytorch-register-model",
diff --git a/tests/integ/sagemaker/workflow/test_processing_steps.py b/tests/integ/sagemaker/workflow/test_processing_steps.py
index 781bce85a7..238eff6123 100644
--- a/tests/integ/sagemaker/workflow/test_processing_steps.py
+++ b/tests/integ/sagemaker/workflow/test_processing_steps.py
@@ -17,15 +17,18 @@
import re
import subprocess
from datetime import datetime
+from pathlib import Path
import pytest
from botocore.exceptions import WaiterError
+from sagemaker.workflow.utilities import hash_files_or_dirs, hash_object
from sagemaker import image_uris, get_execution_role, utils
from sagemaker.dataset_definition import DatasetDefinition, AthenaDatasetDefinition
-from sagemaker.processing import ProcessingInput, ProcessingOutput
-from sagemaker.s3 import S3Uploader
-from sagemaker.sklearn import SKLearnProcessor
+from sagemaker.processing import ProcessingInput, ProcessingOutput, FrameworkProcessor
+from sagemaker.s3 import S3Uploader, S3Downloader
+from sagemaker.sklearn import SKLearnProcessor, SKLearn
+from sagemaker.tensorflow import TensorFlow
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
from sagemaker.workflow.pipeline import Pipeline
from sagemaker.workflow.steps import (
@@ -379,6 +382,203 @@ def test_one_step_framework_processing_pipeline(
pass
+def test_multi_step_framework_processing_pipeline_same_source_dir(
+ pipeline_session, role, pipeline_name
+):
+ default_bucket = pipeline_session.default_bucket()
+ cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
+
+ SOURCE_DIR = "/pipeline/test_source_dir"
+
+ framework_processor_tf = FrameworkProcessor(
+ role=role,
+ instance_type="ml.m5.xlarge",
+ instance_count=1,
+ estimator_cls=TensorFlow,
+ framework_version="2.9",
+ py_version="py39",
+ sagemaker_session=pipeline_session,
+ )
+
+ framework_processor_sk = FrameworkProcessor(
+ framework_version="1.0-1",
+ instance_type="ml.m5.xlarge",
+ instance_count=1,
+ base_job_name="my-job",
+ role=role,
+ estimator_cls=SKLearn,
+ sagemaker_session=pipeline_session,
+ )
+
+ step_1 = ProcessingStep(
+ name="Step-1",
+ step_args=framework_processor_tf.run(
+ code="script_1.py",
+ source_dir=DATA_DIR + SOURCE_DIR,
+ outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")],
+ ),
+ cache_config=cache_config,
+ )
+
+ step_2 = ProcessingStep(
+ name="Step-2",
+ step_args=framework_processor_sk.run(
+ code="script_2.py",
+ source_dir=DATA_DIR + SOURCE_DIR,
+ inputs=[
+ ProcessingInput(
+ source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri,
+ destination="/opt/ml/processing/test",
+ ),
+ ],
+ ),
+ cache_config=cache_config,
+ )
+
+ pipeline = Pipeline(
+ name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session
+ )
+ try:
+ pipeline.create(role)
+ definition = json.loads(pipeline.definition())
+
+ source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step(
+ pipeline_session,
+ framework_processor_tf,
+ default_bucket,
+ pipeline_name,
+ definition["Steps"][0],
+ SOURCE_DIR,
+ "script_1.py",
+ )
+ source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step(
+ pipeline_session,
+ framework_processor_sk,
+ default_bucket,
+ pipeline_name,
+ definition["Steps"][1],
+ SOURCE_DIR,
+ "script_2.py",
+ )
+
+ # the same local source_dirs should have the same s3 paths
+ assert source_dir_1_s3_uri == source_dir_2_s3_uri
+
+ # verify different entry_point paths
+ assert entry_point_1 != entry_point_2
+
+ execution = pipeline.start(parameters={})
+ try:
+ execution.wait(delay=540, max_attempts=3)
+ except WaiterError:
+ pass
+
+ execution_steps = execution.list_steps()
+ assert len(execution_steps) == 2
+ for step in execution_steps:
+ assert step["StepStatus"] == "Succeeded"
+
+ finally:
+ try:
+ pipeline.delete()
+ except Exception:
+ pass
+
+
+def test_multi_step_framework_processing_pipeline_different_source_dir(
+ pipeline_session, role, pipeline_name
+):
+ default_bucket = pipeline_session.default_bucket()
+ cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
+
+ SOURCE_DIR_1 = "/pipeline/test_source_dir"
+ SOURCE_DIR_2 = "/pipeline/test_source_dir_2"
+
+ framework_processor_tf = FrameworkProcessor(
+ role=role,
+ instance_type="ml.m5.xlarge",
+ instance_count=1,
+ estimator_cls=TensorFlow,
+ framework_version="2.9",
+ py_version="py39",
+ sagemaker_session=pipeline_session,
+ )
+
+ step_1 = ProcessingStep(
+ name="Step-1",
+ step_args=framework_processor_tf.run(
+ code="script_1.py",
+ source_dir=DATA_DIR + SOURCE_DIR_1,
+ outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")],
+ ),
+ cache_config=cache_config,
+ )
+
+ step_2 = ProcessingStep(
+ name="Step-2",
+ step_args=framework_processor_tf.run(
+ code="script_2.py",
+ source_dir=DATA_DIR + SOURCE_DIR_2,
+ inputs=[
+ ProcessingInput(
+ source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri,
+ destination="/opt/ml/processing/test",
+ ),
+ ],
+ ),
+ cache_config=cache_config,
+ )
+
+ pipeline = Pipeline(
+ name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session
+ )
+ try:
+ pipeline.create(role)
+ definition = json.loads(pipeline.definition())
+
+ source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step(
+ pipeline_session,
+ framework_processor_tf,
+ default_bucket,
+ pipeline_name,
+ definition["Steps"][0],
+ SOURCE_DIR_1,
+ "script_1.py",
+ )
+ source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step(
+ pipeline_session,
+ framework_processor_tf,
+ default_bucket,
+ pipeline_name,
+ definition["Steps"][1],
+ SOURCE_DIR_2,
+ "script_2.py",
+ )
+
+ # different local source_dirs should have different s3 paths
+ assert source_dir_1_s3_uri != source_dir_2_s3_uri
+
+ # verify different entry_point paths
+ assert entry_point_1 != entry_point_2
+
+ execution = pipeline.start(parameters={})
+ try:
+ execution.wait(delay=540, max_attempts=3)
+ except WaiterError:
+ pass
+
+ execution_steps = execution.list_steps()
+ assert len(execution_steps) == 2
+ for step in execution_steps:
+ assert step["StepStatus"] == "Succeeded"
+
+ finally:
+ try:
+ pipeline.delete()
+ except Exception:
+ pass
+
+
def test_one_step_pyspark_processing_pipeline(
sagemaker_session,
role,
@@ -796,3 +996,46 @@ def test_two_processing_job_depends_on(
pipeline.delete()
except Exception:
pass
+
+
+def _verify_code_artifacts_of_framework_processing_step(
+ pipeline_session, processor, bucket, pipeline_name, step_definition, source_dir, entry_point
+):
+
+ source_dir_s3_uri = (
+ f"s3://{bucket}/{pipeline_name}" f"/code/{hash_files_or_dirs([f'{DATA_DIR}/{source_dir}'])}"
+ )
+
+ # verify runproc.sh prefix is different from code artifact prefix
+ runprocs = []
+ for input_obj in step_definition["Arguments"]["ProcessingInputs"]:
+ if input_obj["InputName"] == "entrypoint":
+ s3_uri = input_obj["S3Input"]["S3Uri"]
+ runprocs.append(s3_uri)
+
+ assert Path(s3_uri).parent != source_dir_s3_uri
+
+ # verify only one entrypoint generated per step
+ assert len(runprocs) == 1
+
+ expected_source_dir_tar = (
+ f"{pipeline_name}"
+ f"/code/{hash_files_or_dirs([DATA_DIR + '/pipeline/test_source_dir'])}/sourcedir.tar.gz"
+ )
+
+ step_script = processor._generate_framework_script(entry_point)
+ expected_step_artifact = f"{pipeline_name}/code/{hash_object(step_script)}/runproc.sh"
+
+ expected_prefix = f"{pipeline_name}/code"
+ s3_code_objects = pipeline_session.list_s3_files(bucket=bucket, key_prefix=expected_prefix)
+
+ # verify all distinct artifacts were uploaded
+ assert expected_source_dir_tar in s3_code_objects
+ assert expected_step_artifact in s3_code_objects
+
+ # verify runprocs contain the correct commands
+ step_runproc = S3Downloader.read_file(
+ f"s3://{bucket}/{expected_step_artifact}", pipeline_session
+ )
+ assert f"python {entry_point}" in step_runproc
+ return source_dir, expected_step_artifact
diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py
index 634ef752d6..44f4e2d26e 100644
--- a/tests/integ/sagemaker/workflow/test_workflow.py
+++ b/tests/integ/sagemaker/workflow/test_workflow.py
@@ -1168,7 +1168,13 @@ def walk():
def test_caching_behavior(
- pipeline_session, role, cpu_instance_type, pipeline_name, script_dir, athena_dataset_definition
+ pipeline_session,
+ role,
+ cpu_instance_type,
+ pipeline_name,
+ script_dir,
+ athena_dataset_definition,
+ region_name,
):
default_bucket = pipeline_session.default_bucket()
data_path = os.path.join(DATA_DIR, "workflow")
diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py
index c1b84117c3..e19cebdca4 100644
--- a/tests/integ/test_feature_store.py
+++ b/tests/integ/test_feature_store.py
@@ -14,6 +14,7 @@
import json
import time
+import datetime
from contextlib import contextmanager
import boto3
@@ -24,6 +25,7 @@
from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition
from sagemaker.feature_store.feature_group import FeatureGroup
+from sagemaker.feature_store.feature_store import FeatureStore
from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum
from sagemaker.session import get_execution_role, Session
from tests.integ.timeout import timeout
@@ -80,6 +82,11 @@ def feature_group_name():
return f"my-feature-group-{int(time.time() * 10**7)}"
+@pytest.fixture
+def base_name():
+ return f"my-base-{int(time.time() * 10**7)}"
+
+
@pytest.fixture
def offline_store_s3_uri(feature_store_session, region_name):
bucket = f"sagemaker-test-featurestore-{region_name}-{feature_store_session.account_id()}"
@@ -107,6 +114,32 @@ def pandas_data_frame():
return df
+@pytest.fixture
+def base_dataframe():
+ base_data = [
+ [1, 187512346.0, 123, 128],
+ [2, 187512347.0, 168, 258],
+ [3, 187512348.0, 125, 184],
+ [1, 187512349.0, 195, 206],
+ ]
+ return pd.DataFrame(
+ base_data, columns=["base_id", "base_time", "base_feature_1", "base_feature_2"]
+ )
+
+
+@pytest.fixture
+def feature_group_dataframe():
+ feature_group_data = [
+ [1, 187512246.0, 456, 325],
+ [2, 187512247.0, 729, 693],
+ [3, 187512348.0, 129, 901],
+ [1, 187512449.0, 289, 286],
+ ]
+ return pd.DataFrame(
+ feature_group_data, columns=["fg_id", "fg_time", "fg_feature_1", "fg_feature_2"]
+ )
+
+
@pytest.fixture
def pandas_data_frame_without_string():
df = pd.DataFrame(
@@ -288,6 +321,92 @@ def test_create_feature_group_glue_table_format(
assert table_format == "Glue"
+def test_get_record(
+ feature_store_session,
+ role,
+ feature_group_name,
+ pandas_data_frame,
+ record,
+):
+ feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
+ feature_group.load_feature_definitions(data_frame=pandas_data_frame)
+
+ record_identifier_value_as_string = record[0].value_as_string
+ with cleanup_feature_group(feature_group):
+ feature_group.create(
+ s3_uri=False,
+ record_identifier_name="feature1",
+ event_time_feature_name="feature3",
+ role_arn=role,
+ enable_online_store=True,
+ )
+ _wait_for_feature_group_create(feature_group)
+ # Ingest data
+ feature_group.put_record(record=record)
+ # Retrieve data
+ retrieved_record = feature_group.get_record(
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ )
+ record_names = list(map(lambda r: r.feature_name, record))
+ assert len(retrieved_record) == len(record_names)
+ for feature in retrieved_record:
+ assert feature["FeatureName"] in record_names
+ removed_feature_name = record_names.pop()
+ # Retrieve data
+ retrieved_record = feature_group.get_record(
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ feature_names=record_names,
+ )
+ assert len(retrieved_record) == len(record_names)
+ for feature in retrieved_record:
+ assert feature["FeatureName"] in record_names
+ assert feature["FeatureName"] is not removed_feature_name
+ # Retrieve data
+ retrieved_record = feature_group.get_record(
+ record_identifier_value_as_string="1.0",
+ )
+ assert retrieved_record is None
+
+
+def test_delete_record(
+ feature_store_session,
+ role,
+ feature_group_name,
+ pandas_data_frame,
+ record,
+):
+ feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
+ feature_group.load_feature_definitions(data_frame=pandas_data_frame)
+
+ record_identifier_value_as_string = record[0].value_as_string
+ with cleanup_feature_group(feature_group):
+ feature_group.create(
+ s3_uri=False,
+ record_identifier_name="feature1",
+ event_time_feature_name="feature3",
+ role_arn=role,
+ enable_online_store=True,
+ )
+ _wait_for_feature_group_create(feature_group)
+ # Ingest data
+ feature_group.put_record(record=record)
+ # Retrieve data
+ retrieved_record = feature_group.get_record(
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ )
+ assert retrieved_record is not None
+ # Delete data
+ feature_group.delete_record(
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ event_time=datetime.datetime.now().replace(microsecond=0).isoformat() + "Z",
+ )
+ # Retrieve data
+ retrieved_record = feature_group.get_record(
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ )
+ assert retrieved_record is None
+
+
def test_update_feature_group(
feature_store_session,
role,
@@ -316,6 +435,25 @@ def test_update_feature_group(
assert any([True for elem in feature_definitions if new_feature_name in elem.values()])
+def test_list_feature_groups(feature_store_session, role, feature_group_name, pandas_data_frame):
+ feature_store = FeatureStore(sagemaker_session=feature_store_session)
+ feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
+ feature_group.load_feature_definitions(data_frame=pandas_data_frame)
+
+ with cleanup_feature_group(feature_group):
+ feature_group.create(
+ s3_uri=False,
+ record_identifier_name="feature1",
+ event_time_feature_name="feature3",
+ role_arn=role,
+ enable_online_store=True,
+ )
+ _wait_for_feature_group_create(feature_group)
+ output = feature_store.list_feature_groups(name_contains=feature_group_name)
+
+ assert output["FeatureGroupSummaries"][0]["FeatureGroupName"] == feature_group_name
+
+
def test_feature_metadata(
feature_store_session,
role,
@@ -420,6 +558,242 @@ def test_ingest_multi_process(
assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}")
+def test_create_dataset_with_feature_group_base(
+ feature_store_session,
+ region_name,
+ role,
+ base_name,
+ feature_group_name,
+ offline_store_s3_uri,
+ base_dataframe,
+ feature_group_dataframe,
+):
+ base = FeatureGroup(name=base_name, sagemaker_session=feature_store_session)
+ feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session)
+ with cleanup_feature_group(base), cleanup_feature_group(feature_group):
+ _create_feature_group_and_ingest_data(
+ base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role
+ )
+ _create_feature_group_and_ingest_data(
+ feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role
+ )
+ base_table_name = _get_athena_table_name_after_data_replication(
+ feature_store_session, base, offline_store_s3_uri
+ )
+ feature_group_table_name = _get_athena_table_name_after_data_replication(
+ feature_store_session, feature_group, offline_store_s3_uri
+ )
+
+ with timeout(minutes=10) and cleanup_offline_store(
+ base_table_name, feature_store_session
+ ) and cleanup_offline_store(feature_group_table_name, feature_store_session):
+ feature_store = FeatureStore(sagemaker_session=feature_store_session)
+ df, query_string = (
+ feature_store.create_dataset(base=base, output_path=offline_store_s3_uri)
+ .with_number_of_recent_records_by_record_identifier(4)
+ .with_feature_group(feature_group)
+ .to_dataframe()
+ )
+ sorted_df = df.sort_values(by=list(df.columns)).reset_index(drop=True)
+ merged_df = base_dataframe.merge(
+ feature_group_dataframe, left_on="base_id", right_on="fg_id"
+ )
+
+ expect_df = merged_df.sort_values(by=list(merged_df.columns)).reset_index(drop=True)
+
+ expect_df.rename(
+ columns={
+ "fg_id": "fg_id.1",
+ "fg_time": "fg_time.1",
+ "fg_feature_1": "fg_feature_1.1",
+ "fg_feature_2": "fg_feature_2.1",
+ },
+ inplace=True,
+ )
+
+ assert sorted_df.equals(expect_df)
+ assert (
+ query_string
+ == "WITH fg_base AS (WITH table_base AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_base."base_id", origin_base."base_time"\n'
+ + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n'
+ + ") AS dedup_row_base\n"
+ + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n'
+ + ")\n"
+ + "WHERE dedup_row_base = 1\n"
+ + "),\n"
+ + "deleted_base AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_base."base_id"\n'
+ + 'ORDER BY origin_base."base_time" DESC,'
+ ' origin_base."api_invocation_time" DESC,'
+ ' origin_base."write_time" DESC\n'
+ + ") AS deleted_row_base\n"
+ + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n'
+ + "WHERE is_deleted\n"
+ + ")\n"
+ + "WHERE deleted_row_base = 1\n"
+ + ")\n"
+ + 'SELECT table_base."base_id", table_base."base_time",'
+ ' table_base."base_feature_1", table_base."base_feature_2"\n'
+ + "FROM (\n"
+ + 'SELECT table_base."base_id", table_base."base_time",'
+ ' table_base."base_feature_1", table_base."base_feature_2",'
+ ' table_base."write_time"\n'
+ + "FROM table_base\n"
+ + "LEFT JOIN deleted_base\n"
+ + 'ON table_base."base_id" = deleted_base."base_id"\n'
+ + 'WHERE deleted_base."base_id" IS NULL\n'
+ + "UNION ALL\n"
+ + 'SELECT table_base."base_id", table_base."base_time",'
+ ' table_base."base_feature_1", table_base."base_feature_2",'
+ ' table_base."write_time"\n'
+ + "FROM deleted_base\n"
+ + "JOIN table_base\n"
+ + 'ON table_base."base_id" = deleted_base."base_id"\n'
+ + "AND (\n"
+ + 'table_base."base_time" > deleted_base."base_time"\n'
+ + 'OR (table_base."base_time" = deleted_base."base_time" AND'
+ ' table_base."api_invocation_time" >'
+ ' deleted_base."api_invocation_time")\n'
+ + 'OR (table_base."base_time" = deleted_base."base_time" AND'
+ ' table_base."api_invocation_time" ='
+ ' deleted_base."api_invocation_time" AND'
+ ' table_base."write_time" > deleted_base."write_time")\n'
+ + ")\n"
+ + ") AS table_base\n"
+ + "),\n"
+ + "fg_0 AS (WITH table_0 AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_0."fg_id", origin_0."fg_time"\n'
+ + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n'
+ + ") AS dedup_row_0\n"
+ + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n'
+ + ")\n"
+ + "WHERE dedup_row_0 = 1\n"
+ + "),\n"
+ + "deleted_0 AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_0."fg_id"\n'
+ + 'ORDER BY origin_0."fg_time" DESC, origin_0."api_invocation_time" DESC,'
+ ' origin_0."write_time" DESC\n'
+ + ") AS deleted_row_0\n"
+ + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n'
+ + "WHERE is_deleted\n"
+ + ")\n"
+ + "WHERE deleted_row_0 = 1\n"
+ + ")\n"
+ + 'SELECT table_0."fg_id", table_0."fg_time", table_0."fg_feature_1",'
+ ' table_0."fg_feature_2"\n'
+ + "FROM (\n"
+ + 'SELECT table_0."fg_id", table_0."fg_time",'
+ ' table_0."fg_feature_1", table_0."fg_feature_2",'
+ ' table_0."write_time"\n'
+ + "FROM table_0\n"
+ + "LEFT JOIN deleted_0\n"
+ + 'ON table_0."fg_id" = deleted_0."fg_id"\n'
+ + 'WHERE deleted_0."fg_id" IS NULL\n'
+ + "UNION ALL\n"
+ + 'SELECT table_0."fg_id", table_0."fg_time",'
+ ' table_0."fg_feature_1", table_0."fg_feature_2",'
+ ' table_0."write_time"\n'
+ + "FROM deleted_0\n"
+ + "JOIN table_0\n"
+ + 'ON table_0."fg_id" = deleted_0."fg_id"\n'
+ + "AND (\n"
+ + 'table_0."fg_time" > deleted_0."fg_time"\n'
+ + 'OR (table_0."fg_time" = deleted_0."fg_time" AND'
+ ' table_0."api_invocation_time" >'
+ ' deleted_0."api_invocation_time")\n'
+ + 'OR (table_0."fg_time" = deleted_0."fg_time" AND'
+ ' table_0."api_invocation_time" ='
+ ' deleted_0."api_invocation_time" AND table_0."write_time" >'
+ ' deleted_0."write_time")\n'
+ + ")\n"
+ + ") AS table_0\n"
+ + ")\n"
+ + "SELECT base_id, base_time, base_feature_1, base_feature_2,"
+ ' "fg_id.1", "fg_time.1", "fg_feature_1.1",'
+ ' "fg_feature_2.1"\n' + "FROM (\n" + "SELECT fg_base.base_id, fg_base.base_time,"
+ " fg_base.base_feature_1, fg_base.base_feature_2,"
+ ' fg_0."fg_id" as "fg_id.1", fg_0."fg_time" as "fg_time.1",'
+ ' fg_0."fg_feature_1" as "fg_feature_1.1",'
+ ' fg_0."fg_feature_2" as "fg_feature_2.1", row_number()'
+ " OVER (\n"
+ + 'PARTITION BY fg_base."base_id"\n'
+ + 'ORDER BY fg_base."base_time" DESC, fg_0."fg_time" DESC\n'
+ + ") AS row_recent\n"
+ + "FROM fg_base\n"
+ + "JOIN fg_0\n"
+ + 'ON fg_base."base_id" = fg_0."fg_id"\n'
+ + ")\n"
+ + "WHERE row_recent <= 4"
+ )
+
+
+def _create_feature_group_and_ingest_data(
+ feature_group: FeatureGroup,
+ dataframe: DataFrame,
+ offline_store_s3_uri: str,
+ record_identifier_name: str,
+ event_time_name: str,
+ role: str,
+):
+ feature_group.load_feature_definitions(data_frame=dataframe)
+ feature_group.create(
+ s3_uri=offline_store_s3_uri,
+ record_identifier_name=record_identifier_name,
+ event_time_feature_name=event_time_name,
+ role_arn=role,
+ enable_online_store=True,
+ )
+ _wait_for_feature_group_create(feature_group)
+
+ ingestion_manager = feature_group.ingest(data_frame=dataframe, max_workers=3, wait=False)
+ ingestion_manager.wait()
+ assert 0 == len(ingestion_manager.failed_rows)
+
+
+def _get_athena_table_name_after_data_replication(
+ feature_store_session, feature_group: FeatureGroup, offline_store_s3_uri
+):
+ feature_group_metadata = feature_group.describe()
+ resolved_output_s3_uri = (
+ feature_group_metadata.get("OfflineStoreConfig", None)
+ .get("S3StorageConfig", None)
+ .get("ResolvedOutputS3Uri", None)
+ )
+ s3_prefix = resolved_output_s3_uri.replace(f"{offline_store_s3_uri}/", "")
+ region_name = feature_store_session.boto_session.region_name
+ s3_client = feature_store_session.boto_session.client(
+ service_name="s3", region_name=region_name
+ )
+ while True:
+ objects_in_bucket = s3_client.list_objects(
+ Bucket=offline_store_s3_uri.replace("s3://", ""), Prefix=s3_prefix
+ )
+ if "Contents" in objects_in_bucket and len(objects_in_bucket["Contents"]) > 1:
+ break
+ else:
+ print(f"Waiting for {feature_group.name} data in offline store...")
+ time.sleep(60)
+ print(f"{feature_group.name} data available.")
+ return (
+ feature_group_metadata.get("OfflineStoreConfig", None)
+ .get("DataCatalogConfig", None)
+ .get("TableName", None)
+ )
+
+
def _wait_for_feature_group_create(feature_group: FeatureGroup):
status = feature_group.describe().get("FeatureGroupStatus")
while status == "Creating":
@@ -451,5 +825,31 @@ def cleanup_feature_group(feature_group: FeatureGroup):
finally:
try:
feature_group.delete()
+ print(f"{feature_group.name} is deleted")
except Exception:
raise RuntimeError(f"Failed to delete feature group with name {feature_group.name}")
+
+
+@contextmanager
+def cleanup_offline_store(table_name: str, feature_store_session: Session):
+ try:
+ yield
+ finally:
+ try:
+ region_name = feature_store_session.boto_session.region_name
+ s3_client = feature_store_session.boto_session.client(
+ service_name="s3", region_name=region_name
+ )
+ account_id = feature_store_session.account_id()
+ bucket_name = f"sagemaker-test-featurestore-{region_name}-{account_id}"
+ response = s3_client.list_objects_v2(
+ Bucket=bucket_name,
+ Prefix=f"{account_id}/sagemaker/{region_name}/offline-store/{table_name}/",
+ )
+ files_in_folder = response["Contents"]
+ files_to_delete = []
+ for f in files_in_folder:
+ files_to_delete.append({"Key": f["Key"]})
+ s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": files_to_delete})
+ except Exception:
+ raise RuntimeError(f"Failed to delete data under {table_name}")
diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py
index 53d966fe9b..a26d8c9101 100644
--- a/tests/integ/test_inference_pipeline.py
+++ b/tests/integ/test_inference_pipeline.py
@@ -50,6 +50,7 @@
)
+@pytest.mark.skip(reason="Test has likely been failing for a while. Suspected bad XGB model.")
def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type):
sparkml_model_data = sagemaker_session.upload_data(
path=os.path.join(SPARKML_DATA_PATH, "mleap_model.tar.gz"),
diff --git a/tests/integ/test_marketplace.py b/tests/integ/test_marketplace.py
index b9ff13c50e..28b537c1ea 100644
--- a/tests/integ/test_marketplace.py
+++ b/tests/integ/test_marketplace.py
@@ -23,6 +23,7 @@
import sagemaker
import tests.integ
+from tests.integ.utils import create_repository
from sagemaker import AlgorithmEstimator, ModelPackage, Model
from sagemaker.serializers import CSVSerializer
from sagemaker.tuner import IntegerParameter, HyperparameterTuner
@@ -33,7 +34,6 @@
from tests.integ.test_multidatamodel import (
_ecr_image_uri,
_ecr_login,
- _create_repository,
_delete_repository,
)
from tests.integ.retry import retries
@@ -214,7 +214,7 @@ def iris_image(sagemaker_session):
rm=True,
)
image.tag(ecr_image, tag="latest")
- _create_repository(ecr_client, algorithm_name)
+ create_repository(ecr_client, algorithm_name)
# Retry docker image push
for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10):
diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py
index 78ba62c3db..d6c14037a7 100644
--- a/tests/integ/test_multidatamodel.py
+++ b/tests/integ/test_multidatamodel.py
@@ -19,8 +19,8 @@
import docker
import numpy
import pytest
-from botocore.exceptions import ClientError
+from tests.integ.utils import create_repository
from sagemaker import utils
from sagemaker.amazon.randomcutforest import RandomCutForest
from sagemaker.deserializers import StringDeserializer
@@ -59,7 +59,7 @@ def container_image(sagemaker_session):
image.tag(ecr_image, tag="latest")
# Create AWS ECR and push the local docker image to it
- _create_repository(ecr_client, algorithm_name)
+ create_repository(ecr_client, algorithm_name)
# Retry docker image push
for _ in retries(3, "Upload docker image to ECR repo", seconds_to_sleep=10):
@@ -90,23 +90,6 @@ def _ecr_image_uri(sagemaker_session, algorithm_name):
return "{}.dkr.{}/{}:latest".format(account_id, endpoint_data["hostname"], algorithm_name)
-def _create_repository(ecr_client, repository_name):
- """
- Creates an ECS Repository (ECR). When a new transform is being registered,
- we'll need a repository to push the image (and composed model images) to
- """
- try:
- response = ecr_client.create_repository(repositoryName=repository_name)
- return response["repository"]["repositoryUri"]
- except ClientError as e:
- # Handle when the repository already exists
- if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"):
- response = ecr_client.describe_repositories(repositoryNames=[repository_name])
- return response["repositories"][0]["repositoryUri"]
- else:
- raise
-
-
def _delete_repository(ecr_client, repository_name):
"""
Deletes an ECS Repository (ECR). After the integration test completes
diff --git a/tests/integ/test_training_compiler.py b/tests/integ/test_training_compiler.py
index 67de050ed1..724cd8890c 100644
--- a/tests/integ/test_training_compiler.py
+++ b/tests/integ/test_training_compiler.py
@@ -20,6 +20,8 @@
from sagemaker.huggingface import TrainingCompilerConfig as HFTrainingCompilerConfig
from sagemaker.tensorflow import TensorFlow
from sagemaker.tensorflow import TrainingCompilerConfig as TFTrainingCompilerConfig
+from sagemaker.pytorch import PyTorch
+from sagemaker.pytorch import TrainingCompilerConfig as PTTrainingCompilerConfig
from tests import integ
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
@@ -48,8 +50,7 @@ def imagenet_val_set(request, sagemaker_session, tmpdir_factory):
key_prefix="Imagenet/TFRecords/validation",
)
train_input = sagemaker_session.upload_data(
- path=local_path,
- key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val",
+ path=local_path, key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val"
)
return train_input
@@ -84,8 +85,8 @@ def skip_if_incompatible(gpu_instance_type, request):
@pytest.mark.parametrize(
"gpu_instance_type,instance_count",
[
- ("ml.p3.2xlarge", 1),
- ("ml.p3.16xlarge", 2),
+ pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release),
+ pytest.param("ml.p3.16xlarge", 2),
],
)
def test_huggingface_pytorch(
@@ -129,27 +130,32 @@ def test_huggingface_pytorch(
hf.fit(huggingface_dummy_dataset)
-@pytest.mark.release
-def test_huggingface_pytorch_release(
+@pytest.mark.parametrize(
+ "gpu_instance_type,instance_count",
+ [
+ pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release),
+ pytest.param("ml.p3.16xlarge", 2),
+ ],
+)
+def test_pytorch(
sagemaker_session,
gpu_instance_type,
- huggingface_training_compiler_latest_version,
- huggingface_training_compiler_pytorch_latest_version,
+ instance_count,
+ pytorch_training_compiler_latest_version,
huggingface_dummy_dataset,
):
"""
- Test the HuggingFace estimator with PyTorch
+ Test the PyTorch estimator
"""
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
- data_path = os.path.join(DATA_DIR, "huggingface")
- hf = HuggingFace(
+ hf = PyTorch(
py_version="py38",
- entry_point=os.path.join(data_path, "run_glue.py"),
+ source_dir=os.path.join(DATA_DIR, "huggingface_byoc"),
+ entry_point="run_glue.py",
role="SageMakerRole",
- transformers_version=huggingface_training_compiler_latest_version,
- pytorch_version=huggingface_training_compiler_pytorch_latest_version,
- instance_count=1,
+ framework_version=pytorch_training_compiler_latest_version,
+ instance_count=instance_count,
instance_type=gpu_instance_type,
hyperparameters={
"model_name_or_path": "distilbert-base-cased",
@@ -163,7 +169,8 @@ def test_huggingface_pytorch_release(
},
sagemaker_session=sagemaker_session,
disable_profiler=True,
- compiler_config=HFTrainingCompilerConfig(),
+ compiler_config=PTTrainingCompilerConfig(),
+ distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None,
)
hf.fit(huggingface_dummy_dataset)
@@ -209,10 +216,7 @@ def test_huggingface_tensorflow(
@pytest.mark.release
def test_tensorflow(
- sagemaker_session,
- gpu_instance_type,
- tensorflow_training_latest_version,
- imagenet_val_set,
+ sagemaker_session, gpu_instance_type, tensorflow_training_latest_version, imagenet_val_set
):
"""
Test the TensorFlow estimator
@@ -264,8 +268,4 @@ def test_tensorflow(
compiler_config=TFTrainingCompilerConfig(),
)
- tf.fit(
- inputs=imagenet_val_set,
- logs=True,
- wait=True,
- )
+ tf.fit(inputs=imagenet_val_set, logs=True, wait=True)
diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py
index a0e37ffc77..1de333b987 100644
--- a/tests/integ/test_transformer.py
+++ b/tests/integ/test_transformer.py
@@ -25,6 +25,7 @@
from sagemaker.transformer import Transformer
from sagemaker.estimator import Estimator
from sagemaker.inputs import BatchDataCaptureConfig
+from sagemaker.xgboost import XGBoostModel
from sagemaker.utils import unique_name_from_base
from tests.integ import (
datasets,
@@ -36,7 +37,7 @@
from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer
from tests.integ.vpc_test_utils import get_or_create_vpc_resources
-from sagemaker.model_monitor import DatasetFormat, Statistics
+from sagemaker.model_monitor import DatasetFormat, Statistics, Constraints
from sagemaker.workflow.check_job_config import CheckJobConfig
from sagemaker.workflow.quality_check_step import (
@@ -645,3 +646,66 @@ def _create_transformer_and_transform_job(
job_name=unique_name_from_base("test-transform"),
)
return transformer
+
+
+def test_transformer_and_monitoring_job(
+ pipeline_session,
+ sagemaker_session,
+ role,
+ pipeline_name,
+ check_job_config,
+ data_bias_check_config,
+):
+ xgb_model_data_s3 = pipeline_session.upload_data(
+ path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"),
+ key_prefix="integ-test-data/xgboost/model",
+ )
+ data_bias_supplied_baseline_constraints = Constraints.from_file_path(
+ constraints_file_path=os.path.join(
+ DATA_DIR, "pipeline/clarify_check_step/data_bias/good_cases/analysis.json"
+ ),
+ sagemaker_session=sagemaker_session,
+ ).file_s3_uri
+
+ xgb_model = XGBoostModel(
+ model_data=xgb_model_data_s3,
+ framework_version="1.3-1",
+ role=role,
+ sagemaker_session=sagemaker_session,
+ entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"),
+ enable_network_isolation=True,
+ )
+
+ xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE)
+
+ transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform"
+ transformer = Transformer(
+ model_name=xgb_model.name,
+ strategy="SingleRecord",
+ instance_type="ml.m5.xlarge",
+ instance_count=1,
+ output_path=transform_output,
+ sagemaker_session=pipeline_session,
+ )
+
+ transform_input = pipeline_session.upload_data(
+ path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"),
+ key_prefix="integ-test-data/xgboost_abalone/abalone",
+ )
+
+ execution = transformer.transform_with_monitoring(
+ monitoring_config=data_bias_check_config,
+ monitoring_resource_config=check_job_config,
+ data=transform_input,
+ content_type="text/libsvm",
+ supplied_baseline_constraints=data_bias_supplied_baseline_constraints,
+ role=role,
+ )
+
+ execution_steps = execution.list_steps()
+ assert len(execution_steps) == 2
+
+ for execution_step in execution_steps:
+ assert execution_step["StepStatus"] == "Succeeded"
+
+ xgb_model.delete_model()
diff --git a/tests/integ/test_xgboost.py b/tests/integ/test_xgboost.py
index 733ab4665a..df06a8863a 100644
--- a/tests/integ/test_xgboost.py
+++ b/tests/integ/test_xgboost.py
@@ -40,6 +40,26 @@ def xgboost_training_job(
)
+def test_sourcedir_naming(
+ sagemaker_session,
+ xgboost_latest_version,
+ xgboost_latest_py_version,
+ cpu_instance_type,
+):
+ with pytest.raises(RuntimeError):
+ processor = XGBoostProcessor(
+ framework_version=xgboost_latest_version,
+ role=ROLE,
+ instance_count=1,
+ instance_type=cpu_instance_type,
+ sagemaker_session=sagemaker_session,
+ )
+ processor.run(
+ source_dir="s3://bucket/deps.tar.gz",
+ code="main_script.py",
+ )
+
+
@pytest.mark.release
def test_framework_processing_job_with_deps(
sagemaker_session,
diff --git a/tests/integ/utils.py b/tests/integ/utils.py
index 53440f96f5..d7891321f2 100644
--- a/tests/integ/utils.py
+++ b/tests/integ/utils.py
@@ -14,6 +14,8 @@
import logging
from functools import wraps
+from botocore.exceptions import ClientError
+
from tests.conftest import NO_P3_REGIONS, NO_M4_REGIONS
from sagemaker.exceptions import CapacityError
@@ -69,3 +71,21 @@ def wrapper(*args, **kwargs):
return wrapper
return decorator
+
+
+def create_repository(ecr_client, repository_name):
+ """Creates an ECS Repository (ECR).
+
+ When a new transform is being registered,
+ we'll need a repository to push the image (and composed model images) to
+ """
+ try:
+ response = ecr_client.create_repository(repositoryName=repository_name)
+ return response["repository"]["repositoryUri"]
+ except ClientError as e:
+ # Handle when the repository already exists
+ if "RepositoryAlreadyExistsException" == e.response.get("Error", {}).get("Code"):
+ response = ecr_client.describe_repositories(repositoryNames=[repository_name])
+ return response["repositories"][0]["repositoryUri"]
+ else:
+ raise
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
new file mode 100644
index 0000000000..21fe49cc97
--- /dev/null
+++ b/tests/unit/conftest.py
@@ -0,0 +1,66 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import pytest
+import sagemaker
+
+from mock import Mock, PropertyMock
+
+_ROLE = "DummyRole"
+_REGION = "us-west-2"
+_DEFAULT_BUCKET = "my-bucket"
+
+
+@pytest.fixture(scope="session")
+def client():
+ """Mock client.
+
+ Considerations when appropriate:
+
+ * utilize botocore.stub.Stubber
+ * separate runtime client from client
+ """
+ client_mock = Mock()
+ client_mock._client_config.user_agent = (
+ "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
+ )
+ return client_mock
+
+
+@pytest.fixture(scope="session")
+def boto_session(client):
+ role_mock = Mock()
+ type(role_mock).arn = PropertyMock(return_value=_ROLE)
+
+ resource_mock = Mock()
+ resource_mock.Role.return_value = role_mock
+
+ session_mock = Mock(region_name=_REGION)
+ session_mock.resource.return_value = resource_mock
+ session_mock.client.return_value = client
+
+ return session_mock
+
+
+@pytest.fixture(scope="session")
+def sagemaker_session(boto_session, client):
+ # ideally this would mock Session instead of instantiating it
+ # most unit tests do mock the session correctly
+ return sagemaker.session.Session(
+ boto_session=boto_session,
+ sagemaker_client=client,
+ sagemaker_runtime_client=client,
+ default_bucket=_DEFAULT_BUCKET,
+ sagemaker_metrics_client=client,
+ )
diff --git a/tests/unit/sagemaker/experiments/__init__.py b/tests/unit/sagemaker/experiments/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/tests/unit/sagemaker/experiments/conftest.py b/tests/unit/sagemaker/experiments/conftest.py
new file mode 100644
index 0000000000..4d33ad759d
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/conftest.py
@@ -0,0 +1,86 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import unittest
+from unittest.mock import patch, MagicMock, Mock
+
+import pytest
+
+from sagemaker import Session
+from sagemaker.experiments.experiment import _Experiment
+from sagemaker.experiments.run import RUN_NAME_BASE
+from sagemaker.experiments import Run
+from tests.unit.sagemaker.experiments.helpers import (
+ mock_tc_load_or_create_func,
+ mock_trial_load_or_create_func,
+ TEST_EXP_NAME,
+)
+
+
+@pytest.fixture
+def client():
+ """Mock client.
+
+ Considerations when appropriate:
+
+ * utilize botocore.stub.Stubber
+ * separate runtime client from client
+ """
+ client_mock = unittest.mock.Mock()
+ client_mock._client_config.user_agent = (
+ "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
+ )
+ return client_mock
+
+
+@pytest.fixture
+def sagemaker_session(client):
+ return Session(
+ sagemaker_client=client,
+ )
+
+
+@pytest.fixture
+def run_obj(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.update_trial_component.return_value = {}
+ client.associate_trial_component.return_value = {}
+ with patch(
+ "sagemaker.experiments.run._Experiment._load_or_create",
+ MagicMock(
+ return_value=_Experiment(
+ experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session
+ )
+ ),
+ ):
+ with patch(
+ "sagemaker.experiments.run._TrialComponent._load_or_create",
+ MagicMock(side_effect=mock_tc_load_or_create_func),
+ ):
+ with patch(
+ "sagemaker.experiments.run._Trial._load_or_create",
+ MagicMock(side_effect=mock_trial_load_or_create_func),
+ ):
+ run = Run(
+ experiment_name=TEST_EXP_NAME,
+ sagemaker_session=sagemaker_session,
+ )
+ run._artifact_uploader = Mock()
+ run._lineage_artifact_tracker = Mock()
+ run._metrics_manager = Mock()
+
+ assert run.run_name.startswith(RUN_NAME_BASE)
+ assert run.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
+
+ return run
diff --git a/tests/unit/sagemaker/experiments/helpers.py b/tests/unit/sagemaker/experiments/helpers.py
new file mode 100644
index 0000000000..b7914010e5
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/helpers.py
@@ -0,0 +1,44 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from sagemaker.experiments.trial import _Trial
+from sagemaker.experiments.trial_component import _TrialComponent
+
+
+TEST_EXP_NAME = "my-experiment"
+TEST_RUN_NAME = "my-run"
+
+
+def mock_tc_load_or_create_func(
+ trial_component_name, display_name=None, tags=None, sagemaker_session=None
+):
+ tc = _TrialComponent(
+ trial_component_name=trial_component_name,
+ display_name=display_name,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
+ return tc, True
+
+
+def mock_trial_load_or_create_func(
+ experiment_name, trial_name, display_name=None, tags=None, sagemaker_session=None
+):
+ return _Trial(
+ trial_name=trial_name,
+ experiment_name=experiment_name,
+ display_name=display_name,
+ tags=tags,
+ sagemaker_session=sagemaker_session,
+ )
diff --git a/tests/unit/sagemaker/experiments/test_environment.py b/tests/unit/sagemaker/experiments/test_environment.py
new file mode 100644
index 0000000000..8bb23db7b6
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_environment.py
@@ -0,0 +1,107 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import json
+import os
+import shutil
+import tempfile
+import unittest.mock
+
+import pytest
+
+from sagemaker.experiments import _environment
+from sagemaker.utils import retry_with_backoff
+
+
+@pytest.fixture
+def tempdir():
+ dir = tempfile.mkdtemp()
+ yield dir
+ shutil.rmtree(dir)
+
+
+@pytest.fixture
+def training_job_env():
+ old_value = os.environ.get("TRAINING_JOB_ARN")
+ os.environ["TRAINING_JOB_ARN"] = "arn:1234aBcDe"
+ yield os.environ
+ del os.environ["TRAINING_JOB_ARN"]
+ if old_value:
+ os.environ["TRAINING_JOB_ARN"] = old_value
+
+
+@pytest.fixture
+def transform_job_env():
+ old_value = os.environ.get("SAGEMAKER_BATCH")
+ os.environ["SAGEMAKER_BATCH"] = "true"
+ yield os.environ
+ del os.environ["SAGEMAKER_BATCH"]
+ if old_value:
+ os.environ["SAGEMAKER_BATCH"] = old_value
+
+
+def test_processing_job_environment(tempdir):
+ config_path = os.path.join(tempdir, "config.json")
+ with open(config_path, "w") as f:
+ f.write(json.dumps({"ProcessingJobArn": "arn:1234aBcDe"}))
+ environment = _environment._RunEnvironment.load(processing_job_config_path=config_path)
+
+ assert _environment._EnvironmentType.SageMakerProcessingJob == environment.environment_type
+ assert "arn:1234aBcDe" == environment.source_arn
+
+
+def test_training_job_environment(training_job_env):
+ environment = _environment._RunEnvironment.load()
+ assert _environment._EnvironmentType.SageMakerTrainingJob == environment.environment_type
+ assert "arn:1234aBcDe" == environment.source_arn
+
+
+def test_transform_job_environment(transform_job_env):
+ environment = _environment._RunEnvironment.load()
+ assert _environment._EnvironmentType.SageMakerTransformJob == environment.environment_type
+ # TODO: update if we figure out how to get source_arn from the transform job
+ assert not environment.source_arn
+
+
+def test_no_environment():
+ assert _environment._RunEnvironment.load() is None
+
+
+def test_resolve_trial_component(training_job_env, sagemaker_session):
+ trial_component_name = "foo-bar"
+ client = sagemaker_session.sagemaker_client
+ client.list_trial_components.return_value = {
+ "TrialComponentSummaries": [{"TrialComponentName": trial_component_name}]
+ }
+ client.describe_trial_component.return_value = {"TrialComponentName": trial_component_name}
+ environment = _environment._RunEnvironment.load()
+ tc = environment.get_trial_component(sagemaker_session)
+
+ assert trial_component_name == tc.trial_component_name
+ client.describe_trial_component.assert_called_with(TrialComponentName=trial_component_name)
+ client.list_trial_components.assert_called_with(SourceArn="arn:1234abcde")
+
+
+@unittest.mock.patch("sagemaker.experiments._environment.retry_with_backoff")
+def test_resolve_trial_component_fails(mock_retry, sagemaker_session, training_job_env):
+ mock_retry.side_effect = lambda func: retry_with_backoff(func, 2)
+ client = sagemaker_session.sagemaker_client
+ client.list_trial_components.side_effect = Exception("Failed test")
+ environment = _environment._RunEnvironment.load()
+ assert environment.get_trial_component(sagemaker_session) is None
+
+
+def test_resolve_transform_job_trial_component_fail(transform_job_env, sagemaker_session):
+ environment = _environment._RunEnvironment.load()
+ assert environment.get_trial_component(sagemaker_session) is None
diff --git a/tests/unit/sagemaker/experiments/test_experiment.py b/tests/unit/sagemaker/experiments/test_experiment.py
new file mode 100644
index 0000000000..b0ad55c27f
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_experiment.py
@@ -0,0 +1,306 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import pytest
+import unittest.mock
+import datetime
+
+from unittest.mock import patch
+
+from sagemaker import Session
+from sagemaker.experiments import experiment
+from sagemaker.experiments._api_types import TrialSummary
+
+
+@pytest.fixture
+def datetime_obj():
+ return datetime.datetime(2017, 6, 16, 15, 55, 0)
+
+
+def test_load(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.describe_experiment.return_value = {"Description": "description-value"}
+ experiment_obj = experiment._Experiment.load(
+ experiment_name="name-value", sagemaker_session=sagemaker_session
+ )
+ assert experiment_obj.experiment_name == "name-value"
+ assert experiment_obj.description == "description-value"
+
+ client.describe_experiment.assert_called_with(ExperimentName="name-value")
+
+
+def test_create(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
+ experiment_obj = experiment._Experiment.create(
+ experiment_name="name-value", sagemaker_session=sagemaker_session
+ )
+ assert experiment_obj.experiment_name == "name-value"
+ client.create_experiment.assert_called_with(ExperimentName="name-value")
+
+
+def test_create_with_tags(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.create_experiment.return_value = {"Arn": "arn:aws:1234"}
+ tags = [{"Key": "foo", "Value": "bar"}]
+ experiment_obj = experiment._Experiment.create(
+ experiment_name="name-value", sagemaker_session=sagemaker_session, tags=tags
+ )
+ assert experiment_obj.experiment_name == "name-value"
+ client.create_experiment.assert_called_with(ExperimentName="name-value", Tags=tags)
+
+
+def test_save(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
+ client.update_experiment.return_value = {}
+ obj.save()
+ client.update_experiment.assert_called_with(ExperimentName="foo", Description="bar")
+
+
+def test_delete(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
+ client.delete_experiment.return_value = {}
+ obj.delete()
+ client.delete_experiment.assert_called_with(ExperimentName="foo")
+
+
+@patch("sagemaker.experiments.experiment._Experiment.load")
+def test_load_or_create_when_exist(mock_load, sagemaker_session):
+ exp_name = "exp_name"
+ experiment._Experiment._load_or_create(
+ experiment_name=exp_name, sagemaker_session=sagemaker_session
+ )
+ mock_load.assert_called_once_with(exp_name, sagemaker_session)
+
+
+@patch("sagemaker.experiments.experiment._Experiment.load")
+@patch("sagemaker.experiments.experiment._Experiment.create")
+def test_load_or_create_when_not_exist(mock_create, mock_load):
+ sagemaker_session = Session()
+ client = sagemaker_session.sagemaker_client
+ exp_name = "exp_name"
+ not_found_err = client.exceptions.ResourceNotFound(
+ error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}},
+ operation_name="foo",
+ )
+ mock_load.side_effect = not_found_err
+
+ experiment._Experiment._load_or_create(
+ experiment_name=exp_name, sagemaker_session=sagemaker_session
+ )
+
+ mock_create.assert_called_once_with(
+ experiment_name=exp_name,
+ display_name=None,
+ description=None,
+ tags=None,
+ sagemaker_session=sagemaker_session,
+ )
+
+
+def test_list_trials_empty(sagemaker_session):
+ sagemaker_session.sagemaker_client.list_trials.return_value = {"TrialSummaries": []}
+ experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
+ assert list(experiment_obj.list_trials()) == []
+
+
+def test_list_trials_single(sagemaker_session, datetime_obj):
+ experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
+ sagemaker_session.sagemaker_client.list_trials.return_value = {
+ "TrialSummaries": [
+ {"Name": "trial-foo", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj}
+ ]
+ }
+
+ assert list(experiment_obj.list_trials()) == [
+ TrialSummary(name="trial-foo", creation_time=datetime_obj, last_modified_time=datetime_obj)
+ ]
+
+
+def test_list_trials_two_values(sagemaker_session, datetime_obj):
+ experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
+ sagemaker_session.sagemaker_client.list_trials.return_value = {
+ "TrialSummaries": [
+ {"Name": "trial-foo-1", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
+ {"Name": "trial-foo-2", "CreationTime": datetime_obj, "LastModifiedTime": datetime_obj},
+ ]
+ }
+
+ assert list(experiment_obj.list_trials()) == [
+ TrialSummary(
+ name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ TrialSummary(
+ name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ ]
+
+
+def test_next_token(sagemaker_session, datetime_obj):
+ experiment_obj = experiment._Experiment(sagemaker_session)
+ client = sagemaker_session.sagemaker_client
+ client.list_trials.side_effect = [
+ {
+ "TrialSummaries": [
+ {
+ "Name": "trial-foo-1",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ {
+ "Name": "trial-foo-2",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ ],
+ "NextToken": "foo",
+ },
+ {
+ "TrialSummaries": [
+ {
+ "Name": "trial-foo-3",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ }
+ ]
+ },
+ ]
+
+ assert list(experiment_obj.list_trials()) == [
+ TrialSummary(
+ name="trial-foo-1", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ TrialSummary(
+ name="trial-foo-2", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ TrialSummary(
+ name="trial-foo-3", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ ]
+
+ client.list_trials.assert_any_call(**{})
+ client.list_trials.assert_any_call(NextToken="foo")
+
+
+def test_list_trials_call_args(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
+ created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
+ experiment_obj = experiment._Experiment(sagemaker_session=sagemaker_session)
+ client.list_trials.return_value = {}
+ assert [] == list(
+ experiment_obj.list_trials(created_after=created_after, created_before=created_before)
+ )
+ client.list_trials.assert_called_with(CreatedBefore=created_before, CreatedAfter=created_after)
+
+
+def test_delete_all_with_incorrect_action_name(sagemaker_session):
+ obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
+ with pytest.raises(ValueError) as err:
+ obj._delete_all(action="abc")
+
+ assert "Must confirm with string '--force'" in str(err)
+
+
+def test_delete_all(sagemaker_session):
+ obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
+ client = sagemaker_session.sagemaker_client
+ client.list_trials.return_value = {
+ "TrialSummaries": [
+ {
+ "TrialName": "trial-1",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ {
+ "TrialName": "trial-2",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ ]
+ }
+ client.describe_trial.side_effect = [
+ {"Trialname": "trial-1", "ExperimentName": "experiment-name-value"},
+ {"Trialname": "trial-2", "ExperimentName": "experiment-name-value"},
+ ]
+ client.list_trial_components.side_effect = [
+ {
+ "TrialComponentSummaries": [
+ {
+ "TrialComponentName": "trial-component-1",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ {
+ "TrialComponentName": "trial-component-2",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ ]
+ },
+ {
+ "TrialComponentSummaries": [
+ {
+ "TrialComponentName": "trial-component-3",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ {
+ "TrialComponentName": "trial-component-4",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ ]
+ },
+ ]
+
+ client.describe_trial_component.side_effect = [
+ {"TrialComponentName": "trial-component-1"},
+ {"TrialComponentName": "trial-component-2"},
+ {"TrialComponentName": "trial-component-3"},
+ {"TrialComponentName": "trial-component-4"},
+ ]
+
+ client.delete_trial_component.return_value = {}
+ client.delete_trial.return_value = {}
+ client.delete_experiment.return_value = {}
+
+ obj._delete_all(action="--force")
+
+ client.delete_experiment.assert_called_with(ExperimentName="foo")
+
+ delete_trial_expected_calls = [
+ unittest.mock.call(TrialName="trial-1"),
+ unittest.mock.call(TrialName="trial-2"),
+ ]
+ assert delete_trial_expected_calls == client.delete_trial.mock_calls
+
+ delete_trial_component_expected_calls = [
+ unittest.mock.call(TrialComponentName="trial-component-1"),
+ unittest.mock.call(TrialComponentName="trial-component-2"),
+ unittest.mock.call(TrialComponentName="trial-component-3"),
+ unittest.mock.call(TrialComponentName="trial-component-4"),
+ ]
+ assert delete_trial_component_expected_calls == client.delete_trial_component.mock_calls
+
+
+def test_delete_all_fail(sagemaker_session):
+ obj = experiment._Experiment(sagemaker_session, experiment_name="foo", description="bar")
+ sagemaker_session.sagemaker_client.list_trials.side_effect = Exception
+ with pytest.raises(Exception) as e:
+ obj._delete_all(action="--force")
+
+ assert str(e.value) == "Failed to delete, please try again."
diff --git a/tests/unit/sagemaker/experiments/test_helper.py b/tests/unit/sagemaker/experiments/test_helper.py
new file mode 100644
index 0000000000..a11f67389b
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_helper.py
@@ -0,0 +1,195 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import json
+import os
+import shutil
+import tempfile
+
+from mock import Mock, PropertyMock, call
+import pytest
+
+from src.sagemaker.experiments._helper import (
+ _LineageArtifactTracker,
+ _ArtifactUploader,
+)
+from src.sagemaker.experiments._utils import resolve_artifact_name
+from src.sagemaker.session import Session
+
+
+@pytest.fixture
+def client():
+ """Mock client.
+
+ Considerations when appropriate:
+
+ * utilize botocore.stub.Stubber
+ * separate runtime client from client
+ """
+ client_mock = Mock()
+ client_mock._client_config.user_agent = (
+ "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
+ )
+ return client_mock
+
+
+@pytest.fixture
+def boto_session(client):
+ role_mock = Mock()
+ type(role_mock).arn = PropertyMock(return_value="DummyRole")
+
+ resource_mock = Mock()
+ resource_mock.Role.return_value = role_mock
+
+ session_mock = Mock(region_name="us-west-2")
+ session_mock.resource.return_value = resource_mock
+ session_mock.client.return_value = client
+
+ return session_mock
+
+
+@pytest.fixture
+def sagemaker_session(client, boto_session):
+ return Session(
+ sagemaker_client=client,
+ boto_session=boto_session,
+ )
+
+
+@pytest.fixture
+def lineage_artifact_tracker(sagemaker_session):
+ return _LineageArtifactTracker("test_trial_component_arn", sagemaker_session)
+
+
+def test_lineage_artifact_tracker(lineage_artifact_tracker, sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ lineage_artifact_tracker.add_input_artifact(
+ "input_name", "input_source_uri", "input_etag", "text/plain"
+ )
+ lineage_artifact_tracker.add_output_artifact(
+ "output_name", "output_source_uri", "output_etag", "text/plain"
+ )
+ client.create_artifact.side_effect = [
+ {"ArtifactArn": "created_arn_1"},
+ {"ArtifactArn": "created_arn_2"},
+ ]
+
+ lineage_artifact_tracker.save()
+
+ expected_calls = [
+ call(
+ ArtifactName="input_name",
+ ArtifactType="text/plain",
+ Source={
+ "SourceUri": "input_source_uri",
+ "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "input_etag"}],
+ },
+ ),
+ call(
+ ArtifactName="output_name",
+ ArtifactType="text/plain",
+ Source={
+ "SourceUri": "output_source_uri",
+ "SourceTypes": [{"SourceIdType": "S3ETag", "Value": "output_etag"}],
+ },
+ ),
+ ]
+ assert expected_calls == client.create_artifact.mock_calls
+
+ expected_calls = [
+ call(
+ SourceArn="created_arn_1",
+ DestinationArn="test_trial_component_arn",
+ AssociationType="ContributedTo",
+ ),
+ call(
+ SourceArn="test_trial_component_arn",
+ DestinationArn="created_arn_2",
+ AssociationType="Produced",
+ ),
+ ]
+ assert expected_calls == client.add_association.mock_calls
+
+
+@pytest.fixture
+def artifact_uploader(sagemaker_session):
+ return _ArtifactUploader(
+ trial_component_name="trial_component_name",
+ artifact_bucket="artifact_bucket",
+ artifact_prefix="artifact_prefix",
+ sagemaker_session=sagemaker_session,
+ )
+
+
+@pytest.fixture
+def tempdir():
+ tmp_dir = tempfile.mkdtemp()
+ yield tmp_dir
+ shutil.rmtree(tmp_dir)
+
+
+def test_artifact_uploader_init(artifact_uploader):
+ assert "trial_component_name" == artifact_uploader.trial_component_name
+ assert "artifact_bucket" == artifact_uploader.artifact_bucket
+ assert "artifact_prefix" == artifact_uploader.artifact_prefix
+
+
+def test_artifact_uploader_upload_artifact_file_not_exists(tempdir, artifact_uploader):
+ not_exist_file = os.path.join(tempdir, "not.exists")
+ with pytest.raises(ValueError) as error:
+ artifact_uploader.upload_artifact(not_exist_file)
+ assert "does not exist or is not a file" in str(error)
+
+
+def test_artifact_uploader_upload_artifact(tempdir, artifact_uploader):
+ path = os.path.join(tempdir, "exists")
+ with open(path, "a") as f:
+ f.write("boo")
+
+ name = resolve_artifact_name(path)
+ artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"}
+
+ s3_uri, etag = artifact_uploader.upload_artifact(path)
+ expected_key = "{}/{}/{}".format(
+ artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name
+ )
+
+ artifact_uploader._s3_client.upload_file.assert_called_with(
+ path, artifact_uploader.artifact_bucket, expected_key
+ )
+
+ expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key)
+ assert expected_uri == s3_uri
+
+
+def test_artifact_uploader_upload_object_artifact(tempdir, artifact_uploader):
+ artifact_uploader._s3_client.head_object.return_value = {"ETag": "etag_value"}
+
+ artifact_name = "my-artifact"
+ artifact_object = {"key": "value"}
+ file_extension = ".csv"
+ s3_uri, etag = artifact_uploader.upload_object_artifact(
+ artifact_name, artifact_object, file_extension
+ )
+ name = artifact_name + file_extension
+ expected_key = "{}/{}/{}".format(
+ artifact_uploader.artifact_prefix, artifact_uploader.trial_component_name, name
+ )
+
+ artifact_uploader._s3_client.put_object.assert_called_with(
+ Body=json.dumps(artifact_object), Bucket=artifact_uploader.artifact_bucket, Key=expected_key
+ )
+
+ expected_uri = "s3://{}/{}".format(artifact_uploader.artifact_bucket, expected_key)
+ assert expected_uri == s3_uri
diff --git a/tests/unit/sagemaker/experiments/test_metrics.py b/tests/unit/sagemaker/experiments/test_metrics.py
new file mode 100644
index 0000000000..21556f70fd
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_metrics.py
@@ -0,0 +1,178 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import os
+import pytest
+import tempfile
+import shutil
+import datetime
+import dateutil
+import json
+import time
+
+from sagemaker.experiments._metrics import (
+ _RawMetricData,
+ _SageMakerFileMetricsWriter,
+ SageMakerMetricsWriterException,
+)
+
+
+@pytest.fixture
+def tempdir():
+ dir = tempfile.mkdtemp()
+ yield dir
+ shutil.rmtree(dir)
+
+
+@pytest.fixture
+def filepath(tempdir):
+ return os.path.join(tempdir, "foo.json")
+
+
+@pytest.fixture
+def timestamp():
+ return datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(hours=1)
+
+
+def test_raw_metric_data_utc_timestamp():
+ utcnow = datetime.datetime.now(datetime.timezone.utc)
+ assert utcnow.tzinfo
+ metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow)
+ assert utcnow.timestamp() == metric.Timestamp
+
+
+def test_raw_metric_data_utc_():
+ utcnow = datetime.datetime.now(datetime.timezone.utc)
+ assert utcnow.tzinfo
+ metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=utcnow)
+ assert utcnow.timestamp() == metric.Timestamp
+
+
+def test_raw_metric_data_aware_timestamp():
+ aware_datetime = datetime.datetime.now(dateutil.tz.gettz("America/Chicago"))
+ assert aware_datetime.tzinfo
+ metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=aware_datetime)
+ assert (aware_datetime - aware_datetime.utcoffset()).replace(
+ tzinfo=datetime.timezone.utc
+ ).timestamp() == metric.Timestamp
+
+
+def test_raw_metric_data_naive_timestamp():
+ naive_datetime = datetime.datetime.now()
+ assert naive_datetime.tzinfo is None
+ metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=naive_datetime)
+ local_datetime = naive_datetime.replace(tzinfo=dateutil.tz.tzlocal())
+ assert (local_datetime - local_datetime.utcoffset()).replace(
+ tzinfo=datetime.timezone.utc
+ ).timestamp() == metric.Timestamp
+
+
+def test_raw_metric_data_number_timestamp():
+ time_now = time.time()
+ metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now)
+ assert time_now == metric.Timestamp
+
+
+def test_raw_metric_data_request_item():
+ time_now = time.time()
+ metric = _RawMetricData(metric_name="foo", value=1.0, timestamp=time_now, step=10)
+ expected = {
+ "MetricName": "foo",
+ "Value": 1.0,
+ "Timestamp": str(int(time_now)),
+ "Step": 10,
+ }
+ assert expected == metric.to_raw_metric_data()
+
+
+def test_raw_metric_data_invalid_timestamp():
+ with pytest.raises(ValueError) as error1:
+ _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() - 2000000)
+ assert "Timestamps must be between two weeks before and two hours from now" in str(error1)
+
+ with pytest.raises(ValueError) as error2:
+ _RawMetricData(metric_name="IFail", value=100, timestamp=time.time() + 10000)
+ assert "Timestamps must be between two weeks before and two hours from now" in str(error2)
+
+
+def test_file_metrics_writer_log_metric(timestamp, filepath):
+ now = datetime.datetime.now(datetime.timezone.utc)
+ writer = _SageMakerFileMetricsWriter(filepath)
+ writer.log_metric(metric_name="foo", value=1.0)
+ writer.log_metric(metric_name="foo", value=2.0, step=1)
+ writer.log_metric(metric_name="foo", value=3.0, timestamp=timestamp)
+ writer.log_metric(metric_name="foo", value=4.0, timestamp=timestamp, step=2)
+ writer.close()
+
+ lines = [x for x in open(filepath).read().split("\n") if x]
+ [entry_one, entry_two, entry_three, entry_four] = [json.loads(line) for line in lines]
+
+ assert "foo" == entry_one["MetricName"]
+ assert 1.0 == entry_one["Value"]
+ assert (now.timestamp() - entry_one["Timestamp"]) < 1
+ assert "Step" not in entry_one
+
+ assert 1 == entry_two["Step"]
+ assert timestamp.timestamp() == entry_three["Timestamp"]
+ assert 2 == entry_four["Step"]
+
+
+def test_file_metrics_writer_flushes_buffer_every_line_log_metric(filepath):
+ writer = _SageMakerFileMetricsWriter(filepath)
+
+ writer.log_metric(metric_name="foo", value=1.0)
+
+ lines = [x for x in open(filepath).read().split("\n") if x]
+ [entry_one] = [json.loads(line) for line in lines]
+ assert "foo" == entry_one["MetricName"]
+ assert 1.0 == entry_one["Value"]
+
+ writer.log_metric(metric_name="bar", value=2.0)
+ lines = [x for x in open(filepath).read().split("\n") if x]
+ [entry_one, entry_two] = [json.loads(line) for line in lines]
+ assert "bar" == entry_two["MetricName"]
+ assert 2.0 == entry_two["Value"]
+
+ writer.log_metric(metric_name="biz", value=3.0)
+ lines = [x for x in open(filepath).read().split("\n") if x]
+ [entry_one, entry_two, entry_three] = [json.loads(line) for line in lines]
+ assert "biz" == entry_three["MetricName"]
+ assert 3.0 == entry_three["Value"]
+
+ writer.close()
+
+
+def test_file_metrics_writer_context_manager(timestamp, filepath):
+ with _SageMakerFileMetricsWriter(filepath) as writer:
+ writer.log_metric("foo", value=1.0, timestamp=timestamp)
+ entry = json.loads(open(filepath, "r").read().strip())
+ assert {
+ "MetricName": "foo",
+ "Value": 1.0,
+ "Timestamp": timestamp.timestamp(),
+ }.items() <= entry.items()
+
+
+def test_file_metrics_writer_fail_write_on_close(filepath):
+ writer = _SageMakerFileMetricsWriter(filepath)
+ writer.log_metric(metric_name="foo", value=1.0)
+ writer.close()
+ with pytest.raises(SageMakerMetricsWriterException):
+ writer.log_metric(metric_name="foo", value=1.0)
+
+
+def test_file_metrics_writer_no_write(filepath):
+ writer = _SageMakerFileMetricsWriter(filepath)
+ writer.close()
+ assert not os.path.exists(filepath)
diff --git a/tests/unit/sagemaker/experiments/test_run.py b/tests/unit/sagemaker/experiments/test_run.py
new file mode 100644
index 0000000000..0e4ebee181
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_run.py
@@ -0,0 +1,941 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import datetime
+import unittest
+from math import inf, nan
+from unittest.mock import patch, Mock, MagicMock
+
+import dateutil
+import pytest
+
+from sagemaker.experiments import _environment, SortOrderType
+from sagemaker.experiments._api_types import (
+ TrialComponentArtifact,
+ TrialComponentSummary,
+ TrialComponentStatus,
+ _TrialComponentStatusType,
+ TrialComponentSearchResult,
+)
+from sagemaker.experiments.experiment import _Experiment
+from sagemaker.experiments.run import (
+ TRIAL_NAME_TEMPLATE,
+ MAX_RUN_TC_ARTIFACTS_LEN,
+ MAX_NAME_LEN_IN_BACKEND,
+ EXPERIMENT_NAME,
+ RUN_NAME,
+ TRIAL_NAME,
+ DELIMITER,
+ RUN_TC_TAG,
+ SortByType,
+)
+from sagemaker.experiments import Run, load_run, list_runs
+from sagemaker.experiments.trial import _Trial
+from sagemaker.experiments.trial_component import _TrialComponent
+from tests.unit.sagemaker.experiments.helpers import (
+ mock_trial_load_or_create_func,
+ mock_tc_load_or_create_func,
+ TEST_EXP_NAME,
+ TEST_RUN_NAME,
+)
+
+
+@patch(
+ "sagemaker.experiments.run._Experiment._load_or_create",
+ MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
+)
+@patch(
+ "sagemaker.experiments.run._Trial._load_or_create",
+ MagicMock(side_effect=mock_trial_load_or_create_func),
+)
+@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
+@patch(
+ "sagemaker.experiments.run._TrialComponent._load_or_create",
+ MagicMock(side_effect=mock_tc_load_or_create_func),
+)
+@patch.object(_TrialComponent, "save")
+def test_run_init(mock_tc_save, sagemaker_session):
+ with Run(
+ experiment_name=TEST_EXP_NAME, run_name=TEST_RUN_NAME, sagemaker_session=sagemaker_session
+ ) as run_obj:
+ assert not run_obj._in_load
+ assert not run_obj._inside_load_context
+ assert run_obj._inside_init_context
+ assert not run_obj._trial_component.parameters
+
+ expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}"
+ assert run_obj.experiment_name == TEST_EXP_NAME
+ assert run_obj.run_name == TEST_RUN_NAME
+ assert run_obj.run_group_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME)
+ assert run_obj._trial_component.trial_component_name == expected_tc_name
+ assert run_obj._trial.trial_name == TRIAL_NAME_TEMPLATE.format(TEST_EXP_NAME)
+ assert run_obj._experiment.experiment_name == TEST_EXP_NAME
+ assert run_obj.experiment_config == {
+ EXPERIMENT_NAME: TEST_EXP_NAME,
+ TRIAL_NAME: run_obj.run_group_name,
+ RUN_NAME: expected_tc_name,
+ }
+
+ # trail_component.save is called when entering/ exiting the with block
+ mock_tc_save.assert_called()
+
+
+def test_run_init_name_length_exceed_limit(sagemaker_session):
+ invalid_name = "x" * MAX_NAME_LEN_IN_BACKEND
+
+ # experiment_name exceeds
+ with pytest.raises(ValueError) as err:
+ Run(
+ experiment_name=invalid_name,
+ run_name=TEST_RUN_NAME,
+ sagemaker_session=sagemaker_session,
+ )
+
+ assert (
+ f"The experiment_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than"
+ in str(err)
+ )
+
+ # run_name exceeds
+ with pytest.raises(ValueError) as err:
+ Run(
+ experiment_name=TEST_EXP_NAME,
+ run_name=invalid_name,
+ sagemaker_session=sagemaker_session,
+ )
+
+ assert f"The run_name (length: {MAX_NAME_LEN_IN_BACKEND}) must have length less than" in str(
+ err
+ )
+
+
+@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
+@patch(
+ "sagemaker.experiments.run._Experiment._load_or_create",
+ MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
+)
+@patch(
+ "sagemaker.experiments.run._Trial._load_or_create",
+ MagicMock(side_effect=mock_trial_load_or_create_func),
+)
+@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
+@patch(
+ "sagemaker.experiments.run._TrialComponent._load_or_create",
+ MagicMock(side_effect=mock_tc_load_or_create_func),
+)
+@patch("sagemaker.experiments.run._RunEnvironment")
+def test_run_load_no_run_name_and_in_train_job(mock_run_env, sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ job_name = "my-train-job"
+ rv = Mock()
+ rv.source_arn = f"arn:1234/{job_name}"
+ rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob
+ mock_run_env.load.return_value = rv
+
+ expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}"
+ exp_config = {
+ EXPERIMENT_NAME: TEST_EXP_NAME,
+ TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME),
+ RUN_NAME: expected_tc_name,
+ }
+ client.describe_training_job.return_value = {
+ "TrainingJobName": "train-job-experiments",
+ # The Run object has been created else where
+ "ExperimentConfig": exp_config,
+ }
+ with load_run(sagemaker_session=sagemaker_session) as run_obj:
+ assert run_obj._in_load
+ assert not run_obj._inside_init_context
+ assert run_obj._inside_load_context
+ assert run_obj.run_name == TEST_RUN_NAME
+ assert run_obj._trial_component.trial_component_name == expected_tc_name
+ assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
+ assert run_obj._trial
+ assert run_obj.experiment_name == TEST_EXP_NAME
+ assert run_obj._experiment
+ assert run_obj.experiment_config == exp_config
+
+ client.describe_training_job.assert_called_once_with(TrainingJobName=job_name)
+
+
+@patch("sagemaker.experiments.run._RunEnvironment")
+def test_run_load_no_run_name_and_in_train_job_but_fail_to_get_exp_cfg(
+ mock_run_env, sagemaker_session
+):
+ rv = Mock()
+ rv.source_arn = "arn:1234/my-train-job"
+ rv.environment_type = _environment._EnvironmentType.SageMakerTrainingJob
+ mock_run_env.load.return_value = rv
+
+ # No Run object is created else where
+ sagemaker_session.sagemaker_client.describe_training_job.return_value = {
+ "TrainingJobName": "train-job-experiments",
+ }
+
+ with pytest.raises(RuntimeError) as err:
+ with load_run(sagemaker_session=sagemaker_session):
+ pass
+
+ assert "Not able to fetch RunName in ExperimentConfig of the sagemaker job" in str(err)
+
+
+def test_run_load_no_run_name_and_not_in_train_job(run_obj, sagemaker_session):
+ with run_obj:
+ with load_run(sagemaker_session=sagemaker_session) as run:
+ assert run_obj == run
+
+
+def test_run_load_no_run_name_and_not_in_train_job_but_no_obj_in_context(sagemaker_session):
+ with pytest.raises(RuntimeError) as err:
+ with load_run(sagemaker_session=sagemaker_session):
+ pass
+
+ assert "Failed to load a Run object" in str(err)
+
+ # experiment_name is given but is not supplied along with the run_name so it's ignored.
+ with pytest.raises(RuntimeError) as err:
+ with load_run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session):
+ pass
+
+ assert "Failed to load a Run object" in str(err)
+
+
+@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
+@patch(
+ "sagemaker.experiments.run._Experiment._load_or_create",
+ MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
+)
+@patch(
+ "sagemaker.experiments.run._Trial._load_or_create",
+ MagicMock(side_effect=mock_trial_load_or_create_func),
+)
+@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
+@patch(
+ "sagemaker.experiments.run._TrialComponent._load_or_create",
+ MagicMock(side_effect=mock_tc_load_or_create_func),
+)
+def test_run_load_with_run_name_and_exp_name(sagemaker_session):
+ with load_run(
+ run_name=TEST_RUN_NAME,
+ experiment_name=TEST_EXP_NAME,
+ sagemaker_session=sagemaker_session,
+ ) as run_obj:
+ expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}"
+ expected_exp_config = {
+ EXPERIMENT_NAME: TEST_EXP_NAME,
+ TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME),
+ RUN_NAME: expected_tc_name,
+ }
+
+ assert run_obj.run_name == TEST_RUN_NAME
+ assert run_obj.run_group_name == Run._generate_trial_name(TEST_EXP_NAME)
+ assert run_obj.experiment_name == TEST_EXP_NAME
+ assert run_obj._trial_component.trial_component_name == expected_tc_name
+ assert run_obj._trial
+ assert run_obj._experiment
+ assert run_obj.experiment_config == expected_exp_config
+
+
+def test_run_load_with_run_name_but_no_exp_name(sagemaker_session):
+ with pytest.raises(ValueError) as err:
+ with load_run(
+ run_name=TEST_RUN_NAME,
+ sagemaker_session=sagemaker_session,
+ ):
+ pass
+
+ assert "Invalid input: experiment_name is missing" in str(err)
+
+
+@patch(
+ "sagemaker.experiments.run._Experiment._load_or_create",
+ MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
+)
+@patch(
+ "sagemaker.experiments.run._Trial._load_or_create",
+ MagicMock(side_effect=mock_trial_load_or_create_func),
+)
+@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
+@patch(
+ "sagemaker.experiments.run._TrialComponent._load_or_create",
+ MagicMock(side_effect=mock_tc_load_or_create_func),
+)
+@patch.object(_TrialComponent, "save", MagicMock(return_value=None))
+@patch("sagemaker.experiments.run._RunEnvironment")
+def test_run_load_in_sm_processing_job(mock_run_env, sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ job_name = "my-process-job"
+ rv = unittest.mock.Mock()
+ rv.source_arn = f"arn:1234/{job_name}"
+ rv.environment_type = _environment._EnvironmentType.SageMakerProcessingJob
+ mock_run_env.load.return_value = rv
+
+ expected_tc_name = f"{TEST_EXP_NAME}{DELIMITER}{TEST_RUN_NAME}"
+ exp_config = {
+ EXPERIMENT_NAME: TEST_EXP_NAME,
+ TRIAL_NAME: Run._generate_trial_name(TEST_EXP_NAME),
+ RUN_NAME: expected_tc_name,
+ }
+ client.describe_processing_job.return_value = {
+ "ProcessingJobName": "process-job-experiments",
+ # The Run object has been created else where
+ "ExperimentConfig": exp_config,
+ }
+
+ with load_run(sagemaker_session=sagemaker_session):
+ pass
+
+ client.describe_processing_job.assert_called_once_with(ProcessingJobName=job_name)
+
+
+@patch("sagemaker.experiments.run._RunEnvironment")
+def test_run_load_in_sm_transform_job(mock_run_env, sagemaker_session):
+ # TODO: update this test once figure out how to get source_arn from transform job
+ rv = unittest.mock.Mock()
+ rv.environment_type = _environment._EnvironmentType.SageMakerTransformJob
+ rv.source_arn = ""
+ mock_run_env.load.return_value = rv
+
+ with pytest.raises(RuntimeError) as err:
+ with load_run(sagemaker_session=sagemaker_session):
+ pass
+
+ assert (
+ "loading experiment config from transform job environment is not currently supported"
+ ) in str(err)
+
+
+def test_log_parameter_outside_run_context(run_obj):
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_parameter("foo", "bar")
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_parameter(run_obj):
+ with run_obj:
+ run_obj.log_parameter("foo", "bar")
+ assert run_obj._trial_component.parameters["foo"] == "bar"
+ run_obj.log_parameter("whizz", 1)
+ assert run_obj._trial_component.parameters["whizz"] == 1
+
+
+def test_log_parameter_skip_invalid_value(run_obj):
+ with run_obj:
+ run_obj.log_parameter("key", nan)
+ assert "key" not in run_obj._trial_component.parameters
+
+
+def test_log_parameters_outside_run_context(run_obj):
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_parameters({"a": "b", "c": "d", "e": 5})
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_parameters(run_obj):
+ with run_obj:
+ run_obj.log_parameters({"a": "b", "c": "d", "e": 5})
+ assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5}
+
+
+def test_log_parameters_skip_invalid_values(run_obj):
+ with run_obj:
+ run_obj.log_parameters({"a": "b", "c": "d", "e": 5, "f": nan})
+ assert run_obj._trial_component.parameters == {"a": "b", "c": "d", "e": 5}
+
+
+def test_log_input_outside_run_context(run_obj):
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_artifact("foo", "baz", "text/text", False)
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_input(run_obj):
+ with run_obj:
+ run_obj.log_artifact("foo", "baz", "text/text", False)
+ assert run_obj._trial_component.input_artifacts == {
+ "foo": TrialComponentArtifact(value="baz", media_type="text/text")
+ }
+
+
+def test_log_output_outside_run_context(run_obj):
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_artifact("foo", "baz", "text/text")
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_output(run_obj):
+ with run_obj:
+ run_obj.log_artifact("foo", "baz", "text/text")
+ assert run_obj._trial_component.output_artifacts == {
+ "foo": TrialComponentArtifact(value="baz", media_type="text/text")
+ }
+
+
+def test_log_metric_outside_run_context(run_obj):
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_metric(name="foo", value=1.0, step=1)
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_metric(run_obj):
+ now = datetime.datetime.now()
+ with run_obj:
+ run_obj.log_metric(name="foo", value=1.0, step=1, timestamp=now)
+ run_obj._metrics_manager.log_metric.assert_called_with(
+ metric_name="foo", value=1.0, step=1, timestamp=now
+ )
+
+
+def test_log_metric_skip_invalid_value(run_obj):
+ with run_obj:
+ run_obj.log_metric(None, nan, None, None)
+ assert not run_obj._metrics_manager.log_metric.called
+
+
+def test_log_metric_attribute_error(run_obj):
+ now = datetime.datetime.now()
+ with run_obj:
+ run_obj._metrics_manager.log_metric.side_effect = AttributeError
+
+ with pytest.raises(AttributeError):
+ run_obj.log_metric("foo", 1.0, 1, now)
+
+
+def test_log_output_artifact_outside_run_context(run_obj):
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_file("foo.txt", "name", "whizz/bang")
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_output_artifact(run_obj):
+ run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
+ with run_obj:
+ run_obj.log_file("foo.txt", "name", "whizz/bang")
+ run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
+ assert "whizz/bang" == run_obj._trial_component.output_artifacts["name"].media_type
+
+ run_obj.log_file("foo.txt")
+ run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
+ assert "foo.txt" in run_obj._trial_component.output_artifacts
+ assert "text/plain" == run_obj._trial_component.output_artifacts["foo.txt"].media_type
+
+
+def test_log_input_artifact_outside_run_context(run_obj):
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_input_artifact(run_obj):
+ run_obj._artifact_uploader.upload_artifact.return_value = ("s3uri_value", "etag_value")
+ with run_obj:
+ run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
+ run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
+ assert "whizz/bang" == run_obj._trial_component.input_artifacts["name"].media_type
+
+ run_obj.log_file("foo.txt", is_output=False)
+ run_obj._artifact_uploader.upload_artifact.assert_called_with("foo.txt")
+ assert "foo.txt" in run_obj._trial_component.input_artifacts
+ assert "text/plain" == run_obj._trial_component.input_artifacts["foo.txt"].media_type
+
+
+def test_log_multiple_inputs(run_obj):
+ with run_obj:
+ for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN):
+ file_path = "foo" + str(index) + ".txt"
+ run_obj._trial_component.input_artifacts[file_path] = {
+ "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text")
+ }
+ with pytest.raises(ValueError) as error:
+ run_obj.log_artifact("foo.txt", "name", "whizz/bang", False)
+ assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error)
+
+
+def test_log_multiple_outputs(run_obj):
+ with run_obj:
+ for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN):
+ file_path = "foo" + str(index) + ".txt"
+ run_obj._trial_component.output_artifacts[file_path] = {
+ "foo": TrialComponentArtifact(value="baz" + str(index), media_type="text/text")
+ }
+ with pytest.raises(ValueError) as error:
+ run_obj.log_artifact("foo.txt", "name", "whizz/bang")
+ assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error)
+
+
+def test_log_multiple_input_artifacts(run_obj):
+ with run_obj:
+ for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN):
+ file_path = "foo" + str(index) + ".txt"
+ run_obj._artifact_uploader.upload_artifact.return_value = (
+ "s3uri_value" + str(index),
+ "etag_value" + str(index),
+ )
+ run_obj.log_file(
+ file_path, "name" + str(index), "whizz/bang" + str(index), is_output=False
+ )
+ run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path)
+
+ run_obj._artifact_uploader.upload_artifact.return_value = (
+ "s3uri_value",
+ "etag_value",
+ )
+
+ # log an output artifact, should be fine
+ run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=True)
+
+ # log an extra input artifact, should raise exception
+ with pytest.raises(ValueError) as error:
+ run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
+ assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} input_artifacts" in str(error)
+
+
+def test_log_multiple_output_artifacts(run_obj):
+ with run_obj:
+ for index in range(0, MAX_RUN_TC_ARTIFACTS_LEN):
+ file_path = "foo" + str(index) + ".txt"
+ run_obj._artifact_uploader.upload_artifact.return_value = (
+ "s3uri_value" + str(index),
+ "etag_value" + str(index),
+ )
+ run_obj.log_file(file_path, "name" + str(index), "whizz/bang" + str(index))
+ run_obj._artifact_uploader.upload_artifact.assert_called_with(file_path)
+
+ run_obj._artifact_uploader.upload_artifact.return_value = (
+ "s3uri_value",
+ "etag_value",
+ )
+
+ # log an input artifact, should be fine
+ run_obj.log_file("foo.txt", "name", "whizz/bang", is_output=False)
+
+ # log an extra output artifact, should raise exception
+ with pytest.raises(ValueError) as error:
+ run_obj.log_file("foo.txt", "name", "whizz/bang")
+ assert f"Cannot add more than {MAX_RUN_TC_ARTIFACTS_LEN} output_artifacts" in str(error)
+
+
+def test_log_precision_recall_outside_run_context(run_obj):
+ y_true = [0, 0, 1, 1]
+ y_scores = [0.1, 0.4, 0.35, 0.8]
+ no_skill = 0.1
+ title = "TestPrecisionRecall"
+
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_precision_recall(
+ y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False
+ )
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_precision_recall(run_obj):
+ y_true = [0, 0, 1, 1]
+ y_scores = [0.1, 0.4, 0.35, 0.8]
+ no_skill = 0.1
+ title = "TestPrecisionRecall"
+
+ run_obj._artifact_uploader.upload_object_artifact.return_value = (
+ "s3uri_value",
+ "etag_value",
+ )
+ with run_obj:
+ run_obj.log_precision_recall(
+ y_true, y_scores, 0, title=title, no_skill=no_skill, is_output=False
+ )
+
+ expected_data = {
+ "type": "PrecisionRecallCurve",
+ "version": 0,
+ "title": title,
+ "precision": [0.5, 0.3333333333333333, 0.5, 0.0, 1.0],
+ "recall": [1.0, 0.5, 0.5, 0.0, 0.0],
+ "averagePrecisionScore": 0.5,
+ "noSkill": 0.1,
+ }
+ run_obj._artifact_uploader.upload_object_artifact.assert_called_with(
+ title, expected_data, file_extension="json"
+ )
+
+ run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with(
+ name=title,
+ source_uri="s3uri_value",
+ etag="etag_value",
+ artifact_type="PrecisionRecallCurve",
+ )
+
+
+def test_log_precision_recall_invalid_input(run_obj):
+ y_true = [0, 0, 1, 1]
+ y_scores = [0.1, 0.4, 0.35]
+ no_skill = 0.1
+
+ with run_obj:
+ with pytest.raises(ValueError) as error:
+ run_obj.log_precision_recall(
+ y_true, y_scores, 0, title="TestPrecisionRecall", no_skill=no_skill, is_output=False
+ )
+ assert "Lengths mismatch between true labels and predicted probabilities" in str(error)
+
+
+def test_log_confusion_matrix_outside_run_context(run_obj):
+ y_true = [2, 0, 2, 2, 0, 1]
+ y_pred = [0, 0, 2, 2, 0, 2]
+
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix")
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_confusion_matrix(run_obj):
+ y_true = [2, 0, 2, 2, 0, 1]
+ y_pred = [0, 0, 2, 2, 0, 2]
+
+ run_obj._artifact_uploader.upload_object_artifact.return_value = (
+ "s3uri_value",
+ "etag_value",
+ )
+ with run_obj:
+ run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix")
+
+ expected_data = {
+ "type": "ConfusionMatrix",
+ "version": 0,
+ "title": "TestConfusionMatrix",
+ "confusionMatrix": [[2, 0, 0], [0, 0, 1], [1, 0, 2]],
+ }
+
+ run_obj._artifact_uploader.upload_object_artifact.assert_called_with(
+ "TestConfusionMatrix", expected_data, file_extension="json"
+ )
+
+ run_obj._lineage_artifact_tracker.add_output_artifact.assert_called_with(
+ name="TestConfusionMatrix",
+ source_uri="s3uri_value",
+ etag="etag_value",
+ artifact_type="ConfusionMatrix",
+ )
+
+
+def test_log_confusion_matrix_invalid_input(run_obj):
+ y_true = [2, 0, 2, 2, 0, 1]
+ y_pred = [0, 0, 2, 2, 0]
+
+ with run_obj:
+ with pytest.raises(ValueError) as error:
+ run_obj.log_confusion_matrix(y_true, y_pred, title="TestConfusionMatrix")
+ assert "Lengths mismatch between true labels and predicted labels" in str(error)
+
+
+def test_log_roc_curve_outside_run_context(run_obj):
+ y_true = [0, 0, 1, 1]
+ y_scores = [0.1, 0.4, 0.35, 0.8]
+
+ with pytest.raises(RuntimeError) as err:
+ run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False)
+ assert "This method should be called inside context of 'with' statement" in str(err)
+
+
+def test_log_roc_curve(run_obj):
+ y_true = [0, 0, 1, 1]
+ y_scores = [0.1, 0.4, 0.35, 0.8]
+ with run_obj:
+ run_obj._artifact_uploader.upload_object_artifact.return_value = (
+ "s3uri_value",
+ "etag_value",
+ )
+
+ run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False)
+
+ expected_data = {
+ "type": "ROCCurve",
+ "version": 0,
+ "title": "TestROCCurve",
+ "falsePositiveRate": [0.0, 0.0, 0.5, 0.5, 1.0],
+ "truePositiveRate": [0.0, 0.5, 0.5, 1.0, 1.0],
+ "areaUnderCurve": 0.75,
+ }
+ run_obj._artifact_uploader.upload_object_artifact.assert_called_with(
+ "TestROCCurve", expected_data, file_extension="json"
+ )
+
+ run_obj._lineage_artifact_tracker.add_input_artifact.assert_called_with(
+ name="TestROCCurve",
+ source_uri="s3uri_value",
+ etag="etag_value",
+ artifact_type="ROCCurve",
+ )
+
+
+def test_log_roc_curve_invalid_input(run_obj):
+ y_true = [0, 0, 1, 1]
+ y_scores = [0.1, 0.4, 0.35]
+
+ with run_obj:
+ with pytest.raises(ValueError) as error:
+ run_obj.log_roc_curve(y_true, y_scores, title="TestROCCurve", is_output=False)
+ assert "Lengths mismatch between true labels and predicted scores" in str(error)
+
+
+@patch(
+ "sagemaker.experiments.run._Experiment._load_or_create",
+ MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
+)
+@patch(
+ "sagemaker.experiments.run._Trial._load_or_create",
+ MagicMock(side_effect=mock_trial_load_or_create_func),
+)
+@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
+@patch("sagemaker.experiments.run._TrialComponent._load_or_create")
+@patch("sagemaker.experiments.run._TrialComponent.list")
+@patch("sagemaker.experiments.run._TrialComponent.search")
+def test_list(mock_tc_search, mock_tc_list, mock_tc_load, run_obj, sagemaker_session):
+ start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)
+ end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2)
+ creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3)
+ last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4)
+ tc_list_len = 20
+ tc_list_len_half = int(tc_list_len / 2)
+ mock_tc_search.side_effect = [
+ [
+ TrialComponentSearchResult(
+ trial_component_name=Run._generate_trial_component_name(
+ "a" + str(i), TEST_EXP_NAME
+ ),
+ trial_component_arn="b" + str(i),
+ display_name="C" + str(i),
+ creation_time=creation_time + datetime.timedelta(hours=i),
+ last_modified_time=last_modified_time + datetime.timedelta(hours=i),
+ last_modified_by={},
+ tags=[RUN_TC_TAG] if i < tc_list_len_half else None,
+ )
+ ]
+ for i in range(tc_list_len)
+ ]
+ mock_tc_list.return_value = [
+ TrialComponentSummary(
+ trial_component_name=Run._generate_trial_component_name("A" + str(i), TEST_EXP_NAME),
+ trial_component_arn="b" + str(i),
+ display_name="C" + str(i),
+ source_arn="D" + str(i),
+ status=TrialComponentStatus(
+ primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
+ ),
+ start_time=start_time + datetime.timedelta(hours=i),
+ end_time=end_time + datetime.timedelta(hours=i),
+ creation_time=creation_time + datetime.timedelta(hours=i),
+ last_modified_time=last_modified_time + datetime.timedelta(hours=i),
+ last_modified_by={},
+ )
+ for i in range(tc_list_len)
+ ]
+ mock_tc_load.side_effect = [
+ (
+ _TrialComponent(
+ trial_component_name=Run._generate_trial_component_name(
+ "a" + str(i), TEST_EXP_NAME
+ ),
+ trial_component_arn="b" + str(i),
+ display_name="C" + str(i),
+ source_arn="D" + str(i),
+ status=TrialComponentStatus(
+ primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
+ ),
+ start_time=start_time + datetime.timedelta(hours=i),
+ end_time=end_time + datetime.timedelta(hours=i),
+ creation_time=creation_time + datetime.timedelta(hours=i),
+ last_modified_time=last_modified_time + datetime.timedelta(hours=i),
+ last_modified_by={},
+ ),
+ True,
+ )
+ for i in range(tc_list_len_half)
+ ]
+
+ run_list = list_runs(
+ experiment_name=TEST_EXP_NAME,
+ sort_by=SortByType.CREATION_TIME,
+ sort_order=SortOrderType.ASCENDING,
+ sagemaker_session=sagemaker_session,
+ )
+
+ mock_tc_list.assert_called_once_with(
+ experiment_name=TEST_EXP_NAME,
+ created_before=None,
+ created_after=None,
+ sort_by="CreationTime",
+ sort_order="Ascending",
+ sagemaker_session=sagemaker_session,
+ max_results=None,
+ next_token=None,
+ )
+ assert len(run_list) == tc_list_len_half
+ for i in range(tc_list_len_half):
+ run = run_list[i]
+ assert run.experiment_name == TEST_EXP_NAME
+ assert run.run_name == "a" + str(i)
+ assert run._experiment
+ assert run._trial
+ assert isinstance(run._trial_component, _TrialComponent)
+ assert run._trial_component.trial_component_name == Run._generate_trial_component_name(
+ "a" + str(i), TEST_EXP_NAME
+ )
+ assert run._in_load is False
+ assert run._inside_load_context is False
+ assert run._inside_init_context is False
+ assert run._artifact_uploader
+ assert run._lineage_artifact_tracker
+ assert run._metrics_manager
+
+
+@patch("sagemaker.experiments.run._TrialComponent.list")
+def test_list_empty(mock_tc_list, sagemaker_session):
+ mock_tc_list.return_value = []
+ assert [] == list_runs(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session)
+
+
+@patch(
+ "sagemaker.experiments.run._Experiment._load_or_create",
+ MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
+)
+@patch(
+ "sagemaker.experiments.run._Trial._load_or_create",
+ MagicMock(side_effect=mock_trial_load_or_create_func),
+)
+@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
+@patch("sagemaker.experiments.run._TrialComponent._load_or_create")
+def test_enter_exit_locally(mock_load_tc, sagemaker_session, run_obj):
+ mock_load_tc.return_value = run_obj._trial_component, False
+ sagemaker_session.sagemaker_client.update_trial_component.return_value = {}
+ _verify_tc_status_before_enter_init(run_obj._trial_component)
+
+ with run_obj:
+ _verify_tc_status_when_entering(run_obj._trial_component)
+ init_start_time = run_obj._trial_component.start_time
+
+ with load_run(sagemaker_session=sagemaker_session):
+ _verify_tc_status_when_entering(
+ trial_component=run_obj._trial_component,
+ init_start_time=init_start_time,
+ )
+
+ old_end_time = _verify_tc_status_when_successfully_exit(
+ trial_component=run_obj._trial_component,
+ )
+
+ old_end_time = _verify_tc_status_when_successfully_exit(
+ trial_component=run_obj._trial_component,
+ old_end_time=old_end_time,
+ )
+
+ # Re-load to verify:
+ # 1. if it works when load_run and with are not in one line
+ # 2. if re-entering the load will change the "Completed" TC status
+ # to "InProgress"
+ # 3. when exiting the load, the end_time and status will be overridden again
+ run_load = load_run(
+ experiment_name=run_obj.experiment_name,
+ run_name=run_obj.run_name,
+ sagemaker_session=sagemaker_session,
+ )
+ with run_load:
+ _verify_tc_status_when_entering(
+ trial_component=run_obj._trial_component,
+ init_start_time=init_start_time,
+ has_completed=True,
+ )
+ _verify_tc_status_when_successfully_exit(
+ trial_component=run_obj._trial_component, old_end_time=old_end_time
+ )
+
+
+def test_exit_fail(sagemaker_session, run_obj):
+ sagemaker_session.sagemaker_client.update_trial_component.return_value = {}
+ try:
+ with run_obj:
+ raise ValueError("Foo")
+ except ValueError:
+ pass
+
+ assert run_obj._trial_component.status.primary_status == _TrialComponentStatusType.Failed.value
+ assert run_obj._trial_component.status.message
+ assert isinstance(run_obj._trial_component.end_time, datetime.datetime)
+
+
+@pytest.mark.parametrize(
+ "metric_value",
+ [1.3, "nan", "inf", "-inf", None],
+)
+def test_is_input_valid(run_obj, metric_value):
+ assert run_obj._is_input_valid("metric", "Name", metric_value)
+
+
+@pytest.mark.parametrize(
+ "metric_value",
+ [nan, inf, -inf],
+)
+def test_is_input_valid_false(run_obj, metric_value):
+ assert not run_obj._is_input_valid("parameter", "Name", metric_value)
+
+
+def test_generate_trial_name():
+ base_name = "x" * MAX_NAME_LEN_IN_BACKEND
+ trial_name = Run._generate_trial_name(base_name=base_name)
+ assert len(trial_name) <= MAX_NAME_LEN_IN_BACKEND
+
+
+def test_append_run_tc_label_to_tags():
+ expected_tc_tag = RUN_TC_TAG
+
+ tags = None
+ ret = Run._append_run_tc_label_to_tags(tags)
+ assert len(ret) == 1
+ assert expected_tc_tag in ret
+
+ tags = []
+ ret = Run._append_run_tc_label_to_tags(tags)
+ assert len(ret) == 1
+ assert expected_tc_tag in ret
+
+ tags = [{"Key": "foo", "Value": "bar"}]
+ ret = Run._append_run_tc_label_to_tags(tags)
+ assert len(ret) == 2
+ assert expected_tc_tag in ret
+
+
+def _verify_tc_status_before_enter_init(trial_component):
+ assert not trial_component.start_time
+ assert not trial_component.end_time
+ assert not trial_component.status
+
+
+def _verify_tc_status_when_entering(trial_component, init_start_time=None, has_completed=False):
+ if not init_start_time:
+ assert isinstance(trial_component.start_time, datetime.datetime)
+ now = datetime.datetime.now(dateutil.tz.tzlocal())
+ assert (now.timestamp() - trial_component.start_time.timestamp()) < 1
+ else:
+ assert trial_component.start_time == init_start_time
+
+ if not has_completed:
+ assert not trial_component.end_time
+ assert trial_component.status.primary_status == _TrialComponentStatusType.InProgress.value
+
+
+def _verify_tc_status_when_successfully_exit(trial_component, old_end_time=None):
+ assert trial_component.status.primary_status == _TrialComponentStatusType.Completed.value
+ assert isinstance(trial_component.start_time, datetime.datetime)
+ assert isinstance(trial_component.end_time, datetime.datetime)
+ if old_end_time:
+ assert trial_component.end_time > old_end_time
+ return trial_component.end_time
diff --git a/tests/unit/sagemaker/experiments/test_run_context.py b/tests/unit/sagemaker/experiments/test_run_context.py
new file mode 100644
index 0000000000..7e068136a1
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_run_context.py
@@ -0,0 +1,191 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from unittest.mock import patch, MagicMock
+
+import pytest
+
+from sagemaker.estimator import Estimator, _TrainingJob
+from sagemaker.experiments.experiment import _Experiment
+from sagemaker.experiments.run import _RunContext
+from sagemaker.experiments import load_run, Run
+from sagemaker.experiments.trial import _Trial
+from tests.unit.sagemaker.experiments.helpers import (
+ TEST_EXP_NAME,
+ mock_trial_load_or_create_func,
+ mock_tc_load_or_create_func,
+)
+
+_bucket = "my-bucket"
+_train_input_path = f"s3://{_bucket}/data.csv"
+_train_output_path = f"s3://{_bucket}"
+
+
+@patch.object(_TrainingJob, "start_new")
+def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session):
+ mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job")
+ with run_obj:
+ estimator = Estimator(
+ role="arn:my-role",
+ image_uri="my-image",
+ sagemaker_session=sagemaker_session,
+ output_path=_train_output_path,
+ )
+ estimator.fit(
+ inputs=_train_input_path,
+ wait=False,
+ )
+
+ assert _RunContext.get_current_run() == run_obj
+
+ expected_exp_config = run_obj.experiment_config
+ mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config)
+
+ # _RunContext is cleaned up after exiting the with statement
+ assert not _RunContext.get_current_run()
+
+
+@patch.object(_TrainingJob, "start_new")
+def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_session):
+ mock_start_job.return_value = _TrainingJob(sagemaker_session, "my-job")
+ supplied_exp_cfg = {
+ "ExperimentName": "my-supplied-exp-name",
+ "TrialName": "my-supplied-run-group-name",
+ "RunName": "my-supplied-run-name",
+ }
+ with run_obj:
+ estimator = Estimator(
+ role="arn:my-role",
+ image_uri="my-image",
+ sagemaker_session=sagemaker_session,
+ output_path=_train_output_path,
+ )
+ estimator.fit(
+ experiment_config=supplied_exp_cfg,
+ inputs=_train_input_path,
+ wait=False,
+ )
+
+ assert _RunContext.get_current_run() == run_obj
+
+ mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg)
+
+ # _RunContext is cleaned up after exiting the with statement
+ assert not _RunContext.get_current_run()
+
+
+def test_auto_fetch_created_run_obj_from_context(run_obj, sagemaker_session):
+ assert not run_obj._inside_init_context
+ assert not run_obj._inside_load_context
+ assert not run_obj._in_load
+ assert not _RunContext.get_current_run()
+
+ def train():
+ with load_run(sagemaker_session=sagemaker_session) as run_load:
+ assert run_load == run_obj
+ assert run_obj._inside_init_context
+ assert run_obj._inside_load_context
+ assert run_obj._in_load
+
+ run_load.log_parameter("foo", "bar")
+ run_load.log_parameter("whizz", 1)
+
+ with run_obj:
+ assert run_obj._inside_init_context
+ assert not run_obj._inside_load_context
+ assert not run_obj._in_load
+ assert _RunContext.get_current_run()
+
+ train()
+
+ assert run_obj._inside_init_context
+ assert not run_obj._inside_load_context
+ assert not run_obj._in_load
+ assert _RunContext.get_current_run()
+
+ run_obj.log_parameters({"a": "b", "c": 2})
+
+ assert run_obj._trial_component.parameters["foo"] == "bar"
+ assert run_obj._trial_component.parameters["whizz"] == 1
+ assert run_obj._trial_component.parameters["a"] == "b"
+ assert run_obj._trial_component.parameters["c"] == 2
+
+ # Verify separating load_run and with statement in different lines still work
+ run_load2 = load_run(sagemaker_session=sagemaker_session)
+ with run_load2:
+ assert run_load2 == run_obj
+ assert run_obj._inside_init_context
+ assert run_obj._inside_load_context
+ assert run_obj._in_load
+
+ assert run_obj._inside_init_context
+ assert not run_obj._inside_load_context
+ assert not run_obj._in_load
+ assert _RunContext.get_current_run()
+
+ assert not run_obj._inside_init_context
+ assert not run_obj._inside_load_context
+ assert not run_obj._in_load
+ assert not _RunContext.get_current_run()
+
+
+def test_nested_run_init_context_on_same_run_object(run_obj, sagemaker_session):
+ assert not _RunContext.get_current_run()
+
+ with pytest.raises(RuntimeError) as err:
+ with run_obj:
+ assert _RunContext.get_current_run()
+
+ with run_obj:
+ pass
+ assert "It is not allowed to use nested 'with' statements on the Run" in str(err)
+
+
+@patch(
+ "sagemaker.experiments.run._Experiment._load_or_create",
+ MagicMock(return_value=_Experiment(experiment_name=TEST_EXP_NAME)),
+)
+@patch(
+ "sagemaker.experiments.run._Trial._load_or_create",
+ MagicMock(side_effect=mock_trial_load_or_create_func),
+)
+@patch.object(_Trial, "add_trial_component", MagicMock(return_value=None))
+@patch(
+ "sagemaker.experiments.run._TrialComponent._load_or_create",
+ MagicMock(side_effect=mock_tc_load_or_create_func),
+)
+def test_nested_run_init_context_on_different_run_object(run_obj, sagemaker_session):
+ assert not _RunContext.get_current_run()
+
+ with pytest.raises(RuntimeError) as err:
+ with Run(experiment_name=TEST_EXP_NAME, sagemaker_session=sagemaker_session):
+ assert _RunContext.get_current_run()
+
+ with run_obj:
+ pass
+ assert "It is not allowed to use nested 'with' statements on the Run" in str(err)
+
+
+def test_nested_run_load_context(run_obj, sagemaker_session):
+ assert not _RunContext.get_current_run()
+
+ with pytest.raises(RuntimeError) as err:
+ with run_obj:
+ assert _RunContext.get_current_run()
+
+ with load_run():
+ run_load = load_run()
+ with run_load:
+ pass
+ assert "It is not allowed to use nested 'with' statements on the load_run" in str(err)
diff --git a/tests/unit/sagemaker/experiments/test_trial.py b/tests/unit/sagemaker/experiments/test_trial.py
new file mode 100644
index 0000000000..f6996fefc3
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_trial.py
@@ -0,0 +1,276 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import pytest
+
+import datetime
+
+from unittest.mock import patch
+
+from sagemaker import Session
+from sagemaker.experiments._api_types import TrialSummary
+from sagemaker.experiments.trial import _Trial
+from sagemaker.experiments.trial_component import _TrialComponent
+
+
+@pytest.fixture
+def datetime_obj():
+ return datetime.datetime(2017, 6, 16, 15, 55, 0)
+
+
+def test_load(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.describe_trial.return_value = {"ExperimentName": "experiment-name-value"}
+ trial_obj = _Trial.load(trial_name="name-value", sagemaker_session=sagemaker_session)
+ assert trial_obj.trial_name == "name-value"
+ assert trial_obj.experiment_name == "experiment-name-value"
+ client.describe_trial.assert_called_with(TrialName="name-value")
+
+
+def test_create(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.create_trial.return_value = {
+ "Arn": "arn:aws:1234",
+ "TrialName": "name-value",
+ }
+ trial_obj = _Trial.create(
+ trial_name="name-value",
+ experiment_name="experiment-name-value",
+ sagemaker_session=sagemaker_session,
+ )
+ assert trial_obj.trial_name == "name-value"
+ client.create_trial.assert_called_with(
+ TrialName="name-value", ExperimentName="experiment-name-value"
+ )
+
+
+def test_create_with_tags(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.create_trial.return_value = {
+ "Arn": "arn:aws:1234",
+ "TrialName": "name-value",
+ }
+ tags = [{"Key": "foo", "Value": "bar"}]
+ trial_obj = _Trial.create(
+ trial_name="name-value",
+ experiment_name="experiment-name-value",
+ sagemaker_session=sagemaker_session,
+ tags=tags,
+ )
+ assert trial_obj.trial_name == "name-value"
+ client.create_trial.assert_called_with(
+ TrialName="name-value",
+ ExperimentName="experiment-name-value",
+ Tags=[{"Key": "foo", "Value": "bar"}],
+ )
+
+
+def test_delete(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ obj = _Trial(sagemaker_session, trial_name="foo")
+ client.delete_trial.return_value = {}
+ obj.delete()
+ client.delete_trial.assert_called_with(TrialName="foo")
+
+
+def test_save(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ obj = _Trial(
+ sagemaker_session,
+ trial_name="foo",
+ experiment_name="whizz",
+ display_name="bar",
+ tags=[{"Key": "foo", "Value": "bar"}],
+ )
+ client.update_trial.return_value = {}
+ obj.save()
+
+ client.update_trial.assert_called_with(
+ TrialName="foo",
+ DisplayName="bar",
+ )
+
+
+def test_add_trial_component(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ trial = _Trial(sagemaker_session=sagemaker_session)
+ trial.trial_name = "bar"
+ trial.add_trial_component("foo")
+ client.associate_trial_component.assert_called_with(TrialName="bar", TrialComponentName="foo")
+
+ tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session)
+ trial.add_trial_component(tc)
+ client.associate_trial_component.assert_called_with(
+ TrialName="bar", TrialComponentName=tc.trial_component_name
+ )
+
+
+def test_remove_trial_component(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ trial = _Trial(sagemaker_session=sagemaker_session)
+ trial.trial_name = "bar"
+ trial.remove_trial_component("foo")
+ client.disassociate_trial_component.assert_called_with(
+ TrialName="bar", TrialComponentName="foo"
+ )
+
+ tc = _TrialComponent(trial_component_name="tc-foo", sagemaker_session=sagemaker_session)
+ trial.remove_trial_component(tc)
+ client.disassociate_trial_component.assert_called_with(
+ TrialName="bar", TrialComponentName=tc.trial_component_name
+ )
+
+
+@patch("sagemaker.experiments.trial._Trial.load")
+def test_load_or_create_when_exist(mock_load):
+ sagemaker_session = Session()
+ trial_name = "trial_name"
+ exp_name = "exp_name"
+
+ # The trial exists and experiment matches
+ mock_load.return_value = _Trial(
+ trial_name=trial_name,
+ experiment_name=exp_name,
+ sagemaker_session=sagemaker_session,
+ )
+ _Trial._load_or_create(
+ trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session
+ )
+ mock_load.assert_called_once_with(trial_name, sagemaker_session)
+
+ # The trial exists but experiment does not match
+ mock_load.return_value = _Trial(
+ trial_name=trial_name,
+ exp_name="another_exp_name",
+ sagemaker_session=sagemaker_session,
+ )
+ with pytest.raises(ValueError) as err:
+ _Trial._load_or_create(
+ trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session
+ )
+ assert "The given experiment_name {} does not match that in the loaded trial".format(
+ exp_name
+ ) in str(err)
+
+
+@patch("sagemaker.experiments.trial._Trial.load")
+@patch("sagemaker.experiments.trial._Trial.create")
+def test_load_or_create_when_not_exist(mock_create, mock_load):
+ sagemaker_session = Session()
+ client = sagemaker_session.sagemaker_client
+ trial_name = "trial_name"
+ exp_name = "exp_name"
+ not_found_err = client.exceptions.ResourceNotFound(
+ error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}},
+ operation_name="foo",
+ )
+ mock_load.side_effect = not_found_err
+
+ _Trial._load_or_create(
+ trial_name=trial_name, experiment_name=exp_name, sagemaker_session=sagemaker_session
+ )
+
+ mock_create.assert_called_once_with(
+ trial_name=trial_name,
+ experiment_name=exp_name,
+ display_name=None,
+ tags=None,
+ sagemaker_session=sagemaker_session,
+ )
+
+
+def test_list_trials_without_experiment_name(sagemaker_session, datetime_obj):
+ client = sagemaker_session.sagemaker_client
+ client.list_trials.return_value = {
+ "TrialSummaries": [
+ {
+ "TrialName": "trial-1",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ {
+ "TrialName": "trial-2",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ ]
+ }
+ expected = [
+ TrialSummary(
+ trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ TrialSummary(
+ trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ ]
+ assert expected == list(_Trial.list(sagemaker_session=sagemaker_session))
+ client.list_trials.assert_called_with(**{})
+
+
+def test_list_trials_with_experiment_name(sagemaker_session, datetime_obj):
+ client = sagemaker_session.sagemaker_client
+ client.list_trials.return_value = {
+ "TrialSummaries": [
+ {
+ "TrialName": "trial-1",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ {
+ "TrialName": "trial-2",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ ]
+ }
+ expected = [
+ TrialSummary(
+ trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ TrialSummary(
+ trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ ]
+ assert expected == list(_Trial.list(experiment_name="foo", sagemaker_session=sagemaker_session))
+ client.list_trials.assert_called_with(ExperimentName="foo")
+
+
+def test_list_trials_with_trial_component_name(sagemaker_session, datetime_obj):
+ client = sagemaker_session.sagemaker_client
+ client.list_trials.return_value = {
+ "TrialSummaries": [
+ {
+ "TrialName": "trial-1",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ {
+ "TrialName": "trial-2",
+ "CreationTime": datetime_obj,
+ "LastModifiedTime": datetime_obj,
+ },
+ ]
+ }
+ expected = [
+ TrialSummary(
+ trial_name="trial-1", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ TrialSummary(
+ trial_name="trial-2", creation_time=datetime_obj, last_modified_time=datetime_obj
+ ),
+ ]
+ assert expected == list(
+ _Trial.list(trial_component_name="tc-foo", sagemaker_session=sagemaker_session)
+ )
+ client.list_trials.assert_called_with(TrialComponentName="tc-foo")
diff --git a/tests/unit/sagemaker/experiments/test_trial_component.py b/tests/unit/sagemaker/experiments/test_trial_component.py
new file mode 100644
index 0000000000..c14663893e
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_trial_component.py
@@ -0,0 +1,384 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import datetime
+import unittest.mock
+
+from unittest.mock import patch
+
+from sagemaker import Session
+from sagemaker.experiments import _api_types
+from sagemaker.experiments._api_types import (
+ TrialComponentSearchResult,
+ Parent,
+ _TrialComponentStatusType,
+)
+from sagemaker.experiments.trial_component import _TrialComponent
+
+
+def test_create(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.create_trial_component.return_value = {
+ "TrialComponentArn": "bazz",
+ }
+ obj = _TrialComponent.create(
+ trial_component_name="foo", display_name="bar", sagemaker_session=sagemaker_session
+ )
+ client.create_trial_component.assert_called_with(TrialComponentName="foo", DisplayName="bar")
+ assert "foo" == obj.trial_component_name
+ assert "bar" == obj.display_name
+ assert "bazz" == obj.trial_component_arn
+
+
+def test_create_with_tags(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.create_trial_component.return_value = {
+ "TrialComponentArn": "bazz",
+ }
+ tags = [{"Key": "foo", "Value": "bar"}]
+ _TrialComponent.create(
+ trial_component_name="foo",
+ display_name="bar",
+ sagemaker_session=sagemaker_session,
+ tags=tags,
+ )
+ client.create_trial_component.assert_called_with(
+ TrialComponentName="foo", DisplayName="bar", Tags=tags
+ )
+
+
+def test_load(sagemaker_session):
+ now = datetime.datetime.now(datetime.timezone.utc)
+ client = sagemaker_session.sagemaker_client
+ client.describe_trial_component.return_value = {
+ "TrialComponentArn": "A",
+ "TrialComponentName": "B",
+ "DisplayName": "C",
+ "Status": {"PrimaryStatus": _TrialComponentStatusType.InProgress.value, "Message": "D"},
+ "Parameters": {"E": {"NumberValue": 1.0}, "F": {"StringValue": "G"}},
+ "InputArtifacts": {"H": {"Value": "s3://foo/bar", "MediaType": "text/plain"}},
+ "OutputArtifacts": {"I": {"Value": "s3://whizz/bang", "MediaType": "text/plain"}},
+ "Metrics": [
+ {
+ "MetricName": "J",
+ "Count": 1,
+ "Min": 1.0,
+ "Max": 2.0,
+ "Avg": 3.0,
+ "StdDev": 4.0,
+ "SourceArn": "K",
+ "Timestamp": now,
+ }
+ ],
+ }
+ obj = _TrialComponent.load(trial_component_name="foo", sagemaker_session=sagemaker_session)
+ client.describe_trial_component.assert_called_with(TrialComponentName="foo")
+ assert "A" == obj.trial_component_arn
+ assert "B" == obj.trial_component_name
+ assert "C" == obj.display_name
+ assert (
+ _api_types.TrialComponentStatus(
+ primary_status=_TrialComponentStatusType.InProgress.value, message="D"
+ )
+ == obj.status
+ )
+ assert {"E": 1.0, "F": "G"} == obj.parameters
+ assert {"H": _api_types.TrialComponentArtifact(value="s3://foo/bar", media_type="text/plain")}
+ assert {
+ "I": _api_types.TrialComponentArtifact(value="s3://whizz/bang", media_type="text/plain")
+ }
+ assert [
+ _api_types.TrialComponentMetricSummary(
+ metric_name="J",
+ count=1,
+ min=1.0,
+ max=2.0,
+ avg=3.0,
+ std_dev=4.0,
+ source_arn="K",
+ timestamp=now,
+ )
+ ]
+
+
+def test_save(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ obj = _TrialComponent(
+ sagemaker_session,
+ trial_component_name="foo",
+ display_name="bar",
+ parameters_to_remove=["E"],
+ input_artifacts_to_remove=["F"],
+ output_artifacts_to_remove=["G"],
+ )
+ client.update_trial_component.return_value = {}
+ obj.save()
+
+ client.update_trial_component.assert_called_with(
+ TrialComponentName="foo",
+ DisplayName="bar",
+ Parameters={},
+ ParametersToRemove=["E"],
+ InputArtifacts={},
+ InputArtifactsToRemove=["F"],
+ OutputArtifacts={},
+ OutputArtifactsToRemove=["G"],
+ )
+
+
+def test_delete(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar")
+ client.delete_trial_component.return_value = {}
+ obj.delete()
+ client.delete_trial_component.assert_called_with(TrialComponentName="foo")
+
+
+def test_delete_with_force_disassociate(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ obj = _TrialComponent(sagemaker_session, trial_component_name="foo", display_name="bar")
+ client.delete_trial_component.return_value = {}
+
+ client.list_trials.side_effect = [
+ {"TrialSummaries": [{"TrialName": "trial-1"}, {"TrialName": "trial-2"}], "NextToken": "a"},
+ {"TrialSummaries": [{"TrialName": "trial-3"}, {"TrialName": "trial-4"}]},
+ ]
+
+ obj.delete(force_disassociate=True)
+ expected_calls = [
+ unittest.mock.call(TrialName="trial-1", TrialComponentName="foo"),
+ unittest.mock.call(TrialName="trial-2", TrialComponentName="foo"),
+ unittest.mock.call(TrialName="trial-3", TrialComponentName="foo"),
+ unittest.mock.call(TrialName="trial-4", TrialComponentName="foo"),
+ ]
+ assert expected_calls == client.disassociate_trial_component.mock_calls
+ client.delete_trial_component.assert_called_with(TrialComponentName="foo")
+
+
+def test_list(sagemaker_session):
+ start_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=1)
+ end_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=2)
+ creation_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=3)
+ last_modified_time = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(hours=4)
+
+ client = sagemaker_session.sagemaker_client
+ client.list_trial_components.side_effect = [
+ {
+ "TrialComponentSummaries": [
+ {
+ "TrialComponentName": "A" + str(i),
+ "TrialComponentArn": "B" + str(i),
+ "DisplayName": "C" + str(i),
+ "SourceArn": "D" + str(i),
+ "Status": {
+ "PrimaryStatus": _TrialComponentStatusType.InProgress.value,
+ "Message": "E" + str(i),
+ },
+ "StartTime": start_time + datetime.timedelta(hours=i),
+ "EndTime": end_time + datetime.timedelta(hours=i),
+ "CreationTime": creation_time + datetime.timedelta(hours=i),
+ "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i),
+ "LastModifiedBy": {},
+ }
+ for i in range(10)
+ ],
+ "NextToken": "100",
+ },
+ {
+ "TrialComponentSummaries": [
+ {
+ "TrialComponentName": "A" + str(i),
+ "TrialComponentArn": "B" + str(i),
+ "DisplayName": "C" + str(i),
+ "SourceArn": "D" + str(i),
+ "Status": {
+ "PrimaryStatus": _TrialComponentStatusType.InProgress.value,
+ "Message": "E" + str(i),
+ },
+ "StartTime": start_time + datetime.timedelta(hours=i),
+ "EndTime": end_time + datetime.timedelta(hours=i),
+ "CreationTime": creation_time + datetime.timedelta(hours=i),
+ "LastModifiedTime": last_modified_time + datetime.timedelta(hours=i),
+ "LastModifiedBy": {},
+ }
+ for i in range(10, 20)
+ ]
+ },
+ ]
+
+ expected = [
+ _api_types.TrialComponentSummary(
+ trial_component_name="A" + str(i),
+ trial_component_arn="B" + str(i),
+ display_name="C" + str(i),
+ source_arn="D" + str(i),
+ status=_api_types.TrialComponentStatus(
+ primary_status=_TrialComponentStatusType.InProgress.value, message="E" + str(i)
+ ),
+ start_time=start_time + datetime.timedelta(hours=i),
+ end_time=end_time + datetime.timedelta(hours=i),
+ creation_time=creation_time + datetime.timedelta(hours=i),
+ last_modified_time=last_modified_time + datetime.timedelta(hours=i),
+ last_modified_by={},
+ )
+ for i in range(20)
+ ]
+ result = list(
+ _TrialComponent.list(
+ sagemaker_session=sagemaker_session,
+ source_arn="foo",
+ sort_by="CreationTime",
+ sort_order="Ascending",
+ )
+ )
+
+ assert expected == result
+ expected_calls = [
+ unittest.mock.call(SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"),
+ unittest.mock.call(
+ NextToken="100", SortBy="CreationTime", SortOrder="Ascending", SourceArn="foo"
+ ),
+ ]
+ assert expected_calls == client.list_trial_components.mock_calls
+
+
+def test_list_empty(sagemaker_session):
+ sagemaker_session.sagemaker_client.list_trial_components.return_value = {
+ "TrialComponentSummaries": []
+ }
+ assert [] == list(_TrialComponent.list(sagemaker_session=sagemaker_session))
+
+
+def test_list_trial_components_call_args(sagemaker_session):
+ created_before = datetime.datetime(1999, 10, 12, 0, 0, 0)
+ created_after = datetime.datetime(1990, 10, 12, 0, 0, 0)
+ trial_name = "foo-trial"
+ experiment_name = "foo-experiment"
+ next_token = "thetoken"
+ max_results = 99
+
+ client = sagemaker_session.sagemaker_client
+ client.list_trial_components.return_value = {}
+ assert [] == list(
+ _TrialComponent.list(
+ sagemaker_session=sagemaker_session,
+ trial_name=trial_name,
+ experiment_name=experiment_name,
+ created_before=created_before,
+ created_after=created_after,
+ next_token=next_token,
+ max_results=max_results,
+ sort_by="CreationTime",
+ sort_order="Ascending",
+ )
+ )
+
+ expected_calls = [
+ unittest.mock.call(
+ TrialName="foo-trial",
+ ExperimentName="foo-experiment",
+ CreatedBefore=created_before,
+ CreatedAfter=created_after,
+ SortBy="CreationTime",
+ SortOrder="Ascending",
+ NextToken="thetoken",
+ MaxResults=99,
+ )
+ ]
+ assert expected_calls == client.list_trial_components.mock_calls
+
+
+@patch("sagemaker.experiments.trial_component._TrialComponent.load")
+def test_load_or_create_when_exist(mock_load, sagemaker_session):
+ tc_name = "tc_name"
+ _, is_existed = _TrialComponent._load_or_create(
+ trial_component_name=tc_name, sagemaker_session=sagemaker_session
+ )
+ assert is_existed
+ mock_load.assert_called_once_with(
+ tc_name,
+ sagemaker_session,
+ )
+
+
+@patch("sagemaker.experiments.trial_component._TrialComponent.load")
+@patch("sagemaker.experiments.trial_component._TrialComponent.create")
+def test_load_or_create_when_not_exist(mock_create, mock_load):
+ sagemaker_session = Session()
+ client = sagemaker_session.sagemaker_client
+ tc_name = "tc_name"
+ not_found_err = client.exceptions.ResourceNotFound(
+ error_response={"Error": {"Code": "ResourceNotFound", "Message": "Not Found"}},
+ operation_name="foo",
+ )
+ mock_load.side_effect = not_found_err
+
+ _, is_existed = _TrialComponent._load_or_create(
+ trial_component_name=tc_name, sagemaker_session=sagemaker_session
+ )
+
+ assert not is_existed
+ mock_create.assert_called_once_with(
+ trial_component_name=tc_name,
+ display_name=None,
+ tags=None,
+ sagemaker_session=sagemaker_session,
+ )
+
+
+def test_search(sagemaker_session):
+ client = sagemaker_session.sagemaker_client
+ client.search.return_value = {
+ "Results": [
+ {
+ "TrialComponent": {
+ "TrialComponentName": "tc-1",
+ "TrialComponentArn": "arn::tc-1",
+ "DisplayName": "TC1",
+ "Parents": [
+ {
+ "ExperimentName": "e-1",
+ "TrialName": "t-1",
+ },
+ {
+ "ExperimentName": "e-2",
+ "TrialName": "t-2",
+ },
+ ],
+ }
+ },
+ {
+ "TrialComponent": {
+ "TrialComponentName": "tc-2",
+ "TrialComponentArn": "arn::tc-2",
+ "DisplayName": "TC2",
+ }
+ },
+ ]
+ }
+ expected = [
+ TrialComponentSearchResult(
+ trial_component_name="tc-1",
+ trial_component_arn="arn::tc-1",
+ display_name="TC1",
+ parents=[
+ Parent(experiment_name="e-1", trial_name="t-1"),
+ Parent(experiment_name="e-2", trial_name="t-2"),
+ ],
+ ),
+ TrialComponentSearchResult(
+ trial_component_name="tc-2", trial_component_arn="arn::tc-2", display_name="TC2"
+ ),
+ ]
+ assert expected == list(_TrialComponent.search(sagemaker_session=sagemaker_session))
diff --git a/tests/unit/sagemaker/experiments/test_utils.py b/tests/unit/sagemaker/experiments/test_utils.py
new file mode 100644
index 0000000000..a63c96c0fe
--- /dev/null
+++ b/tests/unit/sagemaker/experiments/test_utils.py
@@ -0,0 +1,36 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from src.sagemaker.experiments._utils import resolve_artifact_name, guess_media_type
+
+
+def test_resolve_artifact_name():
+ file_names = {
+ "a": "a",
+ "a.txt": "a.txt",
+ "b.": "b.",
+ ".c": ".c",
+ "/x/a/a.txt": "a.txt",
+ "/a/b/c.": "c.",
+ "./.a": ".a",
+ "../b.txt": "b.txt",
+ "~/a.txt": "a.txt",
+ "c/d.txt": "d.txt",
+ }
+ for file_name, artifact_name in file_names.items():
+ assert artifact_name == resolve_artifact_name(file_name)
+
+
+def test_guess_media_type():
+ assert "text/plain" == guess_media_type("foo.txt")
diff --git a/tests/unit/sagemaker/feature_store/test_dataset_builder.py b/tests/unit/sagemaker/feature_store/test_dataset_builder.py
new file mode 100644
index 0000000000..0e55b86bd0
--- /dev/null
+++ b/tests/unit/sagemaker/feature_store/test_dataset_builder.py
@@ -0,0 +1,612 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import datetime
+
+import pandas as pd
+import pytest
+import os
+from mock import Mock, patch
+
+from sagemaker.feature_store.dataset_builder import (
+ DatasetBuilder,
+ FeatureGroupToBeMerged,
+ TableType,
+)
+from sagemaker.feature_store.feature_group import (
+ FeatureDefinition,
+ FeatureGroup,
+ FeatureTypeEnum,
+)
+
+
+@pytest.fixture
+def sagemaker_session_mock():
+ return Mock()
+
+
+@pytest.fixture
+def feature_group_mock():
+ return Mock()
+
+
+@pytest.fixture
+def read_csv_mock():
+ return Mock()
+
+
+@pytest.fixture
+def to_csv_file_mock():
+ return Mock()
+
+
+@pytest.fixture
+def remove_mock():
+ return Mock()
+
+
+BASE = FeatureGroupToBeMerged(
+ ["target-feature", "other-feature"],
+ ["target-feature", "other-feature"],
+ ["target-feature", "other-feature"],
+ "catalog",
+ "database",
+ "base-table",
+ "target-feature",
+ FeatureDefinition("other-feature", FeatureTypeEnum.STRING),
+ None,
+ TableType.FEATURE_GROUP,
+)
+FEATURE_GROUP = FeatureGroupToBeMerged(
+ ["feature-1", "feature-2"],
+ ["feature-1", "feature-2"],
+ ["feature-1", "feature-2"],
+ "catalog",
+ "database",
+ "table-name",
+ "feature-1",
+ FeatureDefinition("feature-2", FeatureTypeEnum.FRACTIONAL),
+ "target-feature",
+ TableType.FEATURE_GROUP,
+)
+
+
+def test_with_feature_group_throw_runtime_error(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group,
+ output_path="file/to/path",
+ )
+ sagemaker_session_mock.describe_feature_group.return_value = {"OfflineStoreConfig": {}}
+ with pytest.raises(RuntimeError) as error:
+ dataset_builder.with_feature_group(
+ feature_group, "target-feature", ["feature-1", "feature-2"]
+ )
+ assert "No metastore is configured with FeatureGroup MyFeatureGroup." in str(error)
+
+
+def test_with_feature_group(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]})
+ feature_group.load_feature_definitions(dataframe)
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group,
+ output_path="file/to/path",
+ record_identifier_feature_name="target-feature",
+ )
+ sagemaker_session_mock.describe_feature_group.return_value = {
+ "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}},
+ "RecordIdentifierFeatureName": "feature-1",
+ "EventTimeFeatureName": "feature-2",
+ "FeatureDefinitions": [
+ {"FeatureName": "feature-1", "FeatureType": "String"},
+ {"FeatureName": "feature-2", "FeatureType": "String"},
+ ],
+ }
+ dataset_builder.with_feature_group(feature_group, "target-feature", ["feature-1", "feature-2"])
+ assert len(dataset_builder._feature_groups_to_be_merged) == 1
+ assert dataset_builder._feature_groups_to_be_merged[0].features == [
+ "feature-1",
+ "feature-2",
+ ]
+ assert dataset_builder._feature_groups_to_be_merged[0].included_feature_names == [
+ "feature-1",
+ "feature-2",
+ ]
+ assert dataset_builder._feature_groups_to_be_merged[0].database == "database"
+ assert dataset_builder._feature_groups_to_be_merged[0].table_name == "table"
+ assert (
+ dataset_builder._feature_groups_to_be_merged[0].record_identifier_feature_name
+ == "feature-1"
+ )
+ assert (
+ dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_name
+ == "feature-2"
+ )
+ assert (
+ dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_type
+ == FeatureTypeEnum.STRING
+ )
+ assert (
+ dataset_builder._feature_groups_to_be_merged[0].target_feature_name_in_base
+ == "target-feature"
+ )
+
+
+def test_point_in_time_accurate_join(sagemaker_session_mock, feature_group_mock):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ dataset_builder.point_in_time_accurate_join()
+ assert dataset_builder._point_in_time_accurate_join
+
+
+def test_include_duplicated_records(sagemaker_session_mock, feature_group_mock):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ dataset_builder.include_duplicated_records()
+ assert dataset_builder._include_duplicated_records
+
+
+def test_include_deleted_records(sagemaker_session_mock, feature_group_mock):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ dataset_builder.include_deleted_records()
+ assert dataset_builder._include_deleted_records
+
+
+def test_with_number_of_recent_records_by_record_identifier(
+ sagemaker_session_mock, feature_group_mock
+):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ dataset_builder.with_number_of_recent_records_by_record_identifier(5)
+ assert dataset_builder._number_of_recent_records == 5
+
+
+def test_with_number_of_records_from_query_results(sagemaker_session_mock, feature_group_mock):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ dataset_builder.with_number_of_records_from_query_results(100)
+ assert dataset_builder._number_of_records == 100
+
+
+def test_with_event_time_range(sagemaker_session_mock, feature_group_mock):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ start = datetime.datetime.now()
+ end = start + datetime.timedelta(minutes=1)
+ dataset_builder.with_event_time_range(start, end)
+ assert dataset_builder._event_time_starting_timestamp == start
+ assert dataset_builder._event_time_ending_timestamp == end
+
+
+def test_to_csv_file_not_support_base_type(sagemaker_session_mock, feature_group_mock):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ with pytest.raises(ValueError) as error:
+ dataset_builder.to_csv_file()
+ assert "Base must be either a FeatureGroup or a DataFrame." in str(error)
+
+
+def test_to_csv_file_with_feature_group(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group,
+ output_path="file/to/path",
+ )
+ sagemaker_session_mock.describe_feature_group.return_value = {
+ "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}},
+ "RecordIdentifierFeatureName": "feature-1",
+ "EventTimeFeatureName": "feature-2",
+ "FeatureDefinitions": [
+ {"FeatureName": "feature-1", "FeatureType": "String"},
+ {"FeatureName": "feature-2", "FeatureType": "String"},
+ ],
+ }
+ sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"}
+ sagemaker_session_mock.get_query_execution.return_value = {
+ "QueryExecution": {
+ "Status": {"State": "SUCCEEDED"},
+ "ResultConfiguration": {"OutputLocation": "s3-file-path"},
+ "Query": "query-string",
+ }
+ }
+ file_path, query_string = dataset_builder.to_csv_file()
+ assert file_path == "s3-file-path"
+ assert query_string == "query-string"
+
+
+@patch("pandas.DataFrame.to_csv")
+@patch("pandas.read_csv")
+@patch("os.remove")
+def test_to_dataframe_with_dataframe(
+ remove_mock, read_csv_mock, to_csv_file_mock, sagemaker_session_mock
+):
+ dataframe = pd.DataFrame({"feature-1": [420, 380.0, 390], "feature-2": [50, 40.0, 45]})
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=dataframe,
+ output_path="s3://file/to/path",
+ event_time_identifier_feature_name="feature-2",
+ )
+ sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"}
+ sagemaker_session_mock.get_query_execution.return_value = {
+ "QueryExecution": {
+ "Status": {"State": "SUCCEEDED"},
+ "ResultConfiguration": {"OutputLocation": "s3://s3-file-path"},
+ "Query": "query-string",
+ }
+ }
+ to_csv_file_mock.return_value = None
+ read_csv_mock.return_value = dataframe
+ os.remove.return_value = None
+ df, query_string = dataset_builder.to_dataframe()
+ assert df.equals(dataframe)
+ assert query_string == "query-string"
+
+
+def test_construct_where_query_string(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group,
+ output_path="file/to/path",
+ )
+ time = datetime.datetime.now().replace(microsecond=0)
+ start = time + datetime.timedelta(minutes=1)
+ end = start + datetime.timedelta(minutes=1)
+ dataset_builder._write_time_ending_timestamp = time
+ dataset_builder._event_time_starting_timestamp = start
+ dataset_builder._event_time_ending_timestamp = end
+ query_string = dataset_builder._construct_where_query_string(
+ "suffix",
+ FeatureDefinition("event-time", FeatureTypeEnum.STRING),
+ ["NOT is_deleted"],
+ )
+ assert (
+ query_string
+ == "WHERE NOT is_deleted\n"
+ + f"AND table_suffix.\"write_time\" <= to_timestamp('{time}', "
+ + "'yyyy-mm-dd hh24:mi:ss')\n"
+ + 'AND from_iso8601_timestamp(table_suffix."event-time") >= '
+ + f"from_unixtime({start.timestamp()})\n"
+ + 'AND from_iso8601_timestamp(table_suffix."event-time") <= '
+ + f"from_unixtime({end.timestamp()})"
+ )
+
+
+def test_construct_query_string_with_duplicated_records(sagemaker_session_mock, feature_group_mock):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ dataset_builder._include_duplicated_records = True
+
+ dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP]
+ query_string = dataset_builder._construct_query_string(BASE)
+ assert (
+ query_string
+ == "WITH fg_base AS (WITH deleted_base AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_base."target-feature"\n'
+ + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" DESC, '
+ + 'origin_base."write_time" DESC\n'
+ + ") AS deleted_row_base\n"
+ + 'FROM "database"."base-table" origin_base\n'
+ + "WHERE is_deleted\n"
+ + ")\n"
+ + "WHERE deleted_row_base = 1\n"
+ + ")\n"
+ + 'SELECT table_base."target-feature", table_base."other-feature"\n'
+ + "FROM (\n"
+ + 'SELECT table_base."target-feature", table_base."other-feature", '
+ + 'table_base."write_time"\n'
+ + 'FROM "database"."base-table" table_base\n'
+ + "LEFT JOIN deleted_base\n"
+ + 'ON table_base."target-feature" = deleted_base."target-feature"\n'
+ + 'WHERE deleted_base."target-feature" IS NULL\n'
+ + "UNION ALL\n"
+ + 'SELECT table_base."target-feature", table_base."other-feature", '
+ + 'table_base."write_time"\n'
+ + "FROM deleted_base\n"
+ + 'JOIN "database"."base-table" table_base\n'
+ + 'ON table_base."target-feature" = deleted_base."target-feature"\n'
+ + "AND (\n"
+ + 'table_base."other-feature" > deleted_base."other-feature"\n'
+ + 'OR (table_base."other-feature" = deleted_base."other-feature" AND '
+ + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n'
+ + 'OR (table_base."other-feature" = deleted_base."other-feature" AND '
+ + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND '
+ + 'table_base."write_time" > deleted_base."write_time")\n'
+ + ")\n"
+ + ") AS table_base\n"
+ + "),\n"
+ + "fg_0 AS (WITH deleted_0 AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_0."feature-1"\n'
+ + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, '
+ + 'origin_0."write_time" DESC\n'
+ + ") AS deleted_row_0\n"
+ + 'FROM "database"."table-name" origin_0\n'
+ + "WHERE is_deleted\n"
+ + ")\n"
+ + "WHERE deleted_row_0 = 1\n"
+ + ")\n"
+ + 'SELECT table_0."feature-1", table_0."feature-2"\n'
+ + "FROM (\n"
+ + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n'
+ + 'FROM "database"."table-name" table_0\n'
+ + "LEFT JOIN deleted_0\n"
+ + 'ON table_0."feature-1" = deleted_0."feature-1"\n'
+ + 'WHERE deleted_0."feature-1" IS NULL\n'
+ + "UNION ALL\n"
+ + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n'
+ + "FROM deleted_0\n"
+ + 'JOIN "database"."table-name" table_0\n'
+ + 'ON table_0."feature-1" = deleted_0."feature-1"\n'
+ + "AND (\n"
+ + 'table_0."feature-2" > deleted_0."feature-2"\n'
+ + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" > '
+ + 'deleted_0."api_invocation_time")\n'
+ + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" = '
+ + 'deleted_0."api_invocation_time" AND table_0."write_time" > deleted_0."write_time")\n'
+ + ")\n"
+ + ") AS table_0\n"
+ + ")\n"
+ + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n'
+ + "FROM (\n"
+ + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as '
+ + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n'
+ + 'PARTITION BY fg_base."target-feature"\n'
+ + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n'
+ + ") AS row_recent\n"
+ + "FROM fg_base\n"
+ + "JOIN fg_0\n"
+ + 'ON fg_base."target-feature" = fg_0."feature-1"\n'
+ + ")\n"
+ )
+
+
+def test_construct_query_string(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group,
+ output_path="file/to/path",
+ )
+ dataset_builder._point_in_time_accurate_join = True
+ dataset_builder._event_time_identifier_feature_name = "target-feature"
+ dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP]
+ query_string = dataset_builder._construct_query_string(BASE)
+ assert (
+ query_string
+ == "WITH fg_base AS (WITH table_base AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_base."target-feature", origin_base."other-feature"\n'
+ + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n'
+ + ") AS dedup_row_base\n"
+ + 'FROM "database"."base-table" origin_base\n'
+ + ")\n"
+ + "WHERE dedup_row_base = 1\n"
+ + "),\n"
+ + "deleted_base AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_base."target-feature"\n'
+ + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" '
+ + 'DESC, origin_base."write_time" DESC\n'
+ + ") AS deleted_row_base\n"
+ + 'FROM "database"."base-table" origin_base\n'
+ + "WHERE is_deleted\n"
+ + ")\n"
+ + "WHERE deleted_row_base = 1\n"
+ + ")\n"
+ + 'SELECT table_base."target-feature", table_base."other-feature"\n'
+ + "FROM (\n"
+ + 'SELECT table_base."target-feature", table_base."other-feature", '
+ + 'table_base."write_time"\n'
+ + "FROM table_base\n"
+ + "LEFT JOIN deleted_base\n"
+ + 'ON table_base."target-feature" = deleted_base."target-feature"\n'
+ + 'WHERE deleted_base."target-feature" IS NULL\n'
+ + "UNION ALL\n"
+ + 'SELECT table_base."target-feature", table_base."other-feature", '
+ + 'table_base."write_time"\n'
+ + "FROM deleted_base\n"
+ + "JOIN table_base\n"
+ + 'ON table_base."target-feature" = deleted_base."target-feature"\n'
+ + "AND (\n"
+ + 'table_base."other-feature" > deleted_base."other-feature"\n'
+ + 'OR (table_base."other-feature" = deleted_base."other-feature" AND '
+ + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n'
+ + 'OR (table_base."other-feature" = deleted_base."other-feature" AND '
+ + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND '
+ + 'table_base."write_time" > deleted_base."write_time")\n'
+ + ")\n"
+ + ") AS table_base\n"
+ + "),\n"
+ + "fg_0 AS (WITH table_0 AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_0."feature-1", origin_0."feature-2"\n'
+ + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n'
+ + ") AS dedup_row_0\n"
+ + 'FROM "database"."table-name" origin_0\n'
+ + ")\n"
+ + "WHERE dedup_row_0 = 1\n"
+ + "),\n"
+ + "deleted_0 AS (\n"
+ + "SELECT *\n"
+ + "FROM (\n"
+ + "SELECT *, row_number() OVER (\n"
+ + 'PARTITION BY origin_0."feature-1"\n'
+ + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, '
+ + 'origin_0."write_time" DESC\n'
+ + ") AS deleted_row_0\n"
+ + 'FROM "database"."table-name" origin_0\n'
+ + "WHERE is_deleted\n"
+ + ")\n"
+ + "WHERE deleted_row_0 = 1\n"
+ + ")\n"
+ + 'SELECT table_0."feature-1", table_0."feature-2"\n'
+ + "FROM (\n"
+ + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n'
+ + "FROM table_0\n"
+ + "LEFT JOIN deleted_0\n"
+ + 'ON table_0."feature-1" = deleted_0."feature-1"\n'
+ + 'WHERE deleted_0."feature-1" IS NULL\n'
+ + "UNION ALL\n"
+ + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n'
+ + "FROM deleted_0\n"
+ + "JOIN table_0\n"
+ + 'ON table_0."feature-1" = deleted_0."feature-1"\n'
+ + "AND (\n"
+ + 'table_0."feature-2" > deleted_0."feature-2"\n'
+ + 'OR (table_0."feature-2" = deleted_0."feature-2" AND '
+ + 'table_0."api_invocation_time" > deleted_0."api_invocation_time")\n'
+ + 'OR (table_0."feature-2" = deleted_0."feature-2" AND '
+ + 'table_0."api_invocation_time" = deleted_0."api_invocation_time" AND '
+ + 'table_0."write_time" > deleted_0."write_time")\n'
+ + ")\n"
+ + ") AS table_0\n"
+ + ")\n"
+ + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n'
+ + "FROM (\n"
+ + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as '
+ + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n'
+ + 'PARTITION BY fg_base."target-feature"\n'
+ + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n'
+ + ") AS row_recent\n"
+ + "FROM fg_base\n"
+ + "JOIN fg_0\n"
+ + 'ON fg_base."target-feature" = fg_0."feature-1"\n'
+ + 'AND from_unixtime(fg_base."target-feature") >= from_unixtime(fg_0."feature-2")\n'
+ + ")\n"
+ )
+
+
+def test_create_temp_table(sagemaker_session_mock):
+ dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]})
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=dataframe,
+ output_path="file/to/path",
+ )
+ sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"}
+ sagemaker_session_mock.get_query_execution.return_value = {
+ "QueryExecution": {"Status": {"State": "SUCCEEDED"}}
+ }
+ dataset_builder._create_temp_table("table-name", "s3-folder")
+ assert sagemaker_session_mock.start_query_execution.call_count == 1
+ sagemaker_session_mock.start_query_execution.assert_called_once_with(
+ catalog="AwsDataCatalog",
+ database="sagemaker_featurestore",
+ query_string="CREATE EXTERNAL TABLE table-name (feature-1 INT, feature-2 INT) "
+ + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' "
+ + 'WITH SERDEPROPERTIES ("separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\") '
+ + "LOCATION 's3-folder';",
+ output_location="file/to/path",
+ kms_key=None,
+ )
+
+
+@pytest.mark.parametrize(
+ "column, expected",
+ [
+ ("feature-1", "feature-1 STRING"),
+ ("feature-2", "feature-2 INT"),
+ ("feature-3", "feature-3 DOUBLE"),
+ ("feature-4", "feature-4 BOOLEAN"),
+ ("feature-5", "feature-5 TIMESTAMP"),
+ ],
+)
+def test_construct_athena_table_column_string(column, expected, sagemaker_session_mock):
+ dataframe = pd.DataFrame(
+ {
+ "feature-1": ["420"],
+ "feature-2": [50],
+ "feature-3": [5.0],
+ "feature-4": [True],
+ "feature-5": [pd.Timestamp(1513393355)],
+ }
+ )
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=dataframe,
+ output_path="file/to/path",
+ )
+ query_string = dataset_builder._construct_athena_table_column_string(column)
+ assert query_string == expected
+
+
+def test_construct_athena_table_column_string_not_support_column_type(
+ sagemaker_session_mock,
+):
+ dataframe = pd.DataFrame({"feature": pd.Series([1] * 3, dtype="int8")})
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=dataframe,
+ output_path="file/to/path",
+ )
+ with pytest.raises(RuntimeError) as error:
+ dataset_builder._construct_athena_table_column_string("feature")
+ assert "The dataframe type int8 is not supported yet." in str(error)
+
+
+def test_run_query_throw_runtime_error(sagemaker_session_mock, feature_group_mock):
+ dataset_builder = DatasetBuilder(
+ sagemaker_session=sagemaker_session_mock,
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"}
+ sagemaker_session_mock.get_query_execution.return_value = {
+ "QueryExecution": {"Status": {"State": "FAILED"}}
+ }
+ with pytest.raises(RuntimeError) as error:
+ dataset_builder._run_query("query-string", "catalog", "database")
+ assert "Failed to execute query query-id." in str(error)
diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py
new file mode 100644
index 0000000000..dce38fe426
--- /dev/null
+++ b/tests/unit/sagemaker/feature_store/test_feature_group.py
@@ -0,0 +1,580 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+
+import pandas as pd
+import pytest
+from mock import Mock, patch, MagicMock
+from botocore.exceptions import ProfileNotFound
+
+from sagemaker.feature_store.feature_definition import (
+ FractionalFeatureDefinition,
+ IntegralFeatureDefinition,
+ StringFeatureDefinition,
+ FeatureTypeEnum,
+)
+from sagemaker.feature_store.feature_group import (
+ FeatureGroup,
+ IngestionManagerPandas,
+ AthenaQuery,
+ IngestionError,
+)
+from sagemaker.feature_store.inputs import FeatureParameter
+
+
+class PicklableMock(Mock):
+ def __reduce__(self):
+ return (Mock, ())
+
+
+@pytest.fixture
+def role_arn():
+ return "arn:role"
+
+
+@pytest.fixture
+def s3_uri():
+ return "s3://some/uri"
+
+
+@pytest.fixture
+def sagemaker_session_mock():
+ return Mock()
+
+
+@pytest.fixture
+def fs_runtime_client_config_mock():
+ return PicklableMock()
+
+
+@pytest.fixture
+def feature_group_dummy_definitions():
+ return [
+ FractionalFeatureDefinition(feature_name="feature1"),
+ IntegralFeatureDefinition(feature_name="feature2"),
+ StringFeatureDefinition(feature_name="feature3"),
+ ]
+
+
+@pytest.fixture
+def create_table_ddl():
+ return (
+ "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n"
+ " feature1 FLOAT\n"
+ " feature2 INT\n"
+ " feature3 STRING\n"
+ " write_time TIMESTAMP\n"
+ " event_time TIMESTAMP\n"
+ " is_deleted BOOLEAN\n"
+ ")\n"
+ "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n"
+ " STORED AS\n"
+ " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n"
+ " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n"
+ "LOCATION 's3://resolved_output_s3_uri'"
+ )
+
+
+def test_feature_store_create(
+ sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri
+):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.feature_definitions = feature_group_dummy_definitions
+ feature_group.create(
+ s3_uri=s3_uri,
+ record_identifier_name="feature1",
+ event_time_feature_name="feature2",
+ role_arn=role_arn,
+ enable_online_store=True,
+ )
+ sagemaker_session_mock.create_feature_group.assert_called_with(
+ feature_group_name="MyFeatureGroup",
+ record_identifier_name="feature1",
+ event_time_feature_name="feature2",
+ feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions],
+ role_arn=role_arn,
+ description=None,
+ tags=None,
+ online_store_config={"EnableOnlineStore": True},
+ offline_store_config={
+ "DisableGlueTableCreation": False,
+ "S3StorageConfig": {"S3Uri": s3_uri},
+ },
+ )
+
+
+def test_feature_store_create_online_only(
+ sagemaker_session_mock, role_arn, feature_group_dummy_definitions
+):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.feature_definitions = feature_group_dummy_definitions
+ feature_group.create(
+ s3_uri=False,
+ record_identifier_name="feature1",
+ event_time_feature_name="feature2",
+ role_arn=role_arn,
+ enable_online_store=True,
+ )
+ sagemaker_session_mock.create_feature_group.assert_called_with(
+ feature_group_name="MyFeatureGroup",
+ record_identifier_name="feature1",
+ event_time_feature_name="feature2",
+ feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions],
+ role_arn=role_arn,
+ description=None,
+ tags=None,
+ online_store_config={"EnableOnlineStore": True},
+ )
+
+
+def test_feature_store_delete(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.delete()
+ sagemaker_session_mock.delete_feature_group.assert_called_with(
+ feature_group_name="MyFeatureGroup"
+ )
+
+
+def test_feature_store_describe(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.describe()
+ sagemaker_session_mock.describe_feature_group.assert_called_with(
+ feature_group_name="MyFeatureGroup", next_token=None
+ )
+
+
+def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.update(feature_group_dummy_definitions)
+ sagemaker_session_mock.update_feature_group.assert_called_with(
+ feature_group_name="MyFeatureGroup",
+ feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions],
+ )
+
+
+def test_feature_metadata_update(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+
+ parameter_additions = [FeatureParameter(key="key1", value="value1")]
+ parameter_removals = ["key2"]
+
+ feature_group.update_feature_metadata(
+ feature_name="Feature1",
+ description="TestDescription",
+ parameter_additions=parameter_additions,
+ parameter_removals=parameter_removals,
+ )
+ sagemaker_session_mock.update_feature_metadata.assert_called_with(
+ feature_group_name="MyFeatureGroup",
+ feature_name="Feature1",
+ description="TestDescription",
+ parameter_additions=[pa.to_dict() for pa in parameter_additions],
+ parameter_removals=parameter_removals,
+ )
+ feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription")
+ sagemaker_session_mock.update_feature_metadata.assert_called_with(
+ feature_group_name="MyFeatureGroup",
+ feature_name="Feature1",
+ description="TestDescription",
+ parameter_additions=[],
+ parameter_removals=[],
+ )
+
+
+def test_feature_metadata_describe(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.describe_feature_metadata(feature_name="Feature1")
+ sagemaker_session_mock.describe_feature_metadata.assert_called_with(
+ feature_group_name="MyFeatureGroup", feature_name="Feature1"
+ )
+
+
+def test_get_record(sagemaker_session_mock):
+ feature_group_name = "MyFeatureGroup"
+ feature_names = ["MyFeature1", "MyFeature2"]
+ record_identifier_value_as_string = "1.0"
+ feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session_mock)
+ feature_group.get_record(
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ feature_names=feature_names,
+ )
+ sagemaker_session_mock.get_record.assert_called_with(
+ feature_group_name=feature_group_name,
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ feature_names=feature_names,
+ )
+
+
+def test_put_record(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.put_record(record=[])
+ sagemaker_session_mock.put_record.assert_called_with(
+ feature_group_name="MyFeatureGroup", record=[]
+ )
+
+
+def test_delete_record(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
+ record_identifier_value_as_string = "1.0"
+ event_time = "2022-09-14"
+ feature_group.delete_record(
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ event_time=event_time,
+ )
+ sagemaker_session_mock.delete_record.assert_called_with(
+ feature_group_name="MyFeatureGroup",
+ record_identifier_value_as_string=record_identifier_value_as_string,
+ event_time=event_time,
+ )
+
+
+def test_load_feature_definition(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock)
+ df = pd.DataFrame(
+ {
+ "float": pd.Series([2.0], dtype="float64"),
+ "int": pd.Series([2], dtype="int64"),
+ "string": pd.Series(["f1"], dtype="string"),
+ }
+ )
+ feature_definitions = feature_group.load_feature_definitions(data_frame=df)
+ names = [fd.feature_name for fd in feature_definitions]
+ types = [fd.feature_type for fd in feature_definitions]
+ assert names == ["float", "int", "string"]
+ assert types == [
+ FeatureTypeEnum.FRACTIONAL,
+ FeatureTypeEnum.INTEGRAL,
+ FeatureTypeEnum.STRING,
+ ]
+
+
+def test_load_feature_definition_unsupported_types(sagemaker_session_mock):
+ feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock)
+ df = pd.DataFrame(
+ {
+ "float": pd.Series([2.0], dtype="float64"),
+ "int": pd.Series([2], dtype="int64"),
+ "bool": pd.Series([True], dtype="bool"),
+ }
+ )
+ with pytest.raises(ValueError) as error:
+ feature_group.load_feature_definitions(data_frame=df)
+ assert "Failed to infer Feature type based on dtype bool for column bool." in str(error)
+
+
+def test_ingest_zero_processes():
+ feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
+ df = Mock()
+ with pytest.raises(RuntimeError) as error:
+ feature_group.ingest(data_frame=df, max_workers=1, max_processes=0)
+
+ assert "max_processes must be greater than 0." in str(error)
+
+
+def test_ingest_zero_workers():
+ feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
+ df = Mock()
+ with pytest.raises(RuntimeError) as error:
+ feature_group.ingest(data_frame=df, max_workers=0, max_processes=1)
+
+ assert "max_workers must be greater than 0." in str(error)
+
+
+@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
+def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock):
+ sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
+ fs_runtime_client_config_mock
+ )
+
+ feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
+ df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))
+
+ mock_ingestion_manager_instance = Mock()
+ ingestion_manager_init.return_value = mock_ingestion_manager_instance
+ feature_group.ingest(data_frame=df, max_workers=10)
+
+ ingestion_manager_init.assert_called_once_with(
+ feature_group_name="MyGroup",
+ sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
+ max_workers=10,
+ max_processes=1,
+ profile_name=None,
+ )
+ mock_ingestion_manager_instance.run.assert_called_once_with(
+ data_frame=df, wait=True, timeout=None
+ )
+
+
+@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
+def test_ingest_with_profile_name(
+ ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock
+):
+ sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
+ fs_runtime_client_config_mock
+ )
+
+ feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
+ df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))
+
+ mock_ingestion_manager_instance = Mock()
+ ingestion_manager_init.return_value = mock_ingestion_manager_instance
+ feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name")
+
+ ingestion_manager_init.assert_called_once_with(
+ feature_group_name="MyGroup",
+ sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
+ max_workers=10,
+ max_processes=1,
+ profile_name="profile_name",
+ )
+ mock_ingestion_manager_instance.run.assert_called_once_with(
+ data_frame=df, wait=True, timeout=None
+ )
+
+
+def test_as_hive_ddl_with_default_values(
+ create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock
+):
+ sagemaker_session_mock.describe_feature_group.return_value = {
+ "OfflineStoreConfig": {
+ "S3StorageConfig": {
+ "S3Uri": "s3://some-bucket",
+ "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri",
+ }
+ }
+ }
+ sagemaker_session_mock.account_id.return_value = "1234"
+ sagemaker_session_mock.boto_session.region_name = "us-west-2"
+
+ feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.feature_definitions = feature_group_dummy_definitions
+ assert (
+ create_table_ddl.format(
+ database="sagemaker_featurestore",
+ table_name="MyGroup",
+ account="1234",
+ region="us-west-2",
+ feature_group_name="MyGroup",
+ )
+ == feature_group.as_hive_ddl()
+ )
+
+
+def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock):
+ sagemaker_session_mock.describe_feature_group.return_value = {
+ "OfflineStoreConfig": {
+ "S3StorageConfig": {
+ "S3Uri": "s3://some-bucket",
+ "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri",
+ }
+ }
+ }
+ sagemaker_session_mock.account_id.return_value = "1234"
+ sagemaker_session_mock.boto_session.region_name = "us-west-2"
+
+ feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
+ feature_group.feature_definitions = feature_group_dummy_definitions
+ assert create_table_ddl.format(
+ database="MyDatabase",
+ table_name="MyTable",
+ account="1234",
+ region="us-west-2",
+ feature_group_name="MyGroup",
+ ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable")
+
+
+@patch(
+ "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process",
+ MagicMock(),
+)
+def test_ingestion_manager_run_success():
+ df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
+ manager = IngestionManagerPandas(
+ feature_group_name="MyGroup",
+ sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
+ max_workers=10,
+ )
+ manager.run(df)
+
+ manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None)
+
+
+@patch(
+ "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded",
+ PicklableMock(return_value=[]),
+)
+def test_ingestion_manager_run_multi_process_with_multi_thread_success(
+ fs_runtime_client_config_mock,
+):
+ df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
+ manager = IngestionManagerPandas(
+ feature_group_name="MyGroup",
+ sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
+ max_workers=2,
+ max_processes=2,
+ )
+ manager.run(df)
+
+
+@patch(
+ "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
+ MagicMock(return_value=[1]),
+)
+def test_ingestion_manager_run_failure():
+ df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
+ manager = IngestionManagerPandas(
+ feature_group_name="MyGroup",
+ sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
+ max_workers=1,
+ )
+
+ with pytest.raises(IngestionError) as error:
+ manager.run(df)
+
+ assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error)
+ assert error.value.failed_rows == [1]
+ assert manager.failed_rows == [1]
+
+
+@patch(
+ "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
+ MagicMock(side_effect=ProfileNotFound(profile="non_exist")),
+)
+def test_ingestion_manager_with_profile_name_run_failure():
+ df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
+ manager = IngestionManagerPandas(
+ feature_group_name="MyGroup",
+ sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
+ max_workers=1,
+ profile_name="non_exist",
+ )
+
+ try:
+ manager.run(df)
+ except Exception as e:
+ assert "The config profile (non_exist) could not be found" in str(e)
+
+
+@patch(
+ "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
+ PicklableMock(return_value=[1]),
+)
+def test_ingestion_manager_run_multi_process_failure():
+ df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
+ manager = IngestionManagerPandas(
+ feature_group_name="MyGroup",
+ sagemaker_fs_runtime_client_config=None,
+ max_workers=2,
+ max_processes=2,
+ )
+
+ with pytest.raises(IngestionError) as error:
+ manager.run(df)
+
+ assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error)
+ assert error.value.failed_rows == [1, 1, 1, 1]
+ assert manager.failed_rows == [1, 1, 1, 1]
+
+
+@pytest.fixture
+def query(sagemaker_session_mock):
+ return AthenaQuery(
+ catalog="catalog",
+ database="database",
+ table_name="table_name",
+ sagemaker_session=sagemaker_session_mock,
+ )
+
+
+def test_athena_query_run(sagemaker_session_mock, query):
+ sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"}
+ query.run(
+ query_string="query", output_location="s3://some-bucket/some-path", workgroup="workgroup"
+ )
+ sagemaker_session_mock.start_query_execution.assert_called_with(
+ catalog="catalog",
+ database="database",
+ query_string="query",
+ output_location="s3://some-bucket/some-path",
+ kms_key=None,
+ workgroup="workgroup",
+ )
+ assert "some-bucket" == query._result_bucket
+ assert "some-path" == query._result_file_prefix
+ assert "query_id" == query._current_query_execution_id
+
+
+def test_athena_query_wait(sagemaker_session_mock, query):
+ query._current_query_execution_id = "query_id"
+ query.wait()
+ sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id")
+
+
+def test_athena_query_get_query_execution(sagemaker_session_mock, query):
+ query._current_query_execution_id = "query_id"
+ query.get_query_execution()
+ sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id")
+
+
+@patch("tempfile.gettempdir", Mock(return_value="tmp"))
+@patch("pandas.read_csv")
+def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query):
+ sagemaker_session_mock.get_query_execution.return_value = {
+ "QueryExecution": {"Status": {"State": "SUCCEEDED"}}
+ }
+ query._current_query_execution_id = "query_id"
+ query._result_bucket = "bucket"
+ query._result_file_prefix = "prefix"
+ query.as_dataframe()
+ sagemaker_session_mock.download_athena_query_result.assert_called_with(
+ bucket="bucket",
+ prefix="prefix",
+ query_execution_id="query_id",
+ filename="tmp/query_id.csv",
+ )
+ read_csv.assert_called_with("tmp/query_id.csv", delimiter=",")
+
+
+@patch("tempfile.gettempdir", Mock(return_value="tmp"))
+def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query):
+ sagemaker_session_mock.get_query_execution.return_value = {
+ "QueryExecution": {"Status": {"State": "FAILED"}}
+ }
+ query._current_query_execution_id = "query_id"
+ with pytest.raises(RuntimeError) as error:
+ query.as_dataframe()
+ assert "Failed to execute query query_id" in str(error)
+
+
+@patch("tempfile.gettempdir", Mock(return_value="tmp"))
+def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query):
+ sagemaker_session_mock.get_query_execution.return_value = {
+ "QueryExecution": {"Status": {"State": "QUEUED"}}
+ }
+ query._current_query_execution_id = "query_id"
+ with pytest.raises(RuntimeError) as error:
+ query.as_dataframe()
+ assert "Current query query_id is still being executed" in str(error)
+
+
+@patch("tempfile.gettempdir", Mock(return_value="tmp"))
+def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query):
+ sagemaker_session_mock.get_query_execution.return_value = {
+ "QueryExecution": {"Status": {"State": "RUNNING"}}
+ }
+ query._current_query_execution_id = "query_id"
+ with pytest.raises(RuntimeError) as error:
+ query.as_dataframe()
+ assert "Current query query_id is still being executed" in str(error)
diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py
index 92ba35573c..073daca9ea 100644
--- a/tests/unit/sagemaker/feature_store/test_feature_store.py
+++ b/tests/unit/sagemaker/feature_store/test_feature_store.py
@@ -10,46 +10,17 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
-# language governing permissions and limitations under the License.
from __future__ import absolute_import
+import datetime
import pandas as pd
import pytest
-from mock import Mock, patch, MagicMock
-from botocore.exceptions import ProfileNotFound
-
-from sagemaker.feature_store.feature_definition import (
- FractionalFeatureDefinition,
- IntegralFeatureDefinition,
- StringFeatureDefinition,
- FeatureTypeEnum,
-)
-from sagemaker.feature_store.feature_group import (
- FeatureGroup,
- IngestionManagerPandas,
- AthenaQuery,
- IngestionError,
-)
-from sagemaker.feature_store.inputs import (
- FeatureParameter,
- TableFormatEnum,
-)
-
+from mock import Mock
-class PicklableMock(Mock):
- def __reduce__(self):
- return (Mock, ())
+from sagemaker.feature_store.feature_store import FeatureStore
-
-@pytest.fixture
-def role_arn():
- return "arn:role"
-
-
-@pytest.fixture
-def s3_uri():
- return "s3://some/uri"
+DATAFRAME = pd.DataFrame({"feature_1": [420, 380, 390], "feature_2": [50, 40, 45]})
@pytest.fixture
@@ -58,558 +29,108 @@ def sagemaker_session_mock():
@pytest.fixture
-def fs_runtime_client_config_mock():
- return PicklableMock()
-
-
-@pytest.fixture
-def feature_group_dummy_definitions():
- return [
- FractionalFeatureDefinition(feature_name="feature1"),
- IntegralFeatureDefinition(feature_name="feature2"),
- StringFeatureDefinition(feature_name="feature3"),
- ]
-
-
-@pytest.fixture
-def create_table_ddl():
- return (
- "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n"
- " feature1 FLOAT\n"
- " feature2 INT\n"
- " feature3 STRING\n"
- " write_time TIMESTAMP\n"
- " event_time TIMESTAMP\n"
- " is_deleted BOOLEAN\n"
- ")\n"
- "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n"
- " STORED AS\n"
- " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n"
- " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n"
- "LOCATION 's3://resolved_output_s3_uri'"
- )
-
-
-def test_feature_store_create(
- sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri
-):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.feature_definitions = feature_group_dummy_definitions
- feature_group.create(
- s3_uri=s3_uri,
- record_identifier_name="feature1",
- event_time_feature_name="feature2",
- role_arn=role_arn,
- enable_online_store=True,
- )
- sagemaker_session_mock.create_feature_group.assert_called_with(
- feature_group_name="MyFeatureGroup",
- record_identifier_name="feature1",
- event_time_feature_name="feature2",
- feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions],
- role_arn=role_arn,
- description=None,
- tags=None,
- online_store_config={"EnableOnlineStore": True},
- offline_store_config={
- "DisableGlueTableCreation": False,
- "S3StorageConfig": {"S3Uri": s3_uri},
- },
- )
-
-
-def test_feature_store_create_iceberg_table_format(
- sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri
-):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.feature_definitions = feature_group_dummy_definitions
- feature_group.create(
- s3_uri=s3_uri,
- record_identifier_name="feature1",
- event_time_feature_name="feature2",
- role_arn=role_arn,
- enable_online_store=True,
- disable_glue_table_creation=False,
- table_format=TableFormatEnum.ICEBERG,
- )
- sagemaker_session_mock.create_feature_group.assert_called_with(
- feature_group_name="MyFeatureGroup",
- record_identifier_name="feature1",
- event_time_feature_name="feature2",
- feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions],
- role_arn=role_arn,
- description=None,
- tags=None,
- online_store_config={"EnableOnlineStore": True},
- offline_store_config={
- "DisableGlueTableCreation": False,
- "TableFormat": "Iceberg",
- "S3StorageConfig": {"S3Uri": s3_uri},
- },
- )
-
-
-def test_feature_store_create_glue_table_format(
- sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri
-):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.feature_definitions = feature_group_dummy_definitions
- feature_group.create(
- s3_uri=s3_uri,
- record_identifier_name="feature1",
- event_time_feature_name="feature2",
- role_arn=role_arn,
- enable_online_store=True,
- disable_glue_table_creation=False,
- table_format=TableFormatEnum.GLUE,
- )
- sagemaker_session_mock.create_feature_group.assert_called_with(
- feature_group_name="MyFeatureGroup",
- record_identifier_name="feature1",
- event_time_feature_name="feature2",
- feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions],
- role_arn=role_arn,
- description=None,
- tags=None,
- online_store_config={"EnableOnlineStore": True},
- offline_store_config={
- "DisableGlueTableCreation": False,
- "TableFormat": "Glue",
- "S3StorageConfig": {"S3Uri": s3_uri},
- },
- )
-
-
-def test_feature_store_create_online_only(
- sagemaker_session_mock, role_arn, feature_group_dummy_definitions
-):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.feature_definitions = feature_group_dummy_definitions
- feature_group.create(
- s3_uri=False,
- record_identifier_name="feature1",
- event_time_feature_name="feature2",
- role_arn=role_arn,
- enable_online_store=True,
- )
- sagemaker_session_mock.create_feature_group.assert_called_with(
- feature_group_name="MyFeatureGroup",
- record_identifier_name="feature1",
- event_time_feature_name="feature2",
- feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions],
- role_arn=role_arn,
- description=None,
- tags=None,
- online_store_config={"EnableOnlineStore": True},
- )
-
-
-def test_feature_store_delete(sagemaker_session_mock):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.delete()
- sagemaker_session_mock.delete_feature_group.assert_called_with(
- feature_group_name="MyFeatureGroup"
- )
-
-
-def test_feature_store_describe(sagemaker_session_mock):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.describe()
- sagemaker_session_mock.describe_feature_group.assert_called_with(
- feature_group_name="MyFeatureGroup", next_token=None
- )
-
-
-def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.update(feature_group_dummy_definitions)
- sagemaker_session_mock.update_feature_group.assert_called_with(
- feature_group_name="MyFeatureGroup",
- feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions],
- )
-
-
-def test_feature_metadata_update(sagemaker_session_mock):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
-
- parameter_additions = [FeatureParameter(key="key1", value="value1")]
- parameter_removals = ["key2"]
-
- feature_group.update_feature_metadata(
- feature_name="Feature1",
- description="TestDescription",
- parameter_additions=parameter_additions,
- parameter_removals=parameter_removals,
- )
- sagemaker_session_mock.update_feature_metadata.assert_called_with(
- feature_group_name="MyFeatureGroup",
- feature_name="Feature1",
- description="TestDescription",
- parameter_additions=[pa.to_dict() for pa in parameter_additions],
- parameter_removals=parameter_removals,
- )
- feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription")
- sagemaker_session_mock.update_feature_metadata.assert_called_with(
- feature_group_name="MyFeatureGroup",
- feature_name="Feature1",
- description="TestDescription",
- parameter_additions=[],
- parameter_removals=[],
- )
-
-
-def test_feature_metadata_describe(sagemaker_session_mock):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.describe_feature_metadata(feature_name="Feature1")
- sagemaker_session_mock.describe_feature_metadata.assert_called_with(
- feature_group_name="MyFeatureGroup", feature_name="Feature1"
- )
-
-
-def test_put_record(sagemaker_session_mock):
- feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.put_record(record=[])
- sagemaker_session_mock.put_record.assert_called_with(
- feature_group_name="MyFeatureGroup", record=[]
- )
-
-
-def test_load_feature_definition(sagemaker_session_mock):
- feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock)
- df = pd.DataFrame(
- {
- "float": pd.Series([2.0], dtype="float64"),
- "int": pd.Series([2], dtype="int64"),
- "string": pd.Series(["f1"], dtype="string"),
- }
- )
- feature_definitions = feature_group.load_feature_definitions(data_frame=df)
- names = [fd.feature_name for fd in feature_definitions]
- types = [fd.feature_type for fd in feature_definitions]
- assert names == ["float", "int", "string"]
- assert types == [
- FeatureTypeEnum.FRACTIONAL,
- FeatureTypeEnum.INTEGRAL,
- FeatureTypeEnum.STRING,
- ]
+def feature_group_mock():
+ return Mock()
-def test_load_feature_definition_unsupported_types(sagemaker_session_mock):
- feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock)
- df = pd.DataFrame(
- {
- "float": pd.Series([2.0], dtype="float64"),
- "int": pd.Series([2], dtype="int64"),
- "object": pd.Series(["f1"], dtype="object"),
- }
- )
+def test_minimal_create_dataset(sagemaker_session_mock, feature_group_mock):
+ feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
+ dataset_builder = feature_store.create_dataset(
+ base=feature_group_mock,
+ output_path="file/to/path",
+ )
+ assert dataset_builder._sagemaker_session == sagemaker_session_mock
+ assert dataset_builder._base == feature_group_mock
+ assert dataset_builder._output_path == "file/to/path"
+
+
+def test_complete_create_dataset(sagemaker_session_mock, feature_group_mock):
+ feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
+ dataset_builder = feature_store.create_dataset(
+ base=feature_group_mock,
+ included_feature_names=["feature_1", "feature_2"],
+ output_path="file/to/path",
+ kms_key_id="kms-key-id",
+ )
+ assert dataset_builder._sagemaker_session == sagemaker_session_mock
+ assert dataset_builder._base == feature_group_mock
+ assert dataset_builder._included_feature_names == ["feature_1", "feature_2"]
+ assert dataset_builder._output_path == "file/to/path"
+ assert dataset_builder._kms_key_id == "kms-key-id"
+
+
+def test_create_dataset_with_dataframe(sagemaker_session_mock):
+ feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
+ dataset_builder = feature_store.create_dataset(
+ base=DATAFRAME,
+ record_identifier_feature_name="feature_1",
+ event_time_identifier_feature_name="feature_2",
+ included_feature_names=["feature_1", "feature_2"],
+ output_path="file/to/path",
+ kms_key_id="kms-key-id",
+ )
+ assert dataset_builder._sagemaker_session == sagemaker_session_mock
+ assert dataset_builder._base.equals(DATAFRAME)
+ assert dataset_builder._record_identifier_feature_name == "feature_1"
+ assert dataset_builder._event_time_identifier_feature_name == "feature_2"
+ assert dataset_builder._included_feature_names == ["feature_1", "feature_2"]
+ assert dataset_builder._output_path == "file/to/path"
+ assert dataset_builder._kms_key_id == "kms-key-id"
+
+
+def test_create_dataset_with_dataframe_value_error(sagemaker_session_mock):
+ feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
with pytest.raises(ValueError) as error:
- feature_group.load_feature_definitions(data_frame=df)
- assert "Failed to infer Feature type based on dtype object for column object." in str(error)
-
-
-def test_ingest_zero_processes():
- feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
- df = Mock()
- with pytest.raises(RuntimeError) as error:
- feature_group.ingest(data_frame=df, max_workers=1, max_processes=0)
-
- assert "max_processes must be greater than 0." in str(error)
-
-
-def test_ingest_zero_workers():
- feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
- df = Mock()
- with pytest.raises(RuntimeError) as error:
- feature_group.ingest(data_frame=df, max_workers=0, max_processes=1)
-
- assert "max_workers must be greater than 0." in str(error)
-
-
-@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
-def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock):
- sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
- fs_runtime_client_config_mock
- )
-
- feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
- df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))
-
- mock_ingestion_manager_instance = Mock()
- ingestion_manager_init.return_value = mock_ingestion_manager_instance
- feature_group.ingest(data_frame=df, max_workers=10)
-
- ingestion_manager_init.assert_called_once_with(
- feature_group_name="MyGroup",
- sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
- max_workers=10,
- max_processes=1,
- profile_name=None,
- )
- mock_ingestion_manager_instance.run.assert_called_once_with(
- data_frame=df, wait=True, timeout=None
- )
-
-
-@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas")
-def test_ingest_with_profile_name(
- ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock
-):
- sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = (
- fs_runtime_client_config_mock
- )
-
- feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
- df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300)))
-
- mock_ingestion_manager_instance = Mock()
- ingestion_manager_init.return_value = mock_ingestion_manager_instance
- feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name")
-
- ingestion_manager_init.assert_called_once_with(
- feature_group_name="MyGroup",
- sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
- max_workers=10,
- max_processes=1,
- profile_name="profile_name",
- )
- mock_ingestion_manager_instance.run.assert_called_once_with(
- data_frame=df, wait=True, timeout=None
- )
-
-
-def test_as_hive_ddl_with_default_values(
- create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock
-):
- sagemaker_session_mock.describe_feature_group.return_value = {
- "OfflineStoreConfig": {
- "S3StorageConfig": {
- "S3Uri": "s3://some-bucket",
- "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri",
- }
- }
- }
- sagemaker_session_mock.account_id.return_value = "1234"
- sagemaker_session_mock.boto_session.region_name = "us-west-2"
-
- feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.feature_definitions = feature_group_dummy_definitions
- assert (
- create_table_ddl.format(
- database="sagemaker_featurestore",
- table_name="MyGroup",
- account="1234",
- region="us-west-2",
- feature_group_name="MyGroup",
+ feature_store.create_dataset(
+ base=DATAFRAME,
+ included_feature_names=["feature_1", "feature_2"],
+ output_path="file/to/path",
+ kms_key_id="kms-key-id",
)
- == feature_group.as_hive_ddl()
- )
-
-
-def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock):
- sagemaker_session_mock.describe_feature_group.return_value = {
- "OfflineStoreConfig": {
- "S3StorageConfig": {
- "S3Uri": "s3://some-bucket",
- "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri",
- }
- }
- }
- sagemaker_session_mock.account_id.return_value = "1234"
- sagemaker_session_mock.boto_session.region_name = "us-west-2"
-
- feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock)
- feature_group.feature_definitions = feature_group_dummy_definitions
- assert create_table_ddl.format(
- database="MyDatabase",
- table_name="MyTable",
- account="1234",
- region="us-west-2",
- feature_group_name="MyGroup",
- ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable")
-
-
-@patch(
- "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process",
- MagicMock(),
-)
-def test_ingestion_manager_run_success():
- df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
- manager = IngestionManagerPandas(
- feature_group_name="MyGroup",
- sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
- max_workers=10,
- )
- manager.run(df)
-
- manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None)
-
-
-@patch(
- "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded",
- PicklableMock(return_value=[]),
-)
-def test_ingestion_manager_run_multi_process_with_multi_thread_success(
- fs_runtime_client_config_mock,
-):
- df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
- manager = IngestionManagerPandas(
- feature_group_name="MyGroup",
- sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
- max_workers=2,
- max_processes=2,
- )
- manager.run(df)
-
-
-@patch(
- "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
- MagicMock(return_value=[1]),
-)
-def test_ingestion_manager_run_failure():
- df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
- manager = IngestionManagerPandas(
- feature_group_name="MyGroup",
- sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
- max_workers=1,
- )
-
- with pytest.raises(IngestionError) as error:
- manager.run(df)
-
- assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error)
- assert error.value.failed_rows == [1]
- assert manager.failed_rows == [1]
-
-
-@patch(
- "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
- MagicMock(side_effect=ProfileNotFound(profile="non_exist")),
-)
-def test_ingestion_manager_with_profile_name_run_failure():
- df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
- manager = IngestionManagerPandas(
- feature_group_name="MyGroup",
- sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock,
- max_workers=1,
- profile_name="non_exist",
- )
-
- try:
- manager.run(df)
- except Exception as e:
- assert "The config profile (non_exist) could not be found" in str(e)
-
-
-@patch(
- "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch",
- PicklableMock(return_value=[1]),
-)
-def test_ingestion_manager_run_multi_process_failure():
- df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")})
- manager = IngestionManagerPandas(
- feature_group_name="MyGroup",
- sagemaker_fs_runtime_client_config=None,
- max_workers=2,
- max_processes=2,
- )
-
- with pytest.raises(IngestionError) as error:
- manager.run(df)
-
- assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error)
- assert error.value.failed_rows == [1, 1, 1, 1]
- assert manager.failed_rows == [1, 1, 1, 1]
-
-
-@pytest.fixture
-def query(sagemaker_session_mock):
- return AthenaQuery(
- catalog="catalog",
- database="database",
- table_name="table_name",
- sagemaker_session=sagemaker_session_mock,
- )
-
-
-def test_athena_query_run(sagemaker_session_mock, query):
- WORKGROUP = "workgroup"
- sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"}
- query.run(
- query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP
- )
- sagemaker_session_mock.start_query_execution.assert_called_with(
- catalog="catalog",
- database="database",
- query_string="query",
- output_location="s3://some-bucket/some-path",
- kms_key=None,
- workgroup=WORKGROUP,
- )
- assert "some-bucket" == query._result_bucket
- assert "some-path" == query._result_file_prefix
- assert "query_id" == query._current_query_execution_id
-
-
-def test_athena_query_wait(sagemaker_session_mock, query):
- query._current_query_execution_id = "query_id"
- query.wait()
- sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id")
-
-
-def test_athena_query_get_query_execution(sagemaker_session_mock, query):
- query._current_query_execution_id = "query_id"
- query.get_query_execution()
- sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id")
-
-
-@patch("tempfile.gettempdir", Mock(return_value="tmp"))
-@patch("pandas.read_csv")
-def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query):
- sagemaker_session_mock.get_query_execution.return_value = {
- "QueryExecution": {"Status": {"State": "SUCCEEDED"}}
- }
- query._current_query_execution_id = "query_id"
- query._result_bucket = "bucket"
- query._result_file_prefix = "prefix"
- query.as_dataframe()
- sagemaker_session_mock.download_athena_query_result.assert_called_with(
- bucket="bucket",
- prefix="prefix",
- query_execution_id="query_id",
- filename="tmp/query_id.csv",
+ assert (
+ "You must provide a record identifier feature name and an event time identifier feature "
+ + "name if specify DataFrame as base."
+ in str(error)
+ )
+
+
+def test_list_feature_groups_with_no_filter(sagemaker_session_mock):
+ feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
+ feature_store.list_feature_groups()
+ sagemaker_session_mock.list_feature_groups.assert_called_with(
+ name_contains=None,
+ feature_group_status_equals=None,
+ offline_store_status_equals=None,
+ creation_time_after=None,
+ creation_time_before=None,
+ sort_order=None,
+ sort_by=None,
+ max_results=None,
+ next_token=None,
+ )
+
+
+def test_list_feature_groups_with_all_filters(sagemaker_session_mock):
+ feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock)
+ feature_store.list_feature_groups(
+ name_contains="MyFeatureGroup",
+ feature_group_status_equals="Created",
+ offline_store_status_equals="Active",
+ creation_time_after=datetime.datetime(2020, 12, 1),
+ creation_time_before=datetime.datetime(2022, 7, 1),
+ sort_order="Ascending",
+ sort_by="Name",
+ max_results=50,
+ next_token="token",
+ )
+ sagemaker_session_mock.list_feature_groups.assert_called_with(
+ name_contains="MyFeatureGroup",
+ feature_group_status_equals="Created",
+ offline_store_status_equals="Active",
+ creation_time_after=datetime.datetime(2020, 12, 1),
+ creation_time_before=datetime.datetime(2022, 7, 1),
+ sort_order="Ascending",
+ sort_by="Name",
+ max_results=50,
+ next_token="token",
)
- read_csv.assert_called_with("tmp/query_id.csv", delimiter=",")
-
-
-@patch("tempfile.gettempdir", Mock(return_value="tmp"))
-def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query):
- sagemaker_session_mock.get_query_execution.return_value = {
- "QueryExecution": {"Status": {"State": "FAILED"}}
- }
- query._current_query_execution_id = "query_id"
- with pytest.raises(RuntimeError) as error:
- query.as_dataframe()
- assert "Failed to execute query query_id" in str(error)
-
-
-@patch("tempfile.gettempdir", Mock(return_value="tmp"))
-def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query):
- sagemaker_session_mock.get_query_execution.return_value = {
- "QueryExecution": {"Status": {"State": "QUEUED"}}
- }
- query._current_query_execution_id = "query_id"
- with pytest.raises(RuntimeError) as error:
- query.as_dataframe()
- assert "Current query query_id is still being executed" in str(error)
-
-
-@patch("tempfile.gettempdir", Mock(return_value="tmp"))
-def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query):
- sagemaker_session_mock.get_query_execution.return_value = {
- "QueryExecution": {"Status": {"State": "RUNNING"}}
- }
- query._current_query_execution_id = "query_id"
- with pytest.raises(RuntimeError) as error:
- query.as_dataframe()
- assert "Current query query_id is still being executed" in str(error)
diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py
index c391d45382..0088e34c58 100644
--- a/tests/unit/sagemaker/huggingface/test_estimator.py
+++ b/tests/unit/sagemaker/huggingface/test_estimator.py
@@ -48,6 +48,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py
index 454d375b4b..443727094a 100644
--- a/tests/unit/sagemaker/image_uris/test_algos.py
+++ b/tests/unit/sagemaker/image_uris/test_algos.py
@@ -68,10 +68,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "382416733822",
"us-east-2": "404615174143",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "174872318107",
@@ -155,10 +157,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032",
diff --git a/tests/unit/sagemaker/image_uris/test_sklearn.py b/tests/unit/sagemaker/image_uris/test_sklearn.py
index d0fcbdb300..8563753e8c 100644
--- a/tests/unit/sagemaker/image_uris/test_sklearn.py
+++ b/tests/unit/sagemaker/image_uris/test_sklearn.py
@@ -37,10 +37,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249",
diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py
index 78ab7e10ee..4d0f9f1dc3 100644
--- a/tests/unit/sagemaker/image_uris/test_xgboost.py
+++ b/tests/unit/sagemaker/image_uris/test_xgboost.py
@@ -35,10 +35,12 @@
"eu-west-3": "749696950732",
"eu-south-1": "257386234256",
"me-south-1": "249704162688",
+ "me-central-1": "272398656194",
"sa-east-1": "855470959533",
"us-east-1": "811284229777",
"us-east-2": "825641698319",
"us-gov-west-1": "226302683700",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "490574956308",
"us-west-1": "632365934929",
"us-west-2": "433757028032",
@@ -67,10 +69,12 @@
"eu-west-3": "659782779980",
"eu-south-1": "978288397137",
"me-south-1": "801668240914",
+ "me-central-1": "272398656194",
"sa-east-1": "737474898029",
"us-east-1": "683313688378",
"us-east-2": "257758044811",
"us-gov-west-1": "414596584902",
+ "us-gov-east-1": "237065988967",
"us-iso-east-1": "833128469047",
"us-west-1": "746614075791",
"us-west-2": "246618743249",
diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py
index 2e7576421f..fea80b7ea9 100644
--- a/tests/unit/sagemaker/tensorflow/test_estimator.py
+++ b/tests/unit/sagemaker/tensorflow/test_estimator.py
@@ -56,6 +56,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py
index af46cf4360..d35c0a51dd 100644
--- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py
+++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py
@@ -52,6 +52,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py
index 5aef9316da..7645c4fe23 100644
--- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py
+++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py
@@ -50,6 +50,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py
new file mode 100644
index 0000000000..0fe2402695
--- /dev/null
+++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py
@@ -0,0 +1,616 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+import logging
+
+import json
+import os
+
+import pytest
+from mock import MagicMock, Mock, patch, ANY
+from packaging.version import Version
+
+from sagemaker import image_uris
+from sagemaker.pytorch import PyTorch, TrainingCompilerConfig
+from sagemaker.pytorch.model import PyTorchModel
+from sagemaker.instance_group import InstanceGroup
+
+from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES
+
+
+DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "..", "data")
+SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
+SERVING_SCRIPT_FILE = "another_dummy_script.py"
+MODEL_DATA = "s3://some/data.tar.gz"
+ENV = {"DUMMY_ENV_VAR": "dummy_value"}
+TIMESTAMP = "2017-11-06-14:14:15.672"
+TIME = 1510006209.073025
+BUCKET_NAME = "mybucket"
+INSTANCE_COUNT = 1
+INSTANCE_TYPE = "ml.p3.2xlarge"
+IMAGE_URI = "pytorch"
+JOB_NAME = "{}-{}".format(IMAGE_URI, TIMESTAMP)
+ROLE = "Dummy"
+REGION = "us-east-1"
+GPU = "ml.p3.2xlarge"
+SUPPORTED_GPU_INSTANCE_CLASSES = {"p3", "p3dn", "g4dn", "p4d", "g5"}
+UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES
+
+LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}
+
+EXPERIMENT_CONFIG = {
+ "ExperimentName": "exp",
+ "TrialName": "trial",
+ "TrialComponentDisplayName": "tc",
+}
+
+
+@pytest.fixture(scope="module")
+def cpu_instance_type():
+ return "ml.m5.xlarge"
+
+
+@pytest.fixture(name="sagemaker_session", scope="function")
+def fixture_sagemaker_session():
+ boto_mock = Mock(name="boto_session", region_name=REGION)
+ session = Mock(
+ name="sagemaker_session",
+ boto_session=boto_mock,
+ boto_region_name=REGION,
+ config=None,
+ local_mode=False,
+ s3_resource=None,
+ s3_client=None,
+ )
+
+ describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
+ session.sagemaker_client.describe_training_job = Mock(return_value=describe)
+ session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT)
+ session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
+ session.expand_role = Mock(name="expand_role", return_value=ROLE)
+ return session
+
+
+def _get_full_gpu_image_uri(version, instance_type, training_compiler_config):
+ return image_uris.retrieve(
+ "pytorch-training-compiler",
+ REGION,
+ version=version,
+ py_version="py38",
+ instance_type=instance_type,
+ image_scope="training",
+ container_version=None,
+ training_compiler_config=training_compiler_config,
+ )
+
+
+def _create_train_job(version, instance_type, training_compiler_config, instance_count=1):
+ return {
+ "image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config),
+ "input_mode": "File",
+ "input_config": [
+ {
+ "ChannelName": "training",
+ "DataSource": {
+ "S3DataSource": {
+ "S3DataDistributionType": "FullyReplicated",
+ "S3DataType": "S3Prefix",
+ }
+ },
+ }
+ ],
+ "role": ROLE,
+ "job_name": JOB_NAME,
+ "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)},
+ "resource_config": {
+ "InstanceType": instance_type,
+ "InstanceCount": instance_count,
+ "VolumeSizeInGB": 30,
+ },
+ "hyperparameters": {
+ "sagemaker_program": json.dumps("dummy_script.py"),
+ "sagemaker_container_log_level": str(logging.INFO),
+ "sagemaker_job_name": json.dumps(JOB_NAME),
+ "sagemaker_submit_directory": json.dumps(
+ "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME)
+ ),
+ "sagemaker_region": '"us-east-1"',
+ },
+ "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
+ "tags": None,
+ "vpc_config": None,
+ "metric_definitions": None,
+ "environment": None,
+ "retry_strategy": None,
+ "experiment_config": EXPERIMENT_CONFIG,
+ "debugger_hook_config": {
+ "CollectionConfigurations": [],
+ "S3OutputPath": "s3://{}/".format(BUCKET_NAME),
+ },
+ "profiler_rule_configs": [
+ {
+ "RuleConfigurationName": "ProfilerReport-1510006209",
+ "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest",
+ "RuleParameters": {"rule_to_invoke": "ProfilerReport"},
+ }
+ ],
+ "profiler_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)},
+ }
+
+
+def test_unsupported_BYOC(
+ pytorch_training_compiler_version,
+):
+ byoc = (
+ "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:"
+ "1.12.0-"
+ "gpu-"
+ "py38-cu113-ubuntu20.04"
+ )
+ with pytest.raises(ValueError):
+ PyTorch(
+ image_uri=byoc,
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_count=INSTANCE_COUNT,
+ instance_type=INSTANCE_TYPE,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ ).fit()
+
+
+def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_version):
+ with pytest.raises(ValueError):
+ PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_count=INSTANCE_COUNT,
+ instance_type=cpu_instance_type,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ ).fit()
+
+
+@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES)
+def test_unsupported_gpu_instance(
+ unsupported_gpu_instance_class, pytorch_training_compiler_version
+):
+ with pytest.raises(ValueError):
+ PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_count=INSTANCE_COUNT,
+ instance_type=f"ml.{unsupported_gpu_instance_class}.xlarge",
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ ).fit()
+
+
+@pytest.mark.xfail(reason="With only 1 supported version, user input is ignored.")
+def test_unsupported_framework_version():
+ with pytest.raises(ValueError):
+ PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_count=INSTANCE_COUNT,
+ instance_type=INSTANCE_TYPE,
+ framework_version="99.99.99",
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ ).fit()
+
+
+def test_unsupported_python_2(
+ pytorch_training_compiler_version,
+):
+ with pytest.raises(ValueError):
+ PyTorch(
+ py_version="py27",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_count=INSTANCE_COUNT,
+ instance_type=INSTANCE_TYPE,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ ).fit()
+
+
+def test_unsupported_instance_group(
+ pytorch_training_compiler_version,
+):
+ if Version(pytorch_training_compiler_version) < Version("1.12"):
+ pytest.skip("This test is intended for PyTorch 1.12 and above")
+ with pytest.raises(ValueError):
+ PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_groups=[
+ InstanceGroup("ml.p3dn.24xlarge", "ml.p3dn.24xlarge", 16),
+ InstanceGroup("ml.p4d.24xlarge", "ml.p4d.24xlarge", 16),
+ ],
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ ).fit()
+
+
+def test_unsupported_distribution(
+ pytorch_training_compiler_version,
+):
+ if Version(pytorch_training_compiler_version) < Version("1.12"):
+ pytest.skip("This test is intended for PyTorch 1.12 and above")
+ with pytest.raises(ValueError):
+ PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_count=2,
+ instance_type=INSTANCE_TYPE,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ distribution={"smdistributed": {"dataparallel": {"enabled": True}}},
+ ).fit()
+
+ with pytest.raises(ValueError):
+ PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_count=2,
+ instance_type=INSTANCE_TYPE,
+ transformers_version="4.17",
+ pytorch_version="1.10",
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ distribution={"pytorchxla": {"enabled": True}},
+ ).fit()
+
+ with pytest.raises(ValueError):
+ PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ instance_count=2,
+ instance_type=INSTANCE_TYPE,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ distribution={"mpi": {"enabled": True}},
+ ).fit()
+
+
+@patch("sagemaker.utils.repack_model", MagicMock())
+@patch("sagemaker.utils.create_tar_file", MagicMock())
+@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
+@patch("time.time", return_value=TIME)
+@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
+def test_pytorchxla_distribution(
+ time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class
+):
+ if Version(pytorch_training_compiler_version) < Version("1.12"):
+ pytest.skip("This test is intended for PyTorch 1.12 and above")
+ compiler_config = TrainingCompilerConfig()
+ instance_type = f"ml.{instance_class}.xlarge"
+
+ pt = PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ sagemaker_session=sagemaker_session,
+ instance_count=2,
+ instance_type=instance_type,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(),
+ distribution={"pytorchxla": {"enabled": True}},
+ )
+
+ inputs = "s3://mybucket/train"
+
+ pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG)
+
+ sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
+ assert sagemaker_call_names == ["train", "logs_for_job"]
+ boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
+ assert boto_call_names == ["resource"]
+
+ expected_train_args = _create_train_job(
+ pytorch_training_compiler_version, instance_type, compiler_config, instance_count=2
+ )
+ expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
+ expected_train_args["enable_sagemaker_metrics"] = False
+ expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
+ True
+ )
+ expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(True)
+ expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
+ False
+ )
+
+ actual_train_args = sagemaker_session.method_calls[0][2]
+ assert (
+ actual_train_args == expected_train_args
+ ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}"
+
+
+@patch("sagemaker.utils.repack_model", MagicMock())
+@patch("sagemaker.utils.create_tar_file", MagicMock())
+@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
+@patch("time.time", return_value=TIME)
+@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES)
+def test_default_compiler_config(
+ time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class
+):
+ compiler_config = TrainingCompilerConfig()
+ instance_type = f"ml.{instance_class}.xlarge"
+
+ pt = PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ sagemaker_session=sagemaker_session,
+ instance_count=INSTANCE_COUNT,
+ instance_type=instance_type,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=compiler_config,
+ )
+
+ inputs = "s3://mybucket/train"
+
+ pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG)
+
+ sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
+ assert sagemaker_call_names == ["train", "logs_for_job"]
+ boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
+ assert boto_call_names == ["resource"]
+
+ expected_train_args = _create_train_job(
+ pytorch_training_compiler_version, instance_type, compiler_config
+ )
+ expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
+ expected_train_args["enable_sagemaker_metrics"] = False
+ expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
+ True
+ )
+ expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
+ False
+ )
+
+ actual_train_args = sagemaker_session.method_calls[0][2]
+ assert (
+ actual_train_args == expected_train_args
+ ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}"
+
+
+@patch("sagemaker.utils.repack_model", MagicMock())
+@patch("sagemaker.utils.create_tar_file", MagicMock())
+@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
+@patch("time.time", return_value=TIME)
+def test_debug_compiler_config(
+ time, name_from_base, sagemaker_session, pytorch_training_compiler_version
+):
+ compiler_config = TrainingCompilerConfig(debug=True)
+
+ pt = PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ sagemaker_session=sagemaker_session,
+ instance_count=INSTANCE_COUNT,
+ instance_type=INSTANCE_TYPE,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=compiler_config,
+ )
+
+ inputs = "s3://mybucket/train"
+
+ pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG)
+
+ sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
+ assert sagemaker_call_names == ["train", "logs_for_job"]
+ boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
+ assert boto_call_names == ["resource"]
+
+ expected_train_args = _create_train_job(
+ pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
+ )
+ expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
+ expected_train_args["enable_sagemaker_metrics"] = False
+ expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
+ True
+ )
+ expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
+ True
+ )
+
+ actual_train_args = sagemaker_session.method_calls[0][2]
+ assert (
+ actual_train_args == expected_train_args
+ ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}"
+
+
+@patch("sagemaker.utils.repack_model", MagicMock())
+@patch("sagemaker.utils.create_tar_file", MagicMock())
+@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME)
+@patch("time.time", return_value=TIME)
+def test_disable_compiler_config(
+ time, name_from_base, sagemaker_session, pytorch_training_compiler_version
+):
+ compiler_config = TrainingCompilerConfig(enabled=False)
+
+ pt = PyTorch(
+ py_version="py38",
+ entry_point=SCRIPT_PATH,
+ role=ROLE,
+ sagemaker_session=sagemaker_session,
+ instance_count=INSTANCE_COUNT,
+ instance_type=INSTANCE_TYPE,
+ framework_version=pytorch_training_compiler_version,
+ enable_sagemaker_metrics=False,
+ compiler_config=TrainingCompilerConfig(enabled=False),
+ )
+
+ inputs = "s3://mybucket/train"
+
+ pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG)
+
+ sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls]
+ assert sagemaker_call_names == ["train", "logs_for_job"]
+ boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls]
+ assert boto_call_names == ["resource"]
+
+ expected_train_args = _create_train_job(
+ pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config
+ )
+ expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
+ expected_train_args["enable_sagemaker_metrics"] = False
+ expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps(
+ False
+ )
+ expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps(
+ False
+ )
+
+ actual_train_args = sagemaker_session.method_calls[0][2]
+ assert (
+ actual_train_args == expected_train_args
+ ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}"
+
+
+@pytest.mark.parametrize(
+ ["compiler_enabled", "debug_enabled"], [(True, False), (True, True), (False, False)]
+)
+def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
+ training_image = (
+ "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:"
+ "1.12.0-"
+ "gpu-"
+ "py38-cu113-ubuntu20.04"
+ )
+ returned_job_description = {
+ "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image},
+ "HyperParameters": {
+ "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"',
+ "sagemaker_program": '"iris-dnn-classifier.py"',
+ "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"',
+ "sagemaker_container_log_level": '"logging.INFO"',
+ "sagemaker_job_name": '"trcomp"',
+ "training_steps": "100",
+ "sagemaker_region": '"us-east-1"',
+ TrainingCompilerConfig.HP_ENABLE_COMPILER: json.dumps(compiler_enabled),
+ TrainingCompilerConfig.HP_ENABLE_DEBUG: json.dumps(debug_enabled),
+ },
+ "RoleArn": "arn:aws:iam::366:role/SageMakerRole",
+ "ResourceConfig": {
+ "VolumeSizeInGB": 30,
+ "InstanceCount": 1,
+ "InstanceType": "ml.p3.2xlarge",
+ },
+ "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60},
+ "TrainingJobName": "trcomp",
+ "TrainingJobStatus": "Completed",
+ "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/trcomp",
+ "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/trcomp"},
+ "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
+ }
+ sagemaker_session.sagemaker_client.describe_training_job = Mock(
+ name="describe_training_job", return_value=returned_job_description
+ )
+
+ estimator = PyTorch.attach(training_job_name="trcomp", sagemaker_session=sagemaker_session)
+ assert estimator.latest_training_job.job_name == "trcomp"
+ assert estimator.py_version == "py38"
+ assert estimator.framework_version == "1.12.0"
+ assert estimator.role == "arn:aws:iam::366:role/SageMakerRole"
+ assert estimator.instance_count == 1
+ assert estimator.max_run == 24 * 60 * 60
+ assert estimator.input_mode == "File"
+ assert estimator.base_job_name == "trcomp"
+ assert estimator.output_path == "s3://place/output/trcomp"
+ assert estimator.output_kms_key == ""
+ assert estimator.hyperparameters()["training_steps"] == "100"
+ assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_COMPILER] == json.dumps(
+ compiler_enabled
+ )
+ assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_DEBUG] == json.dumps(
+ debug_enabled
+ )
+ assert estimator.source_dir == "s3://some/sourcedir.tar.gz"
+ assert estimator.entry_point == "iris-dnn-classifier.py"
+
+
+@patch("sagemaker.utils.repack_model", MagicMock())
+@patch("sagemaker.utils.create_tar_file", MagicMock())
+def test_register_pytorch_model_auto_infer_framework(
+ sagemaker_session, pytorch_training_compiler_version
+):
+
+ model_package_group_name = "test-pt-register-model"
+ content_types = ["application/json"]
+ response_types = ["application/json"]
+ inference_instances = ["ml.m4.xlarge"]
+ transform_instances = ["ml.m4.xlarge"]
+ image_uri = "fakeimage"
+
+ pt_model = PyTorchModel(
+ model_data="s3://some/data.tar.gz",
+ role=ROLE,
+ entry_point=SCRIPT_PATH,
+ framework_version=pytorch_training_compiler_version,
+ py_version="py38",
+ sagemaker_session=sagemaker_session,
+ )
+
+ pt_model.register(
+ content_types,
+ response_types,
+ inference_instances,
+ transform_instances,
+ model_package_group_name=model_package_group_name,
+ marketplace_cert=True,
+ image_uri=image_uri,
+ )
+
+ expected_create_model_package_request = {
+ "containers": [
+ {
+ "Image": image_uri,
+ "Environment": ANY,
+ "ModelDataUrl": ANY,
+ "Framework": "PYTORCH",
+ "FrameworkVersion": pytorch_training_compiler_version,
+ }
+ ],
+ "content_types": content_types,
+ "response_types": response_types,
+ "inference_instances": inference_instances,
+ "transform_instances": transform_instances,
+ "model_package_group_name": model_package_group_name,
+ "marketplace_cert": True,
+ }
+
+ sagemaker_session.create_model_package_from_containers.assert_called_with(
+ **expected_create_model_package_request
+ )
diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py
index 7517f3a641..1ce58a19b4 100644
--- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py
+++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py
@@ -50,6 +50,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
diff --git a/tests/unit/sagemaker/utilities/test_search_expression.py b/tests/unit/sagemaker/utilities/test_search_expression.py
new file mode 100644
index 0000000000..98a52a992a
--- /dev/null
+++ b/tests/unit/sagemaker/utilities/test_search_expression.py
@@ -0,0 +1,80 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+import pytest
+
+from sagemaker.utilities.search_expression import (
+ Filter,
+ Operator,
+ NestedFilter,
+ SearchExpression,
+ BooleanOperator,
+)
+
+
+def test_filters():
+ search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1")
+
+ assert {
+ "Name": "learning_rate",
+ "Operator": "Equals",
+ "Value": "0.1",
+ } == search_filter.to_boto()
+
+
+def test_partial_filters():
+ search_filter = Filter(name="learning_rate")
+
+ assert {"Name": "learning_rate"} == search_filter.to_boto()
+
+
+def test_nested_filters():
+ search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1")
+ filters = [search_filter]
+ nested_filters = NestedFilter(property_name="hyper_param", filters=filters)
+
+ assert {
+ "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}],
+ "NestedPropertyName": "hyper_param",
+ } == nested_filters.to_boto()
+
+
+def test_search_expression():
+ search_filter = Filter(name="learning_rate", operator=Operator.EQUALS, value="0.1")
+ nested_filter = NestedFilter(property_name="hyper_param", filters=[search_filter])
+ search_expression = SearchExpression(
+ filters=[search_filter],
+ nested_filters=[nested_filter],
+ sub_expressions=[],
+ boolean_operator=BooleanOperator.AND,
+ )
+
+ assert {
+ "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}],
+ "NestedFilters": [
+ {
+ "Filters": [{"Name": "learning_rate", "Operator": "Equals", "Value": "0.1"}],
+ "NestedPropertyName": "hyper_param",
+ }
+ ],
+ "SubExpressions": [],
+ "Operator": "And",
+ } == search_expression.to_boto()
+
+
+def test_illegal_search_expression():
+ with pytest.raises(
+ ValueError, match="You must specify at least one subexpression, filter, or nested filter"
+ ):
+ SearchExpression()
diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py
new file mode 100644
index 0000000000..9ea3d0bcac
--- /dev/null
+++ b/tests/unit/sagemaker/workflow/conftest.py
@@ -0,0 +1,75 @@
+# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"). You
+# may not use this file except in compliance with the License. A copy of
+# the License is located at
+#
+# http://aws.amazon.com/apache2.0/
+#
+# or in the "license" file accompanying this file. This file is
+# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
+# ANY KIND, either express or implied. See the License for the specific
+# language governing permissions and limitations under the License.
+from __future__ import absolute_import
+
+from unittest.mock import Mock, PropertyMock
+
+import pytest
+
+from sagemaker import Session
+from sagemaker.workflow.pipeline_context import PipelineSession
+
+REGION = "us-west-2"
+BUCKET = "my-bucket"
+ROLE = "DummyRole"
+IMAGE_URI = "fakeimage"
+
+
+@pytest.fixture(scope="module")
+def client():
+ """Mock client.
+
+ Considerations when appropriate:
+
+ * utilize botocore.stub.Stubber
+ * separate runtime client from client
+ """
+ client_mock = Mock()
+ client_mock._client_config.user_agent = (
+ "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
+ )
+ return client_mock
+
+
+@pytest.fixture(scope="module")
+def boto_session(client):
+ role_mock = Mock()
+ type(role_mock).arn = PropertyMock(return_value=ROLE)
+
+ resource_mock = Mock()
+ resource_mock.Role.return_value = role_mock
+
+ session_mock = Mock(region_name=REGION)
+ session_mock.resource.return_value = resource_mock
+ session_mock.client.return_value = client
+
+ return session_mock
+
+
+@pytest.fixture(scope="module")
+def pipeline_session(boto_session, client):
+ return PipelineSession(
+ boto_session=boto_session,
+ sagemaker_client=client,
+ default_bucket=BUCKET,
+ )
+
+
+@pytest.fixture(scope="module")
+def sagemaker_session(boto_session, client):
+ return Session(
+ boto_session=boto_session,
+ sagemaker_client=client,
+ sagemaker_runtime_client=client,
+ default_bucket=BUCKET,
+ )
diff --git a/tests/unit/sagemaker/workflow/test_clarify_check_step.py b/tests/unit/sagemaker/workflow/test_clarify_check_step.py
index feadaa03dc..54b354b71e 100644
--- a/tests/unit/sagemaker/workflow/test_clarify_check_step.py
+++ b/tests/unit/sagemaker/workflow/test_clarify_check_step.py
@@ -16,10 +16,6 @@
import re
import pytest
-import sagemaker
-
-from mock import Mock, PropertyMock
-
from sagemaker.clarify import (
DataConfig,
BiasConfig,
@@ -50,46 +46,6 @@
_S3_ANALYSIS_CONFIG_OUTPUT_PATH = "s3://my_bucket/analysis_cfg_output"
-@pytest.fixture
-def boto_session():
- role_mock = Mock()
- type(role_mock).arn = PropertyMock(return_value=_ROLE)
-
- resource_mock = Mock()
- resource_mock.Role.return_value = role_mock
-
- session_mock = Mock(region_name=_REGION)
- session_mock.resource.return_value = resource_mock
-
- return session_mock
-
-
-@pytest.fixture
-def client():
- """Mock client.
-
- Considerations when appropriate:
-
- * utilize botocore.stub.Stubber
- * separate runtime client from client
- """
- client_mock = Mock()
- client_mock._client_config.user_agent = (
- "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
- )
- return client_mock
-
-
-@pytest.fixture
-def sagemaker_session(boto_session, client):
- return sagemaker.session.Session(
- boto_session=boto_session,
- sagemaker_client=client,
- sagemaker_runtime_client=client,
- default_bucket=_DEFAULT_BUCKET,
- )
-
-
_expected_data_bias_dsl = {
"Name": "DataBiasCheckStep",
"Type": "ClarifyCheck",
diff --git a/tests/unit/sagemaker/workflow/test_entities.py b/tests/unit/sagemaker/workflow/test_entities.py
index 6f0be2ccca..a36207b241 100644
--- a/tests/unit/sagemaker/workflow/test_entities.py
+++ b/tests/unit/sagemaker/workflow/test_entities.py
@@ -19,9 +19,6 @@
from enum import Enum
-from mock.mock import Mock, PropertyMock
-
-import sagemaker
from sagemaker.workflow.condition_step import ConditionStep
from sagemaker.workflow.conditions import ConditionGreaterThan
from sagemaker.workflow.entities import (
@@ -58,46 +55,6 @@ def custom_entity_list():
return [CustomEntity(1), CustomEntity(2)]
-@pytest.fixture
-def boto_session():
- role_mock = Mock()
- type(role_mock).arn = PropertyMock(return_value="role")
-
- resource_mock = Mock()
- resource_mock.Role.return_value = role_mock
-
- session_mock = Mock(region_name="us-west-2")
- session_mock.resource.return_value = resource_mock
-
- return session_mock
-
-
-@pytest.fixture
-def client():
- """Mock client.
-
- Considerations when appropriate:
-
- * utilize botocore.stub.Stubber
- * separate runtime client from client
- """
- client_mock = Mock()
- client_mock._client_config.user_agent = (
- "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
- )
- return client_mock
-
-
-@pytest.fixture
-def sagemaker_session(boto_session, client):
- return sagemaker.session.Session(
- boto_session=boto_session,
- sagemaker_client=client,
- sagemaker_runtime_client=client,
- default_bucket="my-bucket",
- )
-
-
def test_entity(custom_entity):
request_struct = {"foo": 1}
assert custom_entity.to_request() == request_struct
diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py
index 080e70ca62..2216299d3b 100644
--- a/tests/unit/sagemaker/workflow/test_model_step.py
+++ b/tests/unit/sagemaker/workflow/test_model_step.py
@@ -15,7 +15,7 @@
import json
import os
-from mock import Mock, PropertyMock, patch
+from mock import patch
import pytest
@@ -43,7 +43,6 @@
)
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
-from sagemaker.workflow.pipeline_context import PipelineSession
from sagemaker.workflow.retry import (
StepRetryPolicy,
StepExceptionTypeEnum,
@@ -55,11 +54,9 @@
from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum
from tests.unit import DATA_DIR
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered
+from tests.unit.sagemaker.workflow.conftest import BUCKET, ROLE
_IMAGE_URI = "fakeimage"
-_REGION = "us-west-2"
-_BUCKET = "my-bucket"
-_ROLE = "DummyRole"
_INSTANCE_TYPE = "ml.m4.xlarge"
_SAGEMAKER_PROGRAM = SCRIPT_PARAM_NAME.upper()
@@ -69,60 +66,10 @@
_XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone")
_TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies")
_REPACK_OUTPUT_KEY_PREFIX = "code-output"
-_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}"
+_MODEL_CODE_LOCATION = f"s3://{BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}"
_MODEL_CODE_LOCATION_TRAILING_SLASH = _MODEL_CODE_LOCATION + "/"
-@pytest.fixture
-def client():
- """Mock client.
-
- Considerations when appropriate:
-
- * utilize botocore.stub.Stubber
- * separate runtime client from client
- """
- client_mock = Mock()
- client_mock._client_config.user_agent = (
- "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
- )
- return client_mock
-
-
-@pytest.fixture
-def boto_session(client):
- role_mock = Mock()
- type(role_mock).arn = PropertyMock(return_value=_ROLE)
-
- resource_mock = Mock()
- resource_mock.Role.return_value = role_mock
-
- session_mock = Mock(region_name=_REGION)
- session_mock.resource.return_value = resource_mock
- session_mock.client.return_value = client
-
- return session_mock
-
-
-@pytest.fixture
-def pipeline_session(boto_session, client):
- return PipelineSession(
- boto_session=boto_session,
- sagemaker_client=client,
- default_bucket=_BUCKET,
- )
-
-
-@pytest.fixture
-def sagemaker_session(boto_session, client):
- return Session(
- boto_session=boto_session,
- sagemaker_client=client,
- sagemaker_runtime_client=client,
- default_bucket=_BUCKET,
- )
-
-
@pytest.fixture
def model_data_param():
return ParameterString(name="ModelData", default_value="s3://my-bucket/file")
@@ -137,7 +84,7 @@ def model(pipeline_session, model_data_param):
sagemaker_session=pipeline_session,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
source_dir=f"{DATA_DIR}",
- role=_ROLE,
+ role=ROLE,
)
@@ -322,13 +269,13 @@ def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_
sparkml_model = SparkMLModel(
name="MySparkMLModel",
model_data=model_data_param,
- role=_ROLE,
+ role=ROLE,
sagemaker_session=pipeline_session,
env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
)
# The model need to runtime repack
ppl_model = PipelineModel(
- models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session
+ models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session
)
step_args = ppl_model.create(
instance_type="c4.4xlarge",
@@ -417,7 +364,7 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat
# The model no need to runtime repack, since source_dir is missing
sparkml_model = SparkMLModel(
model_data=model_data_param,
- role=_ROLE,
+ role=ROLE,
sagemaker_session=pipeline_session,
env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
@@ -429,11 +376,11 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat
sagemaker_session=pipeline_session,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
source_dir=f"{DATA_DIR}",
- role=_ROLE,
+ role=ROLE,
env={"k": "v"},
)
model = PipelineModel(
- models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session
+ models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session
)
step_args = model.register(
content_types=["text/csv"],
@@ -516,7 +463,7 @@ def test_register_model_without_repack(pipeline_session):
model_data=model_data,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
sagemaker_session=pipeline_session,
- role=_ROLE,
+ role=ROLE,
)
step_args = model.register(
content_types=["text/csv"],
@@ -547,7 +494,7 @@ def test_register_model_without_repack(pipeline_session):
assert containers[0]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME
assert (
containers[0]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY]
- == f"s3://{_BUCKET}/{model_name}/sourcedir.tar.gz"
+ == f"s3://{BUCKET}/{model_name}/sourcedir.tar.gz"
)
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
assert ordered(adjacency_list) == ordered({"MyModelStep-RegisterModel": []})
@@ -560,11 +507,11 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session):
model = Model(
name=model_name,
image_uri=_IMAGE_URI,
- model_data=f"s3://{_BUCKET}/model.tar.gz",
+ model_data=f"s3://{BUCKET}/model.tar.gz",
sagemaker_session=pipeline_session,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
source_dir=f"{DATA_DIR}",
- role=_ROLE,
+ role=ROLE,
)
step_args = model.create(
instance_type="c4.4xlarge",
@@ -582,7 +529,7 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session):
arguments = step_dsl_list[0]["Arguments"]
assert arguments["PrimaryContainer"]["Image"] == _IMAGE_URI
assert (
- arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{_BUCKET}/{model_name}/model.tar.gz"
+ arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{BUCKET}/{model_name}/model.tar.gz"
)
assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME
assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == _DIR_NAME
@@ -700,7 +647,7 @@ def test_conditional_model_create_and_regis(
model_data="dummy_model_data",
image_uri=_IMAGE_URI,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
- role=_ROLE,
+ role=ROLE,
enable_network_isolation=True,
code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH,
),
@@ -713,7 +660,7 @@ def test_conditional_model_create_and_regis(
framework_version="1.11.0",
image_uri=_IMAGE_URI,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
- role=_ROLE,
+ role=ROLE,
enable_network_isolation=False,
),
1,
@@ -724,7 +671,7 @@ def test_conditional_model_create_and_regis(
model_data="dummy_model_data",
image_uri=_IMAGE_URI,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
- role=_ROLE,
+ role=ROLE,
framework_version="1.5.0",
code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH,
),
@@ -736,7 +683,7 @@ def test_conditional_model_create_and_regis(
model_data="dummy_model_data",
image_uri=_IMAGE_URI,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
- role=_ROLE,
+ role=ROLE,
framework_version="1.2.0",
),
1,
@@ -747,7 +694,7 @@ def test_conditional_model_create_and_regis(
model_data="dummy_model_data",
image_uri=_IMAGE_URI,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
- role=_ROLE,
+ role=ROLE,
),
2,
),
@@ -757,7 +704,7 @@ def test_conditional_model_create_and_regis(
model_data="dummy_model_data",
image_uri=_IMAGE_URI,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
- role=_ROLE,
+ role=ROLE,
code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH,
),
2,
@@ -768,7 +715,7 @@ def test_conditional_model_create_and_regis(
model_data="dummy_model_data",
image_uri=_IMAGE_URI,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
- role=_ROLE,
+ role=ROLE,
),
1,
),
@@ -789,7 +736,7 @@ def assert_test_result(steps: list):
)
else:
assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == (
- f"s3://{_BUCKET}/{model.name}"
+ f"s3://{BUCKET}/{model.name}"
)
model, expected_step_num = test_input
@@ -828,7 +775,7 @@ def assert_test_result(steps: list):
XGBoostModel(
model_data="dummy_model_step",
framework_version="1.3-1",
- role=_ROLE,
+ role=ROLE,
entry_point=os.path.join(_XGBOOST_PATH, "inference.py"),
enable_network_isolation=True,
),
@@ -845,7 +792,7 @@ def assert_test_result(steps: list):
XGBoostModel(
model_data="dummy_model_step",
framework_version="1.3-1",
- role=_ROLE,
+ role=ROLE,
entry_point=os.path.join(_XGBOOST_PATH, "inference.py"),
),
{
@@ -861,7 +808,7 @@ def assert_test_result(steps: list):
XGBoostModel(
model_data="dummy_model_step",
framework_version="1.3-1",
- role=_ROLE,
+ role=ROLE,
entry_point=None,
),
{
@@ -876,9 +823,8 @@ def assert_test_result(steps: list):
(
TensorFlowModel(
model_data="dummy_model_step",
- role=_ROLE,
+ role=ROLE,
image_uri=_IMAGE_URI,
- sagemaker_session=pipeline_session,
entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"),
),
{
@@ -893,9 +839,8 @@ def assert_test_result(steps: list):
(
TensorFlowModel(
model_data="dummy_model_step",
- role=_ROLE,
+ role=ROLE,
image_uri=_IMAGE_URI,
- sagemaker_session=pipeline_session,
),
{
"expected_step_num": 1,
@@ -941,7 +886,7 @@ def test_request_compare_of_register_model_under_different_sessions(
_verify_register_model_container_definition(regis_step_arg, expect, dict)
# Get create model package request under Session
- model.model_data = f"s3://{_BUCKET}"
+ model.model_data = f"s3://{BUCKET}"
model.sagemaker_session = sagemaker_session
with patch.object(
Session, "_intercept_create_request", return_value=dict(ModelPackageArn="arn:aws")
@@ -996,7 +941,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session):
model_data=lambda_step.properties.Outputs["model_artifact"],
sagemaker_session=pipeline_session,
entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}",
- role=_ROLE,
+ role=ROLE,
)
step_create_model = ModelStep(name="mymodelstep", step_args=model.create())
@@ -1031,7 +976,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session):
(
Processor(
image_uri=_IMAGE_URI,
- role=_ROLE,
+ role=ROLE,
instance_count=1,
instance_type=_INSTANCE_TYPE,
),
@@ -1052,7 +997,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session):
(
HyperparameterTuner(
estimator=Estimator(
- role=_ROLE,
+ role=ROLE,
instance_count=1,
instance_type=_INSTANCE_TYPE,
image_uri=_IMAGE_URI,
@@ -1064,7 +1009,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session):
),
(
Estimator(
- role=_ROLE,
+ role=ROLE,
instance_count=1,
instance_type=_INSTANCE_TYPE,
image_uri=_IMAGE_URI,
@@ -1128,3 +1073,31 @@ def test_pass_in_wrong_type_of_retry_policies(pipeline_session, model):
),
)
assert "SageMakerJobStepRetryPolicy is not allowed for a create/registe" in str(error.value)
+
+
+def test_register_model_step_with_model_package_name(pipeline_session):
+ model = Model(
+ name="MyModel",
+ image_uri="my-image",
+ model_data="s3://",
+ sagemaker_session=pipeline_session,
+ )
+ step_args = model.register(
+ content_types=["text/csv"],
+ response_types=["text/csv"],
+ inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
+ transform_instances=["ml.m5.xlarge"],
+ model_package_name="model-pkg-name-will-be-popped-out",
+ )
+ regis_model_step = ModelStep(
+ name="MyModelStep",
+ step_args=step_args,
+ )
+ pipeline = Pipeline(
+ name="MyPipeline",
+ steps=[regis_model_step],
+ sagemaker_session=pipeline_session,
+ )
+ steps = json.loads(pipeline.definition())["Steps"]
+ assert len(steps) == 1
+ assert "ModelPackageName" not in steps[0]["Arguments"]
diff --git a/tests/unit/sagemaker/workflow/test_pipeline.py b/tests/unit/sagemaker/workflow/test_pipeline.py
index 327443aee7..f0cb2e5234 100644
--- a/tests/unit/sagemaker/workflow/test_pipeline.py
+++ b/tests/unit/sagemaker/workflow/test_pipeline.py
@@ -17,7 +17,7 @@
import pytest
-from mock import Mock
+from mock import Mock, patch
from sagemaker import s3
from sagemaker.workflow.condition_step import ConditionStep
@@ -78,6 +78,7 @@ def test_pipeline_create_with_parallelism_config(sagemaker_session_mock, role_ar
)
+@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
def test_large_pipeline_create(sagemaker_session_mock, role_arn):
parameter = ParameterString("MyStr")
pipeline = Pipeline(
@@ -87,8 +88,6 @@ def test_large_pipeline_create(sagemaker_session_mock, role_arn):
sagemaker_session=sagemaker_session_mock,
)
- s3.S3Uploader.upload_string_as_file_body = Mock()
-
pipeline.create(role_arn=role_arn)
assert s3.S3Uploader.upload_string_as_file_body.called_with(
@@ -151,6 +150,7 @@ def test_pipeline_update_with_parallelism_config(sagemaker_session_mock, role_ar
)
+@patch("sagemaker.s3.S3Uploader.upload_string_as_file_body")
def test_large_pipeline_update(sagemaker_session_mock, role_arn):
parameter = ParameterString("MyStr")
pipeline = Pipeline(
@@ -160,8 +160,6 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
sagemaker_session=sagemaker_session_mock,
)
- s3.S3Uploader.upload_string_as_file_body = Mock()
-
pipeline.create(role_arn=role_arn)
assert s3.S3Uploader.upload_string_as_file_body.called_with(
diff --git a/tests/unit/sagemaker/workflow/test_processing_step.py b/tests/unit/sagemaker/workflow/test_processing_step.py
index fd32fd7c73..9ba242b9b2 100644
--- a/tests/unit/sagemaker/workflow/test_processing_step.py
+++ b/tests/unit/sagemaker/workflow/test_processing_step.py
@@ -13,7 +13,8 @@
from __future__ import absolute_import
import json
-from mock import Mock, PropertyMock
+import os
+from mock import Mock, PropertyMock, patch
import pytest
import warnings
@@ -45,9 +46,11 @@
from sagemaker.workflow.steps import CacheConfig, ProcessingStep
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
+from sagemaker.workflow.pipeline_context import _PipelineConfig
from sagemaker.workflow.properties import PropertyFile
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.functions import Join
+from sagemaker.workflow.utilities import hash_files_or_dirs
from sagemaker.workflow import is_pipeline_variable
from sagemaker.network import NetworkConfig
@@ -63,6 +66,7 @@
SHAPConfig,
)
from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered, get_step_args_helper
+from tests.unit import DATA_DIR
REGION = "us-west-2"
BUCKET = "my-bucket"
@@ -70,7 +74,20 @@
IMAGE_URI = "fakeimage"
MODEL_NAME = "gisele"
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
+LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "workflow/abalone/preprocessing.py")
+SPARK_APP_JAR_PATH = os.path.join(
+ DATA_DIR, "spark/code/java/hello-java-spark/HelloJavaSparkApp.jar"
+)
+SPARK_DEP_JAR = os.path.join(DATA_DIR, "spark/code/java/TestJarFile.jar")
+SPARK_APP_PY_PATH = os.path.join(DATA_DIR, "spark/code/python/hello_py_spark/hello_py_spark_app.py")
+SPARK_PY_FILE1 = os.path.join(DATA_DIR, "spark/code/python/hello_py_spark/__init__.py")
+SPARK_PY_FILE2 = os.path.join(DATA_DIR, "spark/code/python/hello_py_spark/hello_py_spark_udfs.py")
+SPARK_SUBMIT_FILE1 = os.path.join(DATA_DIR, "spark/files/data.jsonl")
+SPARK_SUBMIT_FILE2 = os.path.join(DATA_DIR, "spark/files/sample_spark_event_logs")
INSTANCE_TYPE = "ml.m4.xlarge"
+MOCKED_PIPELINE_CONFIG = _PipelineConfig(
+ "MyPipeline", "MyProcessingStep", hash_files_or_dirs([LOCAL_SCRIPT_PATH]), "config-hash-abcdefg"
+)
FRAMEWORK_PROCESSOR = [
(
@@ -154,6 +171,19 @@
),
]
+FRAMEWORK_PROCESSOR_LOCAL_CODE = [
+ (
+ FrameworkProcessor(
+ framework_version="1.8",
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ role=ROLE,
+ estimator_cls=PyTorch,
+ ),
+ {"code": LOCAL_SCRIPT_PATH},
+ ),
+]
+
PROCESSING_INPUT = [
ProcessingInput(source="s3://my-bucket/processing_manifest", destination="processing_manifest"),
ProcessingInput(
@@ -318,7 +348,8 @@ def test_processing_step_with_processor(
else:
expected_step_arguments["ExperimentConfig"] = expected_experiment_config
- assert json.loads(pipeline.definition())["Steps"][0] == {
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == {
"Name": "MyProcessingStep",
"Description": "ProcessingStep description",
"DisplayName": "MyProcessingStep",
@@ -346,6 +377,10 @@ def test_processing_step_with_processor(
}
)
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == step_def2
+
@pytest.mark.parametrize(
"image_uri",
@@ -387,7 +422,11 @@ def test_processing_step_with_processor_and_step_args(
assert isinstance(e, ValueError)
-def test_processing_step_with_script_processor(pipeline_session, processing_input, network_config):
+@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
+@pytest.mark.parametrize("code_artifact", [DUMMY_S3_SCRIPT_PATH, LOCAL_SCRIPT_PATH])
+def test_processing_step_with_script_processor(
+ pipeline_session, processing_input, network_config, code_artifact
+):
processor = ScriptProcessor(
role=ROLE,
image_uri=IMAGE_URI,
@@ -406,7 +445,7 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
)
step_args = processor.run(
- inputs=processing_input, code=DUMMY_S3_SCRIPT_PATH, job_name="my-processing-job"
+ inputs=processing_input, code=code_artifact, job_name="my-processing-job"
)
step = ProcessingStep(
@@ -420,11 +459,13 @@ def test_processing_step_with_script_processor(pipeline_session, processing_inpu
sagemaker_session=pipeline_session,
)
- assert json.loads(pipeline.definition())["Steps"][0] == {
- "Name": "MyProcessingStep",
- "Type": "Processing",
- "Arguments": get_step_args_helper(step_args, "Processing"),
- }
+ step_args = get_step_args_helper(step_args, "Processing")
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == {"Name": "MyProcessingStep", "Type": "Processing", "Arguments": step_args}
+
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == step_def2
@pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR)
@@ -477,6 +518,66 @@ def test_processing_step_with_framework_processor(
"Arguments": step_args,
}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ del step_def2["Arguments"]["ProcessingInputs"][0]["S3Input"]["S3Uri"]
+ del step_def2["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
+ assert step_def == step_def2
+
+
+@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
+@pytest.mark.parametrize("framework_processor", FRAMEWORK_PROCESSOR_LOCAL_CODE)
+def test_processing_step_with_framework_processor_local_code(
+ framework_processor, pipeline_session, network_config
+):
+ processor, run_inputs = framework_processor
+ processor.sagemaker_session = pipeline_session
+ processor.role = ROLE
+
+ processor.volume_kms_key = "volume-kms-key"
+ processor.network_config = network_config
+
+ processing_input = ProcessingInput(
+ source="s3://my-bucket/processing_manifest",
+ destination="processing_manifest",
+ input_name="manifest",
+ )
+ processing_output = ProcessingOutput(
+ output_name="framework_output", source="/opt/ml/processing/framework_output"
+ )
+
+ run_inputs["inputs"] = [processing_input]
+ run_inputs["outputs"] = [processing_output]
+
+ step_args = processor.run(**run_inputs)
+
+ step = ProcessingStep(
+ name="MyProcessingStep",
+ step_args=step_args,
+ )
+ pipeline = Pipeline(
+ name="MyPipeline",
+ steps=[step],
+ sagemaker_session=pipeline_session,
+ )
+
+ step_args = get_step_args_helper(step_args, "Processing")
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+
+ del step_args["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
+ del step_def["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
+
+ assert step_def == {
+ "Name": "MyProcessingStep",
+ "Type": "Processing",
+ "Arguments": step_args,
+ }
+
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ del step_def2["Arguments"]["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
+ assert step_def == step_def2
+
def test_processing_step_with_clarify_processor(pipeline_session):
def headers():
@@ -530,12 +631,17 @@ def verify(step_args):
steps=[step],
sagemaker_session=pipeline_session,
)
- assert json.loads(pipeline.definition())["Steps"][0] == {
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == {
"Name": "MyProcessingStep",
"Type": "Processing",
"Arguments": get_step_args_helper(step_args, "Processing"),
}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == step_def2
+
test_run = utils.unique_name_from_base("test_run")
output_path = "s3://{}/{}/{}".format(
pipeline_session.default_bucket(), "linear_learner_analysis_result", test_run
@@ -852,4 +958,153 @@ def test_spark_processor(spark_processor, processing_input, pipeline_session):
steps=[step],
sagemaker_session=pipeline_session,
)
- pipeline.definition()
+
+ # test for idempotency
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+ step_def_2 = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == step_def_2
+
+
+@pytest.mark.parametrize(
+ "spark_processor",
+ [
+ (
+ SparkJarProcessor(
+ role=ROLE,
+ framework_version="2.4",
+ instance_count=1,
+ instance_type=INSTANCE_TYPE,
+ ),
+ {
+ "submit_app": SPARK_APP_JAR_PATH,
+ "submit_class": "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
+ "arguments": [
+ "--input",
+ "input-data-uri",
+ "--output",
+ ParameterString("MyArgOutput"),
+ ],
+ "submit_jars": [
+ SPARK_DEP_JAR,
+ ],
+ "submit_files": [
+ SPARK_SUBMIT_FILE1,
+ SPARK_SUBMIT_FILE2,
+ ],
+ "spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"),
+ "configuration": {
+ "Classification": "core-site",
+ "Properties": {"hadoop.security.groups.cache.secs": "250"},
+ },
+ },
+ ),
+ (
+ PySparkProcessor(
+ role=ROLE,
+ framework_version="2.4",
+ instance_count=1,
+ instance_type=INSTANCE_TYPE,
+ ),
+ {
+ "submit_app": SPARK_APP_PY_PATH,
+ "arguments": [
+ "--input",
+ "input-data-uri",
+ "--output",
+ ParameterString("MyArgOutput"),
+ ],
+ "submit_py_files": [
+ SPARK_PY_FILE1,
+ SPARK_PY_FILE2,
+ ],
+ "submit_jars": [SPARK_DEP_JAR],
+ "submit_files": [SPARK_SUBMIT_FILE1, SPARK_SUBMIT_FILE2],
+ "spark_event_logs_s3_uri": ParameterString("MySparkEventLogS3Uri"),
+ "configuration": {
+ "Classification": "core-site",
+ "Properties": {"hadoop.security.groups.cache.secs": "250"},
+ },
+ },
+ ),
+ ],
+)
+def test_spark_processor_local_code(spark_processor, processing_input, pipeline_session):
+ processor, run_inputs = spark_processor
+ processor.sagemaker_session = pipeline_session
+ processor.role = ROLE
+
+ run_inputs["inputs"] = processing_input
+
+ step_args = processor.run(**run_inputs)
+ step = ProcessingStep(
+ name="MyProcessingStep",
+ step_args=step_args,
+ )
+
+ step_args = get_step_args_helper(step_args, "Processing")
+
+ assert step_args["AppSpecification"]["ContainerArguments"] == run_inputs["arguments"]
+
+ entry_points = step_args["AppSpecification"]["ContainerEntrypoint"]
+ entry_points_expr = []
+ for entry_point in entry_points:
+ if is_pipeline_variable(entry_point):
+ entry_points_expr.append(entry_point.expr)
+ else:
+ entry_points_expr.append(entry_point)
+
+ if "submit_py_files" in run_inputs:
+ expected = [
+ "smspark-submit",
+ "--py-files",
+ "/opt/ml/processing/input/py-files",
+ "--jars",
+ "/opt/ml/processing/input/jars",
+ "--files",
+ "/opt/ml/processing/input/files",
+ "--local-spark-event-logs-dir",
+ "/opt/ml/processing/spark-events/",
+ "/opt/ml/processing/input/code/hello_py_spark_app.py",
+ ]
+ # py spark
+ else:
+ expected = [
+ "smspark-submit",
+ "--class",
+ "com.amazonaws.sagemaker.spark.test.HelloJavaSparkApp",
+ "--jars",
+ "/opt/ml/processing/input/jars",
+ "--files",
+ "/opt/ml/processing/input/files",
+ "--local-spark-event-logs-dir",
+ "/opt/ml/processing/spark-events/",
+ "/opt/ml/processing/input/code/HelloJavaSparkApp.jar",
+ ]
+
+ assert entry_points_expr == expected
+ for output in step_args["ProcessingOutputConfig"]["Outputs"]:
+ if is_pipeline_variable(output["S3Output"]["S3Uri"]):
+ output["S3Output"]["S3Uri"] = output["S3Output"]["S3Uri"].expr
+
+ assert step_args["ProcessingOutputConfig"]["Outputs"] == [
+ {
+ "OutputName": "output-1",
+ "AppManaged": False,
+ "S3Output": {
+ "S3Uri": {"Get": "Parameters.MySparkEventLogS3Uri"},
+ "LocalPath": "/opt/ml/processing/spark-events/",
+ "S3UploadMode": "Continuous",
+ },
+ }
+ ]
+
+ pipeline = Pipeline(
+ name="MyPipeline",
+ steps=[step],
+ sagemaker_session=pipeline_session,
+ )
+
+ # test for idempotency
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == step_def2
diff --git a/tests/unit/sagemaker/workflow/test_quality_check_step.py b/tests/unit/sagemaker/workflow/test_quality_check_step.py
index b60e2de8fa..dc104d71df 100644
--- a/tests/unit/sagemaker/workflow/test_quality_check_step.py
+++ b/tests/unit/sagemaker/workflow/test_quality_check_step.py
@@ -15,10 +15,6 @@
import json
import pytest
-import sagemaker
-
-from mock import Mock, PropertyMock
-
from sagemaker.model_monitor import DatasetFormat
from sagemaker.workflow.parameters import ParameterString
from sagemaker.workflow.pipeline import Pipeline
@@ -31,49 +27,7 @@
from sagemaker.workflow.steps import CacheConfig
from sagemaker.workflow.check_job_config import CheckJobConfig
-_REGION = "us-west-2"
_ROLE = "DummyRole"
-_BUCKET = "my-bucket"
-
-
-@pytest.fixture
-def boto_session():
- role_mock = Mock()
- type(role_mock).arn = PropertyMock(return_value=_ROLE)
-
- resource_mock = Mock()
- resource_mock.Role.return_value = role_mock
-
- session_mock = Mock(region_name=_REGION)
- session_mock.resource.return_value = resource_mock
-
- return session_mock
-
-
-@pytest.fixture
-def client():
- """Mock client.
-
- Considerations when appropriate:
-
- * utilize botocore.stub.Stubber
- * separate runtime client from client
- """
- client_mock = Mock()
- client_mock._client_config.user_agent = (
- "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
- )
- return client_mock
-
-
-@pytest.fixture
-def sagemaker_session(boto_session, client):
- return sagemaker.session.Session(
- boto_session=boto_session,
- sagemaker_client=client,
- sagemaker_runtime_client=client,
- default_bucket=_BUCKET,
- )
_expected_data_quality_dsl = {
diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py
index 9887d43078..ba712d11d7 100644
--- a/tests/unit/sagemaker/workflow/test_steps.py
+++ b/tests/unit/sagemaker/workflow/test_steps.py
@@ -16,15 +16,10 @@
import json
import pytest
-import sagemaker
import os
import warnings
-from mock import (
- Mock,
- PropertyMock,
- patch,
-)
+from mock import patch
from sagemaker.debugger import ProfilerConfig
from sagemaker.estimator import Estimator
@@ -94,46 +89,6 @@ def create_predictor(self, endpoint_name):
return Predictor(endpoint_name, self.sagemaker_session)
-@pytest.fixture
-def boto_session():
- role_mock = Mock()
- type(role_mock).arn = PropertyMock(return_value=ROLE)
-
- resource_mock = Mock()
- resource_mock.Role.return_value = role_mock
-
- session_mock = Mock(region_name=REGION)
- session_mock.resource.return_value = resource_mock
-
- return session_mock
-
-
-@pytest.fixture
-def client():
- """Mock client.
-
- Considerations when appropriate:
-
- * utilize botocore.stub.Stubber
- * separate runtime client from client
- """
- client_mock = Mock()
- client_mock._client_config.user_agent = (
- "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
- )
- return client_mock
-
-
-@pytest.fixture
-def sagemaker_session(boto_session, client):
- return sagemaker.session.Session(
- boto_session=boto_session,
- sagemaker_client=client,
- sagemaker_runtime_client=client,
- default_bucket=BUCKET,
- )
-
-
@pytest.fixture
def script_processor(sagemaker_session):
return ScriptProcessor(
diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py
index 4133343c93..3e8b57b069 100644
--- a/tests/unit/sagemaker/workflow/test_training_step.py
+++ b/tests/unit/sagemaker/workflow/test_training_step.py
@@ -14,7 +14,7 @@
import os
import json
-from mock import Mock, PropertyMock
+from mock import Mock, PropertyMock, patch
import pytest
import warnings
@@ -25,11 +25,12 @@
from sagemaker.parameter import IntegerParameter
from sagemaker.transformer import Transformer
from sagemaker.tuner import HyperparameterTuner
-from sagemaker.workflow.pipeline_context import PipelineSession
+from sagemaker.workflow.pipeline_context import PipelineSession, _PipelineConfig
from sagemaker.workflow.parameters import ParameterString, ParameterBoolean
from sagemaker.workflow.steps import TrainingStep
from sagemaker.workflow.pipeline import Pipeline, PipelineGraph
+from sagemaker.workflow.utilities import hash_files_or_dirs
from sagemaker.workflow.functions import Join
from sagemaker.estimator import Estimator
@@ -66,9 +67,19 @@
ROLE = "DummyRole"
IMAGE_URI = "fakeimage"
MODEL_NAME = "gisele"
-DUMMY_LOCAL_SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py")
+LOCAL_ENTRY_POINT = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-with-handler/training.py")
+LOCAL_SOURCE_DIR = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-with-handler")
+LOCAL_DEPS = [
+ os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies"),
+]
DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
INSTANCE_TYPE = "ml.m4.xlarge"
+MOCKED_PIPELINE_CONFIG = _PipelineConfig(
+ "MyPipeline",
+ "MyTrainingStep",
+ hash_files_or_dirs([LOCAL_SOURCE_DIR] + LOCAL_DEPS),
+ "config-hash-abcdefg",
+)
ESTIMATOR_LISTS = [
SKLearn(
@@ -152,6 +163,88 @@
),
]
+ESTIMATOR_LISTS_LOCAL_CODE = [
+ SKLearn(
+ framework_version="0.23-1",
+ py_version="py3",
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ role=ROLE,
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ ),
+ PyTorch(
+ role=ROLE,
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ framework_version="1.8.0",
+ py_version="py36",
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ ),
+ TensorFlow(
+ role=ROLE,
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ framework_version="2.0",
+ py_version="py3",
+ ),
+ HuggingFace(
+ transformers_version="4.6",
+ pytorch_version="1.7",
+ role=ROLE,
+ instance_type="ml.p3.2xlarge",
+ instance_count=1,
+ py_version="py36",
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ ),
+ XGBoost(
+ framework_version="1.3-1",
+ py_version="py3",
+ role=ROLE,
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ ),
+ MXNet(
+ framework_version="1.4.1",
+ py_version="py3",
+ role=ROLE,
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ toolkit=RLToolkit.RAY,
+ framework=RLFramework.TENSORFLOW,
+ toolkit_version="0.8.5",
+ ),
+ RLEstimator(
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ toolkit=RLToolkit.RAY,
+ framework=RLFramework.TENSORFLOW,
+ toolkit_version="0.8.5",
+ role=ROLE,
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ ),
+ Chainer(
+ role=ROLE,
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ use_mpi=True,
+ num_processes=4,
+ framework_version="5.0.0",
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ py_version="py3",
+ ),
+]
+
INPUT_PARAM_LISTS = [
"s3://my-bucket/my-training-input",
@@ -209,6 +302,7 @@ def hyperparameters():
return {"test-key": "test-val"}
+@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
@pytest.mark.parametrize(
"experiment_config, expected_experiment_config",
[
@@ -250,6 +344,9 @@ def test_training_step_with_estimator(
hyperparameters=hyperparameters,
enable_network_isolation=enable_network_isolation,
encrypt_inter_container_traffic=encrypt_container_traffic,
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ dependencies=LOCAL_DEPS,
)
with warnings.catch_warnings(record=True) as w:
@@ -289,18 +386,15 @@ def test_training_step_with_estimator(
sagemaker_session=pipeline_session,
)
step_args = get_step_args_helper(step_args, "Training")
- expected_step_arguments = deepcopy(step_args)
- expected_step_arguments["EnableInterContainerTrafficEncryption"] = {
+ step_args["EnableInterContainerTrafficEncryption"] = {
"Get": "Parameters.encrypt_container_traffic"
}
- expected_step_arguments["EnableNetworkIsolation"] = {
- "Get": "Parameters.enable_network_isolation"
- }
+ step_args["EnableNetworkIsolation"] = {"Get": "Parameters.enable_network_isolation"}
if expected_experiment_config is None:
- expected_step_arguments.pop("ExperimentConfig", None)
+ step_args.pop("ExperimentConfig", None)
else:
- expected_step_arguments["ExperimentConfig"] = expected_experiment_config
+ step_args["ExperimentConfig"] = expected_experiment_config
assert step_condition.conditions[0].left.expr == {
"Get": "Steps.MyTrainingStep.FinalMetricDataList['val:acc'].Value"
@@ -309,7 +403,7 @@ def test_training_step_with_estimator(
# delete profiler rule configurations because of timestamp collision
del step_definition["Arguments"]["ProfilerRuleConfigurations"]
- del expected_step_arguments["ProfilerRuleConfigurations"]
+ del step_args["ProfilerRuleConfigurations"]
assert step_definition == {
"Name": "MyTrainingStep",
@@ -317,7 +411,7 @@ def test_training_step_with_estimator(
"DisplayName": "MyTrainingStep",
"Type": "Training",
"DependsOn": ["TestStep"],
- "Arguments": expected_step_arguments,
+ "Arguments": step_args,
}
assert step_train.properties.TrainingJobName.expr == {
"Get": "Steps.MyTrainingStep.TrainingJobName"
@@ -332,6 +426,11 @@ def test_training_step_with_estimator(
}
)
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ del step_def2["Arguments"]["ProfilerRuleConfigurations"]
+ assert step_definition == step_def2
+
def test_training_step_estimator_with_param_code_input(
pipeline_session, training_input, hyperparameters
@@ -374,7 +473,8 @@ def test_training_step_estimator_with_param_code_input(
step_args = get_step_args_helper(step_args, "Training")
step_args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"}
step_args["HyperParameters"]["sagemaker_submit_directory"] = {"Get": "Parameters.SourceDir"}
- assert json.loads(pipeline.definition())["Steps"][0] == {
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == {
"Name": "MyTrainingStep",
"Description": "TrainingStep description",
"DisplayName": "MyTrainingStep",
@@ -382,6 +482,10 @@ def test_training_step_estimator_with_param_code_input(
"Arguments": step_args,
}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == step_def2
+
@pytest.mark.parametrize("estimator", ESTIMATOR_LISTS)
@pytest.mark.parametrize("training_input", INPUT_PARAM_LISTS)
@@ -414,41 +518,119 @@ def test_training_step_with_framework_estimator(
)
step_args = get_step_args_helper(step_args, "Training")
+ expected_step_args = deepcopy(step_args)
step_def = json.loads(pipeline.definition())["Steps"][0]
- assert step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == training_input
- assert step_args["OutputDataConfig"]["S3OutputPath"] == output_path
- step_args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"}
- step_args["HyperParameters"]["sagemaker_submit_directory"] = {"Get": "Parameters.SourceDir"}
+ assert (
+ expected_step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ == training_input
+ )
+ assert expected_step_args["OutputDataConfig"]["S3OutputPath"] == output_path
+ expected_step_args["HyperParameters"]["sagemaker_program"] = {"Get": "Parameters.EntryPoint"}
+ expected_step_args["HyperParameters"]["sagemaker_submit_directory"] = {
+ "Get": "Parameters.SourceDir"
+ }
- del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ del expected_step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
- del step_args["OutputDataConfig"]["S3OutputPath"]
+ del expected_step_args["OutputDataConfig"]["S3OutputPath"]
del step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"]
- # trim timestamp so RuleConfigurationName will match
- rule_config_name_step_args = step_args["ProfilerRuleConfigurations"][0]["RuleConfigurationName"]
- step_args["ProfilerRuleConfigurations"][0][
- "RuleConfigurationName"
- ] = rule_config_name_step_args[:-11]
- rule_config_name_step_def = step_def["Arguments"]["ProfilerRuleConfigurations"][0][
- "RuleConfigurationName"
- ]
- step_def["Arguments"]["ProfilerRuleConfigurations"][0][
- "RuleConfigurationName"
- ] = rule_config_name_step_def[:-11]
+ # delete profiler rule configurations because of timestamp collision
+ del step_def["Arguments"]["ProfilerRuleConfigurations"]
+ del expected_step_args["ProfilerRuleConfigurations"]
if "sagemaker_s3_output" in step_args["HyperParameters"]:
- del step_args["HyperParameters"]["sagemaker_s3_output"]
+ del expected_step_args["HyperParameters"]["sagemaker_s3_output"]
del step_def["Arguments"]["HyperParameters"]["sagemaker_s3_output"]
assert step_def == {
"Name": "MyTrainingStep",
"Type": "Training",
- "Arguments": step_args,
+ "Arguments": expected_step_args,
}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ del step_def2["Arguments"]["OutputDataConfig"]["S3OutputPath"]
+ del step_def2["Arguments"]["ProfilerRuleConfigurations"]
+ if "sagemaker_s3_output" in step_def2["Arguments"]["HyperParameters"]:
+ del step_def2["Arguments"]["HyperParameters"]["sagemaker_s3_output"]
+ assert step_def == step_def2
+
+
+@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
+@pytest.mark.parametrize("estimator", ESTIMATOR_LISTS_LOCAL_CODE)
+@pytest.mark.parametrize("training_input", INPUT_PARAM_LISTS)
+@pytest.mark.parametrize(
+ "output_path", ["s3://my-bucket/my-output-path", ParameterString(name="OutputPath")]
+)
+def test_training_step_with_framework_estimator_local_code(
+ estimator, pipeline_session, training_input, output_path, hyperparameters
+):
+ estimator.set_hyperparameters(**hyperparameters)
+ estimator.volume_kms_key = "volume-kms-key"
+ estimator.output_kms_key = "output-kms-key"
+ estimator.dependencies = LOCAL_DEPS
+ estimator.output_path = output_path
+ # TODO: remove job_name once we merge
+ # https://github.com/aws/sagemaker-python-sdk/pull/3158/files
+ estimator.base_job_name = "TestJob"
+
+ estimator.sagemaker_session = pipeline_session
+ step_args = estimator.fit(inputs=TrainingInput(s3_data=training_input))
+
+ step = TrainingStep(
+ name="MyTrainingStep",
+ step_args=step_args,
+ )
+ pipeline = Pipeline(
+ name="MyPipeline",
+ steps=[step],
+ sagemaker_session=pipeline_session,
+ )
+
+ step_args = get_step_args_helper(step_args, "Training")
+ expected_step_args = deepcopy(step_args)
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+
+ assert (
+ expected_step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ == training_input
+ )
+ assert expected_step_args["OutputDataConfig"]["S3OutputPath"] == output_path
+
+ del expected_step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+
+ del expected_step_args["OutputDataConfig"]["S3OutputPath"]
+ del step_def["Arguments"]["OutputDataConfig"]["S3OutputPath"]
+
+ # delete profiler rule configurations because of timestamp collision
+ del step_def["Arguments"]["ProfilerRuleConfigurations"]
+ del expected_step_args["ProfilerRuleConfigurations"]
+
+ if "sagemaker_s3_output" in step_args["HyperParameters"]:
+ del expected_step_args["HyperParameters"]["sagemaker_s3_output"]
+ del step_def["Arguments"]["HyperParameters"]["sagemaker_s3_output"]
+
+ assert step_def == {
+ "Name": "MyTrainingStep",
+ "Type": "Training",
+ "Arguments": expected_step_args,
+ }
+
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ del step_def2["Arguments"]["OutputDataConfig"]["S3OutputPath"]
+ del step_def2["Arguments"]["ProfilerRuleConfigurations"]
+ if "sagemaker_s3_output" in step_def2["Arguments"]["HyperParameters"]:
+ del step_def2["Arguments"]["HyperParameters"]["sagemaker_s3_output"]
+ assert step_def == step_def2
+
@pytest.mark.parametrize(
"algo_estimator",
@@ -519,17 +701,9 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel
del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
- # trim timestamp so RuleConfigurationName will match
- rule_config_name_step_args = step_args["ProfilerRuleConfigurations"][0]["RuleConfigurationName"]
- step_args["ProfilerRuleConfigurations"][0][
- "RuleConfigurationName"
- ] = rule_config_name_step_args[:-11]
- rule_config_name_step_def = step_def["Arguments"]["ProfilerRuleConfigurations"][0][
- "RuleConfigurationName"
- ]
- step_def["Arguments"]["ProfilerRuleConfigurations"][0][
- "RuleConfigurationName"
- ] = rule_config_name_step_def[:-11]
+ # delete profiler rule configurations because of timestamp collision
+ del step_def["Arguments"]["ProfilerRuleConfigurations"]
+ del step_args["ProfilerRuleConfigurations"]
assert step_def == {
"Name": "MyTrainingStep",
@@ -537,6 +711,100 @@ def test_training_step_with_algorithm_base(algo_estimator, training_input, pipel
"Arguments": step_args,
}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ del step_def2["Arguments"]["ProfilerRuleConfigurations"]
+ assert step_def == step_def2
+
+
+@patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG)
+@pytest.mark.parametrize(
+ "algo_estimator",
+ [
+ KNN,
+ KMeans,
+ LinearLearner,
+ RandomCutForest,
+ LDA,
+ Object2Vec,
+ NTM,
+ PCA,
+ FactorizationMachines,
+ IPInsights,
+ ],
+)
+@pytest.mark.parametrize(
+ "training_input",
+ INPUT_PARAM_LISTS,
+)
+def test_training_step_with_algorithm_base_local_code(
+ algo_estimator, training_input, pipeline_session
+):
+ estimator = algo_estimator(
+ role=ROLE,
+ instance_type=INSTANCE_TYPE,
+ instance_count=1,
+ sagemaker_session=pipeline_session,
+ entry_point=LOCAL_ENTRY_POINT,
+ source_dir=LOCAL_SOURCE_DIR,
+ dependencies=LOCAL_DEPS,
+ # TODO: remove job_name once we merge
+ # https://github.com/aws/sagemaker-python-sdk/pull/3158/files
+ base_job_name="TestJob",
+ )
+ data = RecordSet(
+ s3_data=training_input,
+ num_records=1000,
+ feature_dim=128,
+ channel="train",
+ )
+
+ with warnings.catch_warnings(record=True) as w:
+ step_args = estimator.fit(
+ records=data,
+ mini_batch_size=1000,
+ )
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "Running within a PipelineSession" in str(w[-1].message)
+
+ with warnings.catch_warnings(record=True) as w:
+ step = TrainingStep(
+ name="MyTrainingStep",
+ step_args=step_args,
+ )
+ assert len(w) == 0
+
+ pipeline = Pipeline(
+ name="MyPipeline",
+ steps=[step],
+ sagemaker_session=pipeline_session,
+ )
+
+ step_args = get_step_args_helper(step_args, "Training")
+
+ step_def = json.loads(pipeline.definition())["Steps"][0]
+ assert step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"] == training_input
+ del step_args["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ del step_def["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+
+ # delete profiler rule configurations because of timestamp collision
+ del step_def["Arguments"]["ProfilerRuleConfigurations"]
+ del step_args["ProfilerRuleConfigurations"]
+
+ assert step_def == {
+ "Name": "MyTrainingStep",
+ "Type": "Training",
+ "Arguments": step_args,
+ }
+
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ del step_def2["Arguments"]["InputDataConfig"][0]["DataSource"]["S3DataSource"]["S3Uri"]
+ del step_def2["Arguments"]["ProfilerRuleConfigurations"]
+ assert step_def == step_def2
+
@pytest.mark.parametrize(
"inputs",
diff --git a/tests/unit/sagemaker/workflow/test_transform_step.py b/tests/unit/sagemaker/workflow/test_transform_step.py
index 5699f13538..ffc901bf5c 100644
--- a/tests/unit/sagemaker/workflow/test_transform_step.py
+++ b/tests/unit/sagemaker/workflow/test_transform_step.py
@@ -176,6 +176,10 @@ def test_transform_step_with_transformer(model_name, data, output_path, pipeline
"Arguments": expected_step_arguments,
}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == step_def2
+
@pytest.mark.parametrize(
"experiment_config, expected_experiment_config",
@@ -261,6 +265,10 @@ def test_transform_step_with_transformer_experiment_config(
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
assert adjacency_list == {"MyTransformStep": []}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ assert step_def == step_def2
+
@pytest.mark.parametrize(
"inputs",
diff --git a/tests/unit/sagemaker/workflow/test_tuning_step.py b/tests/unit/sagemaker/workflow/test_tuning_step.py
index 9c7b764c3b..6c022bb255 100644
--- a/tests/unit/sagemaker/workflow/test_tuning_step.py
+++ b/tests/unit/sagemaker/workflow/test_tuning_step.py
@@ -159,22 +159,14 @@ def test_tuning_step_with_single_algo_tuner(pipeline_session, training_input, en
"S3DataSource"
]["S3Uri"]
- # trim timestamp so sagemaker_job_name will still match
- step_args_sm_job_name = step_args["TrainingJobDefinition"]["StaticHyperParameters"][
- "sagemaker_job_name"
- ]
- step_args["TrainingJobDefinition"]["StaticHyperParameters"][
- "sagemaker_job_name"
- ] = step_args_sm_job_name[:-24]
- step_def_sm_job_name = step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
+ # delete sagemaker_job_name b/c of timestamp collision
+ del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_job_name"]
+ del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
"sagemaker_job_name"
]
- step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
- "sagemaker_job_name"
- ] = step_def_sm_job_name[:-24]
- # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled once
- # next PRs are submitted with s3 path updates, removing the job name.
+ # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled after
+ # caching improvements phase 2.
del step_args["TrainingJobDefinition"]["StaticHyperParameters"]["sagemaker_submit_directory"]
del step_def["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
"sagemaker_submit_directory"
@@ -188,6 +180,16 @@ def test_tuning_step_with_single_algo_tuner(pipeline_session, training_input, en
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
assert adjacency_list == {"MyTuningStep": []}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ del step_def2["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
+ "sagemaker_job_name"
+ ]
+ del step_def2["Arguments"]["TrainingJobDefinition"]["StaticHyperParameters"][
+ "sagemaker_submit_directory"
+ ]
+ assert step_def == step_def2
+
def test_tuning_step_with_multi_algo_tuner(pipeline_session, entry_point):
pytorch_estimator = PyTorch(
@@ -249,20 +251,14 @@ def test_tuning_step_with_multi_algo_tuner(pipeline_session, entry_point):
step_def = json.loads(pipeline.definition())["Steps"][0]
for i, step in enumerate(step_args["TrainingJobDefinitions"]):
- # trim timestamp so sagemaker_job_name will still match
- step_args_sm_job_name = step["StaticHyperParameters"]["sagemaker_job_name"]
- step_args["TrainingJobDefinitions"][i]["StaticHyperParameters"][
- "sagemaker_job_name"
- ] = step_args_sm_job_name[:-24]
- step_def_sm_job_name = step_def["Arguments"]["TrainingJobDefinitions"][i][
- "StaticHyperParameters"
- ]["sagemaker_job_name"]
- step_def["Arguments"]["TrainingJobDefinitions"][i]["StaticHyperParameters"][
+ # delete sagemaker_job_name b/c of timestamp collision
+ del step_args["TrainingJobDefinitions"][i]["StaticHyperParameters"]["sagemaker_job_name"]
+ del step_def["Arguments"]["TrainingJobDefinitions"][i]["StaticHyperParameters"][
"sagemaker_job_name"
- ] = step_def_sm_job_name[:-24]
+ ]
- # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled once
- # next PRs are submitted with s3 path updates, removing the job name.
+ # delete S3 path assertions for now because job name is included with timestamp. These will be re-enabled after
+ # caching improvements phase 2.
del step_args["TrainingJobDefinitions"][i]["StaticHyperParameters"][
"sagemaker_submit_directory"
]
@@ -278,6 +274,18 @@ def test_tuning_step_with_multi_algo_tuner(pipeline_session, entry_point):
adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list
assert adjacency_list == {"MyTuningStep": []}
+ # test idempotency
+ step_def2 = json.loads(pipeline.definition())["Steps"][0]
+ for i, step in enumerate(step_def2["Arguments"]["TrainingJobDefinitions"]):
+ del step_def2["Arguments"]["TrainingJobDefinitions"][i]["StaticHyperParameters"][
+ "sagemaker_job_name"
+ ]
+
+ del step_def2["Arguments"]["TrainingJobDefinitions"][i]["StaticHyperParameters"][
+ "sagemaker_submit_directory"
+ ]
+ assert step_def == step_def2
+
@pytest.mark.parametrize(
"inputs",
diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py
index dcbf5a6421..c8d86c5866 100644
--- a/tests/unit/sagemaker/workflow/test_utils.py
+++ b/tests/unit/sagemaker/workflow/test_utils.py
@@ -18,12 +18,6 @@
import tempfile
import pytest
-import sagemaker
-
-from mock import (
- Mock,
- PropertyMock,
-)
from sagemaker.estimator import Estimator
from sagemaker.workflow._utils import (
@@ -35,51 +29,7 @@
from sagemaker.workflow.properties import Properties
from tests.unit.test_utils import FakeS3, list_tar_files
from tests.unit import DATA_DIR
-
-REGION = "us-west-2"
-BUCKET = "my-bucket"
-IMAGE_URI = "fakeimage"
-ROLE = "DummyRole"
-
-
-@pytest.fixture
-def boto_session():
- role_mock = Mock()
- type(role_mock).arn = PropertyMock(return_value=ROLE)
-
- resource_mock = Mock()
- resource_mock.Role.return_value = role_mock
-
- session_mock = Mock(region_name=REGION)
- session_mock.resource.return_value = resource_mock
-
- return session_mock
-
-
-@pytest.fixture
-def client():
- """Mock client.
-
- Considerations when appropriate:
-
- * utilize botocore.stub.Stubber
- * separate runtime client from client
- """
- client_mock = Mock()
- client_mock._client_config.user_agent = (
- "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource"
- )
- return client_mock
-
-
-@pytest.fixture
-def sagemaker_session(boto_session, client):
- return sagemaker.session.Session(
- boto_session=boto_session,
- sagemaker_client=client,
- sagemaker_runtime_client=client,
- default_bucket=BUCKET,
- )
+from tests.unit.sagemaker.workflow.conftest import ROLE, IMAGE_URI, BUCKET
@pytest.fixture
@@ -171,7 +121,7 @@ def test_repack_model_step(estimator):
}
-def test_repack_model_step_with_invalid_input():
+def test_register_model_step_with_invalid_input():
# without both step_args and any of the old required arguments
with pytest.raises(ValueError) as error:
_RegisterModelStep(
diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py
index 82b154317d..44b5818fc8 100644
--- a/tests/unit/test_amazon_estimator.py
+++ b/tests/unit/test_amazon_estimator.py
@@ -225,6 +225,9 @@ def test_fit_ndarray(time, sagemaker_session):
assert mock_object.put.call_count == 4
+ called_args = sagemaker_session.train.call_args
+ assert not called_args[1]["experiment_config"]
+
def test_fit_pass_experiment_config(sagemaker_session):
kwargs = dict(COMMON_ARGS)
@@ -239,12 +242,18 @@ def test_fit_pass_experiment_config(sagemaker_session):
labels = [99, 85, 87, 2]
pca.fit(
pca.record_set(np.array(train), np.array(labels)),
- experiment_config={"ExperimentName": "exp"},
+ experiment_config={
+ "ExperimentName": "exp",
+ "RunName": "rn",
+ },
)
called_args = sagemaker_session.train.call_args
- assert called_args[1]["experiment_config"] == {"ExperimentName": "exp"}
+ assert called_args[1]["experiment_config"] == {
+ "ExperimentName": "exp",
+ "RunName": "rn",
+ }
def test_build_shards():
diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py
index 34e6a43fcf..868da88d78 100644
--- a/tests/unit/test_estimator.py
+++ b/tests/unit/test_estimator.py
@@ -2489,7 +2489,12 @@ def test_start_new(sagemaker_session):
hyperparameters=hyperparameters,
)
- exp_config = {"ExperimentName": "exp", "TrialName": "t", "TrialComponentDisplayName": "tc"}
+ exp_config = {
+ "ExperimentName": "exp",
+ "TrialName": "t",
+ "TrialComponentDisplayName": "tc",
+ "RunName": "rn",
+ }
started_training_job = training_job.start_new(estimator, inputs, experiment_config=exp_config)
called_args = sagemaker_session.train.call_args
@@ -2680,6 +2685,7 @@ def test_unsupported_type_in_dict():
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
}
)
@@ -2884,6 +2890,7 @@ def test_generic_to_fit_with_experiment_config(time, sagemaker_session):
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
},
)
diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py
index 99b0e839b7..9ba3e17ff3 100644
--- a/tests/unit/test_mxnet.py
+++ b/tests/unit/test_mxnet.py
@@ -62,6 +62,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"}
diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py
index 082f699d63..c8aad13774 100644
--- a/tests/unit/test_pytorch.py
+++ b/tests/unit/test_pytorch.py
@@ -54,6 +54,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
DISTRIBUTION_PYTORCH_DDP_ENABLED = {"pytorchddp": {"enabled": True}}
diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py
index 4efc2e5bf8..2035636e76 100644
--- a/tests/unit/test_rl.py
+++ b/tests/unit/test_rl.py
@@ -49,6 +49,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index 8958210092..ec4a21cbc9 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -588,11 +588,16 @@ def test_user_agent_injected(boto_session):
assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent
assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent
+ assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent
assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent
assert (
"AWS-SageMaker-Notebook-Instance"
not in sess.sagemaker_runtime_client._client_config.user_agent
)
+ assert (
+ "AWS-SageMaker-Notebook-Instance"
+ not in sess.sagemaker_metrics_client._client_config.user_agent
+ )
def test_user_agent_injected_with_nbi(boto_session):
@@ -607,10 +612,14 @@ def test_user_agent_injected_with_nbi(boto_session):
assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent
assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent
+ assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent
assert "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_client._client_config.user_agent
assert (
"AWS-SageMaker-Notebook-Instance" in sess.sagemaker_runtime_client._client_config.user_agent
)
+ assert (
+ "AWS-SageMaker-Notebook-Instance" in sess.sagemaker_metrics_client._client_config.user_agent
+ )
def test_user_agent_injected_with_nbi_ioerror(boto_session):
@@ -625,11 +634,16 @@ def test_user_agent_injected_with_nbi_ioerror(boto_session):
assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_client._client_config.user_agent
assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_runtime_client._client_config.user_agent
+ assert "AWS-SageMaker-Python-SDK" in sess.sagemaker_metrics_client._client_config.user_agent
assert "AWS-SageMaker-Notebook-Instance" not in sess.sagemaker_client._client_config.user_agent
assert (
"AWS-SageMaker-Notebook-Instance"
not in sess.sagemaker_runtime_client._client_config.user_agent
)
+ assert (
+ "AWS-SageMaker-Notebook-Instance"
+ not in sess.sagemaker_metrics_client._client_config.user_agent
+ )
def test_training_input_all_defaults():
@@ -700,6 +714,7 @@ def test_training_input_all_arguments():
"ExperimentName": "dummyExp",
"TrialName": "dummyT",
"TrialComponentDisplayName": "dummyTC",
+ "RunName": "dummyRN",
}
MODEL_CLIENT_CONFIG = {"InvocationsMaxRetries": 2, "InvocationsTimeoutInSeconds": 60}
@@ -941,6 +956,13 @@ def test_train_pack_to_request(sagemaker_session):
],
}
+SAMPLE_HYPERBAND_STRATEGY_CONFIG = {
+ "HyperbandStrategyConfig": {
+ "MinResource": 1,
+ "MaxResource": 10,
+ }
+}
+
@pytest.mark.parametrize(
"warm_start_type, parents",
@@ -1167,6 +1189,47 @@ def assert_create_tuning_job_request(**kwrags):
)
+def test_tune_with_strategy_config(sagemaker_session):
+ def assert_create_tuning_job_request(**kwrags):
+ assert (
+ kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][
+ "MinResource"
+ ]
+ == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MinResource"]
+ )
+ assert (
+ kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][
+ "MaxResource"
+ ]
+ == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MaxResource"]
+ )
+
+ sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = (
+ assert_create_tuning_job_request
+ )
+ sagemaker_session.tune(
+ job_name="dummy-tuning-1",
+ strategy="Bayesian",
+ objective_type="Maximize",
+ objective_metric_name="val-score",
+ max_jobs=100,
+ max_parallel_jobs=5,
+ parameter_ranges=SAMPLE_PARAM_RANGES,
+ static_hyperparameters=STATIC_HPs,
+ image_uri="dummy-image-1",
+ input_mode="File",
+ metric_definitions=SAMPLE_METRIC_DEF,
+ role=EXPANDED_ROLE,
+ input_config=SAMPLE_INPUT,
+ output_config=SAMPLE_OUTPUT,
+ resource_config=RESOURCE_CONFIG,
+ stop_condition=SAMPLE_STOPPING_CONDITION,
+ tags=None,
+ warm_start_config=None,
+ strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG,
+ )
+
+
def test_tune_with_encryption_flag(sagemaker_session):
def assert_create_tuning_job_request(**kwrags):
assert (
@@ -2739,6 +2802,35 @@ def test_feature_metadata_describe(sagemaker_session):
)
+def test_list_feature_groups(sagemaker_session):
+ expected_list_feature_groups_args = {
+ "NameContains": "MyFeatureGroup",
+ "FeatureGroupStatusEquals": "Created",
+ "OfflineStoreStatusEquals": "Active",
+ "CreationTimeAfter": datetime.datetime(2020, 12, 1),
+ "CreationTimeBefore": datetime.datetime(2022, 7, 1),
+ "SortOrder": "Ascending",
+ "SortBy": "Name",
+ "MaxResults": 50,
+ "NextToken": "token",
+ }
+ sagemaker_session.list_feature_groups(
+ name_contains="MyFeatureGroup",
+ feature_group_status_equals="Created",
+ offline_store_status_equals="Active",
+ creation_time_after=datetime.datetime(2020, 12, 1),
+ creation_time_before=datetime.datetime(2022, 7, 1),
+ sort_order="Ascending",
+ sort_by="Name",
+ max_results=50,
+ next_token="token",
+ )
+ assert sagemaker_session.sagemaker_client.list_feature_groups.called_once()
+ assert sagemaker_session.sagemaker_client.list_feature_groups.called_with(
+ **expected_list_feature_groups_args
+ )
+
+
def test_start_query_execution(sagemaker_session):
athena_mock = Mock()
sagemaker_session.boto_session.client(
diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py
index 13cc755336..c3e984e0b7 100644
--- a/tests/unit/test_sklearn.py
+++ b/tests/unit/test_sklearn.py
@@ -51,6 +51,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py
index 0eb81be584..8bcbed41c2 100644
--- a/tests/unit/test_utils.py
+++ b/tests/unit/test_utils.py
@@ -25,10 +25,12 @@
from boto3 import exceptions
import botocore
import pytest
-from mock import call, patch, Mock, MagicMock
+from mock import call, patch, Mock, MagicMock, PropertyMock
import sagemaker
+from sagemaker.experiments._run_context import _RunContext
from sagemaker.session_settings import SessionSettings
+from sagemaker.utils import retry_with_backoff, check_and_get_run_experiment_config
from tests.unit.sagemaker.workflow.helpers import CustomStep
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
@@ -795,3 +797,63 @@ def test_start_waiting(capfd):
out, _ = capfd.readouterr()
assert "." * sagemaker.utils.WAITING_DOT_NUMBER in out
+
+
+def test_retry_with_backoff():
+ callable_func = Mock()
+
+ # Invalid input
+ with pytest.raises(ValueError) as value_err:
+ retry_with_backoff(callable_func, 0)
+ assert "The num_attempts must be >= 1" in str(value_err)
+ callable_func.assert_not_called()
+
+ # All retries fail
+ run_err_msg = "Test Retry Error"
+ callable_func.side_effect = RuntimeError(run_err_msg)
+ with pytest.raises(RuntimeError) as run_err:
+ retry_with_backoff(callable_func, 2)
+ assert run_err_msg in str(run_err)
+
+ # One retry passes
+ func_return_val = "Test Return"
+ callable_func.side_effect = [RuntimeError(run_err_msg), func_return_val]
+ assert retry_with_backoff(callable_func, 2) == func_return_val
+
+ # No retry
+ callable_func.side_effect = None
+ callable_func.return_value = func_return_val
+ assert retry_with_backoff(callable_func, 2) == func_return_val
+
+
+def test_check_and_get_run_experiment_config():
+ supplied_exp_cfg = {"ExperimentName": "my-supplied-exp-name", "RunName": "my-supplied-run-name"}
+ run_exp_cfg = {"ExperimentName": "my-run-exp-name", "RunName": "my-run-run-name"}
+
+ # No user supplied exp config and no current Run
+ assert not _RunContext.get_current_run()
+ exp_cfg1 = check_and_get_run_experiment_config(None)
+ assert exp_cfg1 is None
+
+ # With user supplied exp config and no current Run
+ assert not _RunContext.get_current_run()
+ exp_cfg2 = check_and_get_run_experiment_config(supplied_exp_cfg)
+ assert exp_cfg2 == supplied_exp_cfg
+
+ run = Mock()
+ type(run).experiment_config = PropertyMock(return_value=run_exp_cfg)
+ _RunContext.add_run_object(run)
+
+ try:
+ # No user supplied exp config and with current Run
+ assert _RunContext.get_current_run().experiment_config == run_exp_cfg
+ exp_cfg3 = check_and_get_run_experiment_config(None)
+ assert exp_cfg3 == run_exp_cfg
+
+ # With user supplied exp config and current Run
+ assert _RunContext.get_current_run().experiment_config == run_exp_cfg
+ exp_cfg4 = check_and_get_run_experiment_config(supplied_exp_cfg)
+ assert exp_cfg4 == supplied_exp_cfg
+ finally:
+ # Clean up the global static variable in case it affects other tests
+ _RunContext.drop_current_run()
diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py
index 82f27c19ae..d58c4992cd 100644
--- a/tests/unit/test_xgboost.py
+++ b/tests/unit/test_xgboost.py
@@ -54,6 +54,7 @@
"ExperimentName": "exp",
"TrialName": "trial",
"TrialComponentDisplayName": "tc",
+ "RunName": "rn",
}
diff --git a/tox.ini b/tox.ini
index 2d5fdf0b40..3a398ca51d 100644
--- a/tox.ini
+++ b/tox.ini
@@ -73,6 +73,8 @@ passenv =
# Can be used to specify which tests to run, e.g.: tox -- -s
commands =
python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')"
+ pip install 'apache-airflow==2.4.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.4.1/constraints-3.10.txt"
+
pytest --cov=sagemaker --cov-append {posargs}
{env:IGNORE_COVERAGE:} coverage report -i --fail-under=86
deps = .[test]