diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker_core/main/resources.py index c29f1199..c015405e 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker_core/main/resources.py @@ -6467,10 +6467,12 @@ def get_all( "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, } + custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -12609,6 +12611,84 @@ def load( region=region, ) + @classmethod + @Base.add_validate_call + def get_all( + cls, + hub_name: str, + hub_content_type: str, + name_contains: Optional[str] = Unassigned(), + max_schema_version: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["HubContent"]: + """ + Get all HubContent resources + + Parameters: + hub_name: The name of the hub to list the contents of. + hub_content_type: The type of hub content to list. + name_contains: Only list hub content if the name contains the specified string. + max_schema_version: The upper bound of the hub content schema verion. + creation_time_before: Only list hub content that was created before the time specified. + creation_time_after: Only list hub content that was created after the time specified. + sort_by: Sort hub content versions by either name or creation time. + sort_order: Sort hubs by ascending or descending order. + max_results: The maximum amount of hub content to list. + next_token: If the response to a previous ListHubContents request was truncated, the response includes a NextToken. To retrieve the next set of hub content, use the token in the next request. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed HubContent resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "HubName": hub_name, + "HubContentType": hub_content_type, + "NameContains": name_contains, + "MaxSchemaVersion": max_schema_version, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + "SortBy": sort_by, + "SortOrder": sort_order, + } + extract_name_mapping = {"hub_content_arn": ["hub-content/", "hub_name"]} + + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_hub_contents", + summaries_key="HubContentSummaries", + summary_name="HubContentInfo", + resource_cls=HubContent, + extract_name_mapping=extract_name_mapping, + list_method_kwargs=operation_input_args, + ) + @Base.add_validate_call def get_all_versions( self, @@ -14911,6 +14991,81 @@ def wait_for_delete( raise e time.sleep(poll) + @classmethod + @Base.add_validate_call + def get_all( + cls, + image_name: str, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["ImageVersion"]: + """ + Get all ImageVersion resources + + Parameters: + image_name: The name of the image to list the versions of. + creation_time_after: A filter that returns only versions created on or after the specified time. + creation_time_before: A filter that returns only versions created on or before the specified time. + last_modified_time_after: A filter that returns only versions modified on or after the specified time. + last_modified_time_before: A filter that returns only versions modified on or before the specified time. + max_results: The maximum number of versions to return in the response. The default value is 10. + next_token: If the previous call to ListImageVersions didn't return the full set of versions, the call returns a token for getting the next set of versions. + sort_by: The property used to sort results. The default value is CREATION_TIME. + sort_order: The sort order. The default value is DESCENDING. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed ImageVersion resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "ImageName": image_name, + "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "SortBy": sort_by, + "SortOrder": sort_order, + } + extract_name_mapping = {"image_version_arn": ["image-version/", "image_name"]} + + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_image_versions", + summaries_key="ImageVersions", + summary_name="ImageVersion", + resource_cls=ImageVersion, + extract_name_mapping=extract_name_mapping, + list_method_kwargs=operation_input_args, + ) + class InferenceComponent(Base): """ @@ -18500,10 +18655,12 @@ def get_all( "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, } + custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -19729,10 +19886,12 @@ def get_all( "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, } + custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -21346,10 +21505,12 @@ def get_all( "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, } + custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") diff --git a/src/sagemaker_core/main/utils.py b/src/sagemaker_core/main/utils.py index b3a54ff9..39176128 100644 --- a/src/sagemaker_core/main/utils.py +++ b/src/sagemaker_core/main/utils.py @@ -387,6 +387,7 @@ def __init__( list_method: str, list_method_kwargs: dict = {}, custom_key_mapping: dict = None, + extract_name_mapping: dict = None, ): """Initialize a ResourceIterator object @@ -398,6 +399,7 @@ def __init__( list_method (str): The list method string used to make list calls to the client. list_method_kwargs (dict, optional): The kwargs used to make list method calls. Defaults to {}. custom_key_mapping (dict, optional): The custom key mapping used to map keys from summary object to those expected from resource object during initialization. Defaults to None. + extract_name_mapping (dict, optional): The extract name mapping used to extract names from arn in summary object and map to those expected from resource object during initialization. Defaults to None. """ self.summaries_key = summaries_key self.summary_name = summary_name @@ -405,6 +407,7 @@ def __init__( self.list_method = list_method self.list_method_kwargs = list_method_kwargs self.custom_key_mapping = custom_key_mapping + self.extract_name_mapping = extract_name_mapping self.resource_cls = resource_cls self.index = 0 @@ -433,6 +436,12 @@ def __next__(self) -> T: if self.custom_key_mapping: init_data = {self.custom_key_mapping.get(k, k): v for k, v in init_data.items()} + # Extract name from arn. Currently implemented for HubContent and ImageVersion + if self.extract_name_mapping: + for arn, target in self.extract_name_mapping.items(): + name = init_data[arn].split(target[0])[1].split("/")[0] + init_data.update({target[1]: name}) + # Filter out the fields that are not in the resource class fields = self.resource_cls.__annotations__ init_data = {k: v for k, v in init_data.items() if k in fields} diff --git a/src/sagemaker_core/tools/resources_codegen.py b/src/sagemaker_core/tools/resources_codegen.py index 523ef7cf..14bf46cc 100644 --- a/src/sagemaker_core/tools/resources_codegen.py +++ b/src/sagemaker_core/tools/resources_codegen.py @@ -1754,14 +1754,27 @@ def generate_get_all_method(self, resource_name: str) -> str: get_operation_required_input = [] custom_key_mapping_str = "" + custom_key_mapping = {} + extract_name_mapping_str = "" + extract_name_mapping = {} if any(member not in summary_members for member in get_operation_required_input): - if "MonitoringJobDefinitionSummary" == summary_name: + if summary_name == "MonitoringJobDefinitionSummary": custom_key_mapping = { "monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn", } - custom_key_mapping_str = f"custom_key_mapping = {json.dumps(custom_key_mapping)}" - custom_key_mapping_str = add_indent(custom_key_mapping_str, 4) + elif summary_name == "HubContentInfo": + # HubContentArn -- arn::sagemaker:::hub-content//// + # {source key from input: (target label in arn, target key in output)} + extract_name_mapping = { + "hub_content_arn": ("hub-content/", "hub_name"), + } + elif summary_name == "ImageVersion": + # ImageVersionArn -- arn:aws:sagemaker:::image-version// + # {source key from input: (target label in arn, target key in output)} + extract_name_mapping = { + "image_version_arn": ("image-version/", "image_name"), + } else: log.warning( f"Resource {resource_name} summaries do not have required members to create object instance. Resource may require custom key mapping for get_all().\n" @@ -1769,6 +1782,15 @@ def generate_get_all_method(self, resource_name: str) -> str: ) return "" + if custom_key_mapping: + custom_key_mapping_str = add_indent( + f"custom_key_mapping = {json.dumps(custom_key_mapping)}", 4 + ) + if extract_name_mapping: + extract_name_mapping_str = add_indent( + f"extract_name_mapping = {json.dumps(extract_name_mapping)}", 4 + ) + resource_iterator_args_list = [ "client=client", f"list_method='{operation}'", @@ -1780,6 +1802,9 @@ def generate_get_all_method(self, resource_name: str) -> str: if custom_key_mapping_str: resource_iterator_args_list.append(f"custom_key_mapping=custom_key_mapping") + if extract_name_mapping_str: + resource_iterator_args_list.append(f"extract_name_mapping=extract_name_mapping") + exclude_list = ["next_token", "max_results"] get_all_args = self._generate_method_args(operation_input_shape_name, exclude_list) @@ -1792,6 +1817,7 @@ def generate_get_all_method(self, resource_name: str) -> str: resource=resource_name, operation=operation, custom_key_mapping=custom_key_mapping_str, + extract_name_mapping=extract_name_mapping_str, resource_iterator_args=resource_iterator_args, ) return formatted_method @@ -1822,6 +1848,7 @@ def generate_get_all_method(self, resource_name: str) -> str: get_all_args=get_all_args, operation_input_args=operation_input_args, custom_key_mapping=custom_key_mapping_str, + extract_name_mapping=extract_name_mapping_str, resource_iterator_args=resource_iterator_args, ) return formatted_method diff --git a/src/sagemaker_core/tools/templates.py b/src/sagemaker_core/tools/templates.py index dee63ce9..9de38e76 100644 --- a/src/sagemaker_core/tools/templates.py +++ b/src/sagemaker_core/tools/templates.py @@ -512,6 +512,7 @@ def get_all( operation_input_args = {{ {operation_input_args} }} +{extract_name_mapping} {custom_key_mapping} # serialize the input request operation_input_args = serialize(operation_input_args) @@ -542,6 +543,7 @@ def get_all( """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}") +{extract_name_mapping} {custom_key_mapping} return ResourceIterator( {resource_iterator_args} diff --git a/tst/generated/test_resources.py b/tst/generated/test_resources.py index 081247ec..19125299 100644 --- a/tst/generated/test_resources.py +++ b/tst/generated/test_resources.py @@ -114,6 +114,30 @@ def test_resources(self, session, mock_transform): f"{name}s": [summary], f"Summaries": [summary], } + if name == "HubContent": + extract_name_mapping = { + "HubContentArn": ("hub-content/", "HubName"), + } + summary.update( + { + "HubContentArn": "arn:aws:sagemaker:us-west-2:123456789012:hub-content/my-hub/Model/my-model/1.0" + } + ) + for arn, target in extract_name_mapping.items(): + extracted_name = summary[arn].split(target[0])[1].split("/")[0] + summary.update({target[1]: extracted_name}) + if name == "ImageVersion": + extract_name_mapping = { + "ImageVersionArn": ("image-version/", "ImageName"), + } + summary.update( + { + "ImageVersionArn": "arn:aws:sagemaker:us-west-2:123456789012:image-version/my-image/Model/my-model/1.0" + } + ) + for arn, target in extract_name_mapping.items(): + extracted_name = summary[arn].split(target[0])[1].split("/")[0] + summary.update({target[1]: extracted_name}) if name == "MlflowTrackingServer": summary_response = {"TrackingServerSummaries": [summary]} with patch.object( diff --git a/tst/tools/test_resources_codegen.py b/tst/tools/test_resources_codegen.py index fa412420..9e81da26 100644 --- a/tst/tools/test_resources_codegen.py +++ b/tst/tools/test_resources_codegen.py @@ -1122,6 +1122,7 @@ def get_all( 'SpaceNameEquals': space_name_equals, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") @@ -1159,6 +1160,7 @@ def get_all( """ client = Base.get_sagemaker_client(session=session, region_name=region, service_name="sagemaker") + return ResourceIterator( client=client, list_method='list_domains', @@ -1224,6 +1226,7 @@ def get_all( 'CreationTimeBefore': creation_time_before, 'CreationTimeAfter': creation_time_after, } + custom_key_mapping = {"monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn"} # serialize the input request operation_input_args = serialize(operation_input_args)