From 9f9238a6878aaddc997c34ba20e89c9acc9e7dbb Mon Sep 17 00:00:00 2001 From: Molly He Date: Thu, 1 May 2025 11:52:12 -0700 Subject: [PATCH 01/12] add extract_name_mapping logic to fix get_all method --- src/sagemaker_core/main/resources.py | 157 ++++++++++++++++++ src/sagemaker_core/main/utils.py | 8 + src/sagemaker_core/tools/resources_codegen.py | 31 +++- src/sagemaker_core/tools/templates.py | 2 + 4 files changed, 195 insertions(+), 3 deletions(-) diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker_core/main/resources.py index c29f1199..6758eaab 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker_core/main/resources.py @@ -6471,6 +6471,7 @@ def get_all( "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -12609,6 +12610,84 @@ def load( region=region, ) + @classmethod + @Base.add_validate_call + def get_all( + cls, + hub_name: str, + hub_content_type: str, + name_contains: Optional[str] = Unassigned(), + max_schema_version: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["HubContent"]: + """ + Get all HubContent resources + + Parameters: + hub_name: The name of the hub to list the contents of. + hub_content_type: The type of hub content to list. + name_contains: Only list hub content if the name contains the specified string. + max_schema_version: The upper bound of the hub content schema verion. + creation_time_before: Only list hub content that was created before the time specified. + creation_time_after: Only list hub content that was created after the time specified. + sort_by: Sort hub content versions by either name or creation time. + sort_order: Sort hubs by ascending or descending order. + max_results: The maximum amount of hub content to list. + next_token: If the response to a previous ListHubContents request was truncated, the response includes a NextToken. To retrieve the next set of hub content, use the token in the next request. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed HubContent resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "HubName": hub_name, + "HubContentType": hub_content_type, + "NameContains": name_contains, + "MaxSchemaVersion": max_schema_version, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + "SortBy": sort_by, + "SortOrder": sort_order, + } + + extract_name_mapping = {"HubContentArn": ["hub-content/", "HubName"]} + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_hub_contents", + summaries_key="HubContentSummaries", + summary_name="HubContentInfo", + resource_cls=HubContent, + extract_name_mapping=extract_name_mapping, + list_method_kwargs=operation_input_args, + ) + @Base.add_validate_call def get_all_versions( self, @@ -14911,6 +14990,81 @@ def wait_for_delete( raise e time.sleep(poll) + @classmethod + @Base.add_validate_call + def get_all( + cls, + image_name: str, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["ImageVersion"]: + """ + Get all ImageVersion resources + + Parameters: + image_name: The name of the image to list the versions of. + creation_time_after: A filter that returns only versions created on or after the specified time. + creation_time_before: A filter that returns only versions created on or before the specified time. + last_modified_time_after: A filter that returns only versions modified on or after the specified time. + last_modified_time_before: A filter that returns only versions modified on or before the specified time. + max_results: The maximum number of versions to return in the response. The default value is 10. + next_token: If the previous call to ListImageVersions didn't return the full set of versions, the call returns a token for getting the next set of versions. + sort_by: The property used to sort results. The default value is CREATION_TIME. + sort_order: The sort order. The default value is DESCENDING. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed ImageVersion resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "ImageName": image_name, + "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "SortBy": sort_by, + "SortOrder": sort_order, + } + + extract_name_mapping = {"ImageVersionArn": ["image-version/", "ImageName"]} + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_image_versions", + summaries_key="ImageVersions", + summary_name="ImageVersion", + resource_cls=ImageVersion, + extract_name_mapping=extract_name_mapping, + list_method_kwargs=operation_input_args, + ) + class InferenceComponent(Base): """ @@ -18504,6 +18658,7 @@ def get_all( "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -19733,6 +19888,7 @@ def get_all( "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -21350,6 +21506,7 @@ def get_all( "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") diff --git a/src/sagemaker_core/main/utils.py b/src/sagemaker_core/main/utils.py index b3a54ff9..2c7d4359 100644 --- a/src/sagemaker_core/main/utils.py +++ b/src/sagemaker_core/main/utils.py @@ -387,6 +387,7 @@ def __init__( list_method: str, list_method_kwargs: dict = {}, custom_key_mapping: dict = None, + extract_name_mapping: dict = None, ): """Initialize a ResourceIterator object @@ -398,6 +399,7 @@ def __init__( list_method (str): The list method string used to make list calls to the client. list_method_kwargs (dict, optional): The kwargs used to make list method calls. Defaults to {}. custom_key_mapping (dict, optional): The custom key mapping used to map keys from summary object to those expected from resource object during initialization. Defaults to None. + extract_name_mapping (dict, optional): The extract name mapping used to extract names from arn in summary object and map to those expected from resource object during initialization. Defaults to None. """ self.summaries_key = summaries_key self.summary_name = summary_name @@ -405,6 +407,7 @@ def __init__( self.list_method = list_method self.list_method_kwargs = list_method_kwargs self.custom_key_mapping = custom_key_mapping + self.extract_name_mapping = extract_name_mapping self.resource_cls = resource_cls self.index = 0 @@ -433,6 +436,11 @@ def __next__(self) -> T: if self.custom_key_mapping: init_data = {self.custom_key_mapping.get(k, k): v for k, v in init_data.items()} + if self.extract_name_mapping: + for arn, target in self.extract_name_mapping.items(): + name = arn.split(target[0])[1].split("/")[0] + init_data.update({target[1]: name}) + # Filter out the fields that are not in the resource class fields = self.resource_cls.__annotations__ init_data = {k: v for k, v in init_data.items() if k in fields} diff --git a/src/sagemaker_core/tools/resources_codegen.py b/src/sagemaker_core/tools/resources_codegen.py index 523ef7cf..a8af5757 100644 --- a/src/sagemaker_core/tools/resources_codegen.py +++ b/src/sagemaker_core/tools/resources_codegen.py @@ -1754,14 +1754,25 @@ def generate_get_all_method(self, resource_name: str) -> str: get_operation_required_input = [] custom_key_mapping_str = "" + custom_key_mapping = {} + extract_name_mapping_str = "" + extract_name_mapping = {} if any(member not in summary_members for member in get_operation_required_input): - if "MonitoringJobDefinitionSummary" == summary_name: + if summary_name == "MonitoringJobDefinitionSummary": custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } - custom_key_mapping_str = f"custom_key_mapping = {json.dumps(custom_key_mapping)}" - custom_key_mapping_str = add_indent(custom_key_mapping_str, 4) + elif summary_name == "HubContentInfo": + # HubContentArn -- arn::sagemaker:::hub-content//// + extract_name_mapping = { + "HubContentArn": ("hub-content/", "HubName"), + } + elif summary_name == "ImageVersion": + # ImageVersionArn -- arn:aws:sagemaker:::image-version// + extract_name_mapping = { + "ImageVersionArn": ("image-version/", "ImageName"), + } else: log.warning( f"Resource {resource_name} summaries do not have required members to create object instance. Resource may require custom key mapping for get_all().\n" @@ -1769,6 +1780,15 @@ def generate_get_all_method(self, resource_name: str) -> str: ) return "" + if custom_key_mapping: + custom_key_mapping_str = add_indent( + f"custom_key_mapping = {json.dumps(custom_key_mapping)}", 4 + ) + if extract_name_mapping: + extract_name_mapping_str = add_indent( + f"extract_name_mapping = {json.dumps(extract_name_mapping)}", 4 + ) + resource_iterator_args_list = [ "client=client", f"list_method='{operation}'", @@ -1780,6 +1800,9 @@ def generate_get_all_method(self, resource_name: str) -> str: if custom_key_mapping_str: resource_iterator_args_list.append(f"custom_key_mapping=custom_key_mapping") + if extract_name_mapping_str: + resource_iterator_args_list.append(f"extract_name_mapping=extract_name_mapping") + exclude_list = ["next_token", "max_results"] get_all_args = self._generate_method_args(operation_input_shape_name, exclude_list) @@ -1792,6 +1815,7 @@ def generate_get_all_method(self, resource_name: str) -> str: resource=resource_name, operation=operation, custom_key_mapping=custom_key_mapping_str, + extract_name_mapping=extract_name_mapping_str, resource_iterator_args=resource_iterator_args, ) return formatted_method @@ -1822,6 +1846,7 @@ def generate_get_all_method(self, resource_name: str) -> str: get_all_args=get_all_args, operation_input_args=operation_input_args, custom_key_mapping=custom_key_mapping_str, + extract_name_mapping=extract_name_mapping_str, resource_iterator_args=resource_iterator_args, ) return formatted_method diff --git a/src/sagemaker_core/tools/templates.py b/src/sagemaker_core/tools/templates.py index dee63ce9..c8647456 100644 --- a/src/sagemaker_core/tools/templates.py +++ b/src/sagemaker_core/tools/templates.py @@ -513,6 +513,7 @@ def get_all( {operation_input_args} }} {custom_key_mapping} +{extract_name_mapping} # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") @@ -543,6 +544,7 @@ def get_all( """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}") {custom_key_mapping} +{extract_name_mapping} return ResourceIterator( {resource_iterator_args} ) From b5d8d04dcac912cf37dc3d3664fb15374af81d1f Mon Sep 17 00:00:00 2001 From: Molly He Date: Thu, 1 May 2025 12:13:12 -0700 Subject: [PATCH 02/12] switch sequence between custom_key_mapping and extract_name_mapping --- src/sagemaker_core/main/resources.py | 12 ++++++------ src/sagemaker_core/tools/templates.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker_core/main/resources.py index 6758eaab..f9d9df28 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker_core/main/resources.py @@ -6467,11 +6467,11 @@ def get_all( "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, } + custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -12672,8 +12672,8 @@ def get_all( "SortBy": sort_by, "SortOrder": sort_order, } - extract_name_mapping = {"HubContentArn": ["hub-content/", "HubName"]} + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -15049,8 +15049,8 @@ def get_all( "SortBy": sort_by, "SortOrder": sort_order, } - extract_name_mapping = {"ImageVersionArn": ["image-version/", "ImageName"]} + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -18654,11 +18654,11 @@ def get_all( "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, } + custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -19884,11 +19884,11 @@ def get_all( "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, } + custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -21502,11 +21502,11 @@ def get_all( "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, } + custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") diff --git a/src/sagemaker_core/tools/templates.py b/src/sagemaker_core/tools/templates.py index c8647456..9de38e76 100644 --- a/src/sagemaker_core/tools/templates.py +++ b/src/sagemaker_core/tools/templates.py @@ -512,8 +512,8 @@ def get_all( operation_input_args = {{ {operation_input_args} }} -{custom_key_mapping} {extract_name_mapping} +{custom_key_mapping} # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") @@ -543,8 +543,8 @@ def get_all( """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}") -{custom_key_mapping} {extract_name_mapping} +{custom_key_mapping} return ResourceIterator( {resource_iterator_args} ) From ced91c59325259447ce28e6b67b0ab35b952f581 Mon Sep 17 00:00:00 2001 From: Molly He Date: Thu, 1 May 2025 13:49:52 -0700 Subject: [PATCH 03/12] fixed minor bugs discovered in testing --- src/sagemaker_core/main/resources.py | 8 ++++++-- src/sagemaker_core/main/utils.py | 3 ++- src/sagemaker_core/tools/resources_codegen.py | 6 ++++-- src/sagemaker_core/tools/templates.py | 3 +++ 4 files changed, 15 insertions(+), 5 deletions(-) diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker_core/main/resources.py index f9d9df28..c015405e 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker_core/main/resources.py @@ -6472,6 +6472,7 @@ def get_all( "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -12672,7 +12673,7 @@ def get_all( "SortBy": sort_by, "SortOrder": sort_order, } - extract_name_mapping = {"HubContentArn": ["hub-content/", "HubName"]} + extract_name_mapping = {"hub_content_arn": ["hub-content/", "hub_name"]} # serialize the input request operation_input_args = serialize(operation_input_args) @@ -15049,7 +15050,7 @@ def get_all( "SortBy": sort_by, "SortOrder": sort_order, } - extract_name_mapping = {"ImageVersionArn": ["image-version/", "ImageName"]} + extract_name_mapping = {"image_version_arn": ["image-version/", "image_name"]} # serialize the input request operation_input_args = serialize(operation_input_args) @@ -18659,6 +18660,7 @@ def get_all( "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -19889,6 +19891,7 @@ def get_all( "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -21507,6 +21510,7 @@ def get_all( "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") diff --git a/src/sagemaker_core/main/utils.py b/src/sagemaker_core/main/utils.py index 2c7d4359..39176128 100644 --- a/src/sagemaker_core/main/utils.py +++ b/src/sagemaker_core/main/utils.py @@ -436,9 +436,10 @@ def __next__(self) -> T: if self.custom_key_mapping: init_data = {self.custom_key_mapping.get(k, k): v for k, v in init_data.items()} + # Extract name from arn. Currently implemented for HubContent and ImageVersion if self.extract_name_mapping: for arn, target in self.extract_name_mapping.items(): - name = arn.split(target[0])[1].split("/")[0] + name = init_data[arn].split(target[0])[1].split("/")[0] init_data.update({target[1]: name}) # Filter out the fields that are not in the resource class diff --git a/src/sagemaker_core/tools/resources_codegen.py b/src/sagemaker_core/tools/resources_codegen.py index a8af5757..14bf46cc 100644 --- a/src/sagemaker_core/tools/resources_codegen.py +++ b/src/sagemaker_core/tools/resources_codegen.py @@ -1765,13 +1765,15 @@ def generate_get_all_method(self, resource_name: str) -> str: } elif summary_name == "HubContentInfo": # HubContentArn -- arn::sagemaker:::hub-content//// + # {source key from input: (target label in arn, target key in output)} extract_name_mapping = { - "HubContentArn": ("hub-content/", "HubName"), + "hub_content_arn": ("hub-content/", "hub_name"), } elif summary_name == "ImageVersion": # ImageVersionArn -- arn:aws:sagemaker:::image-version// + # {source key from input: (target label in arn, target key in output)} extract_name_mapping = { - "ImageVersionArn": ("image-version/", "ImageName"), + "image_version_arn": ("image-version/", "image_name"), } else: log.warning( diff --git a/src/sagemaker_core/tools/templates.py b/src/sagemaker_core/tools/templates.py index 9de38e76..d93b17ec 100644 --- a/src/sagemaker_core/tools/templates.py +++ b/src/sagemaker_core/tools/templates.py @@ -514,6 +514,7 @@ def get_all( }} {extract_name_mapping} {custom_key_mapping} + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") @@ -544,7 +545,9 @@ def get_all( """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}") {extract_name_mapping} + {custom_key_mapping} + return ResourceIterator( {resource_iterator_args} ) From 05ec59a072e36ee439a993dd82701f94ff1dbc0f Mon Sep 17 00:00:00 2001 From: Molly He Date: Thu, 1 May 2025 13:53:06 -0700 Subject: [PATCH 04/12] remove extra white lines in templates.py --- src/sagemaker_core/tools/templates.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/sagemaker_core/tools/templates.py b/src/sagemaker_core/tools/templates.py index d93b17ec..9de38e76 100644 --- a/src/sagemaker_core/tools/templates.py +++ b/src/sagemaker_core/tools/templates.py @@ -514,7 +514,6 @@ def get_all( }} {extract_name_mapping} {custom_key_mapping} - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {{operation_input_args}}") @@ -545,9 +544,7 @@ def get_all( """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}") {extract_name_mapping} - {custom_key_mapping} - return ResourceIterator( {resource_iterator_args} ) From 0aab9810cb6538c766dcc51a0b207cf263d376f4 Mon Sep 17 00:00:00 2001 From: Molly He Date: Thu, 1 May 2025 14:05:32 -0700 Subject: [PATCH 05/12] add extra line in unit test to align with new template --- tst/tools/test_resources_codegen.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tst/tools/test_resources_codegen.py b/tst/tools/test_resources_codegen.py index fa412420..7a8b45c6 100644 --- a/tst/tools/test_resources_codegen.py +++ b/tst/tools/test_resources_codegen.py @@ -1122,6 +1122,7 @@ def get_all( 'SpaceNameEquals': space_name_equals, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -1159,6 +1160,7 @@ def get_all( """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="sagemaker") + return ResourceIterator( client=client, list_method='list_domains', @@ -1224,6 +1226,7 @@ def get_all( 'CreationTimeBefore': creation_time_before, 'CreationTimeAfter': creation_time_after, } + custom_key_mapping = {"monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn"} # serialize the input request operation_input_args = serialize(operation_input_args) From c8ce5a3c2a5308392817c260b9a4dde40f3e4c96 Mon Sep 17 00:00:00 2001 From: Molly He Date: Thu, 1 May 2025 14:51:04 -0700 Subject: [PATCH 06/12] fix extra line unit test failure --- tst/tools/test_resources_codegen.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tst/tools/test_resources_codegen.py b/tst/tools/test_resources_codegen.py index 7a8b45c6..982211dc 100644 --- a/tst/tools/test_resources_codegen.py +++ b/tst/tools/test_resources_codegen.py @@ -1122,7 +1122,8 @@ def get_all( 'SpaceNameEquals': space_name_equals, } - + + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -1160,7 +1161,9 @@ def get_all( """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="sagemaker") - + + + return ResourceIterator( client=client, list_method='list_domains', @@ -1228,6 +1231,7 @@ def get_all( } custom_key_mapping = {"monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn"} + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") From 11f4f4211fcc5b9df670c39a6ff5d6320971f13b Mon Sep 17 00:00:00 2001 From: Molly He Date: Wed, 7 May 2025 12:04:13 -0700 Subject: [PATCH 07/12] test again without extra lines --- tst/tools/test_resources_codegen.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tst/tools/test_resources_codegen.py b/tst/tools/test_resources_codegen.py index 982211dc..87051e43 100644 --- a/tst/tools/test_resources_codegen.py +++ b/tst/tools/test_resources_codegen.py @@ -1123,7 +1123,6 @@ def get_all( } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -1162,8 +1161,6 @@ def get_all( client = Base.get_sagemaker_client(session=session, region_name=region, service_name="sagemaker") - - return ResourceIterator( client=client, list_method='list_domains', From b2e85e9cd246b75723c3b83ee44d06389561708f Mon Sep 17 00:00:00 2001 From: Molly He Date: Wed, 7 May 2025 12:09:01 -0700 Subject: [PATCH 08/12] test again without extra lines --- src/sagemaker_core/main/resources.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker_core/main/resources.py index c015405e..ef55dd35 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker_core/main/resources.py @@ -1421,7 +1421,6 @@ def get_all( "UserProfileNameEquals": user_profile_name_equals, "SpaceNameEquals": space_name_equals, } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") From 308331b33157661f8f09321739b3b5105adf0b9d Mon Sep 17 00:00:00 2001 From: Molly He Date: Wed, 7 May 2025 12:11:35 -0700 Subject: [PATCH 09/12] update test_resources_codegen --- src/sagemaker_core/main/resources.py | 1 + tst/tools/test_resources_codegen.py | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker_core/main/resources.py index ef55dd35..c015405e 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker_core/main/resources.py @@ -1421,6 +1421,7 @@ def get_all( "UserProfileNameEquals": user_profile_name_equals, "SpaceNameEquals": space_name_equals, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") diff --git a/tst/tools/test_resources_codegen.py b/tst/tools/test_resources_codegen.py index 87051e43..129859e2 100644 --- a/tst/tools/test_resources_codegen.py +++ b/tst/tools/test_resources_codegen.py @@ -1122,7 +1122,6 @@ def get_all( 'SpaceNameEquals': space_name_equals, } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") From 3b5d4d7769b5ac8614d01423f43d60a22be909c9 Mon Sep 17 00:00:00 2001 From: Molly He Date: Wed, 7 May 2025 12:13:57 -0700 Subject: [PATCH 10/12] update test_resources_codegen --- tst/tools/test_resources_codegen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tst/tools/test_resources_codegen.py b/tst/tools/test_resources_codegen.py index 129859e2..9e81da26 100644 --- a/tst/tools/test_resources_codegen.py +++ b/tst/tools/test_resources_codegen.py @@ -1122,6 +1122,7 @@ def get_all( 'SpaceNameEquals': space_name_equals, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -1227,7 +1228,6 @@ def get_all( } custom_key_mapping = {"monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn"} - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") From f7bcdc63f331463aecea3feb2d24bb5ee8fe88a5 Mon Sep 17 00:00:00 2001 From: Molly He Date: Wed, 7 May 2025 15:08:06 -0700 Subject: [PATCH 11/12] add extract_name_mapping logic to unit test: --- tst/generated/test_resources.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tst/generated/test_resources.py b/tst/generated/test_resources.py index 081247ec..50f6b705 100644 --- a/tst/generated/test_resources.py +++ b/tst/generated/test_resources.py @@ -114,6 +114,22 @@ def test_resources(self, session, mock_transform): f"{name}s": [summary], f"Summaries": [summary], } + if name == "HubContent": + extract_name_mapping = { + "HubContentArn": ("hub-content/", "HubName"), + } + summary.update({"HubContentArn": "arn:aws:sagemaker:us-west-2:123456789012:hub-content/my-hub/Model/my-model/1.0"}) + for arn, target in extract_name_mapping.items(): + extracted_name = summary[arn].split(target[0])[1].split("/")[0] + summary.update({target[1]: extracted_name}) + if name == "ImageVersion": + extract_name_mapping = { + "ImageVersionArn": ("image-version/", "ImageName"), + } + summary.update({"ImageVersionArn": "arn:aws:sagemaker:us-west-2:123456789012:image-version/my-image/Model/my-model/1.0"}) + for arn, target in extract_name_mapping.items(): + extracted_name = summary[arn].split(target[0])[1].split("/")[0] + summary.update({target[1]: extracted_name}) if name == "MlflowTrackingServer": summary_response = {"TrackingServerSummaries": [summary]} with patch.object( From 706f447517041be15c8d30286bb0d91d49d0144f Mon Sep 17 00:00:00 2001 From: Molly He Date: Wed, 7 May 2025 15:12:44 -0700 Subject: [PATCH 12/12] fix codestyle with black . --- tst/generated/test_resources.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tst/generated/test_resources.py b/tst/generated/test_resources.py index 50f6b705..19125299 100644 --- a/tst/generated/test_resources.py +++ b/tst/generated/test_resources.py @@ -118,7 +118,11 @@ def test_resources(self, session, mock_transform): extract_name_mapping = { "HubContentArn": ("hub-content/", "HubName"), } - summary.update({"HubContentArn": "arn:aws:sagemaker:us-west-2:123456789012:hub-content/my-hub/Model/my-model/1.0"}) + summary.update( + { + "HubContentArn": "arn:aws:sagemaker:us-west-2:123456789012:hub-content/my-hub/Model/my-model/1.0" + } + ) for arn, target in extract_name_mapping.items(): extracted_name = summary[arn].split(target[0])[1].split("/")[0] summary.update({target[1]: extracted_name}) @@ -126,7 +130,11 @@ def test_resources(self, session, mock_transform): extract_name_mapping = { "ImageVersionArn": ("image-version/", "ImageName"), } - summary.update({"ImageVersionArn": "arn:aws:sagemaker:us-west-2:123456789012:image-version/my-image/Model/my-model/1.0"}) + summary.update( + { + "ImageVersionArn": "arn:aws:sagemaker:us-west-2:123456789012:image-version/my-image/Model/my-model/1.0" + } + ) for arn, target in extract_name_mapping.items(): extracted_name = summary[arn].split(target[0])[1].split("/")[0] summary.update({target[1]: extracted_name})