Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 161 additions & 0 deletions src/sagemaker_core/main/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -6467,10 +6467,12 @@ def get_all(
"CreationTimeBefore": creation_time_before,
"CreationTimeAfter": creation_time_after,
}

custom_key_mapping = {
"monitoring_job_definition_name": "job_definition_name",
"monitoring_job_definition_arn": "job_definition_arn",
}

# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
Expand Down Expand Up @@ -12609,6 +12611,84 @@ def load(
region=region,
)

@classmethod
@Base.add_validate_call
def get_all(
cls,
hub_name: str,
hub_content_type: str,
name_contains: Optional[str] = Unassigned(),
max_schema_version: Optional[str] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
creation_time_after: Optional[datetime.datetime] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["HubContent"]:
"""
Get all HubContent resources

Parameters:
hub_name: The name of the hub to list the contents of.
hub_content_type: The type of hub content to list.
name_contains: Only list hub content if the name contains the specified string.
max_schema_version: The upper bound of the hub content schema verion.
creation_time_before: Only list hub content that was created before the time specified.
creation_time_after: Only list hub content that was created after the time specified.
sort_by: Sort hub content versions by either name or creation time.
sort_order: Sort hubs by ascending or descending order.
max_results: The maximum amount of hub content to list.
next_token: If the response to a previous ListHubContents request was truncated, the response includes a NextToken. To retrieve the next set of hub content, use the token in the next request.
session: Boto3 session.
region: Region name.

Returns:
Iterator for listed HubContent resources.

Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""

client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)

operation_input_args = {
"HubName": hub_name,
"HubContentType": hub_content_type,
"NameContains": name_contains,
"MaxSchemaVersion": max_schema_version,
"CreationTimeBefore": creation_time_before,
"CreationTimeAfter": creation_time_after,
"SortBy": sort_by,
"SortOrder": sort_order,
}
extract_name_mapping = {"hub_content_arn": ["hub-content/", "hub_name"]}

# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")

return ResourceIterator(
client=client,
list_method="list_hub_contents",
summaries_key="HubContentSummaries",
summary_name="HubContentInfo",
resource_cls=HubContent,
extract_name_mapping=extract_name_mapping,
list_method_kwargs=operation_input_args,
)

@Base.add_validate_call
def get_all_versions(
self,
Expand Down Expand Up @@ -14911,6 +14991,81 @@ def wait_for_delete(
raise e
time.sleep(poll)

@classmethod
@Base.add_validate_call
def get_all(
cls,
image_name: str,
creation_time_after: Optional[datetime.datetime] = Unassigned(),
creation_time_before: Optional[datetime.datetime] = Unassigned(),
last_modified_time_after: Optional[datetime.datetime] = Unassigned(),
last_modified_time_before: Optional[datetime.datetime] = Unassigned(),
sort_by: Optional[str] = Unassigned(),
sort_order: Optional[str] = Unassigned(),
session: Optional[Session] = None,
region: Optional[str] = None,
) -> ResourceIterator["ImageVersion"]:
"""
Get all ImageVersion resources

Parameters:
image_name: The name of the image to list the versions of.
creation_time_after: A filter that returns only versions created on or after the specified time.
creation_time_before: A filter that returns only versions created on or before the specified time.
last_modified_time_after: A filter that returns only versions modified on or after the specified time.
last_modified_time_before: A filter that returns only versions modified on or before the specified time.
max_results: The maximum number of versions to return in the response. The default value is 10.
next_token: If the previous call to ListImageVersions didn't return the full set of versions, the call returns a token for getting the next set of versions.
sort_by: The property used to sort results. The default value is CREATION_TIME.
sort_order: The sort order. The default value is DESCENDING.
session: Boto3 session.
region: Region name.

Returns:
Iterator for listed ImageVersion resources.

Raises:
botocore.exceptions.ClientError: This exception is raised for AWS service related errors.
The error message and error code can be parsed from the exception as follows:
```
try:
# AWS service call here
except botocore.exceptions.ClientError as e:
error_message = e.response['Error']['Message']
error_code = e.response['Error']['Code']
```
ResourceNotFound: Resource being access is not found.
"""

client = Base.get_sagemaker_client(
session=session, region_name=region, service_name="sagemaker"
)

operation_input_args = {
"CreationTimeAfter": creation_time_after,
"CreationTimeBefore": creation_time_before,
"ImageName": image_name,
"LastModifiedTimeAfter": last_modified_time_after,
"LastModifiedTimeBefore": last_modified_time_before,
"SortBy": sort_by,
"SortOrder": sort_order,
}
extract_name_mapping = {"image_version_arn": ["image-version/", "image_name"]}

# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")

return ResourceIterator(
client=client,
list_method="list_image_versions",
summaries_key="ImageVersions",
summary_name="ImageVersion",
resource_cls=ImageVersion,
extract_name_mapping=extract_name_mapping,
list_method_kwargs=operation_input_args,
)


class InferenceComponent(Base):
"""
Expand Down Expand Up @@ -18500,10 +18655,12 @@ def get_all(
"CreationTimeBefore": creation_time_before,
"CreationTimeAfter": creation_time_after,
}

custom_key_mapping = {
"monitoring_job_definition_name": "job_definition_name",
"monitoring_job_definition_arn": "job_definition_arn",
}

# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
Expand Down Expand Up @@ -19729,10 +19886,12 @@ def get_all(
"CreationTimeBefore": creation_time_before,
"CreationTimeAfter": creation_time_after,
}

custom_key_mapping = {
"monitoring_job_definition_name": "job_definition_name",
"monitoring_job_definition_arn": "job_definition_arn",
}

# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
Expand Down Expand Up @@ -21346,10 +21505,12 @@ def get_all(
"CreationTimeBefore": creation_time_before,
"CreationTimeAfter": creation_time_after,
}

custom_key_mapping = {
"monitoring_job_definition_name": "job_definition_name",
"monitoring_job_definition_arn": "job_definition_arn",
}

# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker_core/main/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ def __init__(
list_method: str,
list_method_kwargs: dict = {},
custom_key_mapping: dict = None,
extract_name_mapping: dict = None,
):
"""Initialize a ResourceIterator object

Expand All @@ -398,13 +399,15 @@ def __init__(
list_method (str): The list method string used to make list calls to the client.
list_method_kwargs (dict, optional): The kwargs used to make list method calls. Defaults to {}.
custom_key_mapping (dict, optional): The custom key mapping used to map keys from summary object to those expected from resource object during initialization. Defaults to None.
extract_name_mapping (dict, optional): The extract name mapping used to extract names from arn in summary object and map to those expected from resource object during initialization. Defaults to None.
"""
self.summaries_key = summaries_key
self.summary_name = summary_name
self.client = client
self.list_method = list_method
self.list_method_kwargs = list_method_kwargs
self.custom_key_mapping = custom_key_mapping
self.extract_name_mapping = extract_name_mapping

self.resource_cls = resource_cls
self.index = 0
Expand Down Expand Up @@ -433,6 +436,12 @@ def __next__(self) -> T:
if self.custom_key_mapping:
init_data = {self.custom_key_mapping.get(k, k): v for k, v in init_data.items()}

# Extract name from arn. Currently implemented for HubContent and ImageVersion
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this only work for these 2 cases or applies to all arns?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could potentially work for other arns, but it depends on the structure of arn and what we want to extract. HubContentArn and ImageVersionArn shares the same structure where there is "hub-content/" and "image-version/" string before the hub name and image name we want to extract.

if self.extract_name_mapping:
for arn, target in self.extract_name_mapping.items():
name = init_data[arn].split(target[0])[1].split("/")[0]
init_data.update({target[1]: name})

# Filter out the fields that are not in the resource class
fields = self.resource_cls.__annotations__
init_data = {k: v for k, v in init_data.items() if k in fields}
Expand Down
33 changes: 30 additions & 3 deletions src/sagemaker_core/tools/resources_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,21 +1754,43 @@ def generate_get_all_method(self, resource_name: str) -> str:
get_operation_required_input = []

custom_key_mapping_str = ""
custom_key_mapping = {}
extract_name_mapping_str = ""
extract_name_mapping = {}
if any(member not in summary_members for member in get_operation_required_input):
if "MonitoringJobDefinitionSummary" == summary_name:
if summary_name == "MonitoringJobDefinitionSummary":
custom_key_mapping = {
"monitoring_job_definition_name": "job_definition_name",
"monitoring_job_definition_arn": "job_definition_arn",
}
custom_key_mapping_str = f"custom_key_mapping = {json.dumps(custom_key_mapping)}"
custom_key_mapping_str = add_indent(custom_key_mapping_str, 4)
elif summary_name == "HubContentInfo":
# HubContentArn -- arn:<partition>:sagemaker:<region>:<account-id>:hub-content/<hub-name>/<type>/<name>/<version>
# {source key from input: (target label in arn, target key in output)}
extract_name_mapping = {
"hub_content_arn": ("hub-content/", "hub_name"),
}
elif summary_name == "ImageVersion":
# ImageVersionArn -- arn:aws:sagemaker:<region>:<account>:image-version/<image-name>/<version-number>
# {source key from input: (target label in arn, target key in output)}
extract_name_mapping = {
"image_version_arn": ("image-version/", "image_name"),
}
else:
log.warning(
f"Resource {resource_name} summaries do not have required members to create object instance. Resource may require custom key mapping for get_all().\n"
f"List {summary_name} Members: {summary_members}, Object Required Members: {get_operation_required_input}"
)
return ""

if custom_key_mapping:
custom_key_mapping_str = add_indent(
f"custom_key_mapping = {json.dumps(custom_key_mapping)}", 4
)
if extract_name_mapping:
extract_name_mapping_str = add_indent(
f"extract_name_mapping = {json.dumps(extract_name_mapping)}", 4
)

resource_iterator_args_list = [
"client=client",
f"list_method='{operation}'",
Expand All @@ -1780,6 +1802,9 @@ def generate_get_all_method(self, resource_name: str) -> str:
if custom_key_mapping_str:
resource_iterator_args_list.append(f"custom_key_mapping=custom_key_mapping")

if extract_name_mapping_str:
resource_iterator_args_list.append(f"extract_name_mapping=extract_name_mapping")

exclude_list = ["next_token", "max_results"]
get_all_args = self._generate_method_args(operation_input_shape_name, exclude_list)

Expand All @@ -1792,6 +1817,7 @@ def generate_get_all_method(self, resource_name: str) -> str:
resource=resource_name,
operation=operation,
custom_key_mapping=custom_key_mapping_str,
extract_name_mapping=extract_name_mapping_str,
resource_iterator_args=resource_iterator_args,
)
return formatted_method
Expand Down Expand Up @@ -1822,6 +1848,7 @@ def generate_get_all_method(self, resource_name: str) -> str:
get_all_args=get_all_args,
operation_input_args=operation_input_args,
custom_key_mapping=custom_key_mapping_str,
extract_name_mapping=extract_name_mapping_str,
resource_iterator_args=resource_iterator_args,
)
return formatted_method
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker_core/tools/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,7 @@ def get_all(
operation_input_args = {{
{operation_input_args}
}}
{extract_name_mapping}
{custom_key_mapping}
# serialize the input request
operation_input_args = serialize(operation_input_args)
Expand Down Expand Up @@ -542,6 +543,7 @@ def get_all(

"""
client = Base.get_sagemaker_client(session=session, region_name=region, service_name="{service_name}")
{extract_name_mapping}
{custom_key_mapping}
return ResourceIterator(
{resource_iterator_args}
Expand Down
24 changes: 24 additions & 0 deletions tst/generated/test_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,30 @@ def test_resources(self, session, mock_transform):
f"{name}s": [summary],
f"Summaries": [summary],
}
if name == "HubContent":
extract_name_mapping = {
"HubContentArn": ("hub-content/", "HubName"),
}
summary.update(
{
"HubContentArn": "arn:aws:sagemaker:us-west-2:123456789012:hub-content/my-hub/Model/my-model/1.0"
}
)
for arn, target in extract_name_mapping.items():
extracted_name = summary[arn].split(target[0])[1].split("/")[0]
summary.update({target[1]: extracted_name})
if name == "ImageVersion":
extract_name_mapping = {
"ImageVersionArn": ("image-version/", "ImageName"),
}
summary.update(
{
"ImageVersionArn": "arn:aws:sagemaker:us-west-2:123456789012:image-version/my-image/Model/my-model/1.0"
}
)
for arn, target in extract_name_mapping.items():
extracted_name = summary[arn].split(target[0])[1].split("/")[0]
summary.update({target[1]: extracted_name})
if name == "MlflowTrackingServer":
summary_response = {"TrackingServerSummaries": [summary]}
with patch.object(
Expand Down
3 changes: 3 additions & 0 deletions tst/tools/test_resources_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1122,6 +1122,7 @@ def get_all(
'SpaceNameEquals': space_name_equals,
}


# serialize the input request
operation_input_args = serialize(operation_input_args)
logger.debug(f"Serialized input request: {operation_input_args}")
Expand Down Expand Up @@ -1159,6 +1160,7 @@ def get_all(
"""
client = Base.get_sagemaker_client(session=session, region_name=region, service_name="sagemaker")


return ResourceIterator(
client=client,
list_method='list_domains',
Expand Down Expand Up @@ -1224,6 +1226,7 @@ def get_all(
'CreationTimeBefore': creation_time_before,
'CreationTimeAfter': creation_time_after,
}

custom_key_mapping = {"monitoring_job_definition_name": "job_definition_name", "monitoring_job_definition_arn": "job_definition_arn"}
# serialize the input request
operation_input_args = serialize(operation_input_args)
Expand Down