From e520c3cca36b8f0a5e6131c11d7c084381ef727a Mon Sep 17 00:00:00 2001 From: Leo Toikka Date: Fri, 4 Nov 2022 10:18:49 +0200 Subject: [PATCH 01/43] fix: add environment key to model container only if it is not empty This caused a problem when trying to deploy a model package from model registry into an endpoint. No environment variables were provided to the model package, but it was registered as an empty mapping. When trying to deploy the model package as an endpoint and providing some environment variables to the endpoint, there was an error from the API: Environment variable map cannot be specified both at Model Package {} and Model {...}. Fixes: aws/sagemaker-python-sdk#3397 --- src/sagemaker/session.py | 6 +++--- tests/unit/sagemaker/workflow/test_model_step.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 00797c9ea0..75841b0a83 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -4792,9 +4792,9 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None, dict[str, str]: A complete container definition object usable with the CreateModel API if passed via `PrimaryContainers` field. """ - if env is None: - env = {} - c_def = {"Image": image_uri, "Environment": env} + c_def = {"Image": image_uri} + if env: + c_def["Environment"] = env if model_data_url: c_def["ModelDataUrl"] = model_data_url if container_mode: diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index 080e70ca62..b2c049be62 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -674,7 +674,7 @@ def test_conditional_model_create_and_regis( container = arguments["PrimaryContainer"] assert container["Image"] == _IMAGE_URI assert container["ModelDataUrl"] == {"Get": "Parameters.ModelData"} - assert not container.get("Environment", {}) + assert container.get("Environment") is None else: raise Exception("A step exists in the collection of an invalid type.") adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list From 263b87d4497a4b49a313443ec2f54b5247414154 Mon Sep 17 00:00:00 2001 From: Leo Toikka Date: Wed, 30 Nov 2022 22:51:01 +0200 Subject: [PATCH 02/43] Accomodate now missing container environment map to unit tests --- src/sagemaker/local/entities.py | 11 ++++------- tests/unit/sagemaker/model/test_model.py | 3 +-- tests/unit/sagemaker/tensorflow/test_tfs.py | 3 +-- tests/unit/sagemaker/workflow/test_airflow.py | 5 ----- tests/unit/sagemaker/workflow/test_model_step.py | 2 +- .../unit/sagemaker/workflow/test_step_collections.py | 2 -- tests/unit/sagemaker/workflow/test_steps.py | 2 +- tests/unit/test_create_deploy_entities.py | 4 ++-- tests/unit/test_endpoint_from_model_data.py | 2 +- tests/unit/test_estimator.py | 2 -- tests/unit/test_pipeline_model.py | 1 - tests/unit/test_session.py | 1 - 12 files changed, 11 insertions(+), 27 deletions(-) diff --git a/src/sagemaker/local/entities.py b/src/sagemaker/local/entities.py index 8229a7fbac..67887b3e37 100644 --- a/src/sagemaker/local/entities.py +++ b/src/sagemaker/local/entities.py @@ -390,7 +390,7 @@ def _get_container_environment(self, **kwargs): container """ environment = {} - environment.update(self.primary_container["Environment"]) + environment.update(self.primary_container.get("Environment", {})) environment["SAGEMAKER_BATCH"] = "True" if "MaxPayloadInMB" in kwargs: environment["SAGEMAKER_MAX_PAYLOAD_IN_MB"] = str(kwargs["MaxPayloadInMB"]) @@ -591,18 +591,15 @@ def serve(self): instance_count = self.production_variant["InitialInstanceCount"] accelerator_type = self.production_variant.get("AcceleratorType") + environment = self.primary_container.get("Environment", {}) if accelerator_type == "local_sagemaker_notebook": - self.primary_container["Environment"][ - "SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT" - ] = "true" + environment["SAGEMAKER_INFERENCE_ACCELERATOR_PRESENT"] = "true" self.create_time = datetime.datetime.now() self.container = _SageMakerContainer( instance_type, instance_count, image, self.local_session ) - self.container.serve( - self.primary_container["ModelDataUrl"], self.primary_container["Environment"] - ) + self.container.serve(self.primary_container["ModelDataUrl"], environment) serving_port = get_config_value("local.serving_port", self.local_session.config) or 8080 _wait_for_serving_container(serving_port) diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index 0b04d3c8bc..dd27e01962 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -134,7 +134,7 @@ def test_prepare_container_def_with_model_data(): model = Model(MODEL_IMAGE) container_def = model.prepare_container_def(INSTANCE_TYPE, "ml.eia.medium") - expected = {"Image": MODEL_IMAGE, "Environment": {}} + expected = {"Image": MODEL_IMAGE} assert expected == container_def @@ -158,7 +158,6 @@ def test_prepare_container_def_with_image_config(): expected = { "Image": MODEL_IMAGE, "ImageConfig": {"RepositoryAccessMode": "Vpc"}, - "Environment": {}, } container_def = model.prepare_container_def() diff --git a/tests/unit/sagemaker/tensorflow/test_tfs.py b/tests/unit/sagemaker/tensorflow/test_tfs.py index 67b69efc44..5bd6198edd 100644 --- a/tests/unit/sagemaker/tensorflow/test_tfs.py +++ b/tests/unit/sagemaker/tensorflow/test_tfs.py @@ -89,7 +89,7 @@ def test_tfs_model(retrieve_image_uri, sagemaker_session, tensorflow_inference_v serverless_inference_config=None, ) assert IMAGE == cdef["Image"] - assert {} == cdef["Environment"] + assert cdef.get("Environment") is None predictor = model.deploy(INSTANCE_COUNT, INSTANCE_TYPE) assert isinstance(predictor, TensorFlowPredictor) @@ -485,7 +485,6 @@ def test_register_tfs_model_auto_infer_framework(sagemaker_session, tensorflow_i "containers": [ { "Image": image_uri, - "Environment": ANY, "ModelDataUrl": ANY, "Framework": "TENSORFLOW", "FrameworkVersion": tensorflow_inference_version, diff --git a/tests/unit/sagemaker/workflow/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py index fa4b4d2e55..287a273441 100644 --- a/tests/unit/sagemaker/workflow/test_airflow.py +++ b/tests/unit/sagemaker/workflow/test_airflow.py @@ -1014,7 +1014,6 @@ def test_amazon_alg_model_config(sagemaker_session): "ModelName": "pca-%s" % TIME_STAMP, "PrimaryContainer": { "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/pca:1", - "Environment": {}, "ModelDataUrl": "{{ model_data }}", }, "ExecutionRoleArn": "{{ role }}", @@ -1108,7 +1107,6 @@ def test_model_config_from_amazon_alg_estimator(sagemaker_session): "ModelName": "knn-%s" % TIME_STAMP, "PrimaryContainer": { "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1", - "Environment": {}, "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Tuning']['BestTrainingJob']" "['TrainingJobName'] }}/output/model.tar.gz", }, @@ -1309,7 +1307,6 @@ def test_transform_config_from_amazon_alg_estimator(sagemaker_session): "ModelName": "knn-%s" % TIME_STAMP, "PrimaryContainer": { "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1", - "Environment": {}, "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Training']['TrainingJobName'] }}" "/output/model.tar.gz", }, @@ -1413,7 +1410,6 @@ def test_deploy_amazon_alg_model_config(sagemaker_session): "ModelName": "pca-%s" % TIME_STAMP, "PrimaryContainer": { "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/pca:1", - "Environment": {}, "ModelDataUrl": "{{ model_data }}", }, "ExecutionRoleArn": "{{ role }}", @@ -1549,7 +1545,6 @@ def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): "ModelName": "knn-%s" % TIME_STAMP, "PrimaryContainer": { "Image": "174872318107.dkr.ecr.us-west-2.amazonaws.com/knn:1", - "Environment": {}, "ModelDataUrl": "s3://output/{{ ti.xcom_pull(task_ids='task_id')['Tuning']['BestTrainingJob']" "['TrainingJobName'] }}/output/model.tar.gz", }, diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index b2c049be62..e428750132 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -968,7 +968,7 @@ def _verify_register_model_container_definition( containers = request["InferenceSpecification"]["Containers"] assert len(containers) == 1 isinstance(containers[0].pop("ModelDataUrl"), expected_model_data_type) - container_env = containers[0]["Environment"] + container_env = containers[0].get("Environment", {}) assert container_env.pop(_SAGEMAKER_PROGRAM, None) == expected_program submit_dir = container_env.pop(_SAGEMAKER_SUBMIT_DIRECTORY, None) if submit_dir and not submit_dir.startswith("s3://"): diff --git a/tests/unit/sagemaker/workflow/test_step_collections.py b/tests/unit/sagemaker/workflow/test_step_collections.py index 2bf47a79d0..1de56af998 100644 --- a/tests/unit/sagemaker/workflow/test_step_collections.py +++ b/tests/unit/sagemaker/workflow/test_step_collections.py @@ -1173,7 +1173,6 @@ def test_estimator_transformer(estimator): "Arguments": { "ExecutionRoleArn": "DummyRole", "PrimaryContainer": { - "Environment": {}, "Image": "fakeimage", "ModelDataUrl": "s3://my-bucket/model.tar.gz", }, @@ -1288,7 +1287,6 @@ def test_estimator_transformer_with_model_repack_with_estimator(estimator): assert arguments == { "ExecutionRoleArn": "DummyRole", "PrimaryContainer": { - "Environment": {}, "Image": "fakeimage", }, } diff --git a/tests/unit/sagemaker/workflow/test_steps.py b/tests/unit/sagemaker/workflow/test_steps.py index 6161537220..40f3ab40a0 100644 --- a/tests/unit/sagemaker/workflow/test_steps.py +++ b/tests/unit/sagemaker/workflow/test_steps.py @@ -847,7 +847,7 @@ def test_create_model_step(sagemaker_session): "DependsOn": ["TestStep", "SecondTestStep"], "Arguments": { "ExecutionRoleArn": "DummyRole", - "PrimaryContainer": {"Environment": {}, "Image": "fakeimage"}, + "PrimaryContainer": {"Image": "fakeimage"}, }, } assert step.properties.ModelName.expr == {"Get": "Steps.MyCreateModelStep.ModelName"} diff --git a/tests/unit/test_create_deploy_entities.py b/tests/unit/test_create_deploy_entities.py index be59296245..cdae80774f 100644 --- a/tests/unit/test_create_deploy_entities.py +++ b/tests/unit/test_create_deploy_entities.py @@ -23,7 +23,7 @@ ROLE = "myimrole" EXPANDED_ROLE = "arn:aws:iam::111111111111:role/ExpandedRole" IMAGE = "myimage" -FULL_CONTAINER_DEF = {"Environment": {}, "Image": IMAGE, "ModelDataUrl": "s3://mybucket/mymodel"} +FULL_CONTAINER_DEF = {"Image": IMAGE, "ModelDataUrl": "s3://mybucket/mymodel"} VPC_CONFIG = {"Subnets": ["subnet-foo"], "SecurityGroups": ["sg-foo"]} INITIAL_INSTANCE_COUNT = 1 INSTANCE_TYPE = "ml.c4.xlarge" @@ -57,7 +57,7 @@ def test_create_model_expand_primary_container(sagemaker_session): sagemaker_session.create_model(name=MODEL_NAME, role=ROLE, container_defs=IMAGE) _1, _2, create_model_kwargs = sagemaker_session.sagemaker_client.create_model.mock_calls[0] - assert create_model_kwargs["PrimaryContainer"] == {"Environment": {}, "Image": IMAGE} + assert create_model_kwargs["PrimaryContainer"] == {"Image": IMAGE} def test_create_endpoint_config(sagemaker_session): diff --git a/tests/unit/test_endpoint_from_model_data.py b/tests/unit/test_endpoint_from_model_data.py index 64804e2f7d..67db4c2e91 100644 --- a/tests/unit/test_endpoint_from_model_data.py +++ b/tests/unit/test_endpoint_from_model_data.py @@ -25,7 +25,7 @@ ACCELERATOR_TYPE = "ml.eia.medium" S3_MODEL_ARTIFACTS = "s3://mybucket/mymodel" DEPLOY_IMAGE = "mydeployimage" -CONTAINER_DEF = {"Environment": {}, "Image": DEPLOY_IMAGE, "ModelDataUrl": S3_MODEL_ARTIFACTS} +CONTAINER_DEF = {"Image": DEPLOY_IMAGE, "ModelDataUrl": S3_MODEL_ARTIFACTS} VPC_CONFIG = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]} DEPLOY_ROLE = "mydeployrole" ENV_VARS = {"PYTHONUNBUFFERED": "TRUE", "some": "nonsense"} diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 34e6a43fcf..24e7daf242 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -2731,7 +2731,6 @@ def test_fit_deploy_tags_in_estimator(name_from_base, sagemaker_session): role="DummyRole", container_defs={ "ModelDataUrl": "s3://bucket/model.tar.gz", - "Environment": {}, "Image": "fakeimage", }, enable_network_isolation=False, @@ -2781,7 +2780,6 @@ def test_fit_deploy_tags(name_from_base, sagemaker_session): role="DummyRole", container_defs={ "ModelDataUrl": "s3://bucket/model.tar.gz", - "Environment": {}, "Image": "fakeimage", }, enable_network_isolation=False, diff --git a/tests/unit/test_pipeline_model.py b/tests/unit/test_pipeline_model.py index f4fb892d21..2feca292b1 100644 --- a/tests/unit/test_pipeline_model.py +++ b/tests/unit/test_pipeline_model.py @@ -339,7 +339,6 @@ def test_network_isolation(tfo, time, sagemaker_session): }, { "Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4", - "Environment": {}, "ModelDataUrl": "s3://bucket/model_2.tar.gz", }, ], diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8958210092..ddf0ae688c 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1684,7 +1684,6 @@ def test_logs_for_transform_job_full_lifecycle(time, cw, sagemaker_session_full_ MODEL_NAME = "some-model" PRIMARY_CONTAINER = { - "Environment": {}, "Image": IMAGE, "ModelDataUrl": "s3://sagemaker-123/output/jobname/model/model.tar.gz", } From c97c4678b89ff89bda6f1ea4573803d7e1bcb947 Mon Sep 17 00:00:00 2001 From: Kevin Date: Fri, 2 Dec 2022 12:48:09 -0800 Subject: [PATCH 03/43] fix: type hint of PySparkProcessor __init__ (#3297) From de589419595fbf7bf76e55745f454864cc5998be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Perez?= Date: Fri, 2 Dec 2022 22:01:39 +0100 Subject: [PATCH 04/43] fix: fix PySparkProcessor __init__ params type (#3354) From 41dd3305c2673a4f85e54eec9858f37393c89431 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Fri, 2 Dec 2022 13:18:14 -0800 Subject: [PATCH 05/43] fix: Allow Py 3.7 for MMS Test Docker env (#3080) Co-authored-by: Mufaddal Rohawala --- tests/data/multimodel/container/Dockerfile | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/data/multimodel/container/Dockerfile b/tests/data/multimodel/container/Dockerfile index 4792a429c1..71c38a6605 100644 --- a/tests/data/multimodel/container/Dockerfile +++ b/tests/data/multimodel/container/Dockerfile @@ -1,4 +1,5 @@ -FROM public.ecr.aws/ubuntu/ubuntu:18.04 +# added latest image from https://gallery.ecr.aws/lts/ubuntu +FROM public.ecr.aws/ubuntu/ubuntu:22.04 # Set a docker label to advertise multi-model support on the container LABEL com.amazonaws.sagemaker.capabilities.multi-models=true @@ -15,7 +16,7 @@ RUN apt-get update && \ curl \ vim \ && rm -rf /var/lib/apt/lists/* \ - && curl -O https://bootstrap.pypa.io/pip/3.6/get-pip.py \ + && curl -O https://bootstrap.pypa.io/pip/get-pip.py \ && python3 get-pip.py RUN update-alternatives --install /usr/bin/python python /usr/bin/python3 1 From 1e23a3f6a7cf554aa537c5c4e21e35548053a6ee Mon Sep 17 00:00:00 2001 From: maldil Date: Fri, 2 Dec 2022 13:19:59 -0800 Subject: [PATCH 06/43] refactoring : using with statement (#3286) --- src/sagemaker/git_utils.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/git_utils.py b/src/sagemaker/git_utils.py index 80bd62d5be..c424753286 100644 --- a/src/sagemaker/git_utils.py +++ b/src/sagemaker/git_utils.py @@ -279,9 +279,8 @@ def _run_clone_command(repo_url, dest_dir): subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) elif repo_url.startswith("git@"): with tempfile.NamedTemporaryFile() as sshnoprompt: - write_pipe = open(sshnoprompt.name, "w") - write_pipe.write("ssh -oBatchMode=yes $@") - write_pipe.close() + with open(sshnoprompt.name, "w") as write_pipe: + write_pipe.write("ssh -oBatchMode=yes $@") os.chmod(sshnoprompt.name, 0o511) my_env["GIT_SSH"] = sshnoprompt.name subprocess.check_call(["git", "clone", repo_url, dest_dir], env=my_env) From 19efadf043678a6c7da4122368d6141e1ec2df10 Mon Sep 17 00:00:00 2001 From: Shreya Pandit Date: Fri, 2 Dec 2022 13:21:34 -0800 Subject: [PATCH 07/43] Update local_requirements.txt PyYAML version (#3095) Co-authored-by: Basil Beirouti Co-authored-by: Kalyani Nikure <110067132+knikure@users.noreply.github.com> --- requirements/extras/local_requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 5304d82b2a..5f2c85c2fe 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,4 +1,4 @@ urllib3==1.26.8 docker-compose==1.29.2 docker>=5.0.2,<7.0.0 -PyYAML==5.4.1 +PyYAML==6.0.0 From 76f7782db112b38cb7e058dffb1508f2d34fb50b Mon Sep 17 00:00:00 2001 From: arjkesh <33526713+arjkesh@users.noreply.github.com> Date: Fri, 2 Dec 2022 13:22:35 -0800 Subject: [PATCH 08/43] feature: Update TF 2.9 and TF 2.10 inference DLCs (#3465) --- .../image_uri_config/tensorflow.json | 66 ++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 6a01c3e3e6..0122dcd3ca 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -285,7 +285,9 @@ "2.5": "2.5.1", "2.6": "2.6.3", "2.7": "2.7.0", - "2.8": "2.8.0" + "2.8": "2.8.0", + "2.9": "2.9.2", + "2.10": "2.10.0" }, "versions": { "1.10.0": { @@ -1468,6 +1470,68 @@ "us-west-2": "763104351884" }, "repository": "tensorflow-inference" + }, + "2.9.2": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.10.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" } } }, From fde07388dc26cb270a0a0dfba91439c64e87751a Mon Sep 17 00:00:00 2001 From: Keshav Chandak Date: Sat, 3 Dec 2022 03:41:10 +0530 Subject: [PATCH 09/43] feature: Added transform with monitoring pipeline step in transformer (#3438) Co-authored-by: Keshav Chandak --- src/sagemaker/transformer.py | 158 +++++++++++++++++++++++++++++++- tests/integ/test_transformer.py | 66 ++++++++++++- 2 files changed, 220 insertions(+), 4 deletions(-) diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index cfcc637b99..97278abdd0 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -14,14 +14,17 @@ from __future__ import absolute_import from typing import Union, Optional, List, Dict -from botocore import exceptions +import logging +import copy +import time +from botocore import exceptions from sagemaker.job import _Job -from sagemaker.session import Session +from sagemaker.session import Session, get_execution_role from sagemaker.inputs import BatchDataCaptureConfig from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.functions import Join -from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.pipeline_context import runnable_by_pipeline, PipelineSession from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.utils import base_name_from_image, name_from_base @@ -266,6 +269,155 @@ def transform( if wait: self.latest_transform_job.wait(logs=logs) + def transform_with_monitoring( + self, + monitoring_config, + monitoring_resource_config, + data: str, + data_type: str = "S3Prefix", + content_type: str = None, + compression_type: str = None, + split_type: str = None, + input_filter: str = None, + output_filter: str = None, + join_source: str = None, + model_client_config: Dict[str, str] = None, + batch_data_capture_config: BatchDataCaptureConfig = None, + monitor_before_transform: bool = False, + supplied_baseline_statistics: str = None, + supplied_baseline_constraints: str = None, + wait: bool = True, + pipeline_name: str = None, + role: str = None, + ): + """Runs a transform job with monitoring job. + + Note that this function will not start a transform job immediately, + instead, it will create a SageMaker Pipeline and execute it. + If you provide an existing pipeline_name, no new pipeline will be created, otherwise, + each transform_with_monitoring call will create a new pipeline and execute. + + Args: + monitoring_config (Union[ + `sagemaker.workflow.quality_check_step.QualityCheckConfig`, + `sagemaker.workflow.quality_check_step.ClarifyCheckConfig` + ]): the monitoring configuration used for run model monitoring. + monitoring_resource_config (`sagemaker.workflow.check_job_config.CheckJobConfig`): + the check job (processing job) cluster resource configuration. + transform_step_args (_JobStepArguments): the transform step transform arguments. + data (str): Input data location in S3 for the transform job + data_type (str): What the S3 location defines (default: 'S3Prefix'). + Valid values: + * 'S3Prefix' - the S3 URI defines a key name prefix. All objects with this prefix + will be used as inputs for the transform job. + * 'ManifestFile' - the S3 URI points to a single manifest file listing each S3 + object to use as an input for the transform job. + content_type (str): MIME type of the input data (default: None). + compression_type (str): Compression type of the input data, if + compressed (default: None). Valid values: 'Gzip', None. + split_type (str): The record delimiter for the input object + (default: 'None'). Valid values: 'None', 'Line', 'RecordIO', and + 'TFRecord'. + input_filter (str): A JSONPath to select a portion of the input to + pass to the algorithm container for inference. If you omit the + field, it gets the value '$', representing the entire input. + For CSV data, each row is taken as a JSON array, + so only index-based JSONPaths can be applied, e.g. $[0], $[1:]. + CSV data should follow the `RFC format `_. + See `Supported JSONPath Operators + `_ + for a table of supported JSONPath operators. + For more information, see the SageMaker API documentation for + `CreateTransformJob + `_. + Some examples: "$[1:]", "$.features" (default: None). + output_filter (str): A JSONPath to select a portion of the + joined/original output to return as the output. + For more information, see the SageMaker API documentation for + `CreateTransformJob + `_. + Some examples: "$[1:]", "$.prediction" (default: None). + join_source (str): The source of data to be joined to the transform + output. It can be set to 'Input' meaning the entire input record + will be joined to the inference result. You can use OutputFilter + to select the useful portion before uploading to S3. (default: + None). Valid values: Input, None. + model_client_config (dict[str, str]): Model configuration. + Dictionary contains two optional keys, + 'InvocationsTimeoutInSeconds', and 'InvocationsMaxRetries'. + (default: ``None``). + batch_data_capture_config (BatchDataCaptureConfig): Configuration object which + specifies the configurations related to the batch data capture for the transform job + (default: ``None``). + monitor_before_transform (bgool): If to run data quality + or model explainability monitoring type, + a true value of this flag indicates running the check step before the transform job. + fail_on_violation (Union[bool, PipelineVariable]): A opt-out flag to not to fail the + check step when a violation is detected. + supplied_baseline_statistics (Union[str, PipelineVariable]): The S3 path + to the supplied statistics object representing the statistics JSON file + which will be used for drift to check (default: None). + supplied_baseline_constraints (Union[str, PipelineVariable]): The S3 path + to the supplied constraints object representing the constraints JSON file + which will be used for drift to check (default: None). + wait (bool): To determine if needed to wait for the pipeline execution to complete + pipeline_name (str): The name of the Pipeline for the monitoring and transfrom step + role (str): Execution role + """ + + transformer = self + if not isinstance(self.sagemaker_session, PipelineSession): + sagemaker_session = self.sagemaker_session + self.sagemaker_session = None + transformer = copy.deepcopy(self) + transformer.sagemaker_session = PipelineSession() + self.sagemaker_session = sagemaker_session + + transform_step_args = transformer.transform( + data=data, + data_type=data_type, + content_type=content_type, + compression_type=compression_type, + split_type=split_type, + input_filter=input_filter, + output_filter=output_filter, + batch_data_capture_config=batch_data_capture_config, + join_source=join_source, + model_client_config=model_client_config, + ) + + from sagemaker.workflow.monitor_batch_transform_step import MonitorBatchTransformStep + + monitoring_batch_step = MonitorBatchTransformStep( + name="MonitorBatchTransformStep", + display_name="MonitorBatchTransformStep", + description="", + transform_step_args=transform_step_args, + monitor_configuration=monitoring_config, + check_job_configuration=monitoring_resource_config, + monitor_before_transform=monitor_before_transform, + supplied_baseline_constraints=supplied_baseline_constraints, + supplied_baseline_statistics=supplied_baseline_statistics, + ) + + pipeline_name = ( + pipeline_name if pipeline_name else f"TransformWithMonitoring{int(time.time())}" + ) + # if pipeline exists, just start the execution + from sagemaker.workflow.pipeline import Pipeline + + pipeline = Pipeline( + name=pipeline_name, + steps=[monitoring_batch_step], + sagemaker_session=transformer.sagemaker_session, + ) + pipeline.upsert(role_arn=role if role else get_execution_role()) + execution = pipeline.start() + if wait: + logging.info("Waiting for transform with monitoring to execute ...") + execution.wait() + return execution + def delete_model(self): """Delete the corresponding SageMaker model for this Transformer.""" self.sagemaker_session.delete_model(self.model_name) diff --git a/tests/integ/test_transformer.py b/tests/integ/test_transformer.py index a0e37ffc77..1de333b987 100644 --- a/tests/integ/test_transformer.py +++ b/tests/integ/test_transformer.py @@ -25,6 +25,7 @@ from sagemaker.transformer import Transformer from sagemaker.estimator import Estimator from sagemaker.inputs import BatchDataCaptureConfig +from sagemaker.xgboost import XGBoostModel from sagemaker.utils import unique_name_from_base from tests.integ import ( datasets, @@ -36,7 +37,7 @@ from tests.integ.timeout import timeout, timeout_and_delete_model_with_transformer from tests.integ.vpc_test_utils import get_or_create_vpc_resources -from sagemaker.model_monitor import DatasetFormat, Statistics +from sagemaker.model_monitor import DatasetFormat, Statistics, Constraints from sagemaker.workflow.check_job_config import CheckJobConfig from sagemaker.workflow.quality_check_step import ( @@ -645,3 +646,66 @@ def _create_transformer_and_transform_job( job_name=unique_name_from_base("test-transform"), ) return transformer + + +def test_transformer_and_monitoring_job( + pipeline_session, + sagemaker_session, + role, + pipeline_name, + check_job_config, + data_bias_check_config, +): + xgb_model_data_s3 = pipeline_session.upload_data( + path=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "xgb_model.tar.gz"), + key_prefix="integ-test-data/xgboost/model", + ) + data_bias_supplied_baseline_constraints = Constraints.from_file_path( + constraints_file_path=os.path.join( + DATA_DIR, "pipeline/clarify_check_step/data_bias/good_cases/analysis.json" + ), + sagemaker_session=sagemaker_session, + ).file_s3_uri + + xgb_model = XGBoostModel( + model_data=xgb_model_data_s3, + framework_version="1.3-1", + role=role, + sagemaker_session=sagemaker_session, + entry_point=os.path.join(os.path.join(DATA_DIR, "xgboost_abalone"), "inference.py"), + enable_network_isolation=True, + ) + + xgb_model.deploy(_INSTANCE_COUNT, _INSTANCE_TYPE) + + transform_output = f"s3://{sagemaker_session.default_bucket()}/{pipeline_name}Transform" + transformer = Transformer( + model_name=xgb_model.name, + strategy="SingleRecord", + instance_type="ml.m5.xlarge", + instance_count=1, + output_path=transform_output, + sagemaker_session=pipeline_session, + ) + + transform_input = pipeline_session.upload_data( + path=os.path.join(DATA_DIR, "xgboost_abalone", "abalone"), + key_prefix="integ-test-data/xgboost_abalone/abalone", + ) + + execution = transformer.transform_with_monitoring( + monitoring_config=data_bias_check_config, + monitoring_resource_config=check_job_config, + data=transform_input, + content_type="text/libsvm", + supplied_baseline_constraints=data_bias_supplied_baseline_constraints, + role=role, + ) + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + + for execution_step in execution_steps: + assert execution_step["StepStatus"] == "Succeeded" + + xgb_model.delete_model() From 7f9f3b04b6704a4d2378b5d9aa3d37de9db45729 Mon Sep 17 00:00:00 2001 From: Clayton Parnell <42805768+claytonparnell@users.noreply.github.com> Date: Fri, 2 Dec 2022 17:12:34 -0500 Subject: [PATCH 10/43] fix: Fix bug forcing uploaded tar to be named sourcedir (#3412) --- src/sagemaker/processing.py | 19 +++++++++++-------- tests/integ/test_xgboost.py | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index db6ce2badd..308783578d 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1587,13 +1587,13 @@ def run( # type: ignore[override] framework script to run.Path (absolute or relative) to the local Python source file which should be executed as the entry point to training. When `code` is an S3 URI, ignore `source_dir`, - `dependencies, and `git_config`. If ``source_dir`` is specified, + `dependencies`, and `git_config`. If ``source_dir`` is specified, then ``code`` must point to a file located at the root of ``source_dir``. source_dir (str): Path (absolute, relative or an S3 URI) to a directory with any other processing source code dependencies aside from the entry point file (default: None). If ``source_dir`` is an S3 URI, it must - point to a tar.gz file. Structure within this directory are preserved - when processing on Amazon SageMaker (default: None). + point to a file named `sourcedir.tar.gz`. Structure within this directory + are preserved when processing on Amazon SageMaker (default: None). dependencies (list[str]): A list of paths to directories (absolute or relative) with any additional libraries that will be exported to the container (default: []). The library folders will be @@ -1730,12 +1730,15 @@ def _pack_and_upload_code( "sagemaker_session unspecified when creating your Processor to have one set up " "automatically." ) + if "/sourcedir.tar.gz" in estimator.uploaded_code.s3_prefix: + # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh. + entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace( + "sourcedir.tar.gz", + "runproc.sh", + ) + else: + raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`") - # Upload the bootstrapping code as s3://.../jobname/source/runproc.sh. - entrypoint_s3_uri = estimator.uploaded_code.s3_prefix.replace( - "sourcedir.tar.gz", - "runproc.sh", - ) script = estimator.uploaded_code.script_name s3_runproc_sh = S3Uploader.upload_string_as_file_body( self._generate_framework_script(script), diff --git a/tests/integ/test_xgboost.py b/tests/integ/test_xgboost.py index 733ab4665a..df06a8863a 100644 --- a/tests/integ/test_xgboost.py +++ b/tests/integ/test_xgboost.py @@ -40,6 +40,26 @@ def xgboost_training_job( ) +def test_sourcedir_naming( + sagemaker_session, + xgboost_latest_version, + xgboost_latest_py_version, + cpu_instance_type, +): + with pytest.raises(RuntimeError): + processor = XGBoostProcessor( + framework_version=xgboost_latest_version, + role=ROLE, + instance_count=1, + instance_type=cpu_instance_type, + sagemaker_session=sagemaker_session, + ) + processor.run( + source_dir="s3://bucket/deps.tar.gz", + code="main_script.py", + ) + + @pytest.mark.release def test_framework_processing_job_with_deps( sagemaker_session, From 5d5976726cb8e0cf7143d86b4abb4b665842fd14 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Fri, 2 Dec 2022 14:32:01 -0800 Subject: [PATCH 11/43] feature: Add Code Owners file (#3503) Co-authored-by: Navin Soni --- CODEOWNERS | 1 + requirements/extras/local_requirements.txt | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 CODEOWNERS diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000000..7f7ac28644 --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @aws/sagemaker-ml-frameworks diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 5f2c85c2fe..5304d82b2a 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,4 +1,4 @@ urllib3==1.26.8 docker-compose==1.29.2 docker>=5.0.2,<7.0.0 -PyYAML==6.0.0 +PyYAML==5.4.1 From 0f5cf1824c0b116c9b218c803f3b94a85e09fd45 Mon Sep 17 00:00:00 2001 From: ci Date: Sat, 3 Dec 2022 03:22:39 +0000 Subject: [PATCH 12/43] prepare release v2.119.0 --- CHANGELOG.md | 28 ++++++++++++++++++++++++++++ VERSION | 2 +- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 95e4a7b9cf..b8b3155231 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,33 @@ # Changelog +## v2.119.0 (2022-12-03) + +### Features + + * Add Code Owners file + * Added transform with monitoring pipeline step in transformer + * Update TF 2.9 and TF 2.10 inference DLCs + * make estimator accept json file as modelparallel config + * SageMaker Training Compiler does not support p4de instances + * Add support for SparkML v3.3 + +### Bug Fixes and Other Changes + + * Fix bug forcing uploaded tar to be named sourcedir + * Update local_requirements.txt PyYAML version + * refactoring : using with statement + * Allow Py 3.7 for MMS Test Docker env + * fix PySparkProcessor __init__ params type + * type hint of PySparkProcessor __init__ + * Return ARM XGB/SKLearn tags if `image_scope` is `inference_graviton` + * Update scipy to 1.7.3 to support M1 development envs + * Fixing type hints for Spark processor that has instance type/count params in reverse order + * Add DeepAR ap-northeast-3 repository. + * Fix AsyncInferenceConfig documentation typo + * fix ml_inf to ml_inf1 in Neo multi-version support + * Fix type annotations + * add neo mvp region accounts + ## v2.118.0 (2022-12-01) ### Features diff --git a/VERSION b/VERSION index 34d47b7f52..23fe2bf317 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.118.1.dev0 +2.119.0 From f1f0013dc0375aa22805b3a59b82cd2b1a08d40a Mon Sep 17 00:00:00 2001 From: ci Date: Sat, 3 Dec 2022 03:22:41 +0000 Subject: [PATCH 13/43] update development version to v2.119.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 23fe2bf317..dda4128cf2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.119.0 +2.119.1.dev0 From b7512bcccf7be59db0b14638db8bce0aaefd2a4d Mon Sep 17 00:00:00 2001 From: Leo Toikka Date: Mon, 5 Dec 2022 08:44:06 +0200 Subject: [PATCH 14/43] Attempt to fix multidatamodel.py "Environment" KeyError --- src/sagemaker/multidatamodel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/multidatamodel.py b/src/sagemaker/multidatamodel.py index 2cb6674ffd..c6dcf7a7df 100644 --- a/src/sagemaker/multidatamodel.py +++ b/src/sagemaker/multidatamodel.py @@ -139,7 +139,7 @@ def prepare_container_def( if self.model: container_definition = self.model.prepare_container_def(instance_type, accelerator_type) image_uri = container_definition["Image"] - environment = container_definition["Environment"] + environment = container_definition.get("Environment", {}) else: image_uri = self.image_uri environment = self.env From bb4b6897971a4e5ae0cbde948ef1682a64232b41 Mon Sep 17 00:00:00 2001 From: Radhika Bhat <78102284+RadhikaB-97@users.noreply.github.com> Date: Mon, 5 Dec 2022 10:06:58 -0800 Subject: [PATCH 15/43] feature: Add DXB region to frameworks by DLC (#3387) * Add DXB region * Remove change from neuron * Adding DXB to TF 2.1.0 and 2.1.1 --- src/sagemaker/image_uri_config/autogluon.json | 12 ++++ .../huggingface-training-compiler.json | 3 + .../image_uri_config/huggingface.json | 31 +++++++++ src/sagemaker/image_uri_config/mxnet.json | 13 ++++ src/sagemaker/image_uri_config/pytorch.json | 28 ++++++++ .../image_uri_config/tensorflow.json | 65 +++++++++++++++++++ 6 files changed, 152 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 3cc488c55d..0963520e02 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -26,6 +26,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -56,6 +57,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -86,6 +88,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -116,6 +119,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -146,6 +150,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -176,6 +181,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -217,6 +223,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -250,6 +257,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -283,6 +291,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -316,6 +325,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -349,6 +359,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -382,6 +393,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-training-compiler.json b/src/sagemaker/image_uri_config/huggingface-training-compiler.json index e771e2a548..482264b773 100644 --- a/src/sagemaker/image_uri_config/huggingface-training-compiler.json +++ b/src/sagemaker/image_uri_config/huggingface-training-compiler.json @@ -60,6 +60,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -89,6 +90,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -123,6 +125,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 317c17030a..e995c6e8ea 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -38,6 +38,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -70,6 +71,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -108,6 +110,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -140,6 +143,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -180,6 +184,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -213,6 +218,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -246,6 +252,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -279,6 +286,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -320,6 +328,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -353,6 +362,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -386,6 +396,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -419,6 +430,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -458,6 +470,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -491,6 +504,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -530,6 +544,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -563,6 +578,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -602,6 +618,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -635,6 +652,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -687,6 +705,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -720,6 +739,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -753,6 +773,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -794,6 +815,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -827,6 +849,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -860,6 +883,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -893,6 +917,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -932,6 +957,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -965,6 +991,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1004,6 +1031,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1037,6 +1065,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1076,6 +1105,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1109,6 +1139,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 12bc40fccf..14bb74f6a6 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -245,6 +245,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -277,6 +278,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -309,6 +311,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -341,6 +344,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -373,6 +377,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -632,6 +637,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -664,6 +670,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -696,6 +703,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -728,6 +736,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -760,6 +769,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -865,6 +875,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -897,6 +908,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -929,6 +941,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 3bf8016ba8..e1de6ca663 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -195,6 +195,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -230,6 +231,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -264,6 +266,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -298,6 +301,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -333,6 +337,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -368,6 +373,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -403,6 +409,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -438,6 +445,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -472,6 +480,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -506,6 +515,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -540,6 +550,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -574,6 +585,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -608,6 +620,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -642,6 +655,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -879,6 +893,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -914,6 +929,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -949,6 +965,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -983,6 +1000,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1018,6 +1036,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1053,6 +1072,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1088,6 +1108,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1123,6 +1144,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1157,6 +1179,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1191,6 +1214,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1225,6 +1249,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1259,6 +1284,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1293,6 +1319,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1327,6 +1354,7 @@ "eu-west-3": "763104351884", "eu-south-1": "692866216735", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index 0122dcd3ca..bb05682f67 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -154,6 +154,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -185,6 +186,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -216,6 +218,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -247,6 +250,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -401,6 +405,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -432,6 +437,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -463,6 +469,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -494,6 +501,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -525,6 +533,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -556,6 +565,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -587,6 +597,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -810,6 +821,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -841,6 +853,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -872,6 +885,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -903,6 +917,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -934,6 +949,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -965,6 +981,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -996,6 +1013,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1027,6 +1045,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1058,6 +1077,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1089,6 +1109,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1120,6 +1141,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1151,6 +1173,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1182,6 +1205,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1213,6 +1237,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1244,6 +1269,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1275,6 +1301,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1306,6 +1333,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1337,6 +1365,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1368,6 +1397,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1399,6 +1429,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1430,6 +1461,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1461,6 +1493,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1760,6 +1793,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1796,6 +1830,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1831,6 +1866,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1867,6 +1903,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1903,6 +1940,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1939,6 +1977,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -1975,6 +2014,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2202,6 +2242,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2237,6 +2278,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2272,6 +2314,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2306,6 +2349,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2340,6 +2384,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2375,6 +2420,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2410,6 +2456,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2444,6 +2491,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2478,6 +2526,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2512,6 +2561,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2546,6 +2596,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2580,6 +2631,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2614,6 +2666,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2648,6 +2701,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2682,6 +2736,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2716,6 +2771,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2750,6 +2806,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2784,6 +2841,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2818,6 +2876,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2852,6 +2911,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2886,6 +2946,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2920,6 +2981,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2954,6 +3016,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -2988,6 +3051,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", @@ -3022,6 +3086,7 @@ "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", From b68bcd9344deba8e3bedf7ccb0adb31498735b13 Mon Sep 17 00:00:00 2001 From: Brock Wade Date: Mon, 5 Dec 2022 14:11:34 -0800 Subject: [PATCH 16/43] fix: support idempotency for framework and spark processors (#3460) Co-authored-by: Brock Wade Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/processing.py | 8 +- src/sagemaker/spark/processing.py | 37 +- src/sagemaker/workflow/utilities.py | 7 +- tests/data/spark/code/java/TestJarFile.jar | Bin 0 -> 1714 bytes .../hello-java-spark/HelloJavaSparkApp.jar | Bin 0 -> 1714 bytes .../unit/sagemaker/workflow/test_pipeline.py | 8 +- .../workflow/test_processing_step.py | 277 +++++++++++++- .../sagemaker/workflow/test_training_step.py | 354 +++++++++++++++--- .../sagemaker/workflow/test_transform_step.py | 8 + .../sagemaker/workflow/test_tuning_step.py | 58 +-- 10 files changed, 661 insertions(+), 96 deletions(-) create mode 100644 tests/data/spark/code/java/TestJarFile.jar create mode 100644 tests/data/spark/code/java/hello-java-spark/HelloJavaSparkApp.jar diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 308783578d..81e3d34b1d 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -23,6 +23,7 @@ import logging from textwrap import dedent from typing import Dict, List, Optional, Union +from copy import copy import attr @@ -1830,14 +1831,17 @@ def _patch_inputs_with_payload(self, inputs, s3_payload) -> List[ProcessingInput # a7399455f5386d83ddc5cb15c0db00c04bd518ec/src/sagemaker/processing.py#L425-L426 if inputs is None: inputs = [] - inputs.append( + + # make a shallow copy of user inputs + patched_inputs = copy(inputs) + patched_inputs.append( ProcessingInput( input_name="code", source=s3_payload, destination="/opt/ml/processing/input/code/", ) ) - return inputs + return patched_inputs def _set_entrypoint(self, command, user_script_name): """Framework processor override for setting processing job entrypoint. diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index dc3d26a355..912bc90d80 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -30,6 +30,7 @@ from enum import Enum from io import BytesIO from urllib.parse import urlparse +from copy import copy from typing import Union, List, Dict, Optional @@ -279,6 +280,10 @@ def run( def _extend_processing_args(self, inputs, outputs, **kwargs): """Extends processing job args such as inputs.""" + # make a shallow copy of user outputs + outputs = outputs or [] + extended_outputs = copy(outputs) + if kwargs.get("spark_event_logs_s3_uri"): spark_event_logs_s3_uri = kwargs.get("spark_event_logs_s3_uri") self._validate_s3_uri(spark_event_logs_s3_uri) @@ -297,16 +302,21 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): s3_upload_mode="Continuous", ) - outputs = outputs or [] - outputs.append(output) + extended_outputs.append(output) + + # make a shallow copy of user inputs + inputs = inputs or [] + extended_inputs = copy(inputs) if kwargs.get("configuration"): configuration = kwargs.get("configuration") self._validate_configuration(configuration) - inputs = inputs or [] - inputs.append(self._stage_configuration(configuration)) + extended_inputs.append(self._stage_configuration(configuration)) - return inputs, outputs + return ( + extended_inputs if extended_inputs else None, + extended_outputs if extended_outputs else None, + ) def start_history_server(self, spark_event_logs_s3_uri=None): """Starts a Spark history server. @@ -940,9 +950,16 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): outputs: Processing outputs. kwargs: Additional keyword arguments passed to `super()`. """ + + if inputs is None: + inputs = [] + + # make a shallow copy of user inputs + extended_inputs = copy(inputs) + self.command = [_SparkProcessorBase._default_command] extended_inputs = self._handle_script_dependencies( - inputs, kwargs.get("submit_py_files"), FileType.PYTHON + extended_inputs, kwargs.get("submit_py_files"), FileType.PYTHON ) extended_inputs = self._handle_script_dependencies( extended_inputs, kwargs.get("submit_jars"), FileType.JAR @@ -1199,8 +1216,14 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): else: raise ValueError("submit_class is required") + if inputs is None: + inputs = [] + + # make a shallow copy of user inputs + extended_inputs = copy(inputs) + extended_inputs = self._handle_script_dependencies( - inputs, kwargs.get("submit_jars"), FileType.JAR + extended_inputs, kwargs.get("submit_jars"), FileType.JAR ) extended_inputs = self._handle_script_dependencies( extended_inputs, kwargs.get("submit_files"), FileType.FILE diff --git a/src/sagemaker/workflow/utilities.py b/src/sagemaker/workflow/utilities.py index 89d7c5dfd9..08c170d424 100644 --- a/src/sagemaker/workflow/utilities.py +++ b/src/sagemaker/workflow/utilities.py @@ -114,11 +114,12 @@ def get_code_hash(step: Entity) -> str: if isinstance(step, ProcessingStep) and step.step_args: kwargs = step.step_args.func_kwargs source_dir = kwargs.get("source_dir") + submit_class = kwargs.get("submit_class") dependencies = get_processing_dependencies( [ kwargs.get("dependencies"), kwargs.get("submit_py_files"), - kwargs.get("submit_class"), + [submit_class] if submit_class else None, kwargs.get("submit_jars"), kwargs.get("submit_files"), ] @@ -168,7 +169,7 @@ def get_processing_code_hash(code: str, source_dir: str, dependencies: List[str] str: A hash string representing the unique code artifact(s) for the step """ - # FrameworkProcessor + # If FrameworkProcessor contains source_dir if source_dir: source_dir_url = urlparse(source_dir) if source_dir_url.scheme == "" or source_dir_url.scheme == "file": @@ -400,5 +401,5 @@ def execute_job_functions(step_args: _StepArguments): """ chained_args = step_args.func(*step_args.func_args, **step_args.func_kwargs) - if chained_args: + if isinstance(chained_args, _StepArguments): execute_job_functions(chained_args) diff --git a/tests/data/spark/code/java/TestJarFile.jar b/tests/data/spark/code/java/TestJarFile.jar new file mode 100644 index 0000000000000000000000000000000000000000..d528331d557da00908e31c46b2a0dd3dc250a2bf GIT binary patch literal 1714 zcmWIWW@Zs#;Nak32&_&EWk3R)3@i-3t|5-Po_=on|4uP5Ff#;rvvYt{FhP|C;M6Pv zQ~}rQ>*(j{<{BKL=j-;__snS@Z(Y5MyxzK6=gyqp9At3C_`%a6JuhD!Pv48Bt5`TA zUPvC1mX_4auz02x_VoEn)#uN(DxRsn&iqvLv4|1u2Dgc~Y@C2LfH24nTnr3AcNxV* zp?E+PD4UU*lasHTl~|UjTU?M>l&znfpR12si##qZiMfeY`FV-u#dtJp64qRtn4X%O zn4MaL#~6K5jDdIxw}(tfH>@PJxCHDxNU}f=RWCA4^Z><#7ce4%LGj>NP@o5nmEPT4 zha3c4fB)?=3}v}1?~$s$O;hJc(w#MhC*L&B^>k7F|4!|RkwJmw^vM@^ZbTbg%>MGD zm+}6Zogs4UvrW~{G>e_Sq%rjvyCrkm1j%#LG~x;ll{xxp9`tuS$+mI96m!?>jqc*F zmUDYt@G4ul|1M+k`QN#lmi*livHr1!m8grxm8+{7vd>(Rb%^U*cxRE#x^uBx^RN9~ zllGMzl-M^*=l9J9diW#|BN97$kjP>SlHA0+%rsz7>XlTK08{;$%cbW$b@a9cd7L|c z)%%R^np5Y!b@Z=k`}z2v^*!UKdr4c*L+8{rZBJm{`0R1^1h54%;aZXMFvtWh2HbfL zVZuQm6GsljZ3HL}BET0Q6RQ!(ITE*Fpgf5HhKvLaL(ZYNjRoaV1gIdzSXhq5Z8#{; zBEV774Tt7nL=pidSmdM(%EJgC4oo=&f*27h5a)w!z@DR#(-+8I*(j{<{BKL=j-;__snS@Z(Y5MyxzK6=gyqp9At3C_`%a6JuhD!Pv48Bt5`TA zUPvC1mX_4auz02x_VoEn)#uN(DxRsn&iqvLv4|1u2Dgc~Y@C2LfH24nTnr3AcNxV* zp?E+PD4UU*lasHTl~|UjTU?M>l&znfpR12si##qZiMfeY`FV-u#dtJp64qRtn4X%O zn4MaL#~6K5jDdIxw}(tfH>@PJxCHDxNU}f=RWCA4^Z><#7ce4%LGj>NP@o5nmEPT4 zha3c4fB)?=3}v}1?~$s$O;hJc(w#MhC*L&B^>k7F|4!|RkwJmw^vM@^ZbTbg%>MGD zm+}6Zogs4UvrW~{G>e_Sq%rjvyCrkm1j%#LG~x;ll{xxp9`tuS$+mI96m!?>jqc*F zmUDYt@G4ul|1M+k`QN#lmi*livHr1!m8grxm8+{7vd>(Rb%^U*cxRE#x^uBx^RN9~ zllGMzl-M^*=l9J9diW#|BN97$kjP>SlHA0+%rsz7>XlTK08{;$%cbW$b@a9cd7L|c z)%%R^np5Y!b@Z=k`}z2v^*!UKdr4c*L+8{rZBJm{`0R1^1h54%;aZXMFvtWh2HbfL zVZuQm6GsljZ3HL}BET0Q6RQ!(ITE*Fpgf5HhKvLaL(ZYNjRoaV1gIdzSXhq5Z8#{; zBEV774Tt7nL=pidSmdM(%EJgC4oo=&f*27h5a)w!z@DR#(-+8I Date: Mon, 5 Dec 2022 18:18:10 -0600 Subject: [PATCH 17/43] feature: Update registries with new region account number mappings. (#3492) --- src/sagemaker/image_uri_config/autogluon.json | 18 ++++ .../image_uri_config/huggingface-neuron.json | 3 + .../image_uri_config/huggingface.json | 39 +++++++ src/sagemaker/image_uri_config/mxnet.json | 24 +++++ src/sagemaker/image_uri_config/pytorch.json | 54 ++++++++++ .../image_uri_config/tensorflow.json | 102 ++++++++++++++++++ 6 files changed, 240 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 0963520e02..3a9f02142c 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -210,6 +210,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -217,11 +218,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -244,6 +247,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -251,11 +255,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -278,6 +284,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -285,11 +292,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -312,6 +321,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -319,11 +329,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -346,6 +358,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -353,11 +366,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -380,6 +395,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -387,11 +403,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index 1e2246cb11..47d6dbd1dc 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -15,17 +15,20 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ca-central-1": "763104351884", "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index e995c6e8ea..5b98fc0d02 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -692,6 +692,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -699,11 +700,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -726,6 +729,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -733,11 +737,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -760,6 +766,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -767,8 +774,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -802,6 +811,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -809,11 +819,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -836,6 +848,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -843,11 +856,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -870,6 +885,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -877,8 +893,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -904,6 +922,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -911,8 +930,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -944,6 +965,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -951,11 +973,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -978,6 +1002,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -985,8 +1010,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1018,6 +1045,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1025,11 +1053,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -1052,6 +1082,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1059,8 +1090,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1092,6 +1125,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1099,11 +1133,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -1126,6 +1162,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1133,8 +1170,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 14bb74f6a6..8d8733e480 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -624,6 +624,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -631,11 +632,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -657,6 +660,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -664,11 +668,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -690,6 +696,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -697,11 +704,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -723,6 +732,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -730,11 +740,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -756,6 +768,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -763,11 +776,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -862,6 +877,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -869,11 +885,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -895,6 +913,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -902,11 +921,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -928,6 +949,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -935,11 +957,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index e1de6ca663..18a382e591 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -17,6 +17,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -25,7 +26,9 @@ "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", "eu-north-1": "763104351884", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", + "eu-south-2": "503227376785", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-west-2": "763104351884" @@ -39,8 +42,11 @@ "registries": { "ap-northeast-1": "763104351884", "ap-northeast-2": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-3": "907027046896", + "eu-central-2": "380420809688", "eu-west-1": "763104351884", + "eu-south-2": "503227376785", "us-east-1": "763104351884", "us-east-2": "763104351884", "us-west-2": "763104351884" @@ -182,6 +188,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -189,11 +196,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -218,6 +227,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -225,11 +235,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -253,6 +265,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -260,11 +273,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -288,6 +303,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -295,11 +311,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -324,6 +342,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -331,11 +350,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -360,6 +381,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -367,11 +389,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -396,6 +420,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -403,11 +428,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -432,6 +459,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -439,11 +467,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -467,6 +497,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -474,11 +505,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -502,6 +535,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -509,11 +543,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -537,6 +573,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -544,11 +581,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -572,6 +611,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -579,11 +619,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -607,6 +649,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -614,11 +657,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -642,6 +687,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -649,11 +695,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "me-central-1": "914824155844", "sa-east-1": "763104351884", @@ -677,6 +725,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -684,11 +733,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -721,6 +772,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -728,11 +780,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index bb05682f67..a0f2bba014 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -141,6 +141,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -148,12 +149,14 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "me-south-1": "217643126080", + "eu-south-2": "503227376785", "me-central-1": "914824155844", "sa-east-1": "763104351884", "us-east-1": "763104351884", @@ -173,6 +176,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -180,8 +184,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -205,6 +211,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -212,8 +219,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -237,6 +246,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -244,8 +254,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -392,6 +404,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -399,8 +412,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -424,6 +439,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -431,8 +447,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -456,6 +474,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -463,8 +482,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -488,6 +509,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -495,8 +517,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -520,6 +544,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -527,8 +552,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -552,6 +579,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -559,8 +587,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -584,6 +614,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -591,8 +622,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -808,6 +841,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -815,8 +849,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -840,6 +876,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -847,8 +884,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -872,6 +911,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -879,8 +919,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -904,6 +946,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -911,8 +954,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -936,6 +981,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -943,8 +989,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -968,6 +1016,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -975,8 +1024,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1000,6 +1051,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1007,8 +1059,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1032,6 +1086,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1039,8 +1094,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1064,6 +1121,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1071,8 +1129,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1096,6 +1156,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1103,8 +1164,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1128,6 +1191,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1135,8 +1199,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1160,6 +1226,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1167,8 +1234,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1192,6 +1261,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1199,8 +1269,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1224,6 +1296,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1231,8 +1304,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1256,6 +1331,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1263,8 +1339,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1288,6 +1366,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1295,8 +1374,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1320,6 +1401,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1327,8 +1409,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1352,6 +1436,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1359,8 +1444,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1384,6 +1471,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1391,8 +1479,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1416,6 +1506,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1423,8 +1514,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1448,6 +1541,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1455,8 +1549,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1480,6 +1576,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1487,8 +1584,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1587,6 +1686,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1594,11 +1694,13 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "me-south-1": "217643126080", "sa-east-1": "763104351884", "us-east-1": "763104351884", From 767da0afc5cfb11eb96b324debf9a310abaafbcc Mon Sep 17 00:00:00 2001 From: Loki Date: Wed, 7 Dec 2022 06:06:34 +0530 Subject: [PATCH 18/43] feature: Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 (#3500) Co-authored-by: Ubuntu --- src/sagemaker/fw_utils.py | 2 +- .../pytorch-training-compiler.json | 41 ++ src/sagemaker/image_uris.py | 2 +- src/sagemaker/pytorch/__init__.py | 2 + src/sagemaker/pytorch/estimator.py | 60 +- .../pytorch/training_compiler/__init__.py | 0 .../pytorch/training_compiler/config.py | 151 +++++ tests/conftest.py | 1 + tests/data/huggingface_byoc/requirements.txt | 2 + tests/data/huggingface_byoc/run_glue.py | 568 ++++++++++++++++ tests/data/huggingface_byoc/train/dummy.csv | 1 + tests/integ/__init__.py | 2 +- tests/integ/test_training_compiler.py | 50 +- .../test_pytorch_compiler.py | 616 ++++++++++++++++++ 14 files changed, 1467 insertions(+), 31 deletions(-) create mode 100644 src/sagemaker/image_uri_config/pytorch-training-compiler.json create mode 100644 src/sagemaker/pytorch/training_compiler/__init__.py create mode 100644 src/sagemaker/pytorch/training_compiler/config.py create mode 100644 tests/data/huggingface_byoc/requirements.txt create mode 100644 tests/data/huggingface_byoc/run_glue.py create mode 100644 tests/data/huggingface_byoc/train/dummy.csv create mode 100644 tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index d82d3596ac..5efe530396 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -493,7 +493,7 @@ def framework_name_from_image(image_uri): # We must support both the legacy and current image name format. name_pattern = re.compile( r"""^(?:sagemaker(?:-rl)?-)? - (tensorflow|mxnet|chainer|pytorch|scikit-learn|xgboost + (tensorflow|mxnet|chainer|pytorch|pytorch-trcomp|scikit-learn|xgboost |huggingface-tensorflow|huggingface-pytorch |huggingface-tensorflow-trcomp|huggingface-pytorch-trcomp)(?:-)? (scriptmode|training)? diff --git a/src/sagemaker/image_uri_config/pytorch-training-compiler.json b/src/sagemaker/image_uri_config/pytorch-training-compiler.json new file mode 100644 index 0000000000..892ff4237d --- /dev/null +++ b/src/sagemaker/image_uri_config/pytorch-training-compiler.json @@ -0,0 +1,41 @@ +{ + "training": { + "processors": [ + "gpu" + ], + "version_aliases": { + "1.12": "1.12.0" + }, + "versions": { + "1.12.0": { + "py_versions": [ + "py38" + ], + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ca-central-1": "763104351884", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "pytorch-training" + } + } + } +} diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 7d1d3bd835..c42ce02188 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -146,7 +146,7 @@ def retrieve( tolerate_deprecated_model, ) - if training_compiler_config and (framework == HUGGING_FACE_FRAMEWORK): + if training_compiler_config and (framework in [HUGGING_FACE_FRAMEWORK, "pytorch"]): final_image_scope = image_scope config = _config_for_framework_and_scope( framework + "-training-compiler", final_image_scope, accelerator_type diff --git a/src/sagemaker/pytorch/__init__.py b/src/sagemaker/pytorch/__init__.py index cac5f94b9a..e2d14f4163 100644 --- a/src/sagemaker/pytorch/__init__.py +++ b/src/sagemaker/pytorch/__init__.py @@ -16,3 +16,5 @@ from sagemaker.pytorch.estimator import PyTorch # noqa: F401 from sagemaker.pytorch.model import PyTorchModel, PyTorchPredictor # noqa: F401 from sagemaker.pytorch.processing import PyTorchProcessor # noqa: F401 + +from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig # noqa: F401 diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 686de4a78c..29e254662f 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -28,6 +28,7 @@ ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel +from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable @@ -51,7 +52,8 @@ def __init__( hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, - **kwargs + compiler_config: Optional[TrainingCompilerConfig] = None, + **kwargs, ): """This ``Estimator`` executes a PyTorch script in a managed PyTorch execution environment. @@ -208,6 +210,31 @@ def __init__( To learn more, see `Training with parameter servers `_. + **To enable distributed training with + `SageMaker Training Compiler `_ + for PyTorch:** + + .. code:: python + + { + "pytorchxla": { + "enabled": True + } + } + + To learn more, see `SageMaker Training Compiler + `_ + in the *Amazon SageMaker Developer Guide*. + + .. note:: + + When you use this PyTorch XLA option for distributed training strategy, + you must add the ``compiler_config`` parameter and activate SageMaker + Training Compiler. + + compiler_config (:class:`~sagemaker.pytorch.TrainingCompilerConfig`): + Configures SageMaker Training Compiler to accelerate training. + **kwargs: Additional kwargs passed to the :class:`~sagemaker.estimator.Framework` constructor. @@ -250,6 +277,25 @@ def __init__( self.distribution = distribution or {} + if compiler_config is not None: + if not isinstance(compiler_config, TrainingCompilerConfig): + error_string = ( + f"Expected instance of type {TrainingCompilerConfig}" + f"for argument compiler_config. " + f"Instead got {type(compiler_config)}" + ) + raise ValueError(error_string) + if compiler_config: + compiler_config.validate(self) + elif distribution is not None and "pytorchxla" in distribution: + raise ValueError( + "Distributed training through PyTorch XLA is currently only supported " + "when SageMaker Training Compiler is enabled. To learn more, " + "see Enable SageMaker Training Compiler at " + "https://docs.aws.amazon.com/sagemaker/latest/dg/training-compiler-enable.html." + ) + self.compiler_config = compiler_config + def _pytorch_distribution_configuration(self, distribution): """Returns a dict of distribution config for PyTorch training @@ -289,6 +335,12 @@ def hyperparameters(self): hyperparameters.update( EstimatorBase._json_encode_hyperparameters(additional_hyperparameters) ) + if self.compiler_config: + training_compiler_hyperparameters = self.compiler_config._to_hyperparameter_dict() + hyperparameters.update( + EstimatorBase._json_encode_hyperparameters(training_compiler_hyperparameters) + ) + return hyperparameters def create_model( @@ -299,7 +351,7 @@ def create_model( entry_point=None, source_dir=None, dependencies=None, - **kwargs + **kwargs, ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. @@ -350,7 +402,7 @@ def create_model( sagemaker_session=self.sagemaker_session, vpc_config=self.get_vpc_config(vpc_config_override), dependencies=(dependencies or self.dependencies), - **kwargs + **kwargs, ) @classmethod @@ -371,6 +423,8 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na ) image_uri = init_params.pop("image_uri") framework, py_version, tag, _ = framework_name_from_image(image_uri) + if framework: + framework = framework.split("-")[0] if tag is None: framework_version = None diff --git a/src/sagemaker/pytorch/training_compiler/__init__.py b/src/sagemaker/pytorch/training_compiler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/sagemaker/pytorch/training_compiler/config.py b/src/sagemaker/pytorch/training_compiler/config.py new file mode 100644 index 0000000000..7faf8acbbd --- /dev/null +++ b/src/sagemaker/pytorch/training_compiler/config.py @@ -0,0 +1,151 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Configuration for the SageMaker Training Compiler.""" +from __future__ import absolute_import +import logging +from typing import Union +from packaging.specifiers import SpecifierSet +from packaging.version import Version + +from sagemaker.training_compiler.config import TrainingCompilerConfig as BaseConfig +from sagemaker.workflow.entities import PipelineVariable + +logger = logging.getLogger(__name__) + + +class TrainingCompilerConfig(BaseConfig): + """The SageMaker Training Compiler configuration class.""" + + SUPPORTED_INSTANCE_CLASS_PREFIXES = ["p3", "p3dn", "g4dn", "p4d", "g5"] + SUPPORTED_INSTANCE_TYPES_WITH_EFA = [ + "ml.g4dn.8xlarge", + "ml.g4dn.12xlarge", + "ml.g5.48xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + ] + + def __init__( + self, + enabled: Union[bool, PipelineVariable] = True, + debug: Union[bool, PipelineVariable] = False, + ): + """This class initializes a ``TrainingCompilerConfig`` instance. + + `Amazon SageMaker Training Compiler + `_ + is a feature of SageMaker Training + and speeds up training jobs by optimizing model execution graphs. + + You can compile PyTorch models + by passing the object of this configuration class to the ``compiler_config`` + parameter of the :class:`~sagemaker.pytorch.PyTorch` + estimator. + + Args: + enabled (bool or PipelineVariable): Optional. Switch to enable SageMaker + Training Compiler. The default is ``True``. + debug (bool or PipelineVariable): Optional. Whether to dump detailed logs + for debugging. This comes with a potential performance slowdown. + The default is ``False``. + + **Example**: The following code shows the basic usage of the + :class:`sagemaker.pytorch.TrainingCompilerConfig()` class + to run a PyTorch training job with the compiler. + + .. code-block:: python + + from sagemaker.pytorch import PyTorch, TrainingCompilerConfig + + pytorch_estimator=PyTorch( + ... + compiler_config=TrainingCompilerConfig() + ) + + .. seealso:: + + For more information about how to enable SageMaker Training Compiler + for various training settings such as distributed training, + see `Enable SageMaker Training Compiler + `_ + in the `Amazon SageMaker Training Compiler developer guide + `_. + + """ + + super(TrainingCompilerConfig, self).__init__(enabled=enabled, debug=debug) + + @classmethod + def validate( + cls, + estimator, + ): + """Checks if SageMaker Training Compiler is configured correctly. + + Args: + estimator (:class:`sagemaker.pytorch.PyTorch`): An estimator object. + If SageMaker Training Compiler is enabled, it will validate whether + the estimator is configured to be compatible with Training Compiler. + + Raises: + ValueError: Raised if the requested configuration is not compatible + with SageMaker Training Compiler. + """ + + super(TrainingCompilerConfig, cls).validate(estimator) + + if estimator.image_uri: + error_helper_string = ( + "Overriding the image URI is currently not supported " + "for SageMaker Training Compiler." + "Specify the following parameters to run the PyTorch training job " + "with SageMaker Training Compiler enabled: " + "framework_version, and compiler_config." + ) + raise ValueError(error_helper_string) + + if estimator.distribution: + pt_xla_present = "pytorchxla" in estimator.distribution + pt_xla_enabled = estimator.distribution.get("pytorchxla", {}).get("enabled", False) + if pt_xla_enabled: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet("< 1.12"): + error_helper_string = ( + "Distribution mechanism 'pytorchxla' is currently only supported for " + "PyTorch >= 1.12 when SageMaker Training Compiler is enabled." + " Received framework_version={} which is unsupported." + ) + raise ValueError(error_helper_string.format(estimator.framework_version)) + if estimator.instance_type not in cls.SUPPORTED_INSTANCE_TYPES_WITH_EFA: + logger.warning( + "Consider using instances with EFA support when " + "training with PyTorch >= 1.12 and SageMaker Training Compiler " + "enabled. SageMaker Training Compiler leverages EFA to provide better " + "performance for distributed training." + ) + if not pt_xla_present: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet(">= 1.12"): + error_helper_string = ( + "'pytorchxla' is the only distribution mechanism currently supported " + "for PyTorch >= 1.12 when SageMaker Training Compiler is enabled." + " Received distribution={} which is unsupported." + ) + raise ValueError(error_helper_string.format(estimator.distribution)) + elif estimator.instance_count and estimator.instance_count > 1: + if estimator.framework_version: + if Version(estimator.framework_version) in SpecifierSet(">= 1.12"): + logger.warning( + "Consider setting 'distribution' to 'pytorchxla' for distributed " + "training with PyTorch >= 1.12 and SageMaker Training Compiler enabled." + ) diff --git a/tests/conftest.py b/tests/conftest.py index e92d98112b..f6682ebb8c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,6 +73,7 @@ "neo_pytorch", "neo_tensorflow", "pytorch", + "pytorch_training_compiler", "ray_pytorch", "ray_tensorflow", "sklearn", diff --git a/tests/data/huggingface_byoc/requirements.txt b/tests/data/huggingface_byoc/requirements.txt new file mode 100644 index 0000000000..462542f1c1 --- /dev/null +++ b/tests/data/huggingface_byoc/requirements.txt @@ -0,0 +1,2 @@ +transformers +datasets diff --git a/tests/data/huggingface_byoc/run_glue.py b/tests/data/huggingface_byoc/run_glue.py new file mode 100644 index 0000000000..1060398fa4 --- /dev/null +++ b/tests/data/huggingface_byoc/run_glue.py @@ -0,0 +1,568 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2020 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Finetuning the library models for sequence classification on GLUE.""" +# You can also adapt this script on your own text classification task. Pointers for this are left as comments. + +import logging +import os +import random +import sys +from dataclasses import dataclass, field +from typing import Optional + +import numpy as np +from datasets import load_dataset, load_metric + +import transformers +from transformers import ( + AutoConfig, + AutoModelForSequenceClassification, + AutoTokenizer, + DataCollatorWithPadding, + EvalPrediction, + HfArgumentParser, + PretrainedConfig, + Trainer, + TrainingArguments, + default_data_collator, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version + + +# Will error if the minimal version of Transformers is not installed. Remove at your own risks. +check_min_version("4.5.0") + +task_to_keys = { + "cola": ("sentence", None), + "mnli": ("premise", "hypothesis"), + "mrpc": ("sentence1", "sentence2"), + "qnli": ("question", "sentence"), + "qqp": ("question1", "question2"), + "rte": ("sentence1", "sentence2"), + "sst2": ("sentence", None), + "stsb": ("sentence1", "sentence2"), + "wnli": ("sentence1", "sentence2"), +} + +logger = logging.getLogger(__name__) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + + Using `HfArgumentParser` we can turn this class + into argparse arguments to be able to specify them on + the command line. + """ + + task_name: Optional[str] = field( + default=None, + metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())}, + ) + max_seq_length: int = field( + default=128, + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer " + "than this will be truncated, sequences shorter will be padded." + }, + ) + overwrite_cache: bool = field( + default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."} + ) + pad_to_max_length: bool = field( + default=True, + metadata={ + "help": "Whether to pad all samples to `max_seq_length`. " + "If False, will pad the samples dynamically when batching to the maximum length in the batch." + }, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + }, + ) + max_val_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of validation examples to this " + "value if set." + }, + ) + max_test_samples: Optional[int] = field( + default=None, + metadata={ + "help": "For debugging purposes or quicker training, truncate the number of test examples to this " + "value if set." + }, + ) + train_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the training data."} + ) + validation_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the validation data."} + ) + test_file: Optional[str] = field( + default=None, metadata={"help": "A csv or a json file containing the test data."} + ) + + def __post_init__(self): + if self.task_name is not None: + self.task_name = self.task_name.lower() + if self.task_name not in task_to_keys.keys(): + raise ValueError( + "Unknown task, you should pick one in " + ",".join(task_to_keys.keys()) + ) + elif self.train_file is None or self.validation_file is None: + raise ValueError("Need either a GLUE task or a training/validation file.") + else: + train_extension = self.train_file.split(".")[-1] + assert train_extension in [ + "csv", + "json", + ], "`train_file` should be a csv or a json file." + validation_extension = self.validation_file.split(".")[-1] + assert ( + validation_extension == train_extension + ), "`validation_file` should have the same extension (csv or json) as `train_file`." + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} + ) + config_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained config name or path if not the same as model_name"}, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where do you want to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + use_auth_token: bool = field( + default=False, + metadata={ + "help": "Will use the token generated when running `transformers-cli login` (necessary to use this script " + "with private models)." + }, + ) + + +def main(): + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Detecting last checkpoint. + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif last_checkpoint is not None: + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + logger.info(f"Training/evaluation parameters {training_args}") + + # Set seed before initializing model. + set_seed(training_args.seed) + + # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) + # or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub). + # + # For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the + # sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named + # label if at least two columns are provided. + # + # If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this + # single column. You can easily tweak this behavior (see below) + # + # In distributed training, the load_dataset function guarantee that only one local process can concurrently + # download the dataset. + if data_args.task_name is not None: + # Downloading and loading a dataset from the hub. + datasets = load_dataset("glue", data_args.task_name) + else: + # Loading a dataset from your local files. + # CSV/JSON training and evaluation files are needed. + data_files = {"train": data_args.train_file, "validation": data_args.validation_file} + + # Get the test dataset: you can provide your own CSV/JSON test file (see below) + # when you use `do_predict` without specifying a GLUE benchmark task. + if training_args.do_predict: + if data_args.test_file is not None: + train_extension = data_args.train_file.split(".")[-1] + test_extension = data_args.test_file.split(".")[-1] + assert ( + test_extension == train_extension + ), "`test_file` should have the same extension (csv or json) as `train_file`." + data_files["test"] = data_args.test_file + else: + raise ValueError("Need either a GLUE task or a test file for `do_predict`.") + + for key in data_files.keys(): + logger.info(f"load a local file for {key}: {data_files[key]}") + + if data_args.train_file.endswith(".csv"): + # Loading a dataset from local csv files + datasets = load_dataset("csv", data_files=data_files) + else: + # Loading a dataset from local json files + datasets = load_dataset("json", data_files=data_files) + # See more about loading any type of standard or custom dataset at + # https://huggingface.co/docs/datasets/loading_datasets.html. + + # Labels + if data_args.task_name is not None: + is_regression = data_args.task_name == "stsb" + if not is_regression: + label_list = datasets["train"].features["label"].names + num_labels = len(label_list) + else: + num_labels = 1 + else: + # Trying to have good defaults here, don't hesitate to tweak to your needs. + is_regression = datasets["train"].features["label"].dtype in ["float32", "float64"] + if is_regression: + num_labels = 1 + else: + # A useful fast method: + # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.unique + label_list = datasets["train"].unique("label") + label_list.sort() # Let's sort it for determinism + num_labels = len(label_list) + + # Load pretrained model and tokenizer + # + # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently + # download model & vocab. + config = AutoConfig.from_pretrained( + model_args.config_name if model_args.config_name else model_args.model_name_or_path, + num_labels=num_labels, + finetuning_task=data_args.task_name, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + model = AutoModelForSequenceClassification.from_pretrained( + model_args.model_name_or_path, + from_tf=bool(".ckpt" in model_args.model_name_or_path), + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Preprocessing the datasets + if data_args.task_name is not None: + sentence1_key, sentence2_key = task_to_keys[data_args.task_name] + else: + # Again, we try to have some nice defaults but don't hesitate to tweak to your use case. + non_label_column_names = [ + name for name in datasets["train"].column_names if name != "label" + ] + if "sentence1" in non_label_column_names and "sentence2" in non_label_column_names: + sentence1_key, sentence2_key = "sentence1", "sentence2" + else: + if len(non_label_column_names) >= 2: + sentence1_key, sentence2_key = non_label_column_names[:2] + else: + sentence1_key, sentence2_key = non_label_column_names[0], None + + # Padding strategy + if data_args.pad_to_max_length: + padding = "max_length" + else: + # We will pad later, dynamically at batch creation, to the max sequence length in each batch + padding = False + + # Some models have set the order of the labels to use, so let's make sure we do use it. + label_to_id = None + if ( + model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id + and data_args.task_name is not None + and not is_regression + ): + # Some have all caps in their config, some don't. + label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()} + if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): + label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)} + else: + logger.warn( + "Your model seems to have been trained with labels, but they don't match the dataset: ", + f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." + "\nIgnoring the model labels as a result.", + ) + elif data_args.task_name is None and not is_regression: + label_to_id = {v: i for i, v in enumerate(label_list)} + + if data_args.max_seq_length > tokenizer.model_max_length: + logger.warn( + f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" + f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}." + ) + max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length) + + def preprocess_function(examples): + # Tokenize the texts + args = ( + (examples[sentence1_key],) + if sentence2_key is None + else (examples[sentence1_key], examples[sentence2_key]) + ) + result = tokenizer(*args, padding=padding, max_length=max_seq_length, truncation=True) + + # Map labels to IDs (not necessary for GLUE tasks) + if label_to_id is not None and "label" in examples: + result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]] + return result + + datasets = datasets.map( + preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache + ) + if training_args.do_train: + if "train" not in datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = datasets["train"] + if data_args.max_train_samples is not None: + train_dataset = train_dataset.select(range(data_args.max_train_samples)) + + if training_args.do_eval: + if "validation" not in datasets and "validation_matched" not in datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = datasets[ + "validation_matched" if data_args.task_name == "mnli" else "validation" + ] + if data_args.max_val_samples is not None: + eval_dataset = eval_dataset.select(range(data_args.max_val_samples)) + + if ( + training_args.do_predict + or data_args.task_name is not None + or data_args.test_file is not None + ): + if "test" not in datasets and "test_matched" not in datasets: + raise ValueError("--do_predict requires a test dataset") + test_dataset = datasets["test_matched" if data_args.task_name == "mnli" else "test"] + if data_args.max_test_samples is not None: + test_dataset = test_dataset.select(range(data_args.max_test_samples)) + + # Log a few random samples from the training set: + if training_args.do_train: + for index in random.sample(range(len(train_dataset)), 3): + logger.info(f"Sample {index} of the training set: {train_dataset[index]}.") + + # Get the metric function + if data_args.task_name is not None: + metric = load_metric("glue", data_args.task_name) + # TODO: When datasets metrics include regular accuracy, make an else here and remove special branch from + # compute_metrics + + # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a + # predictions and label_ids field) and has to return a dictionary string to float. + def compute_metrics(p: EvalPrediction): + preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions + preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1) + if data_args.task_name is not None: + result = metric.compute(predictions=preds, references=p.label_ids) + if len(result) > 1: + result["combined_score"] = np.mean(list(result.values())).item() + return result + elif is_regression: + return {"mse": ((preds - p.label_ids) ** 2).mean().item()} + else: + return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()} + + # Data collator will default to DataCollatorWithPadding, so we change it if we already did the padding. + if data_args.pad_to_max_length: + data_collator = default_data_collator + elif training_args.fp16: + data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) + else: + data_collator = None + + # Initialize our Trainer + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + compute_metrics=compute_metrics, + tokenizer=tokenizer, + data_collator=data_collator, + ) + + # Training + if training_args.do_train: + checkpoint = None + if last_checkpoint is not None: + checkpoint = last_checkpoint + elif os.path.isdir(model_args.model_name_or_path): + # Check the config from that potential checkpoint has the right number of labels before using it as a + # checkpoint. + if AutoConfig.from_pretrained(model_args.model_name_or_path).num_labels == num_labels: + checkpoint = model_args.model_name_or_path + + train_result = trainer.train(resume_from_checkpoint=checkpoint) + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.save_model() # Saves the tokenizer too for easy upload + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + if training_args.do_eval: + logger.info("*** Evaluate ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + eval_datasets = [eval_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + eval_datasets.append(datasets["validation_mismatched"]) + + for eval_dataset, task in zip(eval_datasets, tasks): + metrics = trainer.evaluate(eval_dataset=eval_dataset) + + max_val_samples = ( + data_args.max_val_samples + if data_args.max_val_samples is not None + else len(eval_dataset) + ) + metrics["eval_samples"] = min(max_val_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Test ***") + + # Loop to handle MNLI double evaluation (matched, mis-matched) + tasks = [data_args.task_name] + test_datasets = [test_dataset] + if data_args.task_name == "mnli": + tasks.append("mnli-mm") + test_datasets.append(datasets["test_mismatched"]) + + for test_dataset, task in zip(test_datasets, tasks): + # Removing the `label` columns because it contains -1 and Trainer won't like that. + test_dataset.remove_columns_("label") + predictions = trainer.predict(test_dataset=test_dataset).predictions + predictions = ( + np.squeeze(predictions) if is_regression else np.argmax(predictions, axis=1) + ) + + output_test_file = os.path.join(training_args.output_dir, f"test_results_{task}.txt") + if trainer.is_world_process_zero(): + with open(output_test_file, "w") as writer: + logger.info(f"***** Test results {task} *****") + writer.write("index\tprediction\n") + for index, item in enumerate(predictions): + if is_regression: + writer.write(f"{index}\t{item:3.3f}\n") + else: + item = label_list[item] + writer.write(f"{index}\t{item}\n") + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/tests/data/huggingface_byoc/train/dummy.csv b/tests/data/huggingface_byoc/train/dummy.csv new file mode 100644 index 0000000000..fb1539d552 --- /dev/null +++ b/tests/data/huggingface_byoc/train/dummy.csv @@ -0,0 +1 @@ +# dummy data \ No newline at end of file diff --git a/tests/integ/__init__.py b/tests/integ/__init__.py index 00ed09577b..9133fc8904 100644 --- a/tests/integ/__init__.py +++ b/tests/integ/__init__.py @@ -158,7 +158,7 @@ "ap-northeast-1", "eu-central-1", ] -# TODO: SM Training Compiler team to add all supported regions. + TRAINING_COMPILER_SUPPORTED_REGIONS = [ "af-south-1", "ap-east-1", diff --git a/tests/integ/test_training_compiler.py b/tests/integ/test_training_compiler.py index 67de050ed1..724cd8890c 100644 --- a/tests/integ/test_training_compiler.py +++ b/tests/integ/test_training_compiler.py @@ -20,6 +20,8 @@ from sagemaker.huggingface import TrainingCompilerConfig as HFTrainingCompilerConfig from sagemaker.tensorflow import TensorFlow from sagemaker.tensorflow import TrainingCompilerConfig as TFTrainingCompilerConfig +from sagemaker.pytorch import PyTorch +from sagemaker.pytorch import TrainingCompilerConfig as PTTrainingCompilerConfig from tests import integ from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES @@ -48,8 +50,7 @@ def imagenet_val_set(request, sagemaker_session, tmpdir_factory): key_prefix="Imagenet/TFRecords/validation", ) train_input = sagemaker_session.upload_data( - path=local_path, - key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val", + path=local_path, key_prefix="integ-test-data/trcomp/tensorflow/imagenet/val" ) return train_input @@ -84,8 +85,8 @@ def skip_if_incompatible(gpu_instance_type, request): @pytest.mark.parametrize( "gpu_instance_type,instance_count", [ - ("ml.p3.2xlarge", 1), - ("ml.p3.16xlarge", 2), + pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release), + pytest.param("ml.p3.16xlarge", 2), ], ) def test_huggingface_pytorch( @@ -129,27 +130,32 @@ def test_huggingface_pytorch( hf.fit(huggingface_dummy_dataset) -@pytest.mark.release -def test_huggingface_pytorch_release( +@pytest.mark.parametrize( + "gpu_instance_type,instance_count", + [ + pytest.param("ml.p3.2xlarge", 1, marks=pytest.mark.release), + pytest.param("ml.p3.16xlarge", 2), + ], +) +def test_pytorch( sagemaker_session, gpu_instance_type, - huggingface_training_compiler_latest_version, - huggingface_training_compiler_pytorch_latest_version, + instance_count, + pytorch_training_compiler_latest_version, huggingface_dummy_dataset, ): """ - Test the HuggingFace estimator with PyTorch + Test the PyTorch estimator """ with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES): - data_path = os.path.join(DATA_DIR, "huggingface") - hf = HuggingFace( + hf = PyTorch( py_version="py38", - entry_point=os.path.join(data_path, "run_glue.py"), + source_dir=os.path.join(DATA_DIR, "huggingface_byoc"), + entry_point="run_glue.py", role="SageMakerRole", - transformers_version=huggingface_training_compiler_latest_version, - pytorch_version=huggingface_training_compiler_pytorch_latest_version, - instance_count=1, + framework_version=pytorch_training_compiler_latest_version, + instance_count=instance_count, instance_type=gpu_instance_type, hyperparameters={ "model_name_or_path": "distilbert-base-cased", @@ -163,7 +169,8 @@ def test_huggingface_pytorch_release( }, sagemaker_session=sagemaker_session, disable_profiler=True, - compiler_config=HFTrainingCompilerConfig(), + compiler_config=PTTrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}} if instance_count > 1 else None, ) hf.fit(huggingface_dummy_dataset) @@ -209,10 +216,7 @@ def test_huggingface_tensorflow( @pytest.mark.release def test_tensorflow( - sagemaker_session, - gpu_instance_type, - tensorflow_training_latest_version, - imagenet_val_set, + sagemaker_session, gpu_instance_type, tensorflow_training_latest_version, imagenet_val_set ): """ Test the TensorFlow estimator @@ -264,8 +268,4 @@ def test_tensorflow( compiler_config=TFTrainingCompilerConfig(), ) - tf.fit( - inputs=imagenet_val_set, - logs=True, - wait=True, - ) + tf.fit(inputs=imagenet_val_set, logs=True, wait=True) diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py new file mode 100644 index 0000000000..0fe2402695 --- /dev/null +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -0,0 +1,616 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import logging + +import json +import os + +import pytest +from mock import MagicMock, Mock, patch, ANY +from packaging.version import Version + +from sagemaker import image_uris +from sagemaker.pytorch import PyTorch, TrainingCompilerConfig +from sagemaker.pytorch.model import PyTorchModel +from sagemaker.instance_group import InstanceGroup + +from tests.unit.sagemaker.training_compiler import EC2_GPU_INSTANCE_CLASSES + + +DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "..", "..", "data") +SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") +SERVING_SCRIPT_FILE = "another_dummy_script.py" +MODEL_DATA = "s3://some/data.tar.gz" +ENV = {"DUMMY_ENV_VAR": "dummy_value"} +TIMESTAMP = "2017-11-06-14:14:15.672" +TIME = 1510006209.073025 +BUCKET_NAME = "mybucket" +INSTANCE_COUNT = 1 +INSTANCE_TYPE = "ml.p3.2xlarge" +IMAGE_URI = "pytorch" +JOB_NAME = "{}-{}".format(IMAGE_URI, TIMESTAMP) +ROLE = "Dummy" +REGION = "us-east-1" +GPU = "ml.p3.2xlarge" +SUPPORTED_GPU_INSTANCE_CLASSES = {"p3", "p3dn", "g4dn", "p4d", "g5"} +UNSUPPORTED_GPU_INSTANCE_CLASSES = EC2_GPU_INSTANCE_CLASSES - SUPPORTED_GPU_INSTANCE_CLASSES + +LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]} + +EXPERIMENT_CONFIG = { + "ExperimentName": "exp", + "TrialName": "trial", + "TrialComponentDisplayName": "tc", +} + + +@pytest.fixture(scope="module") +def cpu_instance_type(): + return "ml.m5.xlarge" + + +@pytest.fixture(name="sagemaker_session", scope="function") +def fixture_sagemaker_session(): + boto_mock = Mock(name="boto_session", region_name=REGION) + session = Mock( + name="sagemaker_session", + boto_session=boto_mock, + boto_region_name=REGION, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + ) + + describe = {"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} + session.sagemaker_client.describe_training_job = Mock(return_value=describe) + session.sagemaker_client.list_tags = Mock(return_value=LIST_TAGS_RESULT) + session.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME) + session.expand_role = Mock(name="expand_role", return_value=ROLE) + return session + + +def _get_full_gpu_image_uri(version, instance_type, training_compiler_config): + return image_uris.retrieve( + "pytorch-training-compiler", + REGION, + version=version, + py_version="py38", + instance_type=instance_type, + image_scope="training", + container_version=None, + training_compiler_config=training_compiler_config, + ) + + +def _create_train_job(version, instance_type, training_compiler_config, instance_count=1): + return { + "image_uri": _get_full_gpu_image_uri(version, instance_type, training_compiler_config), + "input_mode": "File", + "input_config": [ + { + "ChannelName": "training", + "DataSource": { + "S3DataSource": { + "S3DataDistributionType": "FullyReplicated", + "S3DataType": "S3Prefix", + } + }, + } + ], + "role": ROLE, + "job_name": JOB_NAME, + "output_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + "resource_config": { + "InstanceType": instance_type, + "InstanceCount": instance_count, + "VolumeSizeInGB": 30, + }, + "hyperparameters": { + "sagemaker_program": json.dumps("dummy_script.py"), + "sagemaker_container_log_level": str(logging.INFO), + "sagemaker_job_name": json.dumps(JOB_NAME), + "sagemaker_submit_directory": json.dumps( + "s3://{}/{}/source/sourcedir.tar.gz".format(BUCKET_NAME, JOB_NAME) + ), + "sagemaker_region": '"us-east-1"', + }, + "stop_condition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "tags": None, + "vpc_config": None, + "metric_definitions": None, + "environment": None, + "retry_strategy": None, + "experiment_config": EXPERIMENT_CONFIG, + "debugger_hook_config": { + "CollectionConfigurations": [], + "S3OutputPath": "s3://{}/".format(BUCKET_NAME), + }, + "profiler_rule_configs": [ + { + "RuleConfigurationName": "ProfilerReport-1510006209", + "RuleEvaluatorImage": "503895931360.dkr.ecr.us-east-1.amazonaws.com/sagemaker-debugger-rules:latest", + "RuleParameters": {"rule_to_invoke": "ProfilerReport"}, + } + ], + "profiler_config": {"S3OutputPath": "s3://{}/".format(BUCKET_NAME)}, + } + + +def test_unsupported_BYOC( + pytorch_training_compiler_version, +): + byoc = ( + "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:" + "1.12.0-" + "gpu-" + "py38-cu113-ubuntu20.04" + ) + with pytest.raises(ValueError): + PyTorch( + image_uri=byoc, + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_cpu_instance(cpu_instance_type, pytorch_training_compiler_version): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=cpu_instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +@pytest.mark.parametrize("unsupported_gpu_instance_class", UNSUPPORTED_GPU_INSTANCE_CLASSES) +def test_unsupported_gpu_instance( + unsupported_gpu_instance_class, pytorch_training_compiler_version +): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=f"ml.{unsupported_gpu_instance_class}.xlarge", + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +@pytest.mark.xfail(reason="With only 1 supported version, user input is ignored.") +def test_unsupported_framework_version(): + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version="99.99.99", + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_python_2( + pytorch_training_compiler_version, +): + with pytest.raises(ValueError): + PyTorch( + py_version="py27", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_instance_group( + pytorch_training_compiler_version, +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_groups=[ + InstanceGroup("ml.p3dn.24xlarge", "ml.p3dn.24xlarge", 16), + InstanceGroup("ml.p4d.24xlarge", "ml.p4d.24xlarge", 16), + ], + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + ).fit() + + +def test_unsupported_distribution( + pytorch_training_compiler_version, +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"smdistributed": {"dataparallel": {"enabled": True}}}, + ).fit() + + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + transformers_version="4.17", + pytorch_version="1.10", + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}}, + ).fit() + + with pytest.raises(ValueError): + PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + instance_count=2, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"mpi": {"enabled": True}}, + ).fit() + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +def test_pytorchxla_distribution( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class +): + if Version(pytorch_training_compiler_version) < Version("1.12"): + pytest.skip("This test is intended for PyTorch 1.12 and above") + compiler_config = TrainingCompilerConfig() + instance_type = f"ml.{instance_class}.xlarge" + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=2, + instance_type=instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(), + distribution={"pytorchxla": {"enabled": True}}, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, instance_type, compiler_config, instance_count=2 + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][PyTorch.LAUNCH_PT_XLA_ENV_NAME] = json.dumps(True) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +@pytest.mark.parametrize("instance_class", SUPPORTED_GPU_INSTANCE_CLASSES) +def test_default_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version, instance_class +): + compiler_config = TrainingCompilerConfig() + instance_type = f"ml.{instance_class}.xlarge" + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=instance_type, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=compiler_config, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, instance_type, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +def test_debug_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version +): + compiler_config = TrainingCompilerConfig(debug=True) + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=compiler_config, + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + True + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + True + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.estimator.name_from_base", return_value=JOB_NAME) +@patch("time.time", return_value=TIME) +def test_disable_compiler_config( + time, name_from_base, sagemaker_session, pytorch_training_compiler_version +): + compiler_config = TrainingCompilerConfig(enabled=False) + + pt = PyTorch( + py_version="py38", + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + instance_count=INSTANCE_COUNT, + instance_type=INSTANCE_TYPE, + framework_version=pytorch_training_compiler_version, + enable_sagemaker_metrics=False, + compiler_config=TrainingCompilerConfig(enabled=False), + ) + + inputs = "s3://mybucket/train" + + pt.fit(inputs=inputs, experiment_config=EXPERIMENT_CONFIG) + + sagemaker_call_names = [c[0] for c in sagemaker_session.method_calls] + assert sagemaker_call_names == ["train", "logs_for_job"] + boto_call_names = [c[0] for c in sagemaker_session.boto_session.method_calls] + assert boto_call_names == ["resource"] + + expected_train_args = _create_train_job( + pytorch_training_compiler_version, INSTANCE_TYPE, compiler_config + ) + expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs + expected_train_args["enable_sagemaker_metrics"] = False + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_COMPILER] = json.dumps( + False + ) + expected_train_args["hyperparameters"][TrainingCompilerConfig.HP_ENABLE_DEBUG] = json.dumps( + False + ) + + actual_train_args = sagemaker_session.method_calls[0][2] + assert ( + actual_train_args == expected_train_args + ), f"{json.dumps(actual_train_args, indent=2)} != {json.dumps(expected_train_args, indent=2)}" + + +@pytest.mark.parametrize( + ["compiler_enabled", "debug_enabled"], [(True, False), (True, True), (False, False)] +) +def test_attach(sagemaker_session, compiler_enabled, debug_enabled): + training_image = ( + "1.dkr.ecr.us-east-1.amazonaws.com/pytorch-trcomp-training:" + "1.12.0-" + "gpu-" + "py38-cu113-ubuntu20.04" + ) + returned_job_description = { + "AlgorithmSpecification": {"TrainingInputMode": "File", "TrainingImage": training_image}, + "HyperParameters": { + "sagemaker_submit_directory": '"s3://some/sourcedir.tar.gz"', + "sagemaker_program": '"iris-dnn-classifier.py"', + "sagemaker_s3_uri_training": '"sagemaker-3/integ-test-data/tf_iris"', + "sagemaker_container_log_level": '"logging.INFO"', + "sagemaker_job_name": '"trcomp"', + "training_steps": "100", + "sagemaker_region": '"us-east-1"', + TrainingCompilerConfig.HP_ENABLE_COMPILER: json.dumps(compiler_enabled), + TrainingCompilerConfig.HP_ENABLE_DEBUG: json.dumps(debug_enabled), + }, + "RoleArn": "arn:aws:iam::366:role/SageMakerRole", + "ResourceConfig": { + "VolumeSizeInGB": 30, + "InstanceCount": 1, + "InstanceType": "ml.p3.2xlarge", + }, + "StoppingCondition": {"MaxRuntimeInSeconds": 24 * 60 * 60}, + "TrainingJobName": "trcomp", + "TrainingJobStatus": "Completed", + "TrainingJobArn": "arn:aws:sagemaker:us-west-2:336:training-job/trcomp", + "OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/trcomp"}, + "TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"}, + } + sagemaker_session.sagemaker_client.describe_training_job = Mock( + name="describe_training_job", return_value=returned_job_description + ) + + estimator = PyTorch.attach(training_job_name="trcomp", sagemaker_session=sagemaker_session) + assert estimator.latest_training_job.job_name == "trcomp" + assert estimator.py_version == "py38" + assert estimator.framework_version == "1.12.0" + assert estimator.role == "arn:aws:iam::366:role/SageMakerRole" + assert estimator.instance_count == 1 + assert estimator.max_run == 24 * 60 * 60 + assert estimator.input_mode == "File" + assert estimator.base_job_name == "trcomp" + assert estimator.output_path == "s3://place/output/trcomp" + assert estimator.output_kms_key == "" + assert estimator.hyperparameters()["training_steps"] == "100" + assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_COMPILER] == json.dumps( + compiler_enabled + ) + assert estimator.hyperparameters()[TrainingCompilerConfig.HP_ENABLE_DEBUG] == json.dumps( + debug_enabled + ) + assert estimator.source_dir == "s3://some/sourcedir.tar.gz" + assert estimator.entry_point == "iris-dnn-classifier.py" + + +@patch("sagemaker.utils.repack_model", MagicMock()) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_register_pytorch_model_auto_infer_framework( + sagemaker_session, pytorch_training_compiler_version +): + + model_package_group_name = "test-pt-register-model" + content_types = ["application/json"] + response_types = ["application/json"] + inference_instances = ["ml.m4.xlarge"] + transform_instances = ["ml.m4.xlarge"] + image_uri = "fakeimage" + + pt_model = PyTorchModel( + model_data="s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=pytorch_training_compiler_version, + py_version="py38", + sagemaker_session=sagemaker_session, + ) + + pt_model.register( + content_types, + response_types, + inference_instances, + transform_instances, + model_package_group_name=model_package_group_name, + marketplace_cert=True, + image_uri=image_uri, + ) + + expected_create_model_package_request = { + "containers": [ + { + "Image": image_uri, + "Environment": ANY, + "ModelDataUrl": ANY, + "Framework": "PYTORCH", + "FrameworkVersion": pytorch_training_compiler_version, + } + ], + "content_types": content_types, + "response_types": response_types, + "inference_instances": inference_instances, + "transform_instances": transform_instances, + "model_package_group_name": model_package_group_name, + "marketplace_cert": True, + } + + sagemaker_session.create_model_package_from_containers.assert_called_with( + **expected_create_model_package_request + ) From d779d1b8296242eb15637e85272a1a50a7ee897b Mon Sep 17 00:00:00 2001 From: HappyAmazonian <91216626+HappyAmazonian@users.noreply.github.com> Date: Tue, 6 Dec 2022 16:37:16 -0800 Subject: [PATCH 19/43] feature: Add Neo image uri config for Pytorch 1.12 (#3507) --- .../image_uri_config/neo-pytorch.json | 36 ++++++++++++++++++- tests/data/pytorch_neo/code/inference.py | 4 +-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/image_uri_config/neo-pytorch.json b/src/sagemaker/image_uri_config/neo-pytorch.json index bd15a6450e..c46dd3de5d 100644 --- a/src/sagemaker/image_uri_config/neo-pytorch.json +++ b/src/sagemaker/image_uri_config/neo-pytorch.json @@ -11,7 +11,9 @@ "1.7.0": "1.7", "1.7.1": "1.7", "1.8.0": "1.8", - "1.8.1": "1.8" + "1.8.1": "1.8", + "1.12.0": "1.12", + "1.12.1": "1.12" }, "versions": { "1.4": { @@ -173,6 +175,38 @@ "us-west-2": "301217895009" }, "repository": "sagemaker-inference-pytorch" + }, + "1.12": { + "py_versions": ["py3"], + "registries": { + "af-south-1": "774647643957", + "ap-east-1": "110948597952", + "ap-northeast-1": "941853720454", + "ap-northeast-2": "151534178276", + "ap-northeast-3": "925152966179", + "ap-south-1": "763008648453", + "ap-southeast-1": "324986816169", + "ap-southeast-2": "355873309152", + "ca-central-1": "464438896020", + "cn-north-1": "472730292857", + "cn-northwest-1": "474822919863", + "eu-central-1": "746233611703", + "eu-north-1": "601324751636", + "eu-south-1": "966458181534", + "eu-west-1": "802834080501", + "eu-west-2": "205493899709", + "eu-west-3": "254080097072", + "me-south-1": "836785723513", + "sa-east-1": "756306329178", + "us-east-1": "785573368785", + "us-east-2": "007439368137", + "us-gov-west-1": "263933020539", + "us-iso-east-1": "167761179201", + "us-isob-east-1": "406031935815", + "us-west-1": "710691900526", + "us-west-2": "301217895009" + }, + "repository": "sagemaker-inference-pytorch" } } } diff --git a/tests/data/pytorch_neo/code/inference.py b/tests/data/pytorch_neo/code/inference.py index 5b89c2bebc..79fe66d716 100644 --- a/tests/data/pytorch_neo/code/inference.py +++ b/tests/data/pytorch_neo/code/inference.py @@ -71,8 +71,8 @@ def model_fn(model_dir): logger.info("model_fn") neopytorch.config(model_dir=model_dir, neo_runtime=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - # The compiled model is saved as "model.pth" - model = torch.jit.load(os.path.join(model_dir, "model.pth"), map_location=device) + # The compiled model is saved as "model.pth" or "model.pt" + model = torch.jit.load(os.path.join(model_dir, "model.pt"), map_location=device) # It is recommended to run warm-up inference during model load sample_input_path = os.path.join(model_dir, "sample_input.pkl") From 83327fb9ef5eb5f44c9fd3f8925c7791576c9a37 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 7 Dec 2022 03:20:15 +0000 Subject: [PATCH 20/43] prepare release v2.120.0 --- CHANGELOG.md | 13 +++++++++++++ VERSION | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b8b3155231..71894ff29d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v2.120.0 (2022-12-07) + +### Features + + * Add Neo image uri config for Pytorch 1.12 + * Adding support for SageMaker Training Compiler in PyTorch estimator starting 1.12 + * Update registries with new region account number mappings. + * Add DXB region to frameworks by DLC + +### Bug Fixes and Other Changes + + * support idempotency for framework and spark processors + ## v2.119.0 (2022-12-03) ### Features diff --git a/VERSION b/VERSION index dda4128cf2..7de9d18b4e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.119.1.dev0 +2.120.0 From 5bffb04b78e8cd6422654008511aa61ca6f66efb Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 7 Dec 2022 03:20:17 +0000 Subject: [PATCH 21/43] update development version to v2.120.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7de9d18b4e..73c4cd6968 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.120.0 +2.120.1.dev0 From b828396c55082bc5f06092be41555729d775874a Mon Sep 17 00:00:00 2001 From: Malav Shastri <57682969+malav-shastri@users.noreply.github.com> Date: Wed, 7 Dec 2022 20:58:37 +0530 Subject: [PATCH 22/43] feature: Algorithms Region Expansion OSU/DXB (#3508) Co-authored-by: Malav Shastri --- .../image_uri_config/blazingtext.json | 2 ++ .../factorization-machines.json | 2 ++ .../image_uri_config/forecasting-deepar.json | 2 ++ .../image-classification.json | 2 ++ .../image_uri_config/ipinsights.json | 2 ++ src/sagemaker/image_uri_config/kmeans.json | 2 ++ src/sagemaker/image_uri_config/knn.json | 2 ++ .../image_uri_config/linear-learner.json | 2 ++ src/sagemaker/image_uri_config/ntm.json | 2 ++ .../image_uri_config/object-detection.json | 2 ++ .../image_uri_config/object2vec.json | 2 ++ src/sagemaker/image_uri_config/pca.json | 2 ++ .../image_uri_config/randomcutforest.json | 2 ++ .../semantic-segmentation.json | 2 ++ src/sagemaker/image_uri_config/seq2seq.json | 2 ++ src/sagemaker/image_uri_config/sklearn.json | 14 ++++++++ src/sagemaker/image_uri_config/xgboost.json | 36 +++++++++++++++++++ tests/unit/sagemaker/image_uris/test_algos.py | 4 +++ .../unit/sagemaker/image_uris/test_sklearn.py | 2 ++ .../unit/sagemaker/image_uris/test_xgboost.py | 4 +++ 20 files changed, 90 insertions(+) diff --git a/src/sagemaker/image_uri_config/blazingtext.json b/src/sagemaker/image_uri_config/blazingtext.json index c588d65c73..ae4295c59a 100644 --- a/src/sagemaker/image_uri_config/blazingtext.json +++ b/src/sagemaker/image_uri_config/blazingtext.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/factorization-machines.json b/src/sagemaker/image_uri_config/factorization-machines.json index 0f9930357f..8fb1895707 100644 --- a/src/sagemaker/image_uri_config/factorization-machines.json +++ b/src/sagemaker/image_uri_config/factorization-machines.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/forecasting-deepar.json b/src/sagemaker/image_uri_config/forecasting-deepar.json index 1acc96ed3e..e9beb7acb6 100644 --- a/src/sagemaker/image_uri_config/forecasting-deepar.json +++ b/src/sagemaker/image_uri_config/forecasting-deepar.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "522234722520", "us-east-2": "566113047672", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "156387875391" diff --git a/src/sagemaker/image_uri_config/image-classification.json b/src/sagemaker/image_uri_config/image-classification.json index 44ccb3f08d..61e037da08 100644 --- a/src/sagemaker/image_uri_config/image-classification.json +++ b/src/sagemaker/image_uri_config/image-classification.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/ipinsights.json b/src/sagemaker/image_uri_config/ipinsights.json index 4e56c149dc..cf3c70194f 100644 --- a/src/sagemaker/image_uri_config/ipinsights.json +++ b/src/sagemaker/image_uri_config/ipinsights.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/kmeans.json b/src/sagemaker/image_uri_config/kmeans.json index 952724ce11..e8e947f094 100644 --- a/src/sagemaker/image_uri_config/kmeans.json +++ b/src/sagemaker/image_uri_config/kmeans.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/knn.json b/src/sagemaker/image_uri_config/knn.json index 79b239966d..89e8ef6224 100644 --- a/src/sagemaker/image_uri_config/knn.json +++ b/src/sagemaker/image_uri_config/knn.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/linear-learner.json b/src/sagemaker/image_uri_config/linear-learner.json index bb027284ab..606edd3791 100644 --- a/src/sagemaker/image_uri_config/linear-learner.json +++ b/src/sagemaker/image_uri_config/linear-learner.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/ntm.json b/src/sagemaker/image_uri_config/ntm.json index 115264b346..16f9565405 100644 --- a/src/sagemaker/image_uri_config/ntm.json +++ b/src/sagemaker/image_uri_config/ntm.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/object-detection.json b/src/sagemaker/image_uri_config/object-detection.json index 6a7ba03695..67b60fe587 100644 --- a/src/sagemaker/image_uri_config/object-detection.json +++ b/src/sagemaker/image_uri_config/object-detection.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/object2vec.json b/src/sagemaker/image_uri_config/object2vec.json index 39614d1273..b166cc96ff 100644 --- a/src/sagemaker/image_uri_config/object2vec.json +++ b/src/sagemaker/image_uri_config/object2vec.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/pca.json b/src/sagemaker/image_uri_config/pca.json index 5f87d8528c..11982e2197 100644 --- a/src/sagemaker/image_uri_config/pca.json +++ b/src/sagemaker/image_uri_config/pca.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/randomcutforest.json b/src/sagemaker/image_uri_config/randomcutforest.json index ae7a3574be..15dc84dfc5 100644 --- a/src/sagemaker/image_uri_config/randomcutforest.json +++ b/src/sagemaker/image_uri_config/randomcutforest.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107" diff --git a/src/sagemaker/image_uri_config/semantic-segmentation.json b/src/sagemaker/image_uri_config/semantic-segmentation.json index 866dd606b4..f49bc43109 100644 --- a/src/sagemaker/image_uri_config/semantic-segmentation.json +++ b/src/sagemaker/image_uri_config/semantic-segmentation.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/seq2seq.json b/src/sagemaker/image_uri_config/seq2seq.json index bb3daf93b6..87810ad09d 100644 --- a/src/sagemaker/image_uri_config/seq2seq.json +++ b/src/sagemaker/image_uri_config/seq2seq.json @@ -22,10 +22,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" diff --git a/src/sagemaker/image_uri_config/sklearn.json b/src/sagemaker/image_uri_config/sklearn.json index 7961fde282..4d093f5f62 100644 --- a/src/sagemaker/image_uri_config/sklearn.json +++ b/src/sagemaker/image_uri_config/sklearn.json @@ -24,10 +24,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -57,10 +59,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -90,10 +94,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -127,10 +133,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -160,10 +168,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -193,10 +203,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -230,10 +242,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" diff --git a/src/sagemaker/image_uri_config/xgboost.json b/src/sagemaker/image_uri_config/xgboost.json index a809083c4a..946e78ecc4 100644 --- a/src/sagemaker/image_uri_config/xgboost.json +++ b/src/sagemaker/image_uri_config/xgboost.json @@ -25,10 +25,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" @@ -58,10 +60,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -91,10 +95,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -124,10 +130,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -155,10 +163,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -186,10 +196,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -217,10 +229,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -248,10 +262,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -286,10 +302,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032" @@ -319,10 +337,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -352,10 +372,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -385,10 +407,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -416,10 +440,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -447,10 +473,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -478,10 +506,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -509,10 +539,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -544,10 +576,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" @@ -575,10 +609,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249" diff --git a/tests/unit/sagemaker/image_uris/test_algos.py b/tests/unit/sagemaker/image_uris/test_algos.py index 454d375b4b..443727094a 100644 --- a/tests/unit/sagemaker/image_uris/test_algos.py +++ b/tests/unit/sagemaker/image_uris/test_algos.py @@ -68,10 +68,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "382416733822", "us-east-2": "404615174143", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "174872318107", @@ -155,10 +157,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032", diff --git a/tests/unit/sagemaker/image_uris/test_sklearn.py b/tests/unit/sagemaker/image_uris/test_sklearn.py index d0fcbdb300..8563753e8c 100644 --- a/tests/unit/sagemaker/image_uris/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/test_sklearn.py @@ -37,10 +37,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249", diff --git a/tests/unit/sagemaker/image_uris/test_xgboost.py b/tests/unit/sagemaker/image_uris/test_xgboost.py index 78ab7e10ee..4d0f9f1dc3 100644 --- a/tests/unit/sagemaker/image_uris/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/test_xgboost.py @@ -35,10 +35,12 @@ "eu-west-3": "749696950732", "eu-south-1": "257386234256", "me-south-1": "249704162688", + "me-central-1": "272398656194", "sa-east-1": "855470959533", "us-east-1": "811284229777", "us-east-2": "825641698319", "us-gov-west-1": "226302683700", + "us-gov-east-1": "237065988967", "us-iso-east-1": "490574956308", "us-west-1": "632365934929", "us-west-2": "433757028032", @@ -67,10 +69,12 @@ "eu-west-3": "659782779980", "eu-south-1": "978288397137", "me-south-1": "801668240914", + "me-central-1": "272398656194", "sa-east-1": "737474898029", "us-east-1": "683313688378", "us-east-2": "257758044811", "us-gov-west-1": "414596584902", + "us-gov-east-1": "237065988967", "us-iso-east-1": "833128469047", "us-west-1": "746614075791", "us-west-2": "246618743249", From 357f73226c9c5fe651ea74169cafe585e1092ad0 Mon Sep 17 00:00:00 2001 From: Navin Soni Date: Wed, 7 Dec 2022 10:36:33 -0800 Subject: [PATCH 23/43] fix: Add constraints file for apache-airflow (#3510) --- requirements/extras/test_requirements.txt | 1 + tox.ini | 2 ++ 2 files changed, 3 insertions(+) diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index b52f394bd0..fe93fd4d0e 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -11,6 +11,7 @@ contextlib2==21.6.0 awslogs==0.14.0 black==22.3.0 stopit==1.1.2 +# Update tox.ini to have correct version of airflow constraints file apache-airflow==2.4.1 apache-airflow-providers-amazon==4.0.0 attrs==22.1.0 diff --git a/tox.ini b/tox.ini index 2d5fdf0b40..3a398ca51d 100644 --- a/tox.ini +++ b/tox.ini @@ -73,6 +73,8 @@ passenv = # Can be used to specify which tests to run, e.g.: tox -- -s commands = python -c "import os; os.system('install-custom-pkgs --install-boto-wheels')" + pip install 'apache-airflow==2.4.1' --constraint "https://raw.githubusercontent.com/apache/airflow/constraints-2.4.1/constraints-3.10.txt" + pytest --cov=sagemaker --cov-append {posargs} {env:IGNORE_COVERAGE:} coverage report -i --fail-under=86 deps = .[test] From a28d1dd129ecceb612d5e8927b6be72937711722 Mon Sep 17 00:00:00 2001 From: Brock Wade Date: Wed, 7 Dec 2022 19:14:12 -0800 Subject: [PATCH 24/43] fix: FrameworkProcessor S3 uploads (#3493) Co-authored-by: Brock Wade Co-authored-by: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> --- src/sagemaker/processing.py | 47 +++- .../data/pipeline/test_source_dir/script_1.py | 11 + .../data/pipeline/test_source_dir/script_2.py | 9 + .../pipeline/test_source_dir_2/script_2.py | 9 + .../workflow/test_processing_steps.py | 249 +++++++++++++++++- .../integ/sagemaker/workflow/test_workflow.py | 8 +- 6 files changed, 322 insertions(+), 11 deletions(-) create mode 100644 tests/data/pipeline/test_source_dir/script_1.py create mode 100644 tests/data/pipeline/test_source_dir/script_2.py create mode 100644 tests/data/pipeline/test_source_dir_2/script_2.py diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index 81e3d34b1d..01d4361197 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -1741,13 +1741,7 @@ def _pack_and_upload_code( raise RuntimeError("S3 source_dir file must be named `sourcedir.tar.gz.`") script = estimator.uploaded_code.script_name - s3_runproc_sh = S3Uploader.upload_string_as_file_body( - self._generate_framework_script(script), - desired_s3_uri=entrypoint_s3_uri, - kms_key=kms_key, - sagemaker_session=self.sagemaker_session, - ) - logger.info("runproc.sh uploaded to %s", s3_runproc_sh) + s3_runproc_sh = self._create_and_upload_runproc(script, kms_key, entrypoint_s3_uri) return s3_runproc_sh, inputs, job_name @@ -1857,3 +1851,42 @@ def _set_entrypoint(self, command, user_script_name): ) ) self.entrypoint = self.framework_entrypoint_command + [user_script_location] + + def _create_and_upload_runproc(self, user_script, kms_key, entrypoint_s3_uri): + """Create runproc shell script and upload to S3 bucket. + + If leveraging a pipeline session with optimized S3 artifact paths, + the runproc.sh file is hashed and uploaded to a separate S3 location. + + + Args: + user_script (str): Relative path to ```code``` in the source bundle + - e.g. 'process.py'. + kms_key (str): THe kms key used for encryption. + entrypoint_s3_uri (str): The S3 upload path for the runproc script. + """ + from sagemaker.workflow.utilities import _pipeline_config, hash_object + + if _pipeline_config and _pipeline_config.pipeline_name: + runproc_file_str = self._generate_framework_script(user_script) + runproc_file_hash = hash_object(runproc_file_str) + s3_uri = ( + f"s3://{self.sagemaker_session.default_bucket()}/{_pipeline_config.pipeline_name}/" + f"code/{runproc_file_hash}/runproc.sh" + ) + s3_runproc_sh = S3Uploader.upload_string_as_file_body( + runproc_file_str, + desired_s3_uri=s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, + ) + else: + s3_runproc_sh = S3Uploader.upload_string_as_file_body( + self._generate_framework_script(user_script), + desired_s3_uri=entrypoint_s3_uri, + kms_key=kms_key, + sagemaker_session=self.sagemaker_session, + ) + logger.info("runproc.sh uploaded to %s", s3_runproc_sh) + + return s3_runproc_sh diff --git a/tests/data/pipeline/test_source_dir/script_1.py b/tests/data/pipeline/test_source_dir/script_1.py new file mode 100644 index 0000000000..4a427b1898 --- /dev/null +++ b/tests/data/pipeline/test_source_dir/script_1.py @@ -0,0 +1,11 @@ +""" +Integ test file script_1.py +""" +import pathlib + +if __name__ == "__main__": + + print("writing file to /opt/ml/processing/test/test.py...") + pathlib.Path("/opt/ml/processing/test").mkdir(parents=True, exist_ok=True) + with open("/opt/ml/processing/test/test.py", "w") as f: + f.write('print("test...")') diff --git a/tests/data/pipeline/test_source_dir/script_2.py b/tests/data/pipeline/test_source_dir/script_2.py new file mode 100644 index 0000000000..6245dac987 --- /dev/null +++ b/tests/data/pipeline/test_source_dir/script_2.py @@ -0,0 +1,9 @@ +""" +Integ test file script_2.py +""" + +if __name__ == "__main__": + + print("reading file: /opt/ml/procesing/test/test.py") + with open("/opt/ml/processing/test/test.py", "r") as f: + print(f.read()) diff --git a/tests/data/pipeline/test_source_dir_2/script_2.py b/tests/data/pipeline/test_source_dir_2/script_2.py new file mode 100644 index 0000000000..6245dac987 --- /dev/null +++ b/tests/data/pipeline/test_source_dir_2/script_2.py @@ -0,0 +1,9 @@ +""" +Integ test file script_2.py +""" + +if __name__ == "__main__": + + print("reading file: /opt/ml/procesing/test/test.py") + with open("/opt/ml/processing/test/test.py", "r") as f: + print(f.read()) diff --git a/tests/integ/sagemaker/workflow/test_processing_steps.py b/tests/integ/sagemaker/workflow/test_processing_steps.py index 781bce85a7..238eff6123 100644 --- a/tests/integ/sagemaker/workflow/test_processing_steps.py +++ b/tests/integ/sagemaker/workflow/test_processing_steps.py @@ -17,15 +17,18 @@ import re import subprocess from datetime import datetime +from pathlib import Path import pytest from botocore.exceptions import WaiterError +from sagemaker.workflow.utilities import hash_files_or_dirs, hash_object from sagemaker import image_uris, get_execution_role, utils from sagemaker.dataset_definition import DatasetDefinition, AthenaDatasetDefinition -from sagemaker.processing import ProcessingInput, ProcessingOutput -from sagemaker.s3 import S3Uploader -from sagemaker.sklearn import SKLearnProcessor +from sagemaker.processing import ProcessingInput, ProcessingOutput, FrameworkProcessor +from sagemaker.s3 import S3Uploader, S3Downloader +from sagemaker.sklearn import SKLearnProcessor, SKLearn +from sagemaker.tensorflow import TensorFlow from sagemaker.workflow.parameters import ParameterInteger, ParameterString from sagemaker.workflow.pipeline import Pipeline from sagemaker.workflow.steps import ( @@ -379,6 +382,203 @@ def test_one_step_framework_processing_pipeline( pass +def test_multi_step_framework_processing_pipeline_same_source_dir( + pipeline_session, role, pipeline_name +): + default_bucket = pipeline_session.default_bucket() + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + SOURCE_DIR = "/pipeline/test_source_dir" + + framework_processor_tf = FrameworkProcessor( + role=role, + instance_type="ml.m5.xlarge", + instance_count=1, + estimator_cls=TensorFlow, + framework_version="2.9", + py_version="py39", + sagemaker_session=pipeline_session, + ) + + framework_processor_sk = FrameworkProcessor( + framework_version="1.0-1", + instance_type="ml.m5.xlarge", + instance_count=1, + base_job_name="my-job", + role=role, + estimator_cls=SKLearn, + sagemaker_session=pipeline_session, + ) + + step_1 = ProcessingStep( + name="Step-1", + step_args=framework_processor_tf.run( + code="script_1.py", + source_dir=DATA_DIR + SOURCE_DIR, + outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")], + ), + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="Step-2", + step_args=framework_processor_sk.run( + code="script_2.py", + source_dir=DATA_DIR + SOURCE_DIR, + inputs=[ + ProcessingInput( + source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri, + destination="/opt/ml/processing/test", + ), + ], + ), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + ) + try: + pipeline.create(role) + definition = json.loads(pipeline.definition()) + + source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][0], + SOURCE_DIR, + "script_1.py", + ) + source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_sk, + default_bucket, + pipeline_name, + definition["Steps"][1], + SOURCE_DIR, + "script_2.py", + ) + + # the same local source_dirs should have the same s3 paths + assert source_dir_1_s3_uri == source_dir_2_s3_uri + + # verify different entry_point paths + assert entry_point_1 != entry_point_2 + + execution = pipeline.start(parameters={}) + try: + execution.wait(delay=540, max_attempts=3) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + for step in execution_steps: + assert step["StepStatus"] == "Succeeded" + + finally: + try: + pipeline.delete() + except Exception: + pass + + +def test_multi_step_framework_processing_pipeline_different_source_dir( + pipeline_session, role, pipeline_name +): + default_bucket = pipeline_session.default_bucket() + cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") + + SOURCE_DIR_1 = "/pipeline/test_source_dir" + SOURCE_DIR_2 = "/pipeline/test_source_dir_2" + + framework_processor_tf = FrameworkProcessor( + role=role, + instance_type="ml.m5.xlarge", + instance_count=1, + estimator_cls=TensorFlow, + framework_version="2.9", + py_version="py39", + sagemaker_session=pipeline_session, + ) + + step_1 = ProcessingStep( + name="Step-1", + step_args=framework_processor_tf.run( + code="script_1.py", + source_dir=DATA_DIR + SOURCE_DIR_1, + outputs=[ProcessingOutput(output_name="test", source="/opt/ml/processing/test")], + ), + cache_config=cache_config, + ) + + step_2 = ProcessingStep( + name="Step-2", + step_args=framework_processor_tf.run( + code="script_2.py", + source_dir=DATA_DIR + SOURCE_DIR_2, + inputs=[ + ProcessingInput( + source=step_1.properties.ProcessingOutputConfig.Outputs["test"].S3Output.S3Uri, + destination="/opt/ml/processing/test", + ), + ], + ), + cache_config=cache_config, + ) + + pipeline = Pipeline( + name=pipeline_name, steps=[step_1, step_2], sagemaker_session=pipeline_session + ) + try: + pipeline.create(role) + definition = json.loads(pipeline.definition()) + + source_dir_1_s3_uri, entry_point_1 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][0], + SOURCE_DIR_1, + "script_1.py", + ) + source_dir_2_s3_uri, entry_point_2 = _verify_code_artifacts_of_framework_processing_step( + pipeline_session, + framework_processor_tf, + default_bucket, + pipeline_name, + definition["Steps"][1], + SOURCE_DIR_2, + "script_2.py", + ) + + # different local source_dirs should have different s3 paths + assert source_dir_1_s3_uri != source_dir_2_s3_uri + + # verify different entry_point paths + assert entry_point_1 != entry_point_2 + + execution = pipeline.start(parameters={}) + try: + execution.wait(delay=540, max_attempts=3) + except WaiterError: + pass + + execution_steps = execution.list_steps() + assert len(execution_steps) == 2 + for step in execution_steps: + assert step["StepStatus"] == "Succeeded" + + finally: + try: + pipeline.delete() + except Exception: + pass + + def test_one_step_pyspark_processing_pipeline( sagemaker_session, role, @@ -796,3 +996,46 @@ def test_two_processing_job_depends_on( pipeline.delete() except Exception: pass + + +def _verify_code_artifacts_of_framework_processing_step( + pipeline_session, processor, bucket, pipeline_name, step_definition, source_dir, entry_point +): + + source_dir_s3_uri = ( + f"s3://{bucket}/{pipeline_name}" f"/code/{hash_files_or_dirs([f'{DATA_DIR}/{source_dir}'])}" + ) + + # verify runproc.sh prefix is different from code artifact prefix + runprocs = [] + for input_obj in step_definition["Arguments"]["ProcessingInputs"]: + if input_obj["InputName"] == "entrypoint": + s3_uri = input_obj["S3Input"]["S3Uri"] + runprocs.append(s3_uri) + + assert Path(s3_uri).parent != source_dir_s3_uri + + # verify only one entrypoint generated per step + assert len(runprocs) == 1 + + expected_source_dir_tar = ( + f"{pipeline_name}" + f"/code/{hash_files_or_dirs([DATA_DIR + '/pipeline/test_source_dir'])}/sourcedir.tar.gz" + ) + + step_script = processor._generate_framework_script(entry_point) + expected_step_artifact = f"{pipeline_name}/code/{hash_object(step_script)}/runproc.sh" + + expected_prefix = f"{pipeline_name}/code" + s3_code_objects = pipeline_session.list_s3_files(bucket=bucket, key_prefix=expected_prefix) + + # verify all distinct artifacts were uploaded + assert expected_source_dir_tar in s3_code_objects + assert expected_step_artifact in s3_code_objects + + # verify runprocs contain the correct commands + step_runproc = S3Downloader.read_file( + f"s3://{bucket}/{expected_step_artifact}", pipeline_session + ) + assert f"python {entry_point}" in step_runproc + return source_dir, expected_step_artifact diff --git a/tests/integ/sagemaker/workflow/test_workflow.py b/tests/integ/sagemaker/workflow/test_workflow.py index 634ef752d6..44f4e2d26e 100644 --- a/tests/integ/sagemaker/workflow/test_workflow.py +++ b/tests/integ/sagemaker/workflow/test_workflow.py @@ -1168,7 +1168,13 @@ def walk(): def test_caching_behavior( - pipeline_session, role, cpu_instance_type, pipeline_name, script_dir, athena_dataset_definition + pipeline_session, + role, + cpu_instance_type, + pipeline_name, + script_dir, + athena_dataset_definition, + region_name, ): default_bucket = pipeline_session.default_bucket() data_path = os.path.join(DATA_DIR, "workflow") From 11d24754b0a8228893f6663ac1ca5048b8a6e794 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 8 Dec 2022 06:16:54 +0000 Subject: [PATCH 25/43] prepare release v2.121.0 --- CHANGELOG.md | 11 +++++++++++ VERSION | 2 +- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 71894ff29d..29dad5f19f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,16 @@ # Changelog +## v2.121.0 (2022-12-08) + +### Features + + * Algorithms Region Expansion OSU/DXB + +### Bug Fixes and Other Changes + + * FrameworkProcessor S3 uploads + * Add constraints file for apache-airflow + ## v2.120.0 (2022-12-07) ### Features diff --git a/VERSION b/VERSION index 73c4cd6968..7f1e14b5a9 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.120.1.dev0 +2.121.0 From 24171b5efcb9c528f159334d6252835ef10bbcb2 Mon Sep 17 00:00:00 2001 From: ci Date: Thu, 8 Dec 2022 06:16:55 +0000 Subject: [PATCH 26/43] update development version to v2.121.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7f1e14b5a9..28b52ee8d5 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.0 +2.121.1.dev0 From d5847d5ebad840c5f47204742302d91064904be8 Mon Sep 17 00:00:00 2001 From: Loki Date: Fri, 9 Dec 2022 03:10:14 +0530 Subject: [PATCH 27/43] Fix: Differentiate SageMaker Training Compiler's PT DLCs from base PT DLC (#3515) --- src/sagemaker/image_uri_config/pytorch-training-compiler.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/pytorch-training-compiler.json b/src/sagemaker/image_uri_config/pytorch-training-compiler.json index 892ff4237d..fd7df875a3 100644 --- a/src/sagemaker/image_uri_config/pytorch-training-compiler.json +++ b/src/sagemaker/image_uri_config/pytorch-training-compiler.json @@ -34,7 +34,7 @@ "us-west-1": "763104351884", "us-west-2": "763104351884" }, - "repository": "pytorch-training" + "repository": "pytorch-trcomp-training" } } } From 3f6ea884a564090f826fab46270429db553c7b3b Mon Sep 17 00:00:00 2001 From: evakravi <69981223+evakravi@users.noreply.github.com> Date: Thu, 8 Dec 2022 17:17:44 -0500 Subject: [PATCH 28/43] fix: Fix failing jumpstart cache unit tests (#3514) --- setup.py | 2 +- src/sagemaker/jumpstart/cache.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 4327045760..f366b147b8 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def read_requirements(filename): "protobuf3-to-dict>=0.1.5,<1.0", "smdebug_rulesconfig==1.0.1", "importlib-metadata>=1.4.0,<5.0", - "packaging>=20.0", + "packaging==20.9", "pandas", "pathos", "schema", diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 202edff9ad..db607770a7 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -20,7 +20,7 @@ import boto3 import botocore from packaging.version import Version -from packaging.specifiers import SpecifierSet +from packaging.specifiers import SpecifierSet, InvalidSpecifier from sagemaker.jumpstart.constants import ( ENV_VARIABLE_JUMPSTART_MANIFEST_LOCAL_ROOT_DIR_OVERRIDE, ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE, @@ -371,7 +371,10 @@ def _select_version( return None return str(max(available_versions)) - spec = SpecifierSet(f"=={semantic_version_str}") + try: + spec = SpecifierSet(f"=={semantic_version_str}") + except InvalidSpecifier: + raise KeyError(f"Bad semantic version: {semantic_version_str}") available_versions_filtered = list(spec.filter(available_versions)) return ( str(max(available_versions_filtered)) if available_versions_filtered != [] else None From 4570aa6078e75ba0d259f8196891b7856790a435 Mon Sep 17 00:00:00 2001 From: qidewenwhen <32910701+qidewenwhen@users.noreply.github.com> Date: Thu, 8 Dec 2022 19:00:48 -0800 Subject: [PATCH 29/43] fix: Pop out ModelPackageName from pipeline definition (#3472) Co-authored-by: Dewen Qi --- src/sagemaker/workflow/_utils.py | 12 ++ .../sagemaker/workflow/test_model_steps.py | 1 + tests/unit/sagemaker/workflow/conftest.py | 75 +++++++++ .../sagemaker/workflow/test_model_step.py | 147 +++++++----------- tests/unit/sagemaker/workflow/test_utils.py | 54 +------ 5 files changed, 150 insertions(+), 139 deletions(-) create mode 100644 tests/unit/sagemaker/workflow/conftest.py diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 8ba65f1eee..cdef9537c1 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -13,6 +13,7 @@ """Scrapper utilities to support repacking of models.""" from __future__ import absolute_import +import logging import os import shutil import tarfile @@ -37,6 +38,8 @@ if TYPE_CHECKING: from sagemaker.workflow.step_collections import StepCollection +logger = logging.getLogger(__name__) + FRAMEWORK_VERSION = "0.23-1" INSTANCE_TYPE = "ml.m5.large" REPACK_SCRIPT = "_repack_model.py" @@ -479,10 +482,19 @@ def arguments(self) -> RequestType: request_dict = get_create_model_package_request(**model_package_args) # these are not available in the workflow service and will cause rejection + warn_msg_template = ( + "Popping out '%s' from the pipeline definition " + "since it will be overridden in pipeline execution time." + ) if "CertifyForMarketplace" in request_dict: request_dict.pop("CertifyForMarketplace") + logger.warning(warn_msg_template, "CertifyForMarketplace") if "Description" in request_dict: request_dict.pop("Description") + logger.warning(warn_msg_template, "Description") + if "ModelPackageName" in request_dict: + request_dict.pop("ModelPackageName") + logger.warning(warn_msg_template, "ModelPackageName") return request_dict diff --git a/tests/integ/sagemaker/workflow/test_model_steps.py b/tests/integ/sagemaker/workflow/test_model_steps.py index 31c518b100..f25723c440 100644 --- a/tests/integ/sagemaker/workflow/test_model_steps.py +++ b/tests/integ/sagemaker/workflow/test_model_steps.py @@ -112,6 +112,7 @@ def test_pytorch_training_model_registration_and_creation_without_custom_inferen inference_instances=["ml.m5.xlarge"], transform_instances=["ml.m5.xlarge"], description="test-description", + model_package_name="model-pkg-name-will-be-popped-out", ) step_model_regis = ModelStep( name="pytorch-register-model", diff --git a/tests/unit/sagemaker/workflow/conftest.py b/tests/unit/sagemaker/workflow/conftest.py new file mode 100644 index 0000000000..9ea3d0bcac --- /dev/null +++ b/tests/unit/sagemaker/workflow/conftest.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +from unittest.mock import Mock, PropertyMock + +import pytest + +from sagemaker import Session +from sagemaker.workflow.pipeline_context import PipelineSession + +REGION = "us-west-2" +BUCKET = "my-bucket" +ROLE = "DummyRole" +IMAGE_URI = "fakeimage" + + +@pytest.fixture(scope="module") +def client(): + """Mock client. + + Considerations when appropriate: + + * utilize botocore.stub.Stubber + * separate runtime client from client + """ + client_mock = Mock() + client_mock._client_config.user_agent = ( + "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" + ) + return client_mock + + +@pytest.fixture(scope="module") +def boto_session(client): + role_mock = Mock() + type(role_mock).arn = PropertyMock(return_value=ROLE) + + resource_mock = Mock() + resource_mock.Role.return_value = role_mock + + session_mock = Mock(region_name=REGION) + session_mock.resource.return_value = resource_mock + session_mock.client.return_value = client + + return session_mock + + +@pytest.fixture(scope="module") +def pipeline_session(boto_session, client): + return PipelineSession( + boto_session=boto_session, + sagemaker_client=client, + default_bucket=BUCKET, + ) + + +@pytest.fixture(scope="module") +def sagemaker_session(boto_session, client): + return Session( + boto_session=boto_session, + sagemaker_client=client, + sagemaker_runtime_client=client, + default_bucket=BUCKET, + ) diff --git a/tests/unit/sagemaker/workflow/test_model_step.py b/tests/unit/sagemaker/workflow/test_model_step.py index 080e70ca62..2216299d3b 100644 --- a/tests/unit/sagemaker/workflow/test_model_step.py +++ b/tests/unit/sagemaker/workflow/test_model_step.py @@ -15,7 +15,7 @@ import json import os -from mock import Mock, PropertyMock, patch +from mock import patch import pytest @@ -43,7 +43,6 @@ ) from sagemaker.workflow.parameters import ParameterString, ParameterInteger from sagemaker.workflow.pipeline import Pipeline, PipelineGraph -from sagemaker.workflow.pipeline_context import PipelineSession from sagemaker.workflow.retry import ( StepRetryPolicy, StepExceptionTypeEnum, @@ -55,11 +54,9 @@ from sagemaker.workflow.lambda_step import LambdaStep, LambdaOutput, LambdaOutputTypeEnum from tests.unit import DATA_DIR from tests.unit.sagemaker.workflow.helpers import CustomStep, ordered +from tests.unit.sagemaker.workflow.conftest import BUCKET, ROLE _IMAGE_URI = "fakeimage" -_REGION = "us-west-2" -_BUCKET = "my-bucket" -_ROLE = "DummyRole" _INSTANCE_TYPE = "ml.m4.xlarge" _SAGEMAKER_PROGRAM = SCRIPT_PARAM_NAME.upper() @@ -69,60 +66,10 @@ _XGBOOST_PATH = os.path.join(DATA_DIR, "xgboost_abalone") _TENSORFLOW_PATH = os.path.join(DATA_DIR, "tfs/tfs-test-entrypoint-and-dependencies") _REPACK_OUTPUT_KEY_PREFIX = "code-output" -_MODEL_CODE_LOCATION = f"s3://{_BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" +_MODEL_CODE_LOCATION = f"s3://{BUCKET}/{_REPACK_OUTPUT_KEY_PREFIX}" _MODEL_CODE_LOCATION_TRAILING_SLASH = _MODEL_CODE_LOCATION + "/" -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def boto_session(client): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=_ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=_REGION) - session_mock.resource.return_value = resource_mock - session_mock.client.return_value = client - - return session_mock - - -@pytest.fixture -def pipeline_session(boto_session, client): - return PipelineSession( - boto_session=boto_session, - sagemaker_client=client, - default_bucket=_BUCKET, - ) - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=_BUCKET, - ) - - @pytest.fixture def model_data_param(): return ParameterString(name="ModelData", default_value="s3://my-bucket/file") @@ -137,7 +84,7 @@ def model(pipeline_session, model_data_param): sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, ) @@ -322,13 +269,13 @@ def test_create_pipeline_model_with_runtime_repack(pipeline_session, model_data_ sparkml_model = SparkMLModel( name="MySparkMLModel", model_data=model_data_param, - role=_ROLE, + role=ROLE, sagemaker_session=pipeline_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, ) # The model need to runtime repack ppl_model = PipelineModel( - models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session + models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session ) step_args = ppl_model.create( instance_type="c4.4xlarge", @@ -417,7 +364,7 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat # The model no need to runtime repack, since source_dir is missing sparkml_model = SparkMLModel( model_data=model_data_param, - role=_ROLE, + role=ROLE, sagemaker_session=pipeline_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", @@ -429,11 +376,11 @@ def test_register_pipeline_model_with_runtime_repack(pipeline_session, model_dat sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, env={"k": "v"}, ) model = PipelineModel( - models=[sparkml_model, model], role=_ROLE, sagemaker_session=pipeline_session + models=[sparkml_model, model], role=ROLE, sagemaker_session=pipeline_session ) step_args = model.register( content_types=["text/csv"], @@ -516,7 +463,7 @@ def test_register_model_without_repack(pipeline_session): model_data=model_data, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", sagemaker_session=pipeline_session, - role=_ROLE, + role=ROLE, ) step_args = model.register( content_types=["text/csv"], @@ -547,7 +494,7 @@ def test_register_model_without_repack(pipeline_session): assert containers[0]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert ( containers[0]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] - == f"s3://{_BUCKET}/{model_name}/sourcedir.tar.gz" + == f"s3://{BUCKET}/{model_name}/sourcedir.tar.gz" ) adjacency_list = PipelineGraph.from_pipeline(pipeline).adjacency_list assert ordered(adjacency_list) == ordered({"MyModelStep-RegisterModel": []}) @@ -560,11 +507,11 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): model = Model( name=model_name, image_uri=_IMAGE_URI, - model_data=f"s3://{_BUCKET}/model.tar.gz", + model_data=f"s3://{BUCKET}/model.tar.gz", sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", source_dir=f"{DATA_DIR}", - role=_ROLE, + role=ROLE, ) step_args = model.create( instance_type="c4.4xlarge", @@ -582,7 +529,7 @@ def test_create_model_with_compile_time_repack(mock_repack, pipeline_session): arguments = step_dsl_list[0]["Arguments"] assert arguments["PrimaryContainer"]["Image"] == _IMAGE_URI assert ( - arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{_BUCKET}/{model_name}/model.tar.gz" + arguments["PrimaryContainer"]["ModelDataUrl"] == f"s3://{BUCKET}/{model_name}/model.tar.gz" ) assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_PROGRAM] == _SCRIPT_NAME assert arguments["PrimaryContainer"]["Environment"][_SAGEMAKER_SUBMIT_DIRECTORY] == _DIR_NAME @@ -700,7 +647,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, enable_network_isolation=True, code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), @@ -713,7 +660,7 @@ def test_conditional_model_create_and_regis( framework_version="1.11.0", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, enable_network_isolation=False, ), 1, @@ -724,7 +671,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, framework_version="1.5.0", code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), @@ -736,7 +683,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, framework_version="1.2.0", ), 1, @@ -747,7 +694,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ), 2, ), @@ -757,7 +704,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, code_location=_MODEL_CODE_LOCATION_TRAILING_SLASH, ), 2, @@ -768,7 +715,7 @@ def test_conditional_model_create_and_regis( model_data="dummy_model_data", image_uri=_IMAGE_URI, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ), 1, ), @@ -789,7 +736,7 @@ def assert_test_result(steps: list): ) else: assert steps[0]["Arguments"]["OutputDataConfig"]["S3OutputPath"] == ( - f"s3://{_BUCKET}/{model.name}" + f"s3://{BUCKET}/{model.name}" ) model, expected_step_num = test_input @@ -828,7 +775,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=os.path.join(_XGBOOST_PATH, "inference.py"), enable_network_isolation=True, ), @@ -845,7 +792,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=os.path.join(_XGBOOST_PATH, "inference.py"), ), { @@ -861,7 +808,7 @@ def assert_test_result(steps: list): XGBoostModel( model_data="dummy_model_step", framework_version="1.3-1", - role=_ROLE, + role=ROLE, entry_point=None, ), { @@ -876,9 +823,8 @@ def assert_test_result(steps: list): ( TensorFlowModel( model_data="dummy_model_step", - role=_ROLE, + role=ROLE, image_uri=_IMAGE_URI, - sagemaker_session=pipeline_session, entry_point=os.path.join(_TENSORFLOW_PATH, "inference.py"), ), { @@ -893,9 +839,8 @@ def assert_test_result(steps: list): ( TensorFlowModel( model_data="dummy_model_step", - role=_ROLE, + role=ROLE, image_uri=_IMAGE_URI, - sagemaker_session=pipeline_session, ), { "expected_step_num": 1, @@ -941,7 +886,7 @@ def test_request_compare_of_register_model_under_different_sessions( _verify_register_model_container_definition(regis_step_arg, expect, dict) # Get create model package request under Session - model.model_data = f"s3://{_BUCKET}" + model.model_data = f"s3://{BUCKET}" model.sagemaker_session = sagemaker_session with patch.object( Session, "_intercept_create_request", return_value=dict(ModelPackageArn="arn:aws") @@ -996,7 +941,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): model_data=lambda_step.properties.Outputs["model_artifact"], sagemaker_session=pipeline_session, entry_point=f"{DATA_DIR}/{_SCRIPT_NAME}", - role=_ROLE, + role=ROLE, ) step_create_model = ModelStep(name="mymodelstep", step_args=model.create()) @@ -1031,7 +976,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ( Processor( image_uri=_IMAGE_URI, - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, ), @@ -1052,7 +997,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ( HyperparameterTuner( estimator=Estimator( - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, image_uri=_IMAGE_URI, @@ -1064,7 +1009,7 @@ def test_model_step_with_lambda_property_reference(pipeline_session): ), ( Estimator( - role=_ROLE, + role=ROLE, instance_count=1, instance_type=_INSTANCE_TYPE, image_uri=_IMAGE_URI, @@ -1128,3 +1073,31 @@ def test_pass_in_wrong_type_of_retry_policies(pipeline_session, model): ), ) assert "SageMakerJobStepRetryPolicy is not allowed for a create/registe" in str(error.value) + + +def test_register_model_step_with_model_package_name(pipeline_session): + model = Model( + name="MyModel", + image_uri="my-image", + model_data="s3://", + sagemaker_session=pipeline_session, + ) + step_args = model.register( + content_types=["text/csv"], + response_types=["text/csv"], + inference_instances=["ml.t2.medium", "ml.m5.xlarge"], + transform_instances=["ml.m5.xlarge"], + model_package_name="model-pkg-name-will-be-popped-out", + ) + regis_model_step = ModelStep( + name="MyModelStep", + step_args=step_args, + ) + pipeline = Pipeline( + name="MyPipeline", + steps=[regis_model_step], + sagemaker_session=pipeline_session, + ) + steps = json.loads(pipeline.definition())["Steps"] + assert len(steps) == 1 + assert "ModelPackageName" not in steps[0]["Arguments"] diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index dcbf5a6421..c8d86c5866 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -18,12 +18,6 @@ import tempfile import pytest -import sagemaker - -from mock import ( - Mock, - PropertyMock, -) from sagemaker.estimator import Estimator from sagemaker.workflow._utils import ( @@ -35,51 +29,7 @@ from sagemaker.workflow.properties import Properties from tests.unit.test_utils import FakeS3, list_tar_files from tests.unit import DATA_DIR - -REGION = "us-west-2" -BUCKET = "my-bucket" -IMAGE_URI = "fakeimage" -ROLE = "DummyRole" - - -@pytest.fixture -def boto_session(): - role_mock = Mock() - type(role_mock).arn = PropertyMock(return_value=ROLE) - - resource_mock = Mock() - resource_mock.Role.return_value = role_mock - - session_mock = Mock(region_name=REGION) - session_mock.resource.return_value = resource_mock - - return session_mock - - -@pytest.fixture -def client(): - """Mock client. - - Considerations when appropriate: - - * utilize botocore.stub.Stubber - * separate runtime client from client - """ - client_mock = Mock() - client_mock._client_config.user_agent = ( - "Boto3/1.14.24 Python/3.8.5 Linux/5.4.0-42-generic Botocore/1.17.24 Resource" - ) - return client_mock - - -@pytest.fixture -def sagemaker_session(boto_session, client): - return sagemaker.session.Session( - boto_session=boto_session, - sagemaker_client=client, - sagemaker_runtime_client=client, - default_bucket=BUCKET, - ) +from tests.unit.sagemaker.workflow.conftest import ROLE, IMAGE_URI, BUCKET @pytest.fixture @@ -171,7 +121,7 @@ def test_repack_model_step(estimator): } -def test_repack_model_step_with_invalid_input(): +def test_register_model_step_with_invalid_input(): # without both step_args and any of the old required arguments with pytest.raises(ValueError) as error: _RegisterModelStep( From 959ea1a485db702f361ddebda2e80779bfd20e43 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 9 Dec 2022 06:20:46 +0000 Subject: [PATCH 30/43] prepare release v2.121.1 --- CHANGELOG.md | 7 +++++++ VERSION | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29dad5f19f..472a25feb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v2.121.1 (2022-12-09) + +### Bug Fixes and Other Changes + + * Pop out ModelPackageName from pipeline definition + * Fix failing jumpstart cache unit tests + ## v2.121.0 (2022-12-08) ### Features diff --git a/VERSION b/VERSION index 28b52ee8d5..f73c7f057e 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.1.dev0 +2.121.1 From b2e8b66016c09a3898123725bf1c01d1a87b05d0 Mon Sep 17 00:00:00 2001 From: ci Date: Fri, 9 Dec 2022 06:20:47 +0000 Subject: [PATCH 31/43] update development version to v2.121.2.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index f73c7f057e..d866b235cc 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.1 +2.121.2.dev0 From 355975d4d2d45088eeb13681f8d99e48a00909c9 Mon Sep 17 00:00:00 2001 From: amzn-choeric <105388439+amzn-choeric@users.noreply.github.com> Date: Fri, 9 Dec 2022 13:53:28 -0500 Subject: [PATCH 32/43] fix: Skip Bad Transform Test (#3521) --- tests/integ/test_inference_pipeline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/integ/test_inference_pipeline.py b/tests/integ/test_inference_pipeline.py index 53d966fe9b..a26d8c9101 100644 --- a/tests/integ/test_inference_pipeline.py +++ b/tests/integ/test_inference_pipeline.py @@ -50,6 +50,7 @@ ) +@pytest.mark.skip(reason="Test has likely been failing for a while. Suspected bad XGB model.") def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type): sparkml_model_data = sagemaker_session.upload_data( path=os.path.join(SPARKML_DATA_PATH, "mleap_model.tar.gz"), From fadc817c7557f5fea5e414d51b500a6b7cd02065 Mon Sep 17 00:00:00 2001 From: Mufaddal Rohawala <89424143+mufaddal-rohawala@users.noreply.github.com> Date: Fri, 9 Dec 2022 12:07:32 -0800 Subject: [PATCH 33/43] fix: Revert "fix: type hint of PySparkProcessor __init__" (#3524) From c5fc93feea798df1713db6707737a2f24738c4c7 Mon Sep 17 00:00:00 2001 From: hballuru <113142824+hballuru@users.noreply.github.com> Date: Fri, 9 Dec 2022 16:36:12 -0600 Subject: [PATCH 34/43] change: Update for Tensorflow Serving 2.11 inference DLCs (#3509) --- .../image_uri_config/tensorflow.json | 43 ++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index a0f2bba014..aaca927ba4 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -303,7 +303,8 @@ "2.7": "2.7.0", "2.8": "2.8.0", "2.9": "2.9.2", - "2.10": "2.10.0" + "2.10": "2.10.0", + "2.11": "2.11.0" }, "versions": { "1.10.0": { @@ -1611,6 +1612,7 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1618,8 +1620,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", @@ -1642,6 +1646,41 @@ "ap-northeast-2": "763104351884", "ap-northeast-3": "364406365360", "ap-south-1": "763104351884", + "ap-south-2": "772153158452", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-central-2": "380420809688", + "eu-north-1": "763104351884", + "eu-south-1": "692866216735", + "eu-south-2": "503227376785", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-gov-west-1": "442386744353", + "us-iso-east-1": "886529160074", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "tensorflow-inference" + }, + "2.11.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-south-2": "772153158452", "ap-southeast-1": "763104351884", "ap-southeast-2": "763104351884", "ap-southeast-3": "907027046896", @@ -1649,8 +1688,10 @@ "cn-north-1": "727897471807", "cn-northwest-1": "727897471807", "eu-central-1": "763104351884", + "eu-central-2": "380420809688", "eu-north-1": "763104351884", "eu-south-1": "692866216735", + "eu-south-2": "503227376785", "eu-west-1": "763104351884", "eu-west-2": "763104351884", "eu-west-3": "763104351884", From ec8da98a9a7cae848e8bf1af06bdaaabd1ebb382 Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 12 Dec 2022 18:18:58 +0000 Subject: [PATCH 35/43] prepare release v2.121.2 --- CHANGELOG.md | 8 ++++++++ VERSION | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 472a25feb8..8b66e85f54 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## v2.121.2 (2022-12-12) + +### Bug Fixes and Other Changes + + * Update for Tensorflow Serving 2.11 inference DLCs + * Revert "fix: type hint of PySparkProcessor __init__" + * Skip Bad Transform Test + ## v2.121.1 (2022-12-09) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index d866b235cc..3b02379cd3 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.2.dev0 +2.121.2 From 03521222d324ed752174038309828ed8183c5aea Mon Sep 17 00:00:00 2001 From: ci Date: Mon, 12 Dec 2022 18:19:00 +0000 Subject: [PATCH 36/43] update development version to v2.121.3.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 3b02379cd3..8fde5e282f 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.2 +2.121.3.dev0 From d6c021404586d4df601a6115add87fcbf75b6d65 Mon Sep 17 00:00:00 2001 From: Kristopher Siman Date: Mon, 12 Dec 2022 17:21:49 -0500 Subject: [PATCH 37/43] feature: Add OSU region to frameworks for DLC (#3532) --- src/sagemaker/image_uri_config/autogluon.json | 12 ++++ .../image_uri_config/huggingface-neuron.json | 1 + .../image_uri_config/huggingface.json | 31 ++++++++ src/sagemaker/image_uri_config/mxnet.json | 13 ++++ .../image_uri_config/pytorch-neuron.json | 1 + src/sagemaker/image_uri_config/pytorch.json | 31 ++++++++ .../image_uri_config/tensorflow.json | 70 +++++++++++++++++++ 7 files changed, 159 insertions(+) diff --git a/src/sagemaker/image_uri_config/autogluon.json b/src/sagemaker/image_uri_config/autogluon.json index 3a9f02142c..590b6e5f82 100644 --- a/src/sagemaker/image_uri_config/autogluon.json +++ b/src/sagemaker/image_uri_config/autogluon.json @@ -30,6 +30,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -61,6 +62,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -92,6 +94,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -123,6 +126,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -154,6 +158,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -185,6 +190,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -230,6 +236,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -267,6 +274,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -304,6 +312,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -341,6 +350,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -378,6 +388,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -415,6 +426,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface-neuron.json b/src/sagemaker/image_uri_config/huggingface-neuron.json index 47d6dbd1dc..980dceed17 100644 --- a/src/sagemaker/image_uri_config/huggingface-neuron.json +++ b/src/sagemaker/image_uri_config/huggingface-neuron.json @@ -33,6 +33,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/huggingface.json b/src/sagemaker/image_uri_config/huggingface.json index 5b98fc0d02..a0caa59a55 100644 --- a/src/sagemaker/image_uri_config/huggingface.json +++ b/src/sagemaker/image_uri_config/huggingface.json @@ -42,6 +42,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -75,6 +76,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -114,6 +116,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -147,6 +150,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -188,6 +192,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -222,6 +227,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -256,6 +262,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -290,6 +297,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -332,6 +340,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -366,6 +375,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -400,6 +410,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -434,6 +445,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -474,6 +486,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -508,6 +521,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -548,6 +562,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -582,6 +597,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -622,6 +638,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -656,6 +673,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -712,6 +730,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -749,6 +768,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -786,6 +806,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -831,6 +852,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -868,6 +890,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -905,6 +928,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -942,6 +966,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -985,6 +1010,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1022,6 +1048,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1065,6 +1092,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1102,6 +1130,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1145,6 +1174,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1182,6 +1212,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/mxnet.json b/src/sagemaker/image_uri_config/mxnet.json index 8d8733e480..588a03a76e 100644 --- a/src/sagemaker/image_uri_config/mxnet.json +++ b/src/sagemaker/image_uri_config/mxnet.json @@ -249,6 +249,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -282,6 +283,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -315,6 +317,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -348,6 +351,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -381,6 +385,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -644,6 +649,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -680,6 +686,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -716,6 +723,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -752,6 +760,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -788,6 +797,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -897,6 +907,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -933,6 +944,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -969,6 +981,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch-neuron.json b/src/sagemaker/image_uri_config/pytorch-neuron.json index b116a8a36b..5b29406955 100644 --- a/src/sagemaker/image_uri_config/pytorch-neuron.json +++ b/src/sagemaker/image_uri_config/pytorch-neuron.json @@ -28,6 +28,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/pytorch.json b/src/sagemaker/image_uri_config/pytorch.json index 18a382e591..85681a3423 100644 --- a/src/sagemaker/image_uri_config/pytorch.json +++ b/src/sagemaker/image_uri_config/pytorch.json @@ -208,6 +208,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -247,6 +248,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -285,6 +287,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -323,6 +326,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -362,6 +366,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -401,6 +406,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -440,6 +446,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -479,6 +486,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -517,6 +525,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -555,6 +564,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -593,6 +603,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -631,6 +642,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -669,6 +681,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -707,6 +720,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -744,6 +758,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -791,6 +806,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -951,6 +967,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -987,6 +1004,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1023,6 +1041,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1058,6 +1077,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1094,6 +1114,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1130,6 +1151,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1166,6 +1188,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1202,6 +1225,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1237,6 +1261,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1272,6 +1297,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1307,6 +1333,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1342,6 +1369,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1377,6 +1405,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1412,6 +1441,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1446,6 +1476,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", diff --git a/src/sagemaker/image_uri_config/tensorflow.json b/src/sagemaker/image_uri_config/tensorflow.json index aaca927ba4..a900aa4fe5 100644 --- a/src/sagemaker/image_uri_config/tensorflow.json +++ b/src/sagemaker/image_uri_config/tensorflow.json @@ -161,6 +161,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -196,6 +197,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -231,6 +233,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -266,6 +269,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -425,6 +429,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -460,6 +465,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -495,6 +501,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -530,6 +537,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -565,6 +573,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -600,6 +609,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -635,6 +645,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -862,6 +873,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -897,6 +909,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -932,6 +945,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -967,6 +981,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1002,6 +1017,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1037,6 +1053,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1072,6 +1089,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1107,6 +1125,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1142,6 +1161,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1177,6 +1197,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1212,6 +1233,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1247,6 +1269,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1282,6 +1305,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1317,6 +1341,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1352,6 +1377,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1387,6 +1413,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1422,6 +1449,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1457,6 +1485,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1492,6 +1521,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1527,6 +1557,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1562,6 +1593,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1597,6 +1629,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1631,6 +1664,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1665,6 +1699,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1699,6 +1734,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1746,6 +1782,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1940,6 +1977,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -1977,6 +2015,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2013,6 +2052,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2050,6 +2090,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2087,6 +2128,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2124,6 +2166,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2161,6 +2204,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2389,6 +2433,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2425,6 +2470,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2461,6 +2507,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2496,6 +2543,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2531,6 +2579,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2567,6 +2616,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2603,6 +2653,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2638,6 +2689,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2673,6 +2725,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2708,6 +2761,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2743,6 +2797,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2778,6 +2833,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2813,6 +2869,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2848,6 +2905,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2883,6 +2941,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2918,6 +2977,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2953,6 +3013,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -2988,6 +3049,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3023,6 +3085,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3058,6 +3121,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3093,6 +3157,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3128,6 +3193,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3163,6 +3229,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3198,6 +3265,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3233,6 +3301,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", @@ -3267,6 +3336,7 @@ "sa-east-1": "763104351884", "us-east-1": "763104351884", "us-east-2": "763104351884", + "us-gov-east-1": "446045086412", "us-gov-west-1": "442386744353", "us-iso-east-1": "886529160074", "us-west-1": "763104351884", From 5af4feb57d950358dcf5dd15aad7f7d59ae11b31 Mon Sep 17 00:00:00 2001 From: Xiaoguang Chen <68292680+xgchena@users.noreply.github.com> Date: Mon, 12 Dec 2022 15:59:33 -0800 Subject: [PATCH 38/43] fix: Remove content type image/jpg from analysis configuration schema (#3530) Currently the analysis configuration schema of SageMaker Clarify API allows the content_type configuration "image/jpeg" and "image/jpg", but the service side validation only accepts the former which is the registered MIME type for JPEG (see rfc3745 and JPEG specification). The commit removes the latter from the schema to avoid confusion and enable early API validation. --- src/sagemaker/clarify.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/clarify.py b/src/sagemaker/clarify.py index 4765630ce8..f082679401 100644 --- a/src/sagemaker/clarify.py +++ b/src/sagemaker/clarify.py @@ -282,7 +282,6 @@ "text/csv", "application/jsonlines", "image/jpeg", - "image/jpg", "image/png", "application/x-npy", ), From 438984754a8f44b34d70154197a3bbeb0272f052 Mon Sep 17 00:00:00 2001 From: Clayton Parnell <42805768+claytonparnell@users.noreply.github.com> Date: Mon, 12 Dec 2022 22:37:35 -0500 Subject: [PATCH 39/43] fix: unpin packaging version (#3533) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index f366b147b8..4327045760 100644 --- a/setup.py +++ b/setup.py @@ -55,7 +55,7 @@ def read_requirements(filename): "protobuf3-to-dict>=0.1.5,<1.0", "smdebug_rulesconfig==1.0.1", "importlib-metadata>=1.4.0,<5.0", - "packaging==20.9", + "packaging>=20.0", "pandas", "pathos", "schema", From a3efddf6d6a4e89861f2ae1eca9d7fd7712a691b Mon Sep 17 00:00:00 2001 From: Anton Repushko Date: Tue, 13 Dec 2022 20:45:06 +0100 Subject: [PATCH 40/43] fix: the Hyperband support fix for the HPO (#3516) Co-authored-by: Anton Repushko --- src/sagemaker/session.py | 9 +++++++ src/sagemaker/tuner.py | 14 +++++------ tests/unit/test_session.py | 48 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 00797c9ea0..3fc4fc1256 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -2121,6 +2121,7 @@ def tune( # noqa: C901 stop_condition, tags, warm_start_config, + strategy_config=None, enable_network_isolation=False, image_uri=None, algorithm_arn=None, @@ -2136,6 +2137,8 @@ def tune( # noqa: C901 Args: job_name (str): Name of the tuning job being created. strategy (str): Strategy to be used for hyperparameter estimations. + strategy_config (dict): A configuration for the hyperparameter tuning + job optimisation strategy. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize'. objective_metric_name (str): Name of the metric for evaluating training jobs. @@ -2220,6 +2223,7 @@ def tune( # noqa: C901 objective_metric_name=objective_metric_name, parameter_ranges=parameter_ranges, early_stopping_type=early_stopping_type, + strategy_config=strategy_config, ), "TrainingJobDefinition": self._map_training_config( static_hyperparameters=static_hyperparameters, @@ -2375,6 +2379,7 @@ def _map_tuning_config( objective_type=None, objective_metric_name=None, parameter_ranges=None, + strategy_config=None, ): """Construct tuning job configuration dictionary. @@ -2392,6 +2397,8 @@ def _map_tuning_config( objective_metric_name (str): Name of the metric for evaluating training jobs. parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types: Continuous, Integer, or Categorical. + strategy_config (dict): A configuration for the hyperparameter tuning job optimisation + strategy. Returns: A dictionary of tuning job configuration. For format details, please refer to @@ -2415,6 +2422,8 @@ def _map_tuning_config( if parameter_ranges is not None: tuning_config["ParameterRanges"] = parameter_ranges + if strategy_config is not None: + tuning_config["StrategyConfig"] = strategy_config return tuning_config @classmethod diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index 52b9d81d0d..9a694cbec9 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -282,8 +282,8 @@ def from_job_desc(cls, hyperband_strategy_config): Returns: sagemaker.tuner.HyperbandStrategyConfig: De-serialized instance of - HyperbandStrategyConfig containing the max_resource and min_resource provided as part of - ``hyperband_strategy_config``. + ``HyperbandStrategyConfig`` containing the max_resource + and min_resource provided as part of ``hyperband_strategy_config``. """ return cls( min_resource=hyperband_strategy_config[HYPERBAND_MIN_RESOURCE], @@ -306,7 +306,7 @@ def to_input_req(self): Returns: dict: Containing the "MaxResource" and - "MinResource" as the first class fields. + "MinResource" as the first class fields. """ return { HYPERBAND_MIN_RESOURCE: self.min_resource, @@ -330,7 +330,7 @@ def __init__( Args: hyperband_strategy_config (sagemaker.tuner.HyperbandStrategyConfig): The configuration - for the object that specifies the Hyperband strategy. + for the object that specifies the Hyperband strategy. This parameter is only supported for the Hyperband selection for Strategy within the HyperParameterTuningJobConfig. """ @@ -461,7 +461,7 @@ def __init__( ``WarmStartConfig`` object that has been initialized with the configuration defining the nature of warm start tuning job. strategy_config (sagemaker.tuner.StrategyConfig): A configuration for "Hyperparameter" - tuning job optimisation strategy. + tuning job optimisation strategy. early_stopping_type (str or PipelineVariable): Specifies whether early stopping is enabled for the job. Can be either 'Auto' or 'Off' (default: 'Off'). If set to 'Off', early stopping will not be attempted. @@ -1569,7 +1569,7 @@ def create( strategy (str): Strategy to be used for hyperparameter estimations (default: 'Bayesian'). strategy_config (dict): The configuration for a training job launched by a - hyperparameter tuning job. + hyperparameter tuning job. objective_type (str): The type of the objective metric for evaluating training jobs. This value can be either 'Minimize' or 'Maximize' (default: 'Maximize'). max_jobs (int): Maximum total number of training jobs to start for the hyperparameter @@ -1776,7 +1776,7 @@ def _get_tuner_args(cls, tuner, inputs): } if tuner.strategy_config is not None: - tuning_config["strategy_config"] = tuner.strategy_config + tuning_config["strategy_config"] = tuner.strategy_config.to_input_req() if tuner.objective_metric_name is not None: tuning_config["objective_type"] = tuner.objective_type diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8958210092..bf81283177 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -941,6 +941,13 @@ def test_train_pack_to_request(sagemaker_session): ], } +SAMPLE_HYPERBAND_STRATEGY_CONFIG = { + "HyperbandStrategyConfig": { + "MinResource": 1, + "MaxResource": 10, + } +} + @pytest.mark.parametrize( "warm_start_type, parents", @@ -1167,6 +1174,47 @@ def assert_create_tuning_job_request(**kwrags): ) +def test_tune_with_strategy_config(sagemaker_session): + def assert_create_tuning_job_request(**kwrags): + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MinResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MinResource"] + ) + assert ( + kwrags["HyperParameterTuningJobConfig"]["StrategyConfig"]["HyperbandStrategyConfig"][ + "MaxResource" + ] + == SAMPLE_HYPERBAND_STRATEGY_CONFIG["HyperbandStrategyConfig"]["MaxResource"] + ) + + sagemaker_session.sagemaker_client.create_hyper_parameter_tuning_job.side_effect = ( + assert_create_tuning_job_request + ) + sagemaker_session.tune( + job_name="dummy-tuning-1", + strategy="Bayesian", + objective_type="Maximize", + objective_metric_name="val-score", + max_jobs=100, + max_parallel_jobs=5, + parameter_ranges=SAMPLE_PARAM_RANGES, + static_hyperparameters=STATIC_HPs, + image_uri="dummy-image-1", + input_mode="File", + metric_definitions=SAMPLE_METRIC_DEF, + role=EXPANDED_ROLE, + input_config=SAMPLE_INPUT, + output_config=SAMPLE_OUTPUT, + resource_config=RESOURCE_CONFIG, + stop_condition=SAMPLE_STOPPING_CONDITION, + tags=None, + warm_start_config=None, + strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG, + ) + + def test_tune_with_encryption_flag(sagemaker_session): def assert_create_tuning_job_request(**kwrags): assert ( From bd96ec5c585217bdec31951d632247f4b0d9f91b Mon Sep 17 00:00:00 2001 From: Md Mizanur Rahman <105268921+mizanfiu@users.noreply.github.com> Date: Tue, 13 Dec 2022 16:06:08 -0800 Subject: [PATCH 41/43] feature: Feature Store dataset builder, delete_record, get_record, list_feature_group (#3534) Co-authored-by: Eric Zou Co-authored-by: Yiming Zou Co-authored-by: Brandon Chatham Co-authored-by: jiapinw <95885824+jiapinw@users.noreply.github.com> --- .../feature_store/dataset_builder.py | 990 ++++++++++++++++++ src/sagemaker/feature_store/feature_group.py | 45 +- src/sagemaker/feature_store/feature_store.py | 130 +++ src/sagemaker/session.py | 94 +- tests/integ/test_feature_store.py | 400 +++++++ .../feature_store/test_dataset_builder.py | 612 +++++++++++ .../feature_store/test_feature_group.py | 580 ++++++++++ .../feature_store/test_feature_store.py | 687 ++---------- tests/unit/test_session.py | 29 + 9 files changed, 2979 insertions(+), 588 deletions(-) create mode 100644 src/sagemaker/feature_store/dataset_builder.py create mode 100644 src/sagemaker/feature_store/feature_store.py create mode 100644 tests/unit/sagemaker/feature_store/test_dataset_builder.py create mode 100644 tests/unit/sagemaker/feature_store/test_feature_group.py diff --git a/src/sagemaker/feature_store/dataset_builder.py b/src/sagemaker/feature_store/dataset_builder.py new file mode 100644 index 0000000000..fc82997379 --- /dev/null +++ b/src/sagemaker/feature_store/dataset_builder.py @@ -0,0 +1,990 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Dataset Builder + +A Dataset Builder is a builder class for generating a dataset by providing conditions. +""" +from __future__ import absolute_import + +import datetime +from enum import Enum +import os +from typing import Any, Dict, List, Tuple, Union + +import attr +import pandas as pd + +from sagemaker import Session, s3, utils +from sagemaker.feature_store.feature_group import FeatureDefinition, FeatureGroup, FeatureTypeEnum + + +_DEFAULT_CATALOG = "AwsDataCatalog" +_DEFAULT_DATABASE = "sagemaker_featurestore" + + +@attr.s +class TableType(Enum): + """Enum of Table types. + + The data type of a table can be FeatureGroup or DataFrame. + """ + + FEATURE_GROUP = "FeatureGroup" + DATA_FRAME = "DataFrame" + + +@attr.s +class FeatureGroupToBeMerged: + """FeatureGroup metadata which will be used for SQL join. + + This class instantiates a FeatureGroupToBeMerged object that comprises a list of feature names, + a list of feature names which will be included in SQL query, a database, an Athena table name, + a feature name of record identifier, a feature name of event time identifier and a feature name + of base which is the target join key. + + Attributes: + features (List[str]): A list of strings representing feature names of this FeatureGroup. + included_feature_names (List[str]): A list of strings representing features to be + included in the sql join. + projected_feature_names (List[str]): A list of strings representing features to be + included for final projection in output. + catalog (str): A string representing the catalog. + database (str): A string representing the database. + table_name (str): A string representing the Athena table name of this FeatureGroup. + record_dentifier_feature_name (str): A string representing the record identifier feature. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + target_feature_name_in_base (str): A string representing the feature name in base which will + be used as target join key (default: None). + table_type (TableType): A TableType representing the type of table if it is Feature Group or + Panda Data Frame (default: None). + """ + + features: List[str] = attr.ib() + included_feature_names: List[str] = attr.ib() + projected_feature_names: List[str] = attr.ib() + catalog: str = attr.ib() + database: str = attr.ib() + table_name: str = attr.ib() + record_identifier_feature_name: str = attr.ib() + event_time_identifier_feature: FeatureDefinition = attr.ib() + target_feature_name_in_base: str = attr.ib(default=None) + table_type: TableType = attr.ib(default=None) + + +def construct_feature_group_to_be_merged( + feature_group: FeatureGroup, + included_feature_names: List[str], + target_feature_name_in_base: str = None, +) -> FeatureGroupToBeMerged: + """Construct a FeatureGroupToBeMerged object by provided parameters. + + Args: + feature_group (FeatureGroup): A FeatureGroup object. + included_feature_names (List[str]): A list of strings representing features to be + included in the output. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + Returns: + A FeatureGroupToBeMerged object. + + Raises: + ValueError: Invalid feature name(s) in included_feature_names. + """ + feature_group_metadata = feature_group.describe() + data_catalog_config = feature_group_metadata.get("OfflineStoreConfig", {}).get( + "DataCatalogConfig", None + ) + if not data_catalog_config: + raise RuntimeError(f"No metastore is configured with FeatureGroup {feature_group.name}.") + + record_identifier_feature_name = feature_group_metadata.get("RecordIdentifierFeatureName", None) + feature_definitions = feature_group_metadata.get("FeatureDefinitions", []) + event_time_identifier_feature_name = feature_group_metadata.get("EventTimeFeatureName", None) + event_time_identifier_feature_type = FeatureTypeEnum( + next( + filter( + lambda f: f.get("FeatureName", None) == event_time_identifier_feature_name, + feature_definitions, + ), + {}, + ).get("FeatureType", None) + ) + table_name = data_catalog_config.get("TableName", None) + database = data_catalog_config.get("Database", None) + disable_glue = feature_group_metadata.get("DisableGlueTableCreation", False) + catalog = data_catalog_config.get("Catalog", None) if disable_glue else _DEFAULT_CATALOG + features = [feature.get("FeatureName", None) for feature in feature_definitions] + + for included_feature in included_feature_names or []: + if included_feature not in features: + raise ValueError( + f"Feature {included_feature} not found in FeatureGroup {feature_group.name}" + ) + if not included_feature_names: + included_feature_names = features + projected_feature_names = features.copy() + else: + projected_feature_names = included_feature_names.copy() + if record_identifier_feature_name not in included_feature_names: + included_feature_names.append(record_identifier_feature_name) + if event_time_identifier_feature_name not in included_feature_names: + included_feature_names.append(event_time_identifier_feature_name) + return FeatureGroupToBeMerged( + features, + included_feature_names, + projected_feature_names, + catalog, + database, + table_name, + record_identifier_feature_name, + FeatureDefinition(event_time_identifier_feature_name, event_time_identifier_feature_type), + target_feature_name_in_base, + TableType.FEATURE_GROUP, + ) + + +@attr.s +class DatasetBuilder: + """DatasetBuilder definition. + + This class instantiates a DatasetBuilder object that comprises a base, a list of feature names, + an output path and a KMS key ID. + + Attributes: + _sagemaker_session (Session): Session instance to perform boto calls. + _base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a Dataset. + _output_path (str): An S3 URI which stores the output .csv file. + _record_identifier_feature_name (str): A string representing the record identifier feature + if base is a DataFrame (default: None). + _event_time_identifier_feature_name (str): A string representing the event time identifier + feature if base is a DataFrame (default: None). + _included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + _kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file + (default: None). + _point_in_time_accurate_join (bool): A boolean representing whether using point in time join + or not (default: False). + _include_duplicated_records (bool): A boolean representing whether including duplicated + records or not (default: False). + _include_deleted_records (bool): A boolean representing whether including deleted records or + not (default: False). + _number_of_recent_records (int): An int that how many records will be returned for each + record identifier (default: 1). + _number_of_records (int): An int that how many records will be returned (default: None). + _write_time_ending_timestamp (datetime.datetime): A datetime that all records' write time in + dataset will be before it (default: None). + _event_time_starting_timestamp (datetime.datetime): A datetime that all records' event time + in dataset will be after it (default: None). + _event_time_ending_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be before it (default: None). + _feature_groups_to_be_merged (List[FeatureGroupToBeMerged]): A list of + FeatureGroupToBeMerged which will be joined to base (default: []). + _event_time_identifier_feature_type (FeatureTypeEnum): A FeatureTypeEnum representing the + type of event time identifier feature (default: None). + """ + + _sagemaker_session: Session = attr.ib() + _base: Union[FeatureGroup, pd.DataFrame] = attr.ib() + _output_path: str = attr.ib() + _record_identifier_feature_name: str = attr.ib(default=None) + _event_time_identifier_feature_name: str = attr.ib(default=None) + _included_feature_names: List[str] = attr.ib(default=None) + _kms_key_id: str = attr.ib(default=None) + + _point_in_time_accurate_join: bool = attr.ib(init=False, default=False) + _include_duplicated_records: bool = attr.ib(init=False, default=False) + _include_deleted_records: bool = attr.ib(init=False, default=False) + _number_of_recent_records: int = attr.ib(init=False, default=None) + _number_of_records: int = attr.ib(init=False, default=None) + _write_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _event_time_starting_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _event_time_ending_timestamp: datetime.datetime = attr.ib(init=False, default=None) + _feature_groups_to_be_merged: List[FeatureGroupToBeMerged] = attr.ib(init=False, factory=list) + _event_time_identifier_feature_type: FeatureTypeEnum = attr.ib(default=None) + + _DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP = { + "object": "STRING", + "int64": "INT", + "float64": "DOUBLE", + "bool": "BOOLEAN", + "datetime64[ns]": "TIMESTAMP", + } + + def with_feature_group( + self, + feature_group: FeatureGroup, + target_feature_name_in_base: str = None, + included_feature_names: List[str] = None, + ): + """Join FeatureGroup with base. + + Args: + feature_group (FeatureGroup): A FeatureGroup which will be joined to base. + target_feature_name_in_base (str): A string representing the feature name in base which + will be used as target join key (default: None). + included_feature_names (List[str]): A list of strings representing features to be + included in the output (default: None). + Returns: + This DatasetBuilder object. + """ + self._feature_groups_to_be_merged.append( + construct_feature_group_to_be_merged( + feature_group, included_feature_names, target_feature_name_in_base + ) + ) + return self + + def point_in_time_accurate_join(self): + """Set join type as point in time accurate join. + + Returns: + This DatasetBuilder object. + """ + self._point_in_time_accurate_join = True + return self + + def include_duplicated_records(self): + """Include duplicated records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_duplicated_records = True + return self + + def include_deleted_records(self): + """Include deleted records in dataset. + + Returns: + This DatasetBuilder object. + """ + self._include_deleted_records = True + return self + + def with_number_of_recent_records_by_record_identifier(self, number_of_recent_records: int): + """Set number_of_recent_records field with provided input. + + Args: + number_of_recent_records (int): An int that how many recent records will be returned for + each record identifier. + Returns: + This DatasetBuilder object. + """ + self._number_of_recent_records = number_of_recent_records + return self + + def with_number_of_records_from_query_results(self, number_of_records: int): + """Set number_of_records field with provided input. + + Args: + number_of_records (int): An int that how many records will be returned. + Returns: + This DatasetBuilder object. + """ + self._number_of_records = number_of_records + return self + + def as_of(self, timestamp: datetime.datetime): + """Set write_time_ending_timestamp field with provided input. + + Args: + timestamp (datetime.datetime): A datetime that all records' write time in dataset will + be before it. + Returns: + This DatasetBuilder object. + """ + self._write_time_ending_timestamp = timestamp + return self + + def with_event_time_range( + self, + starting_timestamp: datetime.datetime = None, + ending_timestamp: datetime.datetime = None, + ): + """Set event_time_starting_timestamp and event_time_ending_timestamp with provided inputs. + + Args: + starting_timestamp (datetime.datetime): A datetime that all records' event time in + dataset will be after it (default: None). + ending_timestamp (datetime.datetime): A datetime that all records' event time in dataset + will be before it (default: None). + Returns: + This DatasetBuilder object. + """ + self._event_time_starting_timestamp = starting_timestamp + self._event_time_ending_timestamp = ending_timestamp + return self + + def to_csv_file(self) -> Tuple[str, str]: + """Get query string and result in .csv format file + + Returns: + The S3 path of the .csv file. + The query string executed. + """ + if isinstance(self._base, pd.DataFrame): + temp_id = utils.unique_name_from_base("dataframe-base") + local_file_name = f"{temp_id}.csv" + desired_s3_folder = f"{self._output_path}/{temp_id}" + self._base.to_csv(local_file_name, index=False, header=False) + s3.S3Uploader.upload( + local_path=local_file_name, + desired_s3_uri=desired_s3_folder, + sagemaker_session=self._sagemaker_session, + kms_key=self._kms_key_id, + ) + os.remove(local_file_name) + temp_table_name = f'dataframe_{temp_id.replace("-", "_")}' + self._create_temp_table(temp_table_name, desired_s3_folder) + base_features = list(self._base.columns) + event_time_identifier_feature_dtype = self._base[ + self._event_time_identifier_feature_name + ].dtypes + self._event_time_identifier_feature_type = ( + FeatureGroup.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + str(event_time_identifier_feature_dtype), None + ) + ) + query_string = self._construct_query_string( + FeatureGroupToBeMerged( + base_features, + self._included_feature_names if self._included_feature_names else base_features, + self._included_feature_names if self._included_feature_names else base_features, + _DEFAULT_CATALOG, + _DEFAULT_DATABASE, + temp_table_name, + self._record_identifier_feature_name, + FeatureDefinition( + self._event_time_identifier_feature_name, + self._event_time_identifier_feature_type, + ), + None, + TableType.DATA_FRAME, + ) + ) + query_result = self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + # TODO: cleanup temp table, need more clarification, keep it for now + return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + "OutputLocation", None + ), query_result.get("QueryExecution", {}).get("Query", None) + if isinstance(self._base, FeatureGroup): + base_feature_group = construct_feature_group_to_be_merged( + self._base, self._included_feature_names + ) + self._record_identifier_feature_name = base_feature_group.record_identifier_feature_name + self._event_time_identifier_feature_name = ( + base_feature_group.event_time_identifier_feature.feature_name + ) + self._event_time_identifier_feature_type = ( + base_feature_group.event_time_identifier_feature.feature_type + ) + query_string = self._construct_query_string(base_feature_group) + query_result = self._run_query( + query_string, + base_feature_group.catalog, + base_feature_group.database, + ) + return query_result.get("QueryExecution", {}).get("ResultConfiguration", {}).get( + "OutputLocation", None + ), query_result.get("QueryExecution", {}).get("Query", None) + raise ValueError("Base must be either a FeatureGroup or a DataFrame.") + + def to_dataframe(self) -> Tuple[pd.DataFrame, str]: + """Get query string and result in pandas.Dataframe + + Returns: + The pandas.DataFrame object. + The query string executed. + """ + csv_file, query_string = self.to_csv_file() + s3.S3Downloader.download( + s3_uri=csv_file, + local_path="./", + kms_key=self._kms_key_id, + sagemaker_session=self._sagemaker_session, + ) + local_file_name = csv_file.split("/")[-1] + df = pd.read_csv(local_file_name) + os.remove(local_file_name) + + local_metadata_file_name = local_file_name + ".metadata" + if os.path.exists(local_metadata_file_name): + os.remove(local_file_name + ".metadata") + + if "row_recent" in df: + df = df.drop("row_recent", axis="columns") + return df, query_string + + def _construct_event_time_conditions( + self, + table_name: str, + event_time_identifier_feature: FeatureDefinition, + ) -> List[str]: + """Internal method for constructing event time range sql range as string. + + Args: + table_name (str): name of the table. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + Returns: + The list of query strings. + """ + event_time_conditions = [] + timestamp_cast_function_name = "from_unixtime" + if event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING: + timestamp_cast_function_name = "from_iso8601_timestamp" + if self._event_time_starting_timestamp: + event_time_conditions.append( + f"{timestamp_cast_function_name}({table_name}." + + f'"{event_time_identifier_feature.feature_name}") >= ' + + f"from_unixtime({self._event_time_starting_timestamp.timestamp()})" + ) + if self._event_time_ending_timestamp: + event_time_conditions.append( + f"{timestamp_cast_function_name}({table_name}." + + f'"{event_time_identifier_feature.feature_name}") <= ' + + f"from_unixtime({self._event_time_ending_timestamp.timestamp()})" + ) + return event_time_conditions + + def _construct_write_time_condition( + self, + table_name: str, + ) -> str: + """Internal method for constructing write time condition. + + Args: + table_name (str): name of the table. + Returns: + string of write time condition. + """ + write_time_condition = ( + f'{table_name}."write_time" <= ' + f"to_timestamp('{self._write_time_ending_timestamp.replace(microsecond=0)}', " + f"'yyyy-mm-dd hh24:mi:ss')" + ) + return write_time_condition + + def _construct_where_query_string( + self, + suffix: str, + event_time_identifier_feature: FeatureDefinition, + where_conditions: List[str], + ) -> str: + """Internal method for constructing SQL WHERE query string by parameters. + + Args: + suffix (str): A temp identifier of the FeatureGroup. + event_time_identifier_feature (FeatureDefinition): A FeatureDefinition representing the + event time identifier feature. + where_conditions (List[str]): A list of strings representing existing where clauses. + Returns: + The WHERE query string. + + Raises: + ValueError: FeatureGroup not provided while using as_of(). Only found pandas.DataFrame. + """ + if self._number_of_recent_records: + if self._number_of_recent_records < 0: + raise ValueError( + "Please provide non-negative integer for number_of_recent_records." + ) + if self._number_of_records: + if self._number_of_records < 0: + raise ValueError("Please provide non-negative integer for number_of_records.") + if self._include_deleted_records: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "include_deleted_records() only works for FeatureGroup," + " if there is no join operation." + ) + if self._include_duplicated_records: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "include_duplicated_records() only works for FeatureGroup," + " if there is no join operation." + ) + if self._point_in_time_accurate_join: + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "point_in_time_accurate_join() this operation only works when there is " + "more than one feature group to join." + ) + if self._write_time_ending_timestamp: + if isinstance(self._base, pd.DataFrame): + if len(self._feature_groups_to_be_merged) == 0: + raise ValueError( + "as_of() only works for FeatureGroup," " if there is no join operation." + ) + if isinstance(self._base, FeatureGroup): + if self._write_time_ending_timestamp: + where_conditions.append(self._construct_write_time_condition(f"table_{suffix}")) + + event_time_conditions = self._construct_event_time_conditions( + f"table_{suffix}", event_time_identifier_feature + ) + where_conditions.extend(event_time_conditions) + + if len(where_conditions) == 0: + return "" + return "WHERE " + "\nAND ".join(where_conditions) + + def _construct_dedup_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing removing duplicate records SQL query string. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The SQL query string. + """ + record_feature_name = feature_group.record_identifier_feature_name + event_time_identifier_feature = feature_group.event_time_identifier_feature + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + rank_query_string = "" + where_conditions = [] + where_conditions_str = "" + is_dedup_enabled = False + + if feature_group.table_type is TableType.FEATURE_GROUP: + is_dedup_enabled = True + rank_query_string = ( + f'ORDER BY origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + + if self._write_time_ending_timestamp: + where_conditions.append(self._construct_write_time_condition(f"origin_{suffix}")) + + event_time_conditions = self._construct_event_time_conditions( + f"origin_{suffix}", event_time_identifier_feature + ) + where_conditions.extend(event_time_conditions) + + if len(where_conditions) != 0: + where_conditions_str = "WHERE " + "\nAND ".join(where_conditions) + "\n" + + dedup_where_clause = f"WHERE dedup_row_{suffix} = 1\n" if is_dedup_enabled else "" + return ( + f"table_{suffix} AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}", ' + + f'origin_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + f") AS dedup_row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + where_conditions_str + + ")\n" + + dedup_where_clause + + ")" + ) + + def _construct_deleted_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing removing deleted records SQL query string. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The SQL query string. + """ + record_feature_name = feature_group.record_identifier_feature_name + event_time_identifier_feature = feature_group.event_time_identifier_feature + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + rank_query_string = f'ORDER BY origin_{suffix}."{event_time_feature_name}" DESC' + write_time_condition = "\n" + event_time_starting_condition = "" + event_time_ending_condition = "" + + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string += ( + f', origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + + if self._write_time_ending_timestamp: + write_time_condition += " AND " + write_time_condition += self._construct_write_time_condition(f"origin_{suffix}") + write_time_condition += "\n" + + if self._event_time_starting_timestamp and self._event_time_ending_timestamp: + event_time_conditions = self._construct_event_time_conditions( + f"origin_{suffix}", event_time_identifier_feature + ) + event_time_starting_condition = "AND " + event_time_conditions[0] + "\n" + event_time_ending_condition = "AND " + event_time_conditions[1] + "\n" + + return ( + f"deleted_{suffix} AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}"\n' + + rank_query_string + + f") AS deleted_row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + "WHERE is_deleted" + + write_time_condition + + event_time_starting_condition + + event_time_ending_condition + + ")\n" + + f"WHERE deleted_row_{suffix} = 1\n" + + ")" + ) + + def _construct_table_included_features( + self, feature_group: FeatureGroupToBeMerged, suffix: str + ) -> str: + """Internal method for constructing included features string of table. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object + which has the metadata. + suffix (str): A temp identifier of the table. + Returns: + The string that includes all feature to be included of table. + """ + + included_features = ", ".join( + [ + f'table_{suffix}."{include_feature_name}"' + for include_feature_name in feature_group.included_feature_names + ] + ) + return included_features + + def _construct_table_query(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing SQL query string by parameters. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The query string. + """ + included_features = self._construct_table_included_features(feature_group, suffix) + + # If base is a FeatureGroup then included_features_write_time will have a write_time column + # Or included_features_write_time is same as included_features + included_features_write_time = included_features + + if feature_group.table_type is TableType.FEATURE_GROUP: + included_features_write_time += f', table_{suffix}."write_time"' + record_feature_name = feature_group.record_identifier_feature_name + event_time_feature_name = feature_group.event_time_identifier_feature.feature_name + if self._include_duplicated_records and self._include_deleted_records: + return ( + f"SELECT {included_features}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" table_{suffix}\n' + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, ["NOT is_deleted"] + ) + ) + if feature_group.table_type is TableType.FEATURE_GROUP and self._include_deleted_records: + rank_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string = ( + f'ORDER BY origin_{suffix}."api_invocation_time" DESC, ' + + f'origin_{suffix}."write_time" DESC\n' + ) + return ( + f"SELECT {included_features}\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + f'PARTITION BY origin_{suffix}."{record_feature_name}", ' + + f'origin_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + f") AS row_{suffix}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}" origin_{suffix}\n' + + "WHERE NOT is_deleted" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, + feature_group.event_time_identifier_feature, + [f"row_{suffix} = 1"], + ) + ) + rank_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + rank_query_string = ( + f'OR (table_{suffix}."{event_time_feature_name}" = ' + + f'deleted_{suffix}."{event_time_feature_name}" ' + + f'AND table_{suffix}."api_invocation_time" > ' + + f'deleted_{suffix}."api_invocation_time")\n' + + f'OR (table_{suffix}."{event_time_feature_name}" = ' + + f'deleted_{suffix}."{event_time_feature_name}" ' + + f'AND table_{suffix}."api_invocation_time" = ' + + f'deleted_{suffix}."api_invocation_time" ' + + f'AND table_{suffix}."write_time" > deleted_{suffix}."write_time")\n' + ) + + final_query_string = "" + if feature_group.table_type is TableType.FEATURE_GROUP: + if self._include_duplicated_records: + final_query_string = ( + f"WITH {self._construct_deleted_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f'FROM "{feature_group.database}"."{feature_group.table_name}"' + + f" table_{suffix}\n" + + f"LEFT JOIN deleted_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n' + + "UNION ALL\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM deleted_{suffix}\n" + + f'JOIN "{feature_group.database}"."{feature_group.table_name}"' + + f" table_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + "AND (\n" + + f'table_{suffix}."{event_time_feature_name}" > ' + + f'deleted_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + ")\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + else: + final_query_string = ( + f"WITH {self._construct_dedup_query(feature_group, suffix)},\n" + + f"{self._construct_deleted_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM table_{suffix}\n" + + f"LEFT JOIN deleted_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + f'WHERE deleted_{suffix}."{record_feature_name}" IS NULL\n' + + "UNION ALL\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM deleted_{suffix}\n" + + f"JOIN table_{suffix}\n" + + f'ON table_{suffix}."{record_feature_name}" = ' + + f'deleted_{suffix}."{record_feature_name}"\n' + + "AND (\n" + + f'table_{suffix}."{event_time_feature_name}" > ' + + f'deleted_{suffix}."{event_time_feature_name}"\n' + + rank_query_string + + ")\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + else: + final_query_string = ( + f"WITH {self._construct_dedup_query(feature_group, suffix)}\n" + + f"SELECT {included_features}\n" + + "FROM (\n" + + f"SELECT {included_features_write_time}\n" + + f"FROM table_{suffix}\n" + + f") AS table_{suffix}\n" + + self._construct_where_query_string( + suffix, feature_group.event_time_identifier_feature, [] + ) + ) + return final_query_string + + def _construct_query_string(self, base: FeatureGroupToBeMerged) -> str: + """Internal method for constructing SQL query string by parameters. + + Args: + base (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the metadata. + Returns: + The query string. + + Raises: + ValueError: target_feature_name_in_base is an invalid feature name. + """ + base_table_query_string = self._construct_table_query(base, "base") + query_string = f"WITH fg_base AS ({base_table_query_string})" + if len(self._feature_groups_to_be_merged) > 0: + with_subquery_string = "".join( + [ + f",\nfg_{i} AS ({self._construct_table_query(feature_group, str(i))})" + for i, feature_group in enumerate(self._feature_groups_to_be_merged) + ] + ) + query_string += with_subquery_string + + selected_features = "" + selected_features += ", ".join(map("fg_base.{0}".format, base.projected_feature_names)) + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + selected_features += ", " + selected_features += ", ".join( + [ + f'fg_{i}."{feature_name}" as "{feature_name}.{(i+1)}"' + for feature_name in feature_group.projected_feature_names + ] + ) + + selected_features_final = "" + selected_features_final += ", ".join(base.projected_feature_names) + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + selected_features_final += ", " + selected_features_final += ", ".join( + [ + '"{0}.{1}"'.format(feature_name, (i + 1)) + for feature_name in feature_group.projected_feature_names + ] + ) + + query_string += ( + f"\nSELECT {selected_features_final}\n" + + "FROM (\n" + + f"SELECT {selected_features}, row_number() OVER (\n" + + f'PARTITION BY fg_base."{base.record_identifier_feature_name}"\n' + + f'ORDER BY fg_base."{base.event_time_identifier_feature.feature_name}" DESC' + ) + + recent_record_where_clause = "" + if self._number_of_recent_records is not None and self._number_of_recent_records >= 0: + recent_record_where_clause = f"WHERE row_recent <= {self._number_of_recent_records}" + + join_subquery_strings = [] + if len(self._feature_groups_to_be_merged) > 0: + for i, feature_group in enumerate(self._feature_groups_to_be_merged): + if not feature_group.target_feature_name_in_base: + feature_group.target_feature_name_in_base = self._record_identifier_feature_name + else: + if feature_group.target_feature_name_in_base not in base.features: + raise ValueError( + f"Feature {feature_group.target_feature_name_in_base} not found in base" + ) + query_string += ( + f', fg_{i}."{feature_group.event_time_identifier_feature.feature_name}" DESC' + ) + join_subquery_strings.append(self._construct_join_condition(feature_group, str(i))) + + query_string += ( + "\n) AS row_recent\n" + + "FROM fg_base" + + "".join(join_subquery_strings) + + "\n)\n" + + f"{recent_record_where_clause}" + ) + + if self._number_of_records is not None and self._number_of_records >= 0: + query_string += f"\nLIMIT {self._number_of_records}" + return query_string + + def _construct_join_condition(self, feature_group: FeatureGroupToBeMerged, suffix: str) -> str: + """Internal method for constructing SQL JOIN query string by parameters. + + Args: + feature_group (FeatureGroupToBeMerged): A FeatureGroupToBeMerged object which has the + FeatureGroup metadata. + suffix (str): A temp identifier of the FeatureGroup. + Returns: + The JOIN query string. + """ + join_condition_string = ( + f"\nJOIN fg_{suffix}\n" + + f'ON fg_base."{feature_group.target_feature_name_in_base}" = ' + + f'fg_{suffix}."{feature_group.record_identifier_feature_name}"' + ) + base_timestamp_cast_function_name = "from_unixtime" + if self._event_time_identifier_feature_type == FeatureTypeEnum.STRING: + base_timestamp_cast_function_name = "from_iso8601_timestamp" + timestamp_cast_function_name = "from_unixtime" + if feature_group.event_time_identifier_feature.feature_type == FeatureTypeEnum.STRING: + timestamp_cast_function_name = "from_iso8601_timestamp" + if self._point_in_time_accurate_join: + join_condition_string += ( + f"\nAND {base_timestamp_cast_function_name}(fg_base." + + f'"{self._event_time_identifier_feature_name}") >= ' + + f"{timestamp_cast_function_name}(fg_{suffix}." + + f'"{feature_group.event_time_identifier_feature.feature_name}")' + ) + return join_condition_string + + def _create_temp_table(self, temp_table_name: str, desired_s3_folder: str): + """Internal method for creating a temp Athena table for the base pandas.Dataframe. + + Args: + temp_table_name (str): The Athena table name of base pandas.DataFrame. + desired_s3_folder (str): The S3 URI of the folder of the data. + """ + columns_string = ", ".join( + [self._construct_athena_table_column_string(column) for column in self._base.columns] + ) + serde_properties = '"separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\"' + query_string = ( + f"CREATE EXTERNAL TABLE {temp_table_name} ({columns_string}) " + + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + + f"WITH SERDEPROPERTIES ({serde_properties}) " + + f"LOCATION '{desired_s3_folder}';" + ) + self._run_query(query_string, _DEFAULT_CATALOG, _DEFAULT_DATABASE) + + def _construct_athena_table_column_string(self, column: str) -> str: + """Internal method for constructing string of Athena column. + + Args: + column (str): The column name from pandas.Dataframe. + Returns: + The Athena column string. + + Raises: + RuntimeError: The type of pandas.Dataframe column is not support yet. + """ + dataframe_type = self._base[column].dtypes + if str(dataframe_type) not in self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.keys(): + raise RuntimeError(f"The dataframe type {dataframe_type} is not supported yet.") + return f"{column} {self._DATAFRAME_TYPE_TO_COLUMN_TYPE_MAP.get(str(dataframe_type), None)}" + + def _run_query(self, query_string: str, catalog: str, database: str) -> Dict[str, Any]: + """Internal method for execute Athena query, wait for query finish and get query result. + + Args: + query_string (str): The SQL query statements to be executed. + catalog (str): The name of the data catalog used in the query execution. + database (str): The name of the database used in the query execution. + Returns: + The query result. + + Raises: + RuntimeError: Athena query failed. + """ + query = self._sagemaker_session.start_query_execution( + catalog=catalog, + database=database, + query_string=query_string, + output_location=self._output_path, + kms_key=self._kms_key_id, + ) + query_id = query.get("QueryExecutionId", None) + self._sagemaker_session.wait_for_athena_query(query_execution_id=query_id) + query_result = self._sagemaker_session.get_query_execution(query_execution_id=query_id) + query_state = query_result.get("QueryExecution", {}).get("Status", {}).get("State", None) + + if query_state != "SUCCEEDED": + raise RuntimeError(f"Failed to execute query {query_id}.") + return query_result diff --git a/src/sagemaker/feature_store/feature_group.py b/src/sagemaker/feature_store/feature_group.py index d486ab8a01..855e11488f 100644 --- a/src/sagemaker/feature_store/feature_group.py +++ b/src/sagemaker/feature_store/feature_group.py @@ -435,13 +435,14 @@ class FeatureGroup: "uint64", ] _FLOAT_TYPES = ["float_", "float16", "float32", "float64"] - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = { + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP: Dict[str, FeatureTypeEnum] = { type: FeatureTypeEnum.INTEGRAL for type in _INTEGER_TYPES } - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update( + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.update( {type: FeatureTypeEnum.FRACTIONAL for type in _FLOAT_TYPES} ) - _DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["string"] = FeatureTypeEnum.STRING + DTYPE_TO_FEATURE_DEFINITION_CLS_MAP["object"] = FeatureTypeEnum.STRING _FEATURE_TYPE_TO_DDL_DATA_TYPE_MAP = { FeatureTypeEnum.INTEGRAL.value: "INT", @@ -629,7 +630,7 @@ def load_feature_definitions( """ feature_definitions = [] for column in data_frame: - feature_type = self._DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( + feature_type = self.DTYPE_TO_FEATURE_DEFINITION_CLS_MAP.get( str(data_frame[column].dtype), None ) if feature_type: @@ -644,6 +645,23 @@ def load_feature_definitions( self.feature_definitions = feature_definitions return self.feature_definitions + def get_record( + self, record_identifier_value_as_string: str, feature_names: Sequence[str] = None + ) -> Sequence[Dict[str, str]]: + """Get a single record in a FeatureGroup + + Args: + record_identifier_value_as_string (String): + a String representing the value of the record identifier. + feature_names (Sequence[String]): + a list of Strings representing feature names. + """ + return self.sagemaker_session.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_group_name=self.name, + feature_names=feature_names, + ).get("Record") + def put_record(self, record: Sequence[FeatureValue]): """Put a single record in the FeatureGroup. @@ -654,6 +672,25 @@ def put_record(self, record: Sequence[FeatureValue]): feature_group_name=self.name, record=[value.to_dict() for value in record] ) + def delete_record( + self, + record_identifier_value_as_string: str, + event_time: str, + ): + """Delete a single record from a FeatureGroup. + + Args: + record_identifier_value_as_string (String): + a String representing the value of the record identifier. + event_time (String): + a timestamp format String indicating when the deletion event occurred. + """ + return self.sagemaker_session.delete_record( + feature_group_name=self.name, + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + def ingest( self, data_frame: DataFrame, diff --git a/src/sagemaker/feature_store/feature_store.py b/src/sagemaker/feature_store/feature_store.py new file mode 100644 index 0000000000..def8b2b2da --- /dev/null +++ b/src/sagemaker/feature_store/feature_store.py @@ -0,0 +1,130 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Feature Store. + +Amazon SageMaker Feature Store is a fully managed, purpose-built repository to store, share, and +manage features for machine learning (ML) models. +""" +from __future__ import absolute_import + +import datetime +from typing import Any, Dict, Sequence, Union + +import attr +import pandas as pd + +from sagemaker import Session +from sagemaker.feature_store.dataset_builder import DatasetBuilder +from sagemaker.feature_store.feature_group import FeatureGroup + + +@attr.s +class FeatureStore: + """FeatureStore definition. + + This class instantiates a FeatureStore object that comprises a SageMaker session instance. + + Attributes: + sagemaker_session (Session): session instance to perform boto calls. + """ + + sagemaker_session: Session = attr.ib(default=Session) + + def create_dataset( + self, + base: Union[FeatureGroup, pd.DataFrame], + output_path: str, + record_identifier_feature_name: str = None, + event_time_identifier_feature_name: str = None, + included_feature_names: Sequence[str] = None, + kms_key_id: str = None, + ) -> DatasetBuilder: + """Create a Dataset Builder for generating a Dataset. + + Args: + base (Union[FeatureGroup, DataFrame]): A base which can be either a FeatureGroup or a + pandas.DataFrame and will be used to merge other FeatureGroups and generate a + Dataset. + output_path (str): An S3 URI which stores the output .csv file. + record_identifier_feature_name (str): A string representing the record identifier + feature if base is a DataFrame (default: None). + event_time_identifier_feature_name (str): A string representing the event time + identifier feature if base is a DataFrame (default: None). + included_feature_names (List[str]): A list of features to be included in the output + (default: None). + kms_key_id (str): An KMS key id. If set, will be used to encrypt the result file + (default: None). + + Raises: + ValueError: Base is a Pandas DataFrame but no record identifier feature name nor event + time identifier feature name is provided. + """ + if isinstance(base, pd.DataFrame): + if record_identifier_feature_name is None or event_time_identifier_feature_name is None: + raise ValueError( + "You must provide a record identifier feature name and an event time " + + "identifier feature name if specify DataFrame as base." + ) + return DatasetBuilder( + self.sagemaker_session, + base, + output_path, + record_identifier_feature_name, + event_time_identifier_feature_name, + included_feature_names, + kms_key_id, + ) + + def list_feature_groups( + self, + name_contains: str = None, + feature_group_status_equals: str = None, + offline_store_status_equals: str = None, + creation_time_after: datetime.datetime = None, + creation_time_before: datetime.datetime = None, + sort_order: str = None, + sort_by: str = None, + max_results: int = None, + next_token: str = None, + ) -> Dict[str, Any]: + """List all FeatureGroups satisfying given filters. + + Args: + name_contains (str): A string that partially matches one or more FeatureGroups' names. + Filters FeatureGroups by name. + feature_group_status_equals (str): A FeatureGroup status. + Filters FeatureGroups by FeatureGroup status. + offline_store_status_equals (str): An OfflineStore status. + Filters FeatureGroups by OfflineStore status. + creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups + created after a specific date and time. + creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups + created before a specific date and time. + sort_order (str): The order in which FeatureGroups are listed. + sort_by (str): The value on which the FeatureGroup list is sorted. + max_results (int): The maximum number of results returned by ListFeatureGroups. + next_token (str): A token to resume pagination of ListFeatureGroups results. + Returns: + Response dict from service. + """ + return self.sagemaker_session.list_feature_groups( + name_contains=name_contains, + feature_group_status_equals=feature_group_status_equals, + offline_store_status_equals=offline_store_status_equals, + creation_time_after=creation_time_after, + creation_time_before=creation_time_before, + sort_order=sort_order, + sort_by=sort_by, + max_results=max_results, + next_token=next_token, + ) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index 3fc4fc1256..72df570496 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -312,7 +312,7 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None): # For each object key, create the directory on the local machine if needed, and then # download the file. for key in keys: - tail_s3_uri_path = os.path.basename(key_prefix) + tail_s3_uri_path = os.path.basename(key) if not os.path.splitext(key_prefix)[1]: tail_s3_uri_path = os.path.relpath(key, key_prefix) destination_path = os.path.join(path, tail_s3_uri_path) @@ -4341,6 +4341,56 @@ def update_feature_group( FeatureGroupName=feature_group_name, FeatureAdditions=feature_additions ) + def list_feature_groups( + self, + name_contains, + feature_group_status_equals, + offline_store_status_equals, + creation_time_after, + creation_time_before, + sort_order, + sort_by, + max_results, + next_token, + ) -> Dict[str, Any]: + """List all FeatureGroups satisfying given filters. + + Args: + name_contains (str): A string that partially matches one or more FeatureGroups' names. + Filters FeatureGroups by name. + feature_group_status_equals (str): A FeatureGroup status. + Filters FeatureGroups by FeatureGroup status. + offline_store_status_equals (str): An OfflineStore status. + Filters FeatureGroups by OfflineStore status. + creation_time_after (datetime.datetime): Use this parameter to search for FeatureGroups + created after a specific date and time. + creation_time_before (datetime.datetime): Use this parameter to search for FeatureGroups + created before a specific date and time. + sort_order (str): The order in which FeatureGroups are listed. + sort_by (str): The value on which the FeatureGroup list is sorted. + max_results (int): The maximum number of results returned by ListFeatureGroups. + next_token (str): A token to resume pagination of ListFeatureGroups results. + Returns: + Response dict from service. + """ + list_feature_groups_args = {} + + def check_object(key, value): + if value is not None: + list_feature_groups_args[key] = value + + check_object("NameContains", name_contains) + check_object("FeatureGroupStatusEquals", feature_group_status_equals) + check_object("OfflineStoreStatusEquals", offline_store_status_equals) + check_object("CreationTimeAfter", creation_time_after) + check_object("CreationTimeBefore", creation_time_before) + check_object("SortOrder", sort_order) + check_object("SortBy", sort_by) + check_object("MaxResults", max_results) + check_object("NextToken", next_token) + + return self.sagemaker_client.list_feature_groups(**list_feature_groups_args) + def update_feature_metadata( self, feature_group_name: str, @@ -4408,6 +4458,48 @@ def put_record( Record=record, ) + def delete_record( + self, + feature_group_name: str, + record_identifier_value_as_string: str, + event_time: str, + ): + """Deletes a single record from the FeatureGroup. + + Args: + feature_group_name (str): name of the FeatureGroup. + record_identifier_value_as_string (str): name of the record identifier. + event_time (str): a timestamp indicating when the deletion event occurred. + """ + return self.sagemaker_featurestore_runtime_client.delete_record( + FeatureGroupName=feature_group_name, + RecordIdentifierValueAsString=record_identifier_value_as_string, + EventTime=event_time, + ) + + def get_record( + self, + record_identifier_value_as_string: str, + feature_group_name: str, + feature_names: Sequence[str], + ) -> Dict[str, Sequence[Dict[str, str]]]: + """Gets a single record in the FeatureGroup. + + Args: + record_identifier_value_as_string (str): name of the record identifier. + feature_group_name (str): name of the FeatureGroup. + feature_names (Sequence[str]): list of feature names. + """ + get_record_args = { + "FeatureGroupName": feature_group_name, + "RecordIdentifierValueAsString": record_identifier_value_as_string, + } + + if feature_names: + get_record_args["FeatureNames"] = feature_names + + return self.sagemaker_featurestore_runtime_client.get_record(**get_record_args) + def start_query_execution( self, catalog: str, diff --git a/tests/integ/test_feature_store.py b/tests/integ/test_feature_store.py index c1b84117c3..e19cebdca4 100644 --- a/tests/integ/test_feature_store.py +++ b/tests/integ/test_feature_store.py @@ -14,6 +14,7 @@ import json import time +import datetime from contextlib import contextmanager import boto3 @@ -24,6 +25,7 @@ from sagemaker.feature_store.feature_definition import FractionalFeatureDefinition from sagemaker.feature_store.feature_group import FeatureGroup +from sagemaker.feature_store.feature_store import FeatureStore from sagemaker.feature_store.inputs import FeatureValue, FeatureParameter, TableFormatEnum from sagemaker.session import get_execution_role, Session from tests.integ.timeout import timeout @@ -80,6 +82,11 @@ def feature_group_name(): return f"my-feature-group-{int(time.time() * 10**7)}" +@pytest.fixture +def base_name(): + return f"my-base-{int(time.time() * 10**7)}" + + @pytest.fixture def offline_store_s3_uri(feature_store_session, region_name): bucket = f"sagemaker-test-featurestore-{region_name}-{feature_store_session.account_id()}" @@ -107,6 +114,32 @@ def pandas_data_frame(): return df +@pytest.fixture +def base_dataframe(): + base_data = [ + [1, 187512346.0, 123, 128], + [2, 187512347.0, 168, 258], + [3, 187512348.0, 125, 184], + [1, 187512349.0, 195, 206], + ] + return pd.DataFrame( + base_data, columns=["base_id", "base_time", "base_feature_1", "base_feature_2"] + ) + + +@pytest.fixture +def feature_group_dataframe(): + feature_group_data = [ + [1, 187512246.0, 456, 325], + [2, 187512247.0, 729, 693], + [3, 187512348.0, 129, 901], + [1, 187512449.0, 289, 286], + ] + return pd.DataFrame( + feature_group_data, columns=["fg_id", "fg_time", "fg_feature_1", "fg_feature_2"] + ) + + @pytest.fixture def pandas_data_frame_without_string(): df = pd.DataFrame( @@ -288,6 +321,92 @@ def test_create_feature_group_glue_table_format( assert table_format == "Glue" +def test_get_record( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, + record, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + record_identifier_value_as_string = record[0].value_as_string + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + # Ingest data + feature_group.put_record(record=record) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + record_names = list(map(lambda r: r.feature_name, record)) + assert len(retrieved_record) == len(record_names) + for feature in retrieved_record: + assert feature["FeatureName"] in record_names + removed_feature_name = record_names.pop() + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=record_names, + ) + assert len(retrieved_record) == len(record_names) + for feature in retrieved_record: + assert feature["FeatureName"] in record_names + assert feature["FeatureName"] is not removed_feature_name + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string="1.0", + ) + assert retrieved_record is None + + +def test_delete_record( + feature_store_session, + role, + feature_group_name, + pandas_data_frame, + record, +): + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + record_identifier_value_as_string = record[0].value_as_string + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + # Ingest data + feature_group.put_record(record=record) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + assert retrieved_record is not None + # Delete data + feature_group.delete_record( + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=datetime.datetime.now().replace(microsecond=0).isoformat() + "Z", + ) + # Retrieve data + retrieved_record = feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + ) + assert retrieved_record is None + + def test_update_feature_group( feature_store_session, role, @@ -316,6 +435,25 @@ def test_update_feature_group( assert any([True for elem in feature_definitions if new_feature_name in elem.values()]) +def test_list_feature_groups(feature_store_session, role, feature_group_name, pandas_data_frame): + feature_store = FeatureStore(sagemaker_session=feature_store_session) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + feature_group.load_feature_definitions(data_frame=pandas_data_frame) + + with cleanup_feature_group(feature_group): + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature3", + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + output = feature_store.list_feature_groups(name_contains=feature_group_name) + + assert output["FeatureGroupSummaries"][0]["FeatureGroupName"] == feature_group_name + + def test_feature_metadata( feature_store_session, role, @@ -420,6 +558,242 @@ def test_ingest_multi_process( assert output["FeatureGroupArn"].endswith(f"feature-group/{feature_group_name}") +def test_create_dataset_with_feature_group_base( + feature_store_session, + region_name, + role, + base_name, + feature_group_name, + offline_store_s3_uri, + base_dataframe, + feature_group_dataframe, +): + base = FeatureGroup(name=base_name, sagemaker_session=feature_store_session) + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=feature_store_session) + with cleanup_feature_group(base), cleanup_feature_group(feature_group): + _create_feature_group_and_ingest_data( + base, base_dataframe, offline_store_s3_uri, "base_id", "base_time", role + ) + _create_feature_group_and_ingest_data( + feature_group, feature_group_dataframe, offline_store_s3_uri, "fg_id", "fg_time", role + ) + base_table_name = _get_athena_table_name_after_data_replication( + feature_store_session, base, offline_store_s3_uri + ) + feature_group_table_name = _get_athena_table_name_after_data_replication( + feature_store_session, feature_group, offline_store_s3_uri + ) + + with timeout(minutes=10) and cleanup_offline_store( + base_table_name, feature_store_session + ) and cleanup_offline_store(feature_group_table_name, feature_store_session): + feature_store = FeatureStore(sagemaker_session=feature_store_session) + df, query_string = ( + feature_store.create_dataset(base=base, output_path=offline_store_s3_uri) + .with_number_of_recent_records_by_record_identifier(4) + .with_feature_group(feature_group) + .to_dataframe() + ) + sorted_df = df.sort_values(by=list(df.columns)).reset_index(drop=True) + merged_df = base_dataframe.merge( + feature_group_dataframe, left_on="base_id", right_on="fg_id" + ) + + expect_df = merged_df.sort_values(by=list(merged_df.columns)).reset_index(drop=True) + + expect_df.rename( + columns={ + "fg_id": "fg_id.1", + "fg_time": "fg_time.1", + "fg_feature_1": "fg_feature_1.1", + "fg_feature_2": "fg_feature_2.1", + }, + inplace=True, + ) + + assert sorted_df.equals(expect_df) + assert ( + query_string + == "WITH fg_base AS (WITH table_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."base_id", origin_base."base_time"\n' + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' + + ") AS dedup_row_base\n" + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' + + ")\n" + + "WHERE dedup_row_base = 1\n" + + "),\n" + + "deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."base_id"\n' + + 'ORDER BY origin_base."base_time" DESC,' + ' origin_base."api_invocation_time" DESC,' + ' origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + f'FROM "sagemaker_featurestore"."{base_table_name}" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2"\n' + + "FROM (\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2",' + ' table_base."write_time"\n' + + "FROM table_base\n" + + "LEFT JOIN deleted_base\n" + + 'ON table_base."base_id" = deleted_base."base_id"\n' + + 'WHERE deleted_base."base_id" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."base_id", table_base."base_time",' + ' table_base."base_feature_1", table_base."base_feature_2",' + ' table_base."write_time"\n' + + "FROM deleted_base\n" + + "JOIN table_base\n" + + 'ON table_base."base_id" = deleted_base."base_id"\n' + + "AND (\n" + + 'table_base."base_time" > deleted_base."base_time"\n' + + 'OR (table_base."base_time" = deleted_base."base_time" AND' + ' table_base."api_invocation_time" >' + ' deleted_base."api_invocation_time")\n' + + 'OR (table_base."base_time" = deleted_base."base_time" AND' + ' table_base."api_invocation_time" =' + ' deleted_base."api_invocation_time" AND' + ' table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH table_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."fg_id", origin_0."fg_time"\n' + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' + + ") AS dedup_row_0\n" + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' + + ")\n" + + "WHERE dedup_row_0 = 1\n" + + "),\n" + + "deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."fg_id"\n' + + 'ORDER BY origin_0."fg_time" DESC, origin_0."api_invocation_time" DESC,' + ' origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + f'FROM "sagemaker_featurestore"."{feature_group_table_name}" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."fg_id", table_0."fg_time", table_0."fg_feature_1",' + ' table_0."fg_feature_2"\n' + + "FROM (\n" + + 'SELECT table_0."fg_id", table_0."fg_time",' + ' table_0."fg_feature_1", table_0."fg_feature_2",' + ' table_0."write_time"\n' + + "FROM table_0\n" + + "LEFT JOIN deleted_0\n" + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' + + 'WHERE deleted_0."fg_id" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."fg_id", table_0."fg_time",' + ' table_0."fg_feature_1", table_0."fg_feature_2",' + ' table_0."write_time"\n' + + "FROM deleted_0\n" + + "JOIN table_0\n" + + 'ON table_0."fg_id" = deleted_0."fg_id"\n' + + "AND (\n" + + 'table_0."fg_time" > deleted_0."fg_time"\n' + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' + ' table_0."api_invocation_time" >' + ' deleted_0."api_invocation_time")\n' + + 'OR (table_0."fg_time" = deleted_0."fg_time" AND' + ' table_0."api_invocation_time" =' + ' deleted_0."api_invocation_time" AND table_0."write_time" >' + ' deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + "SELECT base_id, base_time, base_feature_1, base_feature_2," + ' "fg_id.1", "fg_time.1", "fg_feature_1.1",' + ' "fg_feature_2.1"\n' + "FROM (\n" + "SELECT fg_base.base_id, fg_base.base_time," + " fg_base.base_feature_1, fg_base.base_feature_2," + ' fg_0."fg_id" as "fg_id.1", fg_0."fg_time" as "fg_time.1",' + ' fg_0."fg_feature_1" as "fg_feature_1.1",' + ' fg_0."fg_feature_2" as "fg_feature_2.1", row_number()' + " OVER (\n" + + 'PARTITION BY fg_base."base_id"\n' + + 'ORDER BY fg_base."base_time" DESC, fg_0."fg_time" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."base_id" = fg_0."fg_id"\n' + + ")\n" + + "WHERE row_recent <= 4" + ) + + +def _create_feature_group_and_ingest_data( + feature_group: FeatureGroup, + dataframe: DataFrame, + offline_store_s3_uri: str, + record_identifier_name: str, + event_time_name: str, + role: str, +): + feature_group.load_feature_definitions(data_frame=dataframe) + feature_group.create( + s3_uri=offline_store_s3_uri, + record_identifier_name=record_identifier_name, + event_time_feature_name=event_time_name, + role_arn=role, + enable_online_store=True, + ) + _wait_for_feature_group_create(feature_group) + + ingestion_manager = feature_group.ingest(data_frame=dataframe, max_workers=3, wait=False) + ingestion_manager.wait() + assert 0 == len(ingestion_manager.failed_rows) + + +def _get_athena_table_name_after_data_replication( + feature_store_session, feature_group: FeatureGroup, offline_store_s3_uri +): + feature_group_metadata = feature_group.describe() + resolved_output_s3_uri = ( + feature_group_metadata.get("OfflineStoreConfig", None) + .get("S3StorageConfig", None) + .get("ResolvedOutputS3Uri", None) + ) + s3_prefix = resolved_output_s3_uri.replace(f"{offline_store_s3_uri}/", "") + region_name = feature_store_session.boto_session.region_name + s3_client = feature_store_session.boto_session.client( + service_name="s3", region_name=region_name + ) + while True: + objects_in_bucket = s3_client.list_objects( + Bucket=offline_store_s3_uri.replace("s3://", ""), Prefix=s3_prefix + ) + if "Contents" in objects_in_bucket and len(objects_in_bucket["Contents"]) > 1: + break + else: + print(f"Waiting for {feature_group.name} data in offline store...") + time.sleep(60) + print(f"{feature_group.name} data available.") + return ( + feature_group_metadata.get("OfflineStoreConfig", None) + .get("DataCatalogConfig", None) + .get("TableName", None) + ) + + def _wait_for_feature_group_create(feature_group: FeatureGroup): status = feature_group.describe().get("FeatureGroupStatus") while status == "Creating": @@ -451,5 +825,31 @@ def cleanup_feature_group(feature_group: FeatureGroup): finally: try: feature_group.delete() + print(f"{feature_group.name} is deleted") except Exception: raise RuntimeError(f"Failed to delete feature group with name {feature_group.name}") + + +@contextmanager +def cleanup_offline_store(table_name: str, feature_store_session: Session): + try: + yield + finally: + try: + region_name = feature_store_session.boto_session.region_name + s3_client = feature_store_session.boto_session.client( + service_name="s3", region_name=region_name + ) + account_id = feature_store_session.account_id() + bucket_name = f"sagemaker-test-featurestore-{region_name}-{account_id}" + response = s3_client.list_objects_v2( + Bucket=bucket_name, + Prefix=f"{account_id}/sagemaker/{region_name}/offline-store/{table_name}/", + ) + files_in_folder = response["Contents"] + files_to_delete = [] + for f in files_in_folder: + files_to_delete.append({"Key": f["Key"]}) + s3_client.delete_objects(Bucket=bucket_name, Delete={"Objects": files_to_delete}) + except Exception: + raise RuntimeError(f"Failed to delete data under {table_name}") diff --git a/tests/unit/sagemaker/feature_store/test_dataset_builder.py b/tests/unit/sagemaker/feature_store/test_dataset_builder.py new file mode 100644 index 0000000000..0e55b86bd0 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_dataset_builder.py @@ -0,0 +1,612 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import datetime + +import pandas as pd +import pytest +import os +from mock import Mock, patch + +from sagemaker.feature_store.dataset_builder import ( + DatasetBuilder, + FeatureGroupToBeMerged, + TableType, +) +from sagemaker.feature_store.feature_group import ( + FeatureDefinition, + FeatureGroup, + FeatureTypeEnum, +) + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +@pytest.fixture +def feature_group_mock(): + return Mock() + + +@pytest.fixture +def read_csv_mock(): + return Mock() + + +@pytest.fixture +def to_csv_file_mock(): + return Mock() + + +@pytest.fixture +def remove_mock(): + return Mock() + + +BASE = FeatureGroupToBeMerged( + ["target-feature", "other-feature"], + ["target-feature", "other-feature"], + ["target-feature", "other-feature"], + "catalog", + "database", + "base-table", + "target-feature", + FeatureDefinition("other-feature", FeatureTypeEnum.STRING), + None, + TableType.FEATURE_GROUP, +) +FEATURE_GROUP = FeatureGroupToBeMerged( + ["feature-1", "feature-2"], + ["feature-1", "feature-2"], + ["feature-1", "feature-2"], + "catalog", + "database", + "table-name", + "feature-1", + FeatureDefinition("feature-2", FeatureTypeEnum.FRACTIONAL), + "target-feature", + TableType.FEATURE_GROUP, +) + + +def test_with_feature_group_throw_runtime_error(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + sagemaker_session_mock.describe_feature_group.return_value = {"OfflineStoreConfig": {}} + with pytest.raises(RuntimeError) as error: + dataset_builder.with_feature_group( + feature_group, "target-feature", ["feature-1", "feature-2"] + ) + assert "No metastore is configured with FeatureGroup MyFeatureGroup." in str(error) + + +def test_with_feature_group(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + feature_group.load_feature_definitions(dataframe) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + record_identifier_feature_name="target-feature", + ) + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}}, + "RecordIdentifierFeatureName": "feature-1", + "EventTimeFeatureName": "feature-2", + "FeatureDefinitions": [ + {"FeatureName": "feature-1", "FeatureType": "String"}, + {"FeatureName": "feature-2", "FeatureType": "String"}, + ], + } + dataset_builder.with_feature_group(feature_group, "target-feature", ["feature-1", "feature-2"]) + assert len(dataset_builder._feature_groups_to_be_merged) == 1 + assert dataset_builder._feature_groups_to_be_merged[0].features == [ + "feature-1", + "feature-2", + ] + assert dataset_builder._feature_groups_to_be_merged[0].included_feature_names == [ + "feature-1", + "feature-2", + ] + assert dataset_builder._feature_groups_to_be_merged[0].database == "database" + assert dataset_builder._feature_groups_to_be_merged[0].table_name == "table" + assert ( + dataset_builder._feature_groups_to_be_merged[0].record_identifier_feature_name + == "feature-1" + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_name + == "feature-2" + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].event_time_identifier_feature.feature_type + == FeatureTypeEnum.STRING + ) + assert ( + dataset_builder._feature_groups_to_be_merged[0].target_feature_name_in_base + == "target-feature" + ) + + +def test_point_in_time_accurate_join(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.point_in_time_accurate_join() + assert dataset_builder._point_in_time_accurate_join + + +def test_include_duplicated_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.include_duplicated_records() + assert dataset_builder._include_duplicated_records + + +def test_include_deleted_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.include_deleted_records() + assert dataset_builder._include_deleted_records + + +def test_with_number_of_recent_records_by_record_identifier( + sagemaker_session_mock, feature_group_mock +): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.with_number_of_recent_records_by_record_identifier(5) + assert dataset_builder._number_of_recent_records == 5 + + +def test_with_number_of_records_from_query_results(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder.with_number_of_records_from_query_results(100) + assert dataset_builder._number_of_records == 100 + + +def test_with_event_time_range(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + start = datetime.datetime.now() + end = start + datetime.timedelta(minutes=1) + dataset_builder.with_event_time_range(start, end) + assert dataset_builder._event_time_starting_timestamp == start + assert dataset_builder._event_time_ending_timestamp == end + + +def test_to_csv_file_not_support_base_type(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + with pytest.raises(ValueError) as error: + dataset_builder.to_csv_file() + assert "Base must be either a FeatureGroup or a DataFrame." in str(error) + + +def test_to_csv_file_with_feature_group(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": {"DataCatalogConfig": {"TableName": "table", "Database": "database"}}, + "RecordIdentifierFeatureName": "feature-1", + "EventTimeFeatureName": "feature-2", + "FeatureDefinitions": [ + {"FeatureName": "feature-1", "FeatureType": "String"}, + {"FeatureName": "feature-2", "FeatureType": "String"}, + ], + } + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": { + "Status": {"State": "SUCCEEDED"}, + "ResultConfiguration": {"OutputLocation": "s3-file-path"}, + "Query": "query-string", + } + } + file_path, query_string = dataset_builder.to_csv_file() + assert file_path == "s3-file-path" + assert query_string == "query-string" + + +@patch("pandas.DataFrame.to_csv") +@patch("pandas.read_csv") +@patch("os.remove") +def test_to_dataframe_with_dataframe( + remove_mock, read_csv_mock, to_csv_file_mock, sagemaker_session_mock +): + dataframe = pd.DataFrame({"feature-1": [420, 380.0, 390], "feature-2": [50, 40.0, 45]}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="s3://file/to/path", + event_time_identifier_feature_name="feature-2", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": { + "Status": {"State": "SUCCEEDED"}, + "ResultConfiguration": {"OutputLocation": "s3://s3-file-path"}, + "Query": "query-string", + } + } + to_csv_file_mock.return_value = None + read_csv_mock.return_value = dataframe + os.remove.return_value = None + df, query_string = dataset_builder.to_dataframe() + assert df.equals(dataframe) + assert query_string == "query-string" + + +def test_construct_where_query_string(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + time = datetime.datetime.now().replace(microsecond=0) + start = time + datetime.timedelta(minutes=1) + end = start + datetime.timedelta(minutes=1) + dataset_builder._write_time_ending_timestamp = time + dataset_builder._event_time_starting_timestamp = start + dataset_builder._event_time_ending_timestamp = end + query_string = dataset_builder._construct_where_query_string( + "suffix", + FeatureDefinition("event-time", FeatureTypeEnum.STRING), + ["NOT is_deleted"], + ) + assert ( + query_string + == "WHERE NOT is_deleted\n" + + f"AND table_suffix.\"write_time\" <= to_timestamp('{time}', " + + "'yyyy-mm-dd hh24:mi:ss')\n" + + 'AND from_iso8601_timestamp(table_suffix."event-time") >= ' + + f"from_unixtime({start.timestamp()})\n" + + 'AND from_iso8601_timestamp(table_suffix."event-time") <= ' + + f"from_unixtime({end.timestamp()})" + ) + + +def test_construct_query_string_with_duplicated_records(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + dataset_builder._include_duplicated_records = True + + dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP] + query_string = dataset_builder._construct_query_string(BASE) + assert ( + query_string + == "WITH fg_base AS (WITH deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature"\n' + + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" DESC, ' + + 'origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."target-feature", table_base."other-feature"\n' + + "FROM (\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + 'FROM "database"."base-table" table_base\n' + + "LEFT JOIN deleted_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + 'WHERE deleted_base."target-feature" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM deleted_base\n" + + 'JOIN "database"."base-table" table_base\n' + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + "AND (\n" + + 'table_base."other-feature" > deleted_base."other-feature"\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND ' + + 'table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1"\n' + + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, ' + + 'origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."feature-1", table_0."feature-2"\n' + + "FROM (\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + 'FROM "database"."table-name" table_0\n' + + "LEFT JOIN deleted_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + 'WHERE deleted_0."feature-1" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM deleted_0\n" + + 'JOIN "database"."table-name" table_0\n' + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + "AND (\n" + + 'table_0."feature-2" > deleted_0."feature-2"\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" > ' + + 'deleted_0."api_invocation_time")\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND table_0."api_invocation_time" = ' + + 'deleted_0."api_invocation_time" AND table_0."write_time" > deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n' + + "FROM (\n" + + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as ' + + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n' + + 'PARTITION BY fg_base."target-feature"\n' + + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."target-feature" = fg_0."feature-1"\n' + + ")\n" + ) + + +def test_construct_query_string(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group, + output_path="file/to/path", + ) + dataset_builder._point_in_time_accurate_join = True + dataset_builder._event_time_identifier_feature_name = "target-feature" + dataset_builder._feature_groups_to_be_merged = [FEATURE_GROUP] + query_string = dataset_builder._construct_query_string(BASE) + assert ( + query_string + == "WITH fg_base AS (WITH table_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature", origin_base."other-feature"\n' + + 'ORDER BY origin_base."api_invocation_time" DESC, origin_base."write_time" DESC\n' + + ") AS dedup_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + ")\n" + + "WHERE dedup_row_base = 1\n" + + "),\n" + + "deleted_base AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_base."target-feature"\n' + + 'ORDER BY origin_base."other-feature" DESC, origin_base."api_invocation_time" ' + + 'DESC, origin_base."write_time" DESC\n' + + ") AS deleted_row_base\n" + + 'FROM "database"."base-table" origin_base\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_base = 1\n" + + ")\n" + + 'SELECT table_base."target-feature", table_base."other-feature"\n' + + "FROM (\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM table_base\n" + + "LEFT JOIN deleted_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + 'WHERE deleted_base."target-feature" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_base."target-feature", table_base."other-feature", ' + + 'table_base."write_time"\n' + + "FROM deleted_base\n" + + "JOIN table_base\n" + + 'ON table_base."target-feature" = deleted_base."target-feature"\n' + + "AND (\n" + + 'table_base."other-feature" > deleted_base."other-feature"\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" > deleted_base."api_invocation_time")\n' + + 'OR (table_base."other-feature" = deleted_base."other-feature" AND ' + + 'table_base."api_invocation_time" = deleted_base."api_invocation_time" AND ' + + 'table_base."write_time" > deleted_base."write_time")\n' + + ")\n" + + ") AS table_base\n" + + "),\n" + + "fg_0 AS (WITH table_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1", origin_0."feature-2"\n' + + 'ORDER BY origin_0."api_invocation_time" DESC, origin_0."write_time" DESC\n' + + ") AS dedup_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + ")\n" + + "WHERE dedup_row_0 = 1\n" + + "),\n" + + "deleted_0 AS (\n" + + "SELECT *\n" + + "FROM (\n" + + "SELECT *, row_number() OVER (\n" + + 'PARTITION BY origin_0."feature-1"\n' + + 'ORDER BY origin_0."feature-2" DESC, origin_0."api_invocation_time" DESC, ' + + 'origin_0."write_time" DESC\n' + + ") AS deleted_row_0\n" + + 'FROM "database"."table-name" origin_0\n' + + "WHERE is_deleted\n" + + ")\n" + + "WHERE deleted_row_0 = 1\n" + + ")\n" + + 'SELECT table_0."feature-1", table_0."feature-2"\n' + + "FROM (\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM table_0\n" + + "LEFT JOIN deleted_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + 'WHERE deleted_0."feature-1" IS NULL\n' + + "UNION ALL\n" + + 'SELECT table_0."feature-1", table_0."feature-2", table_0."write_time"\n' + + "FROM deleted_0\n" + + "JOIN table_0\n" + + 'ON table_0."feature-1" = deleted_0."feature-1"\n' + + "AND (\n" + + 'table_0."feature-2" > deleted_0."feature-2"\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND ' + + 'table_0."api_invocation_time" > deleted_0."api_invocation_time")\n' + + 'OR (table_0."feature-2" = deleted_0."feature-2" AND ' + + 'table_0."api_invocation_time" = deleted_0."api_invocation_time" AND ' + + 'table_0."write_time" > deleted_0."write_time")\n' + + ")\n" + + ") AS table_0\n" + + ")\n" + + 'SELECT target-feature, other-feature, "feature-1.1", "feature-2.1"\n' + + "FROM (\n" + + 'SELECT fg_base.target-feature, fg_base.other-feature, fg_0."feature-1" as ' + + '"feature-1.1", fg_0."feature-2" as "feature-2.1", row_number() OVER (\n' + + 'PARTITION BY fg_base."target-feature"\n' + + 'ORDER BY fg_base."other-feature" DESC, fg_0."feature-2" DESC\n' + + ") AS row_recent\n" + + "FROM fg_base\n" + + "JOIN fg_0\n" + + 'ON fg_base."target-feature" = fg_0."feature-1"\n' + + 'AND from_unixtime(fg_base."target-feature") >= from_unixtime(fg_0."feature-2")\n' + + ")\n" + ) + + +def test_create_temp_table(sagemaker_session_mock): + dataframe = pd.DataFrame({"feature-1": [420, 380, 390], "feature-2": [50, 40, 45]}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "SUCCEEDED"}} + } + dataset_builder._create_temp_table("table-name", "s3-folder") + assert sagemaker_session_mock.start_query_execution.call_count == 1 + sagemaker_session_mock.start_query_execution.assert_called_once_with( + catalog="AwsDataCatalog", + database="sagemaker_featurestore", + query_string="CREATE EXTERNAL TABLE table-name (feature-1 INT, feature-2 INT) " + + "ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.OpenCSVSerde' " + + 'WITH SERDEPROPERTIES ("separatorChar" = ",", "quoteChar" = "`", "escapeChar" = "\\\\") ' + + "LOCATION 's3-folder';", + output_location="file/to/path", + kms_key=None, + ) + + +@pytest.mark.parametrize( + "column, expected", + [ + ("feature-1", "feature-1 STRING"), + ("feature-2", "feature-2 INT"), + ("feature-3", "feature-3 DOUBLE"), + ("feature-4", "feature-4 BOOLEAN"), + ("feature-5", "feature-5 TIMESTAMP"), + ], +) +def test_construct_athena_table_column_string(column, expected, sagemaker_session_mock): + dataframe = pd.DataFrame( + { + "feature-1": ["420"], + "feature-2": [50], + "feature-3": [5.0], + "feature-4": [True], + "feature-5": [pd.Timestamp(1513393355)], + } + ) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + query_string = dataset_builder._construct_athena_table_column_string(column) + assert query_string == expected + + +def test_construct_athena_table_column_string_not_support_column_type( + sagemaker_session_mock, +): + dataframe = pd.DataFrame({"feature": pd.Series([1] * 3, dtype="int8")}) + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=dataframe, + output_path="file/to/path", + ) + with pytest.raises(RuntimeError) as error: + dataset_builder._construct_athena_table_column_string("feature") + assert "The dataframe type int8 is not supported yet." in str(error) + + +def test_run_query_throw_runtime_error(sagemaker_session_mock, feature_group_mock): + dataset_builder = DatasetBuilder( + sagemaker_session=sagemaker_session_mock, + base=feature_group_mock, + output_path="file/to/path", + ) + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query-id"} + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "FAILED"}} + } + with pytest.raises(RuntimeError) as error: + dataset_builder._run_query("query-string", "catalog", "database") + assert "Failed to execute query query-id." in str(error) diff --git a/tests/unit/sagemaker/feature_store/test_feature_group.py b/tests/unit/sagemaker/feature_store/test_feature_group.py new file mode 100644 index 0000000000..dce38fe426 --- /dev/null +++ b/tests/unit/sagemaker/feature_store/test_feature_group.py @@ -0,0 +1,580 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +# language governing permissions and limitations under the License. +from __future__ import absolute_import + + +import pandas as pd +import pytest +from mock import Mock, patch, MagicMock +from botocore.exceptions import ProfileNotFound + +from sagemaker.feature_store.feature_definition import ( + FractionalFeatureDefinition, + IntegralFeatureDefinition, + StringFeatureDefinition, + FeatureTypeEnum, +) +from sagemaker.feature_store.feature_group import ( + FeatureGroup, + IngestionManagerPandas, + AthenaQuery, + IngestionError, +) +from sagemaker.feature_store.inputs import FeatureParameter + + +class PicklableMock(Mock): + def __reduce__(self): + return (Mock, ()) + + +@pytest.fixture +def role_arn(): + return "arn:role" + + +@pytest.fixture +def s3_uri(): + return "s3://some/uri" + + +@pytest.fixture +def sagemaker_session_mock(): + return Mock() + + +@pytest.fixture +def fs_runtime_client_config_mock(): + return PicklableMock() + + +@pytest.fixture +def feature_group_dummy_definitions(): + return [ + FractionalFeatureDefinition(feature_name="feature1"), + IntegralFeatureDefinition(feature_name="feature2"), + StringFeatureDefinition(feature_name="feature3"), + ] + + +@pytest.fixture +def create_table_ddl(): + return ( + "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" + " feature1 FLOAT\n" + " feature2 INT\n" + " feature3 STRING\n" + " write_time TIMESTAMP\n" + " event_time TIMESTAMP\n" + " is_deleted BOOLEAN\n" + ")\n" + "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" + " STORED AS\n" + " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" + " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" + "LOCATION 's3://resolved_output_s3_uri'" + ) + + +def test_feature_store_create( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=s3_uri, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + offline_store_config={ + "DisableGlueTableCreation": False, + "S3StorageConfig": {"S3Uri": s3_uri}, + }, + ) + + +def test_feature_store_create_online_only( + sagemaker_session_mock, role_arn, feature_group_dummy_definitions +): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + feature_group.create( + s3_uri=False, + record_identifier_name="feature1", + event_time_feature_name="feature2", + role_arn=role_arn, + enable_online_store=True, + ) + sagemaker_session_mock.create_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_name="feature1", + event_time_feature_name="feature2", + feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], + role_arn=role_arn, + description=None, + tags=None, + online_store_config={"EnableOnlineStore": True}, + ) + + +def test_feature_store_delete(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.delete() + sagemaker_session_mock.delete_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup" + ) + + +def test_feature_store_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe() + sagemaker_session_mock.describe_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", next_token=None + ) + + +def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.update(feature_group_dummy_definitions) + sagemaker_session_mock.update_feature_group.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], + ) + + +def test_feature_metadata_update(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + + parameter_additions = [FeatureParameter(key="key1", value="value1")] + parameter_removals = ["key2"] + + feature_group.update_feature_metadata( + feature_name="Feature1", + description="TestDescription", + parameter_additions=parameter_additions, + parameter_removals=parameter_removals, + ) + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[pa.to_dict() for pa in parameter_additions], + parameter_removals=parameter_removals, + ) + feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") + sagemaker_session_mock.update_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", + feature_name="Feature1", + description="TestDescription", + parameter_additions=[], + parameter_removals=[], + ) + + +def test_feature_metadata_describe(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.describe_feature_metadata(feature_name="Feature1") + sagemaker_session_mock.describe_feature_metadata.assert_called_with( + feature_group_name="MyFeatureGroup", feature_name="Feature1" + ) + + +def test_get_record(sagemaker_session_mock): + feature_group_name = "MyFeatureGroup" + feature_names = ["MyFeature1", "MyFeature2"] + record_identifier_value_as_string = "1.0" + feature_group = FeatureGroup(name=feature_group_name, sagemaker_session=sagemaker_session_mock) + feature_group.get_record( + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=feature_names, + ) + sagemaker_session_mock.get_record.assert_called_with( + feature_group_name=feature_group_name, + record_identifier_value_as_string=record_identifier_value_as_string, + feature_names=feature_names, + ) + + +def test_put_record(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + feature_group.put_record(record=[]) + sagemaker_session_mock.put_record.assert_called_with( + feature_group_name="MyFeatureGroup", record=[] + ) + + +def test_delete_record(sagemaker_session_mock): + feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) + record_identifier_value_as_string = "1.0" + event_time = "2022-09-14" + feature_group.delete_record( + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + sagemaker_session_mock.delete_record.assert_called_with( + feature_group_name="MyFeatureGroup", + record_identifier_value_as_string=record_identifier_value_as_string, + event_time=event_time, + ) + + +def test_load_feature_definition(sagemaker_session_mock): + feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "string": pd.Series(["f1"], dtype="string"), + } + ) + feature_definitions = feature_group.load_feature_definitions(data_frame=df) + names = [fd.feature_name for fd in feature_definitions] + types = [fd.feature_type for fd in feature_definitions] + assert names == ["float", "int", "string"] + assert types == [ + FeatureTypeEnum.FRACTIONAL, + FeatureTypeEnum.INTEGRAL, + FeatureTypeEnum.STRING, + ] + + +def test_load_feature_definition_unsupported_types(sagemaker_session_mock): + feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame( + { + "float": pd.Series([2.0], dtype="float64"), + "int": pd.Series([2], dtype="int64"), + "bool": pd.Series([True], dtype="bool"), + } + ) + with pytest.raises(ValueError) as error: + feature_group.load_feature_definitions(data_frame=df) + assert "Failed to infer Feature type based on dtype bool for column bool." in str(error) + + +def test_ingest_zero_processes(): + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = Mock() + with pytest.raises(RuntimeError) as error: + feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) + + assert "max_processes must be greater than 0." in str(error) + + +def test_ingest_zero_workers(): + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = Mock() + with pytest.raises(RuntimeError) as error: + feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) + + assert "max_workers must be greater than 0." in str(error) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10) + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name=None, + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") +def test_ingest_with_profile_name( + ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock +): + sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( + fs_runtime_client_config_mock + ) + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) + + mock_ingestion_manager_instance = Mock() + ingestion_manager_init.return_value = mock_ingestion_manager_instance + feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") + + ingestion_manager_init.assert_called_once_with( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + max_processes=1, + profile_name="profile_name", + ) + mock_ingestion_manager_instance.run.assert_called_once_with( + data_frame=df, wait=True, timeout=None + ) + + +def test_as_hive_ddl_with_default_values( + create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock +): + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://some-bucket", + "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", + } + } + } + sagemaker_session_mock.account_id.return_value = "1234" + sagemaker_session_mock.boto_session.region_name = "us-west-2" + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + assert ( + create_table_ddl.format( + database="sagemaker_featurestore", + table_name="MyGroup", + account="1234", + region="us-west-2", + feature_group_name="MyGroup", + ) + == feature_group.as_hive_ddl() + ) + + +def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): + sagemaker_session_mock.describe_feature_group.return_value = { + "OfflineStoreConfig": { + "S3StorageConfig": { + "S3Uri": "s3://some-bucket", + "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", + } + } + } + sagemaker_session_mock.account_id.return_value = "1234" + sagemaker_session_mock.boto_session.region_name = "us-west-2" + + feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) + feature_group.feature_definitions = feature_group_dummy_definitions + assert create_table_ddl.format( + database="MyDatabase", + table_name="MyTable", + account="1234", + region="us-west-2", + feature_group_name="MyGroup", + ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable") + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process", + MagicMock(), +) +def test_ingestion_manager_run_success(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=10, + ) + manager.run(df) + + manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded", + PicklableMock(return_value=[]), +) +def test_ingestion_manager_run_multi_process_with_multi_thread_success( + fs_runtime_client_config_mock, +): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=2, + max_processes=2, + ) + manager.run(df) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(return_value=[1]), +) +def test_ingestion_manager_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + ) + + with pytest.raises(IngestionError) as error: + manager.run(df) + + assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert error.value.failed_rows == [1] + assert manager.failed_rows == [1] + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + MagicMock(side_effect=ProfileNotFound(profile="non_exist")), +) +def test_ingestion_manager_with_profile_name_run_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, + max_workers=1, + profile_name="non_exist", + ) + + try: + manager.run(df) + except Exception as e: + assert "The config profile (non_exist) could not be found" in str(e) + + +@patch( + "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", + PicklableMock(return_value=[1]), +) +def test_ingestion_manager_run_multi_process_failure(): + df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) + manager = IngestionManagerPandas( + feature_group_name="MyGroup", + sagemaker_fs_runtime_client_config=None, + max_workers=2, + max_processes=2, + ) + + with pytest.raises(IngestionError) as error: + manager.run(df) + + assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) + assert error.value.failed_rows == [1, 1, 1, 1] + assert manager.failed_rows == [1, 1, 1, 1] + + +@pytest.fixture +def query(sagemaker_session_mock): + return AthenaQuery( + catalog="catalog", + database="database", + table_name="table_name", + sagemaker_session=sagemaker_session_mock, + ) + + +def test_athena_query_run(sagemaker_session_mock, query): + sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} + query.run( + query_string="query", output_location="s3://some-bucket/some-path", workgroup="workgroup" + ) + sagemaker_session_mock.start_query_execution.assert_called_with( + catalog="catalog", + database="database", + query_string="query", + output_location="s3://some-bucket/some-path", + kms_key=None, + workgroup="workgroup", + ) + assert "some-bucket" == query._result_bucket + assert "some-path" == query._result_file_prefix + assert "query_id" == query._current_query_execution_id + + +def test_athena_query_wait(sagemaker_session_mock, query): + query._current_query_execution_id = "query_id" + query.wait() + sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") + + +def test_athena_query_get_query_execution(sagemaker_session_mock, query): + query._current_query_execution_id = "query_id" + query.get_query_execution() + sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +@patch("pandas.read_csv") +def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "SUCCEEDED"}} + } + query._current_query_execution_id = "query_id" + query._result_bucket = "bucket" + query._result_file_prefix = "prefix" + query.as_dataframe() + sagemaker_session_mock.download_athena_query_result.assert_called_with( + bucket="bucket", + prefix="prefix", + query_execution_id="query_id", + filename="tmp/query_id.csv", + ) + read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "FAILED"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Failed to execute query query_id" in str(error) + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "QUEUED"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Current query query_id is still being executed" in str(error) + + +@patch("tempfile.gettempdir", Mock(return_value="tmp")) +def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): + sagemaker_session_mock.get_query_execution.return_value = { + "QueryExecution": {"Status": {"State": "RUNNING"}} + } + query._current_query_execution_id = "query_id" + with pytest.raises(RuntimeError) as error: + query.as_dataframe() + assert "Current query query_id is still being executed" in str(error) diff --git a/tests/unit/sagemaker/feature_store/test_feature_store.py b/tests/unit/sagemaker/feature_store/test_feature_store.py index 92ba35573c..073daca9ea 100644 --- a/tests/unit/sagemaker/feature_store/test_feature_store.py +++ b/tests/unit/sagemaker/feature_store/test_feature_store.py @@ -10,46 +10,17 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -# language governing permissions and limitations under the License. from __future__ import absolute_import +import datetime import pandas as pd import pytest -from mock import Mock, patch, MagicMock -from botocore.exceptions import ProfileNotFound - -from sagemaker.feature_store.feature_definition import ( - FractionalFeatureDefinition, - IntegralFeatureDefinition, - StringFeatureDefinition, - FeatureTypeEnum, -) -from sagemaker.feature_store.feature_group import ( - FeatureGroup, - IngestionManagerPandas, - AthenaQuery, - IngestionError, -) -from sagemaker.feature_store.inputs import ( - FeatureParameter, - TableFormatEnum, -) - +from mock import Mock -class PicklableMock(Mock): - def __reduce__(self): - return (Mock, ()) +from sagemaker.feature_store.feature_store import FeatureStore - -@pytest.fixture -def role_arn(): - return "arn:role" - - -@pytest.fixture -def s3_uri(): - return "s3://some/uri" +DATAFRAME = pd.DataFrame({"feature_1": [420, 380, 390], "feature_2": [50, 40, 45]}) @pytest.fixture @@ -58,558 +29,108 @@ def sagemaker_session_mock(): @pytest.fixture -def fs_runtime_client_config_mock(): - return PicklableMock() - - -@pytest.fixture -def feature_group_dummy_definitions(): - return [ - FractionalFeatureDefinition(feature_name="feature1"), - IntegralFeatureDefinition(feature_name="feature2"), - StringFeatureDefinition(feature_name="feature3"), - ] - - -@pytest.fixture -def create_table_ddl(): - return ( - "CREATE EXTERNAL TABLE IF NOT EXISTS {database}.{table_name} (\n" - " feature1 FLOAT\n" - " feature2 INT\n" - " feature3 STRING\n" - " write_time TIMESTAMP\n" - " event_time TIMESTAMP\n" - " is_deleted BOOLEAN\n" - ")\n" - "ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'\n" - " STORED AS\n" - " INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat'\n" - " OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat'\n" - "LOCATION 's3://resolved_output_s3_uri'" - ) - - -def test_feature_store_create( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_iceberg_table_format( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - disable_glue_table_creation=False, - table_format=TableFormatEnum.ICEBERG, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "TableFormat": "Iceberg", - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_glue_table_format( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions, s3_uri -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=s3_uri, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - disable_glue_table_creation=False, - table_format=TableFormatEnum.GLUE, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - offline_store_config={ - "DisableGlueTableCreation": False, - "TableFormat": "Glue", - "S3StorageConfig": {"S3Uri": s3_uri}, - }, - ) - - -def test_feature_store_create_online_only( - sagemaker_session_mock, role_arn, feature_group_dummy_definitions -): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - feature_group.create( - s3_uri=False, - record_identifier_name="feature1", - event_time_feature_name="feature2", - role_arn=role_arn, - enable_online_store=True, - ) - sagemaker_session_mock.create_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - record_identifier_name="feature1", - event_time_feature_name="feature2", - feature_definitions=[fd.to_dict() for fd in feature_group_dummy_definitions], - role_arn=role_arn, - description=None, - tags=None, - online_store_config={"EnableOnlineStore": True}, - ) - - -def test_feature_store_delete(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.delete() - sagemaker_session_mock.delete_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup" - ) - - -def test_feature_store_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.describe() - sagemaker_session_mock.describe_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", next_token=None - ) - - -def test_feature_store_update(sagemaker_session_mock, feature_group_dummy_definitions): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.update(feature_group_dummy_definitions) - sagemaker_session_mock.update_feature_group.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_additions=[fd.to_dict() for fd in feature_group_dummy_definitions], - ) - - -def test_feature_metadata_update(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - - parameter_additions = [FeatureParameter(key="key1", value="value1")] - parameter_removals = ["key2"] - - feature_group.update_feature_metadata( - feature_name="Feature1", - description="TestDescription", - parameter_additions=parameter_additions, - parameter_removals=parameter_removals, - ) - sagemaker_session_mock.update_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_name="Feature1", - description="TestDescription", - parameter_additions=[pa.to_dict() for pa in parameter_additions], - parameter_removals=parameter_removals, - ) - feature_group.update_feature_metadata(feature_name="Feature1", description="TestDescription") - sagemaker_session_mock.update_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", - feature_name="Feature1", - description="TestDescription", - parameter_additions=[], - parameter_removals=[], - ) - - -def test_feature_metadata_describe(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.describe_feature_metadata(feature_name="Feature1") - sagemaker_session_mock.describe_feature_metadata.assert_called_with( - feature_group_name="MyFeatureGroup", feature_name="Feature1" - ) - - -def test_put_record(sagemaker_session_mock): - feature_group = FeatureGroup(name="MyFeatureGroup", sagemaker_session=sagemaker_session_mock) - feature_group.put_record(record=[]) - sagemaker_session_mock.put_record.assert_called_with( - feature_group_name="MyFeatureGroup", record=[] - ) - - -def test_load_feature_definition(sagemaker_session_mock): - feature_group = FeatureGroup(name="SomeGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame( - { - "float": pd.Series([2.0], dtype="float64"), - "int": pd.Series([2], dtype="int64"), - "string": pd.Series(["f1"], dtype="string"), - } - ) - feature_definitions = feature_group.load_feature_definitions(data_frame=df) - names = [fd.feature_name for fd in feature_definitions] - types = [fd.feature_type for fd in feature_definitions] - assert names == ["float", "int", "string"] - assert types == [ - FeatureTypeEnum.FRACTIONAL, - FeatureTypeEnum.INTEGRAL, - FeatureTypeEnum.STRING, - ] +def feature_group_mock(): + return Mock() -def test_load_feature_definition_unsupported_types(sagemaker_session_mock): - feature_group = FeatureGroup(name="FailedGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame( - { - "float": pd.Series([2.0], dtype="float64"), - "int": pd.Series([2], dtype="int64"), - "object": pd.Series(["f1"], dtype="object"), - } - ) +def test_minimal_create_dataset(sagemaker_session_mock, feature_group_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=feature_group_mock, + output_path="file/to/path", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base == feature_group_mock + assert dataset_builder._output_path == "file/to/path" + + +def test_complete_create_dataset(sagemaker_session_mock, feature_group_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=feature_group_mock, + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base == feature_group_mock + assert dataset_builder._included_feature_names == ["feature_1", "feature_2"] + assert dataset_builder._output_path == "file/to/path" + assert dataset_builder._kms_key_id == "kms-key-id" + + +def test_create_dataset_with_dataframe(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + dataset_builder = feature_store.create_dataset( + base=DATAFRAME, + record_identifier_feature_name="feature_1", + event_time_identifier_feature_name="feature_2", + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", + ) + assert dataset_builder._sagemaker_session == sagemaker_session_mock + assert dataset_builder._base.equals(DATAFRAME) + assert dataset_builder._record_identifier_feature_name == "feature_1" + assert dataset_builder._event_time_identifier_feature_name == "feature_2" + assert dataset_builder._included_feature_names == ["feature_1", "feature_2"] + assert dataset_builder._output_path == "file/to/path" + assert dataset_builder._kms_key_id == "kms-key-id" + + +def test_create_dataset_with_dataframe_value_error(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) with pytest.raises(ValueError) as error: - feature_group.load_feature_definitions(data_frame=df) - assert "Failed to infer Feature type based on dtype object for column object." in str(error) - - -def test_ingest_zero_processes(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = Mock() - with pytest.raises(RuntimeError) as error: - feature_group.ingest(data_frame=df, max_workers=1, max_processes=0) - - assert "max_processes must be greater than 0." in str(error) - - -def test_ingest_zero_workers(): - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = Mock() - with pytest.raises(RuntimeError) as error: - feature_group.ingest(data_frame=df, max_workers=0, max_processes=1) - - assert "max_workers must be greater than 0." in str(error) - - -@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest(ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock): - sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( - fs_runtime_client_config_mock - ) - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) - - mock_ingestion_manager_instance = Mock() - ingestion_manager_init.return_value = mock_ingestion_manager_instance - feature_group.ingest(data_frame=df, max_workers=10) - - ingestion_manager_init.assert_called_once_with( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - max_processes=1, - profile_name=None, - ) - mock_ingestion_manager_instance.run.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) - - -@patch("sagemaker.feature_store.feature_group.IngestionManagerPandas") -def test_ingest_with_profile_name( - ingestion_manager_init, sagemaker_session_mock, fs_runtime_client_config_mock -): - sagemaker_session_mock.sagemaker_featurestore_runtime_client.meta.config = ( - fs_runtime_client_config_mock - ) - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - df = pd.DataFrame(dict((f"float{i}", pd.Series([2.0], dtype="float64")) for i in range(300))) - - mock_ingestion_manager_instance = Mock() - ingestion_manager_init.return_value = mock_ingestion_manager_instance - feature_group.ingest(data_frame=df, max_workers=10, profile_name="profile_name") - - ingestion_manager_init.assert_called_once_with( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - max_processes=1, - profile_name="profile_name", - ) - mock_ingestion_manager_instance.run.assert_called_once_with( - data_frame=df, wait=True, timeout=None - ) - - -def test_as_hive_ddl_with_default_values( - create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock -): - sagemaker_session_mock.describe_feature_group.return_value = { - "OfflineStoreConfig": { - "S3StorageConfig": { - "S3Uri": "s3://some-bucket", - "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", - } - } - } - sagemaker_session_mock.account_id.return_value = "1234" - sagemaker_session_mock.boto_session.region_name = "us-west-2" - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - assert ( - create_table_ddl.format( - database="sagemaker_featurestore", - table_name="MyGroup", - account="1234", - region="us-west-2", - feature_group_name="MyGroup", + feature_store.create_dataset( + base=DATAFRAME, + included_feature_names=["feature_1", "feature_2"], + output_path="file/to/path", + kms_key_id="kms-key-id", ) - == feature_group.as_hive_ddl() - ) - - -def test_as_hive_ddl(create_table_ddl, feature_group_dummy_definitions, sagemaker_session_mock): - sagemaker_session_mock.describe_feature_group.return_value = { - "OfflineStoreConfig": { - "S3StorageConfig": { - "S3Uri": "s3://some-bucket", - "ResolvedOutputS3Uri": "s3://resolved_output_s3_uri", - } - } - } - sagemaker_session_mock.account_id.return_value = "1234" - sagemaker_session_mock.boto_session.region_name = "us-west-2" - - feature_group = FeatureGroup(name="MyGroup", sagemaker_session=sagemaker_session_mock) - feature_group.feature_definitions = feature_group_dummy_definitions - assert create_table_ddl.format( - database="MyDatabase", - table_name="MyTable", - account="1234", - region="us-west-2", - feature_group_name="MyGroup", - ) == feature_group.as_hive_ddl(database="MyDatabase", table_name="MyTable") - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_process", - MagicMock(), -) -def test_ingestion_manager_run_success(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=10, - ) - manager.run(df) - - manager._run_multi_process.assert_called_once_with(data_frame=df, wait=True, timeout=None) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._run_multi_threaded", - PicklableMock(return_value=[]), -) -def test_ingestion_manager_run_multi_process_with_multi_thread_success( - fs_runtime_client_config_mock, -): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=2, - max_processes=2, - ) - manager.run(df) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(return_value=[1]), -) -def test_ingestion_manager_run_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=1, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1] - assert manager.failed_rows == [1] - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - MagicMock(side_effect=ProfileNotFound(profile="non_exist")), -) -def test_ingestion_manager_with_profile_name_run_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=fs_runtime_client_config_mock, - max_workers=1, - profile_name="non_exist", - ) - - try: - manager.run(df) - except Exception as e: - assert "The config profile (non_exist) could not be found" in str(e) - - -@patch( - "sagemaker.feature_store.feature_group.IngestionManagerPandas._ingest_single_batch", - PicklableMock(return_value=[1]), -) -def test_ingestion_manager_run_multi_process_failure(): - df = pd.DataFrame({"float": pd.Series([2.0], dtype="float64")}) - manager = IngestionManagerPandas( - feature_group_name="MyGroup", - sagemaker_fs_runtime_client_config=None, - max_workers=2, - max_processes=2, - ) - - with pytest.raises(IngestionError) as error: - manager.run(df) - - assert "Failed to ingest some data into FeatureGroup MyGroup" in str(error) - assert error.value.failed_rows == [1, 1, 1, 1] - assert manager.failed_rows == [1, 1, 1, 1] - - -@pytest.fixture -def query(sagemaker_session_mock): - return AthenaQuery( - catalog="catalog", - database="database", - table_name="table_name", - sagemaker_session=sagemaker_session_mock, - ) - - -def test_athena_query_run(sagemaker_session_mock, query): - WORKGROUP = "workgroup" - sagemaker_session_mock.start_query_execution.return_value = {"QueryExecutionId": "query_id"} - query.run( - query_string="query", output_location="s3://some-bucket/some-path", workgroup=WORKGROUP - ) - sagemaker_session_mock.start_query_execution.assert_called_with( - catalog="catalog", - database="database", - query_string="query", - output_location="s3://some-bucket/some-path", - kms_key=None, - workgroup=WORKGROUP, - ) - assert "some-bucket" == query._result_bucket - assert "some-path" == query._result_file_prefix - assert "query_id" == query._current_query_execution_id - - -def test_athena_query_wait(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.wait() - sagemaker_session_mock.wait_for_athena_query.assert_called_with(query_execution_id="query_id") - - -def test_athena_query_get_query_execution(sagemaker_session_mock, query): - query._current_query_execution_id = "query_id" - query.get_query_execution() - sagemaker_session_mock.get_query_execution.assert_called_with(query_execution_id="query_id") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -@patch("pandas.read_csv") -def test_athena_query_as_dataframe(read_csv, sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "SUCCEEDED"}} - } - query._current_query_execution_id = "query_id" - query._result_bucket = "bucket" - query._result_file_prefix = "prefix" - query.as_dataframe() - sagemaker_session_mock.download_athena_query_result.assert_called_with( - bucket="bucket", - prefix="prefix", - query_execution_id="query_id", - filename="tmp/query_id.csv", + assert ( + "You must provide a record identifier feature name and an event time identifier feature " + + "name if specify DataFrame as base." + in str(error) + ) + + +def test_list_feature_groups_with_no_filter(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + feature_store.list_feature_groups() + sagemaker_session_mock.list_feature_groups.assert_called_with( + name_contains=None, + feature_group_status_equals=None, + offline_store_status_equals=None, + creation_time_after=None, + creation_time_before=None, + sort_order=None, + sort_by=None, + max_results=None, + next_token=None, + ) + + +def test_list_feature_groups_with_all_filters(sagemaker_session_mock): + feature_store = FeatureStore(sagemaker_session=sagemaker_session_mock) + feature_store.list_feature_groups( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", + ) + sagemaker_session_mock.list_feature_groups.assert_called_with( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", ) - read_csv.assert_called_with("tmp/query_id.csv", delimiter=",") - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_failed(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "FAILED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Failed to execute query query_id" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_queued(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "QUEUED"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) - - -@patch("tempfile.gettempdir", Mock(return_value="tmp")) -def test_athena_query_as_dataframe_query_running(sagemaker_session_mock, query): - sagemaker_session_mock.get_query_execution.return_value = { - "QueryExecution": {"Status": {"State": "RUNNING"}} - } - query._current_query_execution_id = "query_id" - with pytest.raises(RuntimeError) as error: - query.as_dataframe() - assert "Current query query_id is still being executed" in str(error) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index bf81283177..d7c94470f5 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -2787,6 +2787,35 @@ def test_feature_metadata_describe(sagemaker_session): ) +def test_list_feature_groups(sagemaker_session): + expected_list_feature_groups_args = { + "NameContains": "MyFeatureGroup", + "FeatureGroupStatusEquals": "Created", + "OfflineStoreStatusEquals": "Active", + "CreationTimeAfter": datetime.datetime(2020, 12, 1), + "CreationTimeBefore": datetime.datetime(2022, 7, 1), + "SortOrder": "Ascending", + "SortBy": "Name", + "MaxResults": 50, + "NextToken": "token", + } + sagemaker_session.list_feature_groups( + name_contains="MyFeatureGroup", + feature_group_status_equals="Created", + offline_store_status_equals="Active", + creation_time_after=datetime.datetime(2020, 12, 1), + creation_time_before=datetime.datetime(2022, 7, 1), + sort_order="Ascending", + sort_by="Name", + max_results=50, + next_token="token", + ) + assert sagemaker_session.sagemaker_client.list_feature_groups.called_once() + assert sagemaker_session.sagemaker_client.list_feature_groups.called_with( + **expected_list_feature_groups_args + ) + + def test_start_query_execution(sagemaker_session): athena_mock = Mock() sagemaker_session.boto_session.client( From fb3880f804854d8456682c4aa17de321cb5a89f9 Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 14 Dec 2022 03:40:14 +0000 Subject: [PATCH 42/43] prepare release v2.122.0 --- CHANGELOG.md | 13 +++++++++++++ VERSION | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b66e85f54..de20a8a0df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # Changelog +## v2.122.0 (2022-12-14) + +### Features + + * Feature Store dataset builder, delete_record, get_record, list_feature_group + * Add OSU region to frameworks for DLC + +### Bug Fixes and Other Changes + + * the Hyperband support fix for the HPO + * unpin packaging version + * Remove content type image/jpg from analysis configuration schema + ## v2.121.2 (2022-12-12) ### Bug Fixes and Other Changes diff --git a/VERSION b/VERSION index 8fde5e282f..202f672bab 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.121.3.dev0 +2.122.0 From a584ea5ff73ea5b6df8eec749069ec86adf2e8fc Mon Sep 17 00:00:00 2001 From: ci Date: Wed, 14 Dec 2022 03:40:15 +0000 Subject: [PATCH 43/43] update development version to v2.122.1.dev0 --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 202f672bab..6d7f044fa2 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -2.122.0 +2.122.1.dev0