diff --git a/sample/sagemaker/2017-07-24/service-2.json b/sample/sagemaker/2017-07-24/service-2.json index 10ddbecb..516e0925 100644 --- a/sample/sagemaker/2017-07-24/service-2.json +++ b/sample/sagemaker/2017-07-24/service-2.json @@ -182,6 +182,20 @@ ], "documentation":"
Creates a SageMaker HyperPod cluster. SageMaker HyperPod is a capability of SageMaker for creating and managing persistent clusters for developing large machine learning models, such as large language models (LLMs) and diffusion models. To learn more, see Amazon SageMaker HyperPod in the Amazon SageMaker Developer Guide.
" }, + "CreateClusterSchedulerConfig":{ + "name":"CreateClusterSchedulerConfig", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"CreateClusterSchedulerConfigRequest"}, + "output":{"shape":"CreateClusterSchedulerConfigResponse"}, + "errors":[ + {"shape":"ResourceLimitExceeded"}, + {"shape":"ConflictException"} + ], + "documentation":"Create cluster policy configuration. This policy is used for task prioritization and fair-share allocation of idle compute. This helps prioritize critical workloads and distributes idle compute across entities.
" + }, "CreateCodeRepository":{ "name":"CreateCodeRepository", "http":{ @@ -206,6 +220,20 @@ ], "documentation":"Starts a model compilation job. After the model has been compiled, Amazon SageMaker saves the resulting model artifacts to an Amazon Simple Storage Service (Amazon S3) bucket that you specify.
If you choose to host your model using Amazon SageMaker hosting services, you can use the resulting model artifacts as part of the model. You can also use the artifacts with Amazon Web Services IoT Greengrass. In that case, deploy them as an ML resource.
In the request body, you provide the following:
A name for the compilation job
Information about the input model artifacts
The output location for the compiled model and the device (target) that the model runs on
The Amazon Resource Name (ARN) of the IAM role that Amazon SageMaker assumes to perform the model compilation job.
You can also provide a Tag to track the model compilation job's resource use and costs. The response body contains the CompilationJobArn for the compiled job.
To stop a model compilation job, use StopCompilationJob. To get information about a particular model compilation job, use DescribeCompilationJob. To get information about multiple model compilation jobs, use ListCompilationJobs.
" }, + "CreateComputeQuota":{ + "name":"CreateComputeQuota", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"CreateComputeQuotaRequest"}, + "output":{"shape":"CreateComputeQuotaResponse"}, + "errors":[ + {"shape":"ResourceLimitExceeded"}, + {"shape":"ConflictException"} + ], + "documentation":"Create compute allocation definition. This defines how compute is allocated, shared, and borrowed for specified entities. Specifically, how to lend and borrow idle compute and assign a fair-share weight to the specified entities.
" + }, "CreateContext":{ "name":"CreateContext", "http":{ @@ -683,6 +711,33 @@ ], "documentation":"Creates a job that optimizes a model for inference performance. To create the job, you provide the location of a source model, and you provide the settings for the optimization techniques that you want the job to apply. When the job completes successfully, SageMaker uploads the new optimized model to the output destination that you specify.
For more information about how to use this action, and about the supported optimization techniques, see Optimize model inference with Amazon SageMaker.
" }, + "CreatePartnerApp":{ + "name":"CreatePartnerApp", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"CreatePartnerAppRequest"}, + "output":{"shape":"CreatePartnerAppResponse"}, + "errors":[ + {"shape":"ResourceLimitExceeded"}, + {"shape":"ConflictException"} + ], + "documentation":"Creates an Amazon SageMaker Partner AI App.
" + }, + "CreatePartnerAppPresignedUrl":{ + "name":"CreatePartnerAppPresignedUrl", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"CreatePartnerAppPresignedUrlRequest"}, + "output":{"shape":"CreatePartnerAppPresignedUrlResponse"}, + "errors":[ + {"shape":"ResourceNotFound"} + ], + "documentation":"Creates a presigned URL to access an Amazon SageMaker Partner AI App.
" + }, "CreatePipeline":{ "name":"CreatePipeline", "http":{ @@ -804,6 +859,21 @@ ], "documentation":"Starts a model training job. After training completes, SageMaker saves the resulting model artifacts to an Amazon S3 location that you specify.
If you choose to host your model using SageMaker hosting services, you can use the resulting model artifacts as part of the model. You can also use the artifacts in a machine learning service other than SageMaker, provided that you know how to use them for inference.
In the request body, you provide the following:
AlgorithmSpecification - Identifies the training algorithm to use.
HyperParameters - Specify these algorithm-specific parameters to enable the estimation of model parameters during training. Hyperparameters can be tuned to optimize this learning process. For a list of hyperparameters for each training algorithm provided by SageMaker, see Algorithms.
Do not include any security-sensitive information including account access IDs, secrets or tokens in any hyperparameter field. If the use of security-sensitive credentials are detected, SageMaker will reject your training job request and return an exception error.
InputDataConfig - Describes the input required by the training job and the Amazon S3, EFS, or FSx location where it is stored.
OutputDataConfig - Identifies the Amazon S3 bucket where you want SageMaker to save the results of model training.
ResourceConfig - Identifies the resources, ML compute instances, and ML storage volumes to deploy for model training. In distributed training, you specify more than one instance.
EnableManagedSpotTraining - Optimize the cost of training machine learning models by up to 80% by using Amazon EC2 Spot instances. For more information, see Managed Spot Training.
RoleArn - The Amazon Resource Name (ARN) that SageMaker assumes to perform tasks on your behalf during model training. You must grant this role the necessary permissions so that SageMaker can successfully complete model training.
StoppingCondition - To help cap training costs, use MaxRuntimeInSeconds to set a time limit for training. Use MaxWaitTimeInSeconds to specify how long a managed spot training job has to complete.
Environment - The environment variables to set in the Docker container.
RetryStrategy - The number of times to retry the job when the job fails due to an InternalServerError.
For more information about SageMaker, see How It Works.
" }, + "CreateTrainingPlan":{ + "name":"CreateTrainingPlan", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"CreateTrainingPlanRequest"}, + "output":{"shape":"CreateTrainingPlanResponse"}, + "errors":[ + {"shape":"ResourceLimitExceeded"}, + {"shape":"ResourceNotFound"}, + {"shape":"ResourceInUse"} + ], + "documentation":"Creates a new training plan in SageMaker to reserve compute capacity.
Amazon SageMaker Training Plan is a capability within SageMaker that allows customers to reserve and manage GPU capacity for large-scale AI model training. It provides a way to secure predictable access to computational resources within specific timelines and budgets, without the need to manage underlying infrastructure.
How it works
Plans can be created for specific resources such as SageMaker Training Jobs or SageMaker HyperPod clusters, automatically provisioning resources, setting up infrastructure, executing workloads, and handling infrastructure failures.
Plan creation workflow
Users search for available plan offerings based on their requirements (e.g., instance type, count, start time, duration) using the SearchTrainingPlanOfferings API operation.
They create a plan that best matches their needs using the ID of the plan offering they want to use.
After successful upfront payment, the plan's status becomes Scheduled.
The plan can be used to:
Queue training jobs.
Allocate to an instance group of a SageMaker HyperPod cluster.
When the plan start date arrives, it becomes Active. Based on available reserved capacity:
Training jobs are launched.
Instance groups are provisioned.
Plan composition
A plan can consist of one or more Reserved Capacities, each defined by a specific instance type, quantity, Availability Zone, duration, and start and end times. For more information about Reserved Capacity, see ReservedCapacitySummary .
Delete a SageMaker HyperPod cluster.
" }, + "DeleteClusterSchedulerConfig":{ + "name":"DeleteClusterSchedulerConfig", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"DeleteClusterSchedulerConfigRequest"}, + "errors":[ + {"shape":"ResourceNotFound"} + ], + "documentation":"Deletes the cluster policy of the cluster.
" + }, "DeleteCodeRepository":{ "name":"DeleteCodeRepository", "http":{ @@ -995,6 +1077,18 @@ ], "documentation":"Deletes the specified compilation job. This action deletes only the compilation job resource in Amazon SageMaker. It doesn't delete other resources that are related to that job, such as the model artifacts that the job creates, the compilation logs in CloudWatch, the compiled model, or the IAM role.
You can delete a compilation job only if its current status is COMPLETED, FAILED, or STOPPED. If the job status is STARTING or INPROGRESS, stop the job, and then delete it after its status becomes STOPPED.
Deletes the compute allocation from the cluster.
" + }, "DeleteContext":{ "name":"DeleteContext", "http":{ @@ -1383,6 +1477,20 @@ ], "documentation":"Deletes an optimization job.
" }, + "DeletePartnerApp":{ + "name":"DeletePartnerApp", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"DeletePartnerAppRequest"}, + "output":{"shape":"DeletePartnerAppResponse"}, + "errors":[ + {"shape":"ResourceNotFound"}, + {"shape":"ConflictException"} + ], + "documentation":"Deletes a SageMaker Partner AI App.
" + }, "DeletePipeline":{ "name":"DeletePipeline", "http":{ @@ -1630,6 +1738,19 @@ ], "documentation":"Retrieves information of a node (also called a instance interchangeably) of a SageMaker HyperPod cluster.
" }, + "DescribeClusterSchedulerConfig":{ + "name":"DescribeClusterSchedulerConfig", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"DescribeClusterSchedulerConfigRequest"}, + "output":{"shape":"DescribeClusterSchedulerConfigResponse"}, + "errors":[ + {"shape":"ResourceNotFound"} + ], + "documentation":"Description of the cluster policy. This policy is used for task prioritization and fair-share allocation. This helps prioritize critical workloads and distributes idle compute across entities.
" + }, "DescribeCodeRepository":{ "name":"DescribeCodeRepository", "http":{ @@ -1653,6 +1774,19 @@ ], "documentation":"Returns information about a model compilation job.
To create a model compilation job, use CreateCompilationJob. To get information about multiple model compilation jobs, use ListCompilationJobs.
" }, + "DescribeComputeQuota":{ + "name":"DescribeComputeQuota", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"DescribeComputeQuotaRequest"}, + "output":{"shape":"DescribeComputeQuotaResponse"}, + "errors":[ + {"shape":"ResourceNotFound"} + ], + "documentation":"Description of the compute allocation definition.
" + }, "DescribeContext":{ "name":"DescribeContext", "http":{ @@ -2110,6 +2244,19 @@ ], "documentation":"Provides the properties of the specified optimization job.
" }, + "DescribePartnerApp":{ + "name":"DescribePartnerApp", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"DescribePartnerAppRequest"}, + "output":{"shape":"DescribePartnerAppResponse"}, + "errors":[ + {"shape":"ResourceNotFound"} + ], + "documentation":"Gets information about a SageMaker Partner AI App.
" + }, "DescribePipeline":{ "name":"DescribePipeline", "http":{ @@ -2221,6 +2368,19 @@ ], "documentation":"Returns information about a training job.
Some of the attributes below only appear if the training job successfully starts. If the training job fails, TrainingJobStatus is Failed and, depending on the FailureReason, attributes like TrainingStartTime, TrainingTimeInSeconds, TrainingEndTime, and BillableTimeInSeconds may not be present in the response.
Retrieves detailed information about a specific training plan.
" + }, "DescribeTransformJob":{ "name":"DescribeTransformJob", "http":{ @@ -2526,6 +2686,16 @@ ], "documentation":"Retrieves the list of instances (also called nodes interchangeably) in a SageMaker HyperPod cluster.
" }, + "ListClusterSchedulerConfigs":{ + "name":"ListClusterSchedulerConfigs", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"ListClusterSchedulerConfigsRequest"}, + "output":{"shape":"ListClusterSchedulerConfigsResponse"}, + "documentation":"List the cluster policy configurations.
" + }, "ListClusters":{ "name":"ListClusters", "http":{ @@ -2556,6 +2726,16 @@ "output":{"shape":"ListCompilationJobsResponse"}, "documentation":"Lists model compilation jobs that satisfy various filters.
To create a model compilation job, use CreateCompilationJob. To get information about a particular model compilation job you have created, use DescribeCompilationJob.
" }, + "ListComputeQuotas":{ + "name":"ListComputeQuotas", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"ListComputeQuotasRequest"}, + "output":{"shape":"ListComputeQuotasResponse"}, + "documentation":"List the resource allocation definitions.
" + }, "ListContexts":{ "name":"ListContexts", "http":{ @@ -3023,6 +3203,16 @@ "output":{"shape":"ListOptimizationJobsResponse"}, "documentation":"Lists the optimization jobs in your account and their properties.
" }, + "ListPartnerApps":{ + "name":"ListPartnerApps", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"ListPartnerAppsRequest"}, + "output":{"shape":"ListPartnerAppsResponse"}, + "documentation":"Lists all of the SageMaker Partner AI Apps in an account.
" + }, "ListPipelineExecutionSteps":{ "name":"ListPipelineExecutionSteps", "http":{ @@ -3178,6 +3368,16 @@ ], "documentation":"Gets a list of TrainingJobSummary objects that describe the training jobs that a hyperparameter tuning job launched.
" }, + "ListTrainingPlans":{ + "name":"ListTrainingPlans", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"ListTrainingPlansRequest"}, + "output":{"shape":"ListTrainingPlansResponse"}, + "documentation":"Retrieves a list of training plans for the current account.
" + }, "ListTransformJobs":{ "name":"ListTransformJobs", "http":{ @@ -3320,6 +3520,19 @@ "output":{"shape":"SearchResponse"}, "documentation":"Finds SageMaker resources that match a search query. Matching resources are returned as a list of SearchRecord objects in the response. You can sort the search results by any resource property in a ascending or descending order.
You can query against the following value types: numeric, text, Boolean, and timestamp.
The Search API may provide access to otherwise restricted data. See Amazon SageMaker API Permissions: Actions, Permissions, and Resources Reference for more information.
Searches for available training plan offerings based on specified criteria.
Users search for available plan offerings based on their requirements (e.g., instance type, count, start time, duration).
And then, they create a plan that best matches their needs using the ID of the plan offering they want to use.
For more information about how to reserve GPU capacity for your SageMaker training jobs or SageMaker HyperPod clusters using Amazon SageMaker Training Plan , see CreateTrainingPlan .
Updates a SageMaker HyperPod cluster.
" }, + "UpdateClusterSchedulerConfig":{ + "name":"UpdateClusterSchedulerConfig", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"UpdateClusterSchedulerConfigRequest"}, + "output":{"shape":"UpdateClusterSchedulerConfigResponse"}, + "errors":[ + {"shape":"ResourceNotFound"}, + {"shape":"ResourceLimitExceeded"}, + {"shape":"ConflictException"} + ], + "documentation":"Update the cluster policy configuration.
" + }, "UpdateClusterSoftware":{ "name":"UpdateClusterSoftware", "http":{ @@ -3698,6 +3926,21 @@ ], "documentation":"Updates the specified Git repository with the specified values.
" }, + "UpdateComputeQuota":{ + "name":"UpdateComputeQuota", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"UpdateComputeQuotaRequest"}, + "output":{"shape":"UpdateComputeQuotaResponse"}, + "errors":[ + {"shape":"ResourceNotFound"}, + {"shape":"ResourceLimitExceeded"}, + {"shape":"ConflictException"} + ], + "documentation":"Update the compute allocation definition.
" + }, "UpdateContext":{ "name":"UpdateContext", "http":{ @@ -3992,6 +4235,20 @@ ], "documentation":"Updates a notebook instance lifecycle configuration created with the CreateNotebookInstanceLifecycleConfig API.
" }, + "UpdatePartnerApp":{ + "name":"UpdatePartnerApp", + "http":{ + "method":"POST", + "requestUri":"/" + }, + "input":{"shape":"UpdatePartnerAppRequest"}, + "output":{"shape":"UpdatePartnerAppResponse"}, + "errors":[ + {"shape":"ResourceNotFound"}, + {"shape":"ConflictException"} + ], + "documentation":"Updates all of the SageMaker Partner AI Apps in an account.
" + }, "UpdatePipeline":{ "name":"UpdatePipeline", "http":{ @@ -4218,6 +4475,13 @@ }, "documentation":"Lists the properties of an action. An action represents an action or activity. Some examples are a workflow step and a model deployment. Generally, an action involves at least one input artifact or output artifact.
" }, + "ActivationState":{ + "type":"string", + "enum":[ + "Enabled", + "Disabled" + ] + }, "AddAssociationRequest":{ "type":"structure", "required":[ @@ -6165,6 +6429,16 @@ "type":"string", "enum":["Enabled"] }, + "AvailabilityZone":{ + "type":"string", + "max":32, + "min":1, + "pattern":"[a-z]+\\-[0-9a-z\\-]+" + }, + "AvailableInstanceCount":{ + "type":"integer", + "min":0 + }, "AwsManagedHumanLoopRequestSource":{ "type":"string", "enum":[ @@ -6485,6 +6759,11 @@ "Or" ] }, + "BorrowLimit":{ + "type":"integer", + "max":500, + "min":1 + }, "Branch":{ "type":"string", "max":1024, @@ -7359,6 +7638,18 @@ "shape":"OnStartDeepHealthChecks", "documentation":"A flag indicating whether deep health checks should be performed when the cluster instance group is created or updated.
" }, + "Status":{ + "shape":"InstanceGroupStatus", + "documentation":"The current status of the cluster instance group.
InService: The instance group is active and healthy.
Creating: The instance group is being provisioned.
Updating: The instance group is being updated.
Failed: The instance group has failed to provision or is no longer healthy.
Degraded: The instance group is degraded, meaning that some instances have failed to provision or are no longer healthy.
Deleting: The instance group is being deleted.
The Amazon Resource Name (ARN); of the training plan associated with this cluster instance group.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
The current status of the training plan associated with this cluster instance group.
" + }, "OverrideVpcConfig":{"shape":"VpcConfig"} }, "documentation":"Details of an instance group in a SageMaker HyperPod cluster.
" @@ -7415,6 +7706,10 @@ "shape":"OnStartDeepHealthChecks", "documentation":"A flag indicating whether deep health checks should be performed when the cluster instance group is created or updated.
" }, + "TrainingPlanArn":{ + "shape":"TrainingPlanArn", + "documentation":"The Amazon Resource Name (ARN); of the training plan to use for this cluster instance group.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
The specifications of an instance group that you need to define.
" @@ -7540,7 +7835,9 @@ "ml.g6e.12xlarge", "ml.g6e.24xlarge", "ml.g6e.48xlarge", - "ml.p5e.48xlarge" + "ml.p5e.48xlarge", + "ml.p5en.48xlarge", + "ml.trn2.48xlarge" ] }, "ClusterLifeCycleConfig":{ @@ -7719,6 +8016,71 @@ "type":"string", "pattern":"^((25[0-5]|(2[0-4]|1\\d|[1-9]|)\\d)\\.?\\b){4}$" }, + "ClusterSchedulerConfigArn":{ + "type":"string", + "max":256, + "pattern":"^arn:aws[a-z\\-]*:sagemaker:[a-z0-9\\-]{9,16}:[0-9]{12}:cluster-scheduler-config/[a-z0-9]{12}$" + }, + "ClusterSchedulerConfigId":{ + "type":"string", + "max":12, + "pattern":"^[a-z0-9]{12}$" + }, + "ClusterSchedulerConfigSummary":{ + "type":"structure", + "required":[ + "ClusterSchedulerConfigArn", + "ClusterSchedulerConfigId", + "Name", + "CreationTime", + "Status" + ], + "members":{ + "ClusterSchedulerConfigArn":{ + "shape":"ClusterSchedulerConfigArn", + "documentation":"ARN of the cluster policy.
" + }, + "ClusterSchedulerConfigId":{ + "shape":"ClusterSchedulerConfigId", + "documentation":"ID of the cluster policy.
" + }, + "ClusterSchedulerConfigVersion":{ + "shape":"Integer", + "documentation":"Version of the cluster policy.
" + }, + "Name":{ + "shape":"EntityName", + "documentation":"Name of the cluster policy.
" + }, + "CreationTime":{ + "shape":"Timestamp", + "documentation":"Creation time of the cluster policy.
" + }, + "LastModifiedTime":{ + "shape":"Timestamp", + "documentation":"Last modified time of the cluster policy.
" + }, + "Status":{ + "shape":"SchedulerResourceStatus", + "documentation":"Status of the cluster policy.
" + }, + "ClusterArn":{ + "shape":"ClusterArn", + "documentation":"ARN of the cluster.
" + } + }, + "documentation":"Summary of the cluster policy.
" + }, + "ClusterSchedulerConfigSummaryList":{ + "type":"list", + "member":{"shape":"ClusterSchedulerConfigSummary"}, + "max":100, + "min":0 + }, + "ClusterSchedulerPriorityClassName":{ + "type":"string", + "pattern":"^[a-z0-9]([-a-z0-9]*[a-z0-9]){0,39}?$" + }, "ClusterSortBy":{ "type":"string", "enum":[ @@ -7766,6 +8128,10 @@ "ClusterStatus":{ "shape":"ClusterStatus", "documentation":"The status of the SageMaker HyperPod cluster.
" + }, + "TrainingPlanArns":{ + "shape":"TrainingPlanArns", + "documentation":"A list of Amazon Resource Names (ARNs) of the training plans associated with this cluster.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
Lists a summary of the properties of a SageMaker HyperPod cluster.
" @@ -8100,6 +8466,140 @@ "type":"list", "member":{"shape":"CompressionType"} }, + "ComputeQuotaArn":{ + "type":"string", + "max":2048, + "pattern":"arn:aws[a-z\\-]*:sagemaker:[a-z0-9\\-]*:[0-9]{12}:compute-quota/[a-z0-9]{12}$" + }, + "ComputeQuotaConfig":{ + "type":"structure", + "members":{ + "ComputeQuotaResources":{ + "shape":"ComputeQuotaResourceConfigList", + "documentation":"Allocate compute resources by instance types.
" + }, + "ResourceSharingConfig":{ + "shape":"ResourceSharingConfig", + "documentation":"Resource sharing configuration. This defines how an entity can lend and borrow idle compute with other entities within the cluster.
" + }, + "PreemptTeamTasks":{ + "shape":"PreemptTeamTasks", + "documentation":"Allows workloads from within an entity to preempt same-team workloads. When set to LowerPriority, the entity's lower priority tasks are preempted by their own higher priority tasks.
Default is LowerPriority.
Configuration of the compute allocation definition for an entity. This includes the resource sharing option and the setting to preempt low priority tasks.
" + }, + "ComputeQuotaId":{ + "type":"string", + "pattern":"^[a-z0-9]{12}$" + }, + "ComputeQuotaResourceConfig":{ + "type":"structure", + "required":[ + "InstanceType", + "Count" + ], + "members":{ + "InstanceType":{ + "shape":"ClusterInstanceType", + "documentation":"The instance type of the instance group for the cluster.
" + }, + "Count":{ + "shape":"InstanceCount", + "documentation":"The number of instances to add to the instance group of a SageMaker HyperPod cluster.
" + } + }, + "documentation":"Configuration of the resources used for the compute allocation definition.
" + }, + "ComputeQuotaResourceConfigList":{ + "type":"list", + "member":{"shape":"ComputeQuotaResourceConfig"}, + "max":15, + "min":0 + }, + "ComputeQuotaSummary":{ + "type":"structure", + "required":[ + "ComputeQuotaArn", + "ComputeQuotaId", + "Name", + "Status", + "ComputeQuotaTarget", + "CreationTime" + ], + "members":{ + "ComputeQuotaArn":{ + "shape":"ComputeQuotaArn", + "documentation":"ARN of the compute allocation definition.
" + }, + "ComputeQuotaId":{ + "shape":"ComputeQuotaId", + "documentation":"ID of the compute allocation definition.
" + }, + "Name":{ + "shape":"EntityName", + "documentation":"Name of the compute allocation definition.
" + }, + "ComputeQuotaVersion":{ + "shape":"Integer", + "documentation":"Version of the compute allocation definition.
" + }, + "Status":{ + "shape":"SchedulerResourceStatus", + "documentation":"Status of the compute allocation definition.
" + }, + "ClusterArn":{ + "shape":"ClusterArn", + "documentation":"ARN of the cluster.
" + }, + "ComputeQuotaConfig":{ + "shape":"ComputeQuotaConfig", + "documentation":"Configuration of the compute allocation definition. This includes the resource sharing option, and the setting to preempt low priority tasks.
" + }, + "ComputeQuotaTarget":{ + "shape":"ComputeQuotaTarget", + "documentation":"The target entity to allocate compute resources to.
" + }, + "ActivationState":{ + "shape":"ActivationState", + "documentation":"The state of the compute allocation being described. Use to enable or disable compute allocation.
Default is Enabled.
Creation time of the compute allocation definition.
" + }, + "LastModifiedTime":{ + "shape":"Timestamp", + "documentation":"Last modified time of the compute allocation definition.
" + } + }, + "documentation":"Summary of the compute allocation definition.
" + }, + "ComputeQuotaSummaryList":{ + "type":"list", + "member":{"shape":"ComputeQuotaSummary"}, + "max":100, + "min":0 + }, + "ComputeQuotaTarget":{ + "type":"structure", + "required":["TeamName"], + "members":{ + "TeamName":{ + "shape":"ComputeQuotaTargetTeamName", + "documentation":"Name of the team to allocate compute resources to.
" + }, + "FairShareWeight":{ + "shape":"FairShareWeight", + "documentation":"Assigned entity fair-share weight. Idle compute will be shared across entities based on these assigned weights. This weight is only used when FairShare is enabled.
A weight of 0 is the lowest priority and 100 is the highest. Weight 0 is the default.
" + } + }, + "documentation":"The target entity to allocate compute resources to.
" + }, + "ComputeQuotaTargetTeamName":{ + "type":"string", + "pattern":"^[a-z0-9]([-a-z0-9]*[a-z0-9]){0,39}?$" + }, "ConditionOutcome":{ "type":"string", "enum":[ @@ -8799,6 +9299,53 @@ } } }, + "CreateClusterSchedulerConfigRequest":{ + "type":"structure", + "required":[ + "Name", + "ClusterArn", + "SchedulerConfig" + ], + "members":{ + "Name":{ + "shape":"EntityName", + "documentation":"Name for the cluster policy.
" + }, + "ClusterArn":{ + "shape":"ClusterArn", + "documentation":"ARN of the cluster.
" + }, + "SchedulerConfig":{ + "shape":"SchedulerConfig", + "documentation":"Configuration about the monitoring schedule.
" + }, + "Description":{ + "shape":"EntityDescription", + "documentation":"Description of the cluster policy.
" + }, + "Tags":{ + "shape":"TagList", + "documentation":"Tags of the cluster policy.
" + } + } + }, + "CreateClusterSchedulerConfigResponse":{ + "type":"structure", + "required":[ + "ClusterSchedulerConfigArn", + "ClusterSchedulerConfigId" + ], + "members":{ + "ClusterSchedulerConfigArn":{ + "shape":"ClusterSchedulerConfigArn", + "documentation":"ARN of the cluster policy.
" + }, + "ClusterSchedulerConfigId":{ + "shape":"ClusterSchedulerConfigId", + "documentation":"ID of the cluster policy.
" + } + } + }, "CreateCodeRepositoryInput":{ "type":"structure", "required":[ @@ -8883,6 +9430,62 @@ } } }, + "CreateComputeQuotaRequest":{ + "type":"structure", + "required":[ + "Name", + "ClusterArn", + "ComputeQuotaConfig", + "ComputeQuotaTarget" + ], + "members":{ + "Name":{ + "shape":"EntityName", + "documentation":"Name to the compute allocation definition.
" + }, + "Description":{ + "shape":"EntityDescription", + "documentation":"Description of the compute allocation definition.
" + }, + "ClusterArn":{ + "shape":"ClusterArn", + "documentation":"ARN of the cluster.
" + }, + "ComputeQuotaConfig":{ + "shape":"ComputeQuotaConfig", + "documentation":"Configuration of the compute allocation definition. This includes the resource sharing option, and the setting to preempt low priority tasks.
" + }, + "ComputeQuotaTarget":{ + "shape":"ComputeQuotaTarget", + "documentation":"The target entity to allocate compute resources to.
" + }, + "ActivationState":{ + "shape":"ActivationState", + "documentation":"The state of the compute allocation being described. Use to enable or disable compute allocation.
Default is Enabled.
Tags of the compute allocation definition.
" + } + } + }, + "CreateComputeQuotaResponse":{ + "type":"structure", + "required":[ + "ComputeQuotaArn", + "ComputeQuotaId" + ], + "members":{ + "ComputeQuotaArn":{ + "shape":"ComputeQuotaArn", + "documentation":"ARN of the compute allocation definition.
" + }, + "ComputeQuotaId":{ + "shape":"ComputeQuotaId", + "documentation":"ID of the compute allocation definition.
" + } + } + }, "CreateContextRequest":{ "type":"structure", "required":[ @@ -10585,6 +11188,95 @@ } } }, + "CreatePartnerAppPresignedUrlRequest":{ + "type":"structure", + "required":["Arn"], + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App to create the presigned URL for.
" + }, + "ExpiresInSeconds":{ + "shape":"ExpiresInSeconds", + "documentation":"The time that will pass before the presigned URL expires.
" + }, + "SessionExpirationDurationInSeconds":{ + "shape":"SessionExpirationDurationInSeconds", + "documentation":"Indicates how long the Amazon SageMaker Partner AI App session can be accessed for after logging in.
" + } + } + }, + "CreatePartnerAppPresignedUrlResponse":{ + "type":"structure", + "members":{ + "Url":{ + "shape":"String2048", + "documentation":"The presigned URL that you can use to access the SageMaker Partner AI App.
" + } + } + }, + "CreatePartnerAppRequest":{ + "type":"structure", + "required":[ + "Name", + "Type", + "ExecutionRoleArn", + "Tier", + "AuthType" + ], + "members":{ + "Name":{ + "shape":"PartnerAppName", + "documentation":"The name to give the SageMaker Partner AI App.
" + }, + "Type":{ + "shape":"PartnerAppType", + "documentation":"The type of SageMaker Partner AI App to create. Must be one of the following: lakera-guard, comet, deepchecks-llm-evaluation, or fiddler.
The ARN of the IAM role that the partner application uses.
" + }, + "MaintenanceConfig":{ + "shape":"PartnerAppMaintenanceConfig", + "documentation":"Maintenance configuration settings for the SageMaker Partner AI App.
" + }, + "Tier":{ + "shape":"NonEmptyString64", + "documentation":"Indicates the instance type and size of the cluster attached to the SageMaker Partner AI App.
" + }, + "ApplicationConfig":{ + "shape":"PartnerAppConfig", + "documentation":"Configuration settings for the SageMaker Partner AI App.
" + }, + "AuthType":{ + "shape":"PartnerAppAuthType", + "documentation":"The authorization type that users use to access the SageMaker Partner AI App.
" + }, + "EnableIamSessionBasedIdentity":{ + "shape":"Boolean", + "documentation":"When set to TRUE, the SageMaker Partner AI App sets the Amazon Web Services IAM session name or the authenticated IAM user as the identity of the SageMaker Partner AI App user.
A unique token that guarantees that the call to this API is idempotent.
", + "idempotencyToken":true + }, + "Tags":{ + "shape":"TagList", + "documentation":"Each tag consists of a key and an optional value. Tag keys must be unique per resource.
" + } + } + }, + "CreatePartnerAppResponse":{ + "type":"structure", + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App.
" + } + } + }, "CreatePipelineRequest":{ "type":"structure", "required":[ @@ -11029,6 +11721,37 @@ } } }, + "CreateTrainingPlanRequest":{ + "type":"structure", + "required":[ + "TrainingPlanName", + "TrainingPlanOfferingId" + ], + "members":{ + "TrainingPlanName":{ + "shape":"TrainingPlanName", + "documentation":"The name of the training plan to create.
" + }, + "TrainingPlanOfferingId":{ + "shape":"TrainingPlanOfferingId", + "documentation":"The unique identifier of the training plan offering to use for creating this plan.
" + }, + "Tags":{ + "shape":"TagList", + "documentation":"An array of key-value pairs to apply to this training plan.
" + } + } + }, + "CreateTrainingPlanResponse":{ + "type":"structure", + "required":["TrainingPlanArn"], + "members":{ + "TrainingPlanArn":{ + "shape":"TrainingPlanArn", + "documentation":"The Amazon Resource Name (ARN); of the created training plan.
" + } + } + }, "CreateTransformJobRequest":{ "type":"structure", "required":[ @@ -11337,12 +12060,17 @@ "max":10, "min":1 }, + "CurrencyCode":{"type":"string"}, "CustomFileSystem":{ "type":"structure", "members":{ "EFSFileSystem":{ "shape":"EFSFileSystem", "documentation":"A custom file system in Amazon EFS.
" + }, + "FSxLustreFileSystem":{ + "shape":"FSxLustreFileSystem", + "documentation":"A custom file system in Amazon FSx for Lustre.
" } }, "documentation":"A file system, created by you, that you assign to a user profile or space for an Amazon SageMaker Domain. Permitted users can access this file system in Amazon SageMaker Studio.
", @@ -11354,6 +12082,10 @@ "EFSFileSystemConfig":{ "shape":"EFSFileSystemConfig", "documentation":"The settings for a custom Amazon EFS file system.
" + }, + "FSxLustreFileSystemConfig":{ + "shape":"FSxLustreFileSystemConfig", + "documentation":"The settings for a custom Amazon FSx for Lustre file system.
" } }, "documentation":"The settings for assigning a custom file system to a user profile or space for an Amazon SageMaker Domain. Permitted users can access this file system in Amazon SageMaker Studio.
", @@ -12007,6 +12739,16 @@ } } }, + "DeleteClusterSchedulerConfigRequest":{ + "type":"structure", + "required":["ClusterSchedulerConfigId"], + "members":{ + "ClusterSchedulerConfigId":{ + "shape":"ClusterSchedulerConfigId", + "documentation":"ID of the cluster policy.
" + } + } + }, "DeleteCodeRepositoryInput":{ "type":"structure", "required":["CodeRepositoryName"], @@ -12027,6 +12769,16 @@ } } }, + "DeleteComputeQuotaRequest":{ + "type":"structure", + "required":["ComputeQuotaId"], + "members":{ + "ComputeQuotaId":{ + "shape":"ComputeQuotaId", + "documentation":"ID of the compute allocation definition.
" + } + } + }, "DeleteContextRequest":{ "type":"structure", "required":["ContextName"], @@ -12462,6 +13214,30 @@ } } }, + "DeletePartnerAppRequest":{ + "type":"structure", + "required":["Arn"], + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App to delete.
" + }, + "ClientToken":{ + "shape":"ClientToken", + "documentation":"A unique token that guarantees that the call to this API is idempotent.
", + "idempotencyToken":true + } + } + }, + "DeletePartnerAppResponse":{ + "type":"structure", + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App that was deleted.
" + } + } + }, "DeletePipelineRequest":{ "type":"structure", "required":[ @@ -13402,6 +14178,79 @@ } } }, + "DescribeClusterSchedulerConfigRequest":{ + "type":"structure", + "required":["ClusterSchedulerConfigId"], + "members":{ + "ClusterSchedulerConfigId":{ + "shape":"ClusterSchedulerConfigId", + "documentation":"ID of the cluster policy.
" + }, + "ClusterSchedulerConfigVersion":{ + "shape":"Integer", + "documentation":"Version of the cluster policy.
" + } + } + }, + "DescribeClusterSchedulerConfigResponse":{ + "type":"structure", + "required":[ + "ClusterSchedulerConfigArn", + "ClusterSchedulerConfigId", + "Name", + "ClusterSchedulerConfigVersion", + "Status", + "CreationTime" + ], + "members":{ + "ClusterSchedulerConfigArn":{ + "shape":"ClusterSchedulerConfigArn", + "documentation":"ARN of the cluster policy.
" + }, + "ClusterSchedulerConfigId":{ + "shape":"ClusterSchedulerConfigId", + "documentation":"ID of the cluster policy.
" + }, + "Name":{ + "shape":"EntityName", + "documentation":"Name of the cluster policy.
" + }, + "ClusterSchedulerConfigVersion":{ + "shape":"Integer", + "documentation":"Version of the cluster policy.
" + }, + "Status":{ + "shape":"SchedulerResourceStatus", + "documentation":"Status of the cluster policy.
" + }, + "FailureReason":{ + "shape":"FailureReason", + "documentation":"Failure reason of the cluster policy.
" + }, + "ClusterArn":{ + "shape":"ClusterArn", + "documentation":"ARN of the cluster where the cluster policy is applied.
" + }, + "SchedulerConfig":{ + "shape":"SchedulerConfig", + "documentation":"Cluster policy configuration. This policy is used for task prioritization and fair-share allocation. This helps prioritize critical workloads and distributes idle compute across entities.
" + }, + "Description":{ + "shape":"EntityDescription", + "documentation":"Description of the cluster policy.
" + }, + "CreationTime":{ + "shape":"Timestamp", + "documentation":"Creation time of the cluster policy.
" + }, + "CreatedBy":{"shape":"UserContext"}, + "LastModifiedTime":{ + "shape":"Timestamp", + "documentation":"Last modified time of the cluster policy.
" + }, + "LastModifiedBy":{"shape":"UserContext"} + } + }, "DescribeCodeRepositoryInput":{ "type":"structure", "required":["CodeRepositoryName"], @@ -13543,6 +14392,88 @@ } } }, + "DescribeComputeQuotaRequest":{ + "type":"structure", + "required":["ComputeQuotaId"], + "members":{ + "ComputeQuotaId":{ + "shape":"ComputeQuotaId", + "documentation":"ID of the compute allocation definition.
" + }, + "ComputeQuotaVersion":{ + "shape":"Integer", + "documentation":"Version of the compute allocation definition.
" + } + } + }, + "DescribeComputeQuotaResponse":{ + "type":"structure", + "required":[ + "ComputeQuotaArn", + "ComputeQuotaId", + "Name", + "ComputeQuotaVersion", + "Status", + "ComputeQuotaTarget", + "CreationTime" + ], + "members":{ + "ComputeQuotaArn":{ + "shape":"ComputeQuotaArn", + "documentation":"ARN of the compute allocation definition.
" + }, + "ComputeQuotaId":{ + "shape":"ComputeQuotaId", + "documentation":"ID of the compute allocation definition.
" + }, + "Name":{ + "shape":"EntityName", + "documentation":"Name of the compute allocation definition.
" + }, + "Description":{ + "shape":"EntityDescription", + "documentation":"Description of the compute allocation definition.
" + }, + "ComputeQuotaVersion":{ + "shape":"Integer", + "documentation":"Version of the compute allocation definition.
" + }, + "Status":{ + "shape":"SchedulerResourceStatus", + "documentation":"Status of the compute allocation definition.
" + }, + "FailureReason":{ + "shape":"FailureReason", + "documentation":"Failure reason of the compute allocation definition.
" + }, + "ClusterArn":{ + "shape":"ClusterArn", + "documentation":"ARN of the cluster.
" + }, + "ComputeQuotaConfig":{ + "shape":"ComputeQuotaConfig", + "documentation":"Configuration of the compute allocation definition. This includes the resource sharing option, and the setting to preempt low priority tasks.
" + }, + "ComputeQuotaTarget":{ + "shape":"ComputeQuotaTarget", + "documentation":"The target entity to allocate compute resources to.
" + }, + "ActivationState":{ + "shape":"ActivationState", + "documentation":"The state of the compute allocation being described. Use to enable or disable compute allocation.
Default is Enabled.
Creation time of the compute allocation configuration.
" + }, + "CreatedBy":{"shape":"UserContext"}, + "LastModifiedTime":{ + "shape":"Timestamp", + "documentation":"Last modified time of the compute allocation configuration.
" + }, + "LastModifiedBy":{"shape":"UserContext"} + } + }, "DescribeContextRequest":{ "type":"structure", "required":["ContextName"], @@ -16199,6 +17130,77 @@ } } }, + "DescribePartnerAppRequest":{ + "type":"structure", + "required":["Arn"], + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App to describe.
" + } + } + }, + "DescribePartnerAppResponse":{ + "type":"structure", + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App that was described.
" + }, + "Name":{ + "shape":"PartnerAppName", + "documentation":"The name of the SageMaker Partner AI App.
" + }, + "Type":{ + "shape":"PartnerAppType", + "documentation":"The type of SageMaker Partner AI App. Must be one of the following: lakera-guard, comet, deepchecks-llm-evaluation, or fiddler.
The status of the SageMaker Partner AI App.
" + }, + "CreationTime":{ + "shape":"Timestamp", + "documentation":"The time that the SageMaker Partner AI App was created.
" + }, + "ExecutionRoleArn":{ + "shape":"RoleArn", + "documentation":"The ARN of the IAM role associated with the SageMaker Partner AI App.
" + }, + "BaseUrl":{ + "shape":"String2048", + "documentation":"The URL of the SageMaker Partner AI App that the Application SDK uses to support in-app calls for the user.
" + }, + "MaintenanceConfig":{ + "shape":"PartnerAppMaintenanceConfig", + "documentation":"Maintenance configuration settings for the SageMaker Partner AI App.
" + }, + "Tier":{ + "shape":"NonEmptyString64", + "documentation":"The instance type and size of the cluster attached to the SageMaker Partner AI App.
" + }, + "Version":{ + "shape":"NonEmptyString64", + "documentation":"The version of the SageMaker Partner AI App.
" + }, + "ApplicationConfig":{ + "shape":"PartnerAppConfig", + "documentation":"Configuration settings for the SageMaker Partner AI App.
" + }, + "AuthType":{ + "shape":"PartnerAppAuthType", + "documentation":"The authorization type that users use to access the SageMaker Partner AI App.
" + }, + "EnableIamSessionBasedIdentity":{ + "shape":"Boolean", + "documentation":"When set to TRUE, the SageMaker Partner AI App sets the Amazon Web Services IAM session name or the authenticated IAM user as the identity of the SageMaker Partner AI App user.
This is an error field object that contains the error code and the reason for an operation failure.
" + } + } + }, "DescribePipelineDefinitionForExecutionRequest":{ "type":"structure", "required":["PipelineExecutionArn"], @@ -16825,6 +17827,86 @@ } } }, + "DescribeTrainingPlanRequest":{ + "type":"structure", + "required":["TrainingPlanName"], + "members":{ + "TrainingPlanName":{ + "shape":"TrainingPlanName", + "documentation":"The name of the training plan to describe.
" + } + } + }, + "DescribeTrainingPlanResponse":{ + "type":"structure", + "required":[ + "TrainingPlanArn", + "TrainingPlanName", + "Status" + ], + "members":{ + "TrainingPlanArn":{ + "shape":"TrainingPlanArn", + "documentation":"The Amazon Resource Name (ARN); of the training plan.
" + }, + "TrainingPlanName":{ + "shape":"TrainingPlanName", + "documentation":"The name of the training plan.
" + }, + "Status":{ + "shape":"TrainingPlanStatus", + "documentation":"The current status of the training plan (e.g., Pending, Active, Expired). To see the complete list of status values available for a training plan, refer to the Status attribute within the TrainingPlanSummary object.
A message providing additional information about the current status of the training plan.
" + }, + "DurationHours":{ + "shape":"TrainingPlanDurationHours", + "documentation":"The number of whole hours in the total duration for this training plan.
" + }, + "DurationMinutes":{ + "shape":"TrainingPlanDurationMinutes", + "documentation":"The additional minutes beyond whole hours in the total duration for this training plan.
" + }, + "StartTime":{ + "shape":"Timestamp", + "documentation":"The start time of the training plan.
" + }, + "EndTime":{ + "shape":"Timestamp", + "documentation":"The end time of the training plan.
" + }, + "UpfrontFee":{ + "shape":"String256", + "documentation":"The upfront fee for the training plan.
" + }, + "CurrencyCode":{ + "shape":"CurrencyCode", + "documentation":"The currency code for the upfront fee (e.g., USD).
" + }, + "TotalInstanceCount":{ + "shape":"TotalInstanceCount", + "documentation":"The total number of instances reserved in this training plan.
" + }, + "AvailableInstanceCount":{ + "shape":"AvailableInstanceCount", + "documentation":"The number of instances currently available for use in this training plan.
" + }, + "InUseInstanceCount":{ + "shape":"InUseInstanceCount", + "documentation":"The number of instances currently in use from this training plan.
" + }, + "TargetResources":{ + "shape":"SageMakerResourceNames", + "documentation":"The target resources (e.g., SageMaker Training Jobs, SageMaker HyperPod) that can use this training plan.
Training plans are specific to their target resource.
A training plan designed for SageMaker training jobs can only be used to schedule and run training jobs.
A training plan for HyperPod clusters can be used exclusively to provide compute resources to a cluster's instance group.
The list of Reserved Capacity providing the underlying compute resources of the plan.
" + } + } + }, "DescribeTransformJobRequest":{ "type":"structure", "required":["TransformJobName"], @@ -18755,6 +19837,20 @@ "max":1024, "pattern":"[\\S\\s]*" }, + "ErrorInfo":{ + "type":"structure", + "members":{ + "Code":{ + "shape":"NonEmptyString64", + "documentation":"The error code for an invalid or failed operation.
" + }, + "Reason":{ + "shape":"NonEmptyString256", + "documentation":"The failure reason for the operation.
" + } + }, + "documentation":"This is an error field object that contains the error code and the reason for an operation failure.
" + }, "ExcludeFeaturesAttribute":{ "type":"string", "max":100 @@ -18952,6 +20048,32 @@ }, "documentation":"A parameter to activate explainers.
" }, + "FSxLustreFileSystem":{ + "type":"structure", + "required":["FileSystemId"], + "members":{ + "FileSystemId":{ + "shape":"FileSystemId", + "documentation":"Amazon FSx for Lustre file system ID.
" + } + }, + "documentation":"A custom file system in Amazon FSx for Lustre.
" + }, + "FSxLustreFileSystemConfig":{ + "type":"structure", + "required":["FileSystemId"], + "members":{ + "FileSystemId":{ + "shape":"FileSystemId", + "documentation":"The globally unique, 17-digit, ID of the file system, assigned by Amazon FSx for Lustre.
" + }, + "FileSystemPath":{ + "shape":"FileSystemPath", + "documentation":"The path to the file system directory that is accessible in Amazon SageMaker Studio. Permitted users can access only this directory and below.
" + } + }, + "documentation":"The settings for assigning a custom Amazon FSx for Lustre file system to a user profile or space for an Amazon SageMaker Domain.
" + }, "FailStepMetadata":{ "type":"structure", "members":{ @@ -18973,6 +20095,18 @@ "type":"string", "max":1024 }, + "FairShare":{ + "type":"string", + "enum":[ + "Enabled", + "Disabled" + ] + }, + "FairShareWeight":{ + "type":"integer", + "max":100, + "min":0 + }, "FeatureAdditions":{ "type":"list", "member":{"shape":"FeatureDefinition"}, @@ -21471,6 +22605,10 @@ } } }, + "InUseInstanceCount":{ + "type":"integer", + "min":0 + }, "InferenceComponentArn":{ "type":"string", "max":2048, @@ -22131,6 +23269,10 @@ "member":{"shape":"TrainingInputMode"}, "min":1 }, + "InstanceCount":{ + "type":"integer", + "min":1 + }, "InstanceGroup":{ "type":"structure", "required":[ @@ -22165,6 +23307,23 @@ "member":{"shape":"InstanceGroupName"}, "max":5 }, + "InstanceGroupStatus":{ + "type":"string", + "enum":[ + "InService", + "Creating", + "Updating", + "Failed", + "Degraded", + "SystemUpdating", + "Deleting" + ] + }, + "InstanceGroupTrainingPlanStatus":{ + "type":"string", + "max":63, + "min":1 + }, "InstanceGroups":{ "type":"list", "member":{"shape":"InstanceGroup"}, @@ -23583,6 +24742,60 @@ } } }, + "ListClusterSchedulerConfigsRequest":{ + "type":"structure", + "members":{ + "CreatedAfter":{ + "shape":"Timestamp", + "documentation":"Filter for after this creation time. The input for this parameter is a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter.
" + }, + "CreatedBefore":{ + "shape":"Timestamp", + "documentation":"Filter for before this creation time. The input for this parameter is a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter.
" + }, + "NameContains":{ + "shape":"EntityName", + "documentation":"Filter for name containing this string.
" + }, + "ClusterArn":{ + "shape":"ClusterArn", + "documentation":"Filter for ARN of the cluster.
" + }, + "Status":{ + "shape":"SchedulerResourceStatus", + "documentation":"Filter for status.
" + }, + "SortBy":{ + "shape":"SortClusterSchedulerConfigBy", + "documentation":"Filter for sorting the list by a given value. For example, sort by name, creation time, or status.
" + }, + "SortOrder":{ + "shape":"SortOrder", + "documentation":"The order of the list. By default, listed in Descending order according to by SortBy. To change the list order, you can specify SortOrder to be Ascending.
If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results.
" + }, + "MaxResults":{ + "shape":"MaxResults", + "documentation":"The maximum number of cluster policies to list.
" + } + } + }, + "ListClusterSchedulerConfigsResponse":{ + "type":"structure", + "members":{ + "ClusterSchedulerConfigSummaries":{ + "shape":"ClusterSchedulerConfigSummaryList", + "documentation":"Summaries of the cluster policies.
" + }, + "NextToken":{ + "shape":"NextToken", + "documentation":"If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results.
" + } + } + }, "ListClustersRequest":{ "type":"structure", "members":{ @@ -23613,6 +24826,10 @@ "SortOrder":{ "shape":"SortOrder", "documentation":"The sort order for results. The default value is Ascending.
The Amazon Resource Name (ARN); of the training plan to filter clusters by. For more information about reserving GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
Filter for after this creation time. The input for this parameter is a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter.
" + }, + "CreatedBefore":{ + "shape":"Timestamp", + "documentation":"Filter for before this creation time. The input for this parameter is a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter.
" + }, + "NameContains":{ + "shape":"EntityName", + "documentation":"Filter for name containing this string.
" + }, + "Status":{ + "shape":"SchedulerResourceStatus", + "documentation":"Filter for status.
" + }, + "ClusterArn":{ + "shape":"ClusterArn", + "documentation":"Filter for ARN of the cluster.
" + }, + "SortBy":{ + "shape":"SortQuotaBy", + "documentation":"Filter for sorting the list by a given value. For example, sort by name, creation time, or status.
" + }, + "SortOrder":{ + "shape":"SortOrder", + "documentation":"The order of the list. By default, listed in Descending order according to by SortBy. To change the list order, you can specify SortOrder to be Ascending.
If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results.
" + }, + "MaxResults":{ + "shape":"MaxResults", + "documentation":"The maximum number of compute allocation definitions to list.
" + } + } + }, + "ListComputeQuotasResponse":{ + "type":"structure", + "members":{ + "ComputeQuotaSummaries":{ + "shape":"ComputeQuotaSummaryList", + "documentation":"Summaries of the compute allocation definitions.
" + }, + "NextToken":{ + "shape":"NextToken", + "documentation":"If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results.
" + } + } + }, "ListContextsRequest":{ "type":"structure", "members":{ @@ -26160,6 +27431,32 @@ "Status" ] }, + "ListPartnerAppsRequest":{ + "type":"structure", + "members":{ + "MaxResults":{ + "shape":"MaxResults", + "documentation":"This parameter defines the maximum number of results that can be returned in a single response. The MaxResults parameter is an upper bound, not a target. If there are more results available than the value specified, a NextToken is provided in the response. The NextToken indicates that the user should get the next set of results by providing this token as a part of a subsequent call. The default value for MaxResults is 10.
If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results.
" + } + } + }, + "ListPartnerAppsResponse":{ + "type":"structure", + "members":{ + "Summaries":{ + "shape":"PartnerAppSummaries", + "documentation":"The information related to each of the SageMaker Partner AI Apps in an account.
" + }, + "NextToken":{ + "shape":"NextToken", + "documentation":"If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results.
" + } + } + }, "ListPipelineExecutionStepsRequest":{ "type":"structure", "members":{ @@ -26773,6 +28070,10 @@ "WarmPoolStatusEquals":{ "shape":"WarmPoolResourceStatus", "documentation":"A filter that retrieves only training jobs with a specific warm pool status.
" + }, + "TrainingPlanArnEquals":{ + "shape":"TrainingPlanArn", + "documentation":"The Amazon Resource Name (ARN); of the training plan to filter training jobs by. For more information about reserving GPU capacity for your SageMaker training jobs using Amazon SageMaker Training Plan, see CreateTrainingPlan .
A token to continue pagination if more results are available.
" + }, + "MaxResults":{ + "shape":"MaxResults", + "documentation":"The maximum number of results to return in the response.
", + "box":true + }, + "StartTimeAfter":{ + "shape":"Timestamp", + "documentation":"Filter to list only training plans with an actual start time after this date.
" + }, + "StartTimeBefore":{ + "shape":"Timestamp", + "documentation":"Filter to list only training plans with an actual start time before this date.
" + }, + "SortBy":{ + "shape":"TrainingPlanSortBy", + "documentation":"The training plan field to sort the results by (e.g., StartTime, Status).
" + }, + "SortOrder":{ + "shape":"TrainingPlanSortOrder", + "documentation":"The order to sort the results (Ascending or Descending).
" + }, + "Filters":{ + "shape":"TrainingPlanFilters", + "documentation":"Additional filters to apply to the list of training plans.
" + } + } + }, + "ListTrainingPlansResponse":{ + "type":"structure", + "required":["TrainingPlanSummaries"], + "members":{ + "NextToken":{ + "shape":"NextToken", + "documentation":"A token to continue pagination if more results are available.
" + }, + "TrainingPlanSummaries":{ + "shape":"TrainingPlanSummaries", + "documentation":"A list of summary information for the training plans.
" + } + } + }, "ListTransformJobsRequest":{ "type":"structure", "members":{ @@ -27415,7 +28764,12 @@ "Endpoints", "Projects", "InferenceOptimization", - "PerformanceEvaluation" + "PerformanceEvaluation", + "HyperPodClusters", + "LakeraGuard", + "Comet", + "DeepchecksLLMEvaluation", + "Fiddler" ] }, "MlflowVersion":{ @@ -30936,6 +32290,110 @@ "type":"list", "member":{"shape":"Parent"} }, + "PartnerAppAdminUserList":{ + "type":"list", + "member":{"shape":"NonEmptyString256"}, + "max":5, + "min":0 + }, + "PartnerAppArguments":{ + "type":"map", + "key":{"shape":"NonEmptyString256"}, + "value":{"shape":"String1024"}, + "max":5, + "min":0 + }, + "PartnerAppArn":{ + "type":"string", + "max":128, + "min":1, + "pattern":"arn:aws[a-z\\-]*:sagemaker:[a-z0-9\\-]*:[0-9]{12}:partner-app\\/app-[A-Z0-9]{12}$" + }, + "PartnerAppAuthType":{ + "type":"string", + "enum":["IAM"] + }, + "PartnerAppConfig":{ + "type":"structure", + "members":{ + "AdminUsers":{ + "shape":"PartnerAppAdminUserList", + "documentation":"The list of users that are given admin access to the SageMaker Partner AI App.
" + }, + "Arguments":{ + "shape":"PartnerAppArguments", + "documentation":"This is a map of required inputs for a SageMaker Partner AI App. Based on the application type, the map is populated with a key and value pair that is specific to the user and application.
" + } + }, + "documentation":"Configuration settings for the SageMaker Partner AI App.
" + }, + "PartnerAppMaintenanceConfig":{ + "type":"structure", + "members":{ + "MaintenanceWindowStart":{ + "shape":"WeeklyScheduleTimeFormat", + "documentation":"The day and time of the week in Coordinated Universal Time (UTC) 24-hour standard time that weekly maintenance updates are scheduled. This value must take the following format: 3-letter-day:24-h-hour:minute. For example: TUE:03:30.
Maintenance configuration settings for the SageMaker Partner AI App.
" + }, + "PartnerAppName":{ + "type":"string", + "max":256, + "min":1, + "pattern":"^[a-zA-Z0-9]+" + }, + "PartnerAppStatus":{ + "type":"string", + "enum":[ + "Creating", + "Updating", + "Deleting", + "Available", + "Failed", + "UpdateFailed", + "Deleted" + ] + }, + "PartnerAppSummaries":{ + "type":"list", + "member":{"shape":"PartnerAppSummary"} + }, + "PartnerAppSummary":{ + "type":"structure", + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App.
" + }, + "Name":{ + "shape":"PartnerAppName", + "documentation":"The name of the SageMaker Partner AI App.
" + }, + "Type":{ + "shape":"PartnerAppType", + "documentation":"The type of SageMaker Partner AI App to create. Must be one of the following: lakera-guard, comet, deepchecks-llm-evaluation, or fiddler.
The status of the SageMaker Partner AI App.
" + }, + "CreationTime":{ + "shape":"Timestamp", + "documentation":"The creation time of the SageMaker Partner AI App.
" + } + }, + "documentation":"A subset of information related to a SageMaker Partner AI App. This information is used as part of the ListPartnerApps API response.
A specification for a predefined metric.
" }, + "PreemptTeamTasks":{ + "type":"string", + "enum":[ + "Never", + "LowerPriority" + ] + }, "PresignedDomainUrl":{"type":"string"}, + "PriorityClass":{ + "type":"structure", + "required":[ + "Name", + "Weight" + ], + "members":{ + "Name":{ + "shape":"ClusterSchedulerPriorityClassName", + "documentation":"Name of the priority class.
" + }, + "Weight":{ + "shape":"PriorityWeight", + "documentation":"Weight of the priority class. The value is within a range from 0 to 100, where 0 is the default.
A weight of 0 is the lowest priority and 100 is the highest. Weight 0 is the default.
" + } + }, + "documentation":"Priority class configuration. When included in PriorityClasses, these class configurations define how tasks are queued.
The instance type for the reserved capacity offering.
" + }, + "InstanceCount":{ + "shape":"ReservedCapacityInstanceCount", + "documentation":"The number of instances in the reserved capacity offering.
" + }, + "AvailabilityZone":{ + "shape":"AvailabilityZone", + "documentation":"The availability zone for the reserved capacity offering.
" + }, + "DurationHours":{ + "shape":"ReservedCapacityDurationHours", + "documentation":"The number of whole hours in the total duration for this reserved capacity offering.
" + }, + "DurationMinutes":{ + "shape":"ReservedCapacityDurationMinutes", + "documentation":"The additional minutes beyond whole hours in the total duration for this reserved capacity offering.
" + }, + "StartTime":{ + "shape":"Timestamp", + "documentation":"The start time of the reserved capacity offering.
" + }, + "EndTime":{ + "shape":"Timestamp", + "documentation":"The end time of the reserved capacity offering.
" + } + }, + "documentation":"Details about a reserved capacity offering for a training plan offering.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
The Amazon Resource Name (ARN); of the reserved capacity.
" + }, + "InstanceType":{ + "shape":"ReservedCapacityInstanceType", + "documentation":"The instance type for the reserved capacity.
" + }, + "TotalInstanceCount":{ + "shape":"TotalInstanceCount", + "documentation":"The total number of instances in the reserved capacity.
" + }, + "Status":{ + "shape":"ReservedCapacityStatus", + "documentation":"The current status of the reserved capacity.
" + }, + "AvailabilityZone":{ + "shape":"AvailabilityZone", + "documentation":"The availability zone for the reserved capacity.
" + }, + "DurationHours":{ + "shape":"ReservedCapacityDurationHours", + "documentation":"The number of whole hours in the total duration for this reserved capacity.
" + }, + "DurationMinutes":{ + "shape":"ReservedCapacityDurationMinutes", + "documentation":"The additional minutes beyond whole hours in the total duration for this reserved capacity.
" + }, + "StartTime":{ + "shape":"Timestamp", + "documentation":"The start time of the reserved capacity.
" + }, + "EndTime":{ + "shape":"Timestamp", + "documentation":"The end time of the reserved capacity.
" + } + }, + "documentation":"Details of a reserved capacity for the training plan.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
The configuration of a heterogeneous cluster in JSON format.
" + }, + "TrainingPlanArn":{ + "shape":"TrainingPlanArn", + "documentation":"The Amazon Resource Name (ARN); of the training plan to use for this resource configuration.
" } }, "documentation":"Describes the resources, including machine learning (ML) compute instances and ML storage volumes, to use for model training.
" @@ -33941,6 +35578,29 @@ "documentation":"Optional. Indicates how many seconds the resource stayed in ResourceRetained state. Populated only after resource reaches ResourceReused or ResourceReleased state.", "min":0 }, + "ResourceSharingConfig":{ + "type":"structure", + "required":["Strategy"], + "members":{ + "Strategy":{ + "shape":"ResourceSharingStrategy", + "documentation":"The strategy of how idle compute is shared within the cluster. The following are the options of strategies.
DontLend: entities do not lend idle compute.
Lend: entities can lend idle compute to entities that can borrow.
LendandBorrow: entities can lend idle compute and borrow idle compute from other entities.
Default is LendandBorrow.
The limit on how much idle compute can be borrowed.The values can be 1 - 500 percent of idle compute that the team is allowed to borrow.
Default is 50.
Resource sharing configuration.
" + }, + "ResourceSharingStrategy":{ + "type":"string", + "enum":[ + "Lend", + "DontLend", + "LendAndBorrow" + ] + }, "ResourceSpec":{ "type":"structure", "members":{ @@ -34277,6 +35937,18 @@ "max":255, "pattern":"^arn:[a-z0-9-\\.]{1,63}:sagemaker:\\w+(?:-\\w+)+:aws:hub-content\\/[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}\\/Model\\/[a-zA-Z0-9](-*[a-zA-Z0-9]){0,63}$" }, + "SageMakerResourceName":{ + "type":"string", + "enum":[ + "training-job", + "hyperpod-cluster" + ] + }, + "SageMakerResourceNames":{ + "type":"list", + "member":{"shape":"SageMakerResourceName"}, + "min":1 + }, "SagemakerServicecatalogStatus":{ "type":"string", "enum":[ @@ -34371,6 +36043,37 @@ "Stopped" ] }, + "SchedulerConfig":{ + "type":"structure", + "members":{ + "PriorityClasses":{ + "shape":"PriorityClassList", + "documentation":"List of the priority classes, PriorityClass, of the cluster policy. When specified, these class configurations define how tasks are queued.
When enabled, entities borrow idle compute based on their assigned FairShareWeight.
When disabled, entities borrow idle compute based on a first-come first-serve basis.
Default is Enabled.
Cluster policy configuration. This policy is used for task prioritization and fair-share allocation. This helps prioritize critical workloads and distributes idle compute across entities.
" + }, + "SchedulerResourceStatus":{ + "type":"string", + "enum":[ + "Creating", + "CreateFailed", + "CreateRollbackFailed", + "Created", + "Updating", + "UpdateFailed", + "UpdateRollbackFailed", + "Updated", + "Deleting", + "DeleteFailed", + "DeleteRollbackFailed", + "Deleted" + ] + }, "Scope":{ "type":"string", "max":1024, @@ -34512,6 +36215,50 @@ "Descending" ] }, + "SearchTrainingPlanOfferingsRequest":{ + "type":"structure", + "required":[ + "InstanceType", + "InstanceCount", + "TargetResources" + ], + "members":{ + "InstanceType":{ + "shape":"ReservedCapacityInstanceType", + "documentation":"The type of instance you want to search for in the available training plan offerings. This field allows you to filter the search results based on the specific compute resources you require for your SageMaker training jobs or SageMaker HyperPod clusters. When searching for training plan offerings, specifying the instance type helps you find Reserved Instances that match your computational needs.
" + }, + "InstanceCount":{ + "shape":"ReservedCapacityInstanceCount", + "documentation":"The number of instances you want to reserve in the training plan offerings. This allows you to specify the quantity of compute resources needed for your SageMaker training jobs or SageMaker HyperPod clusters, helping you find reserved capacity offerings that match your requirements.
" + }, + "StartTimeAfter":{ + "shape":"Timestamp", + "documentation":"A filter to search for training plan offerings with a start time after a specified date.
" + }, + "EndTimeBefore":{ + "shape":"Timestamp", + "documentation":"A filter to search for reserved capacity offerings with an end time before a specified date.
" + }, + "DurationHours":{ + "shape":"TrainingPlanDurationHoursInput", + "documentation":"The desired duration in hours for the training plan offerings.
" + }, + "TargetResources":{ + "shape":"SageMakerResourceNames", + "documentation":"The target resources (e.g., SageMaker Training Jobs, SageMaker HyperPod) to search for in the offerings.
Training plans are specific to their target resource.
A training plan designed for SageMaker training jobs can only be used to schedule and run training jobs.
A training plan for HyperPod clusters can be used exclusively to provide compute resources to a cluster's instance group.
A list of training plan offerings that match the search criteria.
" + } + } + }, "SecondaryStatus":{ "type":"string", "enum":[ @@ -34893,6 +36640,14 @@ "Status" ] }, + "SortClusterSchedulerConfigBy":{ + "type":"string", + "enum":[ + "Name", + "CreationTime", + "Status" + ] + }, "SortContextsBy":{ "type":"string", "enum":[ @@ -34943,6 +36698,15 @@ "CreationTime" ] }, + "SortQuotaBy":{ + "type":"string", + "enum":[ + "Name", + "CreationTime", + "Status", + "ClusterArn" + ] + }, "SortTrackingServerBy":{ "type":"string", "enum":[ @@ -35700,6 +37464,10 @@ "min":1, "pattern":".+" }, + "String2048":{ + "type":"string", + "max":2048 + }, "String256":{ "type":"string", "max":256 @@ -36424,6 +38192,10 @@ "max":256, "min":1 }, + "TotalInstanceCount":{ + "type":"integer", + "min":0 + }, "TrackingServerArn":{ "type":"string", "max":2048, @@ -36673,6 +38445,7 @@ "ml.p4de.24xlarge", "ml.p5.48xlarge", "ml.p5e.48xlarge", + "ml.p5en.48xlarge", "ml.c5.xlarge", "ml.c5.2xlarge", "ml.c5.4xlarge", @@ -36710,6 +38483,7 @@ "ml.trn1.2xlarge", "ml.trn1.32xlarge", "ml.trn1n.32xlarge", + "ml.trn2.48xlarge", "ml.m6i.large", "ml.m6i.xlarge", "ml.m6i.2xlarge", @@ -37051,10 +38825,235 @@ "WarmPoolStatus":{ "shape":"WarmPoolStatus", "documentation":"The status of the warm pool associated with the training job.
" + }, + "TrainingPlanArn":{ + "shape":"TrainingPlanArn", + "documentation":"The Amazon Resource Name (ARN); of the training plan associated with this training job.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
Provides summary information about a training job.
" }, + "TrainingPlanArn":{ + "type":"string", + "max":2048, + "min":50, + "pattern":"arn:aws[a-z\\-]*:sagemaker:[a-z0-9\\-]*:[0-9]{12}:training-plan/.*" + }, + "TrainingPlanArns":{ + "type":"list", + "member":{"shape":"TrainingPlanArn"} + }, + "TrainingPlanDurationHours":{ + "type":"long", + "max":87600, + "min":0 + }, + "TrainingPlanDurationHoursInput":{ + "type":"long", + "max":87600, + "min":1 + }, + "TrainingPlanDurationMinutes":{ + "type":"long", + "max":59, + "min":0 + }, + "TrainingPlanFilter":{ + "type":"structure", + "required":[ + "Name", + "Value" + ], + "members":{ + "Name":{ + "shape":"TrainingPlanFilterName", + "documentation":"The name of the filter field (e.g., Status, InstanceType).
" + }, + "Value":{ + "shape":"String64", + "documentation":"The value to filter by for the specified field.
" + } + }, + "documentation":"A filter to apply when listing or searching for training plans.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
The unique identifier for this training plan offering.
" + }, + "TargetResources":{ + "shape":"SageMakerResourceNames", + "documentation":"The target resources (e.g., SageMaker Training Jobs, SageMaker HyperPod) for this training plan offering.
Training plans are specific to their target resource.
A training plan designed for SageMaker training jobs can only be used to schedule and run training jobs.
A training plan for HyperPod clusters can be used exclusively to provide compute resources to a cluster's instance group.
The requested start time that the user specified when searching for the training plan offering.
" + }, + "RequestedEndTimeBefore":{ + "shape":"Timestamp", + "documentation":"The requested end time that the user specified when searching for the training plan offering.
" + }, + "DurationHours":{ + "shape":"TrainingPlanDurationHours", + "documentation":"The number of whole hours in the total duration for this training plan offering.
" + }, + "DurationMinutes":{ + "shape":"TrainingPlanDurationMinutes", + "documentation":"The additional minutes beyond whole hours in the total duration for this training plan offering.
" + }, + "UpfrontFee":{ + "shape":"String256", + "documentation":"The upfront fee for this training plan offering.
" + }, + "CurrencyCode":{ + "shape":"CurrencyCode", + "documentation":"The currency code for the upfront fee (e.g., USD).
" + }, + "ReservedCapacityOfferings":{ + "shape":"ReservedCapacityOfferings", + "documentation":"A list of reserved capacity offerings associated with this training plan offering.
" + } + }, + "documentation":"Details about a training plan offering.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
The Amazon Resource Name (ARN); of the training plan.
" + }, + "TrainingPlanName":{ + "shape":"TrainingPlanName", + "documentation":"The name of the training plan.
" + }, + "Status":{ + "shape":"TrainingPlanStatus", + "documentation":"The current status of the training plan (e.g., Pending, Active, Expired). To see the complete list of status values available for a training plan, refer to the Status attribute within the TrainingPlanSummary object.
A message providing additional information about the current status of the training plan.
" + }, + "DurationHours":{ + "shape":"TrainingPlanDurationHours", + "documentation":"The number of whole hours in the total duration for this training plan.
" + }, + "DurationMinutes":{ + "shape":"TrainingPlanDurationMinutes", + "documentation":"The additional minutes beyond whole hours in the total duration for this training plan.
" + }, + "StartTime":{ + "shape":"Timestamp", + "documentation":"The start time of the training plan.
" + }, + "EndTime":{ + "shape":"Timestamp", + "documentation":"The end time of the training plan.
" + }, + "UpfrontFee":{ + "shape":"String256", + "documentation":"The upfront fee for the training plan.
" + }, + "CurrencyCode":{ + "shape":"CurrencyCode", + "documentation":"The currency code for the upfront fee (e.g., USD).
" + }, + "TotalInstanceCount":{ + "shape":"TotalInstanceCount", + "documentation":"The total number of instances reserved in this training plan.
" + }, + "AvailableInstanceCount":{ + "shape":"AvailableInstanceCount", + "documentation":"The number of instances currently available for use in this training plan.
" + }, + "InUseInstanceCount":{ + "shape":"InUseInstanceCount", + "documentation":"The number of instances currently in use from this training plan.
" + }, + "TargetResources":{ + "shape":"SageMakerResourceNames", + "documentation":"The target resources (e.g., training jobs, HyperPod clusters) that can use this training plan.
Training plans are specific to their target resource.
A training plan designed for SageMaker training jobs can only be used to schedule and run training jobs.
A training plan for HyperPod clusters can be used exclusively to provide compute resources to a cluster's instance group.
A list of reserved capacities associated with this training plan, including details such as instance types, counts, and availability zones.
" + } + }, + "documentation":"Details of the training plan.
For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan .
ID of the cluster policy.
" + }, + "TargetVersion":{ + "shape":"Integer", + "documentation":"Target version.
" + }, + "SchedulerConfig":{ + "shape":"SchedulerConfig", + "documentation":"Cluster policy configuration.
" + }, + "Description":{ + "shape":"EntityDescription", + "documentation":"Description of the cluster policy.
" + } + } + }, + "UpdateClusterSchedulerConfigResponse":{ + "type":"structure", + "required":[ + "ClusterSchedulerConfigArn", + "ClusterSchedulerConfigVersion" + ], + "members":{ + "ClusterSchedulerConfigArn":{ + "shape":"ClusterSchedulerConfigArn", + "documentation":"ARN of the cluster policy.
" + }, + "ClusterSchedulerConfigVersion":{ + "shape":"Integer", + "documentation":"Version of the cluster policy.
" + } + } + }, "UpdateClusterSoftwareRequest":{ "type":"structure", "required":["ClusterName"], @@ -38281,6 +40322,56 @@ } } }, + "UpdateComputeQuotaRequest":{ + "type":"structure", + "required":[ + "ComputeQuotaId", + "TargetVersion" + ], + "members":{ + "ComputeQuotaId":{ + "shape":"ComputeQuotaId", + "documentation":"ID of the compute allocation definition.
" + }, + "TargetVersion":{ + "shape":"Integer", + "documentation":"Target version.
" + }, + "ComputeQuotaConfig":{ + "shape":"ComputeQuotaConfig", + "documentation":"Configuration of the compute allocation definition. This includes the resource sharing option, and the setting to preempt low priority tasks.
" + }, + "ComputeQuotaTarget":{ + "shape":"ComputeQuotaTarget", + "documentation":"The target entity to allocate compute resources to.
" + }, + "ActivationState":{ + "shape":"ActivationState", + "documentation":"The state of the compute allocation being described. Use to enable or disable compute allocation.
Default is Enabled.
Description of the compute allocation definition.
" + } + } + }, + "UpdateComputeQuotaResponse":{ + "type":"structure", + "required":[ + "ComputeQuotaArn", + "ComputeQuotaVersion" + ], + "members":{ + "ComputeQuotaArn":{ + "shape":"ComputeQuotaArn", + "documentation":"ARN of the compute allocation definition.
" + }, + "ComputeQuotaVersion":{ + "shape":"Integer", + "documentation":"Version of the compute allocation definition.
" + } + } + }, "UpdateContextRequest":{ "type":"structure", "required":["ContextName"], @@ -39066,6 +41157,51 @@ "members":{ } }, + "UpdatePartnerAppRequest":{ + "type":"structure", + "required":["Arn"], + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App to update.
" + }, + "MaintenanceConfig":{ + "shape":"PartnerAppMaintenanceConfig", + "documentation":"Maintenance configuration settings for the SageMaker Partner AI App.
" + }, + "Tier":{ + "shape":"NonEmptyString64", + "documentation":"Indicates the instance type and size of the cluster attached to the SageMaker Partner AI App.
" + }, + "ApplicationConfig":{ + "shape":"PartnerAppConfig", + "documentation":"Configuration settings for the SageMaker Partner AI App.
" + }, + "EnableIamSessionBasedIdentity":{ + "shape":"Boolean", + "documentation":"When set to TRUE, the SageMaker Partner AI App sets the Amazon Web Services IAM session name or the authenticated IAM user as the identity of the SageMaker Partner AI App user.
A unique token that guarantees that the call to this API is idempotent.
", + "idempotencyToken":true + }, + "Tags":{ + "shape":"TagList", + "documentation":"Each tag consists of a key and an optional value. Tag keys must be unique per resource.
" + } + } + }, + "UpdatePartnerAppResponse":{ + "type":"structure", + "members":{ + "Arn":{ + "shape":"PartnerAppArn", + "documentation":"The ARN of the SageMaker Partner AI App that was updated.
" + } + } + }, "UpdatePipelineExecutionRequest":{ "type":"structure", "required":["PipelineExecutionArn"], @@ -39820,6 +41956,11 @@ "max":9, "pattern":"(Mon|Tue|Wed|Thu|Fri|Sat|Sun):([01]\\d|2[0-3]):([0-5]\\d)" }, + "WeeklyScheduleTimeFormat":{ + "type":"string", + "max":9, + "pattern":"(Mon|Tue|Wed|Thu|Fri|Sat|Sun):([01]\\d|2[0-3]):([0-5]\\d)" + }, "WorkerAccessConfiguration":{ "type":"structure", "members":{ diff --git a/src/sagemaker_core/main/code_injection/shape_dag.py b/src/sagemaker_core/main/code_injection/shape_dag.py index e7bbcd1c..c249266b 100644 --- a/src/sagemaker_core/main/code_injection/shape_dag.py +++ b/src/sagemaker_core/main/code_injection/shape_dag.py @@ -1263,6 +1263,13 @@ "type": "list", }, {"name": "OnStartDeepHealthChecks", "shape": "OnStartDeepHealthChecks", "type": "list"}, + {"name": "Status", "shape": "InstanceGroupStatus", "type": "string"}, + {"name": "TrainingPlanArn", "shape": "TrainingPlanArn", "type": "string"}, + { + "name": "TrainingPlanStatus", + "shape": "InstanceGroupTrainingPlanStatus", + "type": "string", + }, {"name": "OverrideVpcConfig", "shape": "VpcConfig", "type": "structure"}, ], "type": "structure", @@ -1286,6 +1293,7 @@ "type": "list", }, {"name": "OnStartDeepHealthChecks", "shape": "OnStartDeepHealthChecks", "type": "list"}, + {"name": "TrainingPlanArn", "shape": "TrainingPlanArn", "type": "string"}, {"name": "OverrideVpcConfig", "shape": "VpcConfig", "type": "structure"}, ], "type": "structure", @@ -1380,6 +1388,32 @@ "members": [{"name": "ClusterArn", "shape": "EksClusterArn", "type": "string"}], "type": "structure", }, + "ClusterSchedulerConfigSummary": { + "members": [ + { + "name": "ClusterSchedulerConfigArn", + "shape": "ClusterSchedulerConfigArn", + "type": "string", + }, + { + "name": "ClusterSchedulerConfigId", + "shape": "ClusterSchedulerConfigId", + "type": "string", + }, + {"name": "ClusterSchedulerConfigVersion", "shape": "Integer", "type": "integer"}, + {"name": "Name", "shape": "EntityName", "type": "string"}, + {"name": "CreationTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "LastModifiedTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "Status", "shape": "SchedulerResourceStatus", "type": "string"}, + {"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}, + ], + "type": "structure", + }, + "ClusterSchedulerConfigSummaryList": { + "member_shape": "ClusterSchedulerConfigSummary", + "member_type": "structure", + "type": "list", + }, "ClusterSummaries": { "member_shape": "ClusterSummary", "member_type": "structure", @@ -1391,6 +1425,7 @@ {"name": "ClusterName", "shape": "ClusterName", "type": "string"}, {"name": "CreationTime", "shape": "Timestamp", "type": "timestamp"}, {"name": "ClusterStatus", "shape": "ClusterStatus", "type": "string"}, + {"name": "TrainingPlanArns", "shape": "TrainingPlanArns", "type": "list"}, ], "type": "structure", }, @@ -1515,6 +1550,62 @@ "member_type": "string", "type": "list", }, + "ComputeQuotaConfig": { + "members": [ + { + "name": "ComputeQuotaResources", + "shape": "ComputeQuotaResourceConfigList", + "type": "list", + }, + { + "name": "ResourceSharingConfig", + "shape": "ResourceSharingConfig", + "type": "structure", + }, + {"name": "PreemptTeamTasks", "shape": "PreemptTeamTasks", "type": "string"}, + ], + "type": "structure", + }, + "ComputeQuotaResourceConfig": { + "members": [ + {"name": "InstanceType", "shape": "ClusterInstanceType", "type": "string"}, + {"name": "Count", "shape": "InstanceCount", "type": "integer"}, + ], + "type": "structure", + }, + "ComputeQuotaResourceConfigList": { + "member_shape": "ComputeQuotaResourceConfig", + "member_type": "structure", + "type": "list", + }, + "ComputeQuotaSummary": { + "members": [ + {"name": "ComputeQuotaArn", "shape": "ComputeQuotaArn", "type": "string"}, + {"name": "ComputeQuotaId", "shape": "ComputeQuotaId", "type": "string"}, + {"name": "Name", "shape": "EntityName", "type": "string"}, + {"name": "ComputeQuotaVersion", "shape": "Integer", "type": "integer"}, + {"name": "Status", "shape": "SchedulerResourceStatus", "type": "string"}, + {"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}, + {"name": "ComputeQuotaConfig", "shape": "ComputeQuotaConfig", "type": "structure"}, + {"name": "ComputeQuotaTarget", "shape": "ComputeQuotaTarget", "type": "structure"}, + {"name": "ActivationState", "shape": "ActivationState", "type": "string"}, + {"name": "CreationTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "LastModifiedTime", "shape": "Timestamp", "type": "timestamp"}, + ], + "type": "structure", + }, + "ComputeQuotaSummaryList": { + "member_shape": "ComputeQuotaSummary", + "member_type": "structure", + "type": "list", + }, + "ComputeQuotaTarget": { + "members": [ + {"name": "TeamName", "shape": "ComputeQuotaTargetTeamName", "type": "string"}, + {"name": "FairShareWeight", "shape": "FairShareWeight", "type": "integer"}, + ], + "type": "structure", + }, "ConditionStepMetadata": { "members": [{"name": "Outcome", "shape": "ConditionOutcome", "type": "string"}], "type": "structure", @@ -1811,6 +1902,31 @@ "members": [{"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}], "type": "structure", }, + "CreateClusterSchedulerConfigRequest": { + "members": [ + {"name": "Name", "shape": "EntityName", "type": "string"}, + {"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}, + {"name": "SchedulerConfig", "shape": "SchedulerConfig", "type": "structure"}, + {"name": "Description", "shape": "EntityDescription", "type": "string"}, + {"name": "Tags", "shape": "TagList", "type": "list"}, + ], + "type": "structure", + }, + "CreateClusterSchedulerConfigResponse": { + "members": [ + { + "name": "ClusterSchedulerConfigArn", + "shape": "ClusterSchedulerConfigArn", + "type": "string", + }, + { + "name": "ClusterSchedulerConfigId", + "shape": "ClusterSchedulerConfigId", + "type": "string", + }, + ], + "type": "structure", + }, "CreateCodeRepositoryInput": { "members": [ {"name": "CodeRepositoryName", "shape": "EntityName", "type": "string"}, @@ -1840,6 +1956,25 @@ "members": [{"name": "CompilationJobArn", "shape": "CompilationJobArn", "type": "string"}], "type": "structure", }, + "CreateComputeQuotaRequest": { + "members": [ + {"name": "Name", "shape": "EntityName", "type": "string"}, + {"name": "Description", "shape": "EntityDescription", "type": "string"}, + {"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}, + {"name": "ComputeQuotaConfig", "shape": "ComputeQuotaConfig", "type": "structure"}, + {"name": "ComputeQuotaTarget", "shape": "ComputeQuotaTarget", "type": "structure"}, + {"name": "ActivationState", "shape": "ActivationState", "type": "string"}, + {"name": "Tags", "shape": "TagList", "type": "list"}, + ], + "type": "structure", + }, + "CreateComputeQuotaResponse": { + "members": [ + {"name": "ComputeQuotaArn", "shape": "ComputeQuotaArn", "type": "string"}, + {"name": "ComputeQuotaId", "shape": "ComputeQuotaId", "type": "string"}, + ], + "type": "structure", + }, "CreateContextRequest": { "members": [ {"name": "ContextName", "shape": "ContextName", "type": "string"}, @@ -2648,6 +2783,45 @@ ], "type": "structure", }, + "CreatePartnerAppPresignedUrlRequest": { + "members": [ + {"name": "Arn", "shape": "PartnerAppArn", "type": "string"}, + {"name": "ExpiresInSeconds", "shape": "ExpiresInSeconds", "type": "integer"}, + { + "name": "SessionExpirationDurationInSeconds", + "shape": "SessionExpirationDurationInSeconds", + "type": "integer", + }, + ], + "type": "structure", + }, + "CreatePartnerAppPresignedUrlResponse": { + "members": [{"name": "Url", "shape": "String2048", "type": "string"}], + "type": "structure", + }, + "CreatePartnerAppRequest": { + "members": [ + {"name": "Name", "shape": "PartnerAppName", "type": "string"}, + {"name": "Type", "shape": "PartnerAppType", "type": "string"}, + {"name": "ExecutionRoleArn", "shape": "RoleArn", "type": "string"}, + { + "name": "MaintenanceConfig", + "shape": "PartnerAppMaintenanceConfig", + "type": "structure", + }, + {"name": "Tier", "shape": "NonEmptyString64", "type": "string"}, + {"name": "ApplicationConfig", "shape": "PartnerAppConfig", "type": "structure"}, + {"name": "AuthType", "shape": "PartnerAppAuthType", "type": "string"}, + {"name": "EnableIamSessionBasedIdentity", "shape": "Boolean", "type": "boolean"}, + {"name": "ClientToken", "shape": "ClientToken", "type": "string"}, + {"name": "Tags", "shape": "TagList", "type": "list"}, + ], + "type": "structure", + }, + "CreatePartnerAppResponse": { + "members": [{"name": "Arn", "shape": "PartnerAppArn", "type": "string"}], + "type": "structure", + }, "CreatePipelineRequest": { "members": [ {"name": "PipelineName", "shape": "PipelineName", "type": "string"}, @@ -2873,6 +3047,18 @@ "members": [{"name": "TrainingJobArn", "shape": "TrainingJobArn", "type": "string"}], "type": "structure", }, + "CreateTrainingPlanRequest": { + "members": [ + {"name": "TrainingPlanName", "shape": "TrainingPlanName", "type": "string"}, + {"name": "TrainingPlanOfferingId", "shape": "TrainingPlanOfferingId", "type": "string"}, + {"name": "Tags", "shape": "TagList", "type": "list"}, + ], + "type": "structure", + }, + "CreateTrainingPlanResponse": { + "members": [{"name": "TrainingPlanArn", "shape": "TrainingPlanArn", "type": "string"}], + "type": "structure", + }, "CreateTransformJobRequest": { "members": [ {"name": "TransformJobName", "shape": "TransformJobName", "type": "string"}, @@ -2997,12 +3183,20 @@ }, "CsvContentTypes": {"member_shape": "CsvContentType", "member_type": "string", "type": "list"}, "CustomFileSystem": { - "members": [{"name": "EFSFileSystem", "shape": "EFSFileSystem", "type": "structure"}], + "members": [ + {"name": "EFSFileSystem", "shape": "EFSFileSystem", "type": "structure"}, + {"name": "FSxLustreFileSystem", "shape": "FSxLustreFileSystem", "type": "structure"}, + ], "type": "structure", }, "CustomFileSystemConfig": { "members": [ - {"name": "EFSFileSystemConfig", "shape": "EFSFileSystemConfig", "type": "structure"} + {"name": "EFSFileSystemConfig", "shape": "EFSFileSystemConfig", "type": "structure"}, + { + "name": "FSxLustreFileSystemConfig", + "shape": "FSxLustreFileSystemConfig", + "type": "structure", + }, ], "type": "structure", }, @@ -3333,6 +3527,16 @@ "members": [{"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}], "type": "structure", }, + "DeleteClusterSchedulerConfigRequest": { + "members": [ + { + "name": "ClusterSchedulerConfigId", + "shape": "ClusterSchedulerConfigId", + "type": "string", + } + ], + "type": "structure", + }, "DeleteCodeRepositoryInput": { "members": [{"name": "CodeRepositoryName", "shape": "EntityName", "type": "string"}], "type": "structure", @@ -3341,6 +3545,10 @@ "members": [{"name": "CompilationJobName", "shape": "EntityName", "type": "string"}], "type": "structure", }, + "DeleteComputeQuotaRequest": { + "members": [{"name": "ComputeQuotaId", "shape": "ComputeQuotaId", "type": "string"}], + "type": "structure", + }, "DeleteContextRequest": { "members": [{"name": "ContextName", "shape": "ContextName", "type": "string"}], "type": "structure", @@ -3546,6 +3754,17 @@ "members": [{"name": "OptimizationJobName", "shape": "EntityName", "type": "string"}], "type": "structure", }, + "DeletePartnerAppRequest": { + "members": [ + {"name": "Arn", "shape": "PartnerAppArn", "type": "string"}, + {"name": "ClientToken", "shape": "ClientToken", "type": "string"}, + ], + "type": "structure", + }, + "DeletePartnerAppResponse": { + "members": [{"name": "Arn", "shape": "PartnerAppArn", "type": "string"}], + "type": "structure", + }, "DeletePipelineRequest": { "members": [ {"name": "PipelineName", "shape": "PipelineName", "type": "string"}, @@ -3985,6 +4204,43 @@ ], "type": "structure", }, + "DescribeClusterSchedulerConfigRequest": { + "members": [ + { + "name": "ClusterSchedulerConfigId", + "shape": "ClusterSchedulerConfigId", + "type": "string", + }, + {"name": "ClusterSchedulerConfigVersion", "shape": "Integer", "type": "integer"}, + ], + "type": "structure", + }, + "DescribeClusterSchedulerConfigResponse": { + "members": [ + { + "name": "ClusterSchedulerConfigArn", + "shape": "ClusterSchedulerConfigArn", + "type": "string", + }, + { + "name": "ClusterSchedulerConfigId", + "shape": "ClusterSchedulerConfigId", + "type": "string", + }, + {"name": "Name", "shape": "EntityName", "type": "string"}, + {"name": "ClusterSchedulerConfigVersion", "shape": "Integer", "type": "integer"}, + {"name": "Status", "shape": "SchedulerResourceStatus", "type": "string"}, + {"name": "FailureReason", "shape": "FailureReason", "type": "string"}, + {"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}, + {"name": "SchedulerConfig", "shape": "SchedulerConfig", "type": "structure"}, + {"name": "Description", "shape": "EntityDescription", "type": "string"}, + {"name": "CreationTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "CreatedBy", "shape": "UserContext", "type": "structure"}, + {"name": "LastModifiedTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "LastModifiedBy", "shape": "UserContext", "type": "structure"}, + ], + "type": "structure", + }, "DescribeCodeRepositoryInput": { "members": [{"name": "CodeRepositoryName", "shape": "EntityName", "type": "string"}], "type": "structure", @@ -4026,6 +4282,33 @@ ], "type": "structure", }, + "DescribeComputeQuotaRequest": { + "members": [ + {"name": "ComputeQuotaId", "shape": "ComputeQuotaId", "type": "string"}, + {"name": "ComputeQuotaVersion", "shape": "Integer", "type": "integer"}, + ], + "type": "structure", + }, + "DescribeComputeQuotaResponse": { + "members": [ + {"name": "ComputeQuotaArn", "shape": "ComputeQuotaArn", "type": "string"}, + {"name": "ComputeQuotaId", "shape": "ComputeQuotaId", "type": "string"}, + {"name": "Name", "shape": "EntityName", "type": "string"}, + {"name": "Description", "shape": "EntityDescription", "type": "string"}, + {"name": "ComputeQuotaVersion", "shape": "Integer", "type": "integer"}, + {"name": "Status", "shape": "SchedulerResourceStatus", "type": "string"}, + {"name": "FailureReason", "shape": "FailureReason", "type": "string"}, + {"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}, + {"name": "ComputeQuotaConfig", "shape": "ComputeQuotaConfig", "type": "structure"}, + {"name": "ComputeQuotaTarget", "shape": "ComputeQuotaTarget", "type": "structure"}, + {"name": "ActivationState", "shape": "ActivationState", "type": "string"}, + {"name": "CreationTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "CreatedBy", "shape": "UserContext", "type": "structure"}, + {"name": "LastModifiedTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "LastModifiedBy", "shape": "UserContext", "type": "structure"}, + ], + "type": "structure", + }, "DescribeContextRequest": { "members": [{"name": "ContextName", "shape": "ContextNameOrArn", "type": "string"}], "type": "structure", @@ -5188,6 +5471,33 @@ ], "type": "structure", }, + "DescribePartnerAppRequest": { + "members": [{"name": "Arn", "shape": "PartnerAppArn", "type": "string"}], + "type": "structure", + }, + "DescribePartnerAppResponse": { + "members": [ + {"name": "Arn", "shape": "PartnerAppArn", "type": "string"}, + {"name": "Name", "shape": "PartnerAppName", "type": "string"}, + {"name": "Type", "shape": "PartnerAppType", "type": "string"}, + {"name": "Status", "shape": "PartnerAppStatus", "type": "string"}, + {"name": "CreationTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "ExecutionRoleArn", "shape": "RoleArn", "type": "string"}, + {"name": "BaseUrl", "shape": "String2048", "type": "string"}, + { + "name": "MaintenanceConfig", + "shape": "PartnerAppMaintenanceConfig", + "type": "structure", + }, + {"name": "Tier", "shape": "NonEmptyString64", "type": "string"}, + {"name": "Version", "shape": "NonEmptyString64", "type": "string"}, + {"name": "ApplicationConfig", "shape": "PartnerAppConfig", "type": "structure"}, + {"name": "AuthType", "shape": "PartnerAppAuthType", "type": "string"}, + {"name": "EnableIamSessionBasedIdentity", "shape": "Boolean", "type": "boolean"}, + {"name": "Error", "shape": "ErrorInfo", "type": "structure"}, + ], + "type": "structure", + }, "DescribePipelineDefinitionForExecutionRequest": { "members": [ {"name": "PipelineExecutionArn", "shape": "PipelineExecutionArn", "type": "string"} @@ -5493,6 +5803,38 @@ ], "type": "structure", }, + "DescribeTrainingPlanRequest": { + "members": [{"name": "TrainingPlanName", "shape": "TrainingPlanName", "type": "string"}], + "type": "structure", + }, + "DescribeTrainingPlanResponse": { + "members": [ + {"name": "TrainingPlanArn", "shape": "TrainingPlanArn", "type": "string"}, + {"name": "TrainingPlanName", "shape": "TrainingPlanName", "type": "string"}, + {"name": "Status", "shape": "TrainingPlanStatus", "type": "string"}, + {"name": "StatusMessage", "shape": "TrainingPlanStatusMessage", "type": "string"}, + {"name": "DurationHours", "shape": "TrainingPlanDurationHours", "type": "long"}, + {"name": "DurationMinutes", "shape": "TrainingPlanDurationMinutes", "type": "long"}, + {"name": "StartTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "EndTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "UpfrontFee", "shape": "String256", "type": "string"}, + {"name": "CurrencyCode", "shape": "CurrencyCode", "type": "string"}, + {"name": "TotalInstanceCount", "shape": "TotalInstanceCount", "type": "integer"}, + { + "name": "AvailableInstanceCount", + "shape": "AvailableInstanceCount", + "type": "integer", + }, + {"name": "InUseInstanceCount", "shape": "InUseInstanceCount", "type": "integer"}, + {"name": "TargetResources", "shape": "SageMakerResourceNames", "type": "list"}, + { + "name": "ReservedCapacitySummaries", + "shape": "ReservedCapacitySummaries", + "type": "list", + }, + ], + "type": "structure", + }, "DescribeTransformJobRequest": { "members": [{"name": "TransformJobName", "shape": "TransformJobName", "type": "string"}], "type": "structure", @@ -6198,6 +6540,13 @@ "member_type": "structure", "type": "list", }, + "ErrorInfo": { + "members": [ + {"name": "Code", "shape": "NonEmptyString64", "type": "string"}, + {"name": "Reason", "shape": "NonEmptyString256", "type": "string"}, + ], + "type": "structure", + }, "ExecutionRoleArns": {"member_shape": "RoleArn", "member_type": "string", "type": "list"}, "Experiment": { "members": [ @@ -6264,6 +6613,17 @@ ], "type": "structure", }, + "FSxLustreFileSystem": { + "members": [{"name": "FileSystemId", "shape": "FileSystemId", "type": "string"}], + "type": "structure", + }, + "FSxLustreFileSystemConfig": { + "members": [ + {"name": "FileSystemId", "shape": "FileSystemId", "type": "string"}, + {"name": "FileSystemPath", "shape": "FileSystemPath", "type": "string"}, + ], + "type": "structure", + }, "FailStepMetadata": { "members": [{"name": "ErrorMessage", "shape": "String3072", "type": "string"}], "type": "structure", @@ -8187,6 +8547,31 @@ ], "type": "structure", }, + "ListClusterSchedulerConfigsRequest": { + "members": [ + {"name": "CreatedAfter", "shape": "Timestamp", "type": "timestamp"}, + {"name": "CreatedBefore", "shape": "Timestamp", "type": "timestamp"}, + {"name": "NameContains", "shape": "EntityName", "type": "string"}, + {"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}, + {"name": "Status", "shape": "SchedulerResourceStatus", "type": "string"}, + {"name": "SortBy", "shape": "SortClusterSchedulerConfigBy", "type": "string"}, + {"name": "SortOrder", "shape": "SortOrder", "type": "string"}, + {"name": "NextToken", "shape": "NextToken", "type": "string"}, + {"name": "MaxResults", "shape": "MaxResults", "type": "integer"}, + ], + "type": "structure", + }, + "ListClusterSchedulerConfigsResponse": { + "members": [ + { + "name": "ClusterSchedulerConfigSummaries", + "shape": "ClusterSchedulerConfigSummaryList", + "type": "list", + }, + {"name": "NextToken", "shape": "NextToken", "type": "string"}, + ], + "type": "structure", + }, "ListClustersRequest": { "members": [ {"name": "CreationTimeAfter", "shape": "Timestamp", "type": "timestamp"}, @@ -8196,6 +8581,7 @@ {"name": "NextToken", "shape": "NextToken", "type": "string"}, {"name": "SortBy", "shape": "ClusterSortBy", "type": "string"}, {"name": "SortOrder", "shape": "SortOrder", "type": "string"}, + {"name": "TrainingPlanArn", "shape": "TrainingPlanArn", "type": "string"}, ], "type": "structure", }, @@ -8253,6 +8639,27 @@ ], "type": "structure", }, + "ListComputeQuotasRequest": { + "members": [ + {"name": "CreatedAfter", "shape": "Timestamp", "type": "timestamp"}, + {"name": "CreatedBefore", "shape": "Timestamp", "type": "timestamp"}, + {"name": "NameContains", "shape": "EntityName", "type": "string"}, + {"name": "Status", "shape": "SchedulerResourceStatus", "type": "string"}, + {"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}, + {"name": "SortBy", "shape": "SortQuotaBy", "type": "string"}, + {"name": "SortOrder", "shape": "SortOrder", "type": "string"}, + {"name": "NextToken", "shape": "NextToken", "type": "string"}, + {"name": "MaxResults", "shape": "MaxResults", "type": "integer"}, + ], + "type": "structure", + }, + "ListComputeQuotasResponse": { + "members": [ + {"name": "ComputeQuotaSummaries", "shape": "ComputeQuotaSummaryList", "type": "list"}, + {"name": "NextToken", "shape": "NextToken", "type": "string"}, + ], + "type": "structure", + }, "ListContextsRequest": { "members": [ {"name": "SourceUri", "shape": "SourceUri", "type": "string"}, @@ -9290,6 +9697,20 @@ ], "type": "structure", }, + "ListPartnerAppsRequest": { + "members": [ + {"name": "MaxResults", "shape": "MaxResults", "type": "integer"}, + {"name": "NextToken", "shape": "NextToken", "type": "string"}, + ], + "type": "structure", + }, + "ListPartnerAppsResponse": { + "members": [ + {"name": "Summaries", "shape": "PartnerAppSummaries", "type": "list"}, + {"name": "NextToken", "shape": "NextToken", "type": "string"}, + ], + "type": "structure", + }, "ListPipelineExecutionStepsRequest": { "members": [ {"name": "PipelineExecutionArn", "shape": "PipelineExecutionArn", "type": "string"}, @@ -9561,6 +9982,7 @@ {"name": "SortBy", "shape": "SortBy", "type": "string"}, {"name": "SortOrder", "shape": "SortOrder", "type": "string"}, {"name": "WarmPoolStatusEquals", "shape": "WarmPoolResourceStatus", "type": "string"}, + {"name": "TrainingPlanArnEquals", "shape": "TrainingPlanArn", "type": "string"}, ], "type": "structure", }, @@ -9571,6 +9993,25 @@ ], "type": "structure", }, + "ListTrainingPlansRequest": { + "members": [ + {"name": "NextToken", "shape": "NextToken", "type": "string"}, + {"name": "MaxResults", "shape": "MaxResults", "type": "integer"}, + {"name": "StartTimeAfter", "shape": "Timestamp", "type": "timestamp"}, + {"name": "StartTimeBefore", "shape": "Timestamp", "type": "timestamp"}, + {"name": "SortBy", "shape": "TrainingPlanSortBy", "type": "string"}, + {"name": "SortOrder", "shape": "TrainingPlanSortOrder", "type": "string"}, + {"name": "Filters", "shape": "TrainingPlanFilters", "type": "list"}, + ], + "type": "structure", + }, + "ListTrainingPlansResponse": { + "members": [ + {"name": "NextToken", "shape": "NextToken", "type": "string"}, + {"name": "TrainingPlanSummaries", "shape": "TrainingPlanSummaries", "type": "list"}, + ], + "type": "structure", + }, "ListTransformJobsRequest": { "members": [ {"name": "CreationTimeAfter", "shape": "Timestamp", "type": "timestamp"}, @@ -11304,6 +11745,50 @@ "type": "list", }, "Parents": {"member_shape": "Parent", "member_type": "structure", "type": "list"}, + "PartnerAppAdminUserList": { + "member_shape": "NonEmptyString256", + "member_type": "string", + "type": "list", + }, + "PartnerAppArguments": { + "key_shape": "NonEmptyString256", + "key_type": "string", + "type": "map", + "value_shape": "String1024", + "value_type": "string", + }, + "PartnerAppConfig": { + "members": [ + {"name": "AdminUsers", "shape": "PartnerAppAdminUserList", "type": "list"}, + {"name": "Arguments", "shape": "PartnerAppArguments", "type": "map"}, + ], + "type": "structure", + }, + "PartnerAppMaintenanceConfig": { + "members": [ + { + "name": "MaintenanceWindowStart", + "shape": "WeeklyScheduleTimeFormat", + "type": "string", + } + ], + "type": "structure", + }, + "PartnerAppSummaries": { + "member_shape": "PartnerAppSummary", + "member_type": "structure", + "type": "list", + }, + "PartnerAppSummary": { + "members": [ + {"name": "Arn", "shape": "PartnerAppArn", "type": "string"}, + {"name": "Name", "shape": "PartnerAppName", "type": "string"}, + {"name": "Type", "shape": "PartnerAppType", "type": "string"}, + {"name": "Status", "shape": "PartnerAppStatus", "type": "string"}, + {"name": "CreationTime", "shape": "Timestamp", "type": "timestamp"}, + ], + "type": "structure", + }, "PayloadPart": { "members": [{"name": "Bytes", "shape": "PartBlob", "type": "blob"}], "type": "structure", @@ -11553,6 +12038,18 @@ "members": [{"name": "PredefinedMetricType", "shape": "String", "type": "string"}], "type": "structure", }, + "PriorityClass": { + "members": [ + {"name": "Name", "shape": "ClusterSchedulerPriorityClassName", "type": "string"}, + {"name": "Weight", "shape": "PriorityWeight", "type": "integer"}, + ], + "type": "structure", + }, + "PriorityClassList": { + "member_shape": "PriorityClass", + "member_type": "structure", + "type": "list", + }, "ProcessingClusterConfig": { "members": [ {"name": "InstanceCount", "shape": "ProcessingInstanceCount", "type": "integer"}, @@ -12420,6 +12917,42 @@ ], "type": "structure", }, + "ReservedCapacityOffering": { + "members": [ + {"name": "InstanceType", "shape": "ReservedCapacityInstanceType", "type": "string"}, + {"name": "InstanceCount", "shape": "ReservedCapacityInstanceCount", "type": "integer"}, + {"name": "AvailabilityZone", "shape": "AvailabilityZone", "type": "string"}, + {"name": "DurationHours", "shape": "ReservedCapacityDurationHours", "type": "long"}, + {"name": "DurationMinutes", "shape": "ReservedCapacityDurationMinutes", "type": "long"}, + {"name": "StartTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "EndTime", "shape": "Timestamp", "type": "timestamp"}, + ], + "type": "structure", + }, + "ReservedCapacityOfferings": { + "member_shape": "ReservedCapacityOffering", + "member_type": "structure", + "type": "list", + }, + "ReservedCapacitySummaries": { + "member_shape": "ReservedCapacitySummary", + "member_type": "structure", + "type": "list", + }, + "ReservedCapacitySummary": { + "members": [ + {"name": "ReservedCapacityArn", "shape": "ReservedCapacityArn", "type": "string"}, + {"name": "InstanceType", "shape": "ReservedCapacityInstanceType", "type": "string"}, + {"name": "TotalInstanceCount", "shape": "TotalInstanceCount", "type": "integer"}, + {"name": "Status", "shape": "ReservedCapacityStatus", "type": "string"}, + {"name": "AvailabilityZone", "shape": "AvailabilityZone", "type": "string"}, + {"name": "DurationHours", "shape": "ReservedCapacityDurationHours", "type": "long"}, + {"name": "DurationMinutes", "shape": "ReservedCapacityDurationMinutes", "type": "long"}, + {"name": "StartTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "EndTime", "shape": "Timestamp", "type": "timestamp"}, + ], + "type": "structure", + }, "ResolvedAttributes": { "members": [ {"name": "AutoMLJobObjective", "shape": "AutoMLJobObjective", "type": "structure"}, @@ -12458,6 +12991,7 @@ "type": "integer", }, {"name": "InstanceGroups", "shape": "InstanceGroups", "type": "list"}, + {"name": "TrainingPlanArn", "shape": "TrainingPlanArn", "type": "string"}, ], "type": "structure", }, @@ -12503,6 +13037,13 @@ "members": [{"name": "Message", "shape": "FailureReason", "type": "string"}], "type": "structure", }, + "ResourceSharingConfig": { + "members": [ + {"name": "Strategy", "shape": "ResourceSharingStrategy", "type": "string"}, + {"name": "BorrowLimit", "shape": "BorrowLimit", "type": "integer"}, + ], + "type": "structure", + }, "ResourceSpec": { "members": [ {"name": "SageMakerImageArn", "shape": "ImageArn", "type": "string"}, @@ -12618,6 +13159,11 @@ "member_type": "string", "type": "list", }, + "SageMakerResourceNames": { + "member_shape": "SageMakerResourceName", + "member_type": "string", + "type": "list", + }, "ScalingPolicies": { "member_shape": "ScalingPolicy", "member_type": "structure", @@ -12655,6 +13201,13 @@ ], "type": "structure", }, + "SchedulerConfig": { + "members": [ + {"name": "PriorityClasses", "shape": "PriorityClassList", "type": "list"}, + {"name": "FairShare", "shape": "FairShare", "type": "string"}, + ], + "type": "structure", + }, "SearchExpression": { "members": [ {"name": "Filters", "shape": "FilterList", "type": "list"}, @@ -12722,6 +13275,23 @@ "member_type": "structure", "type": "list", }, + "SearchTrainingPlanOfferingsRequest": { + "members": [ + {"name": "InstanceType", "shape": "ReservedCapacityInstanceType", "type": "string"}, + {"name": "InstanceCount", "shape": "ReservedCapacityInstanceCount", "type": "integer"}, + {"name": "StartTimeAfter", "shape": "Timestamp", "type": "timestamp"}, + {"name": "EndTimeBefore", "shape": "Timestamp", "type": "timestamp"}, + {"name": "DurationHours", "shape": "TrainingPlanDurationHoursInput", "type": "long"}, + {"name": "TargetResources", "shape": "SageMakerResourceNames", "type": "list"}, + ], + "type": "structure", + }, + "SearchTrainingPlanOfferingsResponse": { + "members": [ + {"name": "TrainingPlanOfferings", "shape": "TrainingPlanOfferings", "type": "list"} + ], + "type": "structure", + }, "SecondaryStatusTransition": { "members": [ {"name": "Status", "shape": "SecondaryStatus", "type": "string"}, @@ -13621,6 +14191,80 @@ {"name": "TrainingJobStatus", "shape": "TrainingJobStatus", "type": "string"}, {"name": "SecondaryStatus", "shape": "SecondaryStatus", "type": "string"}, {"name": "WarmPoolStatus", "shape": "WarmPoolStatus", "type": "structure"}, + {"name": "TrainingPlanArn", "shape": "TrainingPlanArn", "type": "string"}, + ], + "type": "structure", + }, + "TrainingPlanArns": { + "member_shape": "TrainingPlanArn", + "member_type": "string", + "type": "list", + }, + "TrainingPlanFilter": { + "members": [ + {"name": "Name", "shape": "TrainingPlanFilterName", "type": "string"}, + {"name": "Value", "shape": "String64", "type": "string"}, + ], + "type": "structure", + }, + "TrainingPlanFilters": { + "member_shape": "TrainingPlanFilter", + "member_type": "structure", + "type": "list", + }, + "TrainingPlanOffering": { + "members": [ + {"name": "TrainingPlanOfferingId", "shape": "TrainingPlanOfferingId", "type": "string"}, + {"name": "TargetResources", "shape": "SageMakerResourceNames", "type": "list"}, + {"name": "RequestedStartTimeAfter", "shape": "Timestamp", "type": "timestamp"}, + {"name": "RequestedEndTimeBefore", "shape": "Timestamp", "type": "timestamp"}, + {"name": "DurationHours", "shape": "TrainingPlanDurationHours", "type": "long"}, + {"name": "DurationMinutes", "shape": "TrainingPlanDurationMinutes", "type": "long"}, + {"name": "UpfrontFee", "shape": "String256", "type": "string"}, + {"name": "CurrencyCode", "shape": "CurrencyCode", "type": "string"}, + { + "name": "ReservedCapacityOfferings", + "shape": "ReservedCapacityOfferings", + "type": "list", + }, + ], + "type": "structure", + }, + "TrainingPlanOfferings": { + "member_shape": "TrainingPlanOffering", + "member_type": "structure", + "type": "list", + }, + "TrainingPlanSummaries": { + "member_shape": "TrainingPlanSummary", + "member_type": "structure", + "type": "list", + }, + "TrainingPlanSummary": { + "members": [ + {"name": "TrainingPlanArn", "shape": "TrainingPlanArn", "type": "string"}, + {"name": "TrainingPlanName", "shape": "TrainingPlanName", "type": "string"}, + {"name": "Status", "shape": "TrainingPlanStatus", "type": "string"}, + {"name": "StatusMessage", "shape": "TrainingPlanStatusMessage", "type": "string"}, + {"name": "DurationHours", "shape": "TrainingPlanDurationHours", "type": "long"}, + {"name": "DurationMinutes", "shape": "TrainingPlanDurationMinutes", "type": "long"}, + {"name": "StartTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "EndTime", "shape": "Timestamp", "type": "timestamp"}, + {"name": "UpfrontFee", "shape": "String256", "type": "string"}, + {"name": "CurrencyCode", "shape": "CurrencyCode", "type": "string"}, + {"name": "TotalInstanceCount", "shape": "TotalInstanceCount", "type": "integer"}, + { + "name": "AvailableInstanceCount", + "shape": "AvailableInstanceCount", + "type": "integer", + }, + {"name": "InUseInstanceCount", "shape": "InUseInstanceCount", "type": "integer"}, + {"name": "TargetResources", "shape": "SageMakerResourceNames", "type": "list"}, + { + "name": "ReservedCapacitySummaries", + "shape": "ReservedCapacitySummaries", + "type": "list", + }, ], "type": "structure", }, @@ -14094,6 +14738,30 @@ "members": [{"name": "ClusterArn", "shape": "ClusterArn", "type": "string"}], "type": "structure", }, + "UpdateClusterSchedulerConfigRequest": { + "members": [ + { + "name": "ClusterSchedulerConfigId", + "shape": "ClusterSchedulerConfigId", + "type": "string", + }, + {"name": "TargetVersion", "shape": "Integer", "type": "integer"}, + {"name": "SchedulerConfig", "shape": "SchedulerConfig", "type": "structure"}, + {"name": "Description", "shape": "EntityDescription", "type": "string"}, + ], + "type": "structure", + }, + "UpdateClusterSchedulerConfigResponse": { + "members": [ + { + "name": "ClusterSchedulerConfigArn", + "shape": "ClusterSchedulerConfigArn", + "type": "string", + }, + {"name": "ClusterSchedulerConfigVersion", "shape": "Integer", "type": "integer"}, + ], + "type": "structure", + }, "UpdateClusterSoftwareRequest": { "members": [{"name": "ClusterName", "shape": "ClusterNameOrArn", "type": "string"}], "type": "structure", @@ -14113,6 +14781,24 @@ "members": [{"name": "CodeRepositoryArn", "shape": "CodeRepositoryArn", "type": "string"}], "type": "structure", }, + "UpdateComputeQuotaRequest": { + "members": [ + {"name": "ComputeQuotaId", "shape": "ComputeQuotaId", "type": "string"}, + {"name": "TargetVersion", "shape": "Integer", "type": "integer"}, + {"name": "ComputeQuotaConfig", "shape": "ComputeQuotaConfig", "type": "structure"}, + {"name": "ComputeQuotaTarget", "shape": "ComputeQuotaTarget", "type": "structure"}, + {"name": "ActivationState", "shape": "ActivationState", "type": "string"}, + {"name": "Description", "shape": "EntityDescription", "type": "string"}, + ], + "type": "structure", + }, + "UpdateComputeQuotaResponse": { + "members": [ + {"name": "ComputeQuotaArn", "shape": "ComputeQuotaArn", "type": "string"}, + {"name": "ComputeQuotaVersion", "shape": "Integer", "type": "integer"}, + ], + "type": "structure", + }, "UpdateContextRequest": { "members": [ {"name": "ContextName", "shape": "ContextName", "type": "string"}, @@ -14517,6 +15203,26 @@ }, "UpdateNotebookInstanceLifecycleConfigOutput": {"members": [], "type": "structure"}, "UpdateNotebookInstanceOutput": {"members": [], "type": "structure"}, + "UpdatePartnerAppRequest": { + "members": [ + {"name": "Arn", "shape": "PartnerAppArn", "type": "string"}, + { + "name": "MaintenanceConfig", + "shape": "PartnerAppMaintenanceConfig", + "type": "structure", + }, + {"name": "Tier", "shape": "NonEmptyString64", "type": "string"}, + {"name": "ApplicationConfig", "shape": "PartnerAppConfig", "type": "structure"}, + {"name": "EnableIamSessionBasedIdentity", "shape": "Boolean", "type": "boolean"}, + {"name": "ClientToken", "shape": "ClientToken", "type": "string"}, + {"name": "Tags", "shape": "TagList", "type": "list"}, + ], + "type": "structure", + }, + "UpdatePartnerAppResponse": { + "members": [{"name": "Arn", "shape": "PartnerAppArn", "type": "string"}], + "type": "structure", + }, "UpdatePipelineExecutionRequest": { "members": [ {"name": "PipelineExecutionArn", "shape": "PipelineExecutionArn", "type": "string"}, diff --git a/src/sagemaker_core/main/config_schema.py b/src/sagemaker_core/main/config_schema.py index ae410807..0396f854 100644 --- a/src/sagemaker_core/main/config_schema.py +++ b/src/sagemaker_core/main/config_schema.py @@ -754,6 +754,10 @@ }, }, }, + "PartnerApp": { + "type": "object", + "properties": {"execution_role_arn": {"type": "string"}}, + }, "Pipeline": { "type": "object", "properties": {"role_arn": {"type": "string"}}, diff --git a/src/sagemaker_core/main/resources.py b/src/sagemaker_core/main/resources.py index 214b1ffd..a612679a 100644 --- a/src/sagemaker_core/main/resources.py +++ b/src/sagemaker_core/main/resources.py @@ -3640,6 +3640,7 @@ def get_all( name_contains: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), + training_plan_arn: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, ) -> ResourceIterator["Cluster"]: @@ -3654,6 +3655,7 @@ def get_all( next_token: Set the next token to retrieve the list of SageMaker HyperPod clusters. sort_by: The field by which to sort results. The default value is CREATION_TIME. sort_order: The sort order for results. The default value is Ascending. + training_plan_arn: The Amazon Resource Name (ARN); of the training plan to filter clusters by. For more information about reserving GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . session: Boto3 session. region: Region name. @@ -3682,6 +3684,7 @@ def get_all( "NameContains": name_contains, "SortBy": sort_by, "SortOrder": sort_order, + "TrainingPlanArn": training_plan_arn, } # serialize the input request @@ -3906,28 +3909,44 @@ def batch_delete_nodes( return BatchDeleteClusterNodesResponse(**transformed_response) -class CodeRepository(Base): +class ClusterSchedulerConfig(Base): """ - Class representing resource CodeRepository + Class representing resource ClusterSchedulerConfig Attributes: - code_repository_name: The name of the Git repository. - code_repository_arn: The Amazon Resource Name (ARN) of the Git repository. - creation_time: The date and time that the repository was created. - last_modified_time: The date and time that the repository was last changed. - git_config: Configuration details about the repository, including the URL where the repository is located, the default branch, and the Amazon Resource Name (ARN) of the Amazon Web Services Secrets Manager secret that contains the credentials used to access the repository. + cluster_scheduler_config_arn: ARN of the cluster policy. + cluster_scheduler_config_id: ID of the cluster policy. + name: Name of the cluster policy. + cluster_scheduler_config_version: Version of the cluster policy. + status: Status of the cluster policy. + creation_time: Creation time of the cluster policy. + failure_reason: Failure reason of the cluster policy. + cluster_arn: ARN of the cluster where the cluster policy is applied. + scheduler_config: Cluster policy configuration. This policy is used for task prioritization and fair-share allocation. This helps prioritize critical workloads and distributes idle compute across entities. + description: Description of the cluster policy. + created_by: + last_modified_time: Last modified time of the cluster policy. + last_modified_by: """ - code_repository_name: str - code_repository_arn: Optional[str] = Unassigned() + cluster_scheduler_config_id: str + cluster_scheduler_config_arn: Optional[str] = Unassigned() + name: Optional[str] = Unassigned() + cluster_scheduler_config_version: Optional[int] = Unassigned() + status: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() + cluster_arn: Optional[str] = Unassigned() + scheduler_config: Optional[SchedulerConfig] = Unassigned() + description: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() + created_by: Optional[UserContext] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - git_config: Optional[GitConfig] = Unassigned() + last_modified_by: Optional[UserContext] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "code_repository_name" + resource_name = "cluster_scheduler_config_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -3938,31 +3957,35 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object code_repository") + logger.error("Name attribute not found for object cluster_scheduler_config") return None @classmethod @Base.add_validate_call def create( cls, - code_repository_name: str, - git_config: GitConfig, + name: str, + cluster_arn: str, + scheduler_config: SchedulerConfig, + description: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["CodeRepository"]: + ) -> Optional["ClusterSchedulerConfig"]: """ - Create a CodeRepository resource + Create a ClusterSchedulerConfig resource Parameters: - code_repository_name: The name of the Git repository. The name must have 1 to 63 characters. Valid characters are a-z, A-Z, 0-9, and - (hyphen). - git_config: Specifies details about the repository, including the URL where the repository is located, the default branch, and credentials to use to access the repository. - tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + name: Name for the cluster policy. + cluster_arn: ARN of the cluster. + scheduler_config: Configuration about the monitoring schedule. + description: Description of the cluster policy. + tags: Tags of the cluster policy. session: Boto3 session. region: Region name. Returns: - The CodeRepository resource. + The ClusterSchedulerConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -3974,24 +3997,28 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating code_repository resource.") + logger.info("Creating cluster_scheduler_config resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "CodeRepositoryName": code_repository_name, - "GitConfig": git_config, + "Name": name, + "ClusterArn": cluster_arn, + "SchedulerConfig": scheduler_config, + "Description": description, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="CodeRepository", operation_input_args=operation_input_args + resource_name="ClusterSchedulerConfig", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -4000,29 +4027,35 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_code_repository(**operation_input_args) + response = client.create_cluster_scheduler_config(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(code_repository_name=code_repository_name, session=session, region=region) + return cls.get( + cluster_scheduler_config_id=response["ClusterSchedulerConfigId"], + session=session, + region=region, + ) @classmethod @Base.add_validate_call def get( cls, - code_repository_name: str, + cluster_scheduler_config_id: str, + cluster_scheduler_config_version: Optional[int] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["CodeRepository"]: + ) -> Optional["ClusterSchedulerConfig"]: """ - Get a CodeRepository resource + Get a ClusterSchedulerConfig resource Parameters: - code_repository_name: The name of the Git repository to describe. + cluster_scheduler_config_id: ID of the cluster policy. + cluster_scheduler_config_version: Version of the cluster policy. session: Boto3 session. region: Region name. Returns: - The CodeRepository resource. + The ClusterSchedulerConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4034,10 +4067,12 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "CodeRepositoryName": code_repository_name, + "ClusterSchedulerConfigId": cluster_scheduler_config_id, + "ClusterSchedulerConfigVersion": cluster_scheduler_config_version, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -4046,24 +4081,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_code_repository(**operation_input_args) + response = client.describe_cluster_scheduler_config(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeCodeRepositoryOutput") - code_repository = cls(**transformed_response) - return code_repository + transformed_response = transform(response, "DescribeClusterSchedulerConfigResponse") + cluster_scheduler_config = cls(**transformed_response) + return cluster_scheduler_config @Base.add_validate_call def refresh( self, - ) -> Optional["CodeRepository"]: + ) -> Optional["ClusterSchedulerConfig"]: """ - Refresh a CodeRepository resource + Refresh a ClusterSchedulerConfig resource Returns: - The CodeRepository resource. + The ClusterSchedulerConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4075,32 +4110,39 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "CodeRepositoryName": self.code_repository_name, + "ClusterSchedulerConfigId": self.cluster_scheduler_config_id, + "ClusterSchedulerConfigVersion": self.cluster_scheduler_config_version, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_code_repository(**operation_input_args) + response = client.describe_cluster_scheduler_config(**operation_input_args) # deserialize response and update self - transform(response, "DescribeCodeRepositoryOutput", self) + transform(response, "DescribeClusterSchedulerConfigResponse", self) return self @Base.add_validate_call def update( self, - git_config: Optional[GitConfigForUpdate] = Unassigned(), - ) -> Optional["CodeRepository"]: + target_version: int, + scheduler_config: Optional[SchedulerConfig] = Unassigned(), + description: Optional[str] = Unassigned(), + ) -> Optional["ClusterSchedulerConfig"]: """ - Update a CodeRepository resource + Update a ClusterSchedulerConfig resource + + Parameters: + target_version: Target version. Returns: - The CodeRepository resource. + The ClusterSchedulerConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4113,14 +4155,18 @@ def update( error_code = e.response['Error']['Code'] ``` ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ - logger.info("Updating code_repository resource.") + logger.info("Updating cluster_scheduler_config resource.") client = Base.get_sagemaker_client() operation_input_args = { - "CodeRepositoryName": self.code_repository_name, - "GitConfig": git_config, + "ClusterSchedulerConfigId": self.cluster_scheduler_config_id, + "TargetVersion": target_version, + "SchedulerConfig": scheduler_config, + "Description": description, } logger.debug(f"Input request: {operation_input_args}") # serialize the input request @@ -4128,7 +4174,7 @@ def update( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_code_repository(**operation_input_args) + response = client.update_cluster_scheduler_config(**operation_input_args) logger.debug(f"Response: {response}") self.refresh() @@ -4139,7 +4185,7 @@ def delete( self, ) -> None: """ - Delete a CodeRepository resource + Delete a ClusterSchedulerConfig resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4151,53 +4197,196 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "CodeRepositoryName": self.code_repository_name, + "ClusterSchedulerConfigId": self.cluster_scheduler_config_id, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_code_repository(**operation_input_args) + client.delete_cluster_scheduler_config(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call + def wait_for_status( + self, + target_status: Literal[ + "Creating", + "CreateFailed", + "CreateRollbackFailed", + "Created", + "Updating", + "UpdateFailed", + "UpdateRollbackFailed", + "Updated", + "Deleting", + "DeleteFailed", + "DeleteRollbackFailed", + "Deleted", + ], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a ClusterSchedulerConfig resource to reach certain status. + + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task( + f"Waiting for ClusterSchedulerConfig to reach [bold]{target_status} status..." + ) + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") + + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="ClusterSchedulerConfig", + status=current_status, + reason=self.failure_reason, + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="ClusterSchedulerConfig", status=current_status + ) + time.sleep(poll) + + @Base.add_validate_call + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a ClusterSchedulerConfig resource to be deleted. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for ClusterSchedulerConfig to be deleted...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") + + if current_status.lower() == "deleted": + print("Resource was deleted.") + return + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="ClusterSchedulerConfig", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) + @classmethod @Base.add_validate_call def get_all( cls, - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + created_after: Optional[datetime.datetime] = Unassigned(), + created_before: Optional[datetime.datetime] = Unassigned(), name_contains: Optional[str] = Unassigned(), + cluster_arn: Optional[str] = Unassigned(), + status: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["CodeRepository"]: + ) -> ResourceIterator["ClusterSchedulerConfig"]: """ - Gets a list of the Git repositories in your account. + Get all ClusterSchedulerConfig resources Parameters: - creation_time_after: A filter that returns only Git repositories that were created after the specified time. - creation_time_before: A filter that returns only Git repositories that were created before the specified time. - last_modified_time_after: A filter that returns only Git repositories that were last modified after the specified time. - last_modified_time_before: A filter that returns only Git repositories that were last modified before the specified time. - max_results: The maximum number of Git repositories to return in the response. - name_contains: A string in the Git repositories name. This filter returns only repositories whose name contains the specified string. - next_token: If the result of a ListCodeRepositoriesOutput request was truncated, the response includes a NextToken. To get the next set of Git repositories, use the token in the next request. - sort_by: The field to sort results by. The default is Name. - sort_order: The sort order for results. The default is Ascending. + created_after: Filter for after this creation time. The input for this parameter is a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter. + created_before: Filter for before this creation time. The input for this parameter is a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter. + name_contains: Filter for name containing this string. + cluster_arn: Filter for ARN of the cluster. + status: Filter for status. + sort_by: Filter for sorting the list by a given value. For example, sort by name, creation time, or status. + sort_order: The order of the list. By default, listed in Descending order according to by SortBy. To change the list order, you can specify SortOrder to be Ascending. + next_token: If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results. + max_results: The maximum number of cluster policies to list. session: Boto3 session. region: Region name. Returns: - Iterator for listed CodeRepository. + Iterator for listed ClusterSchedulerConfig resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4211,81 +4400,56 @@ def get_all( ``` """ + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "LastModifiedTimeBefore": last_modified_time_before, + "CreatedAfter": created_after, + "CreatedBefore": created_before, "NameContains": name_contains, + "ClusterArn": cluster_arn, + "Status": status, "SortBy": sort_by, "SortOrder": sort_order, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - return ResourceIterator( client=client, - list_method="list_code_repositories", - summaries_key="CodeRepositorySummaryList", - summary_name="CodeRepositorySummary", - resource_cls=CodeRepository, + list_method="list_cluster_scheduler_configs", + summaries_key="ClusterSchedulerConfigSummaries", + summary_name="ClusterSchedulerConfigSummary", + resource_cls=ClusterSchedulerConfig, list_method_kwargs=operation_input_args, ) -class CompilationJob(Base): +class CodeRepository(Base): """ - Class representing resource CompilationJob + Class representing resource CodeRepository Attributes: - compilation_job_name: The name of the model compilation job. - compilation_job_arn: The Amazon Resource Name (ARN) of the model compilation job. - compilation_job_status: The status of the model compilation job. - stopping_condition: Specifies a limit to how long a model compilation job can run. When the job reaches the time limit, Amazon SageMaker ends the compilation job. Use this API to cap model training costs. - creation_time: The time that the model compilation job was created. - last_modified_time: The time that the status of the model compilation job was last modified. - failure_reason: If a model compilation job failed, the reason it failed. - model_artifacts: Information about the location in Amazon S3 that has been configured for storing the model artifacts used in the compilation job. - role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker assumes to perform the model compilation job. - input_config: Information about the location in Amazon S3 of the input model artifacts, the name and shape of the expected data inputs, and the framework in which the model was trained. - output_config: Information about the output location for the compiled model and the target device that the model runs on. - compilation_start_time: The time when the model compilation job started the CompilationJob instances. You are billed for the time between this timestamp and the timestamp in the CompilationEndTime field. In Amazon CloudWatch Logs, the start time might be later than this time. That's because it takes time to download the compilation job, which depends on the size of the compilation job container. - compilation_end_time: The time when the model compilation job on a compilation job instance ended. For a successful or stopped job, this is when the job's model artifacts have finished uploading. For a failed job, this is when Amazon SageMaker detected that the job failed. - inference_image: The inference image to use when compiling a model. Specify an image only if the target device is a cloud instance. - model_package_version_arn: The Amazon Resource Name (ARN) of the versioned model package that was provided to SageMaker Neo when you initiated a compilation job. - model_digests: Provides a BLAKE2 hash value that identifies the compiled model artifacts in Amazon S3. - vpc_config: A VpcConfig object that specifies the VPC that you want your compilation job to connect to. Control access to your models by configuring the VPC. For more information, see Protect Compilation Jobs by Using an Amazon Virtual Private Cloud. - derived_information: Information that SageMaker Neo automatically derived about the model. + code_repository_name: The name of the Git repository. + code_repository_arn: The Amazon Resource Name (ARN) of the Git repository. + creation_time: The date and time that the repository was created. + last_modified_time: The date and time that the repository was last changed. + git_config: Configuration details about the repository, including the URL where the repository is located, the default branch, and the Amazon Resource Name (ARN) of the Amazon Web Services Secrets Manager secret that contains the credentials used to access the repository. """ - compilation_job_name: str - compilation_job_arn: Optional[str] = Unassigned() - compilation_job_status: Optional[str] = Unassigned() - compilation_start_time: Optional[datetime.datetime] = Unassigned() - compilation_end_time: Optional[datetime.datetime] = Unassigned() - stopping_condition: Optional[StoppingCondition] = Unassigned() - inference_image: Optional[str] = Unassigned() - model_package_version_arn: Optional[str] = Unassigned() + code_repository_name: str + code_repository_arn: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - failure_reason: Optional[str] = Unassigned() - model_artifacts: Optional[ModelArtifacts] = Unassigned() - model_digests: Optional[ModelDigests] = Unassigned() - role_arn: Optional[str] = Unassigned() - input_config: Optional[InputConfig] = Unassigned() - output_config: Optional[OutputConfig] = Unassigned() - vpc_config: Optional[NeoVpcConfig] = Unassigned() - derived_information: Optional[DerivedInformation] = Unassigned() + git_config: Optional[GitConfig] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "compilation_job_name" + resource_name = "code_repository_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -4296,67 +4460,31 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object compilation_job") + logger.error("Name attribute not found for object code_repository") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "model_artifacts": {"s3_model_artifacts": {"type": "string"}}, - "role_arn": {"type": "string"}, - "input_config": {"s3_uri": {"type": "string"}}, - "output_config": { - "s3_output_location": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - }, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "CompilationJob", **kwargs - ), - ) - - return wrapper - @classmethod - @populate_inputs_decorator @Base.add_validate_call def create( cls, - compilation_job_name: str, - role_arn: str, - output_config: OutputConfig, - stopping_condition: StoppingCondition, - model_package_version_arn: Optional[str] = Unassigned(), - input_config: Optional[InputConfig] = Unassigned(), - vpc_config: Optional[NeoVpcConfig] = Unassigned(), + code_repository_name: str, + git_config: GitConfig, tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["CompilationJob"]: + ) -> Optional["CodeRepository"]: """ - Create a CompilationJob resource + Create a CodeRepository resource Parameters: - compilation_job_name: A name for the model compilation job. The name must be unique within the Amazon Web Services Region and within your Amazon Web Services account. - role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. During model compilation, Amazon SageMaker needs your permission to: Read input data from an S3 bucket Write model artifacts to an S3 bucket Write logs to Amazon CloudWatch Logs Publish metrics to Amazon CloudWatch You grant permissions for all of these tasks to an IAM role. To pass this role to Amazon SageMaker, the caller of this API must have the iam:PassRole permission. For more information, see Amazon SageMaker Roles. - output_config: Provides information about the output location for the compiled model and the target device the model runs on. - stopping_condition: Specifies a limit to how long a model compilation job can run. When the job reaches the time limit, Amazon SageMaker ends the compilation job. Use this API to cap model training costs. - model_package_version_arn: The Amazon Resource Name (ARN) of a versioned model package. Provide either a ModelPackageVersionArn or an InputConfig object in the request syntax. The presence of both objects in the CreateCompilationJob request will return an exception. - input_config: Provides information about the location of input model artifacts, the name and shape of the expected data inputs, and the framework in which the model was trained. - vpc_config: A VpcConfig object that specifies the VPC that you want your compilation job to connect to. Control access to your models by configuring the VPC. For more information, see Protect Compilation Jobs by Using an Amazon Virtual Private Cloud. + code_repository_name: The name of the Git repository. The name must have 1 to 63 characters. Valid characters are a-z, A-Z, 0-9, and - (hyphen). + git_config: Specifies details about the repository, including the URL where the repository is located, the default branch, and credentials to use to access the repository. tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. session: Boto3 session. region: Region name. Returns: - The CompilationJob resource. + The CodeRepository resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4368,31 +4496,24 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating compilation_job resource.") + logger.info("Creating code_repository resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "CompilationJobName": compilation_job_name, - "RoleArn": role_arn, - "ModelPackageVersionArn": model_package_version_arn, - "InputConfig": input_config, - "OutputConfig": output_config, - "VpcConfig": vpc_config, - "StoppingCondition": stopping_condition, + "CodeRepositoryName": code_repository_name, + "GitConfig": git_config, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="CompilationJob", operation_input_args=operation_input_args + resource_name="CodeRepository", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -4401,29 +4522,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_compilation_job(**operation_input_args) + response = client.create_code_repository(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(compilation_job_name=compilation_job_name, session=session, region=region) + return cls.get(code_repository_name=code_repository_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - compilation_job_name: str, + code_repository_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["CompilationJob"]: + ) -> Optional["CodeRepository"]: """ - Get a CompilationJob resource + Get a CodeRepository resource Parameters: - compilation_job_name: The name of the model compilation job that you want information about. + code_repository_name: The name of the Git repository to describe. session: Boto3 session. region: Region name. Returns: - The CompilationJob resource. + The CodeRepository resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4435,11 +4556,10 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "CompilationJobName": compilation_job_name, + "CodeRepositoryName": code_repository_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -4448,24 +4568,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_compilation_job(**operation_input_args) + response = client.describe_code_repository(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeCompilationJobResponse") - compilation_job = cls(**transformed_response) - return compilation_job + transformed_response = transform(response, "DescribeCodeRepositoryOutput") + code_repository = cls(**transformed_response) + return code_repository @Base.add_validate_call def refresh( self, - ) -> Optional["CompilationJob"]: + ) -> Optional["CodeRepository"]: """ - Refresh a CompilationJob resource + Refresh a CodeRepository resource Returns: - The CompilationJob resource. + The CodeRepository resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4477,29 +4597,32 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "CompilationJobName": self.compilation_job_name, + "CodeRepositoryName": self.code_repository_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_compilation_job(**operation_input_args) + response = client.describe_code_repository(**operation_input_args) # deserialize response and update self - transform(response, "DescribeCompilationJobResponse", self) + transform(response, "DescribeCodeRepositoryOutput", self) return self @Base.add_validate_call - def delete( + def update( self, - ) -> None: + git_config: Optional[GitConfigForUpdate] = Unassigned(), + ) -> Optional["CodeRepository"]: """ - Delete a CompilationJob resource + Update a CodeRepository resource + + Returns: + The CodeRepository resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4511,26 +4634,34 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. """ + logger.info("Updating code_repository resource.") client = Base.get_sagemaker_client() operation_input_args = { - "CompilationJobName": self.compilation_job_name, + "CodeRepositoryName": self.code_repository_name, + "GitConfig": git_config, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_compilation_job(**operation_input_args) + # create the resource + response = client.update_code_repository(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + return self @Base.add_validate_call - def stop(self) -> None: + def delete( + self, + ) -> None: """ - Stop a CompilationJob resource + Delete a CodeRepository resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4542,80 +4673,20 @@ def stop(self) -> None: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client() operation_input_args = { - "CompilationJobName": self.compilation_job_name, + "CodeRepositoryName": self.code_repository_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.stop_compilation_job(**operation_input_args) - - logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def wait( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a CompilationJob resource. - - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. - - Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - - """ - terminal_states = ["COMPLETED", "FAILED", "STOPPED"] - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for CompilationJob...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.compilation_job_status - status.update(f"Current status: [bold]{current_status}") - - if current_status in terminal_states: - logger.info(f"Final Resource Status: [bold]{current_status}") - - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="CompilationJob", - status=current_status, - reason=self.failure_reason, - ) - - return + client.delete_code_repository(**operation_input_args) - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="CompilationJob", status=current_status) - time.sleep(poll) + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @classmethod @Base.add_validate_call @@ -4626,31 +4697,29 @@ def get_all( last_modified_time_after: Optional[datetime.datetime] = Unassigned(), last_modified_time_before: Optional[datetime.datetime] = Unassigned(), name_contains: Optional[str] = Unassigned(), - status_equals: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["CompilationJob"]: + ) -> ResourceIterator["CodeRepository"]: """ - Get all CompilationJob resources + Gets a list of the Git repositories in your account. Parameters: - next_token: If the result of the previous ListCompilationJobs request was truncated, the response includes a NextToken. To retrieve the next set of model compilation jobs, use the token in the next request. - max_results: The maximum number of model compilation jobs to return in the response. - creation_time_after: A filter that returns the model compilation jobs that were created after a specified time. - creation_time_before: A filter that returns the model compilation jobs that were created before a specified time. - last_modified_time_after: A filter that returns the model compilation jobs that were modified after a specified time. - last_modified_time_before: A filter that returns the model compilation jobs that were modified before a specified time. - name_contains: A filter that returns the model compilation jobs whose name contains a specified string. - status_equals: A filter that retrieves model compilation jobs with a specific CompilationJobStatus status. - sort_by: The field by which to sort results. The default is CreationTime. + creation_time_after: A filter that returns only Git repositories that were created after the specified time. + creation_time_before: A filter that returns only Git repositories that were created before the specified time. + last_modified_time_after: A filter that returns only Git repositories that were last modified after the specified time. + last_modified_time_before: A filter that returns only Git repositories that were last modified before the specified time. + max_results: The maximum number of Git repositories to return in the response. + name_contains: A string in the Git repositories name. This filter returns only repositories whose name contains the specified string. + next_token: If the result of a ListCodeRepositoriesOutput request was truncated, the response includes a NextToken. To get the next set of Git repositories, use the token in the next request. + sort_by: The field to sort results by. The default is Name. sort_order: The sort order for results. The default is Ascending. session: Boto3 session. region: Region name. Returns: - Iterator for listed CompilationJob resources. + Iterator for listed CodeRepository. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4664,69 +4733,81 @@ def get_all( ``` """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { "CreationTimeAfter": creation_time_after, "CreationTimeBefore": creation_time_before, "LastModifiedTimeAfter": last_modified_time_after, "LastModifiedTimeBefore": last_modified_time_before, "NameContains": name_contains, - "StatusEquals": status_equals, "SortBy": sort_by, "SortOrder": sort_order, } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + return ResourceIterator( client=client, - list_method="list_compilation_jobs", - summaries_key="CompilationJobSummaries", - summary_name="CompilationJobSummary", - resource_cls=CompilationJob, + list_method="list_code_repositories", + summaries_key="CodeRepositorySummaryList", + summary_name="CodeRepositorySummary", + resource_cls=CodeRepository, list_method_kwargs=operation_input_args, ) -class Context(Base): +class CompilationJob(Base): """ - Class representing resource Context + Class representing resource CompilationJob Attributes: - context_name: The name of the context. - context_arn: The Amazon Resource Name (ARN) of the context. - source: The source of the context. - context_type: The type of the context. - description: The description of the context. - properties: A list of the context's properties. - creation_time: When the context was created. - created_by: - last_modified_time: When the context was last modified. - last_modified_by: - lineage_group_arn: The Amazon Resource Name (ARN) of the lineage group. + compilation_job_name: The name of the model compilation job. + compilation_job_arn: The Amazon Resource Name (ARN) of the model compilation job. + compilation_job_status: The status of the model compilation job. + stopping_condition: Specifies a limit to how long a model compilation job can run. When the job reaches the time limit, Amazon SageMaker ends the compilation job. Use this API to cap model training costs. + creation_time: The time that the model compilation job was created. + last_modified_time: The time that the status of the model compilation job was last modified. + failure_reason: If a model compilation job failed, the reason it failed. + model_artifacts: Information about the location in Amazon S3 that has been configured for storing the model artifacts used in the compilation job. + role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker assumes to perform the model compilation job. + input_config: Information about the location in Amazon S3 of the input model artifacts, the name and shape of the expected data inputs, and the framework in which the model was trained. + output_config: Information about the output location for the compiled model and the target device that the model runs on. + compilation_start_time: The time when the model compilation job started the CompilationJob instances. You are billed for the time between this timestamp and the timestamp in the CompilationEndTime field. In Amazon CloudWatch Logs, the start time might be later than this time. That's because it takes time to download the compilation job, which depends on the size of the compilation job container. + compilation_end_time: The time when the model compilation job on a compilation job instance ended. For a successful or stopped job, this is when the job's model artifacts have finished uploading. For a failed job, this is when Amazon SageMaker detected that the job failed. + inference_image: The inference image to use when compiling a model. Specify an image only if the target device is a cloud instance. + model_package_version_arn: The Amazon Resource Name (ARN) of the versioned model package that was provided to SageMaker Neo when you initiated a compilation job. + model_digests: Provides a BLAKE2 hash value that identifies the compiled model artifacts in Amazon S3. + vpc_config: A VpcConfig object that specifies the VPC that you want your compilation job to connect to. Control access to your models by configuring the VPC. For more information, see Protect Compilation Jobs by Using an Amazon Virtual Private Cloud. + derived_information: Information that SageMaker Neo automatically derived about the model. """ - context_name: str - context_arn: Optional[str] = Unassigned() - source: Optional[ContextSource] = Unassigned() - context_type: Optional[str] = Unassigned() - description: Optional[str] = Unassigned() - properties: Optional[Dict[str, str]] = Unassigned() + compilation_job_name: str + compilation_job_arn: Optional[str] = Unassigned() + compilation_job_status: Optional[str] = Unassigned() + compilation_start_time: Optional[datetime.datetime] = Unassigned() + compilation_end_time: Optional[datetime.datetime] = Unassigned() + stopping_condition: Optional[StoppingCondition] = Unassigned() + inference_image: Optional[str] = Unassigned() + model_package_version_arn: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - created_by: Optional[UserContext] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - last_modified_by: Optional[UserContext] = Unassigned() - lineage_group_arn: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() + model_artifacts: Optional[ModelArtifacts] = Unassigned() + model_digests: Optional[ModelDigests] = Unassigned() + role_arn: Optional[str] = Unassigned() + input_config: Optional[InputConfig] = Unassigned() + output_config: Optional[OutputConfig] = Unassigned() + vpc_config: Optional[NeoVpcConfig] = Unassigned() + derived_information: Optional[DerivedInformation] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "context_name" + resource_name = "compilation_job_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -4737,37 +4818,67 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object context") + logger.error("Name attribute not found for object compilation_job") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "model_artifacts": {"s3_model_artifacts": {"type": "string"}}, + "role_arn": {"type": "string"}, + "input_config": {"s3_uri": {"type": "string"}}, + "output_config": { + "s3_output_location": {"type": "string"}, + "kms_key_id": {"type": "string"}, + }, + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "CompilationJob", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, - context_name: str, - source: ContextSource, - context_type: str, - description: Optional[str] = Unassigned(), - properties: Optional[Dict[str, str]] = Unassigned(), + compilation_job_name: str, + role_arn: str, + output_config: OutputConfig, + stopping_condition: StoppingCondition, + model_package_version_arn: Optional[str] = Unassigned(), + input_config: Optional[InputConfig] = Unassigned(), + vpc_config: Optional[NeoVpcConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Context"]: + ) -> Optional["CompilationJob"]: """ - Create a Context resource + Create a CompilationJob resource Parameters: - context_name: The name of the context. Must be unique to your account in an Amazon Web Services Region. - source: The source type, ID, and URI. - context_type: The context type. - description: The description of the context. - properties: A list of properties to add to the context. - tags: A list of tags to apply to the context. + compilation_job_name: A name for the model compilation job. The name must be unique within the Amazon Web Services Region and within your Amazon Web Services account. + role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. During model compilation, Amazon SageMaker needs your permission to: Read input data from an S3 bucket Write model artifacts to an S3 bucket Write logs to Amazon CloudWatch Logs Publish metrics to Amazon CloudWatch You grant permissions for all of these tasks to an IAM role. To pass this role to Amazon SageMaker, the caller of this API must have the iam:PassRole permission. For more information, see Amazon SageMaker Roles. + output_config: Provides information about the output location for the compiled model and the target device the model runs on. + stopping_condition: Specifies a limit to how long a model compilation job can run. When the job reaches the time limit, Amazon SageMaker ends the compilation job. Use this API to cap model training costs. + model_package_version_arn: The Amazon Resource Name (ARN) of a versioned model package. Provide either a ModelPackageVersionArn or an InputConfig object in the request syntax. The presence of both objects in the CreateCompilationJob request will return an exception. + input_config: Provides information about the location of input model artifacts, the name and shape of the expected data inputs, and the framework in which the model was trained. + vpc_config: A VpcConfig object that specifies the VPC that you want your compilation job to connect to. Control access to your models by configuring the VPC. For more information, see Protect Compilation Jobs by Using an Amazon Virtual Private Cloud. + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. session: Boto3 session. region: Region name. Returns: - The Context resource. + The CompilationJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4779,28 +4890,31 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating context resource.") + logger.info("Creating compilation_job resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "ContextName": context_name, - "Source": source, - "ContextType": context_type, - "Description": description, - "Properties": properties, + "CompilationJobName": compilation_job_name, + "RoleArn": role_arn, + "ModelPackageVersionArn": model_package_version_arn, + "InputConfig": input_config, + "OutputConfig": output_config, + "VpcConfig": vpc_config, + "StoppingCondition": stopping_condition, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="Context", operation_input_args=operation_input_args + resource_name="CompilationJob", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -4809,29 +4923,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_context(**operation_input_args) + response = client.create_compilation_job(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(context_name=context_name, session=session, region=region) + return cls.get(compilation_job_name=compilation_job_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - context_name: str, + compilation_job_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Context"]: + ) -> Optional["CompilationJob"]: """ - Get a Context resource + Get a CompilationJob resource Parameters: - context_name: The name of the context to describe. + compilation_job_name: The name of the model compilation job that you want information about. session: Boto3 session. region: Region name. Returns: - The Context resource. + The CompilationJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4847,7 +4961,7 @@ def get( """ operation_input_args = { - "ContextName": context_name, + "CompilationJobName": compilation_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -4856,24 +4970,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_context(**operation_input_args) + response = client.describe_compilation_job(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeContextResponse") - context = cls(**transformed_response) - return context + transformed_response = transform(response, "DescribeCompilationJobResponse") + compilation_job = cls(**transformed_response) + return compilation_job @Base.add_validate_call def refresh( self, - ) -> Optional["Context"]: + ) -> Optional["CompilationJob"]: """ - Refresh a Context resource + Refresh a CompilationJob resource Returns: - The Context resource. + The CompilationJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4889,34 +5003,25 @@ def refresh( """ operation_input_args = { - "ContextName": self.context_name, + "CompilationJobName": self.compilation_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_context(**operation_input_args) + response = client.describe_compilation_job(**operation_input_args) # deserialize response and update self - transform(response, "DescribeContextResponse", self) + transform(response, "DescribeCompilationJobResponse", self) return self @Base.add_validate_call - def update( + def delete( self, - description: Optional[str] = Unassigned(), - properties: Optional[Dict[str, str]] = Unassigned(), - properties_to_remove: Optional[List[str]] = Unassigned(), - ) -> Optional["Context"]: + ) -> None: """ - Update a Context resource - - Parameters: - properties_to_remove: A list of properties to remove. - - Returns: - The Context resource. + Delete a CompilationJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4928,37 +5033,26 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceNotFound: Resource being access is not found. """ - logger.info("Updating context resource.") client = Base.get_sagemaker_client() operation_input_args = { - "ContextName": self.context_name, - "Description": description, - "Properties": properties, - "PropertiesToRemove": properties_to_remove, + "CompilationJobName": self.compilation_job_name, } - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.update_context(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() + client.delete_compilation_job(**operation_input_args) - return self + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def delete( - self, - ) -> None: + def stop(self) -> None: """ - Delete a Context resource + Stop a CompilationJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -4973,49 +5067,112 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = SageMakerClient().client operation_input_args = { - "ContextName": self.context_name, + "CompilationJobName": self.compilation_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_context(**operation_input_args) + client.stop_compilation_job(**operation_input_args) - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") - @classmethod @Base.add_validate_call - def get_all( - cls, - source_uri: Optional[str] = Unassigned(), - context_type: Optional[str] = Unassigned(), - created_after: Optional[datetime.datetime] = Unassigned(), - created_before: Optional[datetime.datetime] = Unassigned(), + def wait( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a CompilationJob resource. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + + """ + terminal_states = ["COMPLETED", "FAILED", "STOPPED"] + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for CompilationJob...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.compilation_job_status + status.update(f"Current status: [bold]{current_status}") + + if current_status in terminal_states: + logger.info(f"Final Resource Status: [bold]{current_status}") + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="CompilationJob", + status=current_status, + reason=self.failure_reason, + ) + + return + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="CompilationJob", status=current_status) + time.sleep(poll) + + @classmethod + @Base.add_validate_call + def get_all( + cls, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + status_equals: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["Context"]: + ) -> ResourceIterator["CompilationJob"]: """ - Get all Context resources + Get all CompilationJob resources Parameters: - source_uri: A filter that returns only contexts with the specified source URI. - context_type: A filter that returns only contexts of the specified type. - created_after: A filter that returns only contexts created on or after the specified time. - created_before: A filter that returns only contexts created on or before the specified time. - sort_by: The property used to sort results. The default value is CreationTime. - sort_order: The sort order. The default value is Descending. - next_token: If the previous call to ListContexts didn't return the full set of contexts, the call returns a token for getting the next set of contexts. - max_results: The maximum number of contexts to return in the response. The default value is 10. + next_token: If the result of the previous ListCompilationJobs request was truncated, the response includes a NextToken. To retrieve the next set of model compilation jobs, use the token in the next request. + max_results: The maximum number of model compilation jobs to return in the response. + creation_time_after: A filter that returns the model compilation jobs that were created after a specified time. + creation_time_before: A filter that returns the model compilation jobs that were created before a specified time. + last_modified_time_after: A filter that returns the model compilation jobs that were modified after a specified time. + last_modified_time_before: A filter that returns the model compilation jobs that were modified before a specified time. + name_contains: A filter that returns the model compilation jobs whose name contains a specified string. + status_equals: A filter that retrieves model compilation jobs with a specific CompilationJobStatus status. + sort_by: The field by which to sort results. The default is CreationTime. + sort_order: The sort order for results. The default is Ascending. session: Boto3 session. region: Region name. Returns: - Iterator for listed Context resources. + Iterator for listed CompilationJob resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5027,7 +5184,6 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client( @@ -5035,10 +5191,12 @@ def get_all( ) operation_input_args = { - "SourceUri": source_uri, - "ContextType": context_type, - "CreatedAfter": created_after, - "CreatedBefore": created_before, + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "NameContains": name_contains, + "StatusEquals": status_equals, "SortBy": sort_by, "SortOrder": sort_order, } @@ -5049,48 +5207,56 @@ def get_all( return ResourceIterator( client=client, - list_method="list_contexts", - summaries_key="ContextSummaries", - summary_name="ContextSummary", - resource_cls=Context, + list_method="list_compilation_jobs", + summaries_key="CompilationJobSummaries", + summary_name="CompilationJobSummary", + resource_cls=CompilationJob, list_method_kwargs=operation_input_args, ) -class DataQualityJobDefinition(Base): +class ComputeQuota(Base): """ - Class representing resource DataQualityJobDefinition + Class representing resource ComputeQuota Attributes: - job_definition_arn: The Amazon Resource Name (ARN) of the data quality monitoring job definition. - job_definition_name: The name of the data quality monitoring job definition. - creation_time: The time that the data quality monitoring job definition was created. - data_quality_app_specification: Information about the container that runs the data quality monitoring job. - data_quality_job_input: The list of inputs for the data quality monitoring job. Currently endpoints are supported. - data_quality_job_output_config: - job_resources: - role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - data_quality_baseline_config: The constraints and baselines for the data quality monitoring job definition. - network_config: The networking configuration for the data quality monitoring job. - stopping_condition: + compute_quota_arn: ARN of the compute allocation definition. + compute_quota_id: ID of the compute allocation definition. + name: Name of the compute allocation definition. + compute_quota_version: Version of the compute allocation definition. + status: Status of the compute allocation definition. + compute_quota_target: The target entity to allocate compute resources to. + creation_time: Creation time of the compute allocation configuration. + description: Description of the compute allocation definition. + failure_reason: Failure reason of the compute allocation definition. + cluster_arn: ARN of the cluster. + compute_quota_config: Configuration of the compute allocation definition. This includes the resource sharing option, and the setting to preempt low priority tasks. + activation_state: The state of the compute allocation being described. Use to enable or disable compute allocation. Default is Enabled. + created_by: + last_modified_time: Last modified time of the compute allocation configuration. + last_modified_by: """ - job_definition_name: str - job_definition_arn: Optional[str] = Unassigned() + compute_quota_id: str + compute_quota_arn: Optional[str] = Unassigned() + name: Optional[str] = Unassigned() + description: Optional[str] = Unassigned() + compute_quota_version: Optional[int] = Unassigned() + status: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() + cluster_arn: Optional[str] = Unassigned() + compute_quota_config: Optional[ComputeQuotaConfig] = Unassigned() + compute_quota_target: Optional[ComputeQuotaTarget] = Unassigned() + activation_state: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - data_quality_baseline_config: Optional[DataQualityBaselineConfig] = Unassigned() - data_quality_app_specification: Optional[DataQualityAppSpecification] = Unassigned() - data_quality_job_input: Optional[DataQualityJobInput] = Unassigned() - data_quality_job_output_config: Optional[MonitoringOutputConfig] = Unassigned() - job_resources: Optional[MonitoringResources] = Unassigned() - network_config: Optional[MonitoringNetworkConfig] = Unassigned() - role_arn: Optional[str] = Unassigned() - stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned() + created_by: Optional[UserContext] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + last_modified_by: Optional[UserContext] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "data_quality_job_definition_name" + resource_name = "compute_quota_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -5101,84 +5267,39 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object data_quality_job_definition") + logger.error("Name attribute not found for object compute_quota") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "data_quality_job_input": { - "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, - "batch_transform_input": { - "data_captured_destination_s3_uri": {"type": "string"}, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, - }, - "data_quality_job_output_config": {"kms_key_id": {"type": "string"}}, - "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, - "role_arn": {"type": "string"}, - "data_quality_baseline_config": { - "constraints_resource": {"s3_uri": {"type": "string"}}, - "statistics_resource": {"s3_uri": {"type": "string"}}, - }, - "network_config": { - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - } - }, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "DataQualityJobDefinition", **kwargs - ), - ) - - return wrapper - @classmethod - @populate_inputs_decorator @Base.add_validate_call def create( cls, - job_definition_name: str, - data_quality_app_specification: DataQualityAppSpecification, - data_quality_job_input: DataQualityJobInput, - data_quality_job_output_config: MonitoringOutputConfig, - job_resources: MonitoringResources, - role_arn: str, - data_quality_baseline_config: Optional[DataQualityBaselineConfig] = Unassigned(), - network_config: Optional[MonitoringNetworkConfig] = Unassigned(), - stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), + name: str, + cluster_arn: str, + compute_quota_config: ComputeQuotaConfig, + compute_quota_target: ComputeQuotaTarget, + description: Optional[str] = Unassigned(), + activation_state: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["DataQualityJobDefinition"]: + ) -> Optional["ComputeQuota"]: """ - Create a DataQualityJobDefinition resource + Create a ComputeQuota resource Parameters: - job_definition_name: The name for the monitoring job definition. - data_quality_app_specification: Specifies the container that runs the monitoring job. - data_quality_job_input: A list of inputs for the monitoring job. Currently endpoints are supported as monitoring inputs. - data_quality_job_output_config: - job_resources: - role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - data_quality_baseline_config: Configures the constraints and baselines for the monitoring job. - network_config: Specifies networking configuration for the monitoring job. - stopping_condition: - tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. + name: Name to the compute allocation definition. + cluster_arn: ARN of the cluster. + compute_quota_config: Configuration of the compute allocation definition. This includes the resource sharing option, and the setting to preempt low priority tasks. + compute_quota_target: The target entity to allocate compute resources to. + description: Description of the compute allocation definition. + activation_state: The state of the compute allocation being described. Use to enable or disable compute allocation. Default is Enabled. + tags: Tags of the compute allocation definition. session: Boto3 session. region: Region name. Returns: - The DataQualityJobDefinition resource. + The ComputeQuota resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5190,33 +5311,30 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating data_quality_job_definition resource.") + logger.info("Creating compute_quota resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "JobDefinitionName": job_definition_name, - "DataQualityBaselineConfig": data_quality_baseline_config, - "DataQualityAppSpecification": data_quality_app_specification, - "DataQualityJobInput": data_quality_job_input, - "DataQualityJobOutputConfig": data_quality_job_output_config, - "JobResources": job_resources, - "NetworkConfig": network_config, - "RoleArn": role_arn, - "StoppingCondition": stopping_condition, + "Name": name, + "Description": description, + "ClusterArn": cluster_arn, + "ComputeQuotaConfig": compute_quota_config, + "ComputeQuotaTarget": compute_quota_target, + "ActivationState": activation_state, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="DataQualityJobDefinition", operation_input_args=operation_input_args + resource_name="ComputeQuota", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -5225,29 +5343,31 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_data_quality_job_definition(**operation_input_args) + response = client.create_compute_quota(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(job_definition_name=job_definition_name, session=session, region=region) + return cls.get(compute_quota_id=response["ComputeQuotaId"], session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - job_definition_name: str, + compute_quota_id: str, + compute_quota_version: Optional[int] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["DataQualityJobDefinition"]: + ) -> Optional["ComputeQuota"]: """ - Get a DataQualityJobDefinition resource + Get a ComputeQuota resource Parameters: - job_definition_name: The name of the data quality monitoring job definition to describe. + compute_quota_id: ID of the compute allocation definition. + compute_quota_version: Version of the compute allocation definition. session: Boto3 session. region: Region name. Returns: - The DataQualityJobDefinition resource. + The ComputeQuota resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5263,7 +5383,8 @@ def get( """ operation_input_args = { - "JobDefinitionName": job_definition_name, + "ComputeQuotaId": compute_quota_id, + "ComputeQuotaVersion": compute_quota_version, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -5272,24 +5393,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_data_quality_job_definition(**operation_input_args) + response = client.describe_compute_quota(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeDataQualityJobDefinitionResponse") - data_quality_job_definition = cls(**transformed_response) - return data_quality_job_definition + transformed_response = transform(response, "DescribeComputeQuotaResponse") + compute_quota = cls(**transformed_response) + return compute_quota @Base.add_validate_call def refresh( self, - ) -> Optional["DataQualityJobDefinition"]: + ) -> Optional["ComputeQuota"]: """ - Refresh a DataQualityJobDefinition resource + Refresh a ComputeQuota resource Returns: - The DataQualityJobDefinition resource. + The ComputeQuota resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5305,25 +5426,37 @@ def refresh( """ operation_input_args = { - "JobDefinitionName": self.job_definition_name, + "ComputeQuotaId": self.compute_quota_id, + "ComputeQuotaVersion": self.compute_quota_version, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_data_quality_job_definition(**operation_input_args) + response = client.describe_compute_quota(**operation_input_args) # deserialize response and update self - transform(response, "DescribeDataQualityJobDefinitionResponse", self) + transform(response, "DescribeComputeQuotaResponse", self) return self @Base.add_validate_call - def delete( + def update( self, - ) -> None: + target_version: int, + compute_quota_config: Optional[ComputeQuotaConfig] = Unassigned(), + compute_quota_target: Optional[ComputeQuotaTarget] = Unassigned(), + activation_state: Optional[str] = Unassigned(), + description: Optional[str] = Unassigned(), + ) -> Optional["ComputeQuota"]: """ - Delete a DataQualityJobDefinition resource + Update a ComputeQuota resource + + Parameters: + target_version: Target version. + + Returns: + The ComputeQuota resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5335,52 +5468,237 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ResourceNotFound: Resource being access is not found. """ + logger.info("Updating compute_quota resource.") client = Base.get_sagemaker_client() operation_input_args = { - "JobDefinitionName": self.job_definition_name, + "ComputeQuotaId": self.compute_quota_id, + "TargetVersion": target_version, + "ComputeQuotaConfig": compute_quota_config, + "ComputeQuotaTarget": compute_quota_target, + "ActivationState": activation_state, + "Description": description, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_data_quality_job_definition(**operation_input_args) + # create the resource + response = client.update_compute_quota(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() + + return self + + @Base.add_validate_call + def delete( + self, + ) -> None: + """ + Delete a ComputeQuota resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "ComputeQuotaId": self.compute_quota_id, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_compute_quota(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call + def wait_for_status( + self, + target_status: Literal[ + "Creating", + "CreateFailed", + "CreateRollbackFailed", + "Created", + "Updating", + "UpdateFailed", + "UpdateRollbackFailed", + "Updated", + "Deleting", + "DeleteFailed", + "DeleteRollbackFailed", + "Deleted", + ], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a ComputeQuota resource to reach certain status. + + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task(f"Waiting for ComputeQuota to reach [bold]{target_status} status...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") + + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="ComputeQuota", + status=current_status, + reason=self.failure_reason, + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="ComputeQuota", status=current_status) + time.sleep(poll) + + @Base.add_validate_call + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a ComputeQuota resource to be deleted. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for ComputeQuota to be deleted...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") + + if current_status.lower() == "deleted": + print("Resource was deleted.") + return + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="ComputeQuota", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) + @classmethod @Base.add_validate_call def get_all( cls, - endpoint_name: Optional[str] = Unassigned(), + created_after: Optional[datetime.datetime] = Unassigned(), + created_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + status: Optional[str] = Unassigned(), + cluster_arn: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["DataQualityJobDefinition"]: + ) -> ResourceIterator["ComputeQuota"]: """ - Get all DataQualityJobDefinition resources + Get all ComputeQuota resources Parameters: - endpoint_name: A filter that lists the data quality job definitions associated with the specified endpoint. - sort_by: The field to sort results by. The default is CreationTime. - sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. - next_token: If the result of the previous ListDataQualityJobDefinitions request was truncated, the response includes a NextToken. To retrieve the next set of transform jobs, use the token in the next request.> - max_results: The maximum number of data quality monitoring job definitions to return in the response. - name_contains: A string in the data quality monitoring job definition name. This filter returns only data quality monitoring job definitions whose name contains the specified string. - creation_time_before: A filter that returns only data quality monitoring job definitions created before the specified time. - creation_time_after: A filter that returns only data quality monitoring job definitions created after the specified time. + created_after: Filter for after this creation time. The input for this parameter is a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter. + created_before: Filter for before this creation time. The input for this parameter is a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter. + name_contains: Filter for name containing this string. + status: Filter for status. + cluster_arn: Filter for ARN of the cluster. + sort_by: Filter for sorting the list by a given value. For example, sort by name, creation time, or status. + sort_order: The order of the list. By default, listed in Descending order according to by SortBy. To change the list order, you can specify SortOrder to be Ascending. + next_token: If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results. + max_results: The maximum number of compute allocation definitions to list. session: Boto3 session. region: Region name. Returns: - Iterator for listed DataQualityJobDefinition resources. + Iterator for listed ComputeQuota resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5399,66 +5717,63 @@ def get_all( ) operation_input_args = { - "EndpointName": endpoint_name, + "CreatedAfter": created_after, + "CreatedBefore": created_before, + "NameContains": name_contains, + "Status": status, + "ClusterArn": cluster_arn, "SortBy": sort_by, "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, - } - custom_key_mapping = { - "monitoring_job_definition_name": "job_definition_name", - "monitoring_job_definition_arn": "job_definition_arn", } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") return ResourceIterator( client=client, - list_method="list_data_quality_job_definitions", - summaries_key="JobDefinitionSummaries", - summary_name="MonitoringJobDefinitionSummary", - resource_cls=DataQualityJobDefinition, - custom_key_mapping=custom_key_mapping, + list_method="list_compute_quotas", + summaries_key="ComputeQuotaSummaries", + summary_name="ComputeQuotaSummary", + resource_cls=ComputeQuota, list_method_kwargs=operation_input_args, ) -class Device(Base): +class Context(Base): """ - Class representing resource Device + Class representing resource Context Attributes: - device_name: The unique identifier of the device. - device_fleet_name: The name of the fleet the device belongs to. - registration_time: The timestamp of the last registration or de-reregistration. - device_arn: The Amazon Resource Name (ARN) of the device. - description: A description of the device. - iot_thing_name: The Amazon Web Services Internet of Things (IoT) object thing name associated with the device. - latest_heartbeat: The last heartbeat received from the device. - models: Models on the device. - max_models: The maximum number of models. - next_token: The response from the last list when returning a list large enough to need tokening. - agent_version: Edge Manager agent version. + context_name: The name of the context. + context_arn: The Amazon Resource Name (ARN) of the context. + source: The source of the context. + context_type: The type of the context. + description: The description of the context. + properties: A list of the context's properties. + creation_time: When the context was created. + created_by: + last_modified_time: When the context was last modified. + last_modified_by: + lineage_group_arn: The Amazon Resource Name (ARN) of the lineage group. """ - device_name: str - device_fleet_name: str - device_arn: Optional[str] = Unassigned() + context_name: str + context_arn: Optional[str] = Unassigned() + source: Optional[ContextSource] = Unassigned() + context_type: Optional[str] = Unassigned() description: Optional[str] = Unassigned() - iot_thing_name: Optional[str] = Unassigned() - registration_time: Optional[datetime.datetime] = Unassigned() - latest_heartbeat: Optional[datetime.datetime] = Unassigned() - models: Optional[List[EdgeModel]] = Unassigned() - max_models: Optional[int] = Unassigned() - next_token: Optional[str] = Unassigned() - agent_version: Optional[str] = Unassigned() + properties: Optional[Dict[str, str]] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + created_by: Optional[UserContext] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + last_modified_by: Optional[UserContext] = Unassigned() + lineage_group_arn: Optional[str] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "device_name" + resource_name = "context_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -5469,31 +5784,101 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object device") + logger.error("Name attribute not found for object context") return None @classmethod @Base.add_validate_call - def get( + def create( cls, - device_name: str, - device_fleet_name: str, - next_token: Optional[str] = Unassigned(), - session: Optional[Session] = None, + context_name: str, + source: ContextSource, + context_type: str, + description: Optional[str] = Unassigned(), + properties: Optional[Dict[str, str]] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Device"]: + ) -> Optional["Context"]: """ - Get a Device resource + Create a Context resource Parameters: - device_name: The unique ID of the device. - device_fleet_name: The name of the fleet the devices belong to. - next_token: Next token of device description. + context_name: The name of the context. Must be unique to your account in an Amazon Web Services Region. + source: The source type, ID, and URI. + context_type: The context type. + description: The description of the context. + properties: A list of properties to add to the context. + tags: A list of tags to apply to the context. session: Boto3 session. region: Region name. Returns: - The Device resource. + The Context resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + """ + + logger.info("Creating context resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "ContextName": context_name, + "Source": source, + "ContextType": context_type, + "Description": description, + "Properties": properties, + "Tags": tags, + } + + operation_input_args = Base.populate_chained_attributes( + resource_name="Context", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.create_context(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get(context_name=context_name, session=session, region=region) + + @classmethod + @Base.add_validate_call + def get( + cls, + context_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["Context"]: + """ + Get a Context resource + + Parameters: + context_name: The name of the context to describe. + session: Boto3 session. + region: Region name. + + Returns: + The Context resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5509,9 +5894,7 @@ def get( """ operation_input_args = { - "NextToken": next_token, - "DeviceName": device_name, - "DeviceFleetName": device_fleet_name, + "ContextName": context_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -5520,24 +5903,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_device(**operation_input_args) + response = client.describe_context(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeDeviceResponse") - device = cls(**transformed_response) - return device + transformed_response = transform(response, "DescribeContextResponse") + context = cls(**transformed_response) + return context @Base.add_validate_call def refresh( self, - ) -> Optional["Device"]: + ) -> Optional["Context"]: """ - Refresh a Device resource + Refresh a Context resource Returns: - The Device resource. + The Context resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5553,45 +5936,133 @@ def refresh( """ operation_input_args = { - "NextToken": self.next_token, - "DeviceName": self.device_name, - "DeviceFleetName": self.device_fleet_name, + "ContextName": self.context_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_device(**operation_input_args) + response = client.describe_context(**operation_input_args) # deserialize response and update self - transform(response, "DescribeDeviceResponse", self) + transform(response, "DescribeContextResponse", self) + return self + + @Base.add_validate_call + def update( + self, + description: Optional[str] = Unassigned(), + properties: Optional[Dict[str, str]] = Unassigned(), + properties_to_remove: Optional[List[str]] = Unassigned(), + ) -> Optional["Context"]: + """ + Update a Context resource + + Parameters: + properties_to_remove: A list of properties to remove. + + Returns: + The Context resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceNotFound: Resource being access is not found. + """ + + logger.info("Updating context resource.") + client = Base.get_sagemaker_client() + + operation_input_args = { + "ContextName": self.context_name, + "Description": description, + "Properties": properties, + "PropertiesToRemove": properties_to_remove, + } + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.update_context(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() + return self + @Base.add_validate_call + def delete( + self, + ) -> None: + """ + Delete a Context resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "ContextName": self.context_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_context(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @classmethod @Base.add_validate_call def get_all( cls, - latest_heartbeat_after: Optional[datetime.datetime] = Unassigned(), - model_name: Optional[str] = Unassigned(), - device_fleet_name: Optional[str] = Unassigned(), + source_uri: Optional[str] = Unassigned(), + context_type: Optional[str] = Unassigned(), + created_after: Optional[datetime.datetime] = Unassigned(), + created_before: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["Device"]: + ) -> ResourceIterator["Context"]: """ - Get all Device resources + Get all Context resources Parameters: - next_token: The response from the last list when returning a list large enough to need tokening. - max_results: Maximum number of results to select. - latest_heartbeat_after: Select fleets where the job was updated after X - model_name: A filter that searches devices that contains this name in any of their models. - device_fleet_name: Filter for fleets containing this name in their device fleet name. + source_uri: A filter that returns only contexts with the specified source URI. + context_type: A filter that returns only contexts of the specified type. + created_after: A filter that returns only contexts created on or after the specified time. + created_before: A filter that returns only contexts created on or before the specified time. + sort_by: The property used to sort results. The default value is CreationTime. + sort_order: The sort order. The default value is Descending. + next_token: If the previous call to ListContexts didn't return the full set of contexts, the call returns a token for getting the next set of contexts. + max_results: The maximum number of contexts to return in the response. The default value is 10. session: Boto3 session. region: Region name. Returns: - Iterator for listed Device resources. + Iterator for listed Context resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5603,6 +6074,7 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client( @@ -5610,9 +6082,12 @@ def get_all( ) operation_input_args = { - "LatestHeartbeatAfter": latest_heartbeat_after, - "ModelName": model_name, - "DeviceFleetName": device_fleet_name, + "SourceUri": source_uri, + "ContextType": context_type, + "CreatedAfter": created_after, + "CreatedBefore": created_before, + "SortBy": sort_by, + "SortOrder": sort_order, } # serialize the input request @@ -5621,42 +6096,48 @@ def get_all( return ResourceIterator( client=client, - list_method="list_devices", - summaries_key="DeviceSummaries", - summary_name="DeviceSummary", - resource_cls=Device, + list_method="list_contexts", + summaries_key="ContextSummaries", + summary_name="ContextSummary", + resource_cls=Context, list_method_kwargs=operation_input_args, ) -class DeviceFleet(Base): +class DataQualityJobDefinition(Base): """ - Class representing resource DeviceFleet + Class representing resource DataQualityJobDefinition Attributes: - device_fleet_name: The name of the fleet. - device_fleet_arn: The The Amazon Resource Name (ARN) of the fleet. - output_config: The output configuration for storing sampled data. - creation_time: Timestamp of when the device fleet was created. - last_modified_time: Timestamp of when the device fleet was last updated. - description: A description of the fleet. - role_arn: The Amazon Resource Name (ARN) that has access to Amazon Web Services Internet of Things (IoT). - iot_role_alias: The Amazon Resource Name (ARN) alias created in Amazon Web Services Internet of Things (IoT). - - """ + job_definition_arn: The Amazon Resource Name (ARN) of the data quality monitoring job definition. + job_definition_name: The name of the data quality monitoring job definition. + creation_time: The time that the data quality monitoring job definition was created. + data_quality_app_specification: Information about the container that runs the data quality monitoring job. + data_quality_job_input: The list of inputs for the data quality monitoring job. Currently endpoints are supported. + data_quality_job_output_config: + job_resources: + role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. + data_quality_baseline_config: The constraints and baselines for the data quality monitoring job definition. + network_config: The networking configuration for the data quality monitoring job. + stopping_condition: - device_fleet_name: str - device_fleet_arn: Optional[str] = Unassigned() - output_config: Optional[EdgeOutputConfig] = Unassigned() - description: Optional[str] = Unassigned() + """ + + job_definition_name: str + job_definition_arn: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() + data_quality_baseline_config: Optional[DataQualityBaselineConfig] = Unassigned() + data_quality_app_specification: Optional[DataQualityAppSpecification] = Unassigned() + data_quality_job_input: Optional[DataQualityJobInput] = Unassigned() + data_quality_job_output_config: Optional[MonitoringOutputConfig] = Unassigned() + job_resources: Optional[MonitoringResources] = Unassigned() + network_config: Optional[MonitoringNetworkConfig] = Unassigned() role_arn: Optional[str] = Unassigned() - iot_role_alias: Optional[str] = Unassigned() + stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "device_fleet_name" + resource_name = "data_quality_job_definition_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -5667,24 +6148,42 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object device_fleet") + logger.error("Name attribute not found for object data_quality_job_definition") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "output_config": { - "s3_output_location": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "data_quality_job_input": { + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, }, + "data_quality_job_output_config": {"kms_key_id": {"type": "string"}}, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, "role_arn": {"type": "string"}, - "iot_role_alias": {"type": "string"}, + "data_quality_baseline_config": { + "constraints_resource": {"s3_uri": {"type": "string"}}, + "statistics_resource": {"s3_uri": {"type": "string"}}, + }, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "DeviceFleet", **kwargs + config_schema_for_resource, "DataQualityJobDefinition", **kwargs ), ) @@ -5695,30 +6194,38 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - device_fleet_name: str, - output_config: EdgeOutputConfig, - role_arn: Optional[str] = Unassigned(), - description: Optional[str] = Unassigned(), + job_definition_name: str, + data_quality_app_specification: DataQualityAppSpecification, + data_quality_job_input: DataQualityJobInput, + data_quality_job_output_config: MonitoringOutputConfig, + job_resources: MonitoringResources, + role_arn: str, + data_quality_baseline_config: Optional[DataQualityBaselineConfig] = Unassigned(), + network_config: Optional[MonitoringNetworkConfig] = Unassigned(), + stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - enable_iot_role_alias: Optional[bool] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["DeviceFleet"]: + ) -> Optional["DataQualityJobDefinition"]: """ - Create a DeviceFleet resource + Create a DataQualityJobDefinition resource Parameters: - device_fleet_name: The name of the fleet that the device belongs to. - output_config: The output configuration for storing sample data collected by the fleet. - role_arn: The Amazon Resource Name (ARN) that has access to Amazon Web Services Internet of Things (IoT). - description: A description of the fleet. - tags: Creates tags for the specified fleet. - enable_iot_role_alias: Whether to create an Amazon Web Services IoT Role Alias during device fleet creation. The name of the role alias generated will match this pattern: "SageMakerEdge-{DeviceFleetName}". For example, if your device fleet is called "demo-fleet", the name of the role alias will be "SageMakerEdge-demo-fleet". + job_definition_name: The name for the monitoring job definition. + data_quality_app_specification: Specifies the container that runs the monitoring job. + data_quality_job_input: A list of inputs for the monitoring job. Currently endpoints are supported as monitoring inputs. + data_quality_job_output_config: + job_resources: + role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. + data_quality_baseline_config: Configures the constraints and baselines for the monitoring job. + network_config: Specifies networking configuration for the monitoring job. + stopping_condition: + tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. session: Boto3 session. region: Region name. Returns: - The DeviceFleet resource. + The DataQualityJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5737,22 +6244,26 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating device_fleet resource.") + logger.info("Creating data_quality_job_definition resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "DeviceFleetName": device_fleet_name, + "JobDefinitionName": job_definition_name, + "DataQualityBaselineConfig": data_quality_baseline_config, + "DataQualityAppSpecification": data_quality_app_specification, + "DataQualityJobInput": data_quality_job_input, + "DataQualityJobOutputConfig": data_quality_job_output_config, + "JobResources": job_resources, + "NetworkConfig": network_config, "RoleArn": role_arn, - "Description": description, - "OutputConfig": output_config, + "StoppingCondition": stopping_condition, "Tags": tags, - "EnableIotRoleAlias": enable_iot_role_alias, } operation_input_args = Base.populate_chained_attributes( - resource_name="DeviceFleet", operation_input_args=operation_input_args + resource_name="DataQualityJobDefinition", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -5761,29 +6272,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_device_fleet(**operation_input_args) + response = client.create_data_quality_job_definition(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(device_fleet_name=device_fleet_name, session=session, region=region) + return cls.get(job_definition_name=job_definition_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - device_fleet_name: str, + job_definition_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["DeviceFleet"]: + ) -> Optional["DataQualityJobDefinition"]: """ - Get a DeviceFleet resource + Get a DataQualityJobDefinition resource Parameters: - device_fleet_name: The name of the fleet. + job_definition_name: The name of the data quality monitoring job definition to describe. session: Boto3 session. region: Region name. Returns: - The DeviceFleet resource. + The DataQualityJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5799,7 +6310,7 @@ def get( """ operation_input_args = { - "DeviceFleetName": device_fleet_name, + "JobDefinitionName": job_definition_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -5808,24 +6319,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_device_fleet(**operation_input_args) + response = client.describe_data_quality_job_definition(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeDeviceFleetResponse") - device_fleet = cls(**transformed_response) - return device_fleet + transformed_response = transform(response, "DescribeDataQualityJobDefinitionResponse") + data_quality_job_definition = cls(**transformed_response) + return data_quality_job_definition @Base.add_validate_call def refresh( self, - ) -> Optional["DeviceFleet"]: + ) -> Optional["DataQualityJobDefinition"]: """ - Refresh a DeviceFleet resource + Refresh a DataQualityJobDefinition resource Returns: - The DeviceFleet resource. + The DataQualityJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5841,70 +6352,17 @@ def refresh( """ operation_input_args = { - "DeviceFleetName": self.device_fleet_name, + "JobDefinitionName": self.job_definition_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_device_fleet(**operation_input_args) + response = client.describe_data_quality_job_definition(**operation_input_args) # deserialize response and update self - transform(response, "DescribeDeviceFleetResponse", self) - return self - - @populate_inputs_decorator - @Base.add_validate_call - def update( - self, - output_config: EdgeOutputConfig, - role_arn: Optional[str] = Unassigned(), - description: Optional[str] = Unassigned(), - enable_iot_role_alias: Optional[bool] = Unassigned(), - ) -> Optional["DeviceFleet"]: - """ - Update a DeviceFleet resource - - Parameters: - enable_iot_role_alias: Whether to create an Amazon Web Services IoT Role Alias during device fleet creation. The name of the role alias generated will match this pattern: "SageMakerEdge-{DeviceFleetName}". For example, if your device fleet is called "demo-fleet", the name of the role alias will be "SageMakerEdge-demo-fleet". - - Returns: - The DeviceFleet resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceInUse: Resource being accessed is in use. - """ - - logger.info("Updating device_fleet resource.") - client = Base.get_sagemaker_client() - - operation_input_args = { - "DeviceFleetName": self.device_fleet_name, - "RoleArn": role_arn, - "Description": description, - "OutputConfig": output_config, - "EnableIotRoleAlias": enable_iot_role_alias, - } - logger.debug(f"Input request: {operation_input_args}") - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - # create the resource - response = client.update_device_fleet(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() - + transform(response, "DescribeDataQualityJobDefinitionResponse", self) return self @Base.add_validate_call @@ -5912,7 +6370,7 @@ def delete( self, ) -> None: """ - Delete a DeviceFleet resource + Delete a DataQualityJobDefinition resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5924,19 +6382,19 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. + ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "DeviceFleetName": self.device_fleet_name, + "JobDefinitionName": self.job_definition_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_device_fleet(**operation_input_args) + client.delete_data_quality_job_definition(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @@ -5944,34 +6402,32 @@ def delete( @Base.add_validate_call def get_all( cls, - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - name_contains: Optional[str] = Unassigned(), + endpoint_name: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["DeviceFleet"]: + ) -> ResourceIterator["DataQualityJobDefinition"]: """ - Get all DeviceFleet resources + Get all DataQualityJobDefinition resources Parameters: - next_token: The response from the last list when returning a list large enough to need tokening. - max_results: The maximum number of results to select. - creation_time_after: Filter fleets where packaging job was created after specified time. - creation_time_before: Filter fleets where the edge packaging job was created before specified time. - last_modified_time_after: Select fleets where the job was updated after X - last_modified_time_before: Select fleets where the job was updated before X - name_contains: Filter for fleets containing this name in their fleet device name. - sort_by: The column to sort by. - sort_order: What direction to sort in. + endpoint_name: A filter that lists the data quality job definitions associated with the specified endpoint. + sort_by: The field to sort results by. The default is CreationTime. + sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. + next_token: If the result of the previous ListDataQualityJobDefinitions request was truncated, the response includes a NextToken. To retrieve the next set of transform jobs, use the token in the next request.> + max_results: The maximum number of data quality monitoring job definitions to return in the response. + name_contains: A string in the data quality monitoring job definition name. This filter returns only data quality monitoring job definitions whose name contains the specified string. + creation_time_before: A filter that returns only data quality monitoring job definitions created before the specified time. + creation_time_after: A filter that returns only data quality monitoring job definitions created after the specified time. session: Boto3 session. region: Region name. Returns: - Iterator for listed DeviceFleet resources. + Iterator for listed DataQualityJobDefinition resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -5990,86 +6446,101 @@ def get_all( ) operation_input_args = { - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "LastModifiedTimeBefore": last_modified_time_before, - "NameContains": name_contains, + "EndpointName": endpoint_name, "SortBy": sort_by, "SortOrder": sort_order, + "NameContains": name_contains, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + } + custom_key_mapping = { + "monitoring_job_definition_name": "job_definition_name", + "monitoring_job_definition_arn": "job_definition_arn", } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") return ResourceIterator( client=client, - list_method="list_device_fleets", - summaries_key="DeviceFleetSummaries", - summary_name="DeviceFleetSummary", - resource_cls=DeviceFleet, + list_method="list_data_quality_job_definitions", + summaries_key="JobDefinitionSummaries", + summary_name="MonitoringJobDefinitionSummary", + resource_cls=DataQualityJobDefinition, + custom_key_mapping=custom_key_mapping, list_method_kwargs=operation_input_args, ) - @Base.add_validate_call - def deregister_devices( - self, - device_names: List[str], - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> None: - """ - Deregisters the specified devices. - Parameters: - device_names: The unique IDs of the devices. - session: Boto3 session. - region: Region name. +class Device(Base): + """ + Class representing resource Device - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - """ + Attributes: + device_name: The unique identifier of the device. + device_fleet_name: The name of the fleet the device belongs to. + registration_time: The timestamp of the last registration or de-reregistration. + device_arn: The Amazon Resource Name (ARN) of the device. + description: A description of the device. + iot_thing_name: The Amazon Web Services Internet of Things (IoT) object thing name associated with the device. + latest_heartbeat: The last heartbeat received from the device. + models: Models on the device. + max_models: The maximum number of models. + next_token: The response from the last list when returning a list large enough to need tokening. + agent_version: Edge Manager agent version. - operation_input_args = { - "DeviceFleetName": self.device_fleet_name, - "DeviceNames": device_names, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") + """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + device_name: str + device_fleet_name: str + device_arn: Optional[str] = Unassigned() + description: Optional[str] = Unassigned() + iot_thing_name: Optional[str] = Unassigned() + registration_time: Optional[datetime.datetime] = Unassigned() + latest_heartbeat: Optional[datetime.datetime] = Unassigned() + models: Optional[List[EdgeModel]] = Unassigned() + max_models: Optional[int] = Unassigned() + next_token: Optional[str] = Unassigned() + agent_version: Optional[str] = Unassigned() - logger.debug(f"Calling deregister_devices API") - response = client.deregister_devices(**operation_input_args) - logger.debug(f"Response: {response}") + def get_name(self) -> str: + attributes = vars(self) + resource_name = "device_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object device") + return None + @classmethod @Base.add_validate_call - def get_report( - self, + def get( + cls, + device_name: str, + device_fleet_name: str, + next_token: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional[GetDeviceFleetReportResponse]: + ) -> Optional["Device"]: """ - Describes a fleet. + Get a Device resource Parameters: + device_name: The unique ID of the device. + device_fleet_name: The name of the fleet the devices belong to. + next_token: Next token of device description. session: Boto3 session. region: Region name. Returns: - GetDeviceFleetReportResponse + The Device resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6081,10 +6552,13 @@ def get_report( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "DeviceFleetName": self.device_fleet_name, + "NextToken": next_token, + "DeviceName": device_name, + "DeviceFleetName": device_fleet_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -6093,30 +6567,24 @@ def get_report( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) + response = client.describe_device(**operation_input_args) - logger.debug(f"Calling get_device_fleet_report API") - response = client.get_device_fleet_report(**operation_input_args) - logger.debug(f"Response: {response}") + logger.debug(response) - transformed_response = transform(response, "GetDeviceFleetReportResponse") - return GetDeviceFleetReportResponse(**transformed_response) + # deserialize the response + transformed_response = transform(response, "DescribeDeviceResponse") + device = cls(**transformed_response) + return device @Base.add_validate_call - def register_devices( + def refresh( self, - devices: List[Device], - tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> None: + ) -> Optional["Device"]: """ - Register devices. + Refresh a Device resource - Parameters: - devices: A list of devices to register with SageMaker Edge Manager. - tags: The tags associated with devices. - session: Boto3 session. - region: Region name. + Returns: + The Device resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6128,41 +6596,50 @@ def register_devices( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ operation_input_args = { + "NextToken": self.next_token, + "DeviceName": self.device_name, "DeviceFleetName": self.device_fleet_name, - "Devices": devices, - "Tags": tags, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + client = Base.get_sagemaker_client() + response = client.describe_device(**operation_input_args) - logger.debug(f"Calling register_devices API") - response = client.register_devices(**operation_input_args) - logger.debug(f"Response: {response}") + # deserialize response and update self + transform(response, "DescribeDeviceResponse", self) + return self + @classmethod @Base.add_validate_call - def update_devices( - self, - devices: List[Device], + def get_all( + cls, + latest_heartbeat_after: Optional[datetime.datetime] = Unassigned(), + model_name: Optional[str] = Unassigned(), + device_fleet_name: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> None: + ) -> ResourceIterator["Device"]: """ - Updates one or more devices in a fleet. + Get all Device resources Parameters: - devices: List of devices to register with Edge Manager agent. + next_token: The response from the last list when returning a list large enough to need tokening. + max_results: Maximum number of results to select. + latest_heartbeat_after: Select fleets where the job was updated after X + model_name: A filter that searches devices that contains this name in any of their models. + device_fleet_name: Filter for fleets containing this name in their device fleet name. session: Boto3 session. region: Region name. + Returns: + Iterator for listed Device resources. + Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -6175,81 +6652,58 @@ def update_devices( ``` """ + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "DeviceFleetName": self.device_fleet_name, - "Devices": devices, + "LatestHeartbeatAfter": latest_heartbeat_after, + "ModelName": model_name, + "DeviceFleetName": device_fleet_name, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + return ResourceIterator( + client=client, + list_method="list_devices", + summaries_key="DeviceSummaries", + summary_name="DeviceSummary", + resource_cls=Device, + list_method_kwargs=operation_input_args, ) - logger.debug(f"Calling update_devices API") - response = client.update_devices(**operation_input_args) - logger.debug(f"Response: {response}") - -class Domain(Base): +class DeviceFleet(Base): """ - Class representing resource Domain + Class representing resource DeviceFleet Attributes: - domain_arn: The domain's Amazon Resource Name (ARN). - domain_id: The domain ID. - domain_name: The domain name. - home_efs_file_system_id: The ID of the Amazon Elastic File System managed by this Domain. - single_sign_on_managed_application_instance_id: The IAM Identity Center managed application instance ID. - single_sign_on_application_arn: The ARN of the application managed by SageMaker in IAM Identity Center. This value is only returned for domains created after October 1, 2023. - status: The status. - creation_time: The creation time. - last_modified_time: The last modified time. - failure_reason: The failure reason. - security_group_id_for_domain_boundary: The ID of the security group that authorizes traffic between the RSessionGateway apps and the RStudioServerPro app. - auth_mode: The domain's authentication mode. - default_user_settings: Settings which are applied to UserProfiles in this domain if settings are not explicitly specified in a given UserProfile. - domain_settings: A collection of Domain settings. - app_network_access_type: Specifies the VPC used for non-EFS traffic. The default value is PublicInternetOnly. PublicInternetOnly - Non-EFS traffic is through a VPC managed by Amazon SageMaker, which allows direct internet access VpcOnly - All traffic is through the specified VPC and subnets - home_efs_file_system_kms_key_id: Use KmsKeyId. - subnet_ids: The VPC subnets that the domain uses for communication. - url: The domain's URL. - vpc_id: The ID of the Amazon Virtual Private Cloud (VPC) that the domain uses for communication. - kms_key_id: The Amazon Web Services KMS customer managed key used to encrypt the EFS volume attached to the domain. - app_security_group_management: The entity that creates and manages the required security groups for inter-app communication in VPCOnly mode. Required when CreateDomain.AppNetworkAccessType is VPCOnly and DomainSettings.RStudioServerProDomainSettings.DomainExecutionRoleArn is provided. - tag_propagation: Indicates whether custom tag propagation is supported for the domain. - default_space_settings: The default settings for shared spaces that users create in the domain. + device_fleet_name: The name of the fleet. + device_fleet_arn: The The Amazon Resource Name (ARN) of the fleet. + output_config: The output configuration for storing sampled data. + creation_time: Timestamp of when the device fleet was created. + last_modified_time: Timestamp of when the device fleet was last updated. + description: A description of the fleet. + role_arn: The Amazon Resource Name (ARN) that has access to Amazon Web Services Internet of Things (IoT). + iot_role_alias: The Amazon Resource Name (ARN) alias created in Amazon Web Services Internet of Things (IoT). """ - domain_id: str - domain_arn: Optional[str] = Unassigned() - domain_name: Optional[str] = Unassigned() - home_efs_file_system_id: Optional[str] = Unassigned() - single_sign_on_managed_application_instance_id: Optional[str] = Unassigned() - single_sign_on_application_arn: Optional[str] = Unassigned() - status: Optional[str] = Unassigned() + device_fleet_name: str + device_fleet_arn: Optional[str] = Unassigned() + output_config: Optional[EdgeOutputConfig] = Unassigned() + description: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - failure_reason: Optional[str] = Unassigned() - security_group_id_for_domain_boundary: Optional[str] = Unassigned() - auth_mode: Optional[str] = Unassigned() - default_user_settings: Optional[UserSettings] = Unassigned() - domain_settings: Optional[DomainSettings] = Unassigned() - app_network_access_type: Optional[str] = Unassigned() - home_efs_file_system_kms_key_id: Optional[str] = Unassigned() - subnet_ids: Optional[List[str]] = Unassigned() - url: Optional[str] = Unassigned() - vpc_id: Optional[str] = Unassigned() - kms_key_id: Optional[str] = Unassigned() - app_security_group_management: Optional[str] = Unassigned() - tag_propagation: Optional[str] = Unassigned() - default_space_settings: Optional[DefaultSpaceSettings] = Unassigned() + role_arn: Optional[str] = Unassigned() + iot_role_alias: Optional[str] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "domain_name" + resource_name = "device_fleet_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -6260,68 +6714,24 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object domain") + logger.error("Name attribute not found for object device_fleet") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "security_group_id_for_domain_boundary": {"type": "string"}, - "default_user_settings": { - "execution_role": {"type": "string"}, - "security_groups": {"type": "array", "items": {"type": "string"}}, - "sharing_settings": { - "s3_output_path": {"type": "string"}, - "s3_kms_key_id": {"type": "string"}, - }, - "canvas_app_settings": { - "time_series_forecasting_settings": { - "amazon_forecast_role_arn": {"type": "string"} - }, - "model_register_settings": { - "cross_account_model_register_role_arn": {"type": "string"} - }, - "workspace_settings": { - "s3_artifact_path": {"type": "string"}, - "s3_kms_key_id": {"type": "string"}, - }, - "generative_ai_settings": {"amazon_bedrock_role_arn": {"type": "string"}}, - "emr_serverless_settings": {"execution_role_arn": {"type": "string"}}, - }, - "jupyter_lab_app_settings": { - "emr_settings": { - "assumable_role_arns": {"type": "array", "items": {"type": "string"}}, - "execution_role_arns": {"type": "array", "items": {"type": "string"}}, - } - }, - }, - "domain_settings": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "r_studio_server_pro_domain_settings": { - "domain_execution_role_arn": {"type": "string"} - }, - "execution_role_identity_config": {"type": "string"}, - }, - "home_efs_file_system_kms_key_id": {"type": "string"}, - "subnet_ids": {"type": "array", "items": {"type": "string"}}, - "kms_key_id": {"type": "string"}, - "app_security_group_management": {"type": "string"}, - "default_space_settings": { - "execution_role": {"type": "string"}, - "security_groups": {"type": "array", "items": {"type": "string"}}, - "jupyter_lab_app_settings": { - "emr_settings": { - "assumable_role_arns": {"type": "array", "items": {"type": "string"}}, - "execution_role_arns": {"type": "array", "items": {"type": "string"}}, - } - }, + "output_config": { + "s3_output_location": {"type": "string"}, + "kms_key_id": {"type": "string"}, }, + "role_arn": {"type": "string"}, + "iot_role_alias": {"type": "string"}, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "Domain", **kwargs + config_schema_for_resource, "DeviceFleet", **kwargs ), ) @@ -6332,44 +6742,30 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - domain_name: str, - auth_mode: str, - default_user_settings: UserSettings, - subnet_ids: List[str], - vpc_id: str, - domain_settings: Optional[DomainSettings] = Unassigned(), + device_fleet_name: str, + output_config: EdgeOutputConfig, + role_arn: Optional[str] = Unassigned(), + description: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - app_network_access_type: Optional[str] = Unassigned(), - home_efs_file_system_kms_key_id: Optional[str] = Unassigned(), - kms_key_id: Optional[str] = Unassigned(), - app_security_group_management: Optional[str] = Unassigned(), - tag_propagation: Optional[str] = Unassigned(), - default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(), + enable_iot_role_alias: Optional[bool] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Domain"]: + ) -> Optional["DeviceFleet"]: """ - Create a Domain resource + Create a DeviceFleet resource Parameters: - domain_name: A name for the domain. - auth_mode: The mode of authentication that members use to access the domain. - default_user_settings: The default settings to use to create a user profile when UserSettings isn't specified in the call to the CreateUserProfile API. SecurityGroups is aggregated when specified in both calls. For all other settings in UserSettings, the values specified in CreateUserProfile take precedence over those specified in CreateDomain. - subnet_ids: The VPC subnets that the domain uses for communication. - vpc_id: The ID of the Amazon Virtual Private Cloud (VPC) that the domain uses for communication. - domain_settings: A collection of Domain settings. - tags: Tags to associated with the Domain. Each tag consists of a key and an optional value. Tag keys must be unique per resource. Tags are searchable using the Search API. Tags that you specify for the Domain are also added to all Apps that the Domain launches. - app_network_access_type: Specifies the VPC used for non-EFS traffic. The default value is PublicInternetOnly. PublicInternetOnly - Non-EFS traffic is through a VPC managed by Amazon SageMaker, which allows direct internet access VpcOnly - All traffic is through the specified VPC and subnets - home_efs_file_system_kms_key_id: Use KmsKeyId. - kms_key_id: SageMaker uses Amazon Web Services KMS to encrypt EFS and EBS volumes attached to the domain with an Amazon Web Services managed key by default. For more control, specify a customer managed key. - app_security_group_management: The entity that creates and manages the required security groups for inter-app communication in VPCOnly mode. Required when CreateDomain.AppNetworkAccessType is VPCOnly and DomainSettings.RStudioServerProDomainSettings.DomainExecutionRoleArn is provided. If setting up the domain for use with RStudio, this value must be set to Service. - tag_propagation: Indicates whether custom tag propagation is supported for the domain. Defaults to DISABLED. - default_space_settings: The default settings for shared spaces that users create in the domain. + device_fleet_name: The name of the fleet that the device belongs to. + output_config: The output configuration for storing sample data collected by the fleet. + role_arn: The Amazon Resource Name (ARN) that has access to Amazon Web Services Internet of Things (IoT). + description: A description of the fleet. + tags: Creates tags for the specified fleet. + enable_iot_role_alias: Whether to create an Amazon Web Services IoT Role Alias during device fleet creation. The name of the role alias generated will match this pattern: "SageMakerEdge-{DeviceFleetName}". For example, if your device fleet is called "demo-fleet", the name of the role alias will be "SageMakerEdge-demo-fleet". session: Boto3 session. region: Region name. Returns: - The Domain resource. + The DeviceFleet resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6388,29 +6784,22 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating domain resource.") + logger.info("Creating device_fleet resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "DomainName": domain_name, - "AuthMode": auth_mode, - "DefaultUserSettings": default_user_settings, - "DomainSettings": domain_settings, - "SubnetIds": subnet_ids, - "VpcId": vpc_id, + "DeviceFleetName": device_fleet_name, + "RoleArn": role_arn, + "Description": description, + "OutputConfig": output_config, "Tags": tags, - "AppNetworkAccessType": app_network_access_type, - "HomeEfsFileSystemKmsKeyId": home_efs_file_system_kms_key_id, - "KmsKeyId": kms_key_id, - "AppSecurityGroupManagement": app_security_group_management, - "TagPropagation": tag_propagation, - "DefaultSpaceSettings": default_space_settings, + "EnableIotRoleAlias": enable_iot_role_alias, } operation_input_args = Base.populate_chained_attributes( - resource_name="Domain", operation_input_args=operation_input_args + resource_name="DeviceFleet", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -6419,29 +6808,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_domain(**operation_input_args) + response = client.create_device_fleet(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(domain_id=response["DomainId"], session=session, region=region) + return cls.get(device_fleet_name=device_fleet_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - domain_id: str, + device_fleet_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Domain"]: + ) -> Optional["DeviceFleet"]: """ - Get a Domain resource + Get a DeviceFleet resource Parameters: - domain_id: The domain ID. + device_fleet_name: The name of the fleet. session: Boto3 session. region: Region name. Returns: - The Domain resource. + The DeviceFleet resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6457,7 +6846,7 @@ def get( """ operation_input_args = { - "DomainId": domain_id, + "DeviceFleetName": device_fleet_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -6466,24 +6855,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_domain(**operation_input_args) + response = client.describe_device_fleet(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeDomainResponse") - domain = cls(**transformed_response) - return domain + transformed_response = transform(response, "DescribeDeviceFleetResponse") + device_fleet = cls(**transformed_response) + return device_fleet @Base.add_validate_call def refresh( self, - ) -> Optional["Domain"]: + ) -> Optional["DeviceFleet"]: """ - Refresh a Domain resource + Refresh a DeviceFleet resource Returns: - The Domain resource. + The DeviceFleet resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6499,39 +6888,36 @@ def refresh( """ operation_input_args = { - "DomainId": self.domain_id, + "DeviceFleetName": self.device_fleet_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_domain(**operation_input_args) + response = client.describe_device_fleet(**operation_input_args) # deserialize response and update self - transform(response, "DescribeDomainResponse", self) + transform(response, "DescribeDeviceFleetResponse", self) return self @populate_inputs_decorator @Base.add_validate_call def update( self, - default_user_settings: Optional[UserSettings] = Unassigned(), - domain_settings_for_update: Optional[DomainSettingsForUpdate] = Unassigned(), - app_security_group_management: Optional[str] = Unassigned(), - default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(), - subnet_ids: Optional[List[str]] = Unassigned(), - app_network_access_type: Optional[str] = Unassigned(), - tag_propagation: Optional[str] = Unassigned(), - ) -> Optional["Domain"]: + output_config: EdgeOutputConfig, + role_arn: Optional[str] = Unassigned(), + description: Optional[str] = Unassigned(), + enable_iot_role_alias: Optional[bool] = Unassigned(), + ) -> Optional["DeviceFleet"]: """ - Update a Domain resource + Update a DeviceFleet resource Parameters: - domain_settings_for_update: A collection of DomainSettings configuration values to update. + enable_iot_role_alias: Whether to create an Amazon Web Services IoT Role Alias during device fleet creation. The name of the role alias generated will match this pattern: "SageMakerEdge-{DeviceFleetName}". For example, if your device fleet is called "demo-fleet", the name of the role alias will be "SageMakerEdge-demo-fleet". Returns: - The Domain resource. + The DeviceFleet resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6544,22 +6930,17 @@ def update( error_code = e.response['Error']['Code'] ``` ResourceInUse: Resource being accessed is in use. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ResourceNotFound: Resource being access is not found. """ - logger.info("Updating domain resource.") + logger.info("Updating device_fleet resource.") client = Base.get_sagemaker_client() operation_input_args = { - "DomainId": self.domain_id, - "DefaultUserSettings": default_user_settings, - "DomainSettingsForUpdate": domain_settings_for_update, - "AppSecurityGroupManagement": app_security_group_management, - "DefaultSpaceSettings": default_space_settings, - "SubnetIds": subnet_ids, - "AppNetworkAccessType": app_network_access_type, - "TagPropagation": tag_propagation, + "DeviceFleetName": self.device_fleet_name, + "RoleArn": role_arn, + "Description": description, + "OutputConfig": output_config, + "EnableIotRoleAlias": enable_iot_role_alias, } logger.debug(f"Input request: {operation_input_args}") # serialize the input request @@ -6567,7 +6948,7 @@ def update( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_domain(**operation_input_args) + response = client.update_device_fleet(**operation_input_args) logger.debug(f"Response: {response}") self.refresh() @@ -6576,10 +6957,9 @@ def update( @Base.add_validate_call def delete( self, - retention_policy: Optional[RetentionPolicy] = Unassigned(), ) -> None: """ - Delete a Domain resource + Delete a DeviceFleet resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6592,99 +6972,107 @@ def delete( error_code = e.response['Error']['Code'] ``` ResourceInUse: Resource being accessed is in use. - ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "DomainId": self.domain_id, - "RetentionPolicy": retention_policy, + "DeviceFleetName": self.device_fleet_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_domain(**operation_input_args) + client.delete_device_fleet(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @classmethod @Base.add_validate_call - def wait_for_status( - self, - target_status: Literal[ - "Deleting", - "Failed", - "InService", - "Pending", - "Updating", - "Update_Failed", - "Delete_Failed", - ], - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: + def get_all( + cls, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["DeviceFleet"]: """ - Wait for a Domain resource to reach certain status. + Get all DeviceFleet resources Parameters: - target_status: The status to wait for. - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + next_token: The response from the last list when returning a list large enough to need tokening. + max_results: The maximum number of results to select. + creation_time_after: Filter fleets where packaging job was created after specified time. + creation_time_before: Filter fleets where the edge packaging job was created before specified time. + last_modified_time_after: Select fleets where the job was updated after X + last_modified_time_before: Select fleets where the job was updated before X + name_contains: Filter for fleets containing this name in their fleet device name. + sort_by: The column to sort by. + sort_order: What direction to sort in. + session: Boto3 session. + region: Region name. - Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. + Returns: + Iterator for listed DeviceFleet resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` """ - start_time = time.time() - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" ) - progress.add_task(f"Waiting for Domain to reach [bold]{target_status} status...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.status - status.update(f"Current status: [bold]{current_status}") - if target_status == current_status: - logger.info(f"Final Resource Status: [bold]{current_status}") - return + operation_input_args = { + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "NameContains": name_contains, + "SortBy": sort_by, + "SortOrder": sort_order, + } - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="Domain", status=current_status, reason=self.failure_reason - ) + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="Domain", status=current_status) - time.sleep(poll) + return ResourceIterator( + client=client, + list_method="list_device_fleets", + summaries_key="DeviceFleetSummaries", + summary_name="DeviceFleetSummary", + resource_cls=DeviceFleet, + list_method_kwargs=operation_input_args, + ) @Base.add_validate_call - def wait_for_delete( + def deregister_devices( self, - poll: int = 5, - timeout: Optional[int] = None, + device_names: List[str], + session: Optional[Session] = None, + region: Optional[str] = None, ) -> None: """ - Wait for a Domain resource to be deleted. + Deregisters the specified devices. Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + device_names: The unique IDs of the devices. + session: Boto3 session. + region: Region name. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6696,157 +7084,87 @@ def wait_for_delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for Domain to be deleted...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ) - ): - while True: - try: - self.refresh() - current_status = self.status - status.update(f"Current status: [bold]{current_status}") - if ( - "delete_failed" in current_status.lower() - or "deletefailed" in current_status.lower() - ): - raise DeleteFailedStatusError( - resource_type="Domain", reason=self.failure_reason - ) + operation_input_args = { + "DeviceFleetName": self.device_fleet_name, + "DeviceNames": device_names, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="Domain", status=current_status) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e - time.sleep(poll) + logger.debug(f"Calling deregister_devices API") + response = client.deregister_devices(**operation_input_args) + logger.debug(f"Response: {response}") - @classmethod @Base.add_validate_call - def get_all( - cls, + def get_report( + self, session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["Domain"]: + ) -> Optional[GetDeviceFleetReportResponse]: """ - Get all Domain resources. + Describes a fleet. Parameters: session: Boto3 session. region: Region name. Returns: - Iterator for listed Domain resources. + GetDeviceFleetReportResponse + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` """ + + operation_input_args = { + "DeviceFleetName": self.device_fleet_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - return ResourceIterator( - client=client, - list_method="list_domains", - summaries_key="Domains", - summary_name="DomainDetails", - resource_cls=Domain, - ) - - -class EdgeDeploymentPlan(Base): - """ - Class representing resource EdgeDeploymentPlan - - Attributes: - edge_deployment_plan_arn: The ARN of edge deployment plan. - edge_deployment_plan_name: The name of the edge deployment plan. - model_configs: List of models associated with the edge deployment plan. - device_fleet_name: The device fleet used for this edge deployment plan. - stages: List of stages in the edge deployment plan. - edge_deployment_success: The number of edge devices with the successful deployment. - edge_deployment_pending: The number of edge devices yet to pick up deployment, or in progress. - edge_deployment_failed: The number of edge devices that failed the deployment. - next_token: Token to use when calling the next set of stages in the edge deployment plan. - creation_time: The time when the edge deployment plan was created. - last_modified_time: The time when the edge deployment plan was last updated. - - """ - - edge_deployment_plan_name: str - edge_deployment_plan_arn: Optional[str] = Unassigned() - model_configs: Optional[List[EdgeDeploymentModelConfig]] = Unassigned() - device_fleet_name: Optional[str] = Unassigned() - edge_deployment_success: Optional[int] = Unassigned() - edge_deployment_pending: Optional[int] = Unassigned() - edge_deployment_failed: Optional[int] = Unassigned() - stages: Optional[List[DeploymentStageStatusSummary]] = Unassigned() - next_token: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - - def get_name(self) -> str: - attributes = vars(self) - resource_name = "edge_deployment_plan_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] - - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) + logger.debug(f"Calling get_device_fleet_report API") + response = client.get_device_fleet_report(**operation_input_args) + logger.debug(f"Response: {response}") - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object edge_deployment_plan") - return None + transformed_response = transform(response, "GetDeviceFleetReportResponse") + return GetDeviceFleetReportResponse(**transformed_response) - @classmethod @Base.add_validate_call - def create( - cls, - edge_deployment_plan_name: str, - model_configs: List[EdgeDeploymentModelConfig], - device_fleet_name: Union[str, object], - stages: Optional[List[DeploymentStage]] = Unassigned(), + def register_devices( + self, + devices: List[Device], tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["EdgeDeploymentPlan"]: + ) -> None: """ - Create a EdgeDeploymentPlan resource + Register devices. Parameters: - edge_deployment_plan_name: The name of the edge deployment plan. - model_configs: List of models associated with the edge deployment plan. - device_fleet_name: The device fleet used for this edge deployment plan. - stages: List of stages of the edge deployment plan. The number of stages is limited to 10 per deployment. - tags: List of tags with which to tag the edge deployment plan. + devices: A list of devices to register with SageMaker Edge Manager. + tags: The tags associated with devices. session: Boto3 session. region: Region name. - Returns: - The EdgeDeploymentPlan resource. - Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -6858,64 +7176,40 @@ def create( error_code = e.response['Error']['Code'] ``` ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating edge_deployment_plan resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { - "EdgeDeploymentPlanName": edge_deployment_plan_name, - "ModelConfigs": model_configs, - "DeviceFleetName": device_fleet_name, - "Stages": stages, + "DeviceFleetName": self.device_fleet_name, + "Devices": devices, "Tags": tags, } - - operation_input_args = Base.populate_chained_attributes( - resource_name="EdgeDeploymentPlan", operation_input_args=operation_input_args - ) - - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.create_edge_deployment_plan(**operation_input_args) - logger.debug(f"Response: {response}") - - return cls.get( - edge_deployment_plan_name=edge_deployment_plan_name, session=session, region=region + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" ) - @classmethod + logger.debug(f"Calling register_devices API") + response = client.register_devices(**operation_input_args) + logger.debug(f"Response: {response}") + @Base.add_validate_call - def get( - cls, - edge_deployment_plan_name: str, - next_token: Optional[str] = Unassigned(), - max_results: Optional[int] = Unassigned(), + def update_devices( + self, + devices: List[Device], session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["EdgeDeploymentPlan"]: + ) -> None: """ - Get a EdgeDeploymentPlan resource + Updates one or more devices in a fleet. Parameters: - edge_deployment_plan_name: The name of the deployment plan to describe. - next_token: If the edge deployment plan has enough stages to require tokening, then this is the response from the last list of stages returned. - max_results: The maximum number of results to select (50 by default). + devices: List of devices to register with Edge Manager agent. session: Boto3 session. region: Region name. - Returns: - The EdgeDeploymentPlan resource. - Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -6926,13 +7220,11 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "EdgeDeploymentPlanName": edge_deployment_plan_name, - "NextToken": next_token, - "MaxResults": max_results, + "DeviceFleetName": self.device_fleet_name, + "Devices": devices, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -6941,25 +7233,190 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_edge_deployment_plan(**operation_input_args) - logger.debug(response) + logger.debug(f"Calling update_devices API") + response = client.update_devices(**operation_input_args) + logger.debug(f"Response: {response}") - # deserialize the response - transformed_response = transform(response, "DescribeEdgeDeploymentPlanResponse") - edge_deployment_plan = cls(**transformed_response) - return edge_deployment_plan +class Domain(Base): + """ + Class representing resource Domain + + Attributes: + domain_arn: The domain's Amazon Resource Name (ARN). + domain_id: The domain ID. + domain_name: The domain name. + home_efs_file_system_id: The ID of the Amazon Elastic File System managed by this Domain. + single_sign_on_managed_application_instance_id: The IAM Identity Center managed application instance ID. + single_sign_on_application_arn: The ARN of the application managed by SageMaker in IAM Identity Center. This value is only returned for domains created after October 1, 2023. + status: The status. + creation_time: The creation time. + last_modified_time: The last modified time. + failure_reason: The failure reason. + security_group_id_for_domain_boundary: The ID of the security group that authorizes traffic between the RSessionGateway apps and the RStudioServerPro app. + auth_mode: The domain's authentication mode. + default_user_settings: Settings which are applied to UserProfiles in this domain if settings are not explicitly specified in a given UserProfile. + domain_settings: A collection of Domain settings. + app_network_access_type: Specifies the VPC used for non-EFS traffic. The default value is PublicInternetOnly. PublicInternetOnly - Non-EFS traffic is through a VPC managed by Amazon SageMaker, which allows direct internet access VpcOnly - All traffic is through the specified VPC and subnets + home_efs_file_system_kms_key_id: Use KmsKeyId. + subnet_ids: The VPC subnets that the domain uses for communication. + url: The domain's URL. + vpc_id: The ID of the Amazon Virtual Private Cloud (VPC) that the domain uses for communication. + kms_key_id: The Amazon Web Services KMS customer managed key used to encrypt the EFS volume attached to the domain. + app_security_group_management: The entity that creates and manages the required security groups for inter-app communication in VPCOnly mode. Required when CreateDomain.AppNetworkAccessType is VPCOnly and DomainSettings.RStudioServerProDomainSettings.DomainExecutionRoleArn is provided. + tag_propagation: Indicates whether custom tag propagation is supported for the domain. + default_space_settings: The default settings for shared spaces that users create in the domain. + + """ + + domain_id: str + domain_arn: Optional[str] = Unassigned() + domain_name: Optional[str] = Unassigned() + home_efs_file_system_id: Optional[str] = Unassigned() + single_sign_on_managed_application_instance_id: Optional[str] = Unassigned() + single_sign_on_application_arn: Optional[str] = Unassigned() + status: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + failure_reason: Optional[str] = Unassigned() + security_group_id_for_domain_boundary: Optional[str] = Unassigned() + auth_mode: Optional[str] = Unassigned() + default_user_settings: Optional[UserSettings] = Unassigned() + domain_settings: Optional[DomainSettings] = Unassigned() + app_network_access_type: Optional[str] = Unassigned() + home_efs_file_system_kms_key_id: Optional[str] = Unassigned() + subnet_ids: Optional[List[str]] = Unassigned() + url: Optional[str] = Unassigned() + vpc_id: Optional[str] = Unassigned() + kms_key_id: Optional[str] = Unassigned() + app_security_group_management: Optional[str] = Unassigned() + tag_propagation: Optional[str] = Unassigned() + default_space_settings: Optional[DefaultSpaceSettings] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "domain_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object domain") + return None + + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "security_group_id_for_domain_boundary": {"type": "string"}, + "default_user_settings": { + "execution_role": {"type": "string"}, + "security_groups": {"type": "array", "items": {"type": "string"}}, + "sharing_settings": { + "s3_output_path": {"type": "string"}, + "s3_kms_key_id": {"type": "string"}, + }, + "canvas_app_settings": { + "time_series_forecasting_settings": { + "amazon_forecast_role_arn": {"type": "string"} + }, + "model_register_settings": { + "cross_account_model_register_role_arn": {"type": "string"} + }, + "workspace_settings": { + "s3_artifact_path": {"type": "string"}, + "s3_kms_key_id": {"type": "string"}, + }, + "generative_ai_settings": {"amazon_bedrock_role_arn": {"type": "string"}}, + "emr_serverless_settings": {"execution_role_arn": {"type": "string"}}, + }, + "jupyter_lab_app_settings": { + "emr_settings": { + "assumable_role_arns": {"type": "array", "items": {"type": "string"}}, + "execution_role_arns": {"type": "array", "items": {"type": "string"}}, + } + }, + }, + "domain_settings": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "r_studio_server_pro_domain_settings": { + "domain_execution_role_arn": {"type": "string"} + }, + "execution_role_identity_config": {"type": "string"}, + }, + "home_efs_file_system_kms_key_id": {"type": "string"}, + "subnet_ids": {"type": "array", "items": {"type": "string"}}, + "kms_key_id": {"type": "string"}, + "app_security_group_management": {"type": "string"}, + "default_space_settings": { + "execution_role": {"type": "string"}, + "security_groups": {"type": "array", "items": {"type": "string"}}, + "jupyter_lab_app_settings": { + "emr_settings": { + "assumable_role_arns": {"type": "array", "items": {"type": "string"}}, + "execution_role_arns": {"type": "array", "items": {"type": "string"}}, + } + }, + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "Domain", **kwargs + ), + ) + + return wrapper + + @classmethod + @populate_inputs_decorator @Base.add_validate_call - def refresh( - self, - max_results: Optional[int] = Unassigned(), - ) -> Optional["EdgeDeploymentPlan"]: + def create( + cls, + domain_name: str, + auth_mode: str, + default_user_settings: UserSettings, + subnet_ids: List[str], + vpc_id: str, + domain_settings: Optional[DomainSettings] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + app_network_access_type: Optional[str] = Unassigned(), + home_efs_file_system_kms_key_id: Optional[str] = Unassigned(), + kms_key_id: Optional[str] = Unassigned(), + app_security_group_management: Optional[str] = Unassigned(), + tag_propagation: Optional[str] = Unassigned(), + default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["Domain"]: """ - Refresh a EdgeDeploymentPlan resource + Create a Domain resource + + Parameters: + domain_name: A name for the domain. + auth_mode: The mode of authentication that members use to access the domain. + default_user_settings: The default settings to use to create a user profile when UserSettings isn't specified in the call to the CreateUserProfile API. SecurityGroups is aggregated when specified in both calls. For all other settings in UserSettings, the values specified in CreateUserProfile take precedence over those specified in CreateDomain. + subnet_ids: The VPC subnets that the domain uses for communication. + vpc_id: The ID of the Amazon Virtual Private Cloud (VPC) that the domain uses for communication. + domain_settings: A collection of Domain settings. + tags: Tags to associated with the Domain. Each tag consists of a key and an optional value. Tag keys must be unique per resource. Tags are searchable using the Search API. Tags that you specify for the Domain are also added to all Apps that the Domain launches. + app_network_access_type: Specifies the VPC used for non-EFS traffic. The default value is PublicInternetOnly. PublicInternetOnly - Non-EFS traffic is through a VPC managed by Amazon SageMaker, which allows direct internet access VpcOnly - All traffic is through the specified VPC and subnets + home_efs_file_system_kms_key_id: Use KmsKeyId. + kms_key_id: SageMaker uses Amazon Web Services KMS to encrypt EFS and EBS volumes attached to the domain with an Amazon Web Services managed key by default. For more control, specify a customer managed key. + app_security_group_management: The entity that creates and manages the required security groups for inter-app communication in VPCOnly mode. Required when CreateDomain.AppNetworkAccessType is VPCOnly and DomainSettings.RStudioServerProDomainSettings.DomainExecutionRoleArn is provided. If setting up the domain for use with RStudio, this value must be set to Service. + tag_propagation: Indicates whether custom tag propagation is supported for the domain. Defaults to DISABLED. + default_space_settings: The default settings for shared spaces that users create in the domain. + session: Boto3 session. + region: Region name. Returns: - The EdgeDeploymentPlan resource. + The Domain resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -6971,92 +7428,67 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - operation_input_args = { - "EdgeDeploymentPlanName": self.edge_deployment_plan_name, - "NextToken": self.next_token, - "MaxResults": max_results, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client = Base.get_sagemaker_client() - response = client.describe_edge_deployment_plan(**operation_input_args) - - # deserialize response and update self - transform(response, "DescribeEdgeDeploymentPlanResponse", self) - return self - - @Base.add_validate_call - def delete( - self, - ) -> None: - """ - Delete a EdgeDeploymentPlan resource - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceInUse: Resource being accessed is in use. - """ - - client = Base.get_sagemaker_client() + logger.info("Creating domain resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) operation_input_args = { - "EdgeDeploymentPlanName": self.edge_deployment_plan_name, + "DomainName": domain_name, + "AuthMode": auth_mode, + "DefaultUserSettings": default_user_settings, + "DomainSettings": domain_settings, + "SubnetIds": subnet_ids, + "VpcId": vpc_id, + "Tags": tags, + "AppNetworkAccessType": app_network_access_type, + "HomeEfsFileSystemKmsKeyId": home_efs_file_system_kms_key_id, + "KmsKeyId": kms_key_id, + "AppSecurityGroupManagement": app_security_group_management, + "TagPropagation": tag_propagation, + "DefaultSpaceSettings": default_space_settings, } + + operation_input_args = Base.populate_chained_attributes( + resource_name="Domain", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_edge_deployment_plan(**operation_input_args) + # create the resource + response = client.create_domain(**operation_input_args) + logger.debug(f"Response: {response}") - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + return cls.get(domain_id=response["DomainId"], session=session, region=region) @classmethod @Base.add_validate_call - def get_all( + def get( cls, - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - device_fleet_name_contains: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), + domain_id: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["EdgeDeploymentPlan"]: + ) -> Optional["Domain"]: """ - Get all EdgeDeploymentPlan resources + Get a Domain resource Parameters: - next_token: The response from the last list when returning a list large enough to need tokening. - max_results: The maximum number of results to select (50 by default). - creation_time_after: Selects edge deployment plans created after this time. - creation_time_before: Selects edge deployment plans created before this time. - last_modified_time_after: Selects edge deployment plans that were last updated after this time. - last_modified_time_before: Selects edge deployment plans that were last updated before this time. - name_contains: Selects edge deployment plans with names containing this name. - device_fleet_name_contains: Selects edge deployment plans with a device fleet name containing this name. - sort_by: The column by which to sort the edge deployment plans. Can be one of NAME, DEVICEFLEETNAME, CREATIONTIME, LASTMODIFIEDTIME. - sort_order: The direction of the sorting (ascending or descending). + domain_id: The domain ID. session: Boto3 session. region: Region name. Returns: - Iterator for listed EdgeDeploymentPlan resources. + The Domain resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7068,48 +7500,37 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "LastModifiedTimeBefore": last_modified_time_before, - "NameContains": name_contains, - "DeviceFleetNameContains": device_fleet_name_contains, - "SortBy": sort_by, - "SortOrder": sort_order, + "DomainId": domain_id, } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_edge_deployment_plans", - summaries_key="EdgeDeploymentPlanSummaries", - summary_name="EdgeDeploymentPlanSummary", - resource_cls=EdgeDeploymentPlan, - list_method_kwargs=operation_input_args, + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" ) + response = client.describe_domain(**operation_input_args) + + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeDomainResponse") + domain = cls(**transformed_response) + return domain @Base.add_validate_call - def create_stage( + def refresh( self, - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> None: + ) -> Optional["Domain"]: """ - Creates a new stage in an existing edge deployment plan. + Refresh a Domain resource - Parameters: - session: Boto3 session. - region: Region name. + Returns: + The Domain resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7121,39 +7542,43 @@ def create_stage( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "EdgeDeploymentPlanName": self.edge_deployment_plan_name, - "Stages": self.stages, + "DomainId": self.domain_id, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + client = Base.get_sagemaker_client() + response = client.describe_domain(**operation_input_args) - logger.debug(f"Calling create_edge_deployment_stage API") - response = client.create_edge_deployment_stage(**operation_input_args) - logger.debug(f"Response: {response}") + # deserialize response and update self + transform(response, "DescribeDomainResponse", self) + return self + @populate_inputs_decorator @Base.add_validate_call - def delete_stage( + def update( self, - stage_name: str, - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> None: + default_user_settings: Optional[UserSettings] = Unassigned(), + domain_settings_for_update: Optional[DomainSettingsForUpdate] = Unassigned(), + app_security_group_management: Optional[str] = Unassigned(), + default_space_settings: Optional[DefaultSpaceSettings] = Unassigned(), + subnet_ids: Optional[List[str]] = Unassigned(), + app_network_access_type: Optional[str] = Unassigned(), + tag_propagation: Optional[str] = Unassigned(), + ) -> Optional["Domain"]: """ - Delete a stage in an edge deployment plan if (and only if) the stage is inactive. + Update a Domain resource Parameters: - stage_name: The name of the stage. - session: Boto3 session. - region: Region name. + domain_settings_for_update: A collection of DomainSettings configuration values to update. + + Returns: + The Domain resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7166,38 +7591,42 @@ def delete_stage( error_code = e.response['Error']['Code'] ``` ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ + logger.info("Updating domain resource.") + client = Base.get_sagemaker_client() + operation_input_args = { - "EdgeDeploymentPlanName": self.edge_deployment_plan_name, - "StageName": stage_name, + "DomainId": self.domain_id, + "DefaultUserSettings": default_user_settings, + "DomainSettingsForUpdate": domain_settings_for_update, + "AppSecurityGroupManagement": app_security_group_management, + "DefaultSpaceSettings": default_space_settings, + "SubnetIds": subnet_ids, + "AppNetworkAccessType": app_network_access_type, + "TagPropagation": tag_propagation, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - logger.debug(f"Calling delete_edge_deployment_stage API") - response = client.delete_edge_deployment_stage(**operation_input_args) + # create the resource + response = client.update_domain(**operation_input_args) logger.debug(f"Response: {response}") + self.refresh() + + return self @Base.add_validate_call - def start_stage( + def delete( self, - stage_name: str, - session: Optional[Session] = None, - region: Optional[str] = None, + retention_policy: Optional[RetentionPolicy] = Unassigned(), ) -> None: """ - Starts a stage in an edge deployment plan. - - Parameters: - stage_name: The name of the stage to start. - session: Boto3 session. - region: Region name. + Delete a Domain resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7209,87 +7638,100 @@ def start_stage( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. + ResourceNotFound: Resource being access is not found. """ + client = Base.get_sagemaker_client() + operation_input_args = { - "EdgeDeploymentPlanName": self.edge_deployment_plan_name, - "StageName": stage_name, + "DomainId": self.domain_id, + "RetentionPolicy": retention_policy, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + client.delete_domain(**operation_input_args) - logger.debug(f"Calling start_edge_deployment_stage API") - response = client.start_edge_deployment_stage(**operation_input_args) - logger.debug(f"Response: {response}") + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def stop_stage( + def wait_for_status( self, - stage_name: str, - session: Optional[Session] = None, - region: Optional[str] = None, + target_status: Literal[ + "Deleting", + "Failed", + "InService", + "Pending", + "Updating", + "Update_Failed", + "Delete_Failed", + ], + poll: int = 5, + timeout: Optional[int] = None, ) -> None: """ - Stops a stage in an edge deployment plan. + Wait for a Domain resource to reach certain status. Parameters: - stage_name: The name of the stage to stop. - session: Boto3 session. - region: Region name. + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. """ + start_time = time.time() - operation_input_args = { - "EdgeDeploymentPlanName": self.edge_deployment_plan_name, - "StageName": stage_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), ) + progress.add_task(f"Waiting for Domain to reach [bold]{target_status} status...") + status = Status("Current status:") - logger.debug(f"Calling stop_edge_deployment_stage API") - response = client.stop_edge_deployment_stage(**operation_input_args) - logger.debug(f"Response: {response}") + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") + + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="Domain", status=current_status, reason=self.failure_reason + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="Domain", status=current_status) + time.sleep(poll) @Base.add_validate_call - def get_all_stage_devices( + def wait_for_delete( self, - stage_name: str, - exclude_devices_deployed_in_other_stage: Optional[bool] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator[DeviceDeploymentSummary]: + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: """ - Lists devices allocated to the stage, containing detailed device information and deployment status. + Wait for a Domain resource to be deleted. Parameters: - stage_name: The name of the stage in the deployment. - max_results: The maximum number of requests to select. - exclude_devices_deployed_in_other_stage: Toggle for excluding devices deployed in other stages. - session: Boto3 session. - region: Region name. - - Returns: - Iterator for listed DeviceDeploymentSummary. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7301,73 +7743,117 @@ def get_all_stage_devices( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. """ + start_time = time.time() - operation_input_args = { - "EdgeDeploymentPlanName": self.edge_deployment_plan_name, - "ExcludeDevicesDeployedInOtherStage": exclude_devices_deployed_in_other_stage, - "StageName": stage_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for Domain to be deleted...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") + + if ( + "delete_failed" in current_status.lower() + or "deletefailed" in current_status.lower() + ): + raise DeleteFailedStatusError( + resource_type="Domain", reason=self.failure_reason + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="Domain", status=current_status) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) + + @classmethod + @Base.add_validate_call + def get_all( + cls, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["Domain"]: + """ + Get all Domain resources. + + Parameters: + session: Boto3 session. + region: Region name. + Returns: + Iterator for listed Domain resources. + + """ client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) return ResourceIterator( client=client, - list_method="list_stage_devices", - summaries_key="DeviceDeploymentSummaries", - summary_name="DeviceDeploymentSummary", - resource_cls=DeviceDeploymentSummary, - list_method_kwargs=operation_input_args, + list_method="list_domains", + summaries_key="Domains", + summary_name="DomainDetails", + resource_cls=Domain, ) -class EdgePackagingJob(Base): +class EdgeDeploymentPlan(Base): """ - Class representing resource EdgePackagingJob + Class representing resource EdgeDeploymentPlan Attributes: - edge_packaging_job_arn: The Amazon Resource Name (ARN) of the edge packaging job. - edge_packaging_job_name: The name of the edge packaging job. - edge_packaging_job_status: The current status of the packaging job. - compilation_job_name: The name of the SageMaker Neo compilation job that is used to locate model artifacts that are being packaged. - model_name: The name of the model. - model_version: The version of the model. - role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to download and upload the model, and to contact Neo. - output_config: The output configuration for the edge packaging job. - resource_key: The Amazon Web Services KMS key to use when encrypting the EBS volume the job run on. - edge_packaging_job_status_message: Returns a message describing the job status and error messages. - creation_time: The timestamp of when the packaging job was created. - last_modified_time: The timestamp of when the job was last updated. - model_artifact: The Amazon Simple Storage (S3) URI where model artifacts ares stored. - model_signature: The signature document of files in the model artifact. - preset_deployment_output: The output of a SageMaker Edge Manager deployable resource. + edge_deployment_plan_arn: The ARN of edge deployment plan. + edge_deployment_plan_name: The name of the edge deployment plan. + model_configs: List of models associated with the edge deployment plan. + device_fleet_name: The device fleet used for this edge deployment plan. + stages: List of stages in the edge deployment plan. + edge_deployment_success: The number of edge devices with the successful deployment. + edge_deployment_pending: The number of edge devices yet to pick up deployment, or in progress. + edge_deployment_failed: The number of edge devices that failed the deployment. + next_token: Token to use when calling the next set of stages in the edge deployment plan. + creation_time: The time when the edge deployment plan was created. + last_modified_time: The time when the edge deployment plan was last updated. """ - edge_packaging_job_name: str - edge_packaging_job_arn: Optional[str] = Unassigned() - compilation_job_name: Optional[str] = Unassigned() - model_name: Optional[str] = Unassigned() - model_version: Optional[str] = Unassigned() - role_arn: Optional[str] = Unassigned() - output_config: Optional[EdgeOutputConfig] = Unassigned() - resource_key: Optional[str] = Unassigned() - edge_packaging_job_status: Optional[str] = Unassigned() - edge_packaging_job_status_message: Optional[str] = Unassigned() + edge_deployment_plan_name: str + edge_deployment_plan_arn: Optional[str] = Unassigned() + model_configs: Optional[List[EdgeDeploymentModelConfig]] = Unassigned() + device_fleet_name: Optional[str] = Unassigned() + edge_deployment_success: Optional[int] = Unassigned() + edge_deployment_pending: Optional[int] = Unassigned() + edge_deployment_failed: Optional[int] = Unassigned() + stages: Optional[List[DeploymentStageStatusSummary]] = Unassigned() + next_token: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - model_artifact: Optional[str] = Unassigned() - model_signature: Optional[str] = Unassigned() - preset_deployment_output: Optional[EdgePresetDeploymentOutput] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "edge_packaging_job_name" + resource_name = "edge_deployment_plan_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -7378,61 +7864,35 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object edge_packaging_job") + logger.error("Name attribute not found for object edge_deployment_plan") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "role_arn": {"type": "string"}, - "output_config": { - "s3_output_location": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "EdgePackagingJob", **kwargs - ), - ) - - return wrapper - @classmethod - @populate_inputs_decorator @Base.add_validate_call def create( cls, - edge_packaging_job_name: str, - compilation_job_name: Union[str, object], - model_name: Union[str, object], - model_version: str, - role_arn: str, - output_config: EdgeOutputConfig, - resource_key: Optional[str] = Unassigned(), + edge_deployment_plan_name: str, + model_configs: List[EdgeDeploymentModelConfig], + device_fleet_name: Union[str, object], + stages: Optional[List[DeploymentStage]] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["EdgePackagingJob"]: + ) -> Optional["EdgeDeploymentPlan"]: """ - Create a EdgePackagingJob resource + Create a EdgeDeploymentPlan resource Parameters: - edge_packaging_job_name: The name of the edge packaging job. - compilation_job_name: The name of the SageMaker Neo compilation job that will be used to locate model artifacts for packaging. - model_name: The name of the model. - model_version: The version of the model. - role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to download and upload the model, and to contact SageMaker Neo. - output_config: Provides information about the output location for the packaged model. - resource_key: The Amazon Web Services KMS key to use when encrypting the EBS volume the edge packaging job runs on. - tags: Creates tags for the packaging job. + edge_deployment_plan_name: The name of the edge deployment plan. + model_configs: List of models associated with the edge deployment plan. + device_fleet_name: The device fleet used for this edge deployment plan. + stages: List of stages of the edge deployment plan. The number of stages is limited to 10 per deployment. + tags: List of tags with which to tag the edge deployment plan. session: Boto3 session. region: Region name. Returns: - The EdgePackagingJob resource. + The EdgeDeploymentPlan resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7450,24 +7910,21 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating edge_packaging_job resource.") + logger.info("Creating edge_deployment_plan resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "EdgePackagingJobName": edge_packaging_job_name, - "CompilationJobName": compilation_job_name, - "ModelName": model_name, - "ModelVersion": model_version, - "RoleArn": role_arn, - "OutputConfig": output_config, - "ResourceKey": resource_key, + "EdgeDeploymentPlanName": edge_deployment_plan_name, + "ModelConfigs": model_configs, + "DeviceFleetName": device_fleet_name, + "Stages": stages, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="EdgePackagingJob", operation_input_args=operation_input_args + resource_name="EdgeDeploymentPlan", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -7476,31 +7933,35 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_edge_packaging_job(**operation_input_args) + response = client.create_edge_deployment_plan(**operation_input_args) logger.debug(f"Response: {response}") return cls.get( - edge_packaging_job_name=edge_packaging_job_name, session=session, region=region + edge_deployment_plan_name=edge_deployment_plan_name, session=session, region=region ) @classmethod @Base.add_validate_call def get( cls, - edge_packaging_job_name: str, + edge_deployment_plan_name: str, + next_token: Optional[str] = Unassigned(), + max_results: Optional[int] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["EdgePackagingJob"]: + ) -> Optional["EdgeDeploymentPlan"]: """ - Get a EdgePackagingJob resource + Get a EdgeDeploymentPlan resource Parameters: - edge_packaging_job_name: The name of the edge packaging job. + edge_deployment_plan_name: The name of the deployment plan to describe. + next_token: If the edge deployment plan has enough stages to require tokening, then this is the response from the last list of stages returned. + max_results: The maximum number of results to select (50 by default). session: Boto3 session. region: Region name. Returns: - The EdgePackagingJob resource. + The EdgeDeploymentPlan resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7516,7 +7977,9 @@ def get( """ operation_input_args = { - "EdgePackagingJobName": edge_packaging_job_name, + "EdgeDeploymentPlanName": edge_deployment_plan_name, + "NextToken": next_token, + "MaxResults": max_results, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -7525,24 +7988,25 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_edge_packaging_job(**operation_input_args) + response = client.describe_edge_deployment_plan(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeEdgePackagingJobResponse") - edge_packaging_job = cls(**transformed_response) - return edge_packaging_job + transformed_response = transform(response, "DescribeEdgeDeploymentPlanResponse") + edge_deployment_plan = cls(**transformed_response) + return edge_deployment_plan @Base.add_validate_call def refresh( self, - ) -> Optional["EdgePackagingJob"]: + max_results: Optional[int] = Unassigned(), + ) -> Optional["EdgeDeploymentPlan"]: """ - Refresh a EdgePackagingJob resource + Refresh a EdgeDeploymentPlan resource Returns: - The EdgePackagingJob resource. + The EdgeDeploymentPlan resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7558,23 +8022,27 @@ def refresh( """ operation_input_args = { - "EdgePackagingJobName": self.edge_packaging_job_name, + "EdgeDeploymentPlanName": self.edge_deployment_plan_name, + "NextToken": self.next_token, + "MaxResults": max_results, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_edge_packaging_job(**operation_input_args) + response = client.describe_edge_deployment_plan(**operation_input_args) # deserialize response and update self - transform(response, "DescribeEdgePackagingJobResponse", self) + transform(response, "DescribeEdgeDeploymentPlanResponse", self) return self @Base.add_validate_call - def stop(self) -> None: + def delete( + self, + ) -> None: """ - Stop a EdgePackagingJob resource + Delete a EdgeDeploymentPlan resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7586,81 +8054,21 @@ def stop(self) -> None: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client() operation_input_args = { - "EdgePackagingJobName": self.edge_packaging_job_name, + "EdgeDeploymentPlanName": self.edge_deployment_plan_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.stop_edge_packaging_job(**operation_input_args) - - logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def wait( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a EdgePackagingJob resource. - - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. - - Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - - """ - terminal_states = ["COMPLETED", "FAILED", "STOPPED"] - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for EdgePackagingJob...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.edge_packaging_job_status - status.update(f"Current status: [bold]{current_status}") - - if current_status in terminal_states: - logger.info(f"Final Resource Status: [bold]{current_status}") - - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="EdgePackagingJob", - status=current_status, - reason=self.edge_packaging_job_status_message, - ) - - return + client.delete_edge_deployment_plan(**operation_input_args) - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="EdgePackagingJob", status=current_status - ) - time.sleep(poll) + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @classmethod @Base.add_validate_call @@ -7671,33 +8079,31 @@ def get_all( last_modified_time_after: Optional[datetime.datetime] = Unassigned(), last_modified_time_before: Optional[datetime.datetime] = Unassigned(), name_contains: Optional[str] = Unassigned(), - model_name_contains: Optional[str] = Unassigned(), - status_equals: Optional[str] = Unassigned(), + device_fleet_name_contains: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["EdgePackagingJob"]: + ) -> ResourceIterator["EdgeDeploymentPlan"]: """ - Get all EdgePackagingJob resources + Get all EdgeDeploymentPlan resources Parameters: next_token: The response from the last list when returning a list large enough to need tokening. - max_results: Maximum number of results to select. - creation_time_after: Select jobs where the job was created after specified time. - creation_time_before: Select jobs where the job was created before specified time. - last_modified_time_after: Select jobs where the job was updated after specified time. - last_modified_time_before: Select jobs where the job was updated before specified time. - name_contains: Filter for jobs containing this name in their packaging job name. - model_name_contains: Filter for jobs where the model name contains this string. - status_equals: The job status to filter for. - sort_by: Use to specify what column to sort by. - sort_order: What direction to sort by. + max_results: The maximum number of results to select (50 by default). + creation_time_after: Selects edge deployment plans created after this time. + creation_time_before: Selects edge deployment plans created before this time. + last_modified_time_after: Selects edge deployment plans that were last updated after this time. + last_modified_time_before: Selects edge deployment plans that were last updated before this time. + name_contains: Selects edge deployment plans with names containing this name. + device_fleet_name_contains: Selects edge deployment plans with a device fleet name containing this name. + sort_by: The column by which to sort the edge deployment plans. Can be one of NAME, DEVICEFLEETNAME, CREATIONTIME, LASTMODIFIEDTIME. + sort_order: The direction of the sorting (ascending or descending). session: Boto3 session. region: Region name. Returns: - Iterator for listed EdgePackagingJob resources. + Iterator for listed EdgeDeploymentPlan resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7721,8 +8127,7 @@ def get_all( "LastModifiedTimeAfter": last_modified_time_after, "LastModifiedTimeBefore": last_modified_time_before, "NameContains": name_contains, - "ModelNameContains": model_name_contains, - "StatusEquals": status_equals, + "DeviceFleetNameContains": device_fleet_name_contains, "SortBy": sort_by, "SortOrder": sort_order, } @@ -7733,118 +8138,70 @@ def get_all( return ResourceIterator( client=client, - list_method="list_edge_packaging_jobs", - summaries_key="EdgePackagingJobSummaries", - summary_name="EdgePackagingJobSummary", - resource_cls=EdgePackagingJob, + list_method="list_edge_deployment_plans", + summaries_key="EdgeDeploymentPlanSummaries", + summary_name="EdgeDeploymentPlanSummary", + resource_cls=EdgeDeploymentPlan, list_method_kwargs=operation_input_args, ) + @Base.add_validate_call + def create_stage( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: + """ + Creates a new stage in an existing edge deployment plan. -class Endpoint(Base): - """ - Class representing resource Endpoint - - Attributes: - endpoint_name: Name of the endpoint. - endpoint_arn: The Amazon Resource Name (ARN) of the endpoint. - endpoint_status: The status of the endpoint. OutOfService: Endpoint is not available to take incoming requests. Creating: CreateEndpoint is executing. Updating: UpdateEndpoint or UpdateEndpointWeightsAndCapacities is executing. SystemUpdating: Endpoint is undergoing maintenance and cannot be updated or deleted or re-scaled until it has completed. This maintenance operation does not change any customer-specified values such as VPC config, KMS encryption, model, instance type, or instance count. RollingBack: Endpoint fails to scale up or down or change its variant weight and is in the process of rolling back to its previous configuration. Once the rollback completes, endpoint returns to an InService status. This transitional status only applies to an endpoint that has autoscaling enabled and is undergoing variant weight or capacity changes as part of an UpdateEndpointWeightsAndCapacities call or when the UpdateEndpointWeightsAndCapacities operation is called explicitly. InService: Endpoint is available to process incoming requests. Deleting: DeleteEndpoint is executing. Failed: Endpoint could not be created, updated, or re-scaled. Use the FailureReason value returned by DescribeEndpoint for information about the failure. DeleteEndpoint is the only operation that can be performed on a failed endpoint. UpdateRollbackFailed: Both the rolling deployment and auto-rollback failed. Your endpoint is in service with a mix of the old and new endpoint configurations. For information about how to remedy this issue and restore the endpoint's status to InService, see Rolling Deployments. - creation_time: A timestamp that shows when the endpoint was created. - last_modified_time: A timestamp that shows when the endpoint was last modified. - endpoint_config_name: The name of the endpoint configuration associated with this endpoint. - production_variants: An array of ProductionVariantSummary objects, one for each model hosted behind this endpoint. - data_capture_config: - failure_reason: If the status of the endpoint is Failed, the reason why it failed. - last_deployment_config: The most recent deployment configuration for the endpoint. - async_inference_config: Returns the description of an endpoint configuration created using the CreateEndpointConfig API. - pending_deployment_summary: Returns the summary of an in-progress deployment. This field is only returned when the endpoint is creating or updating with a new endpoint configuration. - explainer_config: The configuration parameters for an explainer. - shadow_production_variants: An array of ProductionVariantSummary objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. + Parameters: + session: Boto3 session. + region: Region name. - """ + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + """ - endpoint_name: str - endpoint_arn: Optional[str] = Unassigned() - endpoint_config_name: Optional[str] = Unassigned() - production_variants: Optional[List[ProductionVariantSummary]] = Unassigned() - data_capture_config: Optional[DataCaptureConfigSummary] = Unassigned() - endpoint_status: Optional[str] = Unassigned() - failure_reason: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - last_deployment_config: Optional[DeploymentConfig] = Unassigned() - async_inference_config: Optional[AsyncInferenceConfig] = Unassigned() - pending_deployment_summary: Optional[PendingDeploymentSummary] = Unassigned() - explainer_config: Optional[ExplainerConfig] = Unassigned() - shadow_production_variants: Optional[List[ProductionVariantSummary]] = Unassigned() + operation_input_args = { + "EdgeDeploymentPlanName": self.edge_deployment_plan_name, + "Stages": self.stages, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") - def get_name(self) -> str: - attributes = vars(self) - resource_name = "endpoint_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) + logger.debug(f"Calling create_edge_deployment_stage API") + response = client.create_edge_deployment_stage(**operation_input_args) + logger.debug(f"Response: {response}") - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object endpoint") - return None - - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "data_capture_config": { - "destination_s3_uri": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, - "async_inference_config": { - "output_config": { - "kms_key_id": {"type": "string"}, - "s3_output_path": {"type": "string"}, - "s3_failure_path": {"type": "string"}, - } - }, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "Endpoint", **kwargs - ), - ) - - return wrapper - - @classmethod - @populate_inputs_decorator @Base.add_validate_call - def create( - cls, - endpoint_name: str, - endpoint_config_name: Union[str, object], - deployment_config: Optional[DeploymentConfig] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), + def delete_stage( + self, + stage_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Endpoint"]: + ) -> None: """ - Create a Endpoint resource + Delete a stage in an edge deployment plan if (and only if) the stage is inactive. Parameters: - endpoint_name: The name of the endpoint.The name must be unique within an Amazon Web Services Region in your Amazon Web Services account. The name is case-insensitive in CreateEndpoint, but the case is preserved and must be matched in InvokeEndpoint. - endpoint_config_name: The name of an endpoint configuration. For more information, see CreateEndpointConfig. - deployment_config: - tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + stage_name: The name of the stage. session: Boto3 session. region: Region name. - Returns: - The Endpoint resource. - Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -7855,58 +8212,40 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + ResourceInUse: Resource being accessed is in use. """ - logger.info("Creating endpoint resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { - "EndpointName": endpoint_name, - "EndpointConfigName": endpoint_config_name, - "DeploymentConfig": deployment_config, - "Tags": tags, + "EdgeDeploymentPlanName": self.edge_deployment_plan_name, + "StageName": stage_name, } - - operation_input_args = Base.populate_chained_attributes( - resource_name="Endpoint", operation_input_args=operation_input_args - ) - - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.create_endpoint(**operation_input_args) - logger.debug(f"Response: {response}") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) - return cls.get(endpoint_name=endpoint_name, session=session, region=region) + logger.debug(f"Calling delete_edge_deployment_stage API") + response = client.delete_edge_deployment_stage(**operation_input_args) + logger.debug(f"Response: {response}") - @classmethod @Base.add_validate_call - def get( - cls, - endpoint_name: str, + def start_stage( + self, + stage_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Endpoint"]: + ) -> None: """ - Get a Endpoint resource + Starts a stage in an edge deployment plan. Parameters: - endpoint_name: The name of the endpoint. + stage_name: The name of the stage to start. session: Boto3 session. region: Region name. - Returns: - The Endpoint resource. - Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -7920,7 +8259,8 @@ def get( """ operation_input_args = { - "EndpointName": endpoint_name, + "EdgeDeploymentPlanName": self.edge_deployment_plan_name, + "StageName": stage_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -7929,24 +8269,25 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_endpoint(**operation_input_args) - - logger.debug(response) - # deserialize the response - transformed_response = transform(response, "DescribeEndpointOutput") - endpoint = cls(**transformed_response) - return endpoint + logger.debug(f"Calling start_edge_deployment_stage API") + response = client.start_edge_deployment_stage(**operation_input_args) + logger.debug(f"Response: {response}") @Base.add_validate_call - def refresh( + def stop_stage( self, - ) -> Optional["Endpoint"]: + stage_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: """ - Refresh a Endpoint resource + Stops a stage in an edge deployment plan. - Returns: - The Endpoint resource. + Parameters: + stage_name: The name of the stage to stop. + session: Boto3 session. + region: Region name. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -7961,39 +8302,41 @@ def refresh( """ operation_input_args = { - "EndpointName": self.endpoint_name, + "EdgeDeploymentPlanName": self.edge_deployment_plan_name, + "StageName": stage_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() - response = client.describe_endpoint(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) - # deserialize response and update self - transform(response, "DescribeEndpointOutput", self) - return self + logger.debug(f"Calling stop_edge_deployment_stage API") + response = client.stop_edge_deployment_stage(**operation_input_args) + logger.debug(f"Response: {response}") - @populate_inputs_decorator @Base.add_validate_call - def update( + def get_all_stage_devices( self, - retain_all_variant_properties: Optional[bool] = Unassigned(), - exclude_retained_variant_properties: Optional[List[VariantProperty]] = Unassigned(), - deployment_config: Optional[DeploymentConfig] = Unassigned(), - retain_deployment_config: Optional[bool] = Unassigned(), - ) -> Optional["Endpoint"]: + stage_name: str, + exclude_devices_deployed_in_other_stage: Optional[bool] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator[DeviceDeploymentSummary]: """ - Update a Endpoint resource + Lists devices allocated to the stage, containing detailed device information and deployment status. Parameters: - retain_all_variant_properties: When updating endpoint resources, enables or disables the retention of variant properties, such as the instance count or the variant weight. To retain the variant properties of an endpoint when updating it, set RetainAllVariantProperties to true. To use the variant properties specified in a new EndpointConfig call when updating an endpoint, set RetainAllVariantProperties to false. The default is false. - exclude_retained_variant_properties: When you are updating endpoint resources with RetainAllVariantProperties, whose value is set to true, ExcludeRetainedVariantProperties specifies the list of type VariantProperty to override with the values provided by EndpointConfig. If you don't specify a value for ExcludeRetainedVariantProperties, no variant properties are overridden. - deployment_config: The deployment configuration for an endpoint, which contains the desired deployment strategy and rollback configurations. - retain_deployment_config: Specifies whether to reuse the last deployment configuration. The default value is false (the configuration is not reused). + stage_name: The name of the stage in the deployment. + max_results: The maximum number of requests to select. + exclude_devices_deployed_in_other_stage: Toggle for excluding devices deployed in other stages. + session: Boto3 session. + region: Region name. Returns: - The Endpoint resource. + Iterator for listed DeviceDeploymentSummary. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8005,225 +8348,138 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. """ - logger.info("Updating endpoint resource.") - client = Base.get_sagemaker_client() - operation_input_args = { - "EndpointName": self.endpoint_name, - "EndpointConfigName": self.endpoint_config_name, - "RetainAllVariantProperties": retain_all_variant_properties, - "ExcludeRetainedVariantProperties": exclude_retained_variant_properties, - "DeploymentConfig": deployment_config, - "RetainDeploymentConfig": retain_deployment_config, + "EdgeDeploymentPlanName": self.edge_deployment_plan_name, + "ExcludeDevicesDeployedInOtherStage": exclude_devices_deployed_in_other_stage, + "StageName": stage_name, } - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.update_endpoint(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) - return self + return ResourceIterator( + client=client, + list_method="list_stage_devices", + summaries_key="DeviceDeploymentSummaries", + summary_name="DeviceDeploymentSummary", + resource_cls=DeviceDeploymentSummary, + list_method_kwargs=operation_input_args, + ) - @Base.add_validate_call - def delete( - self, - ) -> None: - """ - Delete a Endpoint resource - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - """ +class EdgePackagingJob(Base): + """ + Class representing resource EdgePackagingJob - client = Base.get_sagemaker_client() + Attributes: + edge_packaging_job_arn: The Amazon Resource Name (ARN) of the edge packaging job. + edge_packaging_job_name: The name of the edge packaging job. + edge_packaging_job_status: The current status of the packaging job. + compilation_job_name: The name of the SageMaker Neo compilation job that is used to locate model artifacts that are being packaged. + model_name: The name of the model. + model_version: The version of the model. + role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to download and upload the model, and to contact Neo. + output_config: The output configuration for the edge packaging job. + resource_key: The Amazon Web Services KMS key to use when encrypting the EBS volume the job run on. + edge_packaging_job_status_message: Returns a message describing the job status and error messages. + creation_time: The timestamp of when the packaging job was created. + last_modified_time: The timestamp of when the job was last updated. + model_artifact: The Amazon Simple Storage (S3) URI where model artifacts ares stored. + model_signature: The signature document of files in the model artifact. + preset_deployment_output: The output of a SageMaker Edge Manager deployable resource. - operation_input_args = { - "EndpointName": self.endpoint_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") + """ - client.delete_endpoint(**operation_input_args) + edge_packaging_job_name: str + edge_packaging_job_arn: Optional[str] = Unassigned() + compilation_job_name: Optional[str] = Unassigned() + model_name: Optional[str] = Unassigned() + model_version: Optional[str] = Unassigned() + role_arn: Optional[str] = Unassigned() + output_config: Optional[EdgeOutputConfig] = Unassigned() + resource_key: Optional[str] = Unassigned() + edge_packaging_job_status: Optional[str] = Unassigned() + edge_packaging_job_status_message: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + model_artifact: Optional[str] = Unassigned() + model_signature: Optional[str] = Unassigned() + preset_deployment_output: Optional[EdgePresetDeploymentOutput] = Unassigned() - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def wait_for_status( - self, - target_status: Literal[ - "OutOfService", - "Creating", - "Updating", - "SystemUpdating", - "RollingBack", - "InService", - "Deleting", - "Failed", - "UpdateRollbackFailed", - ], - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a Endpoint resource to reach certain status. - - Parameters: - target_status: The status to wait for. - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. - - Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task(f"Waiting for Endpoint to reach [bold]{target_status} status...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.endpoint_status - status.update(f"Current status: [bold]{current_status}") - - if target_status == current_status: - logger.info(f"Final Resource Status: [bold]{current_status}") - return - - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="Endpoint", status=current_status, reason=self.failure_reason - ) - - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="Endpoint", status=current_status) - time.sleep(poll) - - @Base.add_validate_call - def wait_for_delete( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a Endpoint resource to be deleted. - - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + def get_name(self) -> str: + attributes = vars(self) + resource_name = "edge_packaging_job_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for Endpoint to be deleted...") - status = Status("Current status:") + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object edge_packaging_job") + return None - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "role_arn": {"type": "string"}, + "output_config": { + "s3_output_location": {"type": "string"}, + "kms_key_id": {"type": "string"}, + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "EdgePackagingJob", **kwargs + ), ) - ): - while True: - try: - self.refresh() - current_status = self.endpoint_status - status.update(f"Current status: [bold]{current_status}") - - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="Endpoint", status=current_status) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e - time.sleep(poll) + return wrapper @classmethod + @populate_inputs_decorator @Base.add_validate_call - def get_all( + def create( cls, - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - status_equals: Optional[str] = Unassigned(), + edge_packaging_job_name: str, + compilation_job_name: Union[str, object], + model_name: Union[str, object], + model_version: str, + role_arn: str, + output_config: EdgeOutputConfig, + resource_key: Optional[str] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["Endpoint"]: + ) -> Optional["EdgePackagingJob"]: """ - Get all Endpoint resources + Create a EdgePackagingJob resource Parameters: - sort_by: Sorts the list of results. The default is CreationTime. - sort_order: The sort order for results. The default is Descending. - next_token: If the result of a ListEndpoints request was truncated, the response includes a NextToken. To retrieve the next set of endpoints, use the token in the next request. - max_results: The maximum number of endpoints to return in the response. This value defaults to 10. - name_contains: A string in endpoint names. This filter returns only endpoints whose name contains the specified string. - creation_time_before: A filter that returns only endpoints that were created before the specified time (timestamp). - creation_time_after: A filter that returns only endpoints with a creation time greater than or equal to the specified time (timestamp). - last_modified_time_before: A filter that returns only endpoints that were modified before the specified timestamp. - last_modified_time_after: A filter that returns only endpoints that were modified after the specified timestamp. - status_equals: A filter that returns only endpoints with the specified status. + edge_packaging_job_name: The name of the edge packaging job. + compilation_job_name: The name of the SageMaker Neo compilation job that will be used to locate model artifacts for packaging. + model_name: The name of the model. + model_version: The version of the model. + role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to download and upload the model, and to contact SageMaker Neo. + output_config: Provides information about the output location for the packaged model. + resource_key: The Amazon Web Services KMS key to use when encrypting the EBS volume the edge packaging job runs on. + tags: Creates tags for the packaging job. session: Boto3 session. region: Region name. Returns: - Iterator for listed Endpoint resources. + The EdgePackagingJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8235,51 +8491,64 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ + logger.info("Creating edge_packaging_job resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "SortBy": sort_by, - "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, - "LastModifiedTimeBefore": last_modified_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "StatusEquals": status_equals, + "EdgePackagingJobName": edge_packaging_job_name, + "CompilationJobName": compilation_job_name, + "ModelName": model_name, + "ModelVersion": model_version, + "RoleArn": role_arn, + "OutputConfig": output_config, + "ResourceKey": resource_key, + "Tags": tags, } + operation_input_args = Base.populate_chained_attributes( + resource_name="EdgePackagingJob", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_endpoints", - summaries_key="Endpoints", - summary_name="EndpointSummary", - resource_cls=Endpoint, - list_method_kwargs=operation_input_args, + # create the resource + response = client.create_edge_packaging_job(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get( + edge_packaging_job_name=edge_packaging_job_name, session=session, region=region ) + @classmethod @Base.add_validate_call - def update_weights_and_capacities( - self, - desired_weights_and_capacities: List[DesiredWeightAndCapacity], + def get( + cls, + edge_packaging_job_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> None: + ) -> Optional["EdgePackagingJob"]: """ - Updates variant weight of one or more variants associated with an existing endpoint, or capacity of one variant associated with an existing endpoint. + Get a EdgePackagingJob resource Parameters: - desired_weights_and_capacities: An object that provides new capacity and weight values for a variant. + edge_packaging_job_name: The name of the edge packaging job. session: Boto3 session. region: Region name. + Returns: + The EdgePackagingJob resource. + Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -8290,12 +8559,11 @@ def update_weights_and_capacities( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "EndpointName": self.endpoint_name, - "DesiredWeightsAndCapacities": desired_weights_and_capacities, + "EdgePackagingJobName": edge_packaging_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -8304,48 +8572,24 @@ def update_weights_and_capacities( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) + response = client.describe_edge_packaging_job(**operation_input_args) - logger.debug(f"Calling update_endpoint_weights_and_capacities API") - response = client.update_endpoint_weights_and_capacities(**operation_input_args) - logger.debug(f"Response: {response}") + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeEdgePackagingJobResponse") + edge_packaging_job = cls(**transformed_response) + return edge_packaging_job @Base.add_validate_call - def invoke( + def refresh( self, - body: Any, - content_type: Optional[str] = Unassigned(), - accept: Optional[str] = Unassigned(), - custom_attributes: Optional[str] = Unassigned(), - target_model: Optional[str] = Unassigned(), - target_variant: Optional[str] = Unassigned(), - target_container_hostname: Optional[str] = Unassigned(), - inference_id: Optional[str] = Unassigned(), - enable_explanations: Optional[str] = Unassigned(), - inference_component_name: Optional[str] = Unassigned(), - session_id: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional[InvokeEndpointOutput]: + ) -> Optional["EdgePackagingJob"]: """ - After you deploy a model into production using Amazon SageMaker hosting services, your client applications use this API to get inferences from the model hosted at the specified endpoint. - - Parameters: - body: Provides input data, in the format specified in the ContentType request header. Amazon SageMaker passes all of the data in the body to the model. For information about the format of the request body, see Common Data Formats-Inference. - content_type: The MIME type of the input data in the request body. - accept: The desired MIME type of the inference response from the model container. - custom_attributes: Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK. - target_model: The model to request for inference when invoking a multi-model endpoint. - target_variant: Specify the production variant to send the inference request to when invoking an endpoint that is running two or more variants. Note that this parameter overrides the default behavior for the endpoint, which is to distribute the invocation traffic based on the variant weights. For information about how to use variant targeting to perform a/b testing, see Test models in production - target_container_hostname: If the endpoint hosts multiple containers and is configured to use direct invocation, this parameter specifies the host name of the container to invoke. - inference_id: If you provide a value, it is added to the captured data when you enable data capture on the endpoint. For information about data capture, see Capture Data. - enable_explanations: An optional JMESPath expression used to override the EnableExplanations parameter of the ClarifyExplainerConfig API. See the EnableExplanations section in the developer guide for more information. - inference_component_name: If the endpoint hosts one or more inference components, this parameter specifies the name of inference component to invoke. - session_id: Creates a stateful session or identifies an existing one. You can do one of the following: Create a stateful session by specifying the value NEW_SESSION. Send your request to an existing stateful session by specifying the ID of that session. With a stateful session, you can send multiple requests to a stateful model. When you create a session with a stateful model, the model must create the session ID and set the expiration time. The model must also provide that information in the response to your request. You can get the ID and timestamp from the NewSessionId response parameter. For any subsequent request where you specify that session ID, SageMaker routes the request to the same instance that supports the session. - session: Boto3 session. - region: Region name. + Refresh a EdgePackagingJob resource Returns: - InvokeEndpointOutput + The EdgePackagingJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8357,72 +8601,27 @@ def invoke( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - InternalDependencyException: Your request caused an exception with an internal dependency. Contact customer support. - InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. - ModelError: Model (owned by the customer in the container) returned 4xx or 5xx error code. - ModelNotReadyException: Either a serverless endpoint variant's resources are still being provisioned, or a multi-model endpoint is still downloading or loading the target model. Wait and try your request again. - ServiceUnavailable: The service is currently unavailable. - ValidationError: There was an error validating your request. + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "EndpointName": self.endpoint_name, - "Body": body, - "ContentType": content_type, - "Accept": accept, - "CustomAttributes": custom_attributes, - "TargetModel": target_model, - "TargetVariant": target_variant, - "TargetContainerHostname": target_container_hostname, - "InferenceId": inference_id, - "EnableExplanations": enable_explanations, - "InferenceComponentName": inference_component_name, - "SessionId": session_id, + "EdgePackagingJobName": self.edge_packaging_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker-runtime" - ) - - logger.debug(f"Calling invoke_endpoint API") - response = client.invoke_endpoint(**operation_input_args) - logger.debug(f"Response: {response}") + client = Base.get_sagemaker_client() + response = client.describe_edge_packaging_job(**operation_input_args) - transformed_response = transform(response, "InvokeEndpointOutput") - return InvokeEndpointOutput(**transformed_response) + # deserialize response and update self + transform(response, "DescribeEdgePackagingJobResponse", self) + return self @Base.add_validate_call - def invoke_async( - self, - input_location: str, - content_type: Optional[str] = Unassigned(), - accept: Optional[str] = Unassigned(), - custom_attributes: Optional[str] = Unassigned(), - inference_id: Optional[str] = Unassigned(), - request_ttl_seconds: Optional[int] = Unassigned(), - invocation_timeout_seconds: Optional[int] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional[InvokeEndpointAsyncOutput]: + def stop(self) -> None: """ - After you deploy a model into production using Amazon SageMaker hosting services, your client applications use this API to get inferences from the model hosted at the specified endpoint in an asynchronous manner. - - Parameters: - input_location: The Amazon S3 URI where the inference request payload is stored. - content_type: The MIME type of the input data in the request body. - accept: The desired MIME type of the inference response from the model container. - custom_attributes: Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK. - inference_id: The identifier for the inference request. Amazon SageMaker will generate an identifier for you if none is specified. - request_ttl_seconds: Maximum age in seconds a request can be in the queue before it is marked as expired. The default is 6 hours, or 21,600 seconds. - invocation_timeout_seconds: Maximum amount of time in seconds a request can be processed before it is marked as expired. The default is 15 minutes, or 900 seconds. - session: Boto3 session. - region: Region name. - - Returns: - InvokeEndpointAsyncOutput + Stop a EdgePackagingJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8434,69 +8633,118 @@ def invoke_async( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. - ServiceUnavailable: The service is currently unavailable. - ValidationError: There was an error validating your request. """ + client = SageMakerClient().client + operation_input_args = { - "EndpointName": self.endpoint_name, - "ContentType": content_type, - "Accept": accept, - "CustomAttributes": custom_attributes, - "InferenceId": inference_id, - "InputLocation": input_location, - "RequestTTLSeconds": request_ttl_seconds, - "InvocationTimeoutSeconds": invocation_timeout_seconds, + "EdgePackagingJobName": self.edge_packaging_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker-runtime" + client.stop_edge_packaging_job(**operation_input_args) + + logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") + + @Base.add_validate_call + def wait( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a EdgePackagingJob resource. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + + """ + terminal_states = ["COMPLETED", "FAILED", "STOPPED"] + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), ) + progress.add_task("Waiting for EdgePackagingJob...") + status = Status("Current status:") - logger.debug(f"Calling invoke_endpoint_async API") - response = client.invoke_endpoint_async(**operation_input_args) - logger.debug(f"Response: {response}") + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.edge_packaging_job_status + status.update(f"Current status: [bold]{current_status}") - transformed_response = transform(response, "InvokeEndpointAsyncOutput") - return InvokeEndpointAsyncOutput(**transformed_response) + if current_status in terminal_states: + logger.info(f"Final Resource Status: [bold]{current_status}") + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="EdgePackagingJob", + status=current_status, + reason=self.edge_packaging_job_status_message, + ) + + return + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="EdgePackagingJob", status=current_status + ) + time.sleep(poll) + + @classmethod @Base.add_validate_call - def invoke_with_response_stream( - self, - body: Any, - content_type: Optional[str] = Unassigned(), - accept: Optional[str] = Unassigned(), - custom_attributes: Optional[str] = Unassigned(), - target_variant: Optional[str] = Unassigned(), - target_container_hostname: Optional[str] = Unassigned(), - inference_id: Optional[str] = Unassigned(), - inference_component_name: Optional[str] = Unassigned(), - session_id: Optional[str] = Unassigned(), + def get_all( + cls, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + model_name_contains: Optional[str] = Unassigned(), + status_equals: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional[InvokeEndpointWithResponseStreamOutput]: + ) -> ResourceIterator["EdgePackagingJob"]: """ - Invokes a model at the specified endpoint to return the inference response as a stream. + Get all EdgePackagingJob resources Parameters: - body: Provides input data, in the format specified in the ContentType request header. Amazon SageMaker passes all of the data in the body to the model. For information about the format of the request body, see Common Data Formats-Inference. - content_type: The MIME type of the input data in the request body. - accept: The desired MIME type of the inference response from the model container. - custom_attributes: Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK. - target_variant: Specify the production variant to send the inference request to when invoking an endpoint that is running two or more variants. Note that this parameter overrides the default behavior for the endpoint, which is to distribute the invocation traffic based on the variant weights. For information about how to use variant targeting to perform a/b testing, see Test models in production - target_container_hostname: If the endpoint hosts multiple containers and is configured to use direct invocation, this parameter specifies the host name of the container to invoke. - inference_id: An identifier that you assign to your request. - inference_component_name: If the endpoint hosts one or more inference components, this parameter specifies the name of inference component to invoke for a streaming response. - session_id: The ID of a stateful session to handle your request. You can't create a stateful session by using the InvokeEndpointWithResponseStream action. Instead, you can create one by using the InvokeEndpoint action. In your request, you specify NEW_SESSION for the SessionId request parameter. The response to that request provides the session ID for the NewSessionId response parameter. + next_token: The response from the last list when returning a list large enough to need tokening. + max_results: Maximum number of results to select. + creation_time_after: Select jobs where the job was created after specified time. + creation_time_before: Select jobs where the job was created before specified time. + last_modified_time_after: Select jobs where the job was updated after specified time. + last_modified_time_before: Select jobs where the job was updated before specified time. + name_contains: Filter for jobs containing this name in their packaging job name. + model_name_contains: Filter for jobs where the model name contains this string. + status_equals: The job status to filter for. + sort_by: Use to specify what column to sort by. + sort_order: What direction to sort by. session: Boto3 session. region: Region name. Returns: - InvokeEndpointWithResponseStreamOutput + Iterator for listed EdgePackagingJob resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8508,78 +8756,78 @@ def invoke_with_response_stream( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. - InternalStreamFailure: The stream processing failed because of an unknown error, exception or failure. Try your request again. - ModelError: Model (owned by the customer in the container) returned 4xx or 5xx error code. - ModelStreamError: An error occurred while streaming the response body. This error can have the following error codes: ModelInvocationTimeExceeded The model failed to finish sending the response within the timeout period allowed by Amazon SageMaker. StreamBroken The Transmission Control Protocol (TCP) connection between the client and the model was reset or closed. - ServiceUnavailable: The service is currently unavailable. - ValidationError: There was an error validating your request. """ + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "EndpointName": self.endpoint_name, - "Body": body, - "ContentType": content_type, - "Accept": accept, - "CustomAttributes": custom_attributes, - "TargetVariant": target_variant, - "TargetContainerHostname": target_container_hostname, - "InferenceId": inference_id, - "InferenceComponentName": inference_component_name, - "SessionId": session_id, + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "NameContains": name_contains, + "ModelNameContains": model_name_contains, + "StatusEquals": status_equals, + "SortBy": sort_by, + "SortOrder": sort_order, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker-runtime" + return ResourceIterator( + client=client, + list_method="list_edge_packaging_jobs", + summaries_key="EdgePackagingJobSummaries", + summary_name="EdgePackagingJobSummary", + resource_cls=EdgePackagingJob, + list_method_kwargs=operation_input_args, ) - logger.debug(f"Calling invoke_endpoint_with_response_stream API") - response = client.invoke_endpoint_with_response_stream(**operation_input_args) - logger.debug(f"Response: {response}") - - transformed_response = transform(response, "InvokeEndpointWithResponseStreamOutput") - return InvokeEndpointWithResponseStreamOutput(**transformed_response) - -class EndpointConfig(Base): +class Endpoint(Base): """ - Class representing resource EndpointConfig + Class representing resource Endpoint Attributes: - endpoint_config_name: Name of the SageMaker endpoint configuration. - endpoint_config_arn: The Amazon Resource Name (ARN) of the endpoint configuration. - production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint. - creation_time: A timestamp that shows when the endpoint configuration was created. + endpoint_name: Name of the endpoint. + endpoint_arn: The Amazon Resource Name (ARN) of the endpoint. + endpoint_status: The status of the endpoint. OutOfService: Endpoint is not available to take incoming requests. Creating: CreateEndpoint is executing. Updating: UpdateEndpoint or UpdateEndpointWeightsAndCapacities is executing. SystemUpdating: Endpoint is undergoing maintenance and cannot be updated or deleted or re-scaled until it has completed. This maintenance operation does not change any customer-specified values such as VPC config, KMS encryption, model, instance type, or instance count. RollingBack: Endpoint fails to scale up or down or change its variant weight and is in the process of rolling back to its previous configuration. Once the rollback completes, endpoint returns to an InService status. This transitional status only applies to an endpoint that has autoscaling enabled and is undergoing variant weight or capacity changes as part of an UpdateEndpointWeightsAndCapacities call or when the UpdateEndpointWeightsAndCapacities operation is called explicitly. InService: Endpoint is available to process incoming requests. Deleting: DeleteEndpoint is executing. Failed: Endpoint could not be created, updated, or re-scaled. Use the FailureReason value returned by DescribeEndpoint for information about the failure. DeleteEndpoint is the only operation that can be performed on a failed endpoint. UpdateRollbackFailed: Both the rolling deployment and auto-rollback failed. Your endpoint is in service with a mix of the old and new endpoint configurations. For information about how to remedy this issue and restore the endpoint's status to InService, see Rolling Deployments. + creation_time: A timestamp that shows when the endpoint was created. + last_modified_time: A timestamp that shows when the endpoint was last modified. + endpoint_config_name: The name of the endpoint configuration associated with this endpoint. + production_variants: An array of ProductionVariantSummary objects, one for each model hosted behind this endpoint. data_capture_config: - kms_key_id: Amazon Web Services KMS key ID Amazon SageMaker uses to encrypt data when storing it on the ML storage volume attached to the instance. + failure_reason: If the status of the endpoint is Failed, the reason why it failed. + last_deployment_config: The most recent deployment configuration for the endpoint. async_inference_config: Returns the description of an endpoint configuration created using the CreateEndpointConfig API. + pending_deployment_summary: Returns the summary of an in-progress deployment. This field is only returned when the endpoint is creating or updating with a new endpoint configuration. explainer_config: The configuration parameters for an explainer. - shadow_production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. - execution_role_arn: The Amazon Resource Name (ARN) of the IAM role that you assigned to the endpoint configuration. - vpc_config: - enable_network_isolation: Indicates whether all model containers deployed to the endpoint are isolated. If they are, no inbound or outbound network calls can be made to or from the model containers. + shadow_production_variants: An array of ProductionVariantSummary objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. """ - endpoint_config_name: str - endpoint_config_arn: Optional[str] = Unassigned() - production_variants: Optional[List[ProductionVariant]] = Unassigned() - data_capture_config: Optional[DataCaptureConfig] = Unassigned() - kms_key_id: Optional[str] = Unassigned() + endpoint_name: str + endpoint_arn: Optional[str] = Unassigned() + endpoint_config_name: Optional[str] = Unassigned() + production_variants: Optional[List[ProductionVariantSummary]] = Unassigned() + data_capture_config: Optional[DataCaptureConfigSummary] = Unassigned() + endpoint_status: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + last_deployment_config: Optional[DeploymentConfig] = Unassigned() async_inference_config: Optional[AsyncInferenceConfig] = Unassigned() + pending_deployment_summary: Optional[PendingDeploymentSummary] = Unassigned() explainer_config: Optional[ExplainerConfig] = Unassigned() - shadow_production_variants: Optional[List[ProductionVariant]] = Unassigned() - execution_role_arn: Optional[str] = Unassigned() - vpc_config: Optional[VpcConfig] = Unassigned() - enable_network_isolation: Optional[bool] = Unassigned() + shadow_production_variants: Optional[List[ProductionVariantSummary]] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "endpoint_config_name" + resource_name = "endpoint_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -8590,7 +8838,7 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object endpoint_config") + logger.error("Name attribute not found for object endpoint") return None def populate_inputs_decorator(create_func): @@ -8601,7 +8849,6 @@ def wrapper(*args, **kwargs): "destination_s3_uri": {"type": "string"}, "kms_key_id": {"type": "string"}, }, - "kms_key_id": {"type": "string"}, "async_inference_config": { "output_config": { "kms_key_id": {"type": "string"}, @@ -8609,16 +8856,11 @@ def wrapper(*args, **kwargs): "s3_failure_path": {"type": "string"}, } }, - "execution_role_arn": {"type": "string"}, - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - }, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "EndpointConfig", **kwargs + config_schema_for_resource, "Endpoint", **kwargs ), ) @@ -8629,40 +8871,26 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - endpoint_config_name: str, - production_variants: List[ProductionVariant], - data_capture_config: Optional[DataCaptureConfig] = Unassigned(), + endpoint_name: str, + endpoint_config_name: Union[str, object], + deployment_config: Optional[DeploymentConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - kms_key_id: Optional[str] = Unassigned(), - async_inference_config: Optional[AsyncInferenceConfig] = Unassigned(), - explainer_config: Optional[ExplainerConfig] = Unassigned(), - shadow_production_variants: Optional[List[ProductionVariant]] = Unassigned(), - execution_role_arn: Optional[str] = Unassigned(), - vpc_config: Optional[VpcConfig] = Unassigned(), - enable_network_isolation: Optional[bool] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["EndpointConfig"]: + ) -> Optional["Endpoint"]: """ - Create a EndpointConfig resource + Create a Endpoint resource Parameters: - endpoint_config_name: The name of the endpoint configuration. You specify this name in a CreateEndpoint request. - production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint. - data_capture_config: + endpoint_name: The name of the endpoint.The name must be unique within an Amazon Web Services Region in your Amazon Web Services account. The name is case-insensitive in CreateEndpoint, but the case is preserved and must be matched in InvokeEndpoint. + endpoint_config_name: The name of an endpoint configuration. For more information, see CreateEndpointConfig. + deployment_config: tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. - kms_key_id: The Amazon Resource Name (ARN) of a Amazon Web Services Key Management Service key that SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint. The KmsKeyId can be any of the following formats: Key ID: 1234abcd-12ab-34cd-56ef-1234567890ab Key ARN: arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab Alias name: alias/ExampleAlias Alias name ARN: arn:aws:kms:us-west-2:111122223333:alias/ExampleAlias The KMS key policy must grant permission to the IAM role that you specify in your CreateEndpoint, UpdateEndpoint requests. For more information, refer to the Amazon Web Services Key Management Service section Using Key Policies in Amazon Web Services KMS Certain Nitro-based instances include local storage, dependent on the instance type. Local storage volumes are encrypted using a hardware module on the instance. You can't request a KmsKeyId when using an instance type with local storage. If any of the models that you specify in the ProductionVariants parameter use nitro-based instances with local storage, do not specify a value for the KmsKeyId parameter. If you specify a value for KmsKeyId when using any nitro-based instances with local storage, the call to CreateEndpointConfig fails. For a list of instance types that support local instance storage, see Instance Store Volumes. For more information about local instance storage encryption, see SSD Instance Store Volumes. - async_inference_config: Specifies configuration for how an endpoint performs asynchronous inference. This is a required field in order for your Endpoint to be invoked using InvokeEndpointAsync. - explainer_config: A member of CreateEndpointConfig that enables explainers. - shadow_production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. If you use this field, you can only specify one variant for ProductionVariants and one variant for ShadowProductionVariants. - execution_role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform actions on your behalf. For more information, see SageMaker Roles. To be able to pass this role to Amazon SageMaker, the caller of this action must have the iam:PassRole permission. - vpc_config: - enable_network_isolation: Sets whether all model containers deployed to the endpoint are isolated. If they are, no inbound or outbound network calls can be made to or from the model containers. session: Boto3 session. region: Region name. Returns: - The EndpointConfig resource. + The Endpoint resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8680,27 +8908,20 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating endpoint_config resource.") + logger.info("Creating endpoint resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { + "EndpointName": endpoint_name, "EndpointConfigName": endpoint_config_name, - "ProductionVariants": production_variants, - "DataCaptureConfig": data_capture_config, + "DeploymentConfig": deployment_config, "Tags": tags, - "KmsKeyId": kms_key_id, - "AsyncInferenceConfig": async_inference_config, - "ExplainerConfig": explainer_config, - "ShadowProductionVariants": shadow_production_variants, - "ExecutionRoleArn": execution_role_arn, - "VpcConfig": vpc_config, - "EnableNetworkIsolation": enable_network_isolation, } operation_input_args = Base.populate_chained_attributes( - resource_name="EndpointConfig", operation_input_args=operation_input_args + resource_name="Endpoint", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -8709,29 +8930,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_endpoint_config(**operation_input_args) + response = client.create_endpoint(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(endpoint_config_name=endpoint_config_name, session=session, region=region) + return cls.get(endpoint_name=endpoint_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - endpoint_config_name: str, + endpoint_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["EndpointConfig"]: + ) -> Optional["Endpoint"]: """ - Get a EndpointConfig resource + Get a Endpoint resource Parameters: - endpoint_config_name: The name of the endpoint configuration. + endpoint_name: The name of the endpoint. session: Boto3 session. region: Region name. Returns: - The EndpointConfig resource. + The Endpoint resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8746,7 +8967,7 @@ def get( """ operation_input_args = { - "EndpointConfigName": endpoint_config_name, + "EndpointName": endpoint_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -8755,24 +8976,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_endpoint_config(**operation_input_args) + response = client.describe_endpoint(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeEndpointConfigOutput") - endpoint_config = cls(**transformed_response) - return endpoint_config + transformed_response = transform(response, "DescribeEndpointOutput") + endpoint = cls(**transformed_response) + return endpoint @Base.add_validate_call def refresh( self, - ) -> Optional["EndpointConfig"]: + ) -> Optional["Endpoint"]: """ - Refresh a EndpointConfig resource + Refresh a Endpoint resource Returns: - The EndpointConfig resource. + The Endpoint resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8787,25 +9008,39 @@ def refresh( """ operation_input_args = { - "EndpointConfigName": self.endpoint_config_name, + "EndpointName": self.endpoint_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_endpoint_config(**operation_input_args) + response = client.describe_endpoint(**operation_input_args) # deserialize response and update self - transform(response, "DescribeEndpointConfigOutput", self) + transform(response, "DescribeEndpointOutput", self) return self + @populate_inputs_decorator @Base.add_validate_call - def delete( + def update( self, - ) -> None: + retain_all_variant_properties: Optional[bool] = Unassigned(), + exclude_retained_variant_properties: Optional[List[VariantProperty]] = Unassigned(), + deployment_config: Optional[DeploymentConfig] = Unassigned(), + retain_deployment_config: Optional[bool] = Unassigned(), + ) -> Optional["Endpoint"]: """ - Delete a EndpointConfig resource + Update a Endpoint resource + + Parameters: + retain_all_variant_properties: When updating endpoint resources, enables or disables the retention of variant properties, such as the instance count or the variant weight. To retain the variant properties of an endpoint when updating it, set RetainAllVariantProperties to true. To use the variant properties specified in a new EndpointConfig call when updating an endpoint, set RetainAllVariantProperties to false. The default is false. + exclude_retained_variant_properties: When you are updating endpoint resources with RetainAllVariantProperties, whose value is set to true, ExcludeRetainedVariantProperties specifies the list of type VariantProperty to override with the values provided by EndpointConfig. If you don't specify a value for ExcludeRetainedVariantProperties, no variant properties are overridden. + deployment_config: The deployment configuration for an endpoint, which contains the desired deployment strategy and rollback configurations. + retain_deployment_config: Specifies whether to reuse the last deployment configuration. The default value is false (the configuration is not reused). + + Returns: + The Endpoint resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8817,49 +9052,38 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. """ + logger.info("Updating endpoint resource.") client = Base.get_sagemaker_client() operation_input_args = { + "EndpointName": self.endpoint_name, "EndpointConfigName": self.endpoint_config_name, + "RetainAllVariantProperties": retain_all_variant_properties, + "ExcludeRetainedVariantProperties": exclude_retained_variant_properties, + "DeploymentConfig": deployment_config, + "RetainDeploymentConfig": retain_deployment_config, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_endpoint_config(**operation_input_args) + # create the resource + response = client.update_endpoint(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + return self - @classmethod @Base.add_validate_call - def get_all( - cls, - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["EndpointConfig"]: + def delete( + self, + ) -> None: """ - Get all EndpointConfig resources - - Parameters: - sort_by: The field to sort results by. The default is CreationTime. - sort_order: The sort order for results. The default is Descending. - next_token: If the result of the previous ListEndpointConfig request was truncated, the response includes a NextToken. To retrieve the next set of endpoint configurations, use the token in the next request. - max_results: The maximum number of training jobs to return in the response. - name_contains: A string in the endpoint configuration name. This filter returns only endpoint configurations whose name contains the specified string. - creation_time_before: A filter that returns only endpoint configurations created before the specified time (timestamp). - creation_time_after: A filter that returns only endpoint configurations with a creation time greater than or equal to the specified time (timestamp). - session: Boto3 session. - region: Region name. - - Returns: - Iterator for listed EndpointConfig resources. + Delete a Endpoint resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8873,99 +9097,97 @@ def get_all( ``` """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + client = Base.get_sagemaker_client() operation_input_args = { - "SortBy": sort_by, - "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, + "EndpointName": self.endpoint_name, } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_endpoint_configs", - summaries_key="EndpointConfigs", - summary_name="EndpointConfigSummary", - resource_cls=EndpointConfig, - list_method_kwargs=operation_input_args, - ) + client.delete_endpoint(**operation_input_args) + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") -class Experiment(Base): - """ - Class representing resource Experiment + @Base.add_validate_call + def wait_for_status( + self, + target_status: Literal[ + "OutOfService", + "Creating", + "Updating", + "SystemUpdating", + "RollingBack", + "InService", + "Deleting", + "Failed", + "UpdateRollbackFailed", + ], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a Endpoint resource to reach certain status. - Attributes: - experiment_name: The name of the experiment. - experiment_arn: The Amazon Resource Name (ARN) of the experiment. - display_name: The name of the experiment as displayed. If DisplayName isn't specified, ExperimentName is displayed. - source: The Amazon Resource Name (ARN) of the source and, optionally, the type. - description: The description of the experiment. - creation_time: When the experiment was created. - created_by: Who created the experiment. - last_modified_time: When the experiment was last modified. - last_modified_by: Who last modified the experiment. + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. - """ + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() - experiment_name: str - experiment_arn: Optional[str] = Unassigned() - display_name: Optional[str] = Unassigned() - source: Optional[ExperimentSource] = Unassigned() - description: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - created_by: Optional[UserContext] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - last_modified_by: Optional[UserContext] = Unassigned() + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task(f"Waiting for Endpoint to reach [bold]{target_status} status...") + status = Status("Current status:") - def get_name(self) -> str: - attributes = vars(self) - resource_name = "experiment_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.endpoint_status + status.update(f"Current status: [bold]{current_status}") - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object experiment") - return None + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="Endpoint", status=current_status, reason=self.failure_reason + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="Endpoint", status=current_status) + time.sleep(poll) - @classmethod @Base.add_validate_call - def create( - cls, - experiment_name: str, - display_name: Optional[str] = Unassigned(), - description: Optional[str] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["Experiment"]: + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: """ - Create a Experiment resource + Wait for a Endpoint resource to be deleted. Parameters: - experiment_name: The name of the experiment. The name must be unique in your Amazon Web Services account and is not case-sensitive. - display_name: The name of the experiment as displayed. The name doesn't need to be unique. If you don't specify DisplayName, the value in ExperimentName is displayed. - description: The description of the experiment. - tags: A list of tags to associate with the experiment. You can use Search API to search on the tags. - session: Boto3 session. - region: Region name. - - Returns: - The Experiment resource. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -8977,57 +9199,78 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. """ + start_time = time.time() - logger.info("Creating experiment resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - operation_input_args = { - "ExperimentName": experiment_name, - "DisplayName": display_name, - "Description": description, - "Tags": tags, - } - - operation_input_args = Base.populate_chained_attributes( - resource_name="Experiment", operation_input_args=operation_input_args + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), ) + progress.add_task("Waiting for Endpoint to be deleted...") + status = Status("Current status:") - logger.debug(f"Input request: {operation_input_args}") - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.endpoint_status + status.update(f"Current status: [bold]{current_status}") - # create the resource - response = client.create_experiment(**operation_input_args) - logger.debug(f"Response: {response}") + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="Endpoint", status=current_status) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] - return cls.get(experiment_name=experiment_name, session=session, region=region) + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) @classmethod @Base.add_validate_call - def get( + def get_all( cls, - experiment_name: str, + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + status_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Experiment"]: + ) -> ResourceIterator["Endpoint"]: """ - Get a Experiment resource + Get all Endpoint resources Parameters: - experiment_name: The name of the experiment to describe. + sort_by: Sorts the list of results. The default is CreationTime. + sort_order: The sort order for results. The default is Descending. + next_token: If the result of a ListEndpoints request was truncated, the response includes a NextToken. To retrieve the next set of endpoints, use the token in the next request. + max_results: The maximum number of endpoints to return in the response. This value defaults to 10. + name_contains: A string in endpoint names. This filter returns only endpoints whose name contains the specified string. + creation_time_before: A filter that returns only endpoints that were created before the specified time (timestamp). + creation_time_after: A filter that returns only endpoints with a creation time greater than or equal to the specified time (timestamp). + last_modified_time_before: A filter that returns only endpoints that were modified before the specified timestamp. + last_modified_time_after: A filter that returns only endpoints that were modified after the specified timestamp. + status_equals: A filter that returns only endpoints with the specified status. session: Boto3 session. region: Region name. Returns: - The Experiment resource. + Iterator for listed Endpoint resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9039,37 +9282,50 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "ExperimentName": experiment_name, + "SortBy": sort_by, + "SortOrder": sort_order, + "NameContains": name_contains, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "StatusEquals": status_equals, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + return ResourceIterator( + client=client, + list_method="list_endpoints", + summaries_key="Endpoints", + summary_name="EndpointSummary", + resource_cls=Endpoint, + list_method_kwargs=operation_input_args, ) - response = client.describe_experiment(**operation_input_args) - - logger.debug(response) - - # deserialize the response - transformed_response = transform(response, "DescribeExperimentResponse") - experiment = cls(**transformed_response) - return experiment @Base.add_validate_call - def refresh( + def update_weights_and_capacities( self, - ) -> Optional["Experiment"]: + desired_weights_and_capacities: List[DesiredWeightAndCapacity], + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: """ - Refresh a Experiment resource + Updates variant weight of one or more variants associated with an existing endpoint, or capacity of one variant associated with an existing endpoint. - Returns: - The Experiment resource. + Parameters: + desired_weights_and_capacities: An object that provides new capacity and weight values for a variant. + session: Boto3 session. + region: Region name. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9081,34 +9337,62 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. """ operation_input_args = { - "ExperimentName": self.experiment_name, + "EndpointName": self.endpoint_name, + "DesiredWeightsAndCapacities": desired_weights_and_capacities, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() - response = client.describe_experiment(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) - # deserialize response and update self - transform(response, "DescribeExperimentResponse", self) - return self + logger.debug(f"Calling update_endpoint_weights_and_capacities API") + response = client.update_endpoint_weights_and_capacities(**operation_input_args) + logger.debug(f"Response: {response}") @Base.add_validate_call - def update( + def invoke( self, - display_name: Optional[str] = Unassigned(), - description: Optional[str] = Unassigned(), - ) -> Optional["Experiment"]: + body: Any, + content_type: Optional[str] = Unassigned(), + accept: Optional[str] = Unassigned(), + custom_attributes: Optional[str] = Unassigned(), + target_model: Optional[str] = Unassigned(), + target_variant: Optional[str] = Unassigned(), + target_container_hostname: Optional[str] = Unassigned(), + inference_id: Optional[str] = Unassigned(), + enable_explanations: Optional[str] = Unassigned(), + inference_component_name: Optional[str] = Unassigned(), + session_id: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[InvokeEndpointOutput]: """ - Update a Experiment resource + After you deploy a model into production using Amazon SageMaker hosting services, your client applications use this API to get inferences from the model hosted at the specified endpoint. + + Parameters: + body: Provides input data, in the format specified in the ContentType request header. Amazon SageMaker passes all of the data in the body to the model. For information about the format of the request body, see Common Data Formats-Inference. + content_type: The MIME type of the input data in the request body. + accept: The desired MIME type of the inference response from the model container. + custom_attributes: Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK. + target_model: The model to request for inference when invoking a multi-model endpoint. + target_variant: Specify the production variant to send the inference request to when invoking an endpoint that is running two or more variants. Note that this parameter overrides the default behavior for the endpoint, which is to distribute the invocation traffic based on the variant weights. For information about how to use variant targeting to perform a/b testing, see Test models in production + target_container_hostname: If the endpoint hosts multiple containers and is configured to use direct invocation, this parameter specifies the host name of the container to invoke. + inference_id: If you provide a value, it is added to the captured data when you enable data capture on the endpoint. For information about data capture, see Capture Data. + enable_explanations: An optional JMESPath expression used to override the EnableExplanations parameter of the ClarifyExplainerConfig API. See the EnableExplanations section in the developer guide for more information. + inference_component_name: If the endpoint hosts one or more inference components, this parameter specifies the name of inference component to invoke. + session_id: Creates a stateful session or identifies an existing one. You can do one of the following: Create a stateful session by specifying the value NEW_SESSION. Send your request to an existing stateful session by specifying the ID of that session. With a stateful session, you can send multiple requests to a stateful model. When you create a session with a stateful model, the model must create the session ID and set the expiration time. The model must also provide that information in the response to your request. You can get the ID and timestamp from the NewSessionId response parameter. For any subsequent request where you specify that session ID, SageMaker routes the request to the same instance that supports the session. + session: Boto3 session. + region: Region name. Returns: - The Experiment resource. + InvokeEndpointOutput Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9120,36 +9404,72 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. - ResourceNotFound: Resource being access is not found. + InternalDependencyException: Your request caused an exception with an internal dependency. Contact customer support. + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + ModelError: Model (owned by the customer in the container) returned 4xx or 5xx error code. + ModelNotReadyException: Either a serverless endpoint variant's resources are still being provisioned, or a multi-model endpoint is still downloading or loading the target model. Wait and try your request again. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. """ - logger.info("Updating experiment resource.") - client = Base.get_sagemaker_client() - operation_input_args = { - "ExperimentName": self.experiment_name, - "DisplayName": display_name, - "Description": description, + "EndpointName": self.endpoint_name, + "Body": body, + "ContentType": content_type, + "Accept": accept, + "CustomAttributes": custom_attributes, + "TargetModel": target_model, + "TargetVariant": target_variant, + "TargetContainerHostname": target_container_hostname, + "InferenceId": inference_id, + "EnableExplanations": enable_explanations, + "InferenceComponentName": inference_component_name, + "SessionId": session_id, } - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.update_experiment(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-runtime" + ) + + logger.debug(f"Calling invoke_endpoint API") + response = client.invoke_endpoint(**operation_input_args) logger.debug(f"Response: {response}") - self.refresh() - return self + transformed_response = transform(response, "InvokeEndpointOutput") + return InvokeEndpointOutput(**transformed_response) @Base.add_validate_call - def delete( + def invoke_async( self, - ) -> None: + input_location: str, + content_type: Optional[str] = Unassigned(), + accept: Optional[str] = Unassigned(), + custom_attributes: Optional[str] = Unassigned(), + inference_id: Optional[str] = Unassigned(), + request_ttl_seconds: Optional[int] = Unassigned(), + invocation_timeout_seconds: Optional[int] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[InvokeEndpointAsyncOutput]: """ - Delete a Experiment resource + After you deploy a model into production using Amazon SageMaker hosting services, your client applications use this API to get inferences from the model hosted at the specified endpoint in an asynchronous manner. + + Parameters: + input_location: The Amazon S3 URI where the inference request payload is stored. + content_type: The MIME type of the input data in the request body. + accept: The desired MIME type of the inference response from the model container. + custom_attributes: Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK. + inference_id: The identifier for the inference request. Amazon SageMaker will generate an identifier for you if none is specified. + request_ttl_seconds: Maximum age in seconds a request can be in the queue before it is marked as expired. The default is 6 hours, or 21,600 seconds. + invocation_timeout_seconds: Maximum amount of time in seconds a request can be processed before it is marked as expired. The default is 15 minutes, or 900 seconds. + session: Boto3 session. + region: Region name. + + Returns: + InvokeEndpointAsyncOutput Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9161,48 +9481,69 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. """ - client = Base.get_sagemaker_client() - operation_input_args = { - "ExperimentName": self.experiment_name, + "EndpointName": self.endpoint_name, + "ContentType": content_type, + "Accept": accept, + "CustomAttributes": custom_attributes, + "InferenceId": inference_id, + "InputLocation": input_location, + "RequestTTLSeconds": request_ttl_seconds, + "InvocationTimeoutSeconds": invocation_timeout_seconds, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_experiment(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-runtime" + ) - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + logger.debug(f"Calling invoke_endpoint_async API") + response = client.invoke_endpoint_async(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "InvokeEndpointAsyncOutput") + return InvokeEndpointAsyncOutput(**transformed_response) - @classmethod @Base.add_validate_call - def get_all( - cls, - created_after: Optional[datetime.datetime] = Unassigned(), - created_before: Optional[datetime.datetime] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), + def invoke_with_response_stream( + self, + body: Any, + content_type: Optional[str] = Unassigned(), + accept: Optional[str] = Unassigned(), + custom_attributes: Optional[str] = Unassigned(), + target_variant: Optional[str] = Unassigned(), + target_container_hostname: Optional[str] = Unassigned(), + inference_id: Optional[str] = Unassigned(), + inference_component_name: Optional[str] = Unassigned(), + session_id: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["Experiment"]: + ) -> Optional[InvokeEndpointWithResponseStreamOutput]: """ - Get all Experiment resources + Invokes a model at the specified endpoint to return the inference response as a stream. Parameters: - created_after: A filter that returns only experiments created after the specified time. - created_before: A filter that returns only experiments created before the specified time. - sort_by: The property used to sort results. The default value is CreationTime. - sort_order: The sort order. The default value is Descending. - next_token: If the previous call to ListExperiments didn't return the full set of experiments, the call returns a token for getting the next set of experiments. - max_results: The maximum number of experiments to return in the response. The default value is 10. + body: Provides input data, in the format specified in the ContentType request header. Amazon SageMaker passes all of the data in the body to the model. For information about the format of the request body, see Common Data Formats-Inference. + content_type: The MIME type of the input data in the request body. + accept: The desired MIME type of the inference response from the model container. + custom_attributes: Provides additional information about a request for an inference submitted to a model hosted at an Amazon SageMaker endpoint. The information is an opaque value that is forwarded verbatim. You could use this value, for example, to provide an ID that you can use to track a request or to provide other metadata that a service endpoint was programmed to process. The value must consist of no more than 1024 visible US-ASCII characters as specified in Section 3.3.6. Field Value Components of the Hypertext Transfer Protocol (HTTP/1.1). The code in your model is responsible for setting or updating any custom attributes in the response. If your code does not set this value in the response, an empty value is returned. For example, if a custom attribute represents the trace ID, your model can prepend the custom attribute with Trace ID: in your post-processing function. This feature is currently supported in the Amazon Web Services SDKs but not in the Amazon SageMaker Python SDK. + target_variant: Specify the production variant to send the inference request to when invoking an endpoint that is running two or more variants. Note that this parameter overrides the default behavior for the endpoint, which is to distribute the invocation traffic based on the variant weights. For information about how to use variant targeting to perform a/b testing, see Test models in production + target_container_hostname: If the endpoint hosts multiple containers and is configured to use direct invocation, this parameter specifies the host name of the container to invoke. + inference_id: An identifier that you assign to your request. + inference_component_name: If the endpoint hosts one or more inference components, this parameter specifies the name of inference component to invoke for a streaming response. + session_id: The ID of a stateful session to handle your request. You can't create a stateful session by using the InvokeEndpointWithResponseStream action. Instead, you can create one by using the InvokeEndpoint action. In your request, you specify NEW_SESSION for the SessionId request parameter. The response to that request provides the session ID for the NewSessionId response parameter. session: Boto3 session. region: Region name. Returns: - Iterator for listed Experiment resources. + InvokeEndpointWithResponseStreamOutput Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9214,81 +9555,78 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + InternalStreamFailure: The stream processing failed because of an unknown error, exception or failure. Try your request again. + ModelError: Model (owned by the customer in the container) returned 4xx or 5xx error code. + ModelStreamError: An error occurred while streaming the response body. This error can have the following error codes: ModelInvocationTimeExceeded The model failed to finish sending the response within the timeout period allowed by Amazon SageMaker. StreamBroken The Transmission Control Protocol (TCP) connection between the client and the model was reset or closed. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { - "CreatedAfter": created_after, - "CreatedBefore": created_before, - "SortBy": sort_by, - "SortOrder": sort_order, + "EndpointName": self.endpoint_name, + "Body": body, + "ContentType": content_type, + "Accept": accept, + "CustomAttributes": custom_attributes, + "TargetVariant": target_variant, + "TargetContainerHostname": target_container_hostname, + "InferenceId": inference_id, + "InferenceComponentName": inference_component_name, + "SessionId": session_id, } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_experiments", - summaries_key="ExperimentSummaries", - summary_name="ExperimentSummary", - resource_cls=Experiment, - list_method_kwargs=operation_input_args, + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-runtime" ) + logger.debug(f"Calling invoke_endpoint_with_response_stream API") + response = client.invoke_endpoint_with_response_stream(**operation_input_args) + logger.debug(f"Response: {response}") -class FeatureGroup(Base): - """ - Class representing resource FeatureGroup + transformed_response = transform(response, "InvokeEndpointWithResponseStreamOutput") + return InvokeEndpointWithResponseStreamOutput(**transformed_response) - Attributes: - feature_group_arn: The Amazon Resource Name (ARN) of the FeatureGroup. - feature_group_name: he name of the FeatureGroup. - record_identifier_feature_name: The name of the Feature used for RecordIdentifier, whose value uniquely identifies a record stored in the feature store. - event_time_feature_name: The name of the feature that stores the EventTime of a Record in a FeatureGroup. An EventTime is a point in time when a new event occurs that corresponds to the creation or update of a Record in a FeatureGroup. All Records in the FeatureGroup have a corresponding EventTime. - feature_definitions: A list of the Features in the FeatureGroup. Each feature is defined by a FeatureName and FeatureType. - creation_time: A timestamp indicating when SageMaker created the FeatureGroup. - next_token: A token to resume pagination of the list of Features (FeatureDefinitions). - last_modified_time: A timestamp indicating when the feature group was last updated. - online_store_config: The configuration for the OnlineStore. - offline_store_config: The configuration of the offline store. It includes the following configurations: Amazon S3 location of the offline store. Configuration of the Glue data catalog. Table format of the offline store. Option to disable the automatic creation of a Glue table for the offline store. Encryption configuration. - throughput_config: - role_arn: The Amazon Resource Name (ARN) of the IAM execution role used to persist data into the OfflineStore if an OfflineStoreConfig is provided. - feature_group_status: The status of the feature group. - offline_store_status: The status of the OfflineStore. Notifies you if replicating data into the OfflineStore has failed. Returns either: Active or Blocked - last_update_status: A value indicating whether the update made to the feature group was successful. - failure_reason: The reason that the FeatureGroup failed to be replicated in the OfflineStore. This is failure can occur because: The FeatureGroup could not be created in the OfflineStore. The FeatureGroup could not be deleted from the OfflineStore. - description: A free form description of the feature group. - online_store_total_size_bytes: The size of the OnlineStore in bytes. +class EndpointConfig(Base): """ + Class representing resource EndpointConfig - feature_group_name: str - feature_group_arn: Optional[str] = Unassigned() - record_identifier_feature_name: Optional[str] = Unassigned() - event_time_feature_name: Optional[str] = Unassigned() - feature_definitions: Optional[List[FeatureDefinition]] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - online_store_config: Optional[OnlineStoreConfig] = Unassigned() - offline_store_config: Optional[OfflineStoreConfig] = Unassigned() - throughput_config: Optional[ThroughputConfigDescription] = Unassigned() - role_arn: Optional[str] = Unassigned() - feature_group_status: Optional[str] = Unassigned() - offline_store_status: Optional[OfflineStoreStatus] = Unassigned() - last_update_status: Optional[LastUpdateStatus] = Unassigned() - failure_reason: Optional[str] = Unassigned() - description: Optional[str] = Unassigned() - next_token: Optional[str] = Unassigned() - online_store_total_size_bytes: Optional[int] = Unassigned() + Attributes: + endpoint_config_name: Name of the SageMaker endpoint configuration. + endpoint_config_arn: The Amazon Resource Name (ARN) of the endpoint configuration. + production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint. + creation_time: A timestamp that shows when the endpoint configuration was created. + data_capture_config: + kms_key_id: Amazon Web Services KMS key ID Amazon SageMaker uses to encrypt data when storing it on the ML storage volume attached to the instance. + async_inference_config: Returns the description of an endpoint configuration created using the CreateEndpointConfig API. + explainer_config: The configuration parameters for an explainer. + shadow_production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. + execution_role_arn: The Amazon Resource Name (ARN) of the IAM role that you assigned to the endpoint configuration. + vpc_config: + enable_network_isolation: Indicates whether all model containers deployed to the endpoint are isolated. If they are, no inbound or outbound network calls can be made to or from the model containers. + + """ + + endpoint_config_name: str + endpoint_config_arn: Optional[str] = Unassigned() + production_variants: Optional[List[ProductionVariant]] = Unassigned() + data_capture_config: Optional[DataCaptureConfig] = Unassigned() + kms_key_id: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + async_inference_config: Optional[AsyncInferenceConfig] = Unassigned() + explainer_config: Optional[ExplainerConfig] = Unassigned() + shadow_production_variants: Optional[List[ProductionVariant]] = Unassigned() + execution_role_arn: Optional[str] = Unassigned() + vpc_config: Optional[VpcConfig] = Unassigned() + enable_network_isolation: Optional[bool] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "feature_group_name" + resource_name = "endpoint_config_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -9299,27 +9637,35 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object feature_group") + logger.error("Name attribute not found for object endpoint_config") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "online_store_config": {"security_config": {"kms_key_id": {"type": "string"}}}, - "offline_store_config": { - "s3_storage_config": { - "s3_uri": {"type": "string"}, + "data_capture_config": { + "destination_s3_uri": {"type": "string"}, + "kms_key_id": {"type": "string"}, + }, + "kms_key_id": {"type": "string"}, + "async_inference_config": { + "output_config": { "kms_key_id": {"type": "string"}, - "resolved_output_s3_uri": {"type": "string"}, + "s3_output_path": {"type": "string"}, + "s3_failure_path": {"type": "string"}, } }, - "role_arn": {"type": "string"}, + "execution_role_arn": {"type": "string"}, + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + }, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "FeatureGroup", **kwargs + config_schema_for_resource, "EndpointConfig", **kwargs ), ) @@ -9330,38 +9676,40 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - feature_group_name: str, - record_identifier_feature_name: str, - event_time_feature_name: str, - feature_definitions: List[FeatureDefinition], - online_store_config: Optional[OnlineStoreConfig] = Unassigned(), - offline_store_config: Optional[OfflineStoreConfig] = Unassigned(), - throughput_config: Optional[ThroughputConfig] = Unassigned(), - role_arn: Optional[str] = Unassigned(), - description: Optional[str] = Unassigned(), + endpoint_config_name: str, + production_variants: List[ProductionVariant], + data_capture_config: Optional[DataCaptureConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), + kms_key_id: Optional[str] = Unassigned(), + async_inference_config: Optional[AsyncInferenceConfig] = Unassigned(), + explainer_config: Optional[ExplainerConfig] = Unassigned(), + shadow_production_variants: Optional[List[ProductionVariant]] = Unassigned(), + execution_role_arn: Optional[str] = Unassigned(), + vpc_config: Optional[VpcConfig] = Unassigned(), + enable_network_isolation: Optional[bool] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["FeatureGroup"]: + ) -> Optional["EndpointConfig"]: """ - Create a FeatureGroup resource + Create a EndpointConfig resource Parameters: - feature_group_name: The name of the FeatureGroup. The name must be unique within an Amazon Web Services Region in an Amazon Web Services account. The name: Must start with an alphanumeric character. Can only include alphanumeric characters, underscores, and hyphens. Spaces are not allowed. - record_identifier_feature_name: The name of the Feature whose value uniquely identifies a Record defined in the FeatureStore. Only the latest record per identifier value will be stored in the OnlineStore. RecordIdentifierFeatureName must be one of feature definitions' names. You use the RecordIdentifierFeatureName to access data in a FeatureStore. This name: Must start with an alphanumeric character. Can only contains alphanumeric characters, hyphens, underscores. Spaces are not allowed. - event_time_feature_name: The name of the feature that stores the EventTime of a Record in a FeatureGroup. An EventTime is a point in time when a new event occurs that corresponds to the creation or update of a Record in a FeatureGroup. All Records in the FeatureGroup must have a corresponding EventTime. An EventTime can be a String or Fractional. Fractional: EventTime feature values must be a Unix timestamp in seconds. String: EventTime feature values must be an ISO-8601 string in the format. The following formats are supported yyyy-MM-dd'T'HH:mm:ssZ and yyyy-MM-dd'T'HH:mm:ss.SSSZ where yyyy, MM, and dd represent the year, month, and day respectively and HH, mm, ss, and if applicable, SSS represent the hour, month, second and milliseconds respsectively. 'T' and Z are constants. - feature_definitions: A list of Feature names and types. Name and Type is compulsory per Feature. Valid feature FeatureTypes are Integral, Fractional and String. FeatureNames cannot be any of the following: is_deleted, write_time, api_invocation_time You can create up to 2,500 FeatureDefinitions per FeatureGroup. - online_store_config: You can turn the OnlineStore on or off by specifying True for the EnableOnlineStore flag in OnlineStoreConfig. You can also include an Amazon Web Services KMS key ID (KMSKeyId) for at-rest encryption of the OnlineStore. The default value is False. - offline_store_config: Use this to configure an OfflineFeatureStore. This parameter allows you to specify: The Amazon Simple Storage Service (Amazon S3) location of an OfflineStore. A configuration for an Amazon Web Services Glue or Amazon Web Services Hive data catalog. An KMS encryption key to encrypt the Amazon S3 location used for OfflineStore. If KMS encryption key is not specified, by default we encrypt all data at rest using Amazon Web Services KMS key. By defining your bucket-level key for SSE, you can reduce Amazon Web Services KMS requests costs by up to 99 percent. Format for the offline store table. Supported formats are Glue (Default) and Apache Iceberg. To learn more about this parameter, see OfflineStoreConfig. - throughput_config: - role_arn: The Amazon Resource Name (ARN) of the IAM execution role used to persist data into the OfflineStore if an OfflineStoreConfig is provided. - description: A free-form description of a FeatureGroup. - tags: Tags used to identify Features in each FeatureGroup. + endpoint_config_name: The name of the endpoint configuration. You specify this name in a CreateEndpoint request. + production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint. + data_capture_config: + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + kms_key_id: The Amazon Resource Name (ARN) of a Amazon Web Services Key Management Service key that SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint. The KmsKeyId can be any of the following formats: Key ID: 1234abcd-12ab-34cd-56ef-1234567890ab Key ARN: arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab Alias name: alias/ExampleAlias Alias name ARN: arn:aws:kms:us-west-2:111122223333:alias/ExampleAlias The KMS key policy must grant permission to the IAM role that you specify in your CreateEndpoint, UpdateEndpoint requests. For more information, refer to the Amazon Web Services Key Management Service section Using Key Policies in Amazon Web Services KMS Certain Nitro-based instances include local storage, dependent on the instance type. Local storage volumes are encrypted using a hardware module on the instance. You can't request a KmsKeyId when using an instance type with local storage. If any of the models that you specify in the ProductionVariants parameter use nitro-based instances with local storage, do not specify a value for the KmsKeyId parameter. If you specify a value for KmsKeyId when using any nitro-based instances with local storage, the call to CreateEndpointConfig fails. For a list of instance types that support local instance storage, see Instance Store Volumes. For more information about local instance storage encryption, see SSD Instance Store Volumes. + async_inference_config: Specifies configuration for how an endpoint performs asynchronous inference. This is a required field in order for your Endpoint to be invoked using InvokeEndpointAsync. + explainer_config: A member of CreateEndpointConfig that enables explainers. + shadow_production_variants: An array of ProductionVariant objects, one for each model that you want to host at this endpoint in shadow mode with production traffic replicated from the model specified on ProductionVariants. If you use this field, you can only specify one variant for ProductionVariants and one variant for ShadowProductionVariants. + execution_role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform actions on your behalf. For more information, see SageMaker Roles. To be able to pass this role to Amazon SageMaker, the caller of this action must have the iam:PassRole permission. + vpc_config: + enable_network_isolation: Sets whether all model containers deployed to the endpoint are isolated. If they are, no inbound or outbound network calls can be made to or from the model containers. session: Boto3 session. region: Region name. Returns: - The FeatureGroup resource. + The EndpointConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9373,33 +9721,33 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating feature_group resource.") + logger.info("Creating endpoint_config resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "FeatureGroupName": feature_group_name, - "RecordIdentifierFeatureName": record_identifier_feature_name, - "EventTimeFeatureName": event_time_feature_name, - "FeatureDefinitions": feature_definitions, - "OnlineStoreConfig": online_store_config, - "OfflineStoreConfig": offline_store_config, - "ThroughputConfig": throughput_config, - "RoleArn": role_arn, - "Description": description, + "EndpointConfigName": endpoint_config_name, + "ProductionVariants": production_variants, + "DataCaptureConfig": data_capture_config, "Tags": tags, + "KmsKeyId": kms_key_id, + "AsyncInferenceConfig": async_inference_config, + "ExplainerConfig": explainer_config, + "ShadowProductionVariants": shadow_production_variants, + "ExecutionRoleArn": execution_role_arn, + "VpcConfig": vpc_config, + "EnableNetworkIsolation": enable_network_isolation, } operation_input_args = Base.populate_chained_attributes( - resource_name="FeatureGroup", operation_input_args=operation_input_args + resource_name="EndpointConfig", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -9408,31 +9756,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_feature_group(**operation_input_args) + response = client.create_endpoint_config(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(feature_group_name=feature_group_name, session=session, region=region) + return cls.get(endpoint_config_name=endpoint_config_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - feature_group_name: str, - next_token: Optional[str] = Unassigned(), + endpoint_config_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["FeatureGroup"]: + ) -> Optional["EndpointConfig"]: """ - Get a FeatureGroup resource + Get a EndpointConfig resource Parameters: - feature_group_name: The name or Amazon Resource Name (ARN) of the FeatureGroup you want described. - next_token: A token to resume pagination of the list of Features (FeatureDefinitions). 2,500 Features are returned by default. + endpoint_config_name: The name of the endpoint configuration. session: Boto3 session. region: Region name. Returns: - The FeatureGroup resource. + The EndpointConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9444,12 +9790,10 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "FeatureGroupName": feature_group_name, - "NextToken": next_token, + "EndpointConfigName": endpoint_config_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -9458,24 +9802,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_feature_group(**operation_input_args) + response = client.describe_endpoint_config(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeFeatureGroupResponse") - feature_group = cls(**transformed_response) - return feature_group + transformed_response = transform(response, "DescribeEndpointConfigOutput") + endpoint_config = cls(**transformed_response) + return endpoint_config @Base.add_validate_call def refresh( self, - ) -> Optional["FeatureGroup"]: + ) -> Optional["EndpointConfig"]: """ - Refresh a FeatureGroup resource + Refresh a EndpointConfig resource Returns: - The FeatureGroup resource. + The EndpointConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9487,40 +9831,28 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "FeatureGroupName": self.feature_group_name, - "NextToken": self.next_token, + "EndpointConfigName": self.endpoint_config_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_feature_group(**operation_input_args) + response = client.describe_endpoint_config(**operation_input_args) # deserialize response and update self - transform(response, "DescribeFeatureGroupResponse", self) + transform(response, "DescribeEndpointConfigOutput", self) return self - @populate_inputs_decorator @Base.add_validate_call - def update( + def delete( self, - feature_additions: Optional[List[FeatureDefinition]] = Unassigned(), - online_store_config: Optional[OnlineStoreConfigUpdate] = Unassigned(), - throughput_config: Optional[ThroughputConfigUpdate] = Unassigned(), - ) -> Optional["FeatureGroup"]: + ) -> None: """ - Update a FeatureGroup resource - - Parameters: - feature_additions: Updates the feature group. Updating a feature group is an asynchronous operation. When you get an HTTP 200 response, you've made a valid request. It takes some time after you've made a valid request for Feature Store to update the feature group. - - Returns: - The FeatureGroup resource. + Delete a EndpointConfig resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9532,37 +9864,49 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ResourceNotFound: Resource being access is not found. """ - logger.info("Updating feature_group resource.") client = Base.get_sagemaker_client() operation_input_args = { - "FeatureGroupName": self.feature_group_name, - "FeatureAdditions": feature_additions, - "OnlineStoreConfig": online_store_config, - "ThroughputConfig": throughput_config, + "EndpointConfigName": self.endpoint_config_name, } - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.update_feature_group(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() + client.delete_endpoint_config(**operation_input_args) - return self + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @classmethod @Base.add_validate_call - def delete( - self, - ) -> None: + def get_all( + cls, + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["EndpointConfig"]: """ - Delete a FeatureGroup resource + Get all EndpointConfig resources + + Parameters: + sort_by: The field to sort results by. The default is CreationTime. + sort_order: The sort order for results. The default is Descending. + next_token: If the result of the previous ListEndpointConfig request was truncated, the response includes a NextToken. To retrieve the next set of endpoint configurations, use the token in the next request. + max_results: The maximum number of training jobs to return in the response. + name_contains: A string in the endpoint configuration name. This filter returns only endpoint configurations whose name contains the specified string. + creation_time_before: A filter that returns only endpoint configurations created before the specified time (timestamp). + creation_time_after: A filter that returns only endpoint configurations with a creation time greater than or equal to the specified time (timestamp). + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed EndpointConfig resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9574,175 +9918,101 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) operation_input_args = { - "FeatureGroupName": self.feature_group_name, + "SortBy": sort_by, + "SortOrder": sort_order, + "NameContains": name_contains, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_feature_group(**operation_input_args) - - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def wait_for_status( - self, - target_status: Literal["Creating", "Created", "CreateFailed", "Deleting", "DeleteFailed"], - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a FeatureGroup resource to reach certain status. - - Parameters: - target_status: The status to wait for. - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. - - Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), + return ResourceIterator( + client=client, + list_method="list_endpoint_configs", + summaries_key="EndpointConfigs", + summary_name="EndpointConfigSummary", + resource_cls=EndpointConfig, + list_method_kwargs=operation_input_args, ) - progress.add_task(f"Waiting for FeatureGroup to reach [bold]{target_status} status...") - status = Status("Current status:") - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.feature_group_status - status.update(f"Current status: [bold]{current_status}") - - if target_status == current_status: - logger.info(f"Final Resource Status: [bold]{current_status}") - return - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="FeatureGroup", - status=current_status, - reason=self.failure_reason, - ) - - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="FeatureGroup", status=current_status) - time.sleep(poll) - - @Base.add_validate_call - def wait_for_delete( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a FeatureGroup resource to be deleted. +class Experiment(Base): + """ + Class representing resource Experiment - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + Attributes: + experiment_name: The name of the experiment. + experiment_arn: The Amazon Resource Name (ARN) of the experiment. + display_name: The name of the experiment as displayed. If DisplayName isn't specified, ExperimentName is displayed. + source: The Amazon Resource Name (ARN) of the source and, optionally, the type. + description: The description of the experiment. + creation_time: When the experiment was created. + created_by: Who created the experiment. + last_modified_time: When the experiment was last modified. + last_modified_by: Who last modified the experiment. - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() + """ - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for FeatureGroup to be deleted...") - status = Status("Current status:") + experiment_name: str + experiment_arn: Optional[str] = Unassigned() + display_name: Optional[str] = Unassigned() + source: Optional[ExperimentSource] = Unassigned() + description: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + created_by: Optional[UserContext] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + last_modified_by: Optional[UserContext] = Unassigned() - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ) - ): - while True: - try: - self.refresh() - current_status = self.feature_group_status - status.update(f"Current status: [bold]{current_status}") + def get_name(self) -> str: + attributes = vars(self) + resource_name = "experiment_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="FeatureGroup", status=current_status - ) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e - time.sleep(poll) + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object experiment") + return None @classmethod @Base.add_validate_call - def get_all( + def create( cls, - name_contains: Optional[str] = Unassigned(), - feature_group_status_equals: Optional[str] = Unassigned(), - offline_store_status_equals: Optional[str] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), + experiment_name: str, + display_name: Optional[str] = Unassigned(), + description: Optional[str] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["FeatureGroup"]: + ) -> Optional["Experiment"]: """ - Get all FeatureGroup resources + Create a Experiment resource Parameters: - name_contains: A string that partially matches one or more FeatureGroups names. Filters FeatureGroups by name. - feature_group_status_equals: A FeatureGroup status. Filters by FeatureGroup status. - offline_store_status_equals: An OfflineStore status. Filters by OfflineStore status. - creation_time_after: Use this parameter to search for FeatureGroupss created after a specific date and time. - creation_time_before: Use this parameter to search for FeatureGroupss created before a specific date and time. - sort_order: The order in which feature groups are listed. - sort_by: The value on which the feature group list is sorted. - max_results: The maximum number of results returned by ListFeatureGroups. - next_token: A token to resume pagination of ListFeatureGroups results. + experiment_name: The name of the experiment. The name must be unique in your Amazon Web Services account and is not case-sensitive. + display_name: The name of the experiment as displayed. The name doesn't need to be unique. If you don't specify DisplayName, the value in ExperimentName is displayed. + description: The description of the experiment. + tags: A list of tags to associate with the experiment. You can use Search API to search on the tags. session: Boto3 session. region: Region name. Returns: - Iterator for listed FeatureGroup resources. + The Experiment resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9754,56 +10024,57 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ + logger.info("Creating experiment resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "NameContains": name_contains, - "FeatureGroupStatusEquals": feature_group_status_equals, - "OfflineStoreStatusEquals": offline_store_status_equals, - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "SortOrder": sort_order, - "SortBy": sort_by, + "ExperimentName": experiment_name, + "DisplayName": display_name, + "Description": description, + "Tags": tags, } + operation_input_args = Base.populate_chained_attributes( + resource_name="Experiment", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_feature_groups", - summaries_key="FeatureGroupSummaries", - summary_name="FeatureGroupSummary", - resource_cls=FeatureGroup, - list_method_kwargs=operation_input_args, - ) + # create the resource + response = client.create_experiment(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get(experiment_name=experiment_name, session=session, region=region) + @classmethod @Base.add_validate_call - def get_record( - self, - record_identifier_value_as_string: str, - feature_names: Optional[List[str]] = Unassigned(), - expiration_time_response: Optional[str] = Unassigned(), + def get( + cls, + experiment_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional[GetRecordResponse]: + ) -> Optional["Experiment"]: """ - Use for OnlineStore serving from a FeatureStore. + Get a Experiment resource Parameters: - record_identifier_value_as_string: The value that corresponds to RecordIdentifier type and uniquely identifies the record in the FeatureGroup. - feature_names: List of names of Features to be retrieved. If not specified, the latest value for all the Features are returned. - expiration_time_response: Parameter to request ExpiresAt in response. If Enabled, GetRecord will return the value of ExpiresAt, if it is not null. If Disabled and null, GetRecord will return null. + experiment_name: The name of the experiment to describe. session: Boto3 session. region: Region name. Returns: - GetRecordResponse + The Experiment resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9815,52 +10086,37 @@ def get_record( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - AccessForbidden: You do not have permission to perform an action. - InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. ResourceNotFound: Resource being access is not found. - ServiceUnavailable: The service is currently unavailable. - ValidationError: There was an error validating your request. """ operation_input_args = { - "FeatureGroupName": self.feature_group_name, - "RecordIdentifierValueAsString": record_identifier_value_as_string, - "FeatureNames": feature_names, - "ExpirationTimeResponse": expiration_time_response, + "ExperimentName": experiment_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker-featurestore-runtime" + session=session, region_name=region, service_name="sagemaker" ) + response = client.describe_experiment(**operation_input_args) - logger.debug(f"Calling get_record API") - response = client.get_record(**operation_input_args) - logger.debug(f"Response: {response}") - - transformed_response = transform(response, "GetRecordResponse") - return GetRecordResponse(**transformed_response) + logger.debug(response) - @Base.add_validate_call - def put_record( + # deserialize the response + transformed_response = transform(response, "DescribeExperimentResponse") + experiment = cls(**transformed_response) + return experiment + + @Base.add_validate_call + def refresh( self, - record: List[FeatureValue], - target_stores: Optional[List[str]] = Unassigned(), - ttl_duration: Optional[TtlDuration] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> None: + ) -> Optional["Experiment"]: """ - The PutRecord API is used to ingest a list of Records into your feature group. + Refresh a Experiment resource - Parameters: - record: List of FeatureValues to be inserted. This will be a full over-write. If you only want to update few of the feature values, do the following: Use GetRecord to retrieve the latest record. Update the record returned from GetRecord. Use PutRecord to update feature values. - target_stores: A list of stores to which you're adding the record. By default, Feature Store adds the record to all of the stores that you're using for the FeatureGroup. - ttl_duration: Time to live duration, where the record is hard deleted after the expiration time is reached; ExpiresAt = EventTime + TtlDuration. For information on HardDelete, see the DeleteRecord API in the Amazon SageMaker API Reference guide. - session: Boto3 session. - region: Region name. + Returns: + The Experiment resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9872,50 +10128,34 @@ def put_record( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - AccessForbidden: You do not have permission to perform an action. - InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. - ServiceUnavailable: The service is currently unavailable. - ValidationError: There was an error validating your request. + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "FeatureGroupName": self.feature_group_name, - "Record": record, - "TargetStores": target_stores, - "TtlDuration": ttl_duration, + "ExperimentName": self.experiment_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker-featurestore-runtime" - ) + client = Base.get_sagemaker_client() + response = client.describe_experiment(**operation_input_args) - logger.debug(f"Calling put_record API") - response = client.put_record(**operation_input_args) - logger.debug(f"Response: {response}") + # deserialize response and update self + transform(response, "DescribeExperimentResponse", self) + return self @Base.add_validate_call - def delete_record( + def update( self, - record_identifier_value_as_string: str, - event_time: str, - target_stores: Optional[List[str]] = Unassigned(), - deletion_mode: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> None: + display_name: Optional[str] = Unassigned(), + description: Optional[str] = Unassigned(), + ) -> Optional["Experiment"]: """ - Deletes a Record from a FeatureGroup in the OnlineStore. + Update a Experiment resource - Parameters: - record_identifier_value_as_string: The value for the RecordIdentifier that uniquely identifies the record, in string format. - event_time: Timestamp indicating when the deletion event occurred. EventTime can be used to query data at a certain point in time. - target_stores: A list of stores from which you're deleting the record. By default, Feature Store deletes the record from all of the stores that you're using for the FeatureGroup. - deletion_mode: The name of the deletion mode for deleting the record. By default, the deletion mode is set to SoftDelete. - session: Boto3 session. - region: Region name. + Returns: + The Experiment resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9927,50 +10167,36 @@ def delete_record( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - AccessForbidden: You do not have permission to perform an action. - InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. - ServiceUnavailable: The service is currently unavailable. - ValidationError: There was an error validating your request. + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceNotFound: Resource being access is not found. """ + logger.info("Updating experiment resource.") + client = Base.get_sagemaker_client() + operation_input_args = { - "FeatureGroupName": self.feature_group_name, - "RecordIdentifierValueAsString": record_identifier_value_as_string, - "EventTime": event_time, - "TargetStores": target_stores, - "DeletionMode": deletion_mode, + "ExperimentName": self.experiment_name, + "DisplayName": display_name, + "Description": description, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker-featurestore-runtime" - ) - - logger.debug(f"Calling delete_record API") - response = client.delete_record(**operation_input_args) + # create the resource + response = client.update_experiment(**operation_input_args) logger.debug(f"Response: {response}") + self.refresh() + + return self @Base.add_validate_call - def batch_get_record( + def delete( self, - identifiers: List[BatchGetRecordIdentifier], - expiration_time_response: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional[BatchGetRecordResponse]: + ) -> None: """ - Retrieves a batch of Records from a FeatureGroup. - - Parameters: - identifiers: A list containing the name or Amazon Resource Name (ARN) of the FeatureGroup, the list of names of Features to be retrieved, and the corresponding RecordIdentifier values as strings. - expiration_time_response: Parameter to request ExpiresAt in response. If Enabled, BatchGetRecord will return the value of ExpiresAt, if it is not null. If Disabled and null, BatchGetRecord will return null. - session: Boto3 session. - region: Region name. - - Returns: - BatchGetRecordResponse + Delete a Experiment resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -9982,93 +10208,48 @@ def batch_get_record( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - AccessForbidden: You do not have permission to perform an action. - InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. - ServiceUnavailable: The service is currently unavailable. - ValidationError: There was an error validating your request. + ResourceNotFound: Resource being access is not found. """ + client = Base.get_sagemaker_client() + operation_input_args = { - "Identifiers": identifiers, - "ExpirationTimeResponse": expiration_time_response, + "ExperimentName": self.experiment_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker-featurestore-runtime" - ) - - logger.debug(f"Calling batch_get_record API") - response = client.batch_get_record(**operation_input_args) - logger.debug(f"Response: {response}") - - transformed_response = transform(response, "BatchGetRecordResponse") - return BatchGetRecordResponse(**transformed_response) - - -class FeatureMetadata(Base): - """ - Class representing resource FeatureMetadata - - Attributes: - feature_group_arn: The Amazon Resource Number (ARN) of the feature group that contains the feature. - feature_group_name: The name of the feature group that you've specified. - feature_name: The name of the feature that you've specified. - feature_type: The data type of the feature. - creation_time: A timestamp indicating when the feature was created. - last_modified_time: A timestamp indicating when the metadata for the feature group was modified. For example, if you add a parameter describing the feature, the timestamp changes to reflect the last time you - description: The description you added to describe the feature. - parameters: The key-value pairs that you added to describe the feature. - - """ - - feature_group_name: str - feature_name: str - feature_group_arn: Optional[str] = Unassigned() - feature_type: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - description: Optional[str] = Unassigned() - parameters: Optional[List[FeatureParameter]] = Unassigned() - - def get_name(self) -> str: - attributes = vars(self) - resource_name = "feature_metadata_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] - - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) + client.delete_experiment(**operation_input_args) - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object feature_metadata") - return None + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @classmethod @Base.add_validate_call - def get( + def get_all( cls, - feature_group_name: str, - feature_name: str, + created_after: Optional[datetime.datetime] = Unassigned(), + created_before: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["FeatureMetadata"]: + ) -> ResourceIterator["Experiment"]: """ - Get a FeatureMetadata resource + Get all Experiment resources Parameters: - feature_group_name: The name or Amazon Resource Name (ARN) of the feature group containing the feature. - feature_name: The name of the feature. + created_after: A filter that returns only experiments created after the specified time. + created_before: A filter that returns only experiments created before the specified time. + sort_by: The property used to sort results. The default value is CreationTime. + sort_order: The sort order. The default value is Descending. + next_token: If the previous call to ListExperiments didn't return the full set of experiments, the call returns a token for getting the next set of experiments. + max_results: The maximum number of experiments to return in the response. The default value is 10. session: Boto3 session. region: Region name. Returns: - The FeatureMetadata resource. + Iterator for listed Experiment resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10080,152 +10261,81 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ - operation_input_args = { - "FeatureGroupName": feature_group_name, - "FeatureName": feature_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_feature_metadata(**operation_input_args) - - logger.debug(response) - - # deserialize the response - transformed_response = transform(response, "DescribeFeatureMetadataResponse") - feature_metadata = cls(**transformed_response) - return feature_metadata - - @Base.add_validate_call - def refresh( - self, - ) -> Optional["FeatureMetadata"]: - """ - Refresh a FeatureMetadata resource - - Returns: - The FeatureMetadata resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceNotFound: Resource being access is not found. - """ operation_input_args = { - "FeatureGroupName": self.feature_group_name, - "FeatureName": self.feature_name, + "CreatedAfter": created_after, + "CreatedBefore": created_before, + "SortBy": sort_by, + "SortOrder": sort_order, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() - response = client.describe_feature_metadata(**operation_input_args) - - # deserialize response and update self - transform(response, "DescribeFeatureMetadataResponse", self) - return self - - @Base.add_validate_call - def update( - self, - description: Optional[str] = Unassigned(), - parameter_additions: Optional[List[FeatureParameter]] = Unassigned(), - parameter_removals: Optional[List[str]] = Unassigned(), - ) -> Optional["FeatureMetadata"]: - """ - Update a FeatureMetadata resource - - Parameters: - parameter_additions: A list of key-value pairs that you can add to better describe the feature. - parameter_removals: A list of parameter keys that you can specify to remove parameters that describe your feature. - - Returns: - The FeatureMetadata resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceNotFound: Resource being access is not found. - """ - - logger.info("Updating feature_metadata resource.") - client = Base.get_sagemaker_client() - - operation_input_args = { - "FeatureGroupName": self.feature_group_name, - "FeatureName": self.feature_name, - "Description": description, - "ParameterAdditions": parameter_additions, - "ParameterRemovals": parameter_removals, - } - logger.debug(f"Input request: {operation_input_args}") - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - # create the resource - response = client.update_feature_metadata(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() - - return self + return ResourceIterator( + client=client, + list_method="list_experiments", + summaries_key="ExperimentSummaries", + summary_name="ExperimentSummary", + resource_cls=Experiment, + list_method_kwargs=operation_input_args, + ) -class FlowDefinition(Base): +class FeatureGroup(Base): """ - Class representing resource FlowDefinition + Class representing resource FeatureGroup Attributes: - flow_definition_arn: The Amazon Resource Name (ARN) of the flow defintion. - flow_definition_name: The Amazon Resource Name (ARN) of the flow definition. - flow_definition_status: The status of the flow definition. Valid values are listed below. - creation_time: The timestamp when the flow definition was created. - output_config: An object containing information about the output file. - role_arn: The Amazon Resource Name (ARN) of the Amazon Web Services Identity and Access Management (IAM) execution role for the flow definition. - human_loop_request_source: Container for configuring the source of human task requests. Used to specify if Amazon Rekognition or Amazon Textract is used as an integration source. - human_loop_activation_config: An object containing information about what triggers a human review workflow. - human_loop_config: An object containing information about who works on the task, the workforce task price, and other task details. - failure_reason: The reason your flow definition failed. + feature_group_arn: The Amazon Resource Name (ARN) of the FeatureGroup. + feature_group_name: he name of the FeatureGroup. + record_identifier_feature_name: The name of the Feature used for RecordIdentifier, whose value uniquely identifies a record stored in the feature store. + event_time_feature_name: The name of the feature that stores the EventTime of a Record in a FeatureGroup. An EventTime is a point in time when a new event occurs that corresponds to the creation or update of a Record in a FeatureGroup. All Records in the FeatureGroup have a corresponding EventTime. + feature_definitions: A list of the Features in the FeatureGroup. Each feature is defined by a FeatureName and FeatureType. + creation_time: A timestamp indicating when SageMaker created the FeatureGroup. + next_token: A token to resume pagination of the list of Features (FeatureDefinitions). + last_modified_time: A timestamp indicating when the feature group was last updated. + online_store_config: The configuration for the OnlineStore. + offline_store_config: The configuration of the offline store. It includes the following configurations: Amazon S3 location of the offline store. Configuration of the Glue data catalog. Table format of the offline store. Option to disable the automatic creation of a Glue table for the offline store. Encryption configuration. + throughput_config: + role_arn: The Amazon Resource Name (ARN) of the IAM execution role used to persist data into the OfflineStore if an OfflineStoreConfig is provided. + feature_group_status: The status of the feature group. + offline_store_status: The status of the OfflineStore. Notifies you if replicating data into the OfflineStore has failed. Returns either: Active or Blocked + last_update_status: A value indicating whether the update made to the feature group was successful. + failure_reason: The reason that the FeatureGroup failed to be replicated in the OfflineStore. This is failure can occur because: The FeatureGroup could not be created in the OfflineStore. The FeatureGroup could not be deleted from the OfflineStore. + description: A free form description of the feature group. + online_store_total_size_bytes: The size of the OnlineStore in bytes. """ - flow_definition_name: str - flow_definition_arn: Optional[str] = Unassigned() - flow_definition_status: Optional[str] = Unassigned() + feature_group_name: str + feature_group_arn: Optional[str] = Unassigned() + record_identifier_feature_name: Optional[str] = Unassigned() + event_time_feature_name: Optional[str] = Unassigned() + feature_definitions: Optional[List[FeatureDefinition]] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - human_loop_request_source: Optional[HumanLoopRequestSource] = Unassigned() - human_loop_activation_config: Optional[HumanLoopActivationConfig] = Unassigned() - human_loop_config: Optional[HumanLoopConfig] = Unassigned() - output_config: Optional[FlowDefinitionOutputConfig] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + online_store_config: Optional[OnlineStoreConfig] = Unassigned() + offline_store_config: Optional[OfflineStoreConfig] = Unassigned() + throughput_config: Optional[ThroughputConfigDescription] = Unassigned() role_arn: Optional[str] = Unassigned() + feature_group_status: Optional[str] = Unassigned() + offline_store_status: Optional[OfflineStoreStatus] = Unassigned() + last_update_status: Optional[LastUpdateStatus] = Unassigned() failure_reason: Optional[str] = Unassigned() + description: Optional[str] = Unassigned() + next_token: Optional[str] = Unassigned() + online_store_total_size_bytes: Optional[int] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "flow_definition_name" + resource_name = "feature_group_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -10236,23 +10346,27 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object flow_definition") + logger.error("Name attribute not found for object feature_group") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "output_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, + "online_store_config": {"security_config": {"kms_key_id": {"type": "string"}}}, + "offline_store_config": { + "s3_storage_config": { + "s3_uri": {"type": "string"}, + "kms_key_id": {"type": "string"}, + "resolved_output_s3_uri": {"type": "string"}, + } }, "role_arn": {"type": "string"}, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "FlowDefinition", **kwargs + config_schema_for_resource, "FeatureGroup", **kwargs ), ) @@ -10263,32 +10377,38 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - flow_definition_name: str, - output_config: FlowDefinitionOutputConfig, - role_arn: str, - human_loop_request_source: Optional[HumanLoopRequestSource] = Unassigned(), - human_loop_activation_config: Optional[HumanLoopActivationConfig] = Unassigned(), - human_loop_config: Optional[HumanLoopConfig] = Unassigned(), + feature_group_name: str, + record_identifier_feature_name: str, + event_time_feature_name: str, + feature_definitions: List[FeatureDefinition], + online_store_config: Optional[OnlineStoreConfig] = Unassigned(), + offline_store_config: Optional[OfflineStoreConfig] = Unassigned(), + throughput_config: Optional[ThroughputConfig] = Unassigned(), + role_arn: Optional[str] = Unassigned(), + description: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["FlowDefinition"]: + ) -> Optional["FeatureGroup"]: """ - Create a FlowDefinition resource + Create a FeatureGroup resource Parameters: - flow_definition_name: The name of your flow definition. - output_config: An object containing information about where the human review results will be uploaded. - role_arn: The Amazon Resource Name (ARN) of the role needed to call other services on your behalf. For example, arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole-20180111T151298. - human_loop_request_source: Container for configuring the source of human task requests. Use to specify if Amazon Rekognition or Amazon Textract is used as an integration source. - human_loop_activation_config: An object containing information about the events that trigger a human workflow. - human_loop_config: An object containing information about the tasks the human reviewers will perform. - tags: An array of key-value pairs that contain metadata to help you categorize and organize a flow definition. Each tag consists of a key and a value, both of which you define. + feature_group_name: The name of the FeatureGroup. The name must be unique within an Amazon Web Services Region in an Amazon Web Services account. The name: Must start with an alphanumeric character. Can only include alphanumeric characters, underscores, and hyphens. Spaces are not allowed. + record_identifier_feature_name: The name of the Feature whose value uniquely identifies a Record defined in the FeatureStore. Only the latest record per identifier value will be stored in the OnlineStore. RecordIdentifierFeatureName must be one of feature definitions' names. You use the RecordIdentifierFeatureName to access data in a FeatureStore. This name: Must start with an alphanumeric character. Can only contains alphanumeric characters, hyphens, underscores. Spaces are not allowed. + event_time_feature_name: The name of the feature that stores the EventTime of a Record in a FeatureGroup. An EventTime is a point in time when a new event occurs that corresponds to the creation or update of a Record in a FeatureGroup. All Records in the FeatureGroup must have a corresponding EventTime. An EventTime can be a String or Fractional. Fractional: EventTime feature values must be a Unix timestamp in seconds. String: EventTime feature values must be an ISO-8601 string in the format. The following formats are supported yyyy-MM-dd'T'HH:mm:ssZ and yyyy-MM-dd'T'HH:mm:ss.SSSZ where yyyy, MM, and dd represent the year, month, and day respectively and HH, mm, ss, and if applicable, SSS represent the hour, month, second and milliseconds respsectively. 'T' and Z are constants. + feature_definitions: A list of Feature names and types. Name and Type is compulsory per Feature. Valid feature FeatureTypes are Integral, Fractional and String. FeatureNames cannot be any of the following: is_deleted, write_time, api_invocation_time You can create up to 2,500 FeatureDefinitions per FeatureGroup. + online_store_config: You can turn the OnlineStore on or off by specifying True for the EnableOnlineStore flag in OnlineStoreConfig. You can also include an Amazon Web Services KMS key ID (KMSKeyId) for at-rest encryption of the OnlineStore. The default value is False. + offline_store_config: Use this to configure an OfflineFeatureStore. This parameter allows you to specify: The Amazon Simple Storage Service (Amazon S3) location of an OfflineStore. A configuration for an Amazon Web Services Glue or Amazon Web Services Hive data catalog. An KMS encryption key to encrypt the Amazon S3 location used for OfflineStore. If KMS encryption key is not specified, by default we encrypt all data at rest using Amazon Web Services KMS key. By defining your bucket-level key for SSE, you can reduce Amazon Web Services KMS requests costs by up to 99 percent. Format for the offline store table. Supported formats are Glue (Default) and Apache Iceberg. To learn more about this parameter, see OfflineStoreConfig. + throughput_config: + role_arn: The Amazon Resource Name (ARN) of the IAM execution role used to persist data into the OfflineStore if an OfflineStoreConfig is provided. + description: A free-form description of a FeatureGroup. + tags: Tags used to identify Features in each FeatureGroup. session: Boto3 session. region: Region name. Returns: - The FlowDefinition resource. + The FeatureGroup resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10307,23 +10427,26 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating flow_definition resource.") + logger.info("Creating feature_group resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "FlowDefinitionName": flow_definition_name, - "HumanLoopRequestSource": human_loop_request_source, - "HumanLoopActivationConfig": human_loop_activation_config, - "HumanLoopConfig": human_loop_config, - "OutputConfig": output_config, + "FeatureGroupName": feature_group_name, + "RecordIdentifierFeatureName": record_identifier_feature_name, + "EventTimeFeatureName": event_time_feature_name, + "FeatureDefinitions": feature_definitions, + "OnlineStoreConfig": online_store_config, + "OfflineStoreConfig": offline_store_config, + "ThroughputConfig": throughput_config, "RoleArn": role_arn, + "Description": description, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="FlowDefinition", operation_input_args=operation_input_args + resource_name="FeatureGroup", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -10332,29 +10455,31 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_flow_definition(**operation_input_args) + response = client.create_feature_group(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(flow_definition_name=flow_definition_name, session=session, region=region) + return cls.get(feature_group_name=feature_group_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - flow_definition_name: str, + feature_group_name: str, + next_token: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["FlowDefinition"]: + ) -> Optional["FeatureGroup"]: """ - Get a FlowDefinition resource + Get a FeatureGroup resource Parameters: - flow_definition_name: The name of the flow definition. + feature_group_name: The name or Amazon Resource Name (ARN) of the FeatureGroup you want described. + next_token: A token to resume pagination of the list of Features (FeatureDefinitions). 2,500 Features are returned by default. session: Boto3 session. region: Region name. Returns: - The FlowDefinition resource. + The FeatureGroup resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10370,7 +10495,8 @@ def get( """ operation_input_args = { - "FlowDefinitionName": flow_definition_name, + "FeatureGroupName": feature_group_name, + "NextToken": next_token, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -10379,24 +10505,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_flow_definition(**operation_input_args) + response = client.describe_feature_group(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeFlowDefinitionResponse") - flow_definition = cls(**transformed_response) - return flow_definition + transformed_response = transform(response, "DescribeFeatureGroupResponse") + feature_group = cls(**transformed_response) + return feature_group @Base.add_validate_call def refresh( self, - ) -> Optional["FlowDefinition"]: + ) -> Optional["FeatureGroup"]: """ - Refresh a FlowDefinition resource + Refresh a FeatureGroup resource Returns: - The FlowDefinition resource. + The FeatureGroup resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10412,25 +10538,36 @@ def refresh( """ operation_input_args = { - "FlowDefinitionName": self.flow_definition_name, + "FeatureGroupName": self.feature_group_name, + "NextToken": self.next_token, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_flow_definition(**operation_input_args) + response = client.describe_feature_group(**operation_input_args) # deserialize response and update self - transform(response, "DescribeFlowDefinitionResponse", self) + transform(response, "DescribeFeatureGroupResponse", self) return self + @populate_inputs_decorator @Base.add_validate_call - def delete( + def update( self, - ) -> None: + feature_additions: Optional[List[FeatureDefinition]] = Unassigned(), + online_store_config: Optional[OnlineStoreConfigUpdate] = Unassigned(), + throughput_config: Optional[ThroughputConfigUpdate] = Unassigned(), + ) -> Optional["FeatureGroup"]: """ - Delete a FlowDefinition resource + Update a FeatureGroup resource + + Parameters: + feature_additions: Updates the feature group. Updating a feature group is an asynchronous operation. When you get an HTTP 200 response, you've made a valid request. It takes some time after you've made a valid request for Feature Store to update the feature group. + + Returns: + The FeatureGroup resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10442,32 +10579,73 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ResourceNotFound: Resource being access is not found. """ + logger.info("Updating feature_group resource.") client = Base.get_sagemaker_client() operation_input_args = { - "FlowDefinitionName": self.flow_definition_name, + "FeatureGroupName": self.feature_group_name, + "FeatureAdditions": feature_additions, + "OnlineStoreConfig": online_store_config, + "ThroughputConfig": throughput_config, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_flow_definition(**operation_input_args) + # create the resource + response = client.update_feature_group(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + return self + + @Base.add_validate_call + def delete( + self, + ) -> None: + """ + Delete a FeatureGroup resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "FeatureGroupName": self.feature_group_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_feature_group(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call def wait_for_status( self, - target_status: Literal["Initializing", "Active", "Failed", "Deleting"], + target_status: Literal["Creating", "Created", "CreateFailed", "Deleting", "DeleteFailed"], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a FlowDefinition resource to reach certain status. + Wait for a FeatureGroup resource to reach certain status. Parameters: target_status: The status to wait for. @@ -10486,7 +10664,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task(f"Waiting for FlowDefinition to reach [bold]{target_status} status...") + progress.add_task(f"Waiting for FeatureGroup to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -10499,7 +10677,7 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.flow_definition_status + current_status = self.feature_group_status status.update(f"Current status: [bold]{current_status}") if target_status == current_status: @@ -10508,13 +10686,13 @@ def wait_for_status( if "failed" in current_status.lower(): raise FailedStatusError( - resource_type="FlowDefinition", + resource_type="FeatureGroup", status=current_status, reason=self.failure_reason, ) if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="FlowDefinition", status=current_status) + raise TimeoutExceededError(resouce_type="FeatureGroup", status=current_status) time.sleep(poll) @Base.add_validate_call @@ -10524,7 +10702,7 @@ def wait_for_delete( timeout: Optional[int] = None, ) -> None: """ - Wait for a FlowDefinition resource to be deleted. + Wait for a FeatureGroup resource to be deleted. Parameters: poll: The number of seconds to wait between each poll. @@ -10551,7 +10729,7 @@ def wait_for_delete( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for FlowDefinition to be deleted...") + progress.add_task("Waiting for FeatureGroup to be deleted...") status = Status("Current status:") with Live( @@ -10564,12 +10742,12 @@ def wait_for_delete( while True: try: self.refresh() - current_status = self.flow_definition_status + current_status = self.feature_group_status status.update(f"Current status: [bold]{current_status}") if timeout is not None and time.time() - start_time >= timeout: raise TimeoutExceededError( - resouce_type="FlowDefinition", status=current_status + resouce_type="FeatureGroup", status=current_status ) except botocore.exceptions.ClientError as e: error_code = e.response["Error"]["Code"] @@ -10584,26 +10762,34 @@ def wait_for_delete( @Base.add_validate_call def get_all( cls, + name_contains: Optional[str] = Unassigned(), + feature_group_status_equals: Optional[str] = Unassigned(), + offline_store_status_equals: Optional[str] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), sort_order: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["FlowDefinition"]: + ) -> ResourceIterator["FeatureGroup"]: """ - Get all FlowDefinition resources + Get all FeatureGroup resources Parameters: - creation_time_after: A filter that returns only flow definitions with a creation time greater than or equal to the specified timestamp. - creation_time_before: A filter that returns only flow definitions that were created before the specified timestamp. - sort_order: An optional value that specifies whether you want the results sorted in Ascending or Descending order. - next_token: A token to resume pagination. - max_results: The total number of items to return. If the total number of available items is more than the value specified in MaxResults, then a NextToken will be provided in the output that you can use to resume pagination. + name_contains: A string that partially matches one or more FeatureGroups names. Filters FeatureGroups by name. + feature_group_status_equals: A FeatureGroup status. Filters by FeatureGroup status. + offline_store_status_equals: An OfflineStore status. Filters by OfflineStore status. + creation_time_after: Use this parameter to search for FeatureGroupss created after a specific date and time. + creation_time_before: Use this parameter to search for FeatureGroupss created before a specific date and time. + sort_order: The order in which feature groups are listed. + sort_by: The value on which the feature group list is sorted. + max_results: The maximum number of results returned by ListFeatureGroups. + next_token: A token to resume pagination of ListFeatureGroups results. session: Boto3 session. region: Region name. Returns: - Iterator for listed FlowDefinition resources. + Iterator for listed FeatureGroup resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10622,9 +10808,13 @@ def get_all( ) operation_input_args = { + "NameContains": name_contains, + "FeatureGroupStatusEquals": feature_group_status_equals, + "OfflineStoreStatusEquals": offline_store_status_equals, "CreationTimeAfter": creation_time_after, "CreationTimeBefore": creation_time_before, "SortOrder": sort_order, + "SortBy": sort_by, } # serialize the input request @@ -10633,103 +10823,34 @@ def get_all( return ResourceIterator( client=client, - list_method="list_flow_definitions", - summaries_key="FlowDefinitionSummaries", - summary_name="FlowDefinitionSummary", - resource_cls=FlowDefinition, + list_method="list_feature_groups", + summaries_key="FeatureGroupSummaries", + summary_name="FeatureGroupSummary", + resource_cls=FeatureGroup, list_method_kwargs=operation_input_args, ) - -class Hub(Base): - """ - Class representing resource Hub - - Attributes: - hub_name: The name of the hub. - hub_arn: The Amazon Resource Name (ARN) of the hub. - hub_status: The status of the hub. - creation_time: The date and time that the hub was created. - last_modified_time: The date and time that the hub was last modified. - hub_display_name: The display name of the hub. - hub_description: A description of the hub. - hub_search_keywords: The searchable keywords for the hub. - s3_storage_config: The Amazon S3 storage configuration for the hub. - failure_reason: The failure reason if importing hub content failed. - - """ - - hub_name: str - hub_arn: Optional[str] = Unassigned() - hub_display_name: Optional[str] = Unassigned() - hub_description: Optional[str] = Unassigned() - hub_search_keywords: Optional[List[str]] = Unassigned() - s3_storage_config: Optional[HubS3StorageConfig] = Unassigned() - hub_status: Optional[str] = Unassigned() - failure_reason: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - - def get_name(self) -> str: - attributes = vars(self) - resource_name = "hub_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] - - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) - - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object hub") - return None - - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "s3_storage_config": {"s3_output_path": {"type": "string"}} - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "Hub", **kwargs - ), - ) - - return wrapper - - @classmethod - @populate_inputs_decorator @Base.add_validate_call - def create( - cls, - hub_name: str, - hub_description: str, - hub_display_name: Optional[str] = Unassigned(), - hub_search_keywords: Optional[List[str]] = Unassigned(), - s3_storage_config: Optional[HubS3StorageConfig] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), + def get_record( + self, + record_identifier_value_as_string: str, + feature_names: Optional[List[str]] = Unassigned(), + expiration_time_response: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Hub"]: + ) -> Optional[GetRecordResponse]: """ - Create a Hub resource + Use for OnlineStore serving from a FeatureStore. Parameters: - hub_name: The name of the hub to create. - hub_description: A description of the hub. - hub_display_name: The display name of the hub. - hub_search_keywords: The searchable keywords for the hub. - s3_storage_config: The Amazon S3 storage configuration for the hub. - tags: Any tags to associate with the hub. + record_identifier_value_as_string: The value that corresponds to RecordIdentifier type and uniquely identifies the record in the FeatureGroup. + feature_names: List of names of Features to be retrieved. If not specified, the latest value for all the Features are returned. + expiration_time_response: Parameter to request ExpiresAt in response. If Enabled, GetRecord will return the value of ExpiresAt, if it is not null. If Disabled and null, GetRecord will return null. session: Boto3 session. region: Region name. Returns: - The Hub resource. + GetRecordResponse Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10741,61 +10862,53 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + AccessForbidden: You do not have permission to perform an action. + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + ResourceNotFound: Resource being access is not found. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. """ - logger.info("Creating hub resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { - "HubName": hub_name, - "HubDescription": hub_description, - "HubDisplayName": hub_display_name, - "HubSearchKeywords": hub_search_keywords, - "S3StorageConfig": s3_storage_config, - "Tags": tags, + "FeatureGroupName": self.feature_group_name, + "RecordIdentifierValueAsString": record_identifier_value_as_string, + "FeatureNames": feature_names, + "ExpirationTimeResponse": expiration_time_response, } - - operation_input_args = Base.populate_chained_attributes( - resource_name="Hub", operation_input_args=operation_input_args - ) - - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.create_hub(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-featurestore-runtime" + ) + + logger.debug(f"Calling get_record API") + response = client.get_record(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(hub_name=hub_name, session=session, region=region) + transformed_response = transform(response, "GetRecordResponse") + return GetRecordResponse(**transformed_response) - @classmethod @Base.add_validate_call - def get( - cls, - hub_name: str, + def put_record( + self, + record: List[FeatureValue], + target_stores: Optional[List[str]] = Unassigned(), + ttl_duration: Optional[TtlDuration] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Hub"]: + ) -> None: """ - Get a Hub resource + The PutRecord API is used to ingest a list of Records into your feature group. Parameters: - hub_name: The name of the hub to describe. + record: List of FeatureValues to be inserted. This will be a full over-write. If you only want to update few of the feature values, do the following: Use GetRecord to retrieve the latest record. Update the record returned from GetRecord. Use PutRecord to update feature values. + target_stores: A list of stores to which you're adding the record. By default, Feature Store adds the record to all of the stores that you're using for the FeatureGroup. + ttl_duration: Time to live duration, where the record is hard deleted after the expiration time is reached; ExpiresAt = EventTime + TtlDuration. For information on HardDelete, see the DeleteRecord API in the Amazon SageMaker API Reference guide. session: Boto3 session. region: Region name. - Returns: - The Hub resource. - Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -10806,37 +10919,50 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + AccessForbidden: You do not have permission to perform an action. + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. """ operation_input_args = { - "HubName": hub_name, + "FeatureGroupName": self.feature_group_name, + "Record": record, + "TargetStores": target_stores, + "TtlDuration": ttl_duration, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + session=session, region_name=region, service_name="sagemaker-featurestore-runtime" ) - response = client.describe_hub(**operation_input_args) - logger.debug(response) - - # deserialize the response - transformed_response = transform(response, "DescribeHubResponse") - hub = cls(**transformed_response) - return hub + logger.debug(f"Calling put_record API") + response = client.put_record(**operation_input_args) + logger.debug(f"Response: {response}") @Base.add_validate_call - def refresh( + def delete_record( self, - ) -> Optional["Hub"]: + record_identifier_value_as_string: str, + event_time: str, + target_stores: Optional[List[str]] = Unassigned(), + deletion_mode: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: """ - Refresh a Hub resource + Deletes a Record from a FeatureGroup in the OnlineStore. - Returns: - The Hub resource. + Parameters: + record_identifier_value_as_string: The value for the RecordIdentifier that uniquely identifies the record, in string format. + event_time: Timestamp indicating when the deletion event occurred. EventTime can be used to query data at a certain point in time. + target_stores: A list of stores from which you're deleting the record. By default, Feature Store deletes the record from all of the stores that you're using for the FeatureGroup. + deletion_mode: The name of the deletion mode for deleting the record. By default, the deletion mode is set to SoftDelete. + session: Boto3 session. + region: Region name. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10848,36 +10974,50 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + AccessForbidden: You do not have permission to perform an action. + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. """ operation_input_args = { - "HubName": self.hub_name, + "FeatureGroupName": self.feature_group_name, + "RecordIdentifierValueAsString": record_identifier_value_as_string, + "EventTime": event_time, + "TargetStores": target_stores, + "DeletionMode": deletion_mode, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() - response = client.describe_hub(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-featurestore-runtime" + ) - # deserialize response and update self - transform(response, "DescribeHubResponse", self) - return self + logger.debug(f"Calling delete_record API") + response = client.delete_record(**operation_input_args) + logger.debug(f"Response: {response}") - @populate_inputs_decorator @Base.add_validate_call - def update( + def batch_get_record( self, - hub_description: Optional[str] = Unassigned(), - hub_display_name: Optional[str] = Unassigned(), - hub_search_keywords: Optional[List[str]] = Unassigned(), - ) -> Optional["Hub"]: + identifiers: List[BatchGetRecordIdentifier], + expiration_time_response: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[BatchGetRecordResponse]: """ - Update a Hub resource + Retrieves a batch of Records from a FeatureGroup. + + Parameters: + identifiers: A list containing the name or Amazon Resource Name (ARN) of the FeatureGroup, the list of names of Features to be retrieved, and the corresponding RecordIdentifier values as strings. + expiration_time_response: Parameter to request ExpiresAt in response. If Enabled, BatchGetRecord will return the value of ExpiresAt, if it is not null. If Disabled and null, BatchGetRecord will return null. + session: Boto3 session. + region: Region name. Returns: - The Hub resource. + BatchGetRecordResponse Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10889,36 +11029,93 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + AccessForbidden: You do not have permission to perform an action. + InternalFailure: An internal failure occurred. Try your request again. If the problem persists, contact Amazon Web Services customer support. + ServiceUnavailable: The service is currently unavailable. + ValidationError: There was an error validating your request. """ - logger.info("Updating hub resource.") - client = Base.get_sagemaker_client() - operation_input_args = { - "HubName": self.hub_name, - "HubDescription": hub_description, - "HubDisplayName": hub_display_name, - "HubSearchKeywords": hub_search_keywords, + "Identifiers": identifiers, + "ExpirationTimeResponse": expiration_time_response, } - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.update_hub(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker-featurestore-runtime" + ) + + logger.debug(f"Calling batch_get_record API") + response = client.batch_get_record(**operation_input_args) logger.debug(f"Response: {response}") - self.refresh() - return self + transformed_response = transform(response, "BatchGetRecordResponse") + return BatchGetRecordResponse(**transformed_response) + + +class FeatureMetadata(Base): + """ + Class representing resource FeatureMetadata + + Attributes: + feature_group_arn: The Amazon Resource Number (ARN) of the feature group that contains the feature. + feature_group_name: The name of the feature group that you've specified. + feature_name: The name of the feature that you've specified. + feature_type: The data type of the feature. + creation_time: A timestamp indicating when the feature was created. + last_modified_time: A timestamp indicating when the metadata for the feature group was modified. For example, if you add a parameter describing the feature, the timestamp changes to reflect the last time you + description: The description you added to describe the feature. + parameters: The key-value pairs that you added to describe the feature. + + """ + + feature_group_name: str + feature_name: str + feature_group_arn: Optional[str] = Unassigned() + feature_type: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + description: Optional[str] = Unassigned() + parameters: Optional[List[FeatureParameter]] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "feature_metadata_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object feature_metadata") + return None + + @classmethod @Base.add_validate_call - def delete( - self, - ) -> None: + def get( + cls, + feature_group_name: str, + feature_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["FeatureMetadata"]: """ - Delete a Hub resource + Get a FeatureMetadata resource + + Parameters: + feature_group_name: The name or Amazon Resource Name (ARN) of the feature group containing the feature. + feature_name: The name of the feature. + session: Boto3 session. + region: Region name. + + Returns: + The FeatureMetadata resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -10930,99 +11127,38 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() - operation_input_args = { - "HubName": self.hub_name, + "FeatureGroupName": feature_group_name, + "FeatureName": feature_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_hub(**operation_input_args) - - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def wait_for_status( - self, - target_status: Literal[ - "InService", - "Creating", - "Updating", - "Deleting", - "CreateFailed", - "UpdateFailed", - "DeleteFailed", - ], - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a Hub resource to reach certain status. - - Parameters: - target_status: The status to wait for. - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. - - Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" ) - progress.add_task(f"Waiting for Hub to reach [bold]{target_status} status...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.hub_status - status.update(f"Current status: [bold]{current_status}") - - if target_status == current_status: - logger.info(f"Final Resource Status: [bold]{current_status}") - return + response = client.describe_feature_metadata(**operation_input_args) - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="Hub", status=current_status, reason=self.failure_reason - ) + logger.debug(response) - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="Hub", status=current_status) - time.sleep(poll) + # deserialize the response + transformed_response = transform(response, "DescribeFeatureMetadataResponse") + feature_metadata = cls(**transformed_response) + return feature_metadata @Base.add_validate_call - def wait_for_delete( + def refresh( self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: + ) -> Optional["FeatureMetadata"]: """ - Wait for a Hub resource to be deleted. + Refresh a FeatureMetadata resource - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + Returns: + The FeatureMetadata resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11034,76 +11170,40 @@ def wait_for_delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. + ResourceNotFound: Resource being access is not found. """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for Hub to be deleted...") - status = Status("Current status:") - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ) - ): - while True: - try: - self.refresh() - current_status = self.hub_status - status.update(f"Current status: [bold]{current_status}") + operation_input_args = { + "FeatureGroupName": self.feature_group_name, + "FeatureName": self.feature_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="Hub", status=current_status) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] + client = Base.get_sagemaker_client() + response = client.describe_feature_metadata(**operation_input_args) - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e - time.sleep(poll) + # deserialize response and update self + transform(response, "DescribeFeatureMetadataResponse", self) + return self - @classmethod @Base.add_validate_call - def get_all( - cls, - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["Hub"]: + def update( + self, + description: Optional[str] = Unassigned(), + parameter_additions: Optional[List[FeatureParameter]] = Unassigned(), + parameter_removals: Optional[List[str]] = Unassigned(), + ) -> Optional["FeatureMetadata"]: """ - Get all Hub resources + Update a FeatureMetadata resource Parameters: - name_contains: Only list hubs with names that contain the specified string. - creation_time_before: Only list hubs that were created before the time specified. - creation_time_after: Only list hubs that were created after the time specified. - last_modified_time_before: Only list hubs that were last modified before the time specified. - last_modified_time_after: Only list hubs that were last modified after the time specified. - sort_by: Sort hubs by either name or creation time. - sort_order: Sort hubs by ascending or descending order. - max_results: The maximum number of hubs to list. - next_token: If the response to a previous ListHubs request was truncated, the response includes a NextToken. To retrieve the next set of hubs, use the token in the next request. - session: Boto3 session. - region: Region name. + parameter_additions: A list of key-value pairs that you can add to better describe the feature. + parameter_removals: A list of parameter keys that you can specify to remove parameters that describe your feature. Returns: - Iterator for listed Hub resources. + The FeatureMetadata resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11115,86 +11215,64 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + logger.info("Updating feature_metadata resource.") + client = Base.get_sagemaker_client() operation_input_args = { - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, - "LastModifiedTimeBefore": last_modified_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "SortBy": sort_by, - "SortOrder": sort_order, + "FeatureGroupName": self.feature_group_name, + "FeatureName": self.feature_name, + "Description": description, + "ParameterAdditions": parameter_additions, + "ParameterRemovals": parameter_removals, } - + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_hubs", - summaries_key="HubSummaries", - summary_name="HubInfo", - resource_cls=Hub, - list_method_kwargs=operation_input_args, - ) + # create the resource + response = client.update_feature_metadata(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() + + return self -class HubContent(Base): +class FlowDefinition(Base): """ - Class representing resource HubContent + Class representing resource FlowDefinition Attributes: - hub_content_name: The name of the hub content. - hub_content_arn: The Amazon Resource Name (ARN) of the hub content. - hub_content_version: The version of the hub content. - hub_content_type: The type of hub content. - document_schema_version: The document schema version for the hub content. - hub_name: The name of the hub that contains the content. - hub_arn: The Amazon Resource Name (ARN) of the hub that contains the content. - hub_content_document: The hub content document that describes information about the hub content such as type, associated containers, scripts, and more. - hub_content_status: The status of the hub content. - creation_time: The date and time that hub content was created. - hub_content_display_name: The display name of the hub content. - hub_content_description: A description of the hub content. - hub_content_markdown: A string that provides a description of the hub content. This string can include links, tables, and standard markdown formating. - sage_maker_public_hub_content_arn: The ARN of the public hub content. - reference_min_version: The minimum version of the hub content. - support_status: The support status of the hub content. - hub_content_search_keywords: The searchable keywords for the hub content. - hub_content_dependencies: The location of any dependencies that the hub content has, such as scripts, model artifacts, datasets, or notebooks. - failure_reason: The failure reason if importing hub content failed. + flow_definition_arn: The Amazon Resource Name (ARN) of the flow defintion. + flow_definition_name: The Amazon Resource Name (ARN) of the flow definition. + flow_definition_status: The status of the flow definition. Valid values are listed below. + creation_time: The timestamp when the flow definition was created. + output_config: An object containing information about the output file. + role_arn: The Amazon Resource Name (ARN) of the Amazon Web Services Identity and Access Management (IAM) execution role for the flow definition. + human_loop_request_source: Container for configuring the source of human task requests. Used to specify if Amazon Rekognition or Amazon Textract is used as an integration source. + human_loop_activation_config: An object containing information about what triggers a human review workflow. + human_loop_config: An object containing information about who works on the task, the workforce task price, and other task details. + failure_reason: The reason your flow definition failed. """ - hub_content_type: str - hub_content_name: str - hub_content_arn: Optional[str] = Unassigned() - hub_content_version: Optional[str] = Unassigned() - document_schema_version: Optional[str] = Unassigned() - hub_arn: Optional[str] = Unassigned() - hub_content_display_name: Optional[str] = Unassigned() - hub_content_description: Optional[str] = Unassigned() - hub_content_markdown: Optional[str] = Unassigned() - hub_content_document: Optional[str] = Unassigned() - sage_maker_public_hub_content_arn: Optional[str] = Unassigned() - reference_min_version: Optional[str] = Unassigned() - support_status: Optional[str] = Unassigned() - hub_content_search_keywords: Optional[List[str]] = Unassigned() - hub_content_dependencies: Optional[List[HubContentDependency]] = Unassigned() - hub_content_status: Optional[str] = Unassigned() - failure_reason: Optional[str] = Unassigned() + flow_definition_name: str + flow_definition_arn: Optional[str] = Unassigned() + flow_definition_status: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - hub_name: Optional[str] = Unassigned() + human_loop_request_source: Optional[HumanLoopRequestSource] = Unassigned() + human_loop_activation_config: Optional[HumanLoopActivationConfig] = Unassigned() + human_loop_config: Optional[HumanLoopConfig] = Unassigned() + output_config: Optional[FlowDefinitionOutputConfig] = Unassigned() + role_arn: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "hub_content_name" + resource_name = "flow_definition_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -11205,33 +11283,59 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object hub_content") + logger.error("Name attribute not found for object flow_definition") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "output_config": { + "s3_output_path": {"type": "string"}, + "kms_key_id": {"type": "string"}, + }, + "role_arn": {"type": "string"}, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "FlowDefinition", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call - def get( + def create( cls, - hub_name: str, - hub_content_type: str, - hub_content_name: str, - hub_content_version: Optional[str] = Unassigned(), + flow_definition_name: str, + output_config: FlowDefinitionOutputConfig, + role_arn: str, + human_loop_request_source: Optional[HumanLoopRequestSource] = Unassigned(), + human_loop_activation_config: Optional[HumanLoopActivationConfig] = Unassigned(), + human_loop_config: Optional[HumanLoopConfig] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["HubContent"]: + ) -> Optional["FlowDefinition"]: """ - Get a HubContent resource + Create a FlowDefinition resource Parameters: - hub_name: The name of the hub that contains the content to describe. - hub_content_type: The type of content in the hub. - hub_content_name: The name of the content to describe. - hub_content_version: The version of the content to describe. + flow_definition_name: The name of your flow definition. + output_config: An object containing information about where the human review results will be uploaded. + role_arn: The Amazon Resource Name (ARN) of the role needed to call other services on your behalf. For example, arn:aws:iam::1234567890:role/service-role/AmazonSageMaker-ExecutionRole-20180111T151298. + human_loop_request_source: Container for configuring the source of human task requests. Use to specify if Amazon Rekognition or Amazon Textract is used as an integration source. + human_loop_activation_config: An object containing information about the events that trigger a human workflow. + human_loop_config: An object containing information about the tasks the human reviewers will perform. + tags: An array of key-value pairs that contain metadata to help you categorize and organize a flow definition. Each tag consists of a key and a value, both of which you define. session: Boto3 session. region: Region name. Returns: - The HubContent resource. + The FlowDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11243,40 +11347,103 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - operation_input_args = { - "HubName": hub_name, - "HubContentType": hub_content_type, - "HubContentName": hub_content_name, - "HubContentVersion": hub_content_version, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - + logger.info("Creating flow_definition resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_hub_content(**operation_input_args) - - logger.debug(response) - - # deserialize the response - transformed_response = transform(response, "DescribeHubContentResponse") - hub_content = cls(**transformed_response) - return hub_content - @Base.add_validate_call - def refresh( + operation_input_args = { + "FlowDefinitionName": flow_definition_name, + "HumanLoopRequestSource": human_loop_request_source, + "HumanLoopActivationConfig": human_loop_activation_config, + "HumanLoopConfig": human_loop_config, + "OutputConfig": output_config, + "RoleArn": role_arn, + "Tags": tags, + } + + operation_input_args = Base.populate_chained_attributes( + resource_name="FlowDefinition", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.create_flow_definition(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get(flow_definition_name=flow_definition_name, session=session, region=region) + + @classmethod + @Base.add_validate_call + def get( + cls, + flow_definition_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["FlowDefinition"]: + """ + Get a FlowDefinition resource + + Parameters: + flow_definition_name: The name of the flow definition. + session: Boto3 session. + region: Region name. + + Returns: + The FlowDefinition resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "FlowDefinitionName": flow_definition_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + response = client.describe_flow_definition(**operation_input_args) + + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeFlowDefinitionResponse") + flow_definition = cls(**transformed_response) + return flow_definition + + @Base.add_validate_call + def refresh( self, - ) -> Optional["HubContent"]: + ) -> Optional["FlowDefinition"]: """ - Refresh a HubContent resource + Refresh a FlowDefinition resource Returns: - The HubContent resource. + The FlowDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11292,20 +11459,17 @@ def refresh( """ operation_input_args = { - "HubName": self.hub_name, - "HubContentType": self.hub_content_type, - "HubContentName": self.hub_content_name, - "HubContentVersion": self.hub_content_version, + "FlowDefinitionName": self.flow_definition_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_hub_content(**operation_input_args) + response = client.describe_flow_definition(**operation_input_args) # deserialize response and update self - transform(response, "DescribeHubContentResponse", self) + transform(response, "DescribeFlowDefinitionResponse", self) return self @Base.add_validate_call @@ -11313,7 +11477,7 @@ def delete( self, ) -> None: """ - Delete a HubContent resource + Delete a FlowDefinition resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11332,28 +11496,25 @@ def delete( client = Base.get_sagemaker_client() operation_input_args = { - "HubName": self.hub_name, - "HubContentType": self.hub_content_type, - "HubContentName": self.hub_content_name, - "HubContentVersion": self.hub_content_version, + "FlowDefinitionName": self.flow_definition_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_hub_content(**operation_input_args) + client.delete_flow_definition(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call def wait_for_status( self, - target_status: Literal["Supported", "Deprecated"], + target_status: Literal["Initializing", "Active", "Failed", "Deleting"], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a HubContent resource to reach certain status. + Wait for a FlowDefinition resource to reach certain status. Parameters: target_status: The status to wait for. @@ -11372,7 +11533,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task(f"Waiting for HubContent to reach [bold]{target_status} status...") + progress.add_task(f"Waiting for FlowDefinition to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -11385,55 +11546,111 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.support_status + current_status = self.flow_definition_status status.update(f"Current status: [bold]{current_status}") if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") return + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="FlowDefinition", + status=current_status, + reason=self.failure_reason, + ) + if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="HubContent", status=current_status) + raise TimeoutExceededError(resouce_type="FlowDefinition", status=current_status) + time.sleep(poll) + + @Base.add_validate_call + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a FlowDefinition resource to be deleted. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for FlowDefinition to be deleted...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.flow_definition_status + status.update(f"Current status: [bold]{current_status}") + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="FlowDefinition", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e time.sleep(poll) @classmethod @Base.add_validate_call - def load( + def get_all( cls, - hub_content_name: str, - hub_content_type: str, - document_schema_version: str, - hub_name: str, - hub_content_document: str, - hub_content_version: Optional[str] = Unassigned(), - hub_content_display_name: Optional[str] = Unassigned(), - hub_content_description: Optional[str] = Unassigned(), - hub_content_markdown: Optional[str] = Unassigned(), - hub_content_search_keywords: Optional[List[str]] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["HubContent"]: + ) -> ResourceIterator["FlowDefinition"]: """ - Import a HubContent resource + Get all FlowDefinition resources Parameters: - hub_content_name: The name of the hub content to import. - hub_content_type: The type of hub content to import. - document_schema_version: The version of the hub content schema to import. - hub_name: The name of the hub to import content into. - hub_content_document: The hub content document that describes information about the hub content such as type, associated containers, scripts, and more. - hub_content_version: The version of the hub content to import. - hub_content_display_name: The display name of the hub content to import. - hub_content_description: A description of the hub content to import. - hub_content_markdown: A string that provides a description of the hub content. This string can include links, tables, and standard markdown formating. - hub_content_search_keywords: The searchable keywords of the hub content. - tags: Any tags associated with the hub content. + creation_time_after: A filter that returns only flow definitions with a creation time greater than or equal to the specified timestamp. + creation_time_before: A filter that returns only flow definitions that were created before the specified timestamp. + sort_order: An optional value that specifies whether you want the results sorted in Ascending or Descending order. + next_token: A token to resume pagination. + max_results: The total number of items to return. If the total number of available items is more than the value specified in MaxResults, then a NextToken will be provided in the output that you can use to resume pagination. session: Boto3 session. region: Region name. Returns: - The HubContent resource. + Iterator for listed FlowDefinition resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11445,76 +11662,121 @@ def load( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ResourceNotFound: Resource being access is not found. """ - logger.info(f"Importing hub_content resource.") - client = SageMakerClient( + client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" - ).client + ) operation_input_args = { - "HubContentName": hub_content_name, - "HubContentVersion": hub_content_version, - "HubContentType": hub_content_type, - "DocumentSchemaVersion": document_schema_version, - "HubName": hub_name, - "HubContentDisplayName": hub_content_display_name, - "HubContentDescription": hub_content_description, - "HubContentMarkdown": hub_content_markdown, - "HubContentDocument": hub_content_document, - "HubContentSearchKeywords": hub_content_search_keywords, - "Tags": tags, + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "SortOrder": sort_order, } - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # import the resource - response = client.import_hub_content(**operation_input_args) - logger.debug(f"Response: {response}") + return ResourceIterator( + client=client, + list_method="list_flow_definitions", + summaries_key="FlowDefinitionSummaries", + summary_name="FlowDefinitionSummary", + resource_cls=FlowDefinition, + list_method_kwargs=operation_input_args, + ) - return cls.get( - hub_name=hub_name, - hub_content_type=hub_content_type, - hub_content_name=hub_content_name, - session=session, - region=region, - ) +class Hub(Base): + """ + Class representing resource Hub + + Attributes: + hub_name: The name of the hub. + hub_arn: The Amazon Resource Name (ARN) of the hub. + hub_status: The status of the hub. + creation_time: The date and time that the hub was created. + last_modified_time: The date and time that the hub was last modified. + hub_display_name: The display name of the hub. + hub_description: A description of the hub. + hub_search_keywords: The searchable keywords for the hub. + s3_storage_config: The Amazon S3 storage configuration for the hub. + failure_reason: The failure reason if importing hub content failed. + + """ + + hub_name: str + hub_arn: Optional[str] = Unassigned() + hub_display_name: Optional[str] = Unassigned() + hub_description: Optional[str] = Unassigned() + hub_search_keywords: Optional[List[str]] = Unassigned() + s3_storage_config: Optional[HubS3StorageConfig] = Unassigned() + hub_status: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "hub_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object hub") + return None + + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "s3_storage_config": {"s3_output_path": {"type": "string"}} + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "Hub", **kwargs + ), + ) + + return wrapper + + @classmethod + @populate_inputs_decorator @Base.add_validate_call - def get_all_versions( - self, - min_version: Optional[str] = Unassigned(), - max_schema_version: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), + def create( + cls, + hub_name: str, + hub_description: str, + hub_display_name: Optional[str] = Unassigned(), + hub_search_keywords: Optional[List[str]] = Unassigned(), + s3_storage_config: Optional[HubS3StorageConfig] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["HubContent"]: + ) -> Optional["Hub"]: """ - List hub content versions. + Create a Hub resource Parameters: - min_version: The lower bound of the hub content versions to list. - max_schema_version: The upper bound of the hub content schema version. - creation_time_before: Only list hub content versions that were created before the time specified. - creation_time_after: Only list hub content versions that were created after the time specified. - sort_by: Sort hub content versions by either name or creation time. - sort_order: Sort hub content versions by ascending or descending order. - max_results: The maximum number of hub content versions to list. - next_token: If the response to a previous ListHubContentVersions request was truncated, the response includes a NextToken. To retrieve the next set of hub content versions, use the token in the next request. + hub_name: The name of the hub to create. + hub_description: A description of the hub. + hub_display_name: The display name of the hub. + hub_search_keywords: The searchable keywords for the hub. + s3_storage_config: The Amazon S3 storage configuration for the hub. + tags: Any tags to associate with the hub. session: Boto3 session. region: Region name. Returns: - Iterator for listed HubContent. + The Hub resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11526,101 +11788,60 @@ def get_all_versions( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - operation_input_args = { - "HubName": self.hub_name, - "HubContentType": self.hub_content_type, - "HubContentName": self.hub_content_name, - "MinVersion": min_version, - "MaxSchemaVersion": max_schema_version, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, - "SortBy": sort_by, - "SortOrder": sort_order, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - + logger.info("Creating hub resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - return ResourceIterator( - client=client, - list_method="list_hub_content_versions", - summaries_key="HubContentSummaries", - summary_name="HubContentInfo", - resource_cls=HubContent, - list_method_kwargs=operation_input_args, - ) - - -class HubContentReference(Base): - """ - Class representing resource HubContentReference - - Attributes: - hub_name: The name of the hub to add the hub content reference to. - sage_maker_public_hub_content_arn: The ARN of the public hub content to reference. - hub_arn: The ARN of the hub that the hub content reference was added to. - hub_content_arn: The ARN of the hub content. - hub_content_name: The name of the hub content to reference. - min_version: The minimum version of the hub content to reference. - tags: Any tags associated with the hub content to reference. - - """ + operation_input_args = { + "HubName": hub_name, + "HubDescription": hub_description, + "HubDisplayName": hub_display_name, + "HubSearchKeywords": hub_search_keywords, + "S3StorageConfig": s3_storage_config, + "Tags": tags, + } - hub_name: Union[str, object] - sage_maker_public_hub_content_arn: str - hub_arn: str - hub_content_arn: str - hub_content_name: Optional[Union[str, object]] = Unassigned() - min_version: Optional[str] = Unassigned() - tags: Optional[List[Tag]] = Unassigned() + operation_input_args = Base.populate_chained_attributes( + resource_name="Hub", operation_input_args=operation_input_args + ) - def get_name(self) -> str: - attributes = vars(self) - resource_name = "hub_content_reference_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) + # create the resource + response = client.create_hub(**operation_input_args) + logger.debug(f"Response: {response}") - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object hub_content_reference") - return None + return cls.get(hub_name=hub_name, session=session, region=region) @classmethod @Base.add_validate_call - def create( + def get( cls, - hub_name: Union[str, object], - sage_maker_public_hub_content_arn: str, - hub_content_name: Optional[Union[str, object]] = Unassigned(), - min_version: Optional[str] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), - ) -> Optional["HubContentReference"]: + hub_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["Hub"]: """ - Create a HubContentReference resource + Get a Hub resource Parameters: - hub_name: The name of the hub to add the hub content reference to. - sage_maker_public_hub_content_arn: The ARN of the public hub content to reference. - hub_content_name: The name of the hub content to reference. - min_version: The minimum version of the hub content to reference. - tags: Any tags associated with the hub content to reference. + hub_name: The name of the hub to describe. session: Boto3 session. region: Region name. Returns: - The HubContentReference resource. + The Hub resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11632,20 +11853,11 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ResourceNotFound: Resource being access is not found. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ operation_input_args = { "HubName": hub_name, - "SageMakerPublicHubContentArn": sage_maker_public_hub_content_arn, - "HubContentName": hub_content_name, - "MinVersion": min_version, - "Tags": tags, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -11654,21 +11866,24 @@ def create( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) + response = client.describe_hub(**operation_input_args) - logger.debug(f"Calling create_hub_content_reference API") - response = client.create_hub_content_reference(**operation_input_args) - logger.debug(f"Response: {response}") + logger.debug(response) - transformed_response = transform(response, "CreateHubContentReferenceResponse") - return cls(**operation_input_args, **transformed_response) + # deserialize the response + transformed_response = transform(response, "DescribeHubResponse") + hub = cls(**transformed_response) + return hub @Base.add_validate_call - def delete( + def refresh( self, - hub_content_type: str, - ) -> None: + ) -> Optional["Hub"]: """ - Delete a HubContentReference resource + Refresh a Hub resource + + Returns: + The Hub resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11683,79 +11898,33 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() - operation_input_args = { "HubName": self.hub_name, - "HubContentType": hub_content_type, - "HubContentName": self.hub_content_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_hub_content_reference(**operation_input_args) - - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - + client = Base.get_sagemaker_client() + response = client.describe_hub(**operation_input_args) -class HumanTaskUi(Base): - """ - Class representing resource HumanTaskUi + # deserialize response and update self + transform(response, "DescribeHubResponse", self) + return self - Attributes: - human_task_ui_arn: The Amazon Resource Name (ARN) of the human task user interface (worker task template). - human_task_ui_name: The name of the human task user interface (worker task template). - creation_time: The timestamp when the human task user interface was created. - ui_template: - human_task_ui_status: The status of the human task user interface (worker task template). Valid values are listed below. - - """ - - human_task_ui_name: str - human_task_ui_arn: Optional[str] = Unassigned() - human_task_ui_status: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - ui_template: Optional[UiTemplateInfo] = Unassigned() - - def get_name(self) -> str: - attributes = vars(self) - resource_name = "human_task_ui_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] - - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) - - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object human_task_ui") - return None - - @classmethod + @populate_inputs_decorator @Base.add_validate_call - def create( - cls, - human_task_ui_name: str, - ui_template: UiTemplate, - tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["HumanTaskUi"]: + def update( + self, + hub_description: Optional[str] = Unassigned(), + hub_display_name: Optional[str] = Unassigned(), + hub_search_keywords: Optional[List[str]] = Unassigned(), + ) -> Optional["Hub"]: """ - Create a HumanTaskUi resource - - Parameters: - human_task_ui_name: The name of the user interface you are creating. - ui_template: - tags: An array of key-value pairs that contain metadata to help you categorize and organize a human review workflow user interface. Each tag consists of a key and a value, both of which you define. - session: Boto3 session. - region: Region name. + Update a Hub resource Returns: - The HumanTaskUi resource. + The Hub resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11767,125 +11936,28 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + ResourceNotFound: Resource being access is not found. """ - logger.info("Creating human_task_ui resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + logger.info("Updating hub resource.") + client = Base.get_sagemaker_client() operation_input_args = { - "HumanTaskUiName": human_task_ui_name, - "UiTemplate": ui_template, - "Tags": tags, + "HubName": self.hub_name, + "HubDescription": hub_description, + "HubDisplayName": hub_display_name, + "HubSearchKeywords": hub_search_keywords, } - - operation_input_args = Base.populate_chained_attributes( - resource_name="HumanTaskUi", operation_input_args=operation_input_args - ) - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_human_task_ui(**operation_input_args) + response = client.update_hub(**operation_input_args) logger.debug(f"Response: {response}") + self.refresh() - return cls.get(human_task_ui_name=human_task_ui_name, session=session, region=region) - - @classmethod - @Base.add_validate_call - def get( - cls, - human_task_ui_name: str, - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["HumanTaskUi"]: - """ - Get a HumanTaskUi resource - - Parameters: - human_task_ui_name: The name of the human task user interface (worker task template) you want information about. - session: Boto3 session. - region: Region name. - - Returns: - The HumanTaskUi resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceNotFound: Resource being access is not found. - """ - - operation_input_args = { - "HumanTaskUiName": human_task_ui_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - response = client.describe_human_task_ui(**operation_input_args) - - logger.debug(response) - - # deserialize the response - transformed_response = transform(response, "DescribeHumanTaskUiResponse") - human_task_ui = cls(**transformed_response) - return human_task_ui - - @Base.add_validate_call - def refresh( - self, - ) -> Optional["HumanTaskUi"]: - """ - Refresh a HumanTaskUi resource - - Returns: - The HumanTaskUi resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceNotFound: Resource being access is not found. - """ - - operation_input_args = { - "HumanTaskUiName": self.human_task_ui_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client = Base.get_sagemaker_client() - response = client.describe_human_task_ui(**operation_input_args) - - # deserialize response and update self - transform(response, "DescribeHumanTaskUiResponse", self) return self @Base.add_validate_call @@ -11893,7 +11965,7 @@ def delete( self, ) -> None: """ - Delete a HumanTaskUi resource + Delete a Hub resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -11905,31 +11977,40 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "HumanTaskUiName": self.human_task_ui_name, + "HubName": self.hub_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_human_task_ui(**operation_input_args) + client.delete_hub(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call def wait_for_status( self, - target_status: Literal["Active", "Deleting"], + target_status: Literal[ + "InService", + "Creating", + "Updating", + "Deleting", + "CreateFailed", + "UpdateFailed", + "DeleteFailed", + ], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a HumanTaskUi resource to reach certain status. + Wait for a Hub resource to reach certain status. Parameters: target_status: The status to wait for. @@ -11948,7 +12029,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task(f"Waiting for HumanTaskUi to reach [bold]{target_status} status...") + progress.add_task(f"Waiting for Hub to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -11961,15 +12042,20 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.human_task_ui_status + current_status = self.hub_status status.update(f"Current status: [bold]{current_status}") if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") return + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="Hub", status=current_status, reason=self.failure_reason + ) + if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="HumanTaskUi", status=current_status) + raise TimeoutExceededError(resouce_type="Hub", status=current_status) time.sleep(poll) @Base.add_validate_call @@ -11979,7 +12065,7 @@ def wait_for_delete( timeout: Optional[int] = None, ) -> None: """ - Wait for a HumanTaskUi resource to be deleted. + Wait for a Hub resource to be deleted. Parameters: poll: The number of seconds to wait between each poll. @@ -12006,7 +12092,7 @@ def wait_for_delete( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for HumanTaskUi to be deleted...") + progress.add_task("Waiting for Hub to be deleted...") status = Status("Current status:") with Live( @@ -12019,13 +12105,11 @@ def wait_for_delete( while True: try: self.refresh() - current_status = self.human_task_ui_status + current_status = self.hub_status status.update(f"Current status: [bold]{current_status}") if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="HumanTaskUi", status=current_status - ) + raise TimeoutExceededError(resouce_type="Hub", status=current_status) except botocore.exceptions.ClientError as e: error_code = e.response["Error"]["Code"] @@ -12039,26 +12123,34 @@ def wait_for_delete( @Base.add_validate_call def get_all( cls, - creation_time_after: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["HumanTaskUi"]: + ) -> ResourceIterator["Hub"]: """ - Get all HumanTaskUi resources + Get all Hub resources Parameters: - creation_time_after: A filter that returns only human task user interfaces with a creation time greater than or equal to the specified timestamp. - creation_time_before: A filter that returns only human task user interfaces that were created before the specified timestamp. - sort_order: An optional value that specifies whether you want the results sorted in Ascending or Descending order. - next_token: A token to resume pagination. - max_results: The total number of items to return. If the total number of available items is more than the value specified in MaxResults, then a NextToken will be provided in the output that you can use to resume pagination. + name_contains: Only list hubs with names that contain the specified string. + creation_time_before: Only list hubs that were created before the time specified. + creation_time_after: Only list hubs that were created after the time specified. + last_modified_time_before: Only list hubs that were last modified before the time specified. + last_modified_time_after: Only list hubs that were last modified after the time specified. + sort_by: Sort hubs by either name or creation time. + sort_order: Sort hubs by ascending or descending order. + max_results: The maximum number of hubs to list. + next_token: If the response to a previous ListHubs request was truncated, the response includes a NextToken. To retrieve the next set of hubs, use the token in the next request. session: Boto3 session. region: Region name. Returns: - Iterator for listed HumanTaskUi resources. + Iterator for listed Hub resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12077,8 +12169,12 @@ def get_all( ) operation_input_args = { - "CreationTimeAfter": creation_time_after, + "NameContains": name_contains, "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "SortBy": sort_by, "SortOrder": sort_order, } @@ -12088,62 +12184,64 @@ def get_all( return ResourceIterator( client=client, - list_method="list_human_task_uis", - summaries_key="HumanTaskUiSummaries", - summary_name="HumanTaskUiSummary", - resource_cls=HumanTaskUi, + list_method="list_hubs", + summaries_key="HubSummaries", + summary_name="HubInfo", + resource_cls=Hub, list_method_kwargs=operation_input_args, ) -class HyperParameterTuningJob(Base): +class HubContent(Base): """ - Class representing resource HyperParameterTuningJob + Class representing resource HubContent Attributes: - hyper_parameter_tuning_job_name: The name of the hyperparameter tuning job. - hyper_parameter_tuning_job_arn: The Amazon Resource Name (ARN) of the tuning job. - hyper_parameter_tuning_job_config: The HyperParameterTuningJobConfig object that specifies the configuration of the tuning job. - hyper_parameter_tuning_job_status: The status of the tuning job. - creation_time: The date and time that the tuning job started. - training_job_status_counters: The TrainingJobStatusCounters object that specifies the number of training jobs, categorized by status, that this tuning job launched. - objective_status_counters: The ObjectiveStatusCounters object that specifies the number of training jobs, categorized by the status of their final objective metric, that this tuning job launched. - training_job_definition: The HyperParameterTrainingJobDefinition object that specifies the definition of the training jobs that this tuning job launches. - training_job_definitions: A list of the HyperParameterTrainingJobDefinition objects launched for this tuning job. - hyper_parameter_tuning_end_time: The date and time that the tuning job ended. - last_modified_time: The date and time that the status of the tuning job was modified. - best_training_job: A TrainingJobSummary object that describes the training job that completed with the best current HyperParameterTuningJobObjective. - overall_best_training_job: If the hyperparameter tuning job is an warm start tuning job with a WarmStartType of IDENTICAL_DATA_AND_ALGORITHM, this is the TrainingJobSummary for the training job with the best objective metric value of all training jobs launched by this tuning job and all parent jobs specified for the warm start tuning job. - warm_start_config: The configuration for starting the hyperparameter parameter tuning job using one or more previous tuning jobs as a starting point. The results of previous tuning jobs are used to inform which combinations of hyperparameters to search over in the new tuning job. - autotune: A flag to indicate if autotune is enabled for the hyperparameter tuning job. - failure_reason: If the tuning job failed, the reason it failed. - tuning_job_completion_details: Tuning job completion information returned as the response from a hyperparameter tuning job. This information tells if your tuning job has or has not converged. It also includes the number of training jobs that have not improved model performance as evaluated against the objective function. - consumed_resources: + hub_content_name: The name of the hub content. + hub_content_arn: The Amazon Resource Name (ARN) of the hub content. + hub_content_version: The version of the hub content. + hub_content_type: The type of hub content. + document_schema_version: The document schema version for the hub content. + hub_name: The name of the hub that contains the content. + hub_arn: The Amazon Resource Name (ARN) of the hub that contains the content. + hub_content_document: The hub content document that describes information about the hub content such as type, associated containers, scripts, and more. + hub_content_status: The status of the hub content. + creation_time: The date and time that hub content was created. + hub_content_display_name: The display name of the hub content. + hub_content_description: A description of the hub content. + hub_content_markdown: A string that provides a description of the hub content. This string can include links, tables, and standard markdown formating. + sage_maker_public_hub_content_arn: The ARN of the public hub content. + reference_min_version: The minimum version of the hub content. + support_status: The support status of the hub content. + hub_content_search_keywords: The searchable keywords for the hub content. + hub_content_dependencies: The location of any dependencies that the hub content has, such as scripts, model artifacts, datasets, or notebooks. + failure_reason: The failure reason if importing hub content failed. """ - hyper_parameter_tuning_job_name: str - hyper_parameter_tuning_job_arn: Optional[str] = Unassigned() - hyper_parameter_tuning_job_config: Optional[HyperParameterTuningJobConfig] = Unassigned() - training_job_definition: Optional[HyperParameterTrainingJobDefinition] = Unassigned() - training_job_definitions: Optional[List[HyperParameterTrainingJobDefinition]] = Unassigned() - hyper_parameter_tuning_job_status: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - hyper_parameter_tuning_end_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - training_job_status_counters: Optional[TrainingJobStatusCounters] = Unassigned() - objective_status_counters: Optional[ObjectiveStatusCounters] = Unassigned() - best_training_job: Optional[HyperParameterTrainingJobSummary] = Unassigned() - overall_best_training_job: Optional[HyperParameterTrainingJobSummary] = Unassigned() - warm_start_config: Optional[HyperParameterTuningJobWarmStartConfig] = Unassigned() - autotune: Optional[Autotune] = Unassigned() + hub_content_type: str + hub_content_name: str + hub_content_arn: Optional[str] = Unassigned() + hub_content_version: Optional[str] = Unassigned() + document_schema_version: Optional[str] = Unassigned() + hub_arn: Optional[str] = Unassigned() + hub_content_display_name: Optional[str] = Unassigned() + hub_content_description: Optional[str] = Unassigned() + hub_content_markdown: Optional[str] = Unassigned() + hub_content_document: Optional[str] = Unassigned() + sage_maker_public_hub_content_arn: Optional[str] = Unassigned() + reference_min_version: Optional[str] = Unassigned() + support_status: Optional[str] = Unassigned() + hub_content_search_keywords: Optional[List[str]] = Unassigned() + hub_content_dependencies: Optional[List[HubContentDependency]] = Unassigned() + hub_content_status: Optional[str] = Unassigned() failure_reason: Optional[str] = Unassigned() - tuning_job_completion_details: Optional[HyperParameterTuningJobCompletionDetails] = Unassigned() - consumed_resources: Optional[HyperParameterTuningJobConsumedResources] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + hub_name: Optional[str] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "hyper_parameter_tuning_job_name" + resource_name = "hub_content_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -12154,142 +12252,33 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object hyper_parameter_tuning_job") + logger.error("Name attribute not found for object hub_content") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "training_job_definition": { - "role_arn": {"type": "string"}, - "output_data_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - }, - "resource_config": {"volume_kms_key_id": {"type": "string"}}, - "hyper_parameter_tuning_resource_config": { - "volume_kms_key_id": {"type": "string"} - }, - "checkpoint_config": {"s3_uri": {"type": "string"}}, - } - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "HyperParameterTuningJob", **kwargs - ), - ) - - return wrapper - - @classmethod - @populate_inputs_decorator - @Base.add_validate_call - def create( - cls, - hyper_parameter_tuning_job_name: str, - hyper_parameter_tuning_job_config: HyperParameterTuningJobConfig, - training_job_definition: Optional[HyperParameterTrainingJobDefinition] = Unassigned(), - training_job_definitions: Optional[ - List[HyperParameterTrainingJobDefinition] - ] = Unassigned(), - warm_start_config: Optional[HyperParameterTuningJobWarmStartConfig] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), - autotune: Optional[Autotune] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["HyperParameterTuningJob"]: - """ - Create a HyperParameterTuningJob resource - - Parameters: - hyper_parameter_tuning_job_name: The name of the tuning job. This name is the prefix for the names of all training jobs that this tuning job launches. The name must be unique within the same Amazon Web Services account and Amazon Web Services Region. The name must have 1 to 32 characters. Valid characters are a-z, A-Z, 0-9, and : + = @ _ % - (hyphen). The name is not case sensitive. - hyper_parameter_tuning_job_config: The HyperParameterTuningJobConfig object that describes the tuning job, including the search strategy, the objective metric used to evaluate training jobs, ranges of parameters to search, and resource limits for the tuning job. For more information, see How Hyperparameter Tuning Works. - training_job_definition: The HyperParameterTrainingJobDefinition object that describes the training jobs that this tuning job launches, including static hyperparameters, input data configuration, output data configuration, resource configuration, and stopping condition. - training_job_definitions: A list of the HyperParameterTrainingJobDefinition objects launched for this tuning job. - warm_start_config: Specifies the configuration for starting the hyperparameter tuning job using one or more previous tuning jobs as a starting point. The results of previous tuning jobs are used to inform which combinations of hyperparameters to search over in the new tuning job. All training jobs launched by the new hyperparameter tuning job are evaluated by using the objective metric. If you specify IDENTICAL_DATA_AND_ALGORITHM as the WarmStartType value for the warm start configuration, the training job that performs the best in the new tuning job is compared to the best training jobs from the parent tuning jobs. From these, the training job that performs the best as measured by the objective metric is returned as the overall best training job. All training jobs launched by parent hyperparameter tuning jobs and the new hyperparameter tuning jobs count against the limit of training jobs for the tuning job. - tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. Tags that you specify for the tuning job are also added to all training jobs that the tuning job launches. - autotune: Configures SageMaker Automatic model tuning (AMT) to automatically find optimal parameters for the following fields: ParameterRanges: The names and ranges of parameters that a hyperparameter tuning job can optimize. ResourceLimits: The maximum resources that can be used for a training job. These resources include the maximum number of training jobs, the maximum runtime of a tuning job, and the maximum number of training jobs to run at the same time. TrainingJobEarlyStoppingType: A flag that specifies whether or not to use early stopping for training jobs launched by a hyperparameter tuning job. RetryStrategy: The number of times to retry a training job. Strategy: Specifies how hyperparameter tuning chooses the combinations of hyperparameter values to use for the training jobs that it launches. ConvergenceDetected: A flag to indicate that Automatic model tuning (AMT) has detected model convergence. - session: Boto3 session. - region: Region name. - - Returns: - The HyperParameterTuningJob resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceInUse: Resource being accessed is in use. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 - """ - - logger.info("Creating hyper_parameter_tuning_job resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - operation_input_args = { - "HyperParameterTuningJobName": hyper_parameter_tuning_job_name, - "HyperParameterTuningJobConfig": hyper_parameter_tuning_job_config, - "TrainingJobDefinition": training_job_definition, - "TrainingJobDefinitions": training_job_definitions, - "WarmStartConfig": warm_start_config, - "Tags": tags, - "Autotune": autotune, - } - - operation_input_args = Base.populate_chained_attributes( - resource_name="HyperParameterTuningJob", operation_input_args=operation_input_args - ) - - logger.debug(f"Input request: {operation_input_args}") - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - # create the resource - response = client.create_hyper_parameter_tuning_job(**operation_input_args) - logger.debug(f"Response: {response}") - - return cls.get( - hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name, - session=session, - region=region, - ) - @classmethod @Base.add_validate_call def get( cls, - hyper_parameter_tuning_job_name: str, + hub_name: str, + hub_content_type: str, + hub_content_name: str, + hub_content_version: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["HyperParameterTuningJob"]: + ) -> Optional["HubContent"]: """ - Get a HyperParameterTuningJob resource + Get a HubContent resource Parameters: - hyper_parameter_tuning_job_name: The name of the tuning job. + hub_name: The name of the hub that contains the content to describe. + hub_content_type: The type of content in the hub. + hub_content_name: The name of the content to describe. + hub_content_version: The version of the content to describe. session: Boto3 session. region: Region name. Returns: - The HyperParameterTuningJob resource. + The HubContent resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12305,7 +12294,10 @@ def get( """ operation_input_args = { - "HyperParameterTuningJobName": hyper_parameter_tuning_job_name, + "HubName": hub_name, + "HubContentType": hub_content_type, + "HubContentName": hub_content_name, + "HubContentVersion": hub_content_version, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -12314,24 +12306,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_hyper_parameter_tuning_job(**operation_input_args) + response = client.describe_hub_content(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeHyperParameterTuningJobResponse") - hyper_parameter_tuning_job = cls(**transformed_response) - return hyper_parameter_tuning_job + transformed_response = transform(response, "DescribeHubContentResponse") + hub_content = cls(**transformed_response) + return hub_content @Base.add_validate_call def refresh( self, - ) -> Optional["HyperParameterTuningJob"]: + ) -> Optional["HubContent"]: """ - Refresh a HyperParameterTuningJob resource + Refresh a HubContent resource Returns: - The HyperParameterTuningJob resource. + The HubContent resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12347,17 +12339,20 @@ def refresh( """ operation_input_args = { - "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, + "HubName": self.hub_name, + "HubContentType": self.hub_content_type, + "HubContentName": self.hub_content_name, + "HubContentVersion": self.hub_content_version, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_hyper_parameter_tuning_job(**operation_input_args) + response = client.describe_hub_content(**operation_input_args) # deserialize response and update self - transform(response, "DescribeHyperParameterTuningJobResponse", self) + transform(response, "DescribeHubContentResponse", self) return self @Base.add_validate_call @@ -12365,7 +12360,7 @@ def delete( self, ) -> None: """ - Delete a HyperParameterTuningJob resource + Delete a HubContent resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12377,62 +12372,38 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. + ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, + "HubName": self.hub_name, + "HubContentType": self.hub_content_type, + "HubContentName": self.hub_content_name, + "HubContentVersion": self.hub_content_version, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_hyper_parameter_tuning_job(**operation_input_args) + client.delete_hub_content(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def stop(self) -> None: + def wait_for_status( + self, + target_status: Literal["Supported", "Deprecated"], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: """ - Stop a HyperParameterTuningJob resource - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceNotFound: Resource being access is not found. - """ - - client = SageMakerClient().client - - operation_input_args = { - "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client.stop_hyper_parameter_tuning_job(**operation_input_args) - - logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def wait( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a HyperParameterTuningJob resource. + Wait for a HubContent resource to reach certain status. Parameters: + target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. @@ -12440,9 +12411,7 @@ def wait( TimeoutExceededError: If the resource does not reach a terminal state before the timeout. FailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. - """ - terminal_states = ["Completed", "Failed", "Stopped", "DeleteFailed"] start_time = time.time() progress = Progress( @@ -12450,7 +12419,7 @@ def wait( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for HyperParameterTuningJob...") + progress.add_task(f"Waiting for HubContent to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -12463,124 +12432,55 @@ def wait( ): while True: self.refresh() - current_status = self.hyper_parameter_tuning_job_status + current_status = self.support_status status.update(f"Current status: [bold]{current_status}") - if current_status in terminal_states: + if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") - - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="HyperParameterTuningJob", - status=current_status, - reason=self.failure_reason, - ) - return if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="HyperParameterTuningJob", status=current_status - ) - time.sleep(poll) - - @Base.add_validate_call - def wait_for_delete( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a HyperParameterTuningJob resource to be deleted. - - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for HyperParameterTuningJob to be deleted...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ) - ): - while True: - try: - self.refresh() - current_status = self.hyper_parameter_tuning_job_status - status.update(f"Current status: [bold]{current_status}") - - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="HyperParameterTuningJob", status=current_status - ) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] - - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e + raise TimeoutExceededError(resouce_type="HubContent", status=current_status) time.sleep(poll) @classmethod @Base.add_validate_call - def get_all( + def load( cls, - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - status_equals: Optional[str] = Unassigned(), + hub_content_name: str, + hub_content_type: str, + document_schema_version: str, + hub_name: str, + hub_content_document: str, + hub_content_version: Optional[str] = Unassigned(), + hub_content_display_name: Optional[str] = Unassigned(), + hub_content_description: Optional[str] = Unassigned(), + hub_content_markdown: Optional[str] = Unassigned(), + hub_content_search_keywords: Optional[List[str]] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["HyperParameterTuningJob"]: + ) -> Optional["HubContent"]: """ - Get all HyperParameterTuningJob resources + Import a HubContent resource Parameters: - next_token: If the result of the previous ListHyperParameterTuningJobs request was truncated, the response includes a NextToken. To retrieve the next set of tuning jobs, use the token in the next request. - max_results: The maximum number of tuning jobs to return. The default value is 10. - sort_by: The field to sort results by. The default is Name. - sort_order: The sort order for results. The default is Ascending. - name_contains: A string in the tuning job name. This filter returns only tuning jobs whose name contains the specified string. - creation_time_after: A filter that returns only tuning jobs that were created after the specified time. - creation_time_before: A filter that returns only tuning jobs that were created before the specified time. - last_modified_time_after: A filter that returns only tuning jobs that were modified after the specified time. - last_modified_time_before: A filter that returns only tuning jobs that were modified before the specified time. - status_equals: A filter that returns only tuning jobs with the specified status. + hub_content_name: The name of the hub content to import. + hub_content_type: The type of hub content to import. + document_schema_version: The version of the hub content schema to import. + hub_name: The name of the hub to import content into. + hub_content_document: The hub content document that describes information about the hub content such as type, associated containers, scripts, and more. + hub_content_version: The version of the hub content to import. + hub_content_display_name: The display name of the hub content to import. + hub_content_description: A description of the hub content to import. + hub_content_markdown: A string that provides a description of the hub content. This string can include links, tables, and standard markdown formating. + hub_content_search_keywords: The searchable keywords of the hub content. + tags: Any tags associated with the hub content. session: Boto3 session. region: Region name. Returns: - Iterator for listed HyperParameterTuningJob resources. + The HubContent resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12592,59 +12492,76 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client( + logger.info(f"Importing hub_content resource.") + client = SageMakerClient( session=session, region_name=region, service_name="sagemaker" - ) + ).client operation_input_args = { - "SortBy": sort_by, - "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "LastModifiedTimeBefore": last_modified_time_before, - "StatusEquals": status_equals, + "HubContentName": hub_content_name, + "HubContentVersion": hub_content_version, + "HubContentType": hub_content_type, + "DocumentSchemaVersion": document_schema_version, + "HubName": hub_name, + "HubContentDisplayName": hub_content_display_name, + "HubContentDescription": hub_content_description, + "HubContentMarkdown": hub_content_markdown, + "HubContentDocument": hub_content_document, + "HubContentSearchKeywords": hub_content_search_keywords, + "Tags": tags, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_hyper_parameter_tuning_jobs", - summaries_key="HyperParameterTuningJobSummaries", - summary_name="HyperParameterTuningJobSummary", - resource_cls=HyperParameterTuningJob, - list_method_kwargs=operation_input_args, + # import the resource + response = client.import_hub_content(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get( + hub_name=hub_name, + hub_content_type=hub_content_type, + hub_content_name=hub_content_name, + session=session, + region=region, ) @Base.add_validate_call - def get_all_training_jobs( + def get_all_versions( self, - status_equals: Optional[str] = Unassigned(), + min_version: Optional[str] = Unassigned(), + max_schema_version: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator[HyperParameterTrainingJobSummary]: + ) -> ResourceIterator["HubContent"]: """ - Gets a list of TrainingJobSummary objects that describe the training jobs that a hyperparameter tuning job launched. + List hub content versions. Parameters: - next_token: If the result of the previous ListTrainingJobsForHyperParameterTuningJob request was truncated, the response includes a NextToken. To retrieve the next set of training jobs, use the token in the next request. - max_results: The maximum number of training jobs to return. The default value is 10. - status_equals: A filter that returns only training jobs with the specified status. - sort_by: The field to sort results by. The default is Name. If the value of this field is FinalObjectiveMetricValue, any training jobs that did not return an objective metric are not listed. - sort_order: The sort order for results. The default is Ascending. + min_version: The lower bound of the hub content versions to list. + max_schema_version: The upper bound of the hub content schema version. + creation_time_before: Only list hub content versions that were created before the time specified. + creation_time_after: Only list hub content versions that were created after the time specified. + sort_by: Sort hub content versions by either name or creation time. + sort_order: Sort hub content versions by ascending or descending order. + max_results: The maximum number of hub content versions to list. + next_token: If the response to a previous ListHubContentVersions request was truncated, the response includes a NextToken. To retrieve the next set of hub content versions, use the token in the next request. session: Boto3 session. region: Region name. Returns: - Iterator for listed HyperParameterTrainingJobSummary. + Iterator for listed HubContent. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12660,8 +12577,13 @@ def get_all_training_jobs( """ operation_input_args = { - "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, - "StatusEquals": status_equals, + "HubName": self.hub_name, + "HubContentType": self.hub_content_type, + "HubContentName": self.hub_content_name, + "MinVersion": min_version, + "MaxSchemaVersion": max_schema_version, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, "SortBy": sort_by, "SortOrder": sort_order, } @@ -12675,44 +12597,40 @@ def get_all_training_jobs( return ResourceIterator( client=client, - list_method="list_training_jobs_for_hyper_parameter_tuning_job", - summaries_key="TrainingJobSummaries", - summary_name="HyperParameterTrainingJobSummary", - resource_cls=HyperParameterTrainingJobSummary, + list_method="list_hub_content_versions", + summaries_key="HubContentSummaries", + summary_name="HubContentInfo", + resource_cls=HubContent, list_method_kwargs=operation_input_args, ) -class Image(Base): +class HubContentReference(Base): """ - Class representing resource Image + Class representing resource HubContentReference Attributes: - creation_time: When the image was created. - description: The description of the image. - display_name: The name of the image as displayed. - failure_reason: When a create, update, or delete operation fails, the reason for the failure. - image_arn: The ARN of the image. - image_name: The name of the image. - image_status: The status of the image. - last_modified_time: When the image was last modified. - role_arn: The ARN of the IAM role that enables Amazon SageMaker to perform tasks on your behalf. + hub_name: The name of the hub to add the hub content reference to. + sage_maker_public_hub_content_arn: The ARN of the public hub content to reference. + hub_arn: The ARN of the hub that the hub content reference was added to. + hub_content_arn: The ARN of the hub content. + hub_content_name: The name of the hub content to reference. + min_version: The minimum version of the hub content to reference. + tags: Any tags associated with the hub content to reference. """ - image_name: str - creation_time: Optional[datetime.datetime] = Unassigned() - description: Optional[str] = Unassigned() - display_name: Optional[str] = Unassigned() - failure_reason: Optional[str] = Unassigned() - image_arn: Optional[str] = Unassigned() - image_status: Optional[str] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - role_arn: Optional[str] = Unassigned() + hub_name: Union[str, object] + sage_maker_public_hub_content_arn: str + hub_arn: str + hub_content_arn: str + hub_content_name: Optional[Union[str, object]] = Unassigned() + min_version: Optional[str] = Unassigned() + tags: Optional[List[Tag]] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "image_name" + resource_name = "hub_content_reference_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -12723,49 +12641,168 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object image") + logger.error("Name attribute not found for object hub_content_reference") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = {"role_arn": {"type": "string"}} - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "Image", **kwargs - ), - ) + @classmethod + @Base.add_validate_call + def create( + cls, + hub_name: Union[str, object], + sage_maker_public_hub_content_arn: str, + hub_content_name: Optional[Union[str, object]] = Unassigned(), + min_version: Optional[str] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + ) -> Optional["HubContentReference"]: + """ + Create a HubContentReference resource - return wrapper + Parameters: + hub_name: The name of the hub to add the hub content reference to. + sage_maker_public_hub_content_arn: The ARN of the public hub content to reference. + hub_content_name: The name of the hub content to reference. + min_version: The minimum version of the hub content to reference. + tags: Any tags associated with the hub content to reference. + session: Boto3 session. + region: Region name. + + Returns: + The HubContentReference resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + """ + + operation_input_args = { + "HubName": hub_name, + "SageMakerPublicHubContentArn": sage_maker_public_hub_content_arn, + "HubContentName": hub_content_name, + "MinVersion": min_version, + "Tags": tags, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling create_hub_content_reference API") + response = client.create_hub_content_reference(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "CreateHubContentReferenceResponse") + return cls(**operation_input_args, **transformed_response) + + @Base.add_validate_call + def delete( + self, + hub_content_type: str, + ) -> None: + """ + Delete a HubContentReference resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "HubName": self.hub_name, + "HubContentType": hub_content_type, + "HubContentName": self.hub_content_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_hub_content_reference(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + + +class HumanTaskUi(Base): + """ + Class representing resource HumanTaskUi + + Attributes: + human_task_ui_arn: The Amazon Resource Name (ARN) of the human task user interface (worker task template). + human_task_ui_name: The name of the human task user interface (worker task template). + creation_time: The timestamp when the human task user interface was created. + ui_template: + human_task_ui_status: The status of the human task user interface (worker task template). Valid values are listed below. + + """ + + human_task_ui_name: str + human_task_ui_arn: Optional[str] = Unassigned() + human_task_ui_status: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + ui_template: Optional[UiTemplateInfo] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "human_task_ui_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object human_task_ui") + return None @classmethod - @populate_inputs_decorator @Base.add_validate_call def create( cls, - image_name: str, - role_arn: str, - description: Optional[str] = Unassigned(), - display_name: Optional[str] = Unassigned(), + human_task_ui_name: str, + ui_template: UiTemplate, tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Image"]: + ) -> Optional["HumanTaskUi"]: """ - Create a Image resource + Create a HumanTaskUi resource Parameters: - image_name: The name of the image. Must be unique to your account. - role_arn: The ARN of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. - description: The description of the image. - display_name: The display name of the image. If not provided, ImageName is displayed. - tags: A list of tags to apply to the image. + human_task_ui_name: The name of the user interface you are creating. + ui_template: + tags: An array of key-value pairs that contain metadata to help you categorize and organize a human review workflow user interface. Each tag consists of a key and a value, both of which you define. session: Boto3 session. region: Region name. Returns: - The Image resource. + The HumanTaskUi resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12784,21 +12821,19 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating image resource.") + logger.info("Creating human_task_ui resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "Description": description, - "DisplayName": display_name, - "ImageName": image_name, - "RoleArn": role_arn, + "HumanTaskUiName": human_task_ui_name, + "UiTemplate": ui_template, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="Image", operation_input_args=operation_input_args + resource_name="HumanTaskUi", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -12807,29 +12842,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_image(**operation_input_args) + response = client.create_human_task_ui(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(image_name=image_name, session=session, region=region) + return cls.get(human_task_ui_name=human_task_ui_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - image_name: str, + human_task_ui_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Image"]: + ) -> Optional["HumanTaskUi"]: """ - Get a Image resource + Get a HumanTaskUi resource Parameters: - image_name: The name of the image to describe. + human_task_ui_name: The name of the human task user interface (worker task template) you want information about. session: Boto3 session. region: Region name. Returns: - The Image resource. + The HumanTaskUi resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12845,7 +12880,7 @@ def get( """ operation_input_args = { - "ImageName": image_name, + "HumanTaskUiName": human_task_ui_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -12854,24 +12889,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_image(**operation_input_args) + response = client.describe_human_task_ui(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeImageResponse") - image = cls(**transformed_response) - return image + transformed_response = transform(response, "DescribeHumanTaskUiResponse") + human_task_ui = cls(**transformed_response) + return human_task_ui @Base.add_validate_call def refresh( self, - ) -> Optional["Image"]: + ) -> Optional["HumanTaskUi"]: """ - Refresh a Image resource + Refresh a HumanTaskUi resource Returns: - The Image resource. + The HumanTaskUi resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12887,79 +12922,25 @@ def refresh( """ operation_input_args = { - "ImageName": self.image_name, + "HumanTaskUiName": self.human_task_ui_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_image(**operation_input_args) + response = client.describe_human_task_ui(**operation_input_args) # deserialize response and update self - transform(response, "DescribeImageResponse", self) + transform(response, "DescribeHumanTaskUiResponse", self) return self - @populate_inputs_decorator @Base.add_validate_call - def update( - self, - delete_properties: Optional[List[str]] = Unassigned(), - description: Optional[str] = Unassigned(), - display_name: Optional[str] = Unassigned(), - role_arn: Optional[str] = Unassigned(), - ) -> Optional["Image"]: - """ - Update a Image resource - - Parameters: - delete_properties: A list of properties to delete. Only the Description and DisplayName properties can be deleted. - - Returns: - The Image resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceInUse: Resource being accessed is in use. - ResourceNotFound: Resource being access is not found. - """ - - logger.info("Updating image resource.") - client = Base.get_sagemaker_client() - - operation_input_args = { - "DeleteProperties": delete_properties, - "Description": description, - "DisplayName": display_name, - "ImageName": self.image_name, - "RoleArn": role_arn, - } - logger.debug(f"Input request: {operation_input_args}") - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - # create the resource - response = client.update_image(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() - - return self - - @Base.add_validate_call - def delete( + def delete( self, ) -> None: """ - Delete a Image resource + Delete a HumanTaskUi resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -12971,40 +12952,31 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "ImageName": self.image_name, + "HumanTaskUiName": self.human_task_ui_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_image(**operation_input_args) + client.delete_human_task_ui(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call def wait_for_status( self, - target_status: Literal[ - "CREATING", - "CREATED", - "CREATE_FAILED", - "UPDATING", - "UPDATE_FAILED", - "DELETING", - "DELETE_FAILED", - ], + target_status: Literal["Active", "Deleting"], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a Image resource to reach certain status. + Wait for a HumanTaskUi resource to reach certain status. Parameters: target_status: The status to wait for. @@ -13023,7 +12995,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task(f"Waiting for Image to reach [bold]{target_status} status...") + progress.add_task(f"Waiting for HumanTaskUi to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -13036,20 +13008,15 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.image_status + current_status = self.human_task_ui_status status.update(f"Current status: [bold]{current_status}") if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") return - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="Image", status=current_status, reason=self.failure_reason - ) - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="Image", status=current_status) + raise TimeoutExceededError(resouce_type="HumanTaskUi", status=current_status) time.sleep(poll) @Base.add_validate_call @@ -13059,7 +13026,7 @@ def wait_for_delete( timeout: Optional[int] = None, ) -> None: """ - Wait for a Image resource to be deleted. + Wait for a HumanTaskUi resource to be deleted. Parameters: poll: The number of seconds to wait between each poll. @@ -13086,7 +13053,7 @@ def wait_for_delete( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for Image to be deleted...") + progress.add_task("Waiting for HumanTaskUi to be deleted...") status = Status("Current status:") with Live( @@ -13099,19 +13066,13 @@ def wait_for_delete( while True: try: self.refresh() - current_status = self.image_status + current_status = self.human_task_ui_status status.update(f"Current status: [bold]{current_status}") - if ( - "delete_failed" in current_status.lower() - or "deletefailed" in current_status.lower() - ): - raise DeleteFailedStatusError( - resource_type="Image", reason=self.failure_reason - ) - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="Image", status=current_status) + raise TimeoutExceededError( + resouce_type="HumanTaskUi", status=current_status + ) except botocore.exceptions.ClientError as e: error_code = e.response["Error"]["Code"] @@ -13127,32 +13088,24 @@ def get_all( cls, creation_time_after: Optional[datetime.datetime] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["Image"]: + ) -> ResourceIterator["HumanTaskUi"]: """ - Get all Image resources + Get all HumanTaskUi resources Parameters: - creation_time_after: A filter that returns only images created on or after the specified time. - creation_time_before: A filter that returns only images created on or before the specified time. - last_modified_time_after: A filter that returns only images modified on or after the specified time. - last_modified_time_before: A filter that returns only images modified on or before the specified time. - max_results: The maximum number of images to return in the response. The default value is 10. - name_contains: A filter that returns only images whose name contains the specified string. - next_token: If the previous call to ListImages didn't return the full set of images, the call returns a token for getting the next set of images. - sort_by: The property used to sort results. The default value is CREATION_TIME. - sort_order: The sort order. The default value is DESCENDING. + creation_time_after: A filter that returns only human task user interfaces with a creation time greater than or equal to the specified timestamp. + creation_time_before: A filter that returns only human task user interfaces that were created before the specified timestamp. + sort_order: An optional value that specifies whether you want the results sorted in Ascending or Descending order. + next_token: A token to resume pagination. + max_results: The total number of items to return. If the total number of available items is more than the value specified in MaxResults, then a NextToken will be provided in the output that you can use to resume pagination. session: Boto3 session. region: Region name. Returns: - Iterator for listed Image resources. + Iterator for listed HumanTaskUi resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13173,10 +13126,6 @@ def get_all( operation_input_args = { "CreationTimeAfter": creation_time_after, "CreationTimeBefore": creation_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "LastModifiedTimeBefore": last_modified_time_before, - "NameContains": name_contains, - "SortBy": sort_by, "SortOrder": sort_order, } @@ -13186,116 +13135,62 @@ def get_all( return ResourceIterator( client=client, - list_method="list_images", - summaries_key="Images", - summary_name="Image", - resource_cls=Image, - list_method_kwargs=operation_input_args, - ) - - @Base.add_validate_call - def get_all_aliases( - self, - alias: Optional[str] = Unassigned(), - version: Optional[int] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator[str]: - """ - Lists the aliases of a specified image or image version. - - Parameters: - alias: The alias of the image version. - version: The version of the image. If image version is not specified, the aliases of all versions of the image are listed. - max_results: The maximum number of aliases to return. - next_token: If the previous call to ListAliases didn't return the full set of aliases, the call returns a token for retrieving the next set of aliases. - session: Boto3 session. - region: Region name. - - Returns: - Iterator for listed str. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceNotFound: Resource being access is not found. - """ - - operation_input_args = { - "ImageName": self.image_name, - "Alias": alias, - "Version": version, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - return ResourceIterator( - client=client, - list_method="list_aliases", - summaries_key="SageMakerImageVersionAliases", - summary_name="SageMakerImageVersionAlias", - resource_cls=str, + list_method="list_human_task_uis", + summaries_key="HumanTaskUiSummaries", + summary_name="HumanTaskUiSummary", + resource_cls=HumanTaskUi, list_method_kwargs=operation_input_args, ) -class ImageVersion(Base): +class HyperParameterTuningJob(Base): """ - Class representing resource ImageVersion + Class representing resource HyperParameterTuningJob Attributes: - base_image: The registry path of the container image on which this image version is based. - container_image: The registry path of the container image that contains this image version. - creation_time: When the version was created. - failure_reason: When a create or delete operation fails, the reason for the failure. - image_arn: The ARN of the image the version is based on. - image_version_arn: The ARN of the version. - image_version_status: The status of the version. - last_modified_time: When the version was last modified. - version: The version number. - vendor_guidance: The stability of the image version specified by the maintainer. NOT_PROVIDED: The maintainers did not provide a status for image version stability. STABLE: The image version is stable. TO_BE_ARCHIVED: The image version is set to be archived. Custom image versions that are set to be archived are automatically archived after three months. ARCHIVED: The image version is archived. Archived image versions are not searchable and are no longer actively supported. - job_type: Indicates SageMaker job type compatibility. TRAINING: The image version is compatible with SageMaker training jobs. INFERENCE: The image version is compatible with SageMaker inference jobs. NOTEBOOK_KERNEL: The image version is compatible with SageMaker notebook kernels. - ml_framework: The machine learning framework vended in the image version. - programming_lang: The supported programming language and its version. - processor: Indicates CPU or GPU compatibility. CPU: The image version is compatible with CPU. GPU: The image version is compatible with GPU. - horovod: Indicates Horovod compatibility. - release_notes: The maintainer description of the image version. + hyper_parameter_tuning_job_name: The name of the hyperparameter tuning job. + hyper_parameter_tuning_job_arn: The Amazon Resource Name (ARN) of the tuning job. + hyper_parameter_tuning_job_config: The HyperParameterTuningJobConfig object that specifies the configuration of the tuning job. + hyper_parameter_tuning_job_status: The status of the tuning job. + creation_time: The date and time that the tuning job started. + training_job_status_counters: The TrainingJobStatusCounters object that specifies the number of training jobs, categorized by status, that this tuning job launched. + objective_status_counters: The ObjectiveStatusCounters object that specifies the number of training jobs, categorized by the status of their final objective metric, that this tuning job launched. + training_job_definition: The HyperParameterTrainingJobDefinition object that specifies the definition of the training jobs that this tuning job launches. + training_job_definitions: A list of the HyperParameterTrainingJobDefinition objects launched for this tuning job. + hyper_parameter_tuning_end_time: The date and time that the tuning job ended. + last_modified_time: The date and time that the status of the tuning job was modified. + best_training_job: A TrainingJobSummary object that describes the training job that completed with the best current HyperParameterTuningJobObjective. + overall_best_training_job: If the hyperparameter tuning job is an warm start tuning job with a WarmStartType of IDENTICAL_DATA_AND_ALGORITHM, this is the TrainingJobSummary for the training job with the best objective metric value of all training jobs launched by this tuning job and all parent jobs specified for the warm start tuning job. + warm_start_config: The configuration for starting the hyperparameter parameter tuning job using one or more previous tuning jobs as a starting point. The results of previous tuning jobs are used to inform which combinations of hyperparameters to search over in the new tuning job. + autotune: A flag to indicate if autotune is enabled for the hyperparameter tuning job. + failure_reason: If the tuning job failed, the reason it failed. + tuning_job_completion_details: Tuning job completion information returned as the response from a hyperparameter tuning job. This information tells if your tuning job has or has not converged. It also includes the number of training jobs that have not improved model performance as evaluated against the objective function. + consumed_resources: """ - image_name: str - base_image: Optional[str] = Unassigned() - container_image: Optional[str] = Unassigned() + hyper_parameter_tuning_job_name: str + hyper_parameter_tuning_job_arn: Optional[str] = Unassigned() + hyper_parameter_tuning_job_config: Optional[HyperParameterTuningJobConfig] = Unassigned() + training_job_definition: Optional[HyperParameterTrainingJobDefinition] = Unassigned() + training_job_definitions: Optional[List[HyperParameterTrainingJobDefinition]] = Unassigned() + hyper_parameter_tuning_job_status: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - failure_reason: Optional[str] = Unassigned() - image_arn: Optional[str] = Unassigned() - image_version_arn: Optional[str] = Unassigned() - image_version_status: Optional[str] = Unassigned() + hyper_parameter_tuning_end_time: Optional[datetime.datetime] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - version: Optional[int] = Unassigned() - vendor_guidance: Optional[str] = Unassigned() - job_type: Optional[str] = Unassigned() - ml_framework: Optional[str] = Unassigned() - programming_lang: Optional[str] = Unassigned() - processor: Optional[str] = Unassigned() - horovod: Optional[bool] = Unassigned() - release_notes: Optional[str] = Unassigned() - - def get_name(self) -> str: - attributes = vars(self) - resource_name = "image_version_name" + training_job_status_counters: Optional[TrainingJobStatusCounters] = Unassigned() + objective_status_counters: Optional[ObjectiveStatusCounters] = Unassigned() + best_training_job: Optional[HyperParameterTrainingJobSummary] = Unassigned() + overall_best_training_job: Optional[HyperParameterTrainingJobSummary] = Unassigned() + warm_start_config: Optional[HyperParameterTuningJobWarmStartConfig] = Unassigned() + autotune: Optional[Autotune] = Unassigned() + failure_reason: Optional[str] = Unassigned() + tuning_job_completion_details: Optional[HyperParameterTuningJobCompletionDetails] = Unassigned() + consumed_resources: Optional[HyperParameterTuningJobConsumedResources] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "hyper_parameter_tuning_job_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -13306,47 +13201,72 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object image_version") + logger.error("Name attribute not found for object hyper_parameter_tuning_job") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "training_job_definition": { + "role_arn": {"type": "string"}, + "output_data_config": { + "s3_output_path": {"type": "string"}, + "kms_key_id": {"type": "string"}, + }, + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + }, + "resource_config": {"volume_kms_key_id": {"type": "string"}}, + "hyper_parameter_tuning_resource_config": { + "volume_kms_key_id": {"type": "string"} + }, + "checkpoint_config": {"s3_uri": {"type": "string"}}, + } + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "HyperParameterTuningJob", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, - base_image: str, - client_token: str, - image_name: Union[str, object], - aliases: Optional[List[str]] = Unassigned(), - vendor_guidance: Optional[str] = Unassigned(), - job_type: Optional[str] = Unassigned(), - ml_framework: Optional[str] = Unassigned(), - programming_lang: Optional[str] = Unassigned(), - processor: Optional[str] = Unassigned(), - horovod: Optional[bool] = Unassigned(), - release_notes: Optional[str] = Unassigned(), + hyper_parameter_tuning_job_name: str, + hyper_parameter_tuning_job_config: HyperParameterTuningJobConfig, + training_job_definition: Optional[HyperParameterTrainingJobDefinition] = Unassigned(), + training_job_definitions: Optional[ + List[HyperParameterTrainingJobDefinition] + ] = Unassigned(), + warm_start_config: Optional[HyperParameterTuningJobWarmStartConfig] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + autotune: Optional[Autotune] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ImageVersion"]: + ) -> Optional["HyperParameterTuningJob"]: """ - Create a ImageVersion resource + Create a HyperParameterTuningJob resource Parameters: - base_image: The registry path of the container image to use as the starting point for this version. The path is an Amazon ECR URI in the following format: <acct-id>.dkr.ecr.<region>.amazonaws.com/<repo-name[:tag] or [@digest]> - client_token: A unique ID. If not specified, the Amazon Web Services CLI and Amazon Web Services SDKs, such as the SDK for Python (Boto3), add a unique value to the call. - image_name: The ImageName of the Image to create a version of. - aliases: A list of aliases created with the image version. - vendor_guidance: The stability of the image version, specified by the maintainer. NOT_PROVIDED: The maintainers did not provide a status for image version stability. STABLE: The image version is stable. TO_BE_ARCHIVED: The image version is set to be archived. Custom image versions that are set to be archived are automatically archived after three months. ARCHIVED: The image version is archived. Archived image versions are not searchable and are no longer actively supported. - job_type: Indicates SageMaker job type compatibility. TRAINING: The image version is compatible with SageMaker training jobs. INFERENCE: The image version is compatible with SageMaker inference jobs. NOTEBOOK_KERNEL: The image version is compatible with SageMaker notebook kernels. - ml_framework: The machine learning framework vended in the image version. - programming_lang: The supported programming language and its version. - processor: Indicates CPU or GPU compatibility. CPU: The image version is compatible with CPU. GPU: The image version is compatible with GPU. - horovod: Indicates Horovod compatibility. - release_notes: The maintainer description of the image version. + hyper_parameter_tuning_job_name: The name of the tuning job. This name is the prefix for the names of all training jobs that this tuning job launches. The name must be unique within the same Amazon Web Services account and Amazon Web Services Region. The name must have 1 to 32 characters. Valid characters are a-z, A-Z, 0-9, and : + = @ _ % - (hyphen). The name is not case sensitive. + hyper_parameter_tuning_job_config: The HyperParameterTuningJobConfig object that describes the tuning job, including the search strategy, the objective metric used to evaluate training jobs, ranges of parameters to search, and resource limits for the tuning job. For more information, see How Hyperparameter Tuning Works. + training_job_definition: The HyperParameterTrainingJobDefinition object that describes the training jobs that this tuning job launches, including static hyperparameters, input data configuration, output data configuration, resource configuration, and stopping condition. + training_job_definitions: A list of the HyperParameterTrainingJobDefinition objects launched for this tuning job. + warm_start_config: Specifies the configuration for starting the hyperparameter tuning job using one or more previous tuning jobs as a starting point. The results of previous tuning jobs are used to inform which combinations of hyperparameters to search over in the new tuning job. All training jobs launched by the new hyperparameter tuning job are evaluated by using the objective metric. If you specify IDENTICAL_DATA_AND_ALGORITHM as the WarmStartType value for the warm start configuration, the training job that performs the best in the new tuning job is compared to the best training jobs from the parent tuning jobs. From these, the training job that performs the best as measured by the objective metric is returned as the overall best training job. All training jobs launched by parent hyperparameter tuning jobs and the new hyperparameter tuning jobs count against the limit of training jobs for the tuning job. + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. Tags that you specify for the tuning job are also added to all training jobs that the tuning job launches. + autotune: Configures SageMaker Automatic model tuning (AMT) to automatically find optimal parameters for the following fields: ParameterRanges: The names and ranges of parameters that a hyperparameter tuning job can optimize. ResourceLimits: The maximum resources that can be used for a training job. These resources include the maximum number of training jobs, the maximum runtime of a tuning job, and the maximum number of training jobs to run at the same time. TrainingJobEarlyStoppingType: A flag that specifies whether or not to use early stopping for training jobs launched by a hyperparameter tuning job. RetryStrategy: The number of times to retry a training job. Strategy: Specifies how hyperparameter tuning chooses the combinations of hyperparameter values to use for the training jobs that it launches. ConvergenceDetected: A flag to indicate that Automatic model tuning (AMT) has detected model convergence. session: Boto3 session. region: Region name. Returns: - The ImageVersion resource. + The HyperParameterTuningJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13360,33 +13280,28 @@ def create( ``` ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ResourceNotFound: Resource being access is not found. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating image_version resource.") + logger.info("Creating hyper_parameter_tuning_job resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "BaseImage": base_image, - "ClientToken": client_token, - "ImageName": image_name, - "Aliases": aliases, - "VendorGuidance": vendor_guidance, - "JobType": job_type, - "MLFramework": ml_framework, - "ProgrammingLang": programming_lang, - "Processor": processor, - "Horovod": horovod, - "ReleaseNotes": release_notes, + "HyperParameterTuningJobName": hyper_parameter_tuning_job_name, + "HyperParameterTuningJobConfig": hyper_parameter_tuning_job_config, + "TrainingJobDefinition": training_job_definition, + "TrainingJobDefinitions": training_job_definitions, + "WarmStartConfig": warm_start_config, + "Tags": tags, + "Autotune": autotune, } operation_input_args = Base.populate_chained_attributes( - resource_name="ImageVersion", operation_input_args=operation_input_args + resource_name="HyperParameterTuningJob", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -13395,33 +13310,33 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_image_version(**operation_input_args) + response = client.create_hyper_parameter_tuning_job(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(image_name=image_name, session=session, region=region) + return cls.get( + hyper_parameter_tuning_job_name=hyper_parameter_tuning_job_name, + session=session, + region=region, + ) @classmethod @Base.add_validate_call def get( cls, - image_name: str, - version: Optional[int] = Unassigned(), - alias: Optional[str] = Unassigned(), + hyper_parameter_tuning_job_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ImageVersion"]: + ) -> Optional["HyperParameterTuningJob"]: """ - Get a ImageVersion resource + Get a HyperParameterTuningJob resource Parameters: - image_name: The name of the image. - version: The version of the image. If not specified, the latest version is described. - alias: The alias of the image version. + hyper_parameter_tuning_job_name: The name of the tuning job. session: Boto3 session. region: Region name. Returns: - The ImageVersion resource. + The HyperParameterTuningJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13437,9 +13352,7 @@ def get( """ operation_input_args = { - "ImageName": image_name, - "Version": version, - "Alias": alias, + "HyperParameterTuningJobName": hyper_parameter_tuning_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -13448,25 +13361,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_image_version(**operation_input_args) + response = client.describe_hyper_parameter_tuning_job(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeImageVersionResponse") - image_version = cls(**transformed_response) - return image_version + transformed_response = transform(response, "DescribeHyperParameterTuningJobResponse") + hyper_parameter_tuning_job = cls(**transformed_response) + return hyper_parameter_tuning_job @Base.add_validate_call def refresh( self, - alias: Optional[str] = Unassigned(), - ) -> Optional["ImageVersion"]: + ) -> Optional["HyperParameterTuningJob"]: """ - Refresh a ImageVersion resource + Refresh a HyperParameterTuningJob resource Returns: - The ImageVersion resource. + The HyperParameterTuningJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13482,46 +13394,25 @@ def refresh( """ operation_input_args = { - "ImageName": self.image_name, - "Version": self.version, - "Alias": alias, + "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_image_version(**operation_input_args) + response = client.describe_hyper_parameter_tuning_job(**operation_input_args) # deserialize response and update self - transform(response, "DescribeImageVersionResponse", self) + transform(response, "DescribeHyperParameterTuningJobResponse", self) return self @Base.add_validate_call - def update( + def delete( self, - alias: Optional[str] = Unassigned(), - version: Optional[int] = Unassigned(), - aliases_to_add: Optional[List[str]] = Unassigned(), - aliases_to_delete: Optional[List[str]] = Unassigned(), - vendor_guidance: Optional[str] = Unassigned(), - job_type: Optional[str] = Unassigned(), - ml_framework: Optional[str] = Unassigned(), - programming_lang: Optional[str] = Unassigned(), - processor: Optional[str] = Unassigned(), - horovod: Optional[bool] = Unassigned(), - release_notes: Optional[str] = Unassigned(), - ) -> Optional["ImageVersion"]: + ) -> None: """ - Update a ImageVersion resource - - Parameters: - alias: The alias of the image version. - aliases_to_add: A list of aliases to add. - aliases_to_delete: A list of aliases to delete. - - Returns: - The ImageVersion resource. + Delete a HyperParameterTuningJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13533,46 +13424,25 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. - ResourceNotFound: Resource being access is not found. """ - logger.info("Updating image_version resource.") client = Base.get_sagemaker_client() operation_input_args = { - "ImageName": self.image_name, - "Alias": alias, - "Version": version, - "AliasesToAdd": aliases_to_add, - "AliasesToDelete": aliases_to_delete, - "VendorGuidance": vendor_guidance, - "JobType": job_type, - "MLFramework": ml_framework, - "ProgrammingLang": programming_lang, - "Processor": processor, - "Horovod": horovod, - "ReleaseNotes": release_notes, + "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, } - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.update_image_version(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() + client.delete_hyper_parameter_tuning_job(**operation_input_args) - return self + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def delete( - self, - alias: Optional[str] = Unassigned(), - ) -> None: + def stop(self) -> None: """ - Delete a ImageVersion resource + Stop a HyperParameterTuningJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13584,37 +13454,32 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = SageMakerClient().client operation_input_args = { - "ImageName": self.image_name, - "Version": self.version, - "Alias": alias, + "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_image_version(**operation_input_args) + client.stop_hyper_parameter_tuning_job(**operation_input_args) - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def wait_for_status( + def wait( self, - target_status: Literal["CREATING", "CREATED", "CREATE_FAILED", "DELETING", "DELETE_FAILED"], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a ImageVersion resource to reach certain status. + Wait for a HyperParameterTuningJob resource. Parameters: - target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. @@ -13622,7 +13487,9 @@ def wait_for_status( TimeoutExceededError: If the resource does not reach a terminal state before the timeout. FailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. + """ + terminal_states = ["Completed", "Failed", "Stopped", "DeleteFailed"] start_time = time.time() progress = Progress( @@ -13630,7 +13497,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task(f"Waiting for ImageVersion to reach [bold]{target_status} status...") + progress.add_task("Waiting for HyperParameterTuningJob...") status = Status("Current status:") with Live( @@ -13643,22 +13510,25 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.image_version_status + current_status = self.hyper_parameter_tuning_job_status status.update(f"Current status: [bold]{current_status}") - if target_status == current_status: + if current_status in terminal_states: logger.info(f"Final Resource Status: [bold]{current_status}") + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="HyperParameterTuningJob", + status=current_status, + reason=self.failure_reason, + ) + return - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="ImageVersion", - status=current_status, - reason=self.failure_reason, - ) - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="ImageVersion", status=current_status) + raise TimeoutExceededError( + resouce_type="HyperParameterTuningJob", status=current_status + ) time.sleep(poll) @Base.add_validate_call @@ -13668,7 +13538,7 @@ def wait_for_delete( timeout: Optional[int] = None, ) -> None: """ - Wait for a ImageVersion resource to be deleted. + Wait for a HyperParameterTuningJob resource to be deleted. Parameters: poll: The number of seconds to wait between each poll. @@ -13695,7 +13565,7 @@ def wait_for_delete( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for ImageVersion to be deleted...") + progress.add_task("Waiting for HyperParameterTuningJob to be deleted...") status = Status("Current status:") with Live( @@ -13708,20 +13578,12 @@ def wait_for_delete( while True: try: self.refresh() - current_status = self.image_version_status + current_status = self.hyper_parameter_tuning_job_status status.update(f"Current status: [bold]{current_status}") - if ( - "delete_failed" in current_status.lower() - or "deletefailed" in current_status.lower() - ): - raise DeleteFailedStatusError( - resource_type="ImageVersion", reason=self.failure_reason - ) - if timeout is not None and time.time() - start_time >= timeout: raise TimeoutExceededError( - resouce_type="ImageVersion", status=current_status + resouce_type="HyperParameterTuningJob", status=current_status ) except botocore.exceptions.ClientError as e: error_code = e.response["Error"]["Code"] @@ -13732,82 +13594,40 @@ def wait_for_delete( raise e time.sleep(poll) - -class InferenceComponent(Base): - """ - Class representing resource InferenceComponent - - Attributes: - inference_component_name: The name of the inference component. - inference_component_arn: The Amazon Resource Name (ARN) of the inference component. - endpoint_name: The name of the endpoint that hosts the inference component. - endpoint_arn: The Amazon Resource Name (ARN) of the endpoint that hosts the inference component. - creation_time: The time when the inference component was created. - last_modified_time: The time when the inference component was last updated. - variant_name: The name of the production variant that hosts the inference component. - failure_reason: If the inference component status is Failed, the reason for the failure. - specification: Details about the resources that are deployed with this inference component. - runtime_config: Details about the runtime settings for the model that is deployed with the inference component. - inference_component_status: The status of the inference component. - - """ - - inference_component_name: str - inference_component_arn: Optional[str] = Unassigned() - endpoint_name: Optional[str] = Unassigned() - endpoint_arn: Optional[str] = Unassigned() - variant_name: Optional[str] = Unassigned() - failure_reason: Optional[str] = Unassigned() - specification: Optional[InferenceComponentSpecificationSummary] = Unassigned() - runtime_config: Optional[InferenceComponentRuntimeConfigSummary] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - inference_component_status: Optional[str] = Unassigned() - - def get_name(self) -> str: - attributes = vars(self) - resource_name = "inference_component_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] - - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) - - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object inference_component") - return None - @classmethod @Base.add_validate_call - def create( + def get_all( cls, - inference_component_name: str, - endpoint_name: Union[str, object], - specification: InferenceComponentSpecification, - variant_name: Optional[str] = Unassigned(), - runtime_config: Optional[InferenceComponentRuntimeConfig] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + status_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["InferenceComponent"]: + ) -> ResourceIterator["HyperParameterTuningJob"]: """ - Create a InferenceComponent resource + Get all HyperParameterTuningJob resources Parameters: - inference_component_name: A unique name to assign to the inference component. - endpoint_name: The name of an existing endpoint where you host the inference component. - specification: Details about the resources to deploy with this inference component, including the model, container, and compute resources. - variant_name: The name of an existing production variant where you host the inference component. - runtime_config: Runtime settings for a model that is deployed with an inference component. - tags: A list of key-value pairs associated with the model. For more information, see Tagging Amazon Web Services resources in the Amazon Web Services General Reference. + next_token: If the result of the previous ListHyperParameterTuningJobs request was truncated, the response includes a NextToken. To retrieve the next set of tuning jobs, use the token in the next request. + max_results: The maximum number of tuning jobs to return. The default value is 10. + sort_by: The field to sort results by. The default is Name. + sort_order: The sort order for results. The default is Ascending. + name_contains: A string in the tuning job name. This filter returns only tuning jobs whose name contains the specified string. + creation_time_after: A filter that returns only tuning jobs that were created after the specified time. + creation_time_before: A filter that returns only tuning jobs that were created before the specified time. + last_modified_time_after: A filter that returns only tuning jobs that were modified after the specified time. + last_modified_time_before: A filter that returns only tuning jobs that were modified before the specified time. + status_equals: A filter that returns only tuning jobs with the specified status. session: Boto3 session. region: Region name. Returns: - The InferenceComponent resource. + Iterator for listed HyperParameterTuningJob resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13819,61 +13639,59 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating inference_component resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "InferenceComponentName": inference_component_name, - "EndpointName": endpoint_name, - "VariantName": variant_name, - "Specification": specification, - "RuntimeConfig": runtime_config, - "Tags": tags, + "SortBy": sort_by, + "SortOrder": sort_order, + "NameContains": name_contains, + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "StatusEquals": status_equals, } - operation_input_args = Base.populate_chained_attributes( - resource_name="InferenceComponent", operation_input_args=operation_input_args - ) - - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.create_inference_component(**operation_input_args) - logger.debug(f"Response: {response}") - - return cls.get( - inference_component_name=inference_component_name, session=session, region=region + return ResourceIterator( + client=client, + list_method="list_hyper_parameter_tuning_jobs", + summaries_key="HyperParameterTuningJobSummaries", + summary_name="HyperParameterTuningJobSummary", + resource_cls=HyperParameterTuningJob, + list_method_kwargs=operation_input_args, ) - @classmethod @Base.add_validate_call - def get( - cls, - inference_component_name: str, + def get_all_training_jobs( + self, + status_equals: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["InferenceComponent"]: + ) -> ResourceIterator[HyperParameterTrainingJobSummary]: """ - Get a InferenceComponent resource + Gets a list of TrainingJobSummary objects that describe the training jobs that a hyperparameter tuning job launched. Parameters: - inference_component_name: The name of the inference component. + next_token: If the result of the previous ListTrainingJobsForHyperParameterTuningJob request was truncated, the response includes a NextToken. To retrieve the next set of training jobs, use the token in the next request. + max_results: The maximum number of training jobs to return. The default value is 10. + status_equals: A filter that returns only training jobs with the specified status. + sort_by: The field to sort results by. The default is Name. If the value of this field is FinalObjectiveMetricValue, any training jobs that did not return an objective metric are not listed. + sort_order: The sort order for results. The default is Ascending. session: Boto3 session. region: Region name. Returns: - The InferenceComponent resource. + Iterator for listed HyperParameterTrainingJobSummary. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13885,10 +13703,14 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "InferenceComponentName": inference_component_name, + "HyperParameterTuningJobName": self.hyper_parameter_tuning_job_name, + "StatusEquals": status_equals, + "SortBy": sort_by, + "SortOrder": sort_order, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -13897,62 +13719,100 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_inference_component(**operation_input_args) - logger.debug(response) + return ResourceIterator( + client=client, + list_method="list_training_jobs_for_hyper_parameter_tuning_job", + summaries_key="TrainingJobSummaries", + summary_name="HyperParameterTrainingJobSummary", + resource_cls=HyperParameterTrainingJobSummary, + list_method_kwargs=operation_input_args, + ) - # deserialize the response - transformed_response = transform(response, "DescribeInferenceComponentOutput") - inference_component = cls(**transformed_response) - return inference_component - @Base.add_validate_call - def refresh( - self, - ) -> Optional["InferenceComponent"]: - """ - Refresh a InferenceComponent resource +class Image(Base): + """ + Class representing resource Image - Returns: - The InferenceComponent resource. + Attributes: + creation_time: When the image was created. + description: The description of the image. + display_name: The name of the image as displayed. + failure_reason: When a create, update, or delete operation fails, the reason for the failure. + image_arn: The ARN of the image. + image_name: The name of the image. + image_status: The status of the image. + last_modified_time: When the image was last modified. + role_arn: The ARN of the IAM role that enables Amazon SageMaker to perform tasks on your behalf. - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - """ + """ - operation_input_args = { - "InferenceComponentName": self.inference_component_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") + image_name: str + creation_time: Optional[datetime.datetime] = Unassigned() + description: Optional[str] = Unassigned() + display_name: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() + image_arn: Optional[str] = Unassigned() + image_status: Optional[str] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + role_arn: Optional[str] = Unassigned() - client = Base.get_sagemaker_client() - response = client.describe_inference_component(**operation_input_args) + def get_name(self) -> str: + attributes = vars(self) + resource_name = "image_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] - # deserialize response and update self - transform(response, "DescribeInferenceComponentOutput", self) - return self + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) - @Base.add_validate_call - def update( - self, - specification: Optional[InferenceComponentSpecification] = Unassigned(), - runtime_config: Optional[InferenceComponentRuntimeConfig] = Unassigned(), - ) -> Optional["InferenceComponent"]: - """ - Update a InferenceComponent resource + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object image") + return None + + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = {"role_arn": {"type": "string"}} + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "Image", **kwargs + ), + ) + + return wrapper + + @classmethod + @populate_inputs_decorator + @Base.add_validate_call + def create( + cls, + image_name: str, + role_arn: str, + description: Optional[str] = Unassigned(), + display_name: Optional[str] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["Image"]: + """ + Create a Image resource + + Parameters: + image_name: The name of the image. Must be unique to your account. + role_arn: The ARN of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. + description: The description of the image. + display_name: The display name of the image. If not provided, ImageName is displayed. + tags: A list of tags to apply to the image. + session: Boto3 session. + region: Region name. Returns: - The InferenceComponent resource. + The Image resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -13964,16 +13824,170 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Updating inference_component resource.") + logger.info("Creating image resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "Description": description, + "DisplayName": display_name, + "ImageName": image_name, + "RoleArn": role_arn, + "Tags": tags, + } + + operation_input_args = Base.populate_chained_attributes( + resource_name="Image", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.create_image(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get(image_name=image_name, session=session, region=region) + + @classmethod + @Base.add_validate_call + def get( + cls, + image_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["Image"]: + """ + Get a Image resource + + Parameters: + image_name: The name of the image to describe. + session: Boto3 session. + region: Region name. + + Returns: + The Image resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "ImageName": image_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + response = client.describe_image(**operation_input_args) + + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeImageResponse") + image = cls(**transformed_response) + return image + + @Base.add_validate_call + def refresh( + self, + ) -> Optional["Image"]: + """ + Refresh a Image resource + + Returns: + The Image resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "ImageName": self.image_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client() + response = client.describe_image(**operation_input_args) + + # deserialize response and update self + transform(response, "DescribeImageResponse", self) + return self + + @populate_inputs_decorator + @Base.add_validate_call + def update( + self, + delete_properties: Optional[List[str]] = Unassigned(), + description: Optional[str] = Unassigned(), + display_name: Optional[str] = Unassigned(), + role_arn: Optional[str] = Unassigned(), + ) -> Optional["Image"]: + """ + Update a Image resource + + Parameters: + delete_properties: A list of properties to delete. Only the Description and DisplayName properties can be deleted. + + Returns: + The Image resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceInUse: Resource being accessed is in use. + ResourceNotFound: Resource being access is not found. + """ + + logger.info("Updating image resource.") client = Base.get_sagemaker_client() operation_input_args = { - "InferenceComponentName": self.inference_component_name, - "Specification": specification, - "RuntimeConfig": runtime_config, + "DeleteProperties": delete_properties, + "Description": description, + "DisplayName": display_name, + "ImageName": self.image_name, + "RoleArn": role_arn, } logger.debug(f"Input request: {operation_input_args}") # serialize the input request @@ -13981,7 +13995,7 @@ def update( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_inference_component(**operation_input_args) + response = client.update_image(**operation_input_args) logger.debug(f"Response: {response}") self.refresh() @@ -13992,7 +14006,7 @@ def delete( self, ) -> None: """ - Delete a InferenceComponent resource + Delete a Image resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14004,30 +14018,40 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. + ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "InferenceComponentName": self.inference_component_name, + "ImageName": self.image_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_inference_component(**operation_input_args) + client.delete_image(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call def wait_for_status( self, - target_status: Literal["InService", "Creating", "Updating", "Failed", "Deleting"], + target_status: Literal[ + "CREATING", + "CREATED", + "CREATE_FAILED", + "UPDATING", + "UPDATE_FAILED", + "DELETING", + "DELETE_FAILED", + ], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a InferenceComponent resource to reach certain status. + Wait for a Image resource to reach certain status. Parameters: target_status: The status to wait for. @@ -14046,9 +14070,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task( - f"Waiting for InferenceComponent to reach [bold]{target_status} status..." - ) + progress.add_task(f"Waiting for Image to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -14061,7 +14083,7 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.inference_component_status + current_status = self.image_status status.update(f"Current status: [bold]{current_status}") if target_status == current_status: @@ -14070,15 +14092,11 @@ def wait_for_status( if "failed" in current_status.lower(): raise FailedStatusError( - resource_type="InferenceComponent", - status=current_status, - reason=self.failure_reason, + resource_type="Image", status=current_status, reason=self.failure_reason ) if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="InferenceComponent", status=current_status - ) + raise TimeoutExceededError(resouce_type="Image", status=current_status) time.sleep(poll) @Base.add_validate_call @@ -14088,7 +14106,7 @@ def wait_for_delete( timeout: Optional[int] = None, ) -> None: """ - Wait for a InferenceComponent resource to be deleted. + Wait for a Image resource to be deleted. Parameters: poll: The number of seconds to wait between each poll. @@ -14115,7 +14133,7 @@ def wait_for_delete( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for InferenceComponent to be deleted...") + progress.add_task("Waiting for Image to be deleted...") status = Status("Current status:") with Live( @@ -14128,16 +14146,22 @@ def wait_for_delete( while True: try: self.refresh() - current_status = self.inference_component_status + current_status = self.image_status status.update(f"Current status: [bold]{current_status}") - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="InferenceComponent", status=current_status - ) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] - + if ( + "delete_failed" in current_status.lower() + or "deletefailed" in current_status.lower() + ): + raise DeleteFailedStatusError( + resource_type="Image", reason=self.failure_reason + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="Image", status=current_status) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + if "ResourceNotFound" in error_code or "ValidationException" in error_code: logger.info("Resource was not found. It may have been deleted.") return @@ -14148,40 +14172,34 @@ def wait_for_delete( @Base.add_validate_call def get_all( cls, - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - status_equals: Optional[str] = Unassigned(), - endpoint_name_equals: Optional[str] = Unassigned(), - variant_name_equals: Optional[str] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["InferenceComponent"]: + ) -> ResourceIterator["Image"]: """ - Get all InferenceComponent resources + Get all Image resources Parameters: - sort_by: The field by which to sort the inference components in the response. The default is CreationTime. - sort_order: The sort order for results. The default is Descending. - next_token: A token that you use to get the next set of results following a truncated response. If the response to the previous request was truncated, that response provides the value for this token. - max_results: The maximum number of inference components to return in the response. This value defaults to 10. - name_contains: Filters the results to only those inference components with a name that contains the specified string. - creation_time_before: Filters the results to only those inference components that were created before the specified time. - creation_time_after: Filters the results to only those inference components that were created after the specified time. - last_modified_time_before: Filters the results to only those inference components that were updated before the specified time. - last_modified_time_after: Filters the results to only those inference components that were updated after the specified time. - status_equals: Filters the results to only those inference components with the specified status. - endpoint_name_equals: An endpoint name to filter the listed inference components. The response includes only those inference components that are hosted at the specified endpoint. - variant_name_equals: A production variant name to filter the listed inference components. The response includes only those inference components that are hosted at the specified variant. + creation_time_after: A filter that returns only images created on or after the specified time. + creation_time_before: A filter that returns only images created on or before the specified time. + last_modified_time_after: A filter that returns only images modified on or after the specified time. + last_modified_time_before: A filter that returns only images modified on or before the specified time. + max_results: The maximum number of images to return in the response. The default value is 10. + name_contains: A filter that returns only images whose name contains the specified string. + next_token: If the previous call to ListImages didn't return the full set of images, the call returns a token for getting the next set of images. + sort_by: The property used to sort results. The default value is CREATION_TIME. + sort_order: The sort order. The default value is DESCENDING. session: Boto3 session. region: Region name. Returns: - Iterator for listed InferenceComponent resources. + Iterator for listed Image resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14200,16 +14218,13 @@ def get_all( ) operation_input_args = { - "SortBy": sort_by, - "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, - "LastModifiedTimeBefore": last_modified_time_before, + "CreationTimeBefore": creation_time_before, "LastModifiedTimeAfter": last_modified_time_after, - "StatusEquals": status_equals, - "EndpointNameEquals": endpoint_name_equals, - "VariantNameEquals": variant_name_equals, + "LastModifiedTimeBefore": last_modified_time_before, + "NameContains": name_contains, + "SortBy": sort_by, + "SortOrder": sort_order, } # serialize the input request @@ -14218,28 +14233,35 @@ def get_all( return ResourceIterator( client=client, - list_method="list_inference_components", - summaries_key="InferenceComponents", - summary_name="InferenceComponentSummary", - resource_cls=InferenceComponent, + list_method="list_images", + summaries_key="Images", + summary_name="Image", + resource_cls=Image, list_method_kwargs=operation_input_args, ) @Base.add_validate_call - def update_runtime_configs( + def get_all_aliases( self, - desired_runtime_config: InferenceComponentRuntimeConfig, + alias: Optional[str] = Unassigned(), + version: Optional[int] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> None: + ) -> ResourceIterator[str]: """ - Runtime settings for a model that is deployed with an inference component. + Lists the aliases of a specified image or image version. Parameters: - desired_runtime_config: Runtime settings for a model that is deployed with an inference component. + alias: The alias of the image version. + version: The version of the image. If image version is not specified, the aliases of all versions of the image are listed. + max_results: The maximum number of aliases to return. + next_token: If the previous call to ListAliases didn't return the full set of aliases, the call returns a token for retrieving the next set of aliases. session: Boto3 session. region: Region name. + Returns: + Iterator for listed str. + Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -14250,12 +14272,13 @@ def update_runtime_configs( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "InferenceComponentName": self.inference_component_name, - "DesiredRuntimeConfig": desired_runtime_config, + "ImageName": self.image_name, + "Alias": alias, + "Version": version, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -14265,55 +14288,61 @@ def update_runtime_configs( session=session, region_name=region, service_name="sagemaker" ) - logger.debug(f"Calling update_inference_component_runtime_config API") - response = client.update_inference_component_runtime_config(**operation_input_args) - logger.debug(f"Response: {response}") + return ResourceIterator( + client=client, + list_method="list_aliases", + summaries_key="SageMakerImageVersionAliases", + summary_name="SageMakerImageVersionAlias", + resource_cls=str, + list_method_kwargs=operation_input_args, + ) -class InferenceExperiment(Base): +class ImageVersion(Base): """ - Class representing resource InferenceExperiment + Class representing resource ImageVersion Attributes: - arn: The ARN of the inference experiment being described. - name: The name of the inference experiment. - type: The type of the inference experiment. - status: The status of the inference experiment. The following are the possible statuses for an inference experiment: Creating - Amazon SageMaker is creating your experiment. Created - Amazon SageMaker has finished the creation of your experiment and will begin the experiment at the scheduled time. Updating - When you make changes to your experiment, your experiment shows as updating. Starting - Amazon SageMaker is beginning your experiment. Running - Your experiment is in progress. Stopping - Amazon SageMaker is stopping your experiment. Completed - Your experiment has completed. Cancelled - When you conclude your experiment early using the StopInferenceExperiment API, or if any operation fails with an unexpected error, it shows as cancelled. - endpoint_metadata: The metadata of the endpoint on which the inference experiment ran. - model_variants: An array of ModelVariantConfigSummary objects. There is one for each variant in the inference experiment. Each ModelVariantConfigSummary object in the array describes the infrastructure configuration for deploying the corresponding variant. - schedule: The duration for which the inference experiment ran or will run. - status_reason: The error message or client-specified Reason from the StopInferenceExperiment API, that explains the status of the inference experiment. - description: The description of the inference experiment. - creation_time: The timestamp at which you created the inference experiment. - completion_time: The timestamp at which the inference experiment was completed. - last_modified_time: The timestamp at which you last modified the inference experiment. - role_arn: The ARN of the IAM role that Amazon SageMaker can assume to access model artifacts and container images, and manage Amazon SageMaker Inference endpoints for model deployment. - data_storage_config: The Amazon S3 location and configuration for storing inference request and response data. - shadow_mode_config: The configuration of ShadowMode inference experiment type, which shows the production variant that takes all the inference requests, and the shadow variant to which Amazon SageMaker replicates a percentage of the inference requests. For the shadow variant it also shows the percentage of requests that Amazon SageMaker replicates. - kms_key: The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint. For more information, see CreateInferenceExperiment. + base_image: The registry path of the container image on which this image version is based. + container_image: The registry path of the container image that contains this image version. + creation_time: When the version was created. + failure_reason: When a create or delete operation fails, the reason for the failure. + image_arn: The ARN of the image the version is based on. + image_version_arn: The ARN of the version. + image_version_status: The status of the version. + last_modified_time: When the version was last modified. + version: The version number. + vendor_guidance: The stability of the image version specified by the maintainer. NOT_PROVIDED: The maintainers did not provide a status for image version stability. STABLE: The image version is stable. TO_BE_ARCHIVED: The image version is set to be archived. Custom image versions that are set to be archived are automatically archived after three months. ARCHIVED: The image version is archived. Archived image versions are not searchable and are no longer actively supported. + job_type: Indicates SageMaker job type compatibility. TRAINING: The image version is compatible with SageMaker training jobs. INFERENCE: The image version is compatible with SageMaker inference jobs. NOTEBOOK_KERNEL: The image version is compatible with SageMaker notebook kernels. + ml_framework: The machine learning framework vended in the image version. + programming_lang: The supported programming language and its version. + processor: Indicates CPU or GPU compatibility. CPU: The image version is compatible with CPU. GPU: The image version is compatible with GPU. + horovod: Indicates Horovod compatibility. + release_notes: The maintainer description of the image version. """ - name: str - arn: Optional[str] = Unassigned() - type: Optional[str] = Unassigned() - schedule: Optional[InferenceExperimentSchedule] = Unassigned() - status: Optional[str] = Unassigned() - status_reason: Optional[str] = Unassigned() - description: Optional[str] = Unassigned() + image_name: str + base_image: Optional[str] = Unassigned() + container_image: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - completion_time: Optional[datetime.datetime] = Unassigned() + failure_reason: Optional[str] = Unassigned() + image_arn: Optional[str] = Unassigned() + image_version_arn: Optional[str] = Unassigned() + image_version_status: Optional[str] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - role_arn: Optional[str] = Unassigned() - endpoint_metadata: Optional[EndpointMetadata] = Unassigned() - model_variants: Optional[List[ModelVariantConfigSummary]] = Unassigned() - data_storage_config: Optional[InferenceExperimentDataStorageConfig] = Unassigned() - shadow_mode_config: Optional[ShadowModeConfig] = Unassigned() - kms_key: Optional[str] = Unassigned() + version: Optional[int] = Unassigned() + vendor_guidance: Optional[str] = Unassigned() + job_type: Optional[str] = Unassigned() + ml_framework: Optional[str] = Unassigned() + programming_lang: Optional[str] = Unassigned() + processor: Optional[str] = Unassigned() + horovod: Optional[bool] = Unassigned() + release_notes: Optional[str] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "inference_experiment_name" + resource_name = "image_version_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -14324,65 +14353,47 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object inference_experiment") + logger.error("Name attribute not found for object image_version") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "role_arn": {"type": "string"}, - "data_storage_config": {"kms_key": {"type": "string"}}, - "kms_key": {"type": "string"}, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "InferenceExperiment", **kwargs - ), - ) - - return wrapper - @classmethod - @populate_inputs_decorator @Base.add_validate_call def create( cls, - name: str, - type: str, - role_arn: str, - endpoint_name: Union[str, object], - model_variants: List[ModelVariantConfig], - shadow_mode_config: ShadowModeConfig, - schedule: Optional[InferenceExperimentSchedule] = Unassigned(), - description: Optional[str] = Unassigned(), - data_storage_config: Optional[InferenceExperimentDataStorageConfig] = Unassigned(), - kms_key: Optional[str] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), + base_image: str, + client_token: str, + image_name: Union[str, object], + aliases: Optional[List[str]] = Unassigned(), + vendor_guidance: Optional[str] = Unassigned(), + job_type: Optional[str] = Unassigned(), + ml_framework: Optional[str] = Unassigned(), + programming_lang: Optional[str] = Unassigned(), + processor: Optional[str] = Unassigned(), + horovod: Optional[bool] = Unassigned(), + release_notes: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["InferenceExperiment"]: + ) -> Optional["ImageVersion"]: """ - Create a InferenceExperiment resource + Create a ImageVersion resource Parameters: - name: The name for the inference experiment. - type: The type of the inference experiment that you want to run. The following types of experiments are possible: ShadowMode: You can use this type to validate a shadow variant. For more information, see Shadow tests. - role_arn: The ARN of the IAM role that Amazon SageMaker can assume to access model artifacts and container images, and manage Amazon SageMaker Inference endpoints for model deployment. - endpoint_name: The name of the Amazon SageMaker endpoint on which you want to run the inference experiment. - model_variants: An array of ModelVariantConfig objects. There is one for each variant in the inference experiment. Each ModelVariantConfig object in the array describes the infrastructure configuration for the corresponding variant. - shadow_mode_config: The configuration of ShadowMode inference experiment type. Use this field to specify a production variant which takes all the inference requests, and a shadow variant to which Amazon SageMaker replicates a percentage of the inference requests. For the shadow variant also specify the percentage of requests that Amazon SageMaker replicates. - schedule: The duration for which you want the inference experiment to run. If you don't specify this field, the experiment automatically starts immediately upon creation and concludes after 7 days. - description: A description for the inference experiment. - data_storage_config: The Amazon S3 location and configuration for storing inference request and response data. This is an optional parameter that you can use for data capture. For more information, see Capture data. - kms_key: The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint. The KmsKey can be any of the following formats: KMS key ID "1234abcd-12ab-34cd-56ef-1234567890ab" Amazon Resource Name (ARN) of a KMS key "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab" KMS key Alias "alias/ExampleAlias" Amazon Resource Name (ARN) of a KMS key Alias "arn:aws:kms:us-west-2:111122223333:alias/ExampleAlias" If you use a KMS key ID or an alias of your KMS key, the Amazon SageMaker execution role must include permissions to call kms:Encrypt. If you don't provide a KMS key ID, Amazon SageMaker uses the default KMS key for Amazon S3 for your role's account. Amazon SageMaker uses server-side encryption with KMS managed keys for OutputDataConfig. If you use a bucket policy with an s3:PutObject permission that only allows objects with server-side encryption, set the condition key of s3:x-amz-server-side-encryption to "aws:kms". For more information, see KMS managed Encryption Keys in the Amazon Simple Storage Service Developer Guide. The KMS key policy must grant permission to the IAM role that you specify in your CreateEndpoint and UpdateEndpoint requests. For more information, see Using Key Policies in Amazon Web Services KMS in the Amazon Web Services Key Management Service Developer Guide. - tags: Array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging your Amazon Web Services Resources. + base_image: The registry path of the container image to use as the starting point for this version. The path is an Amazon ECR URI in the following format: <acct-id>.dkr.ecr.<region>.amazonaws.com/<repo-name[:tag] or [@digest]> + client_token: A unique ID. If not specified, the Amazon Web Services CLI and Amazon Web Services SDKs, such as the SDK for Python (Boto3), add a unique value to the call. + image_name: The ImageName of the Image to create a version of. + aliases: A list of aliases created with the image version. + vendor_guidance: The stability of the image version, specified by the maintainer. NOT_PROVIDED: The maintainers did not provide a status for image version stability. STABLE: The image version is stable. TO_BE_ARCHIVED: The image version is set to be archived. Custom image versions that are set to be archived are automatically archived after three months. ARCHIVED: The image version is archived. Archived image versions are not searchable and are no longer actively supported. + job_type: Indicates SageMaker job type compatibility. TRAINING: The image version is compatible with SageMaker training jobs. INFERENCE: The image version is compatible with SageMaker inference jobs. NOTEBOOK_KERNEL: The image version is compatible with SageMaker notebook kernels. + ml_framework: The machine learning framework vended in the image version. + programming_lang: The supported programming language and its version. + processor: Indicates CPU or GPU compatibility. CPU: The image version is compatible with CPU. GPU: The image version is compatible with GPU. + horovod: Indicates Horovod compatibility. + release_notes: The maintainer description of the image version. session: Boto3 session. region: Region name. Returns: - The InferenceExperiment resource. + The ImageVersion resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14396,32 +14407,33 @@ def create( ``` ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating inference_experiment resource.") + logger.info("Creating image_version resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "Name": name, - "Type": type, - "Schedule": schedule, - "Description": description, - "RoleArn": role_arn, - "EndpointName": endpoint_name, - "ModelVariants": model_variants, - "DataStorageConfig": data_storage_config, - "ShadowModeConfig": shadow_mode_config, - "KmsKey": kms_key, - "Tags": tags, + "BaseImage": base_image, + "ClientToken": client_token, + "ImageName": image_name, + "Aliases": aliases, + "VendorGuidance": vendor_guidance, + "JobType": job_type, + "MLFramework": ml_framework, + "ProgrammingLang": programming_lang, + "Processor": processor, + "Horovod": horovod, + "ReleaseNotes": release_notes, } operation_input_args = Base.populate_chained_attributes( - resource_name="InferenceExperiment", operation_input_args=operation_input_args + resource_name="ImageVersion", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -14430,29 +14442,33 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_inference_experiment(**operation_input_args) + response = client.create_image_version(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(name=name, session=session, region=region) + return cls.get(image_name=image_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - name: str, + image_name: str, + version: Optional[int] = Unassigned(), + alias: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["InferenceExperiment"]: + ) -> Optional["ImageVersion"]: """ - Get a InferenceExperiment resource + Get a ImageVersion resource Parameters: - name: The name of the inference experiment to describe. + image_name: The name of the image. + version: The version of the image. If not specified, the latest version is described. + alias: The alias of the image version. session: Boto3 session. region: Region name. Returns: - The InferenceExperiment resource. + The ImageVersion resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14468,7 +14484,9 @@ def get( """ operation_input_args = { - "Name": name, + "ImageName": image_name, + "Version": version, + "Alias": alias, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -14477,24 +14495,25 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_inference_experiment(**operation_input_args) + response = client.describe_image_version(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeInferenceExperimentResponse") - inference_experiment = cls(**transformed_response) - return inference_experiment + transformed_response = transform(response, "DescribeImageVersionResponse") + image_version = cls(**transformed_response) + return image_version @Base.add_validate_call def refresh( self, - ) -> Optional["InferenceExperiment"]: + alias: Optional[str] = Unassigned(), + ) -> Optional["ImageVersion"]: """ - Refresh a InferenceExperiment resource + Refresh a ImageVersion resource Returns: - The InferenceExperiment resource. + The ImageVersion resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14510,34 +14529,46 @@ def refresh( """ operation_input_args = { - "Name": self.name, + "ImageName": self.image_name, + "Version": self.version, + "Alias": alias, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_inference_experiment(**operation_input_args) + response = client.describe_image_version(**operation_input_args) # deserialize response and update self - transform(response, "DescribeInferenceExperimentResponse", self) + transform(response, "DescribeImageVersionResponse", self) return self - @populate_inputs_decorator @Base.add_validate_call def update( self, - schedule: Optional[InferenceExperimentSchedule] = Unassigned(), - description: Optional[str] = Unassigned(), - model_variants: Optional[List[ModelVariantConfig]] = Unassigned(), - data_storage_config: Optional[InferenceExperimentDataStorageConfig] = Unassigned(), - shadow_mode_config: Optional[ShadowModeConfig] = Unassigned(), - ) -> Optional["InferenceExperiment"]: + alias: Optional[str] = Unassigned(), + version: Optional[int] = Unassigned(), + aliases_to_add: Optional[List[str]] = Unassigned(), + aliases_to_delete: Optional[List[str]] = Unassigned(), + vendor_guidance: Optional[str] = Unassigned(), + job_type: Optional[str] = Unassigned(), + ml_framework: Optional[str] = Unassigned(), + programming_lang: Optional[str] = Unassigned(), + processor: Optional[str] = Unassigned(), + horovod: Optional[bool] = Unassigned(), + release_notes: Optional[str] = Unassigned(), + ) -> Optional["ImageVersion"]: """ - Update a InferenceExperiment resource + Update a ImageVersion resource + + Parameters: + alias: The alias of the image version. + aliases_to_add: A list of aliases to add. + aliases_to_delete: A list of aliases to delete. Returns: - The InferenceExperiment resource. + The ImageVersion resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14549,20 +14580,26 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceInUse: Resource being accessed is in use. ResourceNotFound: Resource being access is not found. """ - logger.info("Updating inference_experiment resource.") + logger.info("Updating image_version resource.") client = Base.get_sagemaker_client() operation_input_args = { - "Name": self.name, - "Schedule": schedule, - "Description": description, - "ModelVariants": model_variants, - "DataStorageConfig": data_storage_config, - "ShadowModeConfig": shadow_mode_config, + "ImageName": self.image_name, + "Alias": alias, + "Version": version, + "AliasesToAdd": aliases_to_add, + "AliasesToDelete": aliases_to_delete, + "VendorGuidance": vendor_guidance, + "JobType": job_type, + "MLFramework": ml_framework, + "ProgrammingLang": programming_lang, + "Processor": processor, + "Horovod": horovod, + "ReleaseNotes": release_notes, } logger.debug(f"Input request: {operation_input_args}") # serialize the input request @@ -14570,7 +14607,7 @@ def update( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_inference_experiment(**operation_input_args) + response = client.update_image_version(**operation_input_args) logger.debug(f"Response: {response}") self.refresh() @@ -14579,9 +14616,10 @@ def update( @Base.add_validate_call def delete( self, + alias: Optional[str] = Unassigned(), ) -> None: """ - Delete a InferenceExperiment resource + Delete a ImageVersion resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14593,77 +14631,34 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceInUse: Resource being accessed is in use. ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "Name": self.name, + "ImageName": self.image_name, + "Version": self.version, + "Alias": alias, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_inference_experiment(**operation_input_args) + client.delete_image_version(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - @Base.add_validate_call - def stop(self) -> None: - """ - Stop a InferenceExperiment resource - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. - ResourceNotFound: Resource being access is not found. - """ - - client = SageMakerClient().client - - operation_input_args = { - "Name": self.name, - "ModelVariantActions": self.model_variant_actions, - "DesiredModelVariants": self.desired_model_variants, - "DesiredState": self.desired_state, - "Reason": self.reason, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client.stop_inference_experiment(**operation_input_args) - - logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") - @Base.add_validate_call def wait_for_status( self, - target_status: Literal[ - "Creating", - "Created", - "Updating", - "Running", - "Starting", - "Stopping", - "Completed", - "Cancelled", - ], + target_status: Literal["CREATING", "CREATED", "CREATE_FAILED", "DELETING", "DELETE_FAILED"], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a InferenceExperiment resource to reach certain status. + Wait for a ImageVersion resource to reach certain status. Parameters: target_status: The status to wait for. @@ -14682,9 +14677,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task( - f"Waiting for InferenceExperiment to reach [bold]{target_status} status..." - ) + progress.add_task(f"Waiting for ImageVersion to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -14697,55 +14690,36 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.status + current_status = self.image_version_status status.update(f"Current status: [bold]{current_status}") if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") return - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="InferenceExperiment", status=current_status + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="ImageVersion", + status=current_status, + reason=self.failure_reason, ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="ImageVersion", status=current_status) time.sleep(poll) - @classmethod @Base.add_validate_call - def get_all( - cls, - name_contains: Optional[str] = Unassigned(), - type: Optional[str] = Unassigned(), - status_equals: Optional[str] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["InferenceExperiment"]: + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: """ - Get all InferenceExperiment resources + Wait for a ImageVersion resource to be deleted. Parameters: - name_contains: Selects inference experiments whose names contain this name. - type: Selects inference experiments of this type. For the possible types of inference experiments, see CreateInferenceExperiment. - status_equals: Selects inference experiments which are in this status. For the possible statuses, see DescribeInferenceExperiment. - creation_time_after: Selects inference experiments which were created after this timestamp. - creation_time_before: Selects inference experiments which were created before this timestamp. - last_modified_time_after: Selects inference experiments which were last modified after this timestamp. - last_modified_time_before: Selects inference experiments which were last modified before this timestamp. - sort_by: The column by which to sort the listed inference experiments. - sort_order: The direction of sorting (ascending or descending). - next_token: The response from the last list when returning a list large enough to need tokening. - max_results: The maximum number of results to select. - session: Boto3 session. - region: Region name. - - Returns: - Iterator for listed InferenceExperiment resources. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14757,78 +14731,89 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. """ + start_time = time.time() - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), ) + progress.add_task("Waiting for ImageVersion to be deleted...") + status = Status("Current status:") - operation_input_args = { - "NameContains": name_contains, - "Type": type, - "StatusEquals": status_equals, - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "LastModifiedTimeBefore": last_modified_time_before, - "SortBy": sort_by, - "SortOrder": sort_order, - } - - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.image_version_status + status.update(f"Current status: [bold]{current_status}") - return ResourceIterator( - client=client, - list_method="list_inference_experiments", - summaries_key="InferenceExperiments", - summary_name="InferenceExperimentSummary", - resource_cls=InferenceExperiment, - list_method_kwargs=operation_input_args, - ) + if ( + "delete_failed" in current_status.lower() + or "deletefailed" in current_status.lower() + ): + raise DeleteFailedStatusError( + resource_type="ImageVersion", reason=self.failure_reason + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="ImageVersion", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) -class InferenceRecommendationsJob(Base): +class InferenceComponent(Base): """ - Class representing resource InferenceRecommendationsJob + Class representing resource InferenceComponent Attributes: - job_name: The name of the job. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. - job_type: The job type that you provided when you initiated the job. - job_arn: The Amazon Resource Name (ARN) of the job. - role_arn: The Amazon Resource Name (ARN) of the Amazon Web Services Identity and Access Management (IAM) role you provided when you initiated the job. - status: The status of the job. - creation_time: A timestamp that shows when the job was created. - last_modified_time: A timestamp that shows when the job was last modified. - input_config: Returns information about the versioned model package Amazon Resource Name (ARN), the traffic pattern, and endpoint configurations you provided when you initiated the job. - job_description: The job description that you provided when you initiated the job. - completion_time: A timestamp that shows when the job completed. - failure_reason: If the job fails, provides information why the job failed. - stopping_conditions: The stopping conditions that you provided when you initiated the job. - inference_recommendations: The recommendations made by Inference Recommender. - endpoint_performances: The performance results from running an Inference Recommender job on an existing endpoint. + inference_component_name: The name of the inference component. + inference_component_arn: The Amazon Resource Name (ARN) of the inference component. + endpoint_name: The name of the endpoint that hosts the inference component. + endpoint_arn: The Amazon Resource Name (ARN) of the endpoint that hosts the inference component. + creation_time: The time when the inference component was created. + last_modified_time: The time when the inference component was last updated. + variant_name: The name of the production variant that hosts the inference component. + failure_reason: If the inference component status is Failed, the reason for the failure. + specification: Details about the resources that are deployed with this inference component. + runtime_config: Details about the runtime settings for the model that is deployed with the inference component. + inference_component_status: The status of the inference component. """ - job_name: str - job_description: Optional[str] = Unassigned() - job_type: Optional[str] = Unassigned() - job_arn: Optional[str] = Unassigned() - role_arn: Optional[str] = Unassigned() - status: Optional[str] = Unassigned() + inference_component_name: str + inference_component_arn: Optional[str] = Unassigned() + endpoint_name: Optional[str] = Unassigned() + endpoint_arn: Optional[str] = Unassigned() + variant_name: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() + specification: Optional[InferenceComponentSpecificationSummary] = Unassigned() + runtime_config: Optional[InferenceComponentRuntimeConfigSummary] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - completion_time: Optional[datetime.datetime] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - failure_reason: Optional[str] = Unassigned() - input_config: Optional[RecommendationJobInputConfig] = Unassigned() - stopping_conditions: Optional[RecommendationJobStoppingConditions] = Unassigned() - inference_recommendations: Optional[List[InferenceRecommendation]] = Unassigned() - endpoint_performances: Optional[List[EndpointPerformance]] = Unassigned() + inference_component_status: Optional[str] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "inference_recommendations_job_name" + resource_name = "inference_component_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -14839,64 +14824,37 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object inference_recommendations_job") + logger.error("Name attribute not found for object inference_component") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "role_arn": {"type": "string"}, - "input_config": { - "volume_kms_key_id": {"type": "string"}, - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - }, - }, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "InferenceRecommendationsJob", **kwargs - ), - ) - - return wrapper - @classmethod - @populate_inputs_decorator @Base.add_validate_call def create( cls, - job_name: str, - job_type: str, - role_arn: str, - input_config: RecommendationJobInputConfig, - job_description: Optional[str] = Unassigned(), - stopping_conditions: Optional[RecommendationJobStoppingConditions] = Unassigned(), - output_config: Optional[RecommendationJobOutputConfig] = Unassigned(), + inference_component_name: str, + endpoint_name: Union[str, object], + specification: InferenceComponentSpecification, + variant_name: Optional[str] = Unassigned(), + runtime_config: Optional[InferenceComponentRuntimeConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["InferenceRecommendationsJob"]: + ) -> Optional["InferenceComponent"]: """ - Create a InferenceRecommendationsJob resource + Create a InferenceComponent resource Parameters: - job_name: A name for the recommendation job. The name must be unique within the Amazon Web Services Region and within your Amazon Web Services account. The job name is passed down to the resources created by the recommendation job. The names of resources (such as the model, endpoint configuration, endpoint, and compilation) that are prefixed with the job name are truncated at 40 characters. - job_type: Defines the type of recommendation job. Specify Default to initiate an instance recommendation and Advanced to initiate a load test. If left unspecified, Amazon SageMaker Inference Recommender will run an instance recommendation (DEFAULT) job. - role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. - input_config: Provides information about the versioned model package Amazon Resource Name (ARN), the traffic pattern, and endpoint configurations. - job_description: Description of the recommendation job. - stopping_conditions: A set of conditions for stopping a recommendation job. If any of the conditions are met, the job is automatically stopped. - output_config: Provides information about the output artifacts and the KMS key to use for Amazon S3 server-side encryption. - tags: The metadata that you apply to Amazon Web Services resources to help you categorize and organize them. Each tag consists of a key and a value, both of which you define. For more information, see Tagging Amazon Web Services Resources in the Amazon Web Services General Reference. + inference_component_name: A unique name to assign to the inference component. + endpoint_name: The name of an existing endpoint where you host the inference component. + specification: Details about the resources to deploy with this inference component, including the model, container, and compute resources. + variant_name: The name of an existing production variant where you host the inference component. + runtime_config: Runtime settings for a model that is deployed with an inference component. + tags: A list of key-value pairs associated with the model. For more information, see Tagging Amazon Web Services resources in the Amazon Web Services General Reference. session: Boto3 session. region: Region name. Returns: - The InferenceRecommendationsJob resource. + The InferenceComponent resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14908,31 +14866,28 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating inference_recommendations_job resource.") + logger.info("Creating inference_component resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "JobName": job_name, - "JobType": job_type, - "RoleArn": role_arn, - "InputConfig": input_config, - "JobDescription": job_description, - "StoppingConditions": stopping_conditions, - "OutputConfig": output_config, + "InferenceComponentName": inference_component_name, + "EndpointName": endpoint_name, + "VariantName": variant_name, + "Specification": specification, + "RuntimeConfig": runtime_config, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="InferenceRecommendationsJob", operation_input_args=operation_input_args + resource_name="InferenceComponent", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -14941,29 +14896,31 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_inference_recommendations_job(**operation_input_args) + response = client.create_inference_component(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(job_name=job_name, session=session, region=region) + return cls.get( + inference_component_name=inference_component_name, session=session, region=region + ) @classmethod @Base.add_validate_call def get( cls, - job_name: str, + inference_component_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["InferenceRecommendationsJob"]: + ) -> Optional["InferenceComponent"]: """ - Get a InferenceRecommendationsJob resource + Get a InferenceComponent resource Parameters: - job_name: The name of the job. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + inference_component_name: The name of the inference component. session: Boto3 session. region: Region name. Returns: - The InferenceRecommendationsJob resource. + The InferenceComponent resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -14975,11 +14932,10 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "JobName": job_name, + "InferenceComponentName": inference_component_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -14988,24 +14944,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_inference_recommendations_job(**operation_input_args) + response = client.describe_inference_component(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeInferenceRecommendationsJobResponse") - inference_recommendations_job = cls(**transformed_response) - return inference_recommendations_job + transformed_response = transform(response, "DescribeInferenceComponentOutput") + inference_component = cls(**transformed_response) + return inference_component @Base.add_validate_call def refresh( self, - ) -> Optional["InferenceRecommendationsJob"]: + ) -> Optional["InferenceComponent"]: """ - Refresh a InferenceRecommendationsJob resource + Refresh a InferenceComponent resource Returns: - The InferenceRecommendationsJob resource. + The InferenceComponent resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15017,27 +14973,33 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "JobName": self.job_name, + "InferenceComponentName": self.inference_component_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_inference_recommendations_job(**operation_input_args) + response = client.describe_inference_component(**operation_input_args) # deserialize response and update self - transform(response, "DescribeInferenceRecommendationsJobResponse", self) + transform(response, "DescribeInferenceComponentOutput", self) return self @Base.add_validate_call - def stop(self) -> None: + def update( + self, + specification: Optional[InferenceComponentSpecification] = Unassigned(), + runtime_config: Optional[InferenceComponentRuntimeConfig] = Unassigned(), + ) -> Optional["InferenceComponent"]: """ - Stop a InferenceRecommendationsJob resource + Update a InferenceComponent resource + + Returns: + The InferenceComponent resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15049,50 +15011,91 @@ def stop(self) -> None: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. """ - client = SageMakerClient().client + logger.info("Updating inference_component resource.") + client = Base.get_sagemaker_client() operation_input_args = { - "JobName": self.job_name, + "InferenceComponentName": self.inference_component_name, + "Specification": specification, + "RuntimeConfig": runtime_config, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.stop_inference_recommendations_job(**operation_input_args) + # create the resource + response = client.update_inference_component(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() - logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") + return self @Base.add_validate_call - def wait( + def delete( self, - poll: int = 5, - timeout: Optional[int] = None, ) -> None: """ - Wait for a InferenceRecommendationsJob resource. - - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + Delete a InferenceComponent resource Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - - """ - terminal_states = ["COMPLETED", "FAILED", "STOPPED", "DELETED"] - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "InferenceComponentName": self.inference_component_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_inference_component(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + + @Base.add_validate_call + def wait_for_status( + self, + target_status: Literal["InService", "Creating", "Updating", "Failed", "Deleting"], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a InferenceComponent resource to reach certain status. + + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for InferenceRecommendationsJob...") + progress.add_task( + f"Waiting for InferenceComponent to reach [bold]{target_status} status..." + ) status = Status("Current status:") with Live( @@ -15105,24 +15108,23 @@ def wait( ): while True: self.refresh() - current_status = self.status + current_status = self.inference_component_status status.update(f"Current status: [bold]{current_status}") - if current_status in terminal_states: + if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") - - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="InferenceRecommendationsJob", - status=current_status, - reason=self.failure_reason, - ) - return + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="InferenceComponent", + status=current_status, + reason=self.failure_reason, + ) + if timeout is not None and time.time() - start_time >= timeout: raise TimeoutExceededError( - resouce_type="InferenceRecommendationsJob", status=current_status + resouce_type="InferenceComponent", status=current_status ) time.sleep(poll) @@ -15133,7 +15135,7 @@ def wait_for_delete( timeout: Optional[int] = None, ) -> None: """ - Wait for a InferenceRecommendationsJob resource to be deleted. + Wait for a InferenceComponent resource to be deleted. Parameters: poll: The number of seconds to wait between each poll. @@ -15160,7 +15162,7 @@ def wait_for_delete( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for InferenceRecommendationsJob to be deleted...") + progress.add_task("Waiting for InferenceComponent to be deleted...") status = Status("Current status:") with Live( @@ -15173,16 +15175,12 @@ def wait_for_delete( while True: try: self.refresh() - current_status = self.status + current_status = self.inference_component_status status.update(f"Current status: [bold]{current_status}") - if current_status.lower() == "deleted": - print("Resource was deleted.") - return - if timeout is not None and time.time() - start_time >= timeout: raise TimeoutExceededError( - resouce_type="InferenceRecommendationsJob", status=current_status + resouce_type="InferenceComponent", status=current_status ) except botocore.exceptions.ClientError as e: error_code = e.response["Error"]["Code"] @@ -15197,40 +15195,40 @@ def wait_for_delete( @Base.add_validate_call def get_all( cls, - creation_time_after: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - name_contains: Optional[str] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), status_equals: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - model_name_equals: Optional[str] = Unassigned(), - model_package_version_arn_equals: Optional[str] = Unassigned(), + endpoint_name_equals: Optional[str] = Unassigned(), + variant_name_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["InferenceRecommendationsJob"]: + ) -> ResourceIterator["InferenceComponent"]: """ - Get all InferenceRecommendationsJob resources + Get all InferenceComponent resources Parameters: - creation_time_after: A filter that returns only jobs created after the specified time (timestamp). - creation_time_before: A filter that returns only jobs created before the specified time (timestamp). - last_modified_time_after: A filter that returns only jobs that were last modified after the specified time (timestamp). - last_modified_time_before: A filter that returns only jobs that were last modified before the specified time (timestamp). - name_contains: A string in the job name. This filter returns only recommendations whose name contains the specified string. - status_equals: A filter that retrieves only inference recommendations jobs with a specific status. - sort_by: The parameter by which to sort the results. - sort_order: The sort order for the results. - next_token: If the response to a previous ListInferenceRecommendationsJobsRequest request was truncated, the response includes a NextToken. To retrieve the next set of recommendations, use the token in the next request. - max_results: The maximum number of recommendations to return in the response. - model_name_equals: A filter that returns only jobs that were created for this model. - model_package_version_arn_equals: A filter that returns only jobs that were created for this versioned model package. + sort_by: The field by which to sort the inference components in the response. The default is CreationTime. + sort_order: The sort order for results. The default is Descending. + next_token: A token that you use to get the next set of results following a truncated response. If the response to the previous request was truncated, that response provides the value for this token. + max_results: The maximum number of inference components to return in the response. This value defaults to 10. + name_contains: Filters the results to only those inference components with a name that contains the specified string. + creation_time_before: Filters the results to only those inference components that were created before the specified time. + creation_time_after: Filters the results to only those inference components that were created after the specified time. + last_modified_time_before: Filters the results to only those inference components that were updated before the specified time. + last_modified_time_after: Filters the results to only those inference components that were updated after the specified time. + status_equals: Filters the results to only those inference components with the specified status. + endpoint_name_equals: An endpoint name to filter the listed inference components. The response includes only those inference components that are hosted at the specified endpoint. + variant_name_equals: A production variant name to filter the listed inference components. The response includes only those inference components that are hosted at the specified variant. session: Boto3 session. region: Region name. Returns: - Iterator for listed InferenceRecommendationsJob resources. + Iterator for listed InferenceComponent resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15249,16 +15247,16 @@ def get_all( ) operation_input_args = { - "CreationTimeAfter": creation_time_after, + "SortBy": sort_by, + "SortOrder": sort_order, + "NameContains": name_contains, "CreationTimeBefore": creation_time_before, - "LastModifiedTimeAfter": last_modified_time_after, + "CreationTimeAfter": creation_time_after, "LastModifiedTimeBefore": last_modified_time_before, - "NameContains": name_contains, + "LastModifiedTimeAfter": last_modified_time_after, "StatusEquals": status_equals, - "SortBy": sort_by, - "SortOrder": sort_order, - "ModelNameEquals": model_name_equals, - "ModelPackageVersionArnEquals": model_package_version_arn_equals, + "EndpointNameEquals": endpoint_name_equals, + "VariantNameEquals": variant_name_equals, } # serialize the input request @@ -15267,33 +15265,28 @@ def get_all( return ResourceIterator( client=client, - list_method="list_inference_recommendations_jobs", - summaries_key="InferenceRecommendationsJobs", - summary_name="InferenceRecommendationsJob", - resource_cls=InferenceRecommendationsJob, + list_method="list_inference_components", + summaries_key="InferenceComponents", + summary_name="InferenceComponentSummary", + resource_cls=InferenceComponent, list_method_kwargs=operation_input_args, ) @Base.add_validate_call - def get_all_steps( + def update_runtime_configs( self, - step_type: Optional[str] = Unassigned(), + desired_runtime_config: InferenceComponentRuntimeConfig, session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator[InferenceRecommendationsJobStep]: + ) -> None: """ - Returns a list of the subtasks for an Inference Recommender job. + Runtime settings for a model that is deployed with an inference component. Parameters: - step_type: A filter to return details about the specified type of subtask. BENCHMARK: Evaluate the performance of your model on different instance types. - max_results: The maximum number of results to return. - next_token: A token that you can specify to return more results from the list. Specify this field if you have a token that was returned from a previous request. + desired_runtime_config: Runtime settings for a model that is deployed with an inference component. session: Boto3 session. region: Region name. - Returns: - Iterator for listed InferenceRecommendationsJobStep. - Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -15304,13 +15297,12 @@ def get_all_steps( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. """ operation_input_args = { - "JobName": self.job_name, - "Status": self.status, - "StepType": step_type, + "InferenceComponentName": self.inference_component_name, + "DesiredRuntimeConfig": desired_runtime_config, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -15320,64 +15312,55 @@ def get_all_steps( session=session, region_name=region, service_name="sagemaker" ) - return ResourceIterator( - client=client, - list_method="list_inference_recommendations_job_steps", - summaries_key="Steps", - summary_name="InferenceRecommendationsJobStep", - resource_cls=InferenceRecommendationsJobStep, - list_method_kwargs=operation_input_args, - ) + logger.debug(f"Calling update_inference_component_runtime_config API") + response = client.update_inference_component_runtime_config(**operation_input_args) + logger.debug(f"Response: {response}") -class LabelingJob(Base): +class InferenceExperiment(Base): """ - Class representing resource LabelingJob + Class representing resource InferenceExperiment Attributes: - labeling_job_status: The processing status of the labeling job. - label_counters: Provides a breakdown of the number of data objects labeled by humans, the number of objects labeled by machine, the number of objects than couldn't be labeled, and the total number of objects labeled. - creation_time: The date and time that the labeling job was created. - last_modified_time: The date and time that the labeling job was last updated. - job_reference_code: A unique identifier for work done as part of a labeling job. - labeling_job_name: The name assigned to the labeling job when it was created. - labeling_job_arn: The Amazon Resource Name (ARN) of the labeling job. - input_config: Input configuration information for the labeling job, such as the Amazon S3 location of the data objects and the location of the manifest file that describes the data objects. - output_config: The location of the job's output data and the Amazon Web Services Key Management Service key ID for the key used to encrypt the output data, if any. - role_arn: The Amazon Resource Name (ARN) that SageMaker assumes to perform tasks on your behalf during data labeling. - human_task_config: Configuration information required for human workers to complete a labeling task. - failure_reason: If the job failed, the reason that it failed. - label_attribute_name: The attribute used as the label in the output manifest file. - label_category_config_s3_uri: The S3 location of the JSON file that defines the categories used to label data objects. Please note the following label-category limits: Semantic segmentation labeling jobs using automated labeling: 20 labels Box bounding labeling jobs (all): 10 labels The file is a JSON structure in the following format: { "document-version": "2018-11-28" "labels": [ { "label": "label 1" }, { "label": "label 2" }, ... { "label": "label n" } ] } - stopping_conditions: A set of conditions for stopping a labeling job. If any of the conditions are met, the job is automatically stopped. - labeling_job_algorithms_config: Configuration information for automated data labeling. - tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. - labeling_job_output: The location of the output produced by the labeling job. + arn: The ARN of the inference experiment being described. + name: The name of the inference experiment. + type: The type of the inference experiment. + status: The status of the inference experiment. The following are the possible statuses for an inference experiment: Creating - Amazon SageMaker is creating your experiment. Created - Amazon SageMaker has finished the creation of your experiment and will begin the experiment at the scheduled time. Updating - When you make changes to your experiment, your experiment shows as updating. Starting - Amazon SageMaker is beginning your experiment. Running - Your experiment is in progress. Stopping - Amazon SageMaker is stopping your experiment. Completed - Your experiment has completed. Cancelled - When you conclude your experiment early using the StopInferenceExperiment API, or if any operation fails with an unexpected error, it shows as cancelled. + endpoint_metadata: The metadata of the endpoint on which the inference experiment ran. + model_variants: An array of ModelVariantConfigSummary objects. There is one for each variant in the inference experiment. Each ModelVariantConfigSummary object in the array describes the infrastructure configuration for deploying the corresponding variant. + schedule: The duration for which the inference experiment ran or will run. + status_reason: The error message or client-specified Reason from the StopInferenceExperiment API, that explains the status of the inference experiment. + description: The description of the inference experiment. + creation_time: The timestamp at which you created the inference experiment. + completion_time: The timestamp at which the inference experiment was completed. + last_modified_time: The timestamp at which you last modified the inference experiment. + role_arn: The ARN of the IAM role that Amazon SageMaker can assume to access model artifacts and container images, and manage Amazon SageMaker Inference endpoints for model deployment. + data_storage_config: The Amazon S3 location and configuration for storing inference request and response data. + shadow_mode_config: The configuration of ShadowMode inference experiment type, which shows the production variant that takes all the inference requests, and the shadow variant to which Amazon SageMaker replicates a percentage of the inference requests. For the shadow variant it also shows the percentage of requests that Amazon SageMaker replicates. + kms_key: The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint. For more information, see CreateInferenceExperiment. """ - labeling_job_name: str - labeling_job_status: Optional[str] = Unassigned() - label_counters: Optional[LabelCounters] = Unassigned() - failure_reason: Optional[str] = Unassigned() + name: str + arn: Optional[str] = Unassigned() + type: Optional[str] = Unassigned() + schedule: Optional[InferenceExperimentSchedule] = Unassigned() + status: Optional[str] = Unassigned() + status_reason: Optional[str] = Unassigned() + description: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() + completion_time: Optional[datetime.datetime] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - job_reference_code: Optional[str] = Unassigned() - labeling_job_arn: Optional[str] = Unassigned() - label_attribute_name: Optional[str] = Unassigned() - input_config: Optional[LabelingJobInputConfig] = Unassigned() - output_config: Optional[LabelingJobOutputConfig] = Unassigned() role_arn: Optional[str] = Unassigned() - label_category_config_s3_uri: Optional[str] = Unassigned() - stopping_conditions: Optional[LabelingJobStoppingConditions] = Unassigned() - labeling_job_algorithms_config: Optional[LabelingJobAlgorithmsConfig] = Unassigned() - human_task_config: Optional[HumanTaskConfig] = Unassigned() - tags: Optional[List[Tag]] = Unassigned() - labeling_job_output: Optional[LabelingJobOutput] = Unassigned() + endpoint_metadata: Optional[EndpointMetadata] = Unassigned() + model_variants: Optional[List[ModelVariantConfigSummary]] = Unassigned() + data_storage_config: Optional[InferenceExperimentDataStorageConfig] = Unassigned() + shadow_mode_config: Optional[ShadowModeConfig] = Unassigned() + kms_key: Optional[str] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "labeling_job_name" + resource_name = "inference_experiment_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -15388,38 +15371,21 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object labeling_job") + logger.error("Name attribute not found for object inference_experiment") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "input_config": { - "data_source": {"s3_data_source": {"manifest_s3_uri": {"type": "string"}}} - }, - "output_config": { - "s3_output_path": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, "role_arn": {"type": "string"}, - "human_task_config": {"ui_config": {"ui_template_s3_uri": {"type": "string"}}}, - "label_category_config_s3_uri": {"type": "string"}, - "labeling_job_algorithms_config": { - "labeling_job_resource_config": { - "volume_kms_key_id": {"type": "string"}, - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - }, - } - }, - "labeling_job_output": {"output_dataset_s3_uri": {"type": "string"}}, + "data_storage_config": {"kms_key": {"type": "string"}}, + "kms_key": {"type": "string"}, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "LabelingJob", **kwargs + config_schema_for_resource, "InferenceExperiment", **kwargs ), ) @@ -15430,38 +15396,40 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - labeling_job_name: str, - label_attribute_name: str, - input_config: LabelingJobInputConfig, - output_config: LabelingJobOutputConfig, + name: str, + type: str, role_arn: str, - human_task_config: HumanTaskConfig, - label_category_config_s3_uri: Optional[str] = Unassigned(), - stopping_conditions: Optional[LabelingJobStoppingConditions] = Unassigned(), - labeling_job_algorithms_config: Optional[LabelingJobAlgorithmsConfig] = Unassigned(), + endpoint_name: Union[str, object], + model_variants: List[ModelVariantConfig], + shadow_mode_config: ShadowModeConfig, + schedule: Optional[InferenceExperimentSchedule] = Unassigned(), + description: Optional[str] = Unassigned(), + data_storage_config: Optional[InferenceExperimentDataStorageConfig] = Unassigned(), + kms_key: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["LabelingJob"]: + ) -> Optional["InferenceExperiment"]: """ - Create a LabelingJob resource + Create a InferenceExperiment resource Parameters: - labeling_job_name: The name of the labeling job. This name is used to identify the job in a list of labeling jobs. Labeling job names must be unique within an Amazon Web Services account and region. LabelingJobName is not case sensitive. For example, Example-job and example-job are considered the same labeling job name by Ground Truth. - label_attribute_name: The attribute name to use for the label in the output manifest file. This is the key for the key/value pair formed with the label that a worker assigns to the object. The LabelAttributeName must meet the following requirements. The name can't end with "-metadata". If you are using one of the following built-in task types, the attribute name must end with "-ref". If the task type you are using is not listed below, the attribute name must not end with "-ref". Image semantic segmentation (SemanticSegmentation), and adjustment (AdjustmentSemanticSegmentation) and verification (VerificationSemanticSegmentation) labeling jobs for this task type. Video frame object detection (VideoObjectDetection), and adjustment and verification (AdjustmentVideoObjectDetection) labeling jobs for this task type. Video frame object tracking (VideoObjectTracking), and adjustment and verification (AdjustmentVideoObjectTracking) labeling jobs for this task type. 3D point cloud semantic segmentation (3DPointCloudSemanticSegmentation), and adjustment and verification (Adjustment3DPointCloudSemanticSegmentation) labeling jobs for this task type. 3D point cloud object tracking (3DPointCloudObjectTracking), and adjustment and verification (Adjustment3DPointCloudObjectTracking) labeling jobs for this task type. If you are creating an adjustment or verification labeling job, you must use a different LabelAttributeName than the one used in the original labeling job. The original labeling job is the Ground Truth labeling job that produced the labels that you want verified or adjusted. To learn more about adjustment and verification labeling jobs, see Verify and Adjust Labels. - input_config: Input data for the labeling job, such as the Amazon S3 location of the data objects and the location of the manifest file that describes the data objects. You must specify at least one of the following: S3DataSource or SnsDataSource. Use SnsDataSource to specify an SNS input topic for a streaming labeling job. If you do not specify and SNS input topic ARN, Ground Truth will create a one-time labeling job that stops after all data objects in the input manifest file have been labeled. Use S3DataSource to specify an input manifest file for both streaming and one-time labeling jobs. Adding an S3DataSource is optional if you use SnsDataSource to create a streaming labeling job. If you use the Amazon Mechanical Turk workforce, your input data should not include confidential information, personal information or protected health information. Use ContentClassifiers to specify that your data is free of personally identifiable information and adult content. - output_config: The location of the output data and the Amazon Web Services Key Management Service key ID for the key used to encrypt the output data, if any. - role_arn: The Amazon Resource Number (ARN) that Amazon SageMaker assumes to perform tasks on your behalf during data labeling. You must grant this role the necessary permissions so that Amazon SageMaker can successfully complete data labeling. - human_task_config: Configures the labeling task and how it is presented to workers; including, but not limited to price, keywords, and batch size (task count). - label_category_config_s3_uri: The S3 URI of the file, referred to as a label category configuration file, that defines the categories used to label the data objects. For 3D point cloud and video frame task types, you can add label category attributes and frame attributes to your label category configuration file. To learn how, see Create a Labeling Category Configuration File for 3D Point Cloud Labeling Jobs. For named entity recognition jobs, in addition to "labels", you must provide worker instructions in the label category configuration file using the "instructions" parameter: "instructions": {"shortInstruction":"<h1>Add header</h1><p>Add Instructions</p>", "fullInstruction":"<p>Add additional instructions.</p>"}. For details and an example, see Create a Named Entity Recognition Labeling Job (API) . For all other built-in task types and custom tasks, your label category configuration file must be a JSON file in the following format. Identify the labels you want to use by replacing label_1, label_2,...,label_n with your label categories. { "document-version": "2018-11-28", "labels": [{"label": "label_1"},{"label": "label_2"},...{"label": "label_n"}] } Note the following about the label category configuration file: For image classification and text classification (single and multi-label) you must specify at least two label categories. For all other task types, the minimum number of label categories required is one. Each label category must be unique, you cannot specify duplicate label categories. If you create a 3D point cloud or video frame adjustment or verification labeling job, you must include auditLabelAttributeName in the label category configuration. Use this parameter to enter the LabelAttributeName of the labeling job you want to adjust or verify annotations of. - stopping_conditions: A set of conditions for stopping the labeling job. If any of the conditions are met, the job is automatically stopped. You can use these conditions to control the cost of data labeling. - labeling_job_algorithms_config: Configures the information required to perform automated data labeling. - tags: An array of key/value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. + name: The name for the inference experiment. + type: The type of the inference experiment that you want to run. The following types of experiments are possible: ShadowMode: You can use this type to validate a shadow variant. For more information, see Shadow tests. + role_arn: The ARN of the IAM role that Amazon SageMaker can assume to access model artifacts and container images, and manage Amazon SageMaker Inference endpoints for model deployment. + endpoint_name: The name of the Amazon SageMaker endpoint on which you want to run the inference experiment. + model_variants: An array of ModelVariantConfig objects. There is one for each variant in the inference experiment. Each ModelVariantConfig object in the array describes the infrastructure configuration for the corresponding variant. + shadow_mode_config: The configuration of ShadowMode inference experiment type. Use this field to specify a production variant which takes all the inference requests, and a shadow variant to which Amazon SageMaker replicates a percentage of the inference requests. For the shadow variant also specify the percentage of requests that Amazon SageMaker replicates. + schedule: The duration for which you want the inference experiment to run. If you don't specify this field, the experiment automatically starts immediately upon creation and concludes after 7 days. + description: A description for the inference experiment. + data_storage_config: The Amazon S3 location and configuration for storing inference request and response data. This is an optional parameter that you can use for data capture. For more information, see Capture data. + kms_key: The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt data on the storage volume attached to the ML compute instance that hosts the endpoint. The KmsKey can be any of the following formats: KMS key ID "1234abcd-12ab-34cd-56ef-1234567890ab" Amazon Resource Name (ARN) of a KMS key "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab" KMS key Alias "alias/ExampleAlias" Amazon Resource Name (ARN) of a KMS key Alias "arn:aws:kms:us-west-2:111122223333:alias/ExampleAlias" If you use a KMS key ID or an alias of your KMS key, the Amazon SageMaker execution role must include permissions to call kms:Encrypt. If you don't provide a KMS key ID, Amazon SageMaker uses the default KMS key for Amazon S3 for your role's account. Amazon SageMaker uses server-side encryption with KMS managed keys for OutputDataConfig. If you use a bucket policy with an s3:PutObject permission that only allows objects with server-side encryption, set the condition key of s3:x-amz-server-side-encryption to "aws:kms". For more information, see KMS managed Encryption Keys in the Amazon Simple Storage Service Developer Guide. The KMS key policy must grant permission to the IAM role that you specify in your CreateEndpoint and UpdateEndpoint requests. For more information, see Using Key Policies in Amazon Web Services KMS in the Amazon Web Services Key Management Service Developer Guide. + tags: Array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging your Amazon Web Services Resources. session: Boto3 session. region: Region name. Returns: - The LabelingJob resource. + The InferenceExperiment resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15480,26 +15448,27 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating labeling_job resource.") + logger.info("Creating inference_experiment resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "LabelingJobName": labeling_job_name, - "LabelAttributeName": label_attribute_name, - "InputConfig": input_config, - "OutputConfig": output_config, + "Name": name, + "Type": type, + "Schedule": schedule, + "Description": description, "RoleArn": role_arn, - "LabelCategoryConfigS3Uri": label_category_config_s3_uri, - "StoppingConditions": stopping_conditions, - "LabelingJobAlgorithmsConfig": labeling_job_algorithms_config, - "HumanTaskConfig": human_task_config, + "EndpointName": endpoint_name, + "ModelVariants": model_variants, + "DataStorageConfig": data_storage_config, + "ShadowModeConfig": shadow_mode_config, + "KmsKey": kms_key, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="LabelingJob", operation_input_args=operation_input_args + resource_name="InferenceExperiment", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -15508,29 +15477,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_labeling_job(**operation_input_args) + response = client.create_inference_experiment(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(labeling_job_name=labeling_job_name, session=session, region=region) + return cls.get(name=name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - labeling_job_name: str, + name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["LabelingJob"]: + ) -> Optional["InferenceExperiment"]: """ - Get a LabelingJob resource + Get a InferenceExperiment resource Parameters: - labeling_job_name: The name of the labeling job to return information for. + name: The name of the inference experiment to describe. session: Boto3 session. region: Region name. Returns: - The LabelingJob resource. + The InferenceExperiment resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15546,7 +15515,7 @@ def get( """ operation_input_args = { - "LabelingJobName": labeling_job_name, + "Name": name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -15555,24 +15524,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_labeling_job(**operation_input_args) + response = client.describe_inference_experiment(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeLabelingJobResponse") - labeling_job = cls(**transformed_response) - return labeling_job + transformed_response = transform(response, "DescribeInferenceExperimentResponse") + inference_experiment = cls(**transformed_response) + return inference_experiment @Base.add_validate_call def refresh( self, - ) -> Optional["LabelingJob"]: + ) -> Optional["InferenceExperiment"]: """ - Refresh a LabelingJob resource + Refresh a InferenceExperiment resource Returns: - The LabelingJob resource. + The InferenceExperiment resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15588,23 +15557,110 @@ def refresh( """ operation_input_args = { - "LabelingJobName": self.labeling_job_name, + "Name": self.name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_labeling_job(**operation_input_args) + response = client.describe_inference_experiment(**operation_input_args) # deserialize response and update self - transform(response, "DescribeLabelingJobResponse", self) + transform(response, "DescribeInferenceExperimentResponse", self) + return self + + @populate_inputs_decorator + @Base.add_validate_call + def update( + self, + schedule: Optional[InferenceExperimentSchedule] = Unassigned(), + description: Optional[str] = Unassigned(), + model_variants: Optional[List[ModelVariantConfig]] = Unassigned(), + data_storage_config: Optional[InferenceExperimentDataStorageConfig] = Unassigned(), + shadow_mode_config: Optional[ShadowModeConfig] = Unassigned(), + ) -> Optional["InferenceExperiment"]: + """ + Update a InferenceExperiment resource + + Returns: + The InferenceExperiment resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceNotFound: Resource being access is not found. + """ + + logger.info("Updating inference_experiment resource.") + client = Base.get_sagemaker_client() + + operation_input_args = { + "Name": self.name, + "Schedule": schedule, + "Description": description, + "ModelVariants": model_variants, + "DataStorageConfig": data_storage_config, + "ShadowModeConfig": shadow_mode_config, + } + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.update_inference_experiment(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() + return self + @Base.add_validate_call + def delete( + self, + ) -> None: + """ + Delete a InferenceExperiment resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "Name": self.name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_inference_experiment(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call def stop(self) -> None: """ - Stop a LabelingJob resource + Stop a InferenceExperiment resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15616,32 +15672,48 @@ def stop(self) -> None: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceNotFound: Resource being access is not found. """ client = SageMakerClient().client operation_input_args = { - "LabelingJobName": self.labeling_job_name, + "Name": self.name, + "ModelVariantActions": self.model_variant_actions, + "DesiredModelVariants": self.desired_model_variants, + "DesiredState": self.desired_state, + "Reason": self.reason, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.stop_labeling_job(**operation_input_args) + client.stop_inference_experiment(**operation_input_args) logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def wait( + def wait_for_status( self, + target_status: Literal[ + "Creating", + "Created", + "Updating", + "Running", + "Starting", + "Stopping", + "Completed", + "Cancelled", + ], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a LabelingJob resource. + Wait for a InferenceExperiment resource to reach certain status. Parameters: + target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. @@ -15649,9 +15721,7 @@ def wait( TimeoutExceededError: If the resource does not reach a terminal state before the timeout. FailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. - """ - terminal_states = ["Completed", "Failed", "Stopped"] start_time = time.time() progress = Progress( @@ -15659,7 +15729,9 @@ def wait( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for LabelingJob...") + progress.add_task( + f"Waiting for InferenceExperiment to reach [bold]{target_status} status..." + ) status = Status("Current status:") with Live( @@ -15672,59 +15744,55 @@ def wait( ): while True: self.refresh() - current_status = self.labeling_job_status + current_status = self.status status.update(f"Current status: [bold]{current_status}") - if current_status in terminal_states: + if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") - - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="LabelingJob", - status=current_status, - reason=self.failure_reason, - ) - return if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="LabelingJob", status=current_status) - time.sleep(poll) - + raise TimeoutExceededError( + resouce_type="InferenceExperiment", status=current_status + ) + time.sleep(poll) + @classmethod @Base.add_validate_call def get_all( cls, + name_contains: Optional[str] = Unassigned(), + type: Optional[str] = Unassigned(), + status_equals: Optional[str] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), last_modified_time_after: Optional[datetime.datetime] = Unassigned(), last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - name_contains: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), - status_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["LabelingJob"]: + ) -> ResourceIterator["InferenceExperiment"]: """ - Get all LabelingJob resources + Get all InferenceExperiment resources Parameters: - creation_time_after: A filter that returns only labeling jobs created after the specified time (timestamp). - creation_time_before: A filter that returns only labeling jobs created before the specified time (timestamp). - last_modified_time_after: A filter that returns only labeling jobs modified after the specified time (timestamp). - last_modified_time_before: A filter that returns only labeling jobs modified before the specified time (timestamp). - max_results: The maximum number of labeling jobs to return in each page of the response. - next_token: If the result of the previous ListLabelingJobs request was truncated, the response includes a NextToken. To retrieve the next set of labeling jobs, use the token in the next request. - name_contains: A string in the labeling job name. This filter returns only labeling jobs whose name contains the specified string. - sort_by: The field to sort results by. The default is CreationTime. - sort_order: The sort order for results. The default is Ascending. - status_equals: A filter that retrieves only labeling jobs with a specific status. + name_contains: Selects inference experiments whose names contain this name. + type: Selects inference experiments of this type. For the possible types of inference experiments, see CreateInferenceExperiment. + status_equals: Selects inference experiments which are in this status. For the possible statuses, see DescribeInferenceExperiment. + creation_time_after: Selects inference experiments which were created after this timestamp. + creation_time_before: Selects inference experiments which were created before this timestamp. + last_modified_time_after: Selects inference experiments which were last modified after this timestamp. + last_modified_time_before: Selects inference experiments which were last modified before this timestamp. + sort_by: The column by which to sort the listed inference experiments. + sort_order: The direction of sorting (ascending or descending). + next_token: The response from the last list when returning a list large enough to need tokening. + max_results: The maximum number of results to select. session: Boto3 session. region: Region name. Returns: - Iterator for listed LabelingJob resources. + Iterator for listed InferenceExperiment resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15743,14 +15811,15 @@ def get_all( ) operation_input_args = { + "NameContains": name_contains, + "Type": type, + "StatusEquals": status_equals, "CreationTimeAfter": creation_time_after, "CreationTimeBefore": creation_time_before, "LastModifiedTimeAfter": last_modified_time_after, "LastModifiedTimeBefore": last_modified_time_before, - "NameContains": name_contains, "SortBy": sort_by, "SortOrder": sort_order, - "StatusEquals": status_equals, } # serialize the input request @@ -15759,42 +15828,54 @@ def get_all( return ResourceIterator( client=client, - list_method="list_labeling_jobs", - summaries_key="LabelingJobSummaryList", - summary_name="LabelingJobSummary", - resource_cls=LabelingJob, + list_method="list_inference_experiments", + summaries_key="InferenceExperiments", + summary_name="InferenceExperimentSummary", + resource_cls=InferenceExperiment, list_method_kwargs=operation_input_args, ) -class LineageGroup(Base): +class InferenceRecommendationsJob(Base): """ - Class representing resource LineageGroup + Class representing resource InferenceRecommendationsJob Attributes: - lineage_group_name: The name of the lineage group. - lineage_group_arn: The Amazon Resource Name (ARN) of the lineage group. - display_name: The display name of the lineage group. - description: The description of the lineage group. - creation_time: The creation time of lineage group. - created_by: - last_modified_time: The last modified time of the lineage group. - last_modified_by: + job_name: The name of the job. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + job_type: The job type that you provided when you initiated the job. + job_arn: The Amazon Resource Name (ARN) of the job. + role_arn: The Amazon Resource Name (ARN) of the Amazon Web Services Identity and Access Management (IAM) role you provided when you initiated the job. + status: The status of the job. + creation_time: A timestamp that shows when the job was created. + last_modified_time: A timestamp that shows when the job was last modified. + input_config: Returns information about the versioned model package Amazon Resource Name (ARN), the traffic pattern, and endpoint configurations you provided when you initiated the job. + job_description: The job description that you provided when you initiated the job. + completion_time: A timestamp that shows when the job completed. + failure_reason: If the job fails, provides information why the job failed. + stopping_conditions: The stopping conditions that you provided when you initiated the job. + inference_recommendations: The recommendations made by Inference Recommender. + endpoint_performances: The performance results from running an Inference Recommender job on an existing endpoint. """ - lineage_group_name: str - lineage_group_arn: Optional[str] = Unassigned() - display_name: Optional[str] = Unassigned() - description: Optional[str] = Unassigned() + job_name: str + job_description: Optional[str] = Unassigned() + job_type: Optional[str] = Unassigned() + job_arn: Optional[str] = Unassigned() + role_arn: Optional[str] = Unassigned() + status: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - created_by: Optional[UserContext] = Unassigned() + completion_time: Optional[datetime.datetime] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - last_modified_by: Optional[UserContext] = Unassigned() + failure_reason: Optional[str] = Unassigned() + input_config: Optional[RecommendationJobInputConfig] = Unassigned() + stopping_conditions: Optional[RecommendationJobStoppingConditions] = Unassigned() + inference_recommendations: Optional[List[InferenceRecommendation]] = Unassigned() + endpoint_performances: Optional[List[EndpointPerformance]] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "lineage_group_name" + resource_name = "inference_recommendations_job_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -15805,27 +15886,64 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object lineage_group") + logger.error("Name attribute not found for object inference_recommendations_job") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "role_arn": {"type": "string"}, + "input_config": { + "volume_kms_key_id": {"type": "string"}, + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + }, + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "InferenceRecommendationsJob", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call - def get( + def create( cls, - lineage_group_name: str, + job_name: str, + job_type: str, + role_arn: str, + input_config: RecommendationJobInputConfig, + job_description: Optional[str] = Unassigned(), + stopping_conditions: Optional[RecommendationJobStoppingConditions] = Unassigned(), + output_config: Optional[RecommendationJobOutputConfig] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["LineageGroup"]: + ) -> Optional["InferenceRecommendationsJob"]: """ - Get a LineageGroup resource + Create a InferenceRecommendationsJob resource Parameters: - lineage_group_name: The name of the lineage group. + job_name: A name for the recommendation job. The name must be unique within the Amazon Web Services Region and within your Amazon Web Services account. The job name is passed down to the resources created by the recommendation job. The names of resources (such as the model, endpoint configuration, endpoint, and compilation) that are prefixed with the job name are truncated at 40 characters. + job_type: Defines the type of recommendation job. Specify Default to initiate an instance recommendation and Advanced to initiate a load test. If left unspecified, Amazon SageMaker Inference Recommender will run an instance recommendation (DEFAULT) job. + role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. + input_config: Provides information about the versioned model package Amazon Resource Name (ARN), the traffic pattern, and endpoint configurations. + job_description: Description of the recommendation job. + stopping_conditions: A set of conditions for stopping a recommendation job. If any of the conditions are met, the job is automatically stopped. + output_config: Provides information about the output artifacts and the KMS key to use for Amazon S3 server-side encryption. + tags: The metadata that you apply to Amazon Web Services resources to help you categorize and organize them. Each tag consists of a key and a value, both of which you define. For more information, see Tagging Amazon Web Services Resources in the Amazon Web Services General Reference. session: Boto3 session. region: Region name. Returns: - The LineageGroup resource. + The InferenceRecommendationsJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15837,37 +15955,62 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ + logger.info("Creating inference_recommendations_job resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "LineageGroupName": lineage_group_name, + "JobName": job_name, + "JobType": job_type, + "RoleArn": role_arn, + "InputConfig": input_config, + "JobDescription": job_description, + "StoppingConditions": stopping_conditions, + "OutputConfig": output_config, + "Tags": tags, } + + operation_input_args = Base.populate_chained_attributes( + resource_name="InferenceRecommendationsJob", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - response = client.describe_lineage_group(**operation_input_args) - - logger.debug(response) + # create the resource + response = client.create_inference_recommendations_job(**operation_input_args) + logger.debug(f"Response: {response}") - # deserialize the response - transformed_response = transform(response, "DescribeLineageGroupResponse") - lineage_group = cls(**transformed_response) - return lineage_group + return cls.get(job_name=job_name, session=session, region=region) + @classmethod @Base.add_validate_call - def refresh( - self, - ) -> Optional["LineageGroup"]: + def get( + cls, + job_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["InferenceRecommendationsJob"]: """ - Refresh a LineageGroup resource + Get a InferenceRecommendationsJob resource + + Parameters: + job_name: The name of the job. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + session: Boto3 session. + region: Region name. Returns: - The LineageGroup resource. + The InferenceRecommendationsJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15883,45 +16026,33 @@ def refresh( """ operation_input_args = { - "LineageGroupName": self.lineage_group_name, + "JobName": job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() - response = client.describe_lineage_group(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + response = client.describe_inference_recommendations_job(**operation_input_args) - # deserialize response and update self - transform(response, "DescribeLineageGroupResponse", self) - return self + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeInferenceRecommendationsJobResponse") + inference_recommendations_job = cls(**transformed_response) + return inference_recommendations_job - @classmethod @Base.add_validate_call - def get_all( - cls, - created_after: Optional[datetime.datetime] = Unassigned(), - created_before: Optional[datetime.datetime] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["LineageGroup"]: + def refresh( + self, + ) -> Optional["InferenceRecommendationsJob"]: """ - Get all LineageGroup resources - - Parameters: - created_after: A timestamp to filter against lineage groups created after a certain point in time. - created_before: A timestamp to filter against lineage groups created before a certain point in time. - sort_by: The parameter by which to sort the results. The default is CreationTime. - sort_order: The sort order for the results. The default is Ascending. - next_token: If the response is truncated, SageMaker returns this token. To retrieve the next set of algorithms, use it in the subsequent request. - max_results: The maximum number of endpoints to return in the response. This value defaults to 10. - session: Boto3 session. - region: Region name. + Refresh a InferenceRecommendationsJob resource Returns: - Iterator for listed LineageGroup resources. + The InferenceRecommendationsJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15933,47 +16064,27 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { - "CreatedAfter": created_after, - "CreatedBefore": created_before, - "SortBy": sort_by, - "SortOrder": sort_order, + "JobName": self.job_name, } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_lineage_groups", - summaries_key="LineageGroupSummaries", - summary_name="LineageGroupSummary", - resource_cls=LineageGroup, - list_method_kwargs=operation_input_args, - ) + client = Base.get_sagemaker_client() + response = client.describe_inference_recommendations_job(**operation_input_args) + + # deserialize response and update self + transform(response, "DescribeInferenceRecommendationsJobResponse", self) + return self @Base.add_validate_call - def get_policy( - self, - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional[GetLineageGroupPolicyResponse]: + def stop(self) -> None: """ - The resource policy for the lineage group. - - Parameters: - session: Boto3 session. - region: Region name. - - Returns: - GetLineageGroupPolicyResponse + Stop a InferenceRecommendationsJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -15988,126 +16099,92 @@ def get_policy( ResourceNotFound: Resource being access is not found. """ + client = SageMakerClient().client + operation_input_args = { - "LineageGroupName": self.lineage_group_name, + "JobName": self.job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - logger.debug(f"Calling get_lineage_group_policy API") - response = client.get_lineage_group_policy(**operation_input_args) - logger.debug(f"Response: {response}") + client.stop_inference_recommendations_job(**operation_input_args) - transformed_response = transform(response, "GetLineageGroupPolicyResponse") - return GetLineageGroupPolicyResponse(**transformed_response) + logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call + def wait( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a InferenceRecommendationsJob resource. -class MlflowTrackingServer(Base): - """ - Class representing resource MlflowTrackingServer + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. - Attributes: - tracking_server_arn: The ARN of the described tracking server. - tracking_server_name: The name of the described tracking server. - artifact_store_uri: The S3 URI of the general purpose bucket used as the MLflow Tracking Server artifact store. - tracking_server_size: The size of the described tracking server. - mlflow_version: The MLflow version used for the described tracking server. - role_arn: The Amazon Resource Name (ARN) for an IAM role in your account that the described MLflow Tracking Server uses to access the artifact store in Amazon S3. - tracking_server_status: The current creation status of the described MLflow Tracking Server. - is_active: Whether the described MLflow Tracking Server is currently active. - tracking_server_url: The URL to connect to the MLflow user interface for the described tracking server. - weekly_maintenance_window_start: The day and time of the week when weekly maintenance occurs on the described tracking server. - automatic_model_registration: Whether automatic registration of new MLflow models to the SageMaker Model Registry is enabled. - creation_time: The timestamp of when the described MLflow Tracking Server was created. - created_by: - last_modified_time: The timestamp of when the described MLflow Tracking Server was last modified. - last_modified_by: + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. - """ + """ + terminal_states = ["COMPLETED", "FAILED", "STOPPED", "DELETED"] + start_time = time.time() - tracking_server_name: str - tracking_server_arn: Optional[str] = Unassigned() - artifact_store_uri: Optional[str] = Unassigned() - tracking_server_size: Optional[str] = Unassigned() - mlflow_version: Optional[str] = Unassigned() - role_arn: Optional[str] = Unassigned() - tracking_server_status: Optional[str] = Unassigned() - is_active: Optional[str] = Unassigned() - tracking_server_url: Optional[str] = Unassigned() - weekly_maintenance_window_start: Optional[str] = Unassigned() - automatic_model_registration: Optional[bool] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - created_by: Optional[UserContext] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - last_modified_by: Optional[UserContext] = Unassigned() + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for InferenceRecommendationsJob...") + status = Status("Current status:") - def get_name(self) -> str: - attributes = vars(self) - resource_name = "mlflow_tracking_server_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) + if current_status in terminal_states: + logger.info(f"Final Resource Status: [bold]{current_status}") - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object mlflow_tracking_server") - return None + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="InferenceRecommendationsJob", + status=current_status, + reason=self.failure_reason, + ) - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = {"role_arn": {"type": "string"}} - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "MlflowTrackingServer", **kwargs - ), - ) + return - return wrapper + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="InferenceRecommendationsJob", status=current_status + ) + time.sleep(poll) - @classmethod - @populate_inputs_decorator @Base.add_validate_call - def create( - cls, - tracking_server_name: str, - artifact_store_uri: str, - role_arn: str, - tracking_server_size: Optional[str] = Unassigned(), - mlflow_version: Optional[str] = Unassigned(), - automatic_model_registration: Optional[bool] = Unassigned(), - weekly_maintenance_window_start: Optional[str] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["MlflowTrackingServer"]: + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: """ - Create a MlflowTrackingServer resource + Wait for a InferenceRecommendationsJob resource to be deleted. Parameters: - tracking_server_name: A unique string identifying the tracking server name. This string is part of the tracking server ARN. - artifact_store_uri: The S3 URI for a general purpose bucket to use as the MLflow Tracking Server artifact store. - role_arn: The Amazon Resource Name (ARN) for an IAM role in your account that the MLflow Tracking Server uses to access the artifact store in Amazon S3. The role should have AmazonS3FullAccess permissions. For more information on IAM permissions for tracking server creation, see Set up IAM permissions for MLflow. - tracking_server_size: The size of the tracking server you want to create. You can choose between "Small", "Medium", and "Large". The default MLflow Tracking Server configuration size is "Small". You can choose a size depending on the projected use of the tracking server such as the volume of data logged, number of users, and frequency of use. We recommend using a small tracking server for teams of up to 25 users, a medium tracking server for teams of up to 50 users, and a large tracking server for teams of up to 100 users. - mlflow_version: The version of MLflow that the tracking server uses. To see which MLflow versions are available to use, see How it works. - automatic_model_registration: Whether to enable or disable automatic registration of new MLflow models to the SageMaker Model Registry. To enable automatic model registration, set this value to True. To disable automatic model registration, set this value to False. If not specified, AutomaticModelRegistration defaults to False. - weekly_maintenance_window_start: The day and time of the week in Coordinated Universal Time (UTC) 24-hour standard time that weekly maintenance updates are scheduled. For example: TUE:03:30. - tags: Tags consisting of key-value pairs used to manage metadata for the tracking server. - session: Boto3 session. - region: Region name. - - Returns: - The MlflowTrackingServer resource. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16119,61 +16196,88 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. """ + start_time = time.time() - logger.info("Creating mlflow_tracking_server resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), ) + progress.add_task("Waiting for InferenceRecommendationsJob to be deleted...") + status = Status("Current status:") - operation_input_args = { - "TrackingServerName": tracking_server_name, - "ArtifactStoreUri": artifact_store_uri, - "TrackingServerSize": tracking_server_size, - "MlflowVersion": mlflow_version, - "RoleArn": role_arn, - "AutomaticModelRegistration": automatic_model_registration, - "WeeklyMaintenanceWindowStart": weekly_maintenance_window_start, - "Tags": tags, - } - - operation_input_args = Base.populate_chained_attributes( - resource_name="MlflowTrackingServer", operation_input_args=operation_input_args - ) + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") - logger.debug(f"Input request: {operation_input_args}") - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") + if current_status.lower() == "deleted": + print("Resource was deleted.") + return - # create the resource - response = client.create_mlflow_tracking_server(**operation_input_args) - logger.debug(f"Response: {response}") + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="InferenceRecommendationsJob", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] - return cls.get(tracking_server_name=tracking_server_name, session=session, region=region) + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) @classmethod @Base.add_validate_call - def get( + def get_all( cls, - tracking_server_name: str, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + status_equals: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + model_name_equals: Optional[str] = Unassigned(), + model_package_version_arn_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["MlflowTrackingServer"]: + ) -> ResourceIterator["InferenceRecommendationsJob"]: """ - Get a MlflowTrackingServer resource + Get all InferenceRecommendationsJob resources Parameters: - tracking_server_name: The name of the MLflow Tracking Server to describe. + creation_time_after: A filter that returns only jobs created after the specified time (timestamp). + creation_time_before: A filter that returns only jobs created before the specified time (timestamp). + last_modified_time_after: A filter that returns only jobs that were last modified after the specified time (timestamp). + last_modified_time_before: A filter that returns only jobs that were last modified before the specified time (timestamp). + name_contains: A string in the job name. This filter returns only recommendations whose name contains the specified string. + status_equals: A filter that retrieves only inference recommendations jobs with a specific status. + sort_by: The parameter by which to sort the results. + sort_order: The sort order for the results. + next_token: If the response to a previous ListInferenceRecommendationsJobsRequest request was truncated, the response includes a NextToken. To retrieve the next set of recommendations, use the token in the next request. + max_results: The maximum number of recommendations to return in the response. + model_name_equals: A filter that returns only jobs that were created for this model. + model_package_version_arn_equals: A filter that returns only jobs that were created for this versioned model package. session: Boto3 session. region: Region name. Returns: - The MlflowTrackingServer resource. + Iterator for listed InferenceRecommendationsJob resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16185,37 +16289,57 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "TrackingServerName": tracking_server_name, + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "NameContains": name_contains, + "StatusEquals": status_equals, + "SortBy": sort_by, + "SortOrder": sort_order, + "ModelNameEquals": model_name_equals, + "ModelPackageVersionArnEquals": model_package_version_arn_equals, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + return ResourceIterator( + client=client, + list_method="list_inference_recommendations_jobs", + summaries_key="InferenceRecommendationsJobs", + summary_name="InferenceRecommendationsJob", + resource_cls=InferenceRecommendationsJob, + list_method_kwargs=operation_input_args, ) - response = client.describe_mlflow_tracking_server(**operation_input_args) - - logger.debug(response) - - # deserialize the response - transformed_response = transform(response, "DescribeMlflowTrackingServerResponse") - mlflow_tracking_server = cls(**transformed_response) - return mlflow_tracking_server @Base.add_validate_call - def refresh( + def get_all_steps( self, - ) -> Optional["MlflowTrackingServer"]: + step_type: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator[InferenceRecommendationsJobStep]: """ - Refresh a MlflowTrackingServer resource + Returns a list of the subtasks for an Inference Recommender job. + + Parameters: + step_type: A filter to return details about the specified type of subtask. BENCHMARK: Evaluate the performance of your model on different instance types. + max_results: The maximum number of results to return. + next_token: A token that you can specify to return more results from the list. Specify this field if you have a token that was returned from a previous request. + session: Boto3 session. + region: Region name. Returns: - The MlflowTrackingServer resource. + Iterator for listed InferenceRecommendationsJobStep. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16231,33 +16355,160 @@ def refresh( """ operation_input_args = { - "TrackingServerName": self.tracking_server_name, + "JobName": self.job_name, + "Status": self.status, + "StepType": step_type, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() - response = client.describe_mlflow_tracking_server(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) - # deserialize response and update self - transform(response, "DescribeMlflowTrackingServerResponse", self) - return self + return ResourceIterator( + client=client, + list_method="list_inference_recommendations_job_steps", + summaries_key="Steps", + summary_name="InferenceRecommendationsJobStep", + resource_cls=InferenceRecommendationsJobStep, + list_method_kwargs=operation_input_args, + ) + + +class LabelingJob(Base): + """ + Class representing resource LabelingJob + + Attributes: + labeling_job_status: The processing status of the labeling job. + label_counters: Provides a breakdown of the number of data objects labeled by humans, the number of objects labeled by machine, the number of objects than couldn't be labeled, and the total number of objects labeled. + creation_time: The date and time that the labeling job was created. + last_modified_time: The date and time that the labeling job was last updated. + job_reference_code: A unique identifier for work done as part of a labeling job. + labeling_job_name: The name assigned to the labeling job when it was created. + labeling_job_arn: The Amazon Resource Name (ARN) of the labeling job. + input_config: Input configuration information for the labeling job, such as the Amazon S3 location of the data objects and the location of the manifest file that describes the data objects. + output_config: The location of the job's output data and the Amazon Web Services Key Management Service key ID for the key used to encrypt the output data, if any. + role_arn: The Amazon Resource Name (ARN) that SageMaker assumes to perform tasks on your behalf during data labeling. + human_task_config: Configuration information required for human workers to complete a labeling task. + failure_reason: If the job failed, the reason that it failed. + label_attribute_name: The attribute used as the label in the output manifest file. + label_category_config_s3_uri: The S3 location of the JSON file that defines the categories used to label data objects. Please note the following label-category limits: Semantic segmentation labeling jobs using automated labeling: 20 labels Box bounding labeling jobs (all): 10 labels The file is a JSON structure in the following format: { "document-version": "2018-11-28" "labels": [ { "label": "label 1" }, { "label": "label 2" }, ... { "label": "label n" } ] } + stopping_conditions: A set of conditions for stopping a labeling job. If any of the conditions are met, the job is automatically stopped. + labeling_job_algorithms_config: Configuration information for automated data labeling. + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + labeling_job_output: The location of the output produced by the labeling job. + + """ + + labeling_job_name: str + labeling_job_status: Optional[str] = Unassigned() + label_counters: Optional[LabelCounters] = Unassigned() + failure_reason: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + job_reference_code: Optional[str] = Unassigned() + labeling_job_arn: Optional[str] = Unassigned() + label_attribute_name: Optional[str] = Unassigned() + input_config: Optional[LabelingJobInputConfig] = Unassigned() + output_config: Optional[LabelingJobOutputConfig] = Unassigned() + role_arn: Optional[str] = Unassigned() + label_category_config_s3_uri: Optional[str] = Unassigned() + stopping_conditions: Optional[LabelingJobStoppingConditions] = Unassigned() + labeling_job_algorithms_config: Optional[LabelingJobAlgorithmsConfig] = Unassigned() + human_task_config: Optional[HumanTaskConfig] = Unassigned() + tags: Optional[List[Tag]] = Unassigned() + labeling_job_output: Optional[LabelingJobOutput] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "labeling_job_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object labeling_job") + return None + + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "input_config": { + "data_source": {"s3_data_source": {"manifest_s3_uri": {"type": "string"}}} + }, + "output_config": { + "s3_output_path": {"type": "string"}, + "kms_key_id": {"type": "string"}, + }, + "role_arn": {"type": "string"}, + "human_task_config": {"ui_config": {"ui_template_s3_uri": {"type": "string"}}}, + "label_category_config_s3_uri": {"type": "string"}, + "labeling_job_algorithms_config": { + "labeling_job_resource_config": { + "volume_kms_key_id": {"type": "string"}, + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + }, + } + }, + "labeling_job_output": {"output_dataset_s3_uri": {"type": "string"}}, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "LabelingJob", **kwargs + ), + ) + + return wrapper + @classmethod @populate_inputs_decorator @Base.add_validate_call - def update( - self, - artifact_store_uri: Optional[str] = Unassigned(), - tracking_server_size: Optional[str] = Unassigned(), - automatic_model_registration: Optional[bool] = Unassigned(), - weekly_maintenance_window_start: Optional[str] = Unassigned(), - ) -> Optional["MlflowTrackingServer"]: + def create( + cls, + labeling_job_name: str, + label_attribute_name: str, + input_config: LabelingJobInputConfig, + output_config: LabelingJobOutputConfig, + role_arn: str, + human_task_config: HumanTaskConfig, + label_category_config_s3_uri: Optional[str] = Unassigned(), + stopping_conditions: Optional[LabelingJobStoppingConditions] = Unassigned(), + labeling_job_algorithms_config: Optional[LabelingJobAlgorithmsConfig] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["LabelingJob"]: """ - Update a MlflowTrackingServer resource + Create a LabelingJob resource + + Parameters: + labeling_job_name: The name of the labeling job. This name is used to identify the job in a list of labeling jobs. Labeling job names must be unique within an Amazon Web Services account and region. LabelingJobName is not case sensitive. For example, Example-job and example-job are considered the same labeling job name by Ground Truth. + label_attribute_name: The attribute name to use for the label in the output manifest file. This is the key for the key/value pair formed with the label that a worker assigns to the object. The LabelAttributeName must meet the following requirements. The name can't end with "-metadata". If you are using one of the following built-in task types, the attribute name must end with "-ref". If the task type you are using is not listed below, the attribute name must not end with "-ref". Image semantic segmentation (SemanticSegmentation), and adjustment (AdjustmentSemanticSegmentation) and verification (VerificationSemanticSegmentation) labeling jobs for this task type. Video frame object detection (VideoObjectDetection), and adjustment and verification (AdjustmentVideoObjectDetection) labeling jobs for this task type. Video frame object tracking (VideoObjectTracking), and adjustment and verification (AdjustmentVideoObjectTracking) labeling jobs for this task type. 3D point cloud semantic segmentation (3DPointCloudSemanticSegmentation), and adjustment and verification (Adjustment3DPointCloudSemanticSegmentation) labeling jobs for this task type. 3D point cloud object tracking (3DPointCloudObjectTracking), and adjustment and verification (Adjustment3DPointCloudObjectTracking) labeling jobs for this task type. If you are creating an adjustment or verification labeling job, you must use a different LabelAttributeName than the one used in the original labeling job. The original labeling job is the Ground Truth labeling job that produced the labels that you want verified or adjusted. To learn more about adjustment and verification labeling jobs, see Verify and Adjust Labels. + input_config: Input data for the labeling job, such as the Amazon S3 location of the data objects and the location of the manifest file that describes the data objects. You must specify at least one of the following: S3DataSource or SnsDataSource. Use SnsDataSource to specify an SNS input topic for a streaming labeling job. If you do not specify and SNS input topic ARN, Ground Truth will create a one-time labeling job that stops after all data objects in the input manifest file have been labeled. Use S3DataSource to specify an input manifest file for both streaming and one-time labeling jobs. Adding an S3DataSource is optional if you use SnsDataSource to create a streaming labeling job. If you use the Amazon Mechanical Turk workforce, your input data should not include confidential information, personal information or protected health information. Use ContentClassifiers to specify that your data is free of personally identifiable information and adult content. + output_config: The location of the output data and the Amazon Web Services Key Management Service key ID for the key used to encrypt the output data, if any. + role_arn: The Amazon Resource Number (ARN) that Amazon SageMaker assumes to perform tasks on your behalf during data labeling. You must grant this role the necessary permissions so that Amazon SageMaker can successfully complete data labeling. + human_task_config: Configures the labeling task and how it is presented to workers; including, but not limited to price, keywords, and batch size (task count). + label_category_config_s3_uri: The S3 URI of the file, referred to as a label category configuration file, that defines the categories used to label the data objects. For 3D point cloud and video frame task types, you can add label category attributes and frame attributes to your label category configuration file. To learn how, see Create a Labeling Category Configuration File for 3D Point Cloud Labeling Jobs. For named entity recognition jobs, in addition to "labels", you must provide worker instructions in the label category configuration file using the "instructions" parameter: "instructions": {"shortInstruction":"<h1>Add header</h1><p>Add Instructions</p>", "fullInstruction":"<p>Add additional instructions.</p>"}. For details and an example, see Create a Named Entity Recognition Labeling Job (API) . For all other built-in task types and custom tasks, your label category configuration file must be a JSON file in the following format. Identify the labels you want to use by replacing label_1, label_2,...,label_n with your label categories. { "document-version": "2018-11-28", "labels": [{"label": "label_1"},{"label": "label_2"},...{"label": "label_n"}] } Note the following about the label category configuration file: For image classification and text classification (single and multi-label) you must specify at least two label categories. For all other task types, the minimum number of label categories required is one. Each label category must be unique, you cannot specify duplicate label categories. If you create a 3D point cloud or video frame adjustment or verification labeling job, you must include auditLabelAttributeName in the label category configuration. Use this parameter to enter the LabelAttributeName of the labeling job you want to adjust or verify annotations of. + stopping_conditions: A set of conditions for stopping the labeling job. If any of the conditions are met, the job is automatically stopped. You can use these conditions to control the cost of data labeling. + labeling_job_algorithms_config: Configures the information required to perform automated data labeling. + tags: An array of key/value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. + session: Boto3 session. + region: Region name. Returns: - The MlflowTrackingServer resource. + The LabelingJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16269,39 +16520,64 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ResourceNotFound: Resource being access is not found. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Updating mlflow_tracking_server resource.") - client = Base.get_sagemaker_client() - - operation_input_args = { - "TrackingServerName": self.tracking_server_name, - "ArtifactStoreUri": artifact_store_uri, - "TrackingServerSize": tracking_server_size, - "AutomaticModelRegistration": automatic_model_registration, - "WeeklyMaintenanceWindowStart": weekly_maintenance_window_start, + logger.info("Creating labeling_job resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "LabelingJobName": labeling_job_name, + "LabelAttributeName": label_attribute_name, + "InputConfig": input_config, + "OutputConfig": output_config, + "RoleArn": role_arn, + "LabelCategoryConfigS3Uri": label_category_config_s3_uri, + "StoppingConditions": stopping_conditions, + "LabelingJobAlgorithmsConfig": labeling_job_algorithms_config, + "HumanTaskConfig": human_task_config, + "Tags": tags, } + + operation_input_args = Base.populate_chained_attributes( + resource_name="LabelingJob", operation_input_args=operation_input_args + ) + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_mlflow_tracking_server(**operation_input_args) + response = client.create_labeling_job(**operation_input_args) logger.debug(f"Response: {response}") - self.refresh() - return self + return cls.get(labeling_job_name=labeling_job_name, session=session, region=region) + @classmethod @Base.add_validate_call - def delete( - self, - ) -> None: + def get( + cls, + labeling_job_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["LabelingJob"]: """ - Delete a MlflowTrackingServer resource + Get a LabelingJob resource + + Parameters: + labeling_job_name: The name of the labeling job to return information for. + session: Boto3 session. + region: Region name. + + Returns: + The LabelingJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16316,23 +16592,66 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + operation_input_args = { + "LabelingJobName": labeling_job_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + response = client.describe_labeling_job(**operation_input_args) + + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeLabelingJobResponse") + labeling_job = cls(**transformed_response) + return labeling_job + + @Base.add_validate_call + def refresh( + self, + ) -> Optional["LabelingJob"]: + """ + Refresh a LabelingJob resource + + Returns: + The LabelingJob resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ operation_input_args = { - "TrackingServerName": self.tracking_server_name, + "LabelingJobName": self.labeling_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_mlflow_tracking_server(**operation_input_args) + client = Base.get_sagemaker_client() + response = client.describe_labeling_job(**operation_input_args) - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + # deserialize response and update self + transform(response, "DescribeLabelingJobResponse", self) + return self @Base.add_validate_call def stop(self) -> None: """ - Stop a MlflowTrackingServer resource + Stop a LabelingJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16344,53 +16663,32 @@ def stop(self) -> None: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceNotFound: Resource being access is not found. """ client = SageMakerClient().client operation_input_args = { - "TrackingServerName": self.tracking_server_name, + "LabelingJobName": self.labeling_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.stop_mlflow_tracking_server(**operation_input_args) + client.stop_labeling_job(**operation_input_args) logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def wait_for_status( + def wait( self, - target_status: Literal[ - "Creating", - "Created", - "CreateFailed", - "Updating", - "Updated", - "UpdateFailed", - "Deleting", - "DeleteFailed", - "Stopping", - "Stopped", - "StopFailed", - "Starting", - "Started", - "StartFailed", - "MaintenanceInProgress", - "MaintenanceComplete", - "MaintenanceFailed", - ], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a MlflowTrackingServer resource to reach certain status. + Wait for a LabelingJob resource. Parameters: - target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. @@ -16398,7 +16696,9 @@ def wait_for_status( TimeoutExceededError: If the resource does not reach a terminal state before the timeout. FailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. + """ + terminal_states = ["Completed", "Failed", "Stopped"] start_time = time.time() progress = Progress( @@ -16406,9 +16706,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task( - f"Waiting for MlflowTrackingServer to reach [bold]{target_status} status..." - ) + progress.add_task("Waiting for LabelingJob...") status = Status("Current status:") with Live( @@ -16421,38 +16719,59 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.tracking_server_status + current_status = self.labeling_job_status status.update(f"Current status: [bold]{current_status}") - if target_status == current_status: + if current_status in terminal_states: logger.info(f"Final Resource Status: [bold]{current_status}") - return - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="MlflowTrackingServer", - status=current_status, - reason="(Unknown)", - ) + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="LabelingJob", + status=current_status, + reason=self.failure_reason, + ) + + return if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="MlflowTrackingServer", status=current_status - ) + raise TimeoutExceededError(resouce_type="LabelingJob", status=current_status) time.sleep(poll) + @classmethod @Base.add_validate_call - def wait_for_delete( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: + def get_all( + cls, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + status_equals: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["LabelingJob"]: """ - Wait for a MlflowTrackingServer resource to be deleted. + Get all LabelingJob resources Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + creation_time_after: A filter that returns only labeling jobs created after the specified time (timestamp). + creation_time_before: A filter that returns only labeling jobs created before the specified time (timestamp). + last_modified_time_after: A filter that returns only labeling jobs modified after the specified time (timestamp). + last_modified_time_before: A filter that returns only labeling jobs modified before the specified time (timestamp). + max_results: The maximum number of labeling jobs to return in each page of the response. + next_token: If the result of the previous ListLabelingJobs request was truncated, the response includes a NextToken. To retrieve the next set of labeling jobs, use the token in the next request. + name_contains: A string in the labeling job name. This filter returns only labeling jobs whose name contains the specified string. + sort_by: The field to sort results by. The default is CreationTime. + sort_order: The sort order for results. The default is Ascending. + status_equals: A filter that retrieves only labeling jobs with a specific status. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed LabelingJob resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16464,148 +16783,65 @@ def wait_for_delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. """ - start_time = time.time() - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" ) - progress.add_task("Waiting for MlflowTrackingServer to be deleted...") - status = Status("Current status:") - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ) - ): - while True: - try: - self.refresh() - current_status = self.tracking_server_status - status.update(f"Current status: [bold]{current_status}") + operation_input_args = { + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "NameContains": name_contains, + "SortBy": sort_by, + "SortOrder": sort_order, + "StatusEquals": status_equals, + } - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="MlflowTrackingServer", status=current_status - ) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e - time.sleep(poll) + return ResourceIterator( + client=client, + list_method="list_labeling_jobs", + summaries_key="LabelingJobSummaryList", + summary_name="LabelingJobSummary", + resource_cls=LabelingJob, + list_method_kwargs=operation_input_args, + ) - @classmethod - @Base.add_validate_call - def get_all( - cls, - created_after: Optional[datetime.datetime] = Unassigned(), - created_before: Optional[datetime.datetime] = Unassigned(), - tracking_server_status: Optional[str] = Unassigned(), - mlflow_version: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["MlflowTrackingServer"]: - """ - Get all MlflowTrackingServer resources - - Parameters: - created_after: Use the CreatedAfter filter to only list tracking servers created after a specific date and time. Listed tracking servers are shown with a date and time such as "2024-03-16T01:46:56+00:00". The CreatedAfter parameter takes in a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter. - created_before: Use the CreatedBefore filter to only list tracking servers created before a specific date and time. Listed tracking servers are shown with a date and time such as "2024-03-16T01:46:56+00:00". The CreatedBefore parameter takes in a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter. - tracking_server_status: Filter for tracking servers with a specified creation status. - mlflow_version: Filter for tracking servers using the specified MLflow version. - sort_by: Filter for trackings servers sorting by name, creation time, or creation status. - sort_order: Change the order of the listed tracking servers. By default, tracking servers are listed in Descending order by creation time. To change the list order, you can specify SortOrder to be Ascending. - next_token: If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results. - max_results: The maximum number of tracking servers to list. - session: Boto3 session. - region: Region name. - - Returns: - Iterator for listed MlflowTrackingServer resources. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - """ - - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - operation_input_args = { - "CreatedAfter": created_after, - "CreatedBefore": created_before, - "TrackingServerStatus": tracking_server_status, - "MlflowVersion": mlflow_version, - "SortBy": sort_by, - "SortOrder": sort_order, - } - - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - return ResourceIterator( - client=client, - list_method="list_mlflow_tracking_servers", - summaries_key="TrackingServerSummaries", - summary_name="TrackingServerSummary", - resource_cls=MlflowTrackingServer, - list_method_kwargs=operation_input_args, - ) - -class Model(Base): +class LineageGroup(Base): """ - Class representing resource Model + Class representing resource LineageGroup Attributes: - model_name: Name of the SageMaker model. - creation_time: A timestamp that shows when the model was created. - model_arn: The Amazon Resource Name (ARN) of the model. - primary_container: The location of the primary inference code, associated artifacts, and custom environment map that the inference code uses when it is deployed in production. - containers: The containers in the inference pipeline. - inference_execution_config: Specifies details of how containers in a multi-container endpoint are called. - execution_role_arn: The Amazon Resource Name (ARN) of the IAM role that you specified for the model. - vpc_config: A VpcConfig object that specifies the VPC that this model has access to. For more information, see Protect Endpoints by Using an Amazon Virtual Private Cloud - enable_network_isolation: If True, no inbound or outbound network calls can be made to or from the model container. - deployment_recommendation: A set of recommended deployment configurations for the model. + lineage_group_name: The name of the lineage group. + lineage_group_arn: The Amazon Resource Name (ARN) of the lineage group. + display_name: The display name of the lineage group. + description: The description of the lineage group. + creation_time: The creation time of lineage group. + created_by: + last_modified_time: The last modified time of the lineage group. + last_modified_by: """ - model_name: str - primary_container: Optional[ContainerDefinition] = Unassigned() - containers: Optional[List[ContainerDefinition]] = Unassigned() - inference_execution_config: Optional[InferenceExecutionConfig] = Unassigned() - execution_role_arn: Optional[str] = Unassigned() - vpc_config: Optional[VpcConfig] = Unassigned() + lineage_group_name: str + lineage_group_arn: Optional[str] = Unassigned() + display_name: Optional[str] = Unassigned() + description: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - model_arn: Optional[str] = Unassigned() - enable_network_isolation: Optional[bool] = Unassigned() - deployment_recommendation: Optional[DeploymentRecommendation] = Unassigned() + created_by: Optional[UserContext] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + last_modified_by: Optional[UserContext] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "model_name" + resource_name = "lineage_group_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -16616,136 +16852,27 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object model") + logger.error("Name attribute not found for object lineage_group") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "primary_container": { - "model_data_source": { - "s3_data_source": { - "s3_uri": {"type": "string"}, - "s3_data_type": {"type": "string"}, - "manifest_s3_uri": {"type": "string"}, - } - } - }, - "execution_role_arn": {"type": "string"}, - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - }, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "Model", **kwargs - ), - ) - - return wrapper - - @classmethod - @populate_inputs_decorator - @Base.add_validate_call - def create( - cls, - model_name: str, - primary_container: Optional[ContainerDefinition] = Unassigned(), - containers: Optional[List[ContainerDefinition]] = Unassigned(), - inference_execution_config: Optional[InferenceExecutionConfig] = Unassigned(), - execution_role_arn: Optional[str] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), - vpc_config: Optional[VpcConfig] = Unassigned(), - enable_network_isolation: Optional[bool] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["Model"]: - """ - Create a Model resource - - Parameters: - model_name: The name of the new model. - primary_container: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions. - containers: Specifies the containers in the inference pipeline. - inference_execution_config: Specifies details of how containers in a multi-container endpoint are called. - execution_role_arn: The Amazon Resource Name (ARN) of the IAM role that SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs. Deploying on ML compute instances is part of model hosting. For more information, see SageMaker Roles. To be able to pass this role to SageMaker, the caller of this API must have the iam:PassRole permission. - tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. - vpc_config: A VpcConfig object that specifies the VPC that you want your model to connect to. Control access to and from your model container by configuring the VPC. VpcConfig is used in hosting services and in batch transform. For more information, see Protect Endpoints by Using an Amazon Virtual Private Cloud and Protect Data in Batch Transform Jobs by Using an Amazon Virtual Private Cloud. - enable_network_isolation: Isolates the model container. No inbound or outbound network calls can be made to or from the model container. - session: Boto3 session. - region: Region name. - - Returns: - The Model resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 - """ - - logger.info("Creating model resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - operation_input_args = { - "ModelName": model_name, - "PrimaryContainer": primary_container, - "Containers": containers, - "InferenceExecutionConfig": inference_execution_config, - "ExecutionRoleArn": execution_role_arn, - "Tags": tags, - "VpcConfig": vpc_config, - "EnableNetworkIsolation": enable_network_isolation, - } - - operation_input_args = Base.populate_chained_attributes( - resource_name="Model", operation_input_args=operation_input_args - ) - - logger.debug(f"Input request: {operation_input_args}") - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - # create the resource - response = client.create_model(**operation_input_args) - logger.debug(f"Response: {response}") - - return cls.get(model_name=model_name, session=session, region=region) - @classmethod @Base.add_validate_call def get( cls, - model_name: str, + lineage_group_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["Model"]: + ) -> Optional["LineageGroup"]: """ - Get a Model resource + Get a LineageGroup resource Parameters: - model_name: The name of the model. + lineage_group_name: The name of the lineage group. session: Boto3 session. region: Region name. Returns: - The Model resource. + The LineageGroup resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16757,10 +16884,11 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelName": model_name, + "LineageGroupName": lineage_group_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -16769,24 +16897,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_model(**operation_input_args) + response = client.describe_lineage_group(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeModelOutput") - model = cls(**transformed_response) - return model + transformed_response = transform(response, "DescribeLineageGroupResponse") + lineage_group = cls(**transformed_response) + return lineage_group @Base.add_validate_call def refresh( self, - ) -> Optional["Model"]: + ) -> Optional["LineageGroup"]: """ - Refresh a Model resource + Refresh a LineageGroup resource Returns: - The Model resource. + The LineageGroup resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16798,82 +16926,49 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelName": self.model_name, + "LineageGroupName": self.lineage_group_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_model(**operation_input_args) + response = client.describe_lineage_group(**operation_input_args) # deserialize response and update self - transform(response, "DescribeModelOutput", self) + transform(response, "DescribeLineageGroupResponse", self) return self - @Base.add_validate_call - def delete( - self, - ) -> None: - """ - Delete a Model resource - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - """ - - client = Base.get_sagemaker_client() - - operation_input_args = { - "ModelName": self.model_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client.delete_model(**operation_input_args) - - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - - @classmethod + @classmethod @Base.add_validate_call def get_all( cls, + created_after: Optional[datetime.datetime] = Unassigned(), + created_before: Optional[datetime.datetime] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["Model"]: + ) -> ResourceIterator["LineageGroup"]: """ - Get all Model resources + Get all LineageGroup resources Parameters: - sort_by: Sorts the list of results. The default is CreationTime. - sort_order: The sort order for results. The default is Descending. - next_token: If the response to a previous ListModels request was truncated, the response includes a NextToken. To retrieve the next set of models, use the token in the next request. - max_results: The maximum number of models to return in the response. - name_contains: A string in the model name. This filter returns only models whose name contains the specified string. - creation_time_before: A filter that returns only models created before the specified time (timestamp). - creation_time_after: A filter that returns only models with a creation time greater than or equal to the specified time (timestamp). + created_after: A timestamp to filter against lineage groups created after a certain point in time. + created_before: A timestamp to filter against lineage groups created before a certain point in time. + sort_by: The parameter by which to sort the results. The default is CreationTime. + sort_order: The sort order for the results. The default is Ascending. + next_token: If the response is truncated, SageMaker returns this token. To retrieve the next set of algorithms, use it in the subsequent request. + max_results: The maximum number of endpoints to return in the response. This value defaults to 10. session: Boto3 session. region: Region name. Returns: - Iterator for listed Model resources. + Iterator for listed LineageGroup resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16892,11 +16987,10 @@ def get_all( ) operation_input_args = { + "CreatedAfter": created_after, + "CreatedBefore": created_before, "SortBy": sort_by, "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, } # serialize the input request @@ -16905,32 +16999,28 @@ def get_all( return ResourceIterator( client=client, - list_method="list_models", - summaries_key="Models", - summary_name="ModelSummary", - resource_cls=Model, + list_method="list_lineage_groups", + summaries_key="LineageGroupSummaries", + summary_name="LineageGroupSummary", + resource_cls=LineageGroup, list_method_kwargs=operation_input_args, ) @Base.add_validate_call - def get_all_metadata( + def get_policy( self, - search_expression: Optional[ModelMetadataSearchExpression] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator[ModelMetadataSummary]: + ) -> Optional[GetLineageGroupPolicyResponse]: """ - Lists the domain, framework, task, and model name of standard machine learning models found in common model zoos. + The resource policy for the lineage group. Parameters: - search_expression: One or more filters that searches for the specified resource or resources in a search. All resource objects that satisfy the expression's condition are included in the search results. Specify the Framework, FrameworkVersion, Domain or Task to filter supported. Filter names and values are case-sensitive. - next_token: If the response to a previous ListModelMetadataResponse request was truncated, the response includes a NextToken. To retrieve the next set of model metadata, use the token in the next request. - max_results: The maximum number of models to return in the response. session: Boto3 session. region: Region name. Returns: - Iterator for listed ModelMetadataSummary. + GetLineageGroupPolicyResponse Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -16942,10 +17032,11 @@ def get_all_metadata( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "SearchExpression": search_expression, + "LineageGroupName": self.lineage_group_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -16955,50 +17046,56 @@ def get_all_metadata( session=session, region_name=region, service_name="sagemaker" ) - return ResourceIterator( - client=client, - list_method="list_model_metadata", - summaries_key="ModelMetadataSummaries", - summary_name="ModelMetadataSummary", - resource_cls=ModelMetadataSummary, - list_method_kwargs=operation_input_args, - ) + logger.debug(f"Calling get_lineage_group_policy API") + response = client.get_lineage_group_policy(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "GetLineageGroupPolicyResponse") + return GetLineageGroupPolicyResponse(**transformed_response) -class ModelBiasJobDefinition(Base): +class MlflowTrackingServer(Base): """ - Class representing resource ModelBiasJobDefinition + Class representing resource MlflowTrackingServer Attributes: - job_definition_arn: The Amazon Resource Name (ARN) of the model bias job. - job_definition_name: The name of the bias job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. - creation_time: The time at which the model bias job was created. - model_bias_app_specification: Configures the model bias job to run a specified Docker container image. - model_bias_job_input: Inputs for the model bias job. - model_bias_job_output_config: - job_resources: - role_arn: The Amazon Resource Name (ARN) of the IAM role that has read permission to the input data location and write permission to the output data location in Amazon S3. - model_bias_baseline_config: The baseline configuration for a model bias job. - network_config: Networking options for a model bias job. - stopping_condition: + tracking_server_arn: The ARN of the described tracking server. + tracking_server_name: The name of the described tracking server. + artifact_store_uri: The S3 URI of the general purpose bucket used as the MLflow Tracking Server artifact store. + tracking_server_size: The size of the described tracking server. + mlflow_version: The MLflow version used for the described tracking server. + role_arn: The Amazon Resource Name (ARN) for an IAM role in your account that the described MLflow Tracking Server uses to access the artifact store in Amazon S3. + tracking_server_status: The current creation status of the described MLflow Tracking Server. + is_active: Whether the described MLflow Tracking Server is currently active. + tracking_server_url: The URL to connect to the MLflow user interface for the described tracking server. + weekly_maintenance_window_start: The day and time of the week when weekly maintenance occurs on the described tracking server. + automatic_model_registration: Whether automatic registration of new MLflow models to the SageMaker Model Registry is enabled. + creation_time: The timestamp of when the described MLflow Tracking Server was created. + created_by: + last_modified_time: The timestamp of when the described MLflow Tracking Server was last modified. + last_modified_by: """ - job_definition_name: str - job_definition_arn: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - model_bias_baseline_config: Optional[ModelBiasBaselineConfig] = Unassigned() - model_bias_app_specification: Optional[ModelBiasAppSpecification] = Unassigned() - model_bias_job_input: Optional[ModelBiasJobInput] = Unassigned() - model_bias_job_output_config: Optional[MonitoringOutputConfig] = Unassigned() - job_resources: Optional[MonitoringResources] = Unassigned() - network_config: Optional[MonitoringNetworkConfig] = Unassigned() + tracking_server_name: str + tracking_server_arn: Optional[str] = Unassigned() + artifact_store_uri: Optional[str] = Unassigned() + tracking_server_size: Optional[str] = Unassigned() + mlflow_version: Optional[str] = Unassigned() role_arn: Optional[str] = Unassigned() - stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned() + tracking_server_status: Optional[str] = Unassigned() + is_active: Optional[str] = Unassigned() + tracking_server_url: Optional[str] = Unassigned() + weekly_maintenance_window_start: Optional[str] = Unassigned() + automatic_model_registration: Optional[bool] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + created_by: Optional[UserContext] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + last_modified_by: Optional[UserContext] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "model_bias_job_definition_name" + resource_name = "mlflow_tracking_server_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -17009,42 +17106,17 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object model_bias_job_definition") + logger.error("Name attribute not found for object mlflow_tracking_server") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): - config_schema_for_resource = { - "model_bias_job_input": { - "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, - "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, - "batch_transform_input": { - "data_captured_destination_s3_uri": {"type": "string"}, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, - }, - }, - "model_bias_job_output_config": {"kms_key_id": {"type": "string"}}, - "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, - "role_arn": {"type": "string"}, - "model_bias_baseline_config": { - "constraints_resource": {"s3_uri": {"type": "string"}} - }, - "network_config": { - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - } - }, - } + config_schema_for_resource = {"role_arn": {"type": "string"}} return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "ModelBiasJobDefinition", **kwargs + config_schema_for_resource, "MlflowTrackingServer", **kwargs ), ) @@ -17055,38 +17127,34 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - job_definition_name: str, - model_bias_app_specification: ModelBiasAppSpecification, - model_bias_job_input: ModelBiasJobInput, - model_bias_job_output_config: MonitoringOutputConfig, - job_resources: MonitoringResources, + tracking_server_name: str, + artifact_store_uri: str, role_arn: str, - model_bias_baseline_config: Optional[ModelBiasBaselineConfig] = Unassigned(), - network_config: Optional[MonitoringNetworkConfig] = Unassigned(), - stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), + tracking_server_size: Optional[str] = Unassigned(), + mlflow_version: Optional[str] = Unassigned(), + automatic_model_registration: Optional[bool] = Unassigned(), + weekly_maintenance_window_start: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelBiasJobDefinition"]: + ) -> Optional["MlflowTrackingServer"]: """ - Create a ModelBiasJobDefinition resource + Create a MlflowTrackingServer resource Parameters: - job_definition_name: The name of the bias job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. - model_bias_app_specification: Configures the model bias job to run a specified Docker container image. - model_bias_job_input: Inputs for the model bias job. - model_bias_job_output_config: - job_resources: - role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - model_bias_baseline_config: The baseline configuration for a model bias job. - network_config: Networking options for a model bias job. - stopping_condition: - tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. + tracking_server_name: A unique string identifying the tracking server name. This string is part of the tracking server ARN. + artifact_store_uri: The S3 URI for a general purpose bucket to use as the MLflow Tracking Server artifact store. + role_arn: The Amazon Resource Name (ARN) for an IAM role in your account that the MLflow Tracking Server uses to access the artifact store in Amazon S3. The role should have AmazonS3FullAccess permissions. For more information on IAM permissions for tracking server creation, see Set up IAM permissions for MLflow. + tracking_server_size: The size of the tracking server you want to create. You can choose between "Small", "Medium", and "Large". The default MLflow Tracking Server configuration size is "Small". You can choose a size depending on the projected use of the tracking server such as the volume of data logged, number of users, and frequency of use. We recommend using a small tracking server for teams of up to 25 users, a medium tracking server for teams of up to 50 users, and a large tracking server for teams of up to 100 users. + mlflow_version: The version of MLflow that the tracking server uses. To see which MLflow versions are available to use, see How it works. + automatic_model_registration: Whether to enable or disable automatic registration of new MLflow models to the SageMaker Model Registry. To enable automatic model registration, set this value to True. To disable automatic model registration, set this value to False. If not specified, AutomaticModelRegistration defaults to False. + weekly_maintenance_window_start: The day and time of the week in Coordinated Universal Time (UTC) 24-hour standard time that weekly maintenance updates are scheduled. For example: TUE:03:30. + tags: Tags consisting of key-value pairs used to manage metadata for the tracking server. session: Boto3 session. region: Region name. Returns: - The ModelBiasJobDefinition resource. + The MlflowTrackingServer resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17098,33 +17166,30 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating model_bias_job_definition resource.") + logger.info("Creating mlflow_tracking_server resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "JobDefinitionName": job_definition_name, - "ModelBiasBaselineConfig": model_bias_baseline_config, - "ModelBiasAppSpecification": model_bias_app_specification, - "ModelBiasJobInput": model_bias_job_input, - "ModelBiasJobOutputConfig": model_bias_job_output_config, - "JobResources": job_resources, - "NetworkConfig": network_config, + "TrackingServerName": tracking_server_name, + "ArtifactStoreUri": artifact_store_uri, + "TrackingServerSize": tracking_server_size, + "MlflowVersion": mlflow_version, "RoleArn": role_arn, - "StoppingCondition": stopping_condition, + "AutomaticModelRegistration": automatic_model_registration, + "WeeklyMaintenanceWindowStart": weekly_maintenance_window_start, "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="ModelBiasJobDefinition", operation_input_args=operation_input_args + resource_name="MlflowTrackingServer", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -17133,29 +17198,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_model_bias_job_definition(**operation_input_args) + response = client.create_mlflow_tracking_server(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(job_definition_name=job_definition_name, session=session, region=region) + return cls.get(tracking_server_name=tracking_server_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - job_definition_name: str, + tracking_server_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelBiasJobDefinition"]: + ) -> Optional["MlflowTrackingServer"]: """ - Get a ModelBiasJobDefinition resource + Get a MlflowTrackingServer resource Parameters: - job_definition_name: The name of the model bias job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + tracking_server_name: The name of the MLflow Tracking Server to describe. session: Boto3 session. region: Region name. Returns: - The ModelBiasJobDefinition resource. + The MlflowTrackingServer resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17171,7 +17236,7 @@ def get( """ operation_input_args = { - "JobDefinitionName": job_definition_name, + "TrackingServerName": tracking_server_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -17180,24 +17245,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_model_bias_job_definition(**operation_input_args) + response = client.describe_mlflow_tracking_server(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeModelBiasJobDefinitionResponse") - model_bias_job_definition = cls(**transformed_response) - return model_bias_job_definition + transformed_response = transform(response, "DescribeMlflowTrackingServerResponse") + mlflow_tracking_server = cls(**transformed_response) + return mlflow_tracking_server @Base.add_validate_call def refresh( self, - ) -> Optional["ModelBiasJobDefinition"]: + ) -> Optional["MlflowTrackingServer"]: """ - Refresh a ModelBiasJobDefinition resource + Refresh a MlflowTrackingServer resource Returns: - The ModelBiasJobDefinition resource. + The MlflowTrackingServer resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17213,25 +17278,33 @@ def refresh( """ operation_input_args = { - "JobDefinitionName": self.job_definition_name, + "TrackingServerName": self.tracking_server_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_model_bias_job_definition(**operation_input_args) + response = client.describe_mlflow_tracking_server(**operation_input_args) # deserialize response and update self - transform(response, "DescribeModelBiasJobDefinitionResponse", self) + transform(response, "DescribeMlflowTrackingServerResponse", self) return self + @populate_inputs_decorator @Base.add_validate_call - def delete( + def update( self, - ) -> None: + artifact_store_uri: Optional[str] = Unassigned(), + tracking_server_size: Optional[str] = Unassigned(), + automatic_model_registration: Optional[bool] = Unassigned(), + weekly_maintenance_window_start: Optional[str] = Unassigned(), + ) -> Optional["MlflowTrackingServer"]: """ - Delete a ModelBiasJobDefinition resource + Update a MlflowTrackingServer resource + + Returns: + The MlflowTrackingServer resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17243,52 +17316,39 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ResourceNotFound: Resource being access is not found. """ + logger.info("Updating mlflow_tracking_server resource.") client = Base.get_sagemaker_client() operation_input_args = { - "JobDefinitionName": self.job_definition_name, + "TrackingServerName": self.tracking_server_name, + "ArtifactStoreUri": artifact_store_uri, + "TrackingServerSize": tracking_server_size, + "AutomaticModelRegistration": automatic_model_registration, + "WeeklyMaintenanceWindowStart": weekly_maintenance_window_start, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_model_bias_job_definition(**operation_input_args) + # create the resource + response = client.update_mlflow_tracking_server(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + return self - @classmethod @Base.add_validate_call - def get_all( - cls, - endpoint_name: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["ModelBiasJobDefinition"]: + def delete( + self, + ) -> None: """ - Get all ModelBiasJobDefinition resources - - Parameters: - endpoint_name: Name of the endpoint to monitor for model bias. - sort_by: Whether to sort results by the Name or CreationTime field. The default is CreationTime. - sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. - next_token: The token returned if the response is truncated. To retrieve the next set of job executions, use it in the next request. - max_results: The maximum number of model bias jobs to return in the response. The default value is 10. - name_contains: Filter for model bias jobs whose name contains a specified string. - creation_time_before: A filter that returns only model bias jobs created before a specified time. - creation_time_after: A filter that returns only model bias jobs created after a specified time. - session: Boto3 session. - region: Region name. - - Returns: - Iterator for listed ModelBiasJobDefinition resources. + Delete a MlflowTrackingServer resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17300,126 +17360,26 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + client = Base.get_sagemaker_client() operation_input_args = { - "EndpointName": endpoint_name, - "SortBy": sort_by, - "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, - } - custom_key_mapping = { - "monitoring_job_definition_name": "job_definition_name", - "monitoring_job_definition_arn": "job_definition_arn", + "TrackingServerName": self.tracking_server_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_model_bias_job_definitions", - summaries_key="JobDefinitionSummaries", - summary_name="MonitoringJobDefinitionSummary", - resource_cls=ModelBiasJobDefinition, - custom_key_mapping=custom_key_mapping, - list_method_kwargs=operation_input_args, - ) - - -class ModelCard(Base): - """ - Class representing resource ModelCard - - Attributes: - model_card_arn: The Amazon Resource Name (ARN) of the model card. - model_card_name: The name of the model card. - model_card_version: The version of the model card. - content: The content of the model card. - model_card_status: The approval status of the model card within your organization. Different organizations might have different criteria for model card review and approval. Draft: The model card is a work in progress. PendingReview: The model card is pending review. Approved: The model card is approved. Archived: The model card is archived. No more updates should be made to the model card, but it can still be exported. - creation_time: The date and time the model card was created. - created_by: - security_config: The security configuration used to protect model card content. - last_modified_time: The date and time the model card was last modified. - last_modified_by: - model_card_processing_status: The processing status of model card deletion. The ModelCardProcessingStatus updates throughout the different deletion steps. DeletePending: Model card deletion request received. DeleteInProgress: Model card deletion is in progress. ContentDeleted: Deleted model card content. ExportJobsDeleted: Deleted all export jobs associated with the model card. DeleteCompleted: Successfully deleted the model card. DeleteFailed: The model card failed to delete. - - """ - - model_card_name: str - model_card_arn: Optional[str] = Unassigned() - model_card_version: Optional[int] = Unassigned() - content: Optional[str] = Unassigned() - model_card_status: Optional[str] = Unassigned() - security_config: Optional[ModelCardSecurityConfig] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - created_by: Optional[UserContext] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - last_modified_by: Optional[UserContext] = Unassigned() - model_card_processing_status: Optional[str] = Unassigned() - - def get_name(self) -> str: - attributes = vars(self) - resource_name = "model_card_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] - - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) - - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object model_card") - return None - - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = {"security_config": {"kms_key_id": {"type": "string"}}} - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "ModelCard", **kwargs - ), - ) + client.delete_mlflow_tracking_server(**operation_input_args) - return wrapper + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - @classmethod - @populate_inputs_decorator @Base.add_validate_call - def create( - cls, - model_card_name: str, - content: str, - model_card_status: str, - security_config: Optional[ModelCardSecurityConfig] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["ModelCard"]: + def stop(self) -> None: """ - Create a ModelCard resource - - Parameters: - model_card_name: The unique name of the model card. - content: The content of the model card. Content must be in model card JSON schema and provided as a string. - model_card_status: The approval status of the model card within your organization. Different organizations might have different criteria for model card review and approval. Draft: The model card is a work in progress. PendingReview: The model card is pending review. Approved: The model card is approved. Archived: The model card is archived. No more updates should be made to the model card, but it can still be exported. - security_config: An optional Key Management Service key to encrypt, decrypt, and re-encrypt model card content for regulated workloads with highly sensitive data. - tags: Key-value pairs used to manage metadata for model cards. - session: Boto3 session. - region: Region name. - - Returns: - The ModelCard resource. + Stop a MlflowTrackingServer resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17432,223 +17392,49 @@ def create( error_code = e.response['Error']['Code'] ``` ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema - LocalConfigNotFoundError: Raised when a configuration file is not found in local file system - S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + ResourceNotFound: Resource being access is not found. """ - logger.info("Creating model_card resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + client = SageMakerClient().client operation_input_args = { - "ModelCardName": model_card_name, - "SecurityConfig": security_config, - "Content": content, - "ModelCardStatus": model_card_status, - "Tags": tags, + "TrackingServerName": self.tracking_server_name, } - - operation_input_args = Base.populate_chained_attributes( - resource_name="ModelCard", operation_input_args=operation_input_args - ) - - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.create_model_card(**operation_input_args) - logger.debug(f"Response: {response}") + client.stop_mlflow_tracking_server(**operation_input_args) - return cls.get(model_card_name=model_card_name, session=session, region=region) - - @classmethod - @Base.add_validate_call - def get( - cls, - model_card_name: str, - model_card_version: Optional[int] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["ModelCard"]: - """ - Get a ModelCard resource - - Parameters: - model_card_name: The name or Amazon Resource Name (ARN) of the model card to describe. - model_card_version: The version of the model card to describe. If a version is not provided, then the latest version of the model card is described. - session: Boto3 session. - region: Region name. - - Returns: - The ModelCard resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceNotFound: Resource being access is not found. - """ - - operation_input_args = { - "ModelCardName": model_card_name, - "ModelCardVersion": model_card_version, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - response = client.describe_model_card(**operation_input_args) - - logger.debug(response) - - # deserialize the response - transformed_response = transform(response, "DescribeModelCardResponse") - model_card = cls(**transformed_response) - return model_card - - @Base.add_validate_call - def refresh( - self, - ) -> Optional["ModelCard"]: - """ - Refresh a ModelCard resource - - Returns: - The ModelCard resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ResourceNotFound: Resource being access is not found. - """ - - operation_input_args = { - "ModelCardName": self.model_card_name, - "ModelCardVersion": self.model_card_version, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client = Base.get_sagemaker_client() - response = client.describe_model_card(**operation_input_args) - - # deserialize response and update self - transform(response, "DescribeModelCardResponse", self) - return self - - @populate_inputs_decorator - @Base.add_validate_call - def update( - self, - content: Optional[str] = Unassigned(), - model_card_status: Optional[str] = Unassigned(), - ) -> Optional["ModelCard"]: - """ - Update a ModelCard resource - - Returns: - The ModelCard resource. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ResourceNotFound: Resource being access is not found. - """ - - logger.info("Updating model_card resource.") - client = Base.get_sagemaker_client() - - operation_input_args = { - "ModelCardName": self.model_card_name, - "Content": content, - "ModelCardStatus": model_card_status, - } - logger.debug(f"Input request: {operation_input_args}") - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - # create the resource - response = client.update_model_card(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() - - return self - - @Base.add_validate_call - def delete( - self, - ) -> None: - """ - Delete a ModelCard resource - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. - ResourceNotFound: Resource being access is not found. - """ - - client = Base.get_sagemaker_client() - - operation_input_args = { - "ModelCardName": self.model_card_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client.delete_model_card(**operation_input_args) - - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call def wait_for_status( self, - target_status: Literal["Draft", "PendingReview", "Approved", "Archived"], + target_status: Literal[ + "Creating", + "Created", + "CreateFailed", + "Updating", + "Updated", + "UpdateFailed", + "Deleting", + "DeleteFailed", + "Stopping", + "Stopped", + "StopFailed", + "Starting", + "Started", + "StartFailed", + "MaintenanceInProgress", + "MaintenanceComplete", + "MaintenanceFailed", + ], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a ModelCard resource to reach certain status. + Wait for a MlflowTrackingServer resource to reach certain status. Parameters: target_status: The status to wait for. @@ -17667,7 +17453,9 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task(f"Waiting for ModelCard to reach [bold]{target_status} status...") + progress.add_task( + f"Waiting for MlflowTrackingServer to reach [bold]{target_status} status..." + ) status = Status("Current status:") with Live( @@ -17680,47 +17468,38 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.model_card_status + current_status = self.tracking_server_status status.update(f"Current status: [bold]{current_status}") if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") return + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="MlflowTrackingServer", + status=current_status, + reason="(Unknown)", + ) + if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="ModelCard", status=current_status) + raise TimeoutExceededError( + resouce_type="MlflowTrackingServer", status=current_status + ) time.sleep(poll) - @classmethod @Base.add_validate_call - def get_all( - cls, - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - model_card_status: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["ModelCard"]: + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: """ - Get all ModelCard resources + Wait for a MlflowTrackingServer resource to be deleted. Parameters: - creation_time_after: Only list model cards that were created after the time specified. - creation_time_before: Only list model cards that were created before the time specified. - max_results: The maximum number of model cards to list. - name_contains: Only list model cards with names that contain the specified string. - model_card_status: Only list model cards with the specified approval status. - next_token: If the response to a previous ListModelCards request was truncated, the response includes a NextToken. To retrieve the next set of model cards, use the token in the next request. - sort_by: Sort model cards by either name or creation time. Sorts by creation time by default. - sort_order: Sort model cards by ascending or descending order. - session: Boto3 session. - region: Region name. - - Returns: - Iterator for listed ModelCard resources. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17732,59 +17511,76 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. """ + start_time = time.time() - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), ) + progress.add_task("Waiting for MlflowTrackingServer to be deleted...") + status = Status("Current status:") - operation_input_args = { - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "NameContains": name_contains, - "ModelCardStatus": model_card_status, - "SortBy": sort_by, - "SortOrder": sort_order, - } - - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.tracking_server_status + status.update(f"Current status: [bold]{current_status}") - return ResourceIterator( - client=client, - list_method="list_model_cards", - summaries_key="ModelCardSummaries", - summary_name="ModelCardSummary", - resource_cls=ModelCard, - list_method_kwargs=operation_input_args, - ) + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="MlflowTrackingServer", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) + @classmethod @Base.add_validate_call - def get_all_versions( - self, - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), + def get_all( + cls, + created_after: Optional[datetime.datetime] = Unassigned(), + created_before: Optional[datetime.datetime] = Unassigned(), + tracking_server_status: Optional[str] = Unassigned(), + mlflow_version: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator[ModelCardVersionSummary]: + ) -> ResourceIterator["MlflowTrackingServer"]: """ - List existing versions of an Amazon SageMaker Model Card. + Get all MlflowTrackingServer resources Parameters: - creation_time_after: Only list model card versions that were created after the time specified. - creation_time_before: Only list model card versions that were created before the time specified. - max_results: The maximum number of model card versions to list. - next_token: If the response to a previous ListModelCardVersions request was truncated, the response includes a NextToken. To retrieve the next set of model card versions, use the token in the next request. - sort_by: Sort listed model card versions by version. Sorts by version by default. - sort_order: Sort model card versions by ascending or descending order. + created_after: Use the CreatedAfter filter to only list tracking servers created after a specific date and time. Listed tracking servers are shown with a date and time such as "2024-03-16T01:46:56+00:00". The CreatedAfter parameter takes in a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter. + created_before: Use the CreatedBefore filter to only list tracking servers created before a specific date and time. Listed tracking servers are shown with a date and time such as "2024-03-16T01:46:56+00:00". The CreatedBefore parameter takes in a Unix timestamp. To convert a date and time into a Unix timestamp, see EpochConverter. + tracking_server_status: Filter for tracking servers with a specified creation status. + mlflow_version: Filter for tracking servers using the specified MLflow version. + sort_by: Filter for trackings servers sorting by name, creation time, or creation status. + sort_order: Change the order of the listed tracking servers. By default, tracking servers are listed in Descending order by creation time. To change the list order, you can specify SortOrder to be Ascending. + next_token: If the previous response was truncated, you will receive this token. Use it in your next request to receive the next set of results. + max_results: The maximum number of tracking servers to list. session: Boto3 session. region: Region name. Returns: - Iterator for listed ModelCardVersionSummary. + Iterator for listed MlflowTrackingServer resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17796,67 +17592,67 @@ def get_all_versions( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "ModelCardName": self.model_card_name, - "ModelCardStatus": self.model_card_status, + "CreatedAfter": created_after, + "CreatedBefore": created_before, + "TrackingServerStatus": tracking_server_status, + "MlflowVersion": mlflow_version, "SortBy": sort_by, "SortOrder": sort_order, } + # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - return ResourceIterator( client=client, - list_method="list_model_card_versions", - summaries_key="ModelCardVersionSummaryList", - summary_name="ModelCardVersionSummary", - resource_cls=ModelCardVersionSummary, + list_method="list_mlflow_tracking_servers", + summaries_key="TrackingServerSummaries", + summary_name="TrackingServerSummary", + resource_cls=MlflowTrackingServer, list_method_kwargs=operation_input_args, ) -class ModelCardExportJob(Base): +class Model(Base): """ - Class representing resource ModelCardExportJob + Class representing resource Model Attributes: - model_card_export_job_name: The name of the model card export job to describe. - model_card_export_job_arn: The Amazon Resource Name (ARN) of the model card export job. - status: The completion status of the model card export job. InProgress: The model card export job is in progress. Completed: The model card export job is complete. Failed: The model card export job failed. To see the reason for the failure, see the FailureReason field in the response to a DescribeModelCardExportJob call. - model_card_name: The name or Amazon Resource Name (ARN) of the model card that the model export job exports. - model_card_version: The version of the model card that the model export job exports. - output_config: The export output details for the model card. - created_at: The date and time that the model export job was created. - last_modified_at: The date and time that the model export job was last modified. - failure_reason: The failure reason if the model export job fails. - export_artifacts: The exported model card artifacts. + model_name: Name of the SageMaker model. + creation_time: A timestamp that shows when the model was created. + model_arn: The Amazon Resource Name (ARN) of the model. + primary_container: The location of the primary inference code, associated artifacts, and custom environment map that the inference code uses when it is deployed in production. + containers: The containers in the inference pipeline. + inference_execution_config: Specifies details of how containers in a multi-container endpoint are called. + execution_role_arn: The Amazon Resource Name (ARN) of the IAM role that you specified for the model. + vpc_config: A VpcConfig object that specifies the VPC that this model has access to. For more information, see Protect Endpoints by Using an Amazon Virtual Private Cloud + enable_network_isolation: If True, no inbound or outbound network calls can be made to or from the model container. + deployment_recommendation: A set of recommended deployment configurations for the model. """ - model_card_export_job_arn: str - model_card_export_job_name: Optional[str] = Unassigned() - status: Optional[str] = Unassigned() - model_card_name: Optional[str] = Unassigned() - model_card_version: Optional[int] = Unassigned() - output_config: Optional[ModelCardExportOutputConfig] = Unassigned() - created_at: Optional[datetime.datetime] = Unassigned() - last_modified_at: Optional[datetime.datetime] = Unassigned() - failure_reason: Optional[str] = Unassigned() - export_artifacts: Optional[ModelCardExportArtifacts] = Unassigned() + model_name: str + primary_container: Optional[ContainerDefinition] = Unassigned() + containers: Optional[List[ContainerDefinition]] = Unassigned() + inference_execution_config: Optional[InferenceExecutionConfig] = Unassigned() + execution_role_arn: Optional[str] = Unassigned() + vpc_config: Optional[VpcConfig] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + model_arn: Optional[str] = Unassigned() + enable_network_isolation: Optional[bool] = Unassigned() + deployment_recommendation: Optional[DeploymentRecommendation] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "model_card_export_job_name" + resource_name = "model_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -17867,20 +17663,32 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object model_card_export_job") + logger.error("Name attribute not found for object model") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "output_config": {"s3_output_path": {"type": "string"}}, - "export_artifacts": {"s3_export_artifacts": {"type": "string"}}, + "primary_container": { + "model_data_source": { + "s3_data_source": { + "s3_uri": {"type": "string"}, + "s3_data_type": {"type": "string"}, + "manifest_s3_uri": {"type": "string"}, + } + } + }, + "execution_role_arn": {"type": "string"}, + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + }, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "ModelCardExportJob", **kwargs + config_schema_for_resource, "Model", **kwargs ), ) @@ -17891,26 +17699,34 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - model_card_name: Union[str, object], - model_card_export_job_name: str, - output_config: ModelCardExportOutputConfig, - model_card_version: Optional[int] = Unassigned(), + model_name: str, + primary_container: Optional[ContainerDefinition] = Unassigned(), + containers: Optional[List[ContainerDefinition]] = Unassigned(), + inference_execution_config: Optional[InferenceExecutionConfig] = Unassigned(), + execution_role_arn: Optional[str] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + vpc_config: Optional[VpcConfig] = Unassigned(), + enable_network_isolation: Optional[bool] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelCardExportJob"]: + ) -> Optional["Model"]: """ - Create a ModelCardExportJob resource + Create a Model resource Parameters: - model_card_name: The name or Amazon Resource Name (ARN) of the model card to export. - model_card_export_job_name: The name of the model card export job. - output_config: The model card output configuration that specifies the Amazon S3 path for exporting. - model_card_version: The version of the model card to export. If a version is not provided, then the latest version of the model card is exported. + model_name: The name of the new model. + primary_container: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions. + containers: Specifies the containers in the inference pipeline. + inference_execution_config: Specifies details of how containers in a multi-container endpoint are called. + execution_role_arn: The Amazon Resource Name (ARN) of the IAM role that SageMaker can assume to access model artifacts and docker image for deployment on ML compute instances or for batch transform jobs. Deploying on ML compute instances is part of model hosting. For more information, see SageMaker Roles. To be able to pass this role to SageMaker, the caller of this API must have the iam:PassRole permission. + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + vpc_config: A VpcConfig object that specifies the VPC that you want your model to connect to. Control access to and from your model container by configuring the VPC. VpcConfig is used in hosting services and in batch transform. For more information, see Protect Endpoints by Using an Amazon Virtual Private Cloud and Protect Data in Batch Transform Jobs by Using an Amazon Virtual Private Cloud. + enable_network_isolation: Isolates the model container. No inbound or outbound network calls can be made to or from the model container. session: Boto3 session. region: Region name. Returns: - The ModelCardExportJob resource. + The Model resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17922,28 +17738,30 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ResourceNotFound: Resource being access is not found. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating model_card_export_job resource.") + logger.info("Creating model resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "ModelCardName": model_card_name, - "ModelCardVersion": model_card_version, - "ModelCardExportJobName": model_card_export_job_name, - "OutputConfig": output_config, + "ModelName": model_name, + "PrimaryContainer": primary_container, + "Containers": containers, + "InferenceExecutionConfig": inference_execution_config, + "ExecutionRoleArn": execution_role_arn, + "Tags": tags, + "VpcConfig": vpc_config, + "EnableNetworkIsolation": enable_network_isolation, } operation_input_args = Base.populate_chained_attributes( - resource_name="ModelCardExportJob", operation_input_args=operation_input_args + resource_name="Model", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -17952,33 +17770,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_model_card_export_job(**operation_input_args) + response = client.create_model(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get( - model_card_export_job_arn=response["ModelCardExportJobArn"], - session=session, - region=region, - ) + return cls.get(model_name=model_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - model_card_export_job_arn: str, + model_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelCardExportJob"]: + ) -> Optional["Model"]: """ - Get a ModelCardExportJob resource + Get a Model resource Parameters: - model_card_export_job_arn: The Amazon Resource Name (ARN) of the model card export job to describe. + model_name: The name of the model. session: Boto3 session. region: Region name. Returns: - The ModelCardExportJob resource. + The Model resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -17990,11 +17804,10 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelCardExportJobArn": model_card_export_job_arn, + "ModelName": model_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -18003,24 +17816,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_model_card_export_job(**operation_input_args) + response = client.describe_model(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeModelCardExportJobResponse") - model_card_export_job = cls(**transformed_response) - return model_card_export_job + transformed_response = transform(response, "DescribeModelOutput") + model = cls(**transformed_response) + return model @Base.add_validate_call def refresh( self, - ) -> Optional["ModelCardExportJob"]: + ) -> Optional["Model"]: """ - Refresh a ModelCardExportJob resource + Refresh a Model resource Returns: - The ModelCardExportJob resource. + The Model resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18032,118 +17845,82 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelCardExportJobArn": self.model_card_export_job_arn, + "ModelName": self.model_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_model_card_export_job(**operation_input_args) + response = client.describe_model(**operation_input_args) # deserialize response and update self - transform(response, "DescribeModelCardExportJobResponse", self) + transform(response, "DescribeModelOutput", self) return self @Base.add_validate_call - def wait( + def delete( self, - poll: int = 5, - timeout: Optional[int] = None, ) -> None: """ - Wait for a ModelCardExportJob resource. - - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + Delete a Model resource Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` """ - terminal_states = ["Completed", "Failed"] - start_time = time.time() - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for ModelCardExportJob...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.status - status.update(f"Current status: [bold]{current_status}") - - if current_status in terminal_states: - logger.info(f"Final Resource Status: [bold]{current_status}") + client = Base.get_sagemaker_client() - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="ModelCardExportJob", - status=current_status, - reason=self.failure_reason, - ) + operation_input_args = { + "ModelName": self.model_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") - return + client.delete_model(**operation_input_args) - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="ModelCardExportJob", status=current_status - ) - time.sleep(poll) + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @classmethod @Base.add_validate_call def get_all( cls, - model_card_name: str, - model_card_version: Optional[int] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - model_card_export_job_name_contains: Optional[str] = Unassigned(), - status_equals: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["ModelCardExportJob"]: + ) -> ResourceIterator["Model"]: """ - Get all ModelCardExportJob resources + Get all Model resources Parameters: - model_card_name: List export jobs for the model card with the specified name. - model_card_version: List export jobs for the model card with the specified version. - creation_time_after: Only list model card export jobs that were created after the time specified. - creation_time_before: Only list model card export jobs that were created before the time specified. - model_card_export_job_name_contains: Only list model card export jobs with names that contain the specified string. - status_equals: Only list model card export jobs with the specified status. - sort_by: Sort model card export jobs by either name or creation time. Sorts by creation time by default. - sort_order: Sort model card export jobs by ascending or descending order. - next_token: If the response to a previous ListModelCardExportJobs request was truncated, the response includes a NextToken. To retrieve the next set of model card export jobs, use the token in the next request. - max_results: The maximum number of model card export jobs to list. + sort_by: Sorts the list of results. The default is CreationTime. + sort_order: The sort order for results. The default is Descending. + next_token: If the response to a previous ListModels request was truncated, the response includes a NextToken. To retrieve the next set of models, use the token in the next request. + max_results: The maximum number of models to return in the response. + name_contains: A string in the model name. This filter returns only models whose name contains the specified string. + creation_time_before: A filter that returns only models created before the specified time (timestamp). + creation_time_after: A filter that returns only models with a creation time greater than or equal to the specified time (timestamp). session: Boto3 session. region: Region name. Returns: - Iterator for listed ModelCardExportJob resources. + Iterator for listed Model resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18162,14 +17939,11 @@ def get_all( ) operation_input_args = { - "ModelCardName": model_card_name, - "ModelCardVersion": model_card_version, - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "ModelCardExportJobNameContains": model_card_export_job_name_contains, - "StatusEquals": status_equals, "SortBy": sort_by, "SortOrder": sort_order, + "NameContains": name_contains, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, } # serialize the input request @@ -18178,29 +17952,81 @@ def get_all( return ResourceIterator( client=client, - list_method="list_model_card_export_jobs", - summaries_key="ModelCardExportJobSummaries", - summary_name="ModelCardExportJobSummary", - resource_cls=ModelCardExportJob, + list_method="list_models", + summaries_key="Models", + summary_name="ModelSummary", + resource_cls=Model, list_method_kwargs=operation_input_args, ) + @Base.add_validate_call + def get_all_metadata( + self, + search_expression: Optional[ModelMetadataSearchExpression] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator[ModelMetadataSummary]: + """ + Lists the domain, framework, task, and model name of standard machine learning models found in common model zoos. + + Parameters: + search_expression: One or more filters that searches for the specified resource or resources in a search. All resource objects that satisfy the expression's condition are included in the search results. Specify the Framework, FrameworkVersion, Domain or Task to filter supported. Filter names and values are case-sensitive. + next_token: If the response to a previous ListModelMetadataResponse request was truncated, the response includes a NextToken. To retrieve the next set of model metadata, use the token in the next request. + max_results: The maximum number of models to return in the response. + session: Boto3 session. + region: Region name. -class ModelExplainabilityJobDefinition(Base): + Returns: + Iterator for listed ModelMetadataSummary. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + operation_input_args = { + "SearchExpression": search_expression, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + return ResourceIterator( + client=client, + list_method="list_model_metadata", + summaries_key="ModelMetadataSummaries", + summary_name="ModelMetadataSummary", + resource_cls=ModelMetadataSummary, + list_method_kwargs=operation_input_args, + ) + + +class ModelBiasJobDefinition(Base): """ - Class representing resource ModelExplainabilityJobDefinition + Class representing resource ModelBiasJobDefinition Attributes: - job_definition_arn: The Amazon Resource Name (ARN) of the model explainability job. - job_definition_name: The name of the explainability job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. - creation_time: The time at which the model explainability job was created. - model_explainability_app_specification: Configures the model explainability job to run a specified Docker container image. - model_explainability_job_input: Inputs for the model explainability job. - model_explainability_job_output_config: + job_definition_arn: The Amazon Resource Name (ARN) of the model bias job. + job_definition_name: The name of the bias job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + creation_time: The time at which the model bias job was created. + model_bias_app_specification: Configures the model bias job to run a specified Docker container image. + model_bias_job_input: Inputs for the model bias job. + model_bias_job_output_config: job_resources: role_arn: The Amazon Resource Name (ARN) of the IAM role that has read permission to the input data location and write permission to the output data location in Amazon S3. - model_explainability_baseline_config: The baseline configuration for a model explainability job. - network_config: Networking options for a model explainability job. + model_bias_baseline_config: The baseline configuration for a model bias job. + network_config: Networking options for a model bias job. stopping_condition: """ @@ -18208,12 +18034,10 @@ class ModelExplainabilityJobDefinition(Base): job_definition_name: str job_definition_arn: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - model_explainability_baseline_config: Optional[ModelExplainabilityBaselineConfig] = Unassigned() - model_explainability_app_specification: Optional[ModelExplainabilityAppSpecification] = ( - Unassigned() - ) - model_explainability_job_input: Optional[ModelExplainabilityJobInput] = Unassigned() - model_explainability_job_output_config: Optional[MonitoringOutputConfig] = Unassigned() + model_bias_baseline_config: Optional[ModelBiasBaselineConfig] = Unassigned() + model_bias_app_specification: Optional[ModelBiasAppSpecification] = Unassigned() + model_bias_job_input: Optional[ModelBiasJobInput] = Unassigned() + model_bias_job_output_config: Optional[MonitoringOutputConfig] = Unassigned() job_resources: Optional[MonitoringResources] = Unassigned() network_config: Optional[MonitoringNetworkConfig] = Unassigned() role_arn: Optional[str] = Unassigned() @@ -18221,7 +18045,7 @@ class ModelExplainabilityJobDefinition(Base): def get_name(self) -> str: attributes = vars(self) - resource_name = "model_explainability_job_definition_name" + resource_name = "model_bias_job_definition_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -18232,14 +18056,15 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object model_explainability_job_definition") + logger.error("Name attribute not found for object model_bias_job_definition") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "model_explainability_job_input": { + "model_bias_job_input": { + "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, "endpoint_input": { "s3_input_mode": {"type": "string"}, "s3_data_distribution_type": {"type": "string"}, @@ -18250,10 +18075,10 @@ def wrapper(*args, **kwargs): "s3_data_distribution_type": {"type": "string"}, }, }, - "model_explainability_job_output_config": {"kms_key_id": {"type": "string"}}, + "model_bias_job_output_config": {"kms_key_id": {"type": "string"}}, "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, "role_arn": {"type": "string"}, - "model_explainability_baseline_config": { + "model_bias_baseline_config": { "constraints_resource": {"s3_uri": {"type": "string"}} }, "network_config": { @@ -18266,7 +18091,7 @@ def wrapper(*args, **kwargs): return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "ModelExplainabilityJobDefinition", **kwargs + config_schema_for_resource, "ModelBiasJobDefinition", **kwargs ), ) @@ -18278,39 +18103,37 @@ def wrapper(*args, **kwargs): def create( cls, job_definition_name: str, - model_explainability_app_specification: ModelExplainabilityAppSpecification, - model_explainability_job_input: ModelExplainabilityJobInput, - model_explainability_job_output_config: MonitoringOutputConfig, + model_bias_app_specification: ModelBiasAppSpecification, + model_bias_job_input: ModelBiasJobInput, + model_bias_job_output_config: MonitoringOutputConfig, job_resources: MonitoringResources, role_arn: str, - model_explainability_baseline_config: Optional[ - ModelExplainabilityBaselineConfig - ] = Unassigned(), + model_bias_baseline_config: Optional[ModelBiasBaselineConfig] = Unassigned(), network_config: Optional[MonitoringNetworkConfig] = Unassigned(), stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelExplainabilityJobDefinition"]: + ) -> Optional["ModelBiasJobDefinition"]: """ - Create a ModelExplainabilityJobDefinition resource + Create a ModelBiasJobDefinition resource Parameters: - job_definition_name: The name of the model explainability job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. - model_explainability_app_specification: Configures the model explainability job to run a specified Docker container image. - model_explainability_job_input: Inputs for the model explainability job. - model_explainability_job_output_config: + job_definition_name: The name of the bias job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + model_bias_app_specification: Configures the model bias job to run a specified Docker container image. + model_bias_job_input: Inputs for the model bias job. + model_bias_job_output_config: job_resources: role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - model_explainability_baseline_config: The baseline configuration for a model explainability job. - network_config: Networking options for a model explainability job. + model_bias_baseline_config: The baseline configuration for a model bias job. + network_config: Networking options for a model bias job. stopping_condition: tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. session: Boto3 session. region: Region name. Returns: - The ModelExplainabilityJobDefinition resource. + The ModelBiasJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18329,17 +18152,17 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating model_explainability_job_definition resource.") + logger.info("Creating model_bias_job_definition resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { "JobDefinitionName": job_definition_name, - "ModelExplainabilityBaselineConfig": model_explainability_baseline_config, - "ModelExplainabilityAppSpecification": model_explainability_app_specification, - "ModelExplainabilityJobInput": model_explainability_job_input, - "ModelExplainabilityJobOutputConfig": model_explainability_job_output_config, + "ModelBiasBaselineConfig": model_bias_baseline_config, + "ModelBiasAppSpecification": model_bias_app_specification, + "ModelBiasJobInput": model_bias_job_input, + "ModelBiasJobOutputConfig": model_bias_job_output_config, "JobResources": job_resources, "NetworkConfig": network_config, "RoleArn": role_arn, @@ -18348,8 +18171,7 @@ def create( } operation_input_args = Base.populate_chained_attributes( - resource_name="ModelExplainabilityJobDefinition", - operation_input_args=operation_input_args, + resource_name="ModelBiasJobDefinition", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -18358,7 +18180,7 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_model_explainability_job_definition(**operation_input_args) + response = client.create_model_bias_job_definition(**operation_input_args) logger.debug(f"Response: {response}") return cls.get(job_definition_name=job_definition_name, session=session, region=region) @@ -18370,17 +18192,17 @@ def get( job_definition_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelExplainabilityJobDefinition"]: + ) -> Optional["ModelBiasJobDefinition"]: """ - Get a ModelExplainabilityJobDefinition resource + Get a ModelBiasJobDefinition resource Parameters: - job_definition_name: The name of the model explainability job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + job_definition_name: The name of the model bias job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. session: Boto3 session. region: Region name. Returns: - The ModelExplainabilityJobDefinition resource. + The ModelBiasJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18405,26 +18227,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_model_explainability_job_definition(**operation_input_args) + response = client.describe_model_bias_job_definition(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform( - response, "DescribeModelExplainabilityJobDefinitionResponse" - ) - model_explainability_job_definition = cls(**transformed_response) - return model_explainability_job_definition + transformed_response = transform(response, "DescribeModelBiasJobDefinitionResponse") + model_bias_job_definition = cls(**transformed_response) + return model_bias_job_definition @Base.add_validate_call def refresh( self, - ) -> Optional["ModelExplainabilityJobDefinition"]: + ) -> Optional["ModelBiasJobDefinition"]: """ - Refresh a ModelExplainabilityJobDefinition resource + Refresh a ModelBiasJobDefinition resource Returns: - The ModelExplainabilityJobDefinition resource. + The ModelBiasJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18447,10 +18267,10 @@ def refresh( logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_model_explainability_job_definition(**operation_input_args) + response = client.describe_model_bias_job_definition(**operation_input_args) # deserialize response and update self - transform(response, "DescribeModelExplainabilityJobDefinitionResponse", self) + transform(response, "DescribeModelBiasJobDefinitionResponse", self) return self @Base.add_validate_call @@ -18458,7 +18278,7 @@ def delete( self, ) -> None: """ - Delete a ModelExplainabilityJobDefinition resource + Delete a ModelBiasJobDefinition resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18482,7 +18302,7 @@ def delete( operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_model_explainability_job_definition(**operation_input_args) + client.delete_model_bias_job_definition(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @@ -18498,24 +18318,24 @@ def get_all( creation_time_after: Optional[datetime.datetime] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["ModelExplainabilityJobDefinition"]: + ) -> ResourceIterator["ModelBiasJobDefinition"]: """ - Get all ModelExplainabilityJobDefinition resources + Get all ModelBiasJobDefinition resources Parameters: - endpoint_name: Name of the endpoint to monitor for model explainability. + endpoint_name: Name of the endpoint to monitor for model bias. sort_by: Whether to sort results by the Name or CreationTime field. The default is CreationTime. sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. next_token: The token returned if the response is truncated. To retrieve the next set of job executions, use it in the next request. - max_results: The maximum number of jobs to return in the response. The default value is 10. - name_contains: Filter for model explainability jobs whose name contains a specified string. - creation_time_before: A filter that returns only model explainability jobs created before a specified time. - creation_time_after: A filter that returns only model explainability jobs created after a specified time. + max_results: The maximum number of model bias jobs to return in the response. The default value is 10. + name_contains: Filter for model bias jobs whose name contains a specified string. + creation_time_before: A filter that returns only model bias jobs created before a specified time. + creation_time_after: A filter that returns only model bias jobs created after a specified time. session: Boto3 session. region: Region name. Returns: - Iterator for listed ModelExplainabilityJobDefinition resources. + Iterator for listed ModelBiasJobDefinition resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18551,89 +18371,49 @@ def get_all( return ResourceIterator( client=client, - list_method="list_model_explainability_job_definitions", + list_method="list_model_bias_job_definitions", summaries_key="JobDefinitionSummaries", summary_name="MonitoringJobDefinitionSummary", - resource_cls=ModelExplainabilityJobDefinition, + resource_cls=ModelBiasJobDefinition, custom_key_mapping=custom_key_mapping, list_method_kwargs=operation_input_args, ) -class ModelPackage(Base): +class ModelCard(Base): """ - Class representing resource ModelPackage + Class representing resource ModelCard Attributes: - model_package_name: The name of the model package being described. - model_package_arn: The Amazon Resource Name (ARN) of the model package. - creation_time: A timestamp specifying when the model package was created. - model_package_status: The current status of the model package. - model_package_status_details: Details about the current status of the model package. - model_package_group_name: If the model is a versioned model, the name of the model group that the versioned model belongs to. - model_package_version: The version of the model package. - model_package_description: A brief summary of the model package. - inference_specification: Details about inference jobs that you can run with models based on this model package. - source_algorithm_specification: Details about the algorithm that was used to create the model package. - validation_specification: Configurations for one or more transform jobs that SageMaker runs to test the model package. - certify_for_marketplace: Whether the model package is certified for listing on Amazon Web Services Marketplace. - model_approval_status: The approval status of the model package. + model_card_arn: The Amazon Resource Name (ARN) of the model card. + model_card_name: The name of the model card. + model_card_version: The version of the model card. + content: The content of the model card. + model_card_status: The approval status of the model card within your organization. Different organizations might have different criteria for model card review and approval. Draft: The model card is a work in progress. PendingReview: The model card is pending review. Approved: The model card is approved. Archived: The model card is archived. No more updates should be made to the model card, but it can still be exported. + creation_time: The date and time the model card was created. created_by: - metadata_properties: - model_metrics: Metrics for the model. - last_modified_time: The last time that the model package was modified. + security_config: The security configuration used to protect model card content. + last_modified_time: The date and time the model card was last modified. last_modified_by: - approval_description: A description provided for the model approval. - domain: The machine learning domain of the model package you specified. Common machine learning domains include computer vision and natural language processing. - task: The machine learning task you specified that your model package accomplishes. Common machine learning tasks include object detection and image classification. - sample_payload_url: The Amazon Simple Storage Service (Amazon S3) path where the sample payload are stored. This path points to a single gzip compressed tar archive (.tar.gz suffix). - customer_metadata_properties: The metadata properties associated with the model package versions. - drift_check_baselines: Represents the drift check baselines that can be used when the model monitor is set using the model package. For more information, see the topic on Drift Detection against Previous Baselines in SageMaker Pipelines in the Amazon SageMaker Developer Guide. - additional_inference_specifications: An array of additional Inference Specification objects. Each additional Inference Specification specifies artifacts based on this model package that can be used on inference endpoints. Generally used with SageMaker Neo to store the compiled artifacts. - skip_model_validation: Indicates if you want to skip model validation. - source_uri: The URI of the source for the model package. - security_config: The KMS Key ID (KMSKeyId) used for encryption of model package information. - model_card: The model card associated with the model package. Since ModelPackageModelCard is tied to a model package, it is a specific usage of a model card and its schema is simplified compared to the schema of ModelCard. The ModelPackageModelCard schema does not include model_package_details, and model_overview is composed of the model_creator and model_artifact properties. For more information about the model package model card schema, see Model package model card schema. For more information about the model card associated with the model package, see View the Details of a Model Version. - model_life_cycle: A structure describing the current state of the model in its life cycle. + model_card_processing_status: The processing status of model card deletion. The ModelCardProcessingStatus updates throughout the different deletion steps. DeletePending: Model card deletion request received. DeleteInProgress: Model card deletion is in progress. ContentDeleted: Deleted model card content. ExportJobsDeleted: Deleted all export jobs associated with the model card. DeleteCompleted: Successfully deleted the model card. DeleteFailed: The model card failed to delete. """ - model_package_name: str - model_package_group_name: Optional[str] = Unassigned() - model_package_version: Optional[int] = Unassigned() - model_package_arn: Optional[str] = Unassigned() - model_package_description: Optional[str] = Unassigned() + model_card_name: str + model_card_arn: Optional[str] = Unassigned() + model_card_version: Optional[int] = Unassigned() + content: Optional[str] = Unassigned() + model_card_status: Optional[str] = Unassigned() + security_config: Optional[ModelCardSecurityConfig] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - inference_specification: Optional[InferenceSpecification] = Unassigned() - source_algorithm_specification: Optional[SourceAlgorithmSpecification] = Unassigned() - validation_specification: Optional[ModelPackageValidationSpecification] = Unassigned() - model_package_status: Optional[str] = Unassigned() - model_package_status_details: Optional[ModelPackageStatusDetails] = Unassigned() - certify_for_marketplace: Optional[bool] = Unassigned() - model_approval_status: Optional[str] = Unassigned() created_by: Optional[UserContext] = Unassigned() - metadata_properties: Optional[MetadataProperties] = Unassigned() - model_metrics: Optional[ModelMetrics] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() last_modified_by: Optional[UserContext] = Unassigned() - approval_description: Optional[str] = Unassigned() - domain: Optional[str] = Unassigned() - task: Optional[str] = Unassigned() - sample_payload_url: Optional[str] = Unassigned() - customer_metadata_properties: Optional[Dict[str, str]] = Unassigned() - drift_check_baselines: Optional[DriftCheckBaselines] = Unassigned() - additional_inference_specifications: Optional[ - List[AdditionalInferenceSpecificationDefinition] - ] = Unassigned() - skip_model_validation: Optional[str] = Unassigned() - source_uri: Optional[str] = Unassigned() - security_config: Optional[ModelPackageSecurityConfig] = Unassigned() - model_card: Optional[ModelPackageModelCard] = Unassigned() - model_life_cycle: Optional[ModelLifeCycle] = Unassigned() + model_card_processing_status: Optional[str] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "model_package_name" + resource_name = "model_card_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -18644,55 +18424,17 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object model_package") + logger.error("Name attribute not found for object model_card") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): - config_schema_for_resource = { - "validation_specification": {"validation_role": {"type": "string"}}, - "model_metrics": { - "model_quality": { - "statistics": {"s3_uri": {"type": "string"}}, - "constraints": {"s3_uri": {"type": "string"}}, - }, - "model_data_quality": { - "statistics": {"s3_uri": {"type": "string"}}, - "constraints": {"s3_uri": {"type": "string"}}, - }, - "bias": { - "report": {"s3_uri": {"type": "string"}}, - "pre_training_report": {"s3_uri": {"type": "string"}}, - "post_training_report": {"s3_uri": {"type": "string"}}, - }, - "explainability": {"report": {"s3_uri": {"type": "string"}}}, - }, - "drift_check_baselines": { - "bias": { - "config_file": {"s3_uri": {"type": "string"}}, - "pre_training_constraints": {"s3_uri": {"type": "string"}}, - "post_training_constraints": {"s3_uri": {"type": "string"}}, - }, - "explainability": { - "constraints": {"s3_uri": {"type": "string"}}, - "config_file": {"s3_uri": {"type": "string"}}, - }, - "model_quality": { - "statistics": {"s3_uri": {"type": "string"}}, - "constraints": {"s3_uri": {"type": "string"}}, - }, - "model_data_quality": { - "statistics": {"s3_uri": {"type": "string"}}, - "constraints": {"s3_uri": {"type": "string"}}, - }, - }, - "security_config": {"kms_key_id": {"type": "string"}}, - } + config_schema_for_resource = {"security_config": {"kms_key_id": {"type": "string"}}} return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "ModelPackage", **kwargs + config_schema_for_resource, "ModelCard", **kwargs ), ) @@ -18703,66 +18445,28 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - model_package_name: Optional[str] = Unassigned(), - model_package_group_name: Optional[Union[str, object]] = Unassigned(), - model_package_description: Optional[str] = Unassigned(), - inference_specification: Optional[InferenceSpecification] = Unassigned(), - validation_specification: Optional[ModelPackageValidationSpecification] = Unassigned(), - source_algorithm_specification: Optional[SourceAlgorithmSpecification] = Unassigned(), - certify_for_marketplace: Optional[bool] = Unassigned(), + model_card_name: str, + content: str, + model_card_status: str, + security_config: Optional[ModelCardSecurityConfig] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - model_approval_status: Optional[str] = Unassigned(), - metadata_properties: Optional[MetadataProperties] = Unassigned(), - model_metrics: Optional[ModelMetrics] = Unassigned(), - client_token: Optional[str] = Unassigned(), - domain: Optional[str] = Unassigned(), - task: Optional[str] = Unassigned(), - sample_payload_url: Optional[str] = Unassigned(), - customer_metadata_properties: Optional[Dict[str, str]] = Unassigned(), - drift_check_baselines: Optional[DriftCheckBaselines] = Unassigned(), - additional_inference_specifications: Optional[ - List[AdditionalInferenceSpecificationDefinition] - ] = Unassigned(), - skip_model_validation: Optional[str] = Unassigned(), - source_uri: Optional[str] = Unassigned(), - security_config: Optional[ModelPackageSecurityConfig] = Unassigned(), - model_card: Optional[ModelPackageModelCard] = Unassigned(), - model_life_cycle: Optional[ModelLifeCycle] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelPackage"]: + ) -> Optional["ModelCard"]: """ - Create a ModelPackage resource + Create a ModelCard resource Parameters: - model_package_name: The name of the model package. The name must have 1 to 63 characters. Valid characters are a-z, A-Z, 0-9, and - (hyphen). This parameter is required for unversioned models. It is not applicable to versioned models. - model_package_group_name: The name or Amazon Resource Name (ARN) of the model package group that this model version belongs to. This parameter is required for versioned models, and does not apply to unversioned models. - model_package_description: A description of the model package. - inference_specification: Specifies details about inference jobs that you can run with models based on this model package, including the following information: The Amazon ECR paths of containers that contain the inference code and model artifacts. The instance types that the model package supports for transform jobs and real-time endpoints used for inference. The input and output content formats that the model package supports for inference. - validation_specification: Specifies configurations for one or more transform jobs that SageMaker runs to test the model package. - source_algorithm_specification: Details about the algorithm that was used to create the model package. - certify_for_marketplace: Whether to certify the model package for listing on Amazon Web Services Marketplace. This parameter is optional for unversioned models, and does not apply to versioned models. - tags: A list of key value pairs associated with the model. For more information, see Tagging Amazon Web Services resources in the Amazon Web Services General Reference Guide. If you supply ModelPackageGroupName, your model package belongs to the model group you specify and uses the tags associated with the model group. In this case, you cannot supply a tag argument. - model_approval_status: Whether the model is approved for deployment. This parameter is optional for versioned models, and does not apply to unversioned models. For versioned models, the value of this parameter must be set to Approved to deploy the model. - metadata_properties: - model_metrics: A structure that contains model metrics reports. - client_token: A unique token that guarantees that the call to this API is idempotent. - domain: The machine learning domain of your model package and its components. Common machine learning domains include computer vision and natural language processing. - task: The machine learning task your model package accomplishes. Common machine learning tasks include object detection and image classification. The following tasks are supported by Inference Recommender: "IMAGE_CLASSIFICATION" \| "OBJECT_DETECTION" \| "TEXT_GENERATION" \|"IMAGE_SEGMENTATION" \| "FILL_MASK" \| "CLASSIFICATION" \| "REGRESSION" \| "OTHER". Specify "OTHER" if none of the tasks listed fit your use case. - sample_payload_url: The Amazon Simple Storage Service (Amazon S3) path where the sample payload is stored. This path must point to a single gzip compressed tar archive (.tar.gz suffix). This archive can hold multiple files that are all equally used in the load test. Each file in the archive must satisfy the size constraints of the InvokeEndpoint call. - customer_metadata_properties: The metadata properties associated with the model package versions. - drift_check_baselines: Represents the drift check baselines that can be used when the model monitor is set using the model package. For more information, see the topic on Drift Detection against Previous Baselines in SageMaker Pipelines in the Amazon SageMaker Developer Guide. - additional_inference_specifications: An array of additional Inference Specification objects. Each additional Inference Specification specifies artifacts based on this model package that can be used on inference endpoints. Generally used with SageMaker Neo to store the compiled artifacts. - skip_model_validation: Indicates if you want to skip model validation. - source_uri: The URI of the source for the model package. If you want to clone a model package, set it to the model package Amazon Resource Name (ARN). If you want to register a model, set it to the model ARN. - security_config: The KMS Key ID (KMSKeyId) used for encryption of model package information. - model_card: The model card associated with the model package. Since ModelPackageModelCard is tied to a model package, it is a specific usage of a model card and its schema is simplified compared to the schema of ModelCard. The ModelPackageModelCard schema does not include model_package_details, and model_overview is composed of the model_creator and model_artifact properties. For more information about the model package model card schema, see Model package model card schema. For more information about the model card associated with the model package, see View the Details of a Model Version. - model_life_cycle: A structure describing the current state of the model in its life cycle. + model_card_name: The unique name of the model card. + content: The content of the model card. Content must be in model card JSON schema and provided as a string. + model_card_status: The approval status of the model card within your organization. Different organizations might have different criteria for model card review and approval. Draft: The model card is a work in progress. PendingReview: The model card is pending review. Approved: The model card is approved. Archived: The model card is archived. No more updates should be made to the model card, but it can still be exported. + security_config: An optional Key Management Service key to encrypt, decrypt, and re-encrypt model card content for regulated workloads with highly sensitive data. + tags: Key-value pairs used to manage metadata for model cards. session: Boto3 session. region: Region name. Returns: - The ModelPackage resource. + The ModelCard resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18781,39 +18485,21 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating model_package resource.") + logger.info("Creating model_card resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "ModelPackageName": model_package_name, - "ModelPackageGroupName": model_package_group_name, - "ModelPackageDescription": model_package_description, - "InferenceSpecification": inference_specification, - "ValidationSpecification": validation_specification, - "SourceAlgorithmSpecification": source_algorithm_specification, - "CertifyForMarketplace": certify_for_marketplace, - "Tags": tags, - "ModelApprovalStatus": model_approval_status, - "MetadataProperties": metadata_properties, - "ModelMetrics": model_metrics, - "ClientToken": client_token, - "Domain": domain, - "Task": task, - "SamplePayloadUrl": sample_payload_url, - "CustomerMetadataProperties": customer_metadata_properties, - "DriftCheckBaselines": drift_check_baselines, - "AdditionalInferenceSpecifications": additional_inference_specifications, - "SkipModelValidation": skip_model_validation, - "SourceUri": source_uri, + "ModelCardName": model_card_name, "SecurityConfig": security_config, - "ModelCard": model_card, - "ModelLifeCycle": model_life_cycle, + "Content": content, + "ModelCardStatus": model_card_status, + "Tags": tags, } operation_input_args = Base.populate_chained_attributes( - resource_name="ModelPackage", operation_input_args=operation_input_args + resource_name="ModelCard", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -18822,31 +18508,31 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_model_package(**operation_input_args) + response = client.create_model_card(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get( - model_package_name=response["ModelPackageName"], session=session, region=region - ) + return cls.get(model_card_name=model_card_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - model_package_name: str, + model_card_name: str, + model_card_version: Optional[int] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelPackage"]: + ) -> Optional["ModelCard"]: """ - Get a ModelPackage resource + Get a ModelCard resource Parameters: - model_package_name: The name or Amazon Resource Name (ARN) of the model package to describe. When you specify a name, the name must have 1 to 63 characters. Valid characters are a-z, A-Z, 0-9, and - (hyphen). - session: Boto3 session. + model_card_name: The name or Amazon Resource Name (ARN) of the model card to describe. + model_card_version: The version of the model card to describe. If a version is not provided, then the latest version of the model card is described. + session: Boto3 session. region: Region name. Returns: - The ModelPackage resource. + The ModelCard resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18858,10 +18544,12 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelPackageName": model_package_name, + "ModelCardName": model_card_name, + "ModelCardVersion": model_card_version, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -18870,24 +18558,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_model_package(**operation_input_args) + response = client.describe_model_card(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeModelPackageOutput") - model_package = cls(**transformed_response) - return model_package + transformed_response = transform(response, "DescribeModelCardResponse") + model_card = cls(**transformed_response) + return model_card @Base.add_validate_call def refresh( self, - ) -> Optional["ModelPackage"]: + ) -> Optional["ModelCard"]: """ - Refresh a ModelPackage resource + Refresh a ModelCard resource Returns: - The ModelPackage resource. + The ModelCard resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18899,49 +18587,36 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelPackageName": self.model_package_name, + "ModelCardName": self.model_card_name, + "ModelCardVersion": self.model_card_version, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_model_package(**operation_input_args) + response = client.describe_model_card(**operation_input_args) # deserialize response and update self - transform(response, "DescribeModelPackageOutput", self) + transform(response, "DescribeModelCardResponse", self) return self @populate_inputs_decorator @Base.add_validate_call def update( self, - model_approval_status: Optional[str] = Unassigned(), - approval_description: Optional[str] = Unassigned(), - customer_metadata_properties: Optional[Dict[str, str]] = Unassigned(), - customer_metadata_properties_to_remove: Optional[List[str]] = Unassigned(), - additional_inference_specifications_to_add: Optional[ - List[AdditionalInferenceSpecificationDefinition] - ] = Unassigned(), - inference_specification: Optional[InferenceSpecification] = Unassigned(), - source_uri: Optional[str] = Unassigned(), - model_card: Optional[ModelPackageModelCard] = Unassigned(), - model_life_cycle: Optional[ModelLifeCycle] = Unassigned(), - client_token: Optional[str] = Unassigned(), - ) -> Optional["ModelPackage"]: + content: Optional[str] = Unassigned(), + model_card_status: Optional[str] = Unassigned(), + ) -> Optional["ModelCard"]: """ - Update a ModelPackage resource - - Parameters: - customer_metadata_properties_to_remove: The metadata properties associated with the model package versions to remove. - additional_inference_specifications_to_add: An array of additional Inference Specification objects to be added to the existing array additional Inference Specification. Total number of additional Inference Specifications can not exceed 15. Each additional Inference Specification specifies artifacts based on this model package that can be used on inference endpoints. Generally used with SageMaker Neo to store the compiled artifacts. - client_token: A unique token that guarantees that the call to this API is idempotent. + Update a ModelCard resource Returns: - The ModelPackage resource. + The ModelCard resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -18954,23 +18629,17 @@ def update( error_code = e.response['Error']['Code'] ``` ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ - logger.info("Updating model_package resource.") + logger.info("Updating model_card resource.") client = Base.get_sagemaker_client() operation_input_args = { - "ModelPackageArn": self.model_package_arn, - "ModelApprovalStatus": model_approval_status, - "ApprovalDescription": approval_description, - "CustomerMetadataProperties": customer_metadata_properties, - "CustomerMetadataPropertiesToRemove": customer_metadata_properties_to_remove, - "AdditionalInferenceSpecificationsToAdd": additional_inference_specifications_to_add, - "InferenceSpecification": inference_specification, - "SourceUri": source_uri, - "ModelCard": model_card, - "ModelLifeCycle": model_life_cycle, - "ClientToken": client_token, + "ModelCardName": self.model_card_name, + "Content": content, + "ModelCardStatus": model_card_status, } logger.debug(f"Input request: {operation_input_args}") # serialize the input request @@ -18978,7 +18647,7 @@ def update( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_model_package(**operation_input_args) + response = client.update_model_card(**operation_input_args) logger.debug(f"Response: {response}") self.refresh() @@ -18989,7 +18658,7 @@ def delete( self, ) -> None: """ - Delete a ModelPackage resource + Delete a ModelCard resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19002,30 +18671,31 @@ def delete( error_code = e.response['Error']['Code'] ``` ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "ModelPackageName": self.model_package_name, + "ModelCardName": self.model_card_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_model_package(**operation_input_args) + client.delete_model_card(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call def wait_for_status( self, - target_status: Literal["Pending", "InProgress", "Completed", "Failed", "Deleting"], + target_status: Literal["Draft", "PendingReview", "Approved", "Archived"], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a ModelPackage resource to reach certain status. + Wait for a ModelCard resource to reach certain status. Parameters: target_status: The status to wait for. @@ -19044,7 +18714,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task(f"Waiting for ModelPackage to reach [bold]{target_status} status...") + progress.add_task(f"Waiting for ModelCard to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -19057,83 +18727,15 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.model_package_status + current_status = self.model_card_status status.update(f"Current status: [bold]{current_status}") if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") return - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="ModelPackage", status=current_status, reason="(Unknown)" - ) - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError(resouce_type="ModelPackage", status=current_status) - time.sleep(poll) - - @Base.add_validate_call - def wait_for_delete( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a ModelPackage resource to be deleted. - - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for ModelPackage to be deleted...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ) - ): - while True: - try: - self.refresh() - current_status = self.model_package_status - status.update(f"Current status: [bold]{current_status}") - - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="ModelPackage", status=current_status - ) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] - - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e + raise TimeoutExceededError(resouce_type="ModelCard", status=current_status) time.sleep(poll) @classmethod @@ -19143,33 +18745,29 @@ def get_all( creation_time_after: Optional[datetime.datetime] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), name_contains: Optional[str] = Unassigned(), - model_approval_status: Optional[str] = Unassigned(), - model_package_group_name: Optional[str] = Unassigned(), - model_package_type: Optional[str] = Unassigned(), + model_card_status: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["ModelPackage"]: + ) -> ResourceIterator["ModelCard"]: """ - Get all ModelPackage resources + Get all ModelCard resources Parameters: - creation_time_after: A filter that returns only model packages created after the specified time (timestamp). - creation_time_before: A filter that returns only model packages created before the specified time (timestamp). - max_results: The maximum number of model packages to return in the response. - name_contains: A string in the model package name. This filter returns only model packages whose name contains the specified string. - model_approval_status: A filter that returns only the model packages with the specified approval status. - model_package_group_name: A filter that returns only model versions that belong to the specified model group. - model_package_type: A filter that returns only the model packages of the specified type. This can be one of the following values. UNVERSIONED - List only unversioined models. This is the default value if no ModelPackageType is specified. VERSIONED - List only versioned models. BOTH - List both versioned and unversioned models. - next_token: If the response to a previous ListModelPackages request was truncated, the response includes a NextToken. To retrieve the next set of model packages, use the token in the next request. - sort_by: The parameter by which to sort the results. The default is CreationTime. - sort_order: The sort order for the results. The default is Ascending. + creation_time_after: Only list model cards that were created after the time specified. + creation_time_before: Only list model cards that were created before the time specified. + max_results: The maximum number of model cards to list. + name_contains: Only list model cards with names that contain the specified string. + model_card_status: Only list model cards with the specified approval status. + next_token: If the response to a previous ListModelCards request was truncated, the response includes a NextToken. To retrieve the next set of model cards, use the token in the next request. + sort_by: Sort model cards by either name or creation time. Sorts by creation time by default. + sort_order: Sort model cards by ascending or descending order. session: Boto3 session. region: Region name. Returns: - Iterator for listed ModelPackage resources. + Iterator for listed ModelCard resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19191,9 +18789,7 @@ def get_all( "CreationTimeAfter": creation_time_after, "CreationTimeBefore": creation_time_before, "NameContains": name_contains, - "ModelApprovalStatus": model_approval_status, - "ModelPackageGroupName": model_package_group_name, - "ModelPackageType": model_package_type, + "ModelCardStatus": model_card_status, "SortBy": sort_by, "SortOrder": sort_order, } @@ -19204,30 +18800,38 @@ def get_all( return ResourceIterator( client=client, - list_method="list_model_packages", - summaries_key="ModelPackageSummaryList", - summary_name="ModelPackageSummary", - resource_cls=ModelPackage, + list_method="list_model_cards", + summaries_key="ModelCardSummaries", + summary_name="ModelCardSummary", + resource_cls=ModelCard, list_method_kwargs=operation_input_args, ) @Base.add_validate_call - def batch_get( + def get_all_versions( self, - model_package_arn_list: List[str], + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional[BatchDescribeModelPackageOutput]: + ) -> ResourceIterator[ModelCardVersionSummary]: """ - This action batch describes a list of versioned model packages. + List existing versions of an Amazon SageMaker Model Card. Parameters: - model_package_arn_list: The list of Amazon Resource Name (ARN) of the model package groups. + creation_time_after: Only list model card versions that were created after the time specified. + creation_time_before: Only list model card versions that were created before the time specified. + max_results: The maximum number of model card versions to list. + next_token: If the response to a previous ListModelCardVersions request was truncated, the response includes a NextToken. To retrieve the next set of model card versions, use the token in the next request. + sort_by: Sort listed model card versions by version. Sorts by version by default. + sort_order: Sort model card versions by ascending or descending order. session: Boto3 session. region: Region name. Returns: - BatchDescribeModelPackageOutput + Iterator for listed ModelCardVersionSummary. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19239,10 +18843,16 @@ def batch_get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelPackageArnList": model_package_arn_list, + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "ModelCardName": self.model_card_name, + "ModelCardStatus": self.model_card_status, + "SortBy": sort_by, + "SortOrder": sort_order, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -19252,38 +18862,48 @@ def batch_get( session=session, region_name=region, service_name="sagemaker" ) - logger.debug(f"Calling batch_describe_model_package API") - response = client.batch_describe_model_package(**operation_input_args) - logger.debug(f"Response: {response}") - - transformed_response = transform(response, "BatchDescribeModelPackageOutput") - return BatchDescribeModelPackageOutput(**transformed_response) + return ResourceIterator( + client=client, + list_method="list_model_card_versions", + summaries_key="ModelCardVersionSummaryList", + summary_name="ModelCardVersionSummary", + resource_cls=ModelCardVersionSummary, + list_method_kwargs=operation_input_args, + ) -class ModelPackageGroup(Base): +class ModelCardExportJob(Base): """ - Class representing resource ModelPackageGroup + Class representing resource ModelCardExportJob Attributes: - model_package_group_name: The name of the model group. - model_package_group_arn: The Amazon Resource Name (ARN) of the model group. - creation_time: The time that the model group was created. - created_by: - model_package_group_status: The status of the model group. - model_package_group_description: A description of the model group. - - """ - - model_package_group_name: str - model_package_group_arn: Optional[str] = Unassigned() - model_package_group_description: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() - created_by: Optional[UserContext] = Unassigned() - model_package_group_status: Optional[str] = Unassigned() + model_card_export_job_name: The name of the model card export job to describe. + model_card_export_job_arn: The Amazon Resource Name (ARN) of the model card export job. + status: The completion status of the model card export job. InProgress: The model card export job is in progress. Completed: The model card export job is complete. Failed: The model card export job failed. To see the reason for the failure, see the FailureReason field in the response to a DescribeModelCardExportJob call. + model_card_name: The name or Amazon Resource Name (ARN) of the model card that the model export job exports. + model_card_version: The version of the model card that the model export job exports. + output_config: The export output details for the model card. + created_at: The date and time that the model export job was created. + last_modified_at: The date and time that the model export job was last modified. + failure_reason: The failure reason if the model export job fails. + export_artifacts: The exported model card artifacts. + + """ + + model_card_export_job_arn: str + model_card_export_job_name: Optional[str] = Unassigned() + status: Optional[str] = Unassigned() + model_card_name: Optional[str] = Unassigned() + model_card_version: Optional[int] = Unassigned() + output_config: Optional[ModelCardExportOutputConfig] = Unassigned() + created_at: Optional[datetime.datetime] = Unassigned() + last_modified_at: Optional[datetime.datetime] = Unassigned() + failure_reason: Optional[str] = Unassigned() + export_artifacts: Optional[ModelCardExportArtifacts] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "model_package_group_name" + resource_name = "model_card_export_job_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -19294,31 +18914,50 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object model_package_group") + logger.error("Name attribute not found for object model_card_export_job") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "output_config": {"s3_output_path": {"type": "string"}}, + "export_artifacts": {"s3_export_artifacts": {"type": "string"}}, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "ModelCardExportJob", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, - model_package_group_name: str, - model_package_group_description: Optional[str] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), + model_card_name: Union[str, object], + model_card_export_job_name: str, + output_config: ModelCardExportOutputConfig, + model_card_version: Optional[int] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelPackageGroup"]: + ) -> Optional["ModelCardExportJob"]: """ - Create a ModelPackageGroup resource + Create a ModelCardExportJob resource Parameters: - model_package_group_name: The name of the model group. - model_package_group_description: A description for the model group. - tags: A list of key value pairs associated with the model group. For more information, see Tagging Amazon Web Services resources in the Amazon Web Services General Reference Guide. + model_card_name: The name or Amazon Resource Name (ARN) of the model card to export. + model_card_export_job_name: The name of the model card export job. + output_config: The model card output configuration that specifies the Amazon S3 path for exporting. + model_card_version: The version of the model card to export. If a version is not provided, then the latest version of the model card is exported. session: Boto3 session. region: Region name. Returns: - The ModelPackageGroup resource. + The ModelCardExportJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19330,25 +18969,28 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating model_package_group resource.") + logger.info("Creating model_card_export_job resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "ModelPackageGroupName": model_package_group_name, - "ModelPackageGroupDescription": model_package_group_description, - "Tags": tags, + "ModelCardName": model_card_name, + "ModelCardVersion": model_card_version, + "ModelCardExportJobName": model_card_export_job_name, + "OutputConfig": output_config, } operation_input_args = Base.populate_chained_attributes( - resource_name="ModelPackageGroup", operation_input_args=operation_input_args + resource_name="ModelCardExportJob", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -19357,31 +18999,33 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_model_package_group(**operation_input_args) + response = client.create_model_card_export_job(**operation_input_args) logger.debug(f"Response: {response}") return cls.get( - model_package_group_name=model_package_group_name, session=session, region=region + model_card_export_job_arn=response["ModelCardExportJobArn"], + session=session, + region=region, ) @classmethod @Base.add_validate_call def get( cls, - model_package_group_name: str, + model_card_export_job_arn: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelPackageGroup"]: + ) -> Optional["ModelCardExportJob"]: """ - Get a ModelPackageGroup resource + Get a ModelCardExportJob resource Parameters: - model_package_group_name: The name of the model group to describe. + model_card_export_job_arn: The Amazon Resource Name (ARN) of the model card export job to describe. session: Boto3 session. region: Region name. Returns: - The ModelPackageGroup resource. + The ModelCardExportJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19393,10 +19037,11 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelPackageGroupName": model_package_group_name, + "ModelCardExportJobArn": model_card_export_job_arn, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -19405,24 +19050,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_model_package_group(**operation_input_args) + response = client.describe_model_card_export_job(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeModelPackageGroupOutput") - model_package_group = cls(**transformed_response) - return model_package_group + transformed_response = transform(response, "DescribeModelCardExportJobResponse") + model_card_export_job = cls(**transformed_response) + return model_card_export_job @Base.add_validate_call def refresh( self, - ) -> Optional["ModelPackageGroup"]: + ) -> Optional["ModelCardExportJob"]: """ - Refresh a ModelPackageGroup resource + Refresh a ModelCardExportJob resource Returns: - The ModelPackageGroup resource. + The ModelCardExportJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19434,69 +19079,33 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelPackageGroupName": self.model_package_group_name, + "ModelCardExportJobArn": self.model_card_export_job_arn, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_model_package_group(**operation_input_args) + response = client.describe_model_card_export_job(**operation_input_args) # deserialize response and update self - transform(response, "DescribeModelPackageGroupOutput", self) + transform(response, "DescribeModelCardExportJobResponse", self) return self @Base.add_validate_call - def delete( - self, - ) -> None: - """ - Delete a ModelPackageGroup resource - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. - """ - - client = Base.get_sagemaker_client() - - operation_input_args = { - "ModelPackageGroupName": self.model_package_group_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client.delete_model_package_group(**operation_input_args) - - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def wait_for_status( + def wait( self, - target_status: Literal[ - "Pending", "InProgress", "Completed", "Failed", "Deleting", "DeleteFailed" - ], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a ModelPackageGroup resource to reach certain status. + Wait for a ModelCardExportJob resource. Parameters: - target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. @@ -19504,7 +19113,9 @@ def wait_for_status( TimeoutExceededError: If the resource does not reach a terminal state before the timeout. FailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. + """ + terminal_states = ["Completed", "Failed"] start_time = time.time() progress = Progress( @@ -19512,7 +19123,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task(f"Waiting for ModelPackageGroup to reach [bold]{target_status} status...") + progress.add_task("Waiting for ModelCardExportJob...") status = Status("Current status:") with Live( @@ -19525,36 +19136,61 @@ def wait_for_status( ): while True: self.refresh() - current_status = self.model_package_group_status + current_status = self.status status.update(f"Current status: [bold]{current_status}") - if target_status == current_status: + if current_status in terminal_states: logger.info(f"Final Resource Status: [bold]{current_status}") - return - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="ModelPackageGroup", status=current_status, reason="(Unknown)" - ) + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="ModelCardExportJob", + status=current_status, + reason=self.failure_reason, + ) + + return if timeout is not None and time.time() - start_time >= timeout: raise TimeoutExceededError( - resouce_type="ModelPackageGroup", status=current_status + resouce_type="ModelCardExportJob", status=current_status ) time.sleep(poll) + @classmethod @Base.add_validate_call - def wait_for_delete( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: + def get_all( + cls, + model_card_name: str, + model_card_version: Optional[int] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + model_card_export_job_name_contains: Optional[str] = Unassigned(), + status_equals: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["ModelCardExportJob"]: """ - Wait for a ModelPackageGroup resource to be deleted. + Get all ModelCardExportJob resources Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + model_card_name: List export jobs for the model card with the specified name. + model_card_version: List export jobs for the model card with the specified version. + creation_time_after: Only list model card export jobs that were created after the time specified. + creation_time_before: Only list model card export jobs that were created before the time specified. + model_card_export_job_name_contains: Only list model card export jobs with names that contain the specified string. + status_equals: Only list model card export jobs with the specified status. + sort_by: Sort model card export jobs by either name or creation time. Sorts by creation time by default. + sort_order: Sort model card export jobs by ascending or descending order. + next_token: If the response to a previous ListModelCardExportJobs request was truncated, the response includes a NextToken. To retrieve the next set of model card export jobs, use the token in the next request. + max_results: The maximum number of model card export jobs to list. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed ModelCardExportJob resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19566,76 +19202,162 @@ def wait_for_delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. """ - start_time = time.time() - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" ) - progress.add_task("Waiting for ModelPackageGroup to be deleted...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ) - ): - while True: - try: - self.refresh() - current_status = self.model_package_group_status - status.update(f"Current status: [bold]{current_status}") - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="ModelPackageGroup", status=current_status - ) - except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] + operation_input_args = { + "ModelCardName": model_card_name, + "ModelCardVersion": model_card_version, + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "ModelCardExportJobNameContains": model_card_export_job_name_contains, + "StatusEquals": status_equals, + "SortBy": sort_by, + "SortOrder": sort_order, + } - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e - time.sleep(poll) + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_model_card_export_jobs", + summaries_key="ModelCardExportJobSummaries", + summary_name="ModelCardExportJobSummary", + resource_cls=ModelCardExportJob, + list_method_kwargs=operation_input_args, + ) + + +class ModelExplainabilityJobDefinition(Base): + """ + Class representing resource ModelExplainabilityJobDefinition + + Attributes: + job_definition_arn: The Amazon Resource Name (ARN) of the model explainability job. + job_definition_name: The name of the explainability job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + creation_time: The time at which the model explainability job was created. + model_explainability_app_specification: Configures the model explainability job to run a specified Docker container image. + model_explainability_job_input: Inputs for the model explainability job. + model_explainability_job_output_config: + job_resources: + role_arn: The Amazon Resource Name (ARN) of the IAM role that has read permission to the input data location and write permission to the output data location in Amazon S3. + model_explainability_baseline_config: The baseline configuration for a model explainability job. + network_config: Networking options for a model explainability job. + stopping_condition: + + """ + + job_definition_name: str + job_definition_arn: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + model_explainability_baseline_config: Optional[ModelExplainabilityBaselineConfig] = Unassigned() + model_explainability_app_specification: Optional[ModelExplainabilityAppSpecification] = ( + Unassigned() + ) + model_explainability_job_input: Optional[ModelExplainabilityJobInput] = Unassigned() + model_explainability_job_output_config: Optional[MonitoringOutputConfig] = Unassigned() + job_resources: Optional[MonitoringResources] = Unassigned() + network_config: Optional[MonitoringNetworkConfig] = Unassigned() + role_arn: Optional[str] = Unassigned() + stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "model_explainability_job_definition_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object model_explainability_job_definition") + return None + + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "model_explainability_job_input": { + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + }, + "model_explainability_job_output_config": {"kms_key_id": {"type": "string"}}, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, + "role_arn": {"type": "string"}, + "model_explainability_baseline_config": { + "constraints_resource": {"s3_uri": {"type": "string"}} + }, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "ModelExplainabilityJobDefinition", **kwargs + ), + ) + + return wrapper @classmethod + @populate_inputs_decorator @Base.add_validate_call - def get_all( + def create( cls, - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - cross_account_filter_option: Optional[str] = Unassigned(), + job_definition_name: str, + model_explainability_app_specification: ModelExplainabilityAppSpecification, + model_explainability_job_input: ModelExplainabilityJobInput, + model_explainability_job_output_config: MonitoringOutputConfig, + job_resources: MonitoringResources, + role_arn: str, + model_explainability_baseline_config: Optional[ + ModelExplainabilityBaselineConfig + ] = Unassigned(), + network_config: Optional[MonitoringNetworkConfig] = Unassigned(), + stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["ModelPackageGroup"]: + ) -> Optional["ModelExplainabilityJobDefinition"]: """ - Get all ModelPackageGroup resources + Create a ModelExplainabilityJobDefinition resource Parameters: - creation_time_after: A filter that returns only model groups created after the specified time. - creation_time_before: A filter that returns only model groups created before the specified time. - max_results: The maximum number of results to return in the response. - name_contains: A string in the model group name. This filter returns only model groups whose name contains the specified string. - next_token: If the result of the previous ListModelPackageGroups request was truncated, the response includes a NextToken. To retrieve the next set of model groups, use the token in the next request. - sort_by: The field to sort results by. The default is CreationTime. - sort_order: The sort order for results. The default is Ascending. - cross_account_filter_option: A filter that returns either model groups shared with you or model groups in your own account. When the value is CrossAccount, the results show the resources made discoverable to you from other accounts. When the value is SameAccount or null, the results show resources from your account. The default is SameAccount. + job_definition_name: The name of the model explainability job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + model_explainability_app_specification: Configures the model explainability job to run a specified Docker container image. + model_explainability_job_input: Inputs for the model explainability job. + model_explainability_job_output_config: + job_resources: + role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. + model_explainability_baseline_config: The baseline configuration for a model explainability job. + network_config: Networking options for a model explainability job. + stopping_condition: + tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. session: Boto3 session. region: Region name. Returns: - Iterator for listed ModelPackageGroup resources. + The ModelExplainabilityJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19647,49 +19369,65 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ + logger.info("Creating model_explainability_job_definition resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "NameContains": name_contains, - "SortBy": sort_by, - "SortOrder": sort_order, - "CrossAccountFilterOption": cross_account_filter_option, + "JobDefinitionName": job_definition_name, + "ModelExplainabilityBaselineConfig": model_explainability_baseline_config, + "ModelExplainabilityAppSpecification": model_explainability_app_specification, + "ModelExplainabilityJobInput": model_explainability_job_input, + "ModelExplainabilityJobOutputConfig": model_explainability_job_output_config, + "JobResources": job_resources, + "NetworkConfig": network_config, + "RoleArn": role_arn, + "StoppingCondition": stopping_condition, + "Tags": tags, } + operation_input_args = Base.populate_chained_attributes( + resource_name="ModelExplainabilityJobDefinition", + operation_input_args=operation_input_args, + ) + + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_model_package_groups", - summaries_key="ModelPackageGroupSummaryList", - summary_name="ModelPackageGroupSummary", - resource_cls=ModelPackageGroup, - list_method_kwargs=operation_input_args, - ) + # create the resource + response = client.create_model_explainability_job_definition(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get(job_definition_name=job_definition_name, session=session, region=region) + @classmethod @Base.add_validate_call - def get_policy( - self, + def get( + cls, + job_definition_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional[str]: + ) -> Optional["ModelExplainabilityJobDefinition"]: """ - Gets a resource policy that manages access for a model group. + Get a ModelExplainabilityJobDefinition resource Parameters: + job_definition_name: The name of the model explainability job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. session: Boto3 session. region: Region name. Returns: - str + The ModelExplainabilityJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19701,10 +19439,11 @@ def get_policy( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelPackageGroupName": self.model_package_group_name, + "JobDefinitionName": job_definition_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -19713,25 +19452,26 @@ def get_policy( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) + response = client.describe_model_explainability_job_definition(**operation_input_args) - logger.debug(f"Calling get_model_package_group_policy API") - response = client.get_model_package_group_policy(**operation_input_args) - logger.debug(f"Response: {response}") + logger.debug(response) - return list(response.values())[0] + # deserialize the response + transformed_response = transform( + response, "DescribeModelExplainabilityJobDefinitionResponse" + ) + model_explainability_job_definition = cls(**transformed_response) + return model_explainability_job_definition @Base.add_validate_call - def delete_policy( + def refresh( self, - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> None: + ) -> Optional["ModelExplainabilityJobDefinition"]: """ - Deletes a model group resource policy. + Refresh a ModelExplainabilityJobDefinition resource - Parameters: - session: Boto3 session. - region: Region name. + Returns: + The ModelExplainabilityJobDefinition resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19743,38 +19483,87 @@ def delete_policy( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "ModelPackageGroupName": self.model_package_group_name, + "JobDefinitionName": self.job_definition_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) + client = Base.get_sagemaker_client() + response = client.describe_model_explainability_job_definition(**operation_input_args) - logger.debug(f"Calling delete_model_package_group_policy API") - response = client.delete_model_package_group_policy(**operation_input_args) - logger.debug(f"Response: {response}") + # deserialize response and update self + transform(response, "DescribeModelExplainabilityJobDefinitionResponse", self) + return self @Base.add_validate_call - def put_policy( + def delete( self, - resource_policy: str, + ) -> None: + """ + Delete a ModelExplainabilityJobDefinition resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "JobDefinitionName": self.job_definition_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_model_explainability_job_definition(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + + @classmethod + @Base.add_validate_call + def get_all( + cls, + endpoint_name: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> None: + ) -> ResourceIterator["ModelExplainabilityJobDefinition"]: """ - Adds a resouce policy to control access to a model group. + Get all ModelExplainabilityJobDefinition resources Parameters: - resource_policy: The resource policy for the model group. + endpoint_name: Name of the endpoint to monitor for model explainability. + sort_by: Whether to sort results by the Name or CreationTime field. The default is CreationTime. + sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. + next_token: The token returned if the response is truncated. To retrieve the next set of job executions, use it in the next request. + max_results: The maximum number of jobs to return in the response. The default value is 10. + name_contains: Filter for model explainability jobs whose name contains a specified string. + creation_time_before: A filter that returns only model explainability jobs created before a specified time. + creation_time_after: A filter that returns only model explainability jobs created after a specified time. session: Boto3 session. region: Region name. + Returns: + Iterator for listed ModelExplainabilityJobDefinition resources. + Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: @@ -19785,60 +19574,113 @@ def put_policy( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. """ + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "ModelPackageGroupName": self.model_package_group_name, - "ResourcePolicy": resource_policy, + "EndpointName": endpoint_name, + "SortBy": sort_by, + "SortOrder": sort_order, + "NameContains": name_contains, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + } + custom_key_mapping = { + "monitoring_job_definition_name": "job_definition_name", + "monitoring_job_definition_arn": "job_definition_arn", } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" + return ResourceIterator( + client=client, + list_method="list_model_explainability_job_definitions", + summaries_key="JobDefinitionSummaries", + summary_name="MonitoringJobDefinitionSummary", + resource_cls=ModelExplainabilityJobDefinition, + custom_key_mapping=custom_key_mapping, + list_method_kwargs=operation_input_args, ) - logger.debug(f"Calling put_model_package_group_policy API") - response = client.put_model_package_group_policy(**operation_input_args) - logger.debug(f"Response: {response}") - -class ModelQualityJobDefinition(Base): +class ModelPackage(Base): """ - Class representing resource ModelQualityJobDefinition + Class representing resource ModelPackage Attributes: - job_definition_arn: The Amazon Resource Name (ARN) of the model quality job. - job_definition_name: The name of the quality job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. - creation_time: The time at which the model quality job was created. - model_quality_app_specification: Configures the model quality job to run a specified Docker container image. - model_quality_job_input: Inputs for the model quality job. - model_quality_job_output_config: - job_resources: - role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - model_quality_baseline_config: The baseline configuration for a model quality job. - network_config: Networking options for a model quality job. - stopping_condition: + model_package_name: The name of the model package being described. + model_package_arn: The Amazon Resource Name (ARN) of the model package. + creation_time: A timestamp specifying when the model package was created. + model_package_status: The current status of the model package. + model_package_status_details: Details about the current status of the model package. + model_package_group_name: If the model is a versioned model, the name of the model group that the versioned model belongs to. + model_package_version: The version of the model package. + model_package_description: A brief summary of the model package. + inference_specification: Details about inference jobs that you can run with models based on this model package. + source_algorithm_specification: Details about the algorithm that was used to create the model package. + validation_specification: Configurations for one or more transform jobs that SageMaker runs to test the model package. + certify_for_marketplace: Whether the model package is certified for listing on Amazon Web Services Marketplace. + model_approval_status: The approval status of the model package. + created_by: + metadata_properties: + model_metrics: Metrics for the model. + last_modified_time: The last time that the model package was modified. + last_modified_by: + approval_description: A description provided for the model approval. + domain: The machine learning domain of the model package you specified. Common machine learning domains include computer vision and natural language processing. + task: The machine learning task you specified that your model package accomplishes. Common machine learning tasks include object detection and image classification. + sample_payload_url: The Amazon Simple Storage Service (Amazon S3) path where the sample payload are stored. This path points to a single gzip compressed tar archive (.tar.gz suffix). + customer_metadata_properties: The metadata properties associated with the model package versions. + drift_check_baselines: Represents the drift check baselines that can be used when the model monitor is set using the model package. For more information, see the topic on Drift Detection against Previous Baselines in SageMaker Pipelines in the Amazon SageMaker Developer Guide. + additional_inference_specifications: An array of additional Inference Specification objects. Each additional Inference Specification specifies artifacts based on this model package that can be used on inference endpoints. Generally used with SageMaker Neo to store the compiled artifacts. + skip_model_validation: Indicates if you want to skip model validation. + source_uri: The URI of the source for the model package. + security_config: The KMS Key ID (KMSKeyId) used for encryption of model package information. + model_card: The model card associated with the model package. Since ModelPackageModelCard is tied to a model package, it is a specific usage of a model card and its schema is simplified compared to the schema of ModelCard. The ModelPackageModelCard schema does not include model_package_details, and model_overview is composed of the model_creator and model_artifact properties. For more information about the model package model card schema, see Model package model card schema. For more information about the model card associated with the model package, see View the Details of a Model Version. + model_life_cycle: A structure describing the current state of the model in its life cycle. """ - job_definition_name: str - job_definition_arn: Optional[str] = Unassigned() + model_package_name: str + model_package_group_name: Optional[str] = Unassigned() + model_package_version: Optional[int] = Unassigned() + model_package_arn: Optional[str] = Unassigned() + model_package_description: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - model_quality_baseline_config: Optional[ModelQualityBaselineConfig] = Unassigned() - model_quality_app_specification: Optional[ModelQualityAppSpecification] = Unassigned() - model_quality_job_input: Optional[ModelQualityJobInput] = Unassigned() - model_quality_job_output_config: Optional[MonitoringOutputConfig] = Unassigned() - job_resources: Optional[MonitoringResources] = Unassigned() - network_config: Optional[MonitoringNetworkConfig] = Unassigned() - role_arn: Optional[str] = Unassigned() - stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned() + inference_specification: Optional[InferenceSpecification] = Unassigned() + source_algorithm_specification: Optional[SourceAlgorithmSpecification] = Unassigned() + validation_specification: Optional[ModelPackageValidationSpecification] = Unassigned() + model_package_status: Optional[str] = Unassigned() + model_package_status_details: Optional[ModelPackageStatusDetails] = Unassigned() + certify_for_marketplace: Optional[bool] = Unassigned() + model_approval_status: Optional[str] = Unassigned() + created_by: Optional[UserContext] = Unassigned() + metadata_properties: Optional[MetadataProperties] = Unassigned() + model_metrics: Optional[ModelMetrics] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + last_modified_by: Optional[UserContext] = Unassigned() + approval_description: Optional[str] = Unassigned() + domain: Optional[str] = Unassigned() + task: Optional[str] = Unassigned() + sample_payload_url: Optional[str] = Unassigned() + customer_metadata_properties: Optional[Dict[str, str]] = Unassigned() + drift_check_baselines: Optional[DriftCheckBaselines] = Unassigned() + additional_inference_specifications: Optional[ + List[AdditionalInferenceSpecificationDefinition] + ] = Unassigned() + skip_model_validation: Optional[str] = Unassigned() + source_uri: Optional[str] = Unassigned() + security_config: Optional[ModelPackageSecurityConfig] = Unassigned() + model_card: Optional[ModelPackageModelCard] = Unassigned() + model_life_cycle: Optional[ModelLifeCycle] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "model_quality_job_definition_name" + resource_name = "model_package_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -19849,42 +19691,55 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object model_quality_job_definition") + logger.error("Name attribute not found for object model_package") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "model_quality_job_input": { - "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, - "endpoint_input": { - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "validation_specification": {"validation_role": {"type": "string"}}, + "model_metrics": { + "model_quality": { + "statistics": {"s3_uri": {"type": "string"}}, + "constraints": {"s3_uri": {"type": "string"}}, }, - "batch_transform_input": { - "data_captured_destination_s3_uri": {"type": "string"}, - "s3_input_mode": {"type": "string"}, - "s3_data_distribution_type": {"type": "string"}, + "model_data_quality": { + "statistics": {"s3_uri": {"type": "string"}}, + "constraints": {"s3_uri": {"type": "string"}}, }, + "bias": { + "report": {"s3_uri": {"type": "string"}}, + "pre_training_report": {"s3_uri": {"type": "string"}}, + "post_training_report": {"s3_uri": {"type": "string"}}, + }, + "explainability": {"report": {"s3_uri": {"type": "string"}}}, }, - "model_quality_job_output_config": {"kms_key_id": {"type": "string"}}, - "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, - "role_arn": {"type": "string"}, - "model_quality_baseline_config": { - "constraints_resource": {"s3_uri": {"type": "string"}} - }, - "network_config": { - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - } + "drift_check_baselines": { + "bias": { + "config_file": {"s3_uri": {"type": "string"}}, + "pre_training_constraints": {"s3_uri": {"type": "string"}}, + "post_training_constraints": {"s3_uri": {"type": "string"}}, + }, + "explainability": { + "constraints": {"s3_uri": {"type": "string"}}, + "config_file": {"s3_uri": {"type": "string"}}, + }, + "model_quality": { + "statistics": {"s3_uri": {"type": "string"}}, + "constraints": {"s3_uri": {"type": "string"}}, + }, + "model_data_quality": { + "statistics": {"s3_uri": {"type": "string"}}, + "constraints": {"s3_uri": {"type": "string"}}, + }, }, + "security_config": {"kms_key_id": {"type": "string"}}, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "ModelQualityJobDefinition", **kwargs + config_schema_for_resource, "ModelPackage", **kwargs ), ) @@ -19895,38 +19750,66 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - job_definition_name: str, - model_quality_app_specification: ModelQualityAppSpecification, - model_quality_job_input: ModelQualityJobInput, - model_quality_job_output_config: MonitoringOutputConfig, - job_resources: MonitoringResources, - role_arn: str, - model_quality_baseline_config: Optional[ModelQualityBaselineConfig] = Unassigned(), - network_config: Optional[MonitoringNetworkConfig] = Unassigned(), - stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), + model_package_name: Optional[str] = Unassigned(), + model_package_group_name: Optional[Union[str, object]] = Unassigned(), + model_package_description: Optional[str] = Unassigned(), + inference_specification: Optional[InferenceSpecification] = Unassigned(), + validation_specification: Optional[ModelPackageValidationSpecification] = Unassigned(), + source_algorithm_specification: Optional[SourceAlgorithmSpecification] = Unassigned(), + certify_for_marketplace: Optional[bool] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional["ModelQualityJobDefinition"]: - """ - Create a ModelQualityJobDefinition resource - + model_approval_status: Optional[str] = Unassigned(), + metadata_properties: Optional[MetadataProperties] = Unassigned(), + model_metrics: Optional[ModelMetrics] = Unassigned(), + client_token: Optional[str] = Unassigned(), + domain: Optional[str] = Unassigned(), + task: Optional[str] = Unassigned(), + sample_payload_url: Optional[str] = Unassigned(), + customer_metadata_properties: Optional[Dict[str, str]] = Unassigned(), + drift_check_baselines: Optional[DriftCheckBaselines] = Unassigned(), + additional_inference_specifications: Optional[ + List[AdditionalInferenceSpecificationDefinition] + ] = Unassigned(), + skip_model_validation: Optional[str] = Unassigned(), + source_uri: Optional[str] = Unassigned(), + security_config: Optional[ModelPackageSecurityConfig] = Unassigned(), + model_card: Optional[ModelPackageModelCard] = Unassigned(), + model_life_cycle: Optional[ModelLifeCycle] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["ModelPackage"]: + """ + Create a ModelPackage resource + Parameters: - job_definition_name: The name of the monitoring job definition. - model_quality_app_specification: The container that runs the monitoring job. - model_quality_job_input: A list of the inputs that are monitored. Currently endpoints are supported. - model_quality_job_output_config: - job_resources: - role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. - model_quality_baseline_config: Specifies the constraints and baselines for the monitoring job. - network_config: Specifies the network configuration for the monitoring job. - stopping_condition: - tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. + model_package_name: The name of the model package. The name must have 1 to 63 characters. Valid characters are a-z, A-Z, 0-9, and - (hyphen). This parameter is required for unversioned models. It is not applicable to versioned models. + model_package_group_name: The name or Amazon Resource Name (ARN) of the model package group that this model version belongs to. This parameter is required for versioned models, and does not apply to unversioned models. + model_package_description: A description of the model package. + inference_specification: Specifies details about inference jobs that you can run with models based on this model package, including the following information: The Amazon ECR paths of containers that contain the inference code and model artifacts. The instance types that the model package supports for transform jobs and real-time endpoints used for inference. The input and output content formats that the model package supports for inference. + validation_specification: Specifies configurations for one or more transform jobs that SageMaker runs to test the model package. + source_algorithm_specification: Details about the algorithm that was used to create the model package. + certify_for_marketplace: Whether to certify the model package for listing on Amazon Web Services Marketplace. This parameter is optional for unversioned models, and does not apply to versioned models. + tags: A list of key value pairs associated with the model. For more information, see Tagging Amazon Web Services resources in the Amazon Web Services General Reference Guide. If you supply ModelPackageGroupName, your model package belongs to the model group you specify and uses the tags associated with the model group. In this case, you cannot supply a tag argument. + model_approval_status: Whether the model is approved for deployment. This parameter is optional for versioned models, and does not apply to unversioned models. For versioned models, the value of this parameter must be set to Approved to deploy the model. + metadata_properties: + model_metrics: A structure that contains model metrics reports. + client_token: A unique token that guarantees that the call to this API is idempotent. + domain: The machine learning domain of your model package and its components. Common machine learning domains include computer vision and natural language processing. + task: The machine learning task your model package accomplishes. Common machine learning tasks include object detection and image classification. The following tasks are supported by Inference Recommender: "IMAGE_CLASSIFICATION" \| "OBJECT_DETECTION" \| "TEXT_GENERATION" \|"IMAGE_SEGMENTATION" \| "FILL_MASK" \| "CLASSIFICATION" \| "REGRESSION" \| "OTHER". Specify "OTHER" if none of the tasks listed fit your use case. + sample_payload_url: The Amazon Simple Storage Service (Amazon S3) path where the sample payload is stored. This path must point to a single gzip compressed tar archive (.tar.gz suffix). This archive can hold multiple files that are all equally used in the load test. Each file in the archive must satisfy the size constraints of the InvokeEndpoint call. + customer_metadata_properties: The metadata properties associated with the model package versions. + drift_check_baselines: Represents the drift check baselines that can be used when the model monitor is set using the model package. For more information, see the topic on Drift Detection against Previous Baselines in SageMaker Pipelines in the Amazon SageMaker Developer Guide. + additional_inference_specifications: An array of additional Inference Specification objects. Each additional Inference Specification specifies artifacts based on this model package that can be used on inference endpoints. Generally used with SageMaker Neo to store the compiled artifacts. + skip_model_validation: Indicates if you want to skip model validation. + source_uri: The URI of the source for the model package. If you want to clone a model package, set it to the model package Amazon Resource Name (ARN). If you want to register a model, set it to the model ARN. + security_config: The KMS Key ID (KMSKeyId) used for encryption of model package information. + model_card: The model card associated with the model package. Since ModelPackageModelCard is tied to a model package, it is a specific usage of a model card and its schema is simplified compared to the schema of ModelCard. The ModelPackageModelCard schema does not include model_package_details, and model_overview is composed of the model_creator and model_artifact properties. For more information about the model package model card schema, see Model package model card schema. For more information about the model card associated with the model package, see View the Details of a Model Version. + model_life_cycle: A structure describing the current state of the model in its life cycle. session: Boto3 session. region: Region name. Returns: - The ModelQualityJobDefinition resource. + The ModelPackage resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -19938,33 +19821,46 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating model_quality_job_definition resource.") + logger.info("Creating model_package resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "JobDefinitionName": job_definition_name, - "ModelQualityBaselineConfig": model_quality_baseline_config, - "ModelQualityAppSpecification": model_quality_app_specification, - "ModelQualityJobInput": model_quality_job_input, - "ModelQualityJobOutputConfig": model_quality_job_output_config, - "JobResources": job_resources, - "NetworkConfig": network_config, - "RoleArn": role_arn, - "StoppingCondition": stopping_condition, + "ModelPackageName": model_package_name, + "ModelPackageGroupName": model_package_group_name, + "ModelPackageDescription": model_package_description, + "InferenceSpecification": inference_specification, + "ValidationSpecification": validation_specification, + "SourceAlgorithmSpecification": source_algorithm_specification, + "CertifyForMarketplace": certify_for_marketplace, "Tags": tags, + "ModelApprovalStatus": model_approval_status, + "MetadataProperties": metadata_properties, + "ModelMetrics": model_metrics, + "ClientToken": client_token, + "Domain": domain, + "Task": task, + "SamplePayloadUrl": sample_payload_url, + "CustomerMetadataProperties": customer_metadata_properties, + "DriftCheckBaselines": drift_check_baselines, + "AdditionalInferenceSpecifications": additional_inference_specifications, + "SkipModelValidation": skip_model_validation, + "SourceUri": source_uri, + "SecurityConfig": security_config, + "ModelCard": model_card, + "ModelLifeCycle": model_life_cycle, } operation_input_args = Base.populate_chained_attributes( - resource_name="ModelQualityJobDefinition", operation_input_args=operation_input_args + resource_name="ModelPackage", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -19973,29 +19869,31 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_model_quality_job_definition(**operation_input_args) + response = client.create_model_package(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(job_definition_name=job_definition_name, session=session, region=region) + return cls.get( + model_package_name=response["ModelPackageName"], session=session, region=region + ) @classmethod @Base.add_validate_call def get( cls, - job_definition_name: str, + model_package_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["ModelQualityJobDefinition"]: + ) -> Optional["ModelPackage"]: """ - Get a ModelQualityJobDefinition resource + Get a ModelPackage resource Parameters: - job_definition_name: The name of the model quality job. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + model_package_name: The name or Amazon Resource Name (ARN) of the model package to describe. When you specify a name, the name must have 1 to 63 characters. Valid characters are a-z, A-Z, 0-9, and - (hyphen). session: Boto3 session. region: Region name. Returns: - The ModelQualityJobDefinition resource. + The ModelPackage resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20007,11 +19905,10 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "JobDefinitionName": job_definition_name, + "ModelPackageName": model_package_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -20020,24 +19917,1766 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_model_quality_job_definition(**operation_input_args) + response = client.describe_model_package(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeModelQualityJobDefinitionResponse") - model_quality_job_definition = cls(**transformed_response) - return model_quality_job_definition + transformed_response = transform(response, "DescribeModelPackageOutput") + model_package = cls(**transformed_response) + return model_package @Base.add_validate_call def refresh( self, - ) -> Optional["ModelQualityJobDefinition"]: + ) -> Optional["ModelPackage"]: """ - Refresh a ModelQualityJobDefinition resource + Refresh a ModelPackage resource + + Returns: + The ModelPackage resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + operation_input_args = { + "ModelPackageName": self.model_package_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client() + response = client.describe_model_package(**operation_input_args) + + # deserialize response and update self + transform(response, "DescribeModelPackageOutput", self) + return self + + @populate_inputs_decorator + @Base.add_validate_call + def update( + self, + model_approval_status: Optional[str] = Unassigned(), + approval_description: Optional[str] = Unassigned(), + customer_metadata_properties: Optional[Dict[str, str]] = Unassigned(), + customer_metadata_properties_to_remove: Optional[List[str]] = Unassigned(), + additional_inference_specifications_to_add: Optional[ + List[AdditionalInferenceSpecificationDefinition] + ] = Unassigned(), + inference_specification: Optional[InferenceSpecification] = Unassigned(), + source_uri: Optional[str] = Unassigned(), + model_card: Optional[ModelPackageModelCard] = Unassigned(), + model_life_cycle: Optional[ModelLifeCycle] = Unassigned(), + client_token: Optional[str] = Unassigned(), + ) -> Optional["ModelPackage"]: + """ + Update a ModelPackage resource + + Parameters: + customer_metadata_properties_to_remove: The metadata properties associated with the model package versions to remove. + additional_inference_specifications_to_add: An array of additional Inference Specification objects to be added to the existing array additional Inference Specification. Total number of additional Inference Specifications can not exceed 15. Each additional Inference Specification specifies artifacts based on this model package that can be used on inference endpoints. Generally used with SageMaker Neo to store the compiled artifacts. + client_token: A unique token that guarantees that the call to this API is idempotent. + + Returns: + The ModelPackage resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + """ + + logger.info("Updating model_package resource.") + client = Base.get_sagemaker_client() + + operation_input_args = { + "ModelPackageArn": self.model_package_arn, + "ModelApprovalStatus": model_approval_status, + "ApprovalDescription": approval_description, + "CustomerMetadataProperties": customer_metadata_properties, + "CustomerMetadataPropertiesToRemove": customer_metadata_properties_to_remove, + "AdditionalInferenceSpecificationsToAdd": additional_inference_specifications_to_add, + "InferenceSpecification": inference_specification, + "SourceUri": source_uri, + "ModelCard": model_card, + "ModelLifeCycle": model_life_cycle, + "ClientToken": client_token, + } + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.update_model_package(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() + + return self + + @Base.add_validate_call + def delete( + self, + ) -> None: + """ + Delete a ModelPackage resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "ModelPackageName": self.model_package_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_model_package(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + + @Base.add_validate_call + def wait_for_status( + self, + target_status: Literal["Pending", "InProgress", "Completed", "Failed", "Deleting"], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a ModelPackage resource to reach certain status. + + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task(f"Waiting for ModelPackage to reach [bold]{target_status} status...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.model_package_status + status.update(f"Current status: [bold]{current_status}") + + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="ModelPackage", status=current_status, reason="(Unknown)" + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="ModelPackage", status=current_status) + time.sleep(poll) + + @Base.add_validate_call + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a ModelPackage resource to be deleted. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for ModelPackage to be deleted...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.model_package_status + status.update(f"Current status: [bold]{current_status}") + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="ModelPackage", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) + + @classmethod + @Base.add_validate_call + def get_all( + cls, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + model_approval_status: Optional[str] = Unassigned(), + model_package_group_name: Optional[str] = Unassigned(), + model_package_type: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["ModelPackage"]: + """ + Get all ModelPackage resources + + Parameters: + creation_time_after: A filter that returns only model packages created after the specified time (timestamp). + creation_time_before: A filter that returns only model packages created before the specified time (timestamp). + max_results: The maximum number of model packages to return in the response. + name_contains: A string in the model package name. This filter returns only model packages whose name contains the specified string. + model_approval_status: A filter that returns only the model packages with the specified approval status. + model_package_group_name: A filter that returns only model versions that belong to the specified model group. + model_package_type: A filter that returns only the model packages of the specified type. This can be one of the following values. UNVERSIONED - List only unversioined models. This is the default value if no ModelPackageType is specified. VERSIONED - List only versioned models. BOTH - List both versioned and unversioned models. + next_token: If the response to a previous ListModelPackages request was truncated, the response includes a NextToken. To retrieve the next set of model packages, use the token in the next request. + sort_by: The parameter by which to sort the results. The default is CreationTime. + sort_order: The sort order for the results. The default is Ascending. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed ModelPackage resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "NameContains": name_contains, + "ModelApprovalStatus": model_approval_status, + "ModelPackageGroupName": model_package_group_name, + "ModelPackageType": model_package_type, + "SortBy": sort_by, + "SortOrder": sort_order, + } + + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_model_packages", + summaries_key="ModelPackageSummaryList", + summary_name="ModelPackageSummary", + resource_cls=ModelPackage, + list_method_kwargs=operation_input_args, + ) + + @Base.add_validate_call + def batch_get( + self, + model_package_arn_list: List[str], + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[BatchDescribeModelPackageOutput]: + """ + This action batch describes a list of versioned model packages. + + Parameters: + model_package_arn_list: The list of Amazon Resource Name (ARN) of the model package groups. + session: Boto3 session. + region: Region name. + + Returns: + BatchDescribeModelPackageOutput + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + operation_input_args = { + "ModelPackageArnList": model_package_arn_list, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling batch_describe_model_package API") + response = client.batch_describe_model_package(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "BatchDescribeModelPackageOutput") + return BatchDescribeModelPackageOutput(**transformed_response) + + +class ModelPackageGroup(Base): + """ + Class representing resource ModelPackageGroup + + Attributes: + model_package_group_name: The name of the model group. + model_package_group_arn: The Amazon Resource Name (ARN) of the model group. + creation_time: The time that the model group was created. + created_by: + model_package_group_status: The status of the model group. + model_package_group_description: A description of the model group. + + """ + + model_package_group_name: str + model_package_group_arn: Optional[str] = Unassigned() + model_package_group_description: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + created_by: Optional[UserContext] = Unassigned() + model_package_group_status: Optional[str] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "model_package_group_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object model_package_group") + return None + + @classmethod + @Base.add_validate_call + def create( + cls, + model_package_group_name: str, + model_package_group_description: Optional[str] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["ModelPackageGroup"]: + """ + Create a ModelPackageGroup resource + + Parameters: + model_package_group_name: The name of the model group. + model_package_group_description: A description for the model group. + tags: A list of key value pairs associated with the model group. For more information, see Tagging Amazon Web Services resources in the Amazon Web Services General Reference Guide. + session: Boto3 session. + region: Region name. + + Returns: + The ModelPackageGroup resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + """ + + logger.info("Creating model_package_group resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "ModelPackageGroupName": model_package_group_name, + "ModelPackageGroupDescription": model_package_group_description, + "Tags": tags, + } + + operation_input_args = Base.populate_chained_attributes( + resource_name="ModelPackageGroup", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.create_model_package_group(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get( + model_package_group_name=model_package_group_name, session=session, region=region + ) + + @classmethod + @Base.add_validate_call + def get( + cls, + model_package_group_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["ModelPackageGroup"]: + """ + Get a ModelPackageGroup resource + + Parameters: + model_package_group_name: The name of the model group to describe. + session: Boto3 session. + region: Region name. + + Returns: + The ModelPackageGroup resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + operation_input_args = { + "ModelPackageGroupName": model_package_group_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + response = client.describe_model_package_group(**operation_input_args) + + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeModelPackageGroupOutput") + model_package_group = cls(**transformed_response) + return model_package_group + + @Base.add_validate_call + def refresh( + self, + ) -> Optional["ModelPackageGroup"]: + """ + Refresh a ModelPackageGroup resource + + Returns: + The ModelPackageGroup resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + operation_input_args = { + "ModelPackageGroupName": self.model_package_group_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client() + response = client.describe_model_package_group(**operation_input_args) + + # deserialize response and update self + transform(response, "DescribeModelPackageGroupOutput", self) + return self + + @Base.add_validate_call + def delete( + self, + ) -> None: + """ + Delete a ModelPackageGroup resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "ModelPackageGroupName": self.model_package_group_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_model_package_group(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + + @Base.add_validate_call + def wait_for_status( + self, + target_status: Literal[ + "Pending", "InProgress", "Completed", "Failed", "Deleting", "DeleteFailed" + ], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a ModelPackageGroup resource to reach certain status. + + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task(f"Waiting for ModelPackageGroup to reach [bold]{target_status} status...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.model_package_group_status + status.update(f"Current status: [bold]{current_status}") + + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="ModelPackageGroup", status=current_status, reason="(Unknown)" + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="ModelPackageGroup", status=current_status + ) + time.sleep(poll) + + @Base.add_validate_call + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a ModelPackageGroup resource to be deleted. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for ModelPackageGroup to be deleted...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.model_package_group_status + status.update(f"Current status: [bold]{current_status}") + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="ModelPackageGroup", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e + time.sleep(poll) + + @classmethod + @Base.add_validate_call + def get_all( + cls, + creation_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + cross_account_filter_option: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["ModelPackageGroup"]: + """ + Get all ModelPackageGroup resources + + Parameters: + creation_time_after: A filter that returns only model groups created after the specified time. + creation_time_before: A filter that returns only model groups created before the specified time. + max_results: The maximum number of results to return in the response. + name_contains: A string in the model group name. This filter returns only model groups whose name contains the specified string. + next_token: If the result of the previous ListModelPackageGroups request was truncated, the response includes a NextToken. To retrieve the next set of model groups, use the token in the next request. + sort_by: The field to sort results by. The default is CreationTime. + sort_order: The sort order for results. The default is Ascending. + cross_account_filter_option: A filter that returns either model groups shared with you or model groups in your own account. When the value is CrossAccount, the results show the resources made discoverable to you from other accounts. When the value is SameAccount or null, the results show resources from your account. The default is SameAccount. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed ModelPackageGroup resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "CreationTimeAfter": creation_time_after, + "CreationTimeBefore": creation_time_before, + "NameContains": name_contains, + "SortBy": sort_by, + "SortOrder": sort_order, + "CrossAccountFilterOption": cross_account_filter_option, + } + + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_model_package_groups", + summaries_key="ModelPackageGroupSummaryList", + summary_name="ModelPackageGroupSummary", + resource_cls=ModelPackageGroup, + list_method_kwargs=operation_input_args, + ) + + @Base.add_validate_call + def get_policy( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[str]: + """ + Gets a resource policy that manages access for a model group. + + Parameters: + session: Boto3 session. + region: Region name. + + Returns: + str + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + operation_input_args = { + "ModelPackageGroupName": self.model_package_group_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling get_model_package_group_policy API") + response = client.get_model_package_group_policy(**operation_input_args) + logger.debug(f"Response: {response}") + + return list(response.values())[0] + + @Base.add_validate_call + def delete_policy( + self, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: + """ + Deletes a model group resource policy. + + Parameters: + session: Boto3 session. + region: Region name. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + operation_input_args = { + "ModelPackageGroupName": self.model_package_group_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling delete_model_package_group_policy API") + response = client.delete_model_package_group_policy(**operation_input_args) + logger.debug(f"Response: {response}") + + @Base.add_validate_call + def put_policy( + self, + resource_policy: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> None: + """ + Adds a resouce policy to control access to a model group. + + Parameters: + resource_policy: The resource policy for the model group. + session: Boto3 session. + region: Region name. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. + """ + + operation_input_args = { + "ModelPackageGroupName": self.model_package_group_name, + "ResourcePolicy": resource_policy, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling put_model_package_group_policy API") + response = client.put_model_package_group_policy(**operation_input_args) + logger.debug(f"Response: {response}") + + +class ModelQualityJobDefinition(Base): + """ + Class representing resource ModelQualityJobDefinition + + Attributes: + job_definition_arn: The Amazon Resource Name (ARN) of the model quality job. + job_definition_name: The name of the quality job definition. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + creation_time: The time at which the model quality job was created. + model_quality_app_specification: Configures the model quality job to run a specified Docker container image. + model_quality_job_input: Inputs for the model quality job. + model_quality_job_output_config: + job_resources: + role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. + model_quality_baseline_config: The baseline configuration for a model quality job. + network_config: Networking options for a model quality job. + stopping_condition: + + """ + + job_definition_name: str + job_definition_arn: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + model_quality_baseline_config: Optional[ModelQualityBaselineConfig] = Unassigned() + model_quality_app_specification: Optional[ModelQualityAppSpecification] = Unassigned() + model_quality_job_input: Optional[ModelQualityJobInput] = Unassigned() + model_quality_job_output_config: Optional[MonitoringOutputConfig] = Unassigned() + job_resources: Optional[MonitoringResources] = Unassigned() + network_config: Optional[MonitoringNetworkConfig] = Unassigned() + role_arn: Optional[str] = Unassigned() + stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "model_quality_job_definition_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object model_quality_job_definition") + return None + + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "model_quality_job_input": { + "ground_truth_s3_input": {"s3_uri": {"type": "string"}}, + "endpoint_input": { + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + "batch_transform_input": { + "data_captured_destination_s3_uri": {"type": "string"}, + "s3_input_mode": {"type": "string"}, + "s3_data_distribution_type": {"type": "string"}, + }, + }, + "model_quality_job_output_config": {"kms_key_id": {"type": "string"}}, + "job_resources": {"cluster_config": {"volume_kms_key_id": {"type": "string"}}}, + "role_arn": {"type": "string"}, + "model_quality_baseline_config": { + "constraints_resource": {"s3_uri": {"type": "string"}} + }, + "network_config": { + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "ModelQualityJobDefinition", **kwargs + ), + ) + + return wrapper + + @classmethod + @populate_inputs_decorator + @Base.add_validate_call + def create( + cls, + job_definition_name: str, + model_quality_app_specification: ModelQualityAppSpecification, + model_quality_job_input: ModelQualityJobInput, + model_quality_job_output_config: MonitoringOutputConfig, + job_resources: MonitoringResources, + role_arn: str, + model_quality_baseline_config: Optional[ModelQualityBaselineConfig] = Unassigned(), + network_config: Optional[MonitoringNetworkConfig] = Unassigned(), + stopping_condition: Optional[MonitoringStoppingCondition] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["ModelQualityJobDefinition"]: + """ + Create a ModelQualityJobDefinition resource + + Parameters: + job_definition_name: The name of the monitoring job definition. + model_quality_app_specification: The container that runs the monitoring job. + model_quality_job_input: A list of the inputs that are monitored. Currently endpoints are supported. + model_quality_job_output_config: + job_resources: + role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon SageMaker can assume to perform tasks on your behalf. + model_quality_baseline_config: Specifies the constraints and baselines for the monitoring job. + network_config: Specifies the network configuration for the monitoring job. + stopping_condition: + tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. + session: Boto3 session. + region: Region name. + + Returns: + The ModelQualityJobDefinition resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + """ + + logger.info("Creating model_quality_job_definition resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "JobDefinitionName": job_definition_name, + "ModelQualityBaselineConfig": model_quality_baseline_config, + "ModelQualityAppSpecification": model_quality_app_specification, + "ModelQualityJobInput": model_quality_job_input, + "ModelQualityJobOutputConfig": model_quality_job_output_config, + "JobResources": job_resources, + "NetworkConfig": network_config, + "RoleArn": role_arn, + "StoppingCondition": stopping_condition, + "Tags": tags, + } + + operation_input_args = Base.populate_chained_attributes( + resource_name="ModelQualityJobDefinition", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.create_model_quality_job_definition(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get(job_definition_name=job_definition_name, session=session, region=region) + + @classmethod + @Base.add_validate_call + def get( + cls, + job_definition_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["ModelQualityJobDefinition"]: + """ + Get a ModelQualityJobDefinition resource + + Parameters: + job_definition_name: The name of the model quality job. The name must be unique within an Amazon Web Services Region in the Amazon Web Services account. + session: Boto3 session. + region: Region name. + + Returns: + The ModelQualityJobDefinition resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "JobDefinitionName": job_definition_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + response = client.describe_model_quality_job_definition(**operation_input_args) + + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeModelQualityJobDefinitionResponse") + model_quality_job_definition = cls(**transformed_response) + return model_quality_job_definition + + @Base.add_validate_call + def refresh( + self, + ) -> Optional["ModelQualityJobDefinition"]: + """ + Refresh a ModelQualityJobDefinition resource + + Returns: + The ModelQualityJobDefinition resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "JobDefinitionName": self.job_definition_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client() + response = client.describe_model_quality_job_definition(**operation_input_args) + + # deserialize response and update self + transform(response, "DescribeModelQualityJobDefinitionResponse", self) + return self + + @Base.add_validate_call + def delete( + self, + ) -> None: + """ + Delete a ModelQualityJobDefinition resource + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client() + + operation_input_args = { + "JobDefinitionName": self.job_definition_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_model_quality_job_definition(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + + @classmethod + @Base.add_validate_call + def get_all( + cls, + endpoint_name: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["ModelQualityJobDefinition"]: + """ + Get all ModelQualityJobDefinition resources + + Parameters: + endpoint_name: A filter that returns only model quality monitoring job definitions that are associated with the specified endpoint. + sort_by: The field to sort results by. The default is CreationTime. + sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. + next_token: If the result of the previous ListModelQualityJobDefinitions request was truncated, the response includes a NextToken. To retrieve the next set of model quality monitoring job definitions, use the token in the next request. + max_results: The maximum number of results to return in a call to ListModelQualityJobDefinitions. + name_contains: A string in the transform job name. This filter returns only model quality monitoring job definitions whose name contains the specified string. + creation_time_before: A filter that returns only model quality monitoring job definitions created before the specified time. + creation_time_after: A filter that returns only model quality monitoring job definitions created after the specified time. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed ModelQualityJobDefinition resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "EndpointName": endpoint_name, + "SortBy": sort_by, + "SortOrder": sort_order, + "NameContains": name_contains, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + } + custom_key_mapping = { + "monitoring_job_definition_name": "job_definition_name", + "monitoring_job_definition_arn": "job_definition_arn", + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_model_quality_job_definitions", + summaries_key="JobDefinitionSummaries", + summary_name="MonitoringJobDefinitionSummary", + resource_cls=ModelQualityJobDefinition, + custom_key_mapping=custom_key_mapping, + list_method_kwargs=operation_input_args, + ) + + +class MonitoringAlert(Base): + """ + Class representing resource MonitoringAlert + + Attributes: + monitoring_alert_name: The name of a monitoring alert. + creation_time: A timestamp that indicates when a monitor alert was created. + last_modified_time: A timestamp that indicates when a monitor alert was last updated. + alert_status: The current status of an alert. + datapoints_to_alert: Within EvaluationPeriod, how many execution failures will raise an alert. + evaluation_period: The number of most recent monitoring executions to consider when evaluating alert status. + actions: A list of alert actions taken in response to an alert going into InAlert status. + + """ + + monitoring_alert_name: str + creation_time: datetime.datetime + last_modified_time: datetime.datetime + alert_status: str + datapoints_to_alert: int + evaluation_period: int + actions: MonitoringAlertActions + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "monitoring_alert_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object monitoring_alert") + return None + + @Base.add_validate_call + def update( + self, + monitoring_schedule_name: str, + datapoints_to_alert: int, + evaluation_period: int, + ) -> Optional["MonitoringAlert"]: + """ + Update a MonitoringAlert resource + + Parameters: + monitoring_schedule_name: The name of a monitoring schedule. + + Returns: + The MonitoringAlert resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. + """ + + logger.info("Updating monitoring_alert resource.") + client = Base.get_sagemaker_client() + + operation_input_args = { + "MonitoringScheduleName": monitoring_schedule_name, + "MonitoringAlertName": self.monitoring_alert_name, + "DatapointsToAlert": datapoints_to_alert, + "EvaluationPeriod": evaluation_period, + } + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.update_monitoring_alert(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() + + return self + + @classmethod + @Base.add_validate_call + def get_all( + cls, + monitoring_schedule_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["MonitoringAlert"]: + """ + Get all MonitoringAlert resources + + Parameters: + monitoring_schedule_name: The name of a monitoring schedule. + next_token: If the result of the previous ListMonitoringAlerts request was truncated, the response includes a NextToken. To retrieve the next set of alerts in the history, use the token in the next request. + max_results: The maximum number of results to display. The default is 100. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed MonitoringAlert resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "MonitoringScheduleName": monitoring_schedule_name, + } + + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_monitoring_alerts", + summaries_key="MonitoringAlertSummaries", + summary_name="MonitoringAlertSummary", + resource_cls=MonitoringAlert, + list_method_kwargs=operation_input_args, + ) + + @Base.add_validate_call + def list_history( + self, + monitoring_schedule_name: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + next_token: Optional[str] = Unassigned(), + max_results: Optional[int] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + status_equals: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional[MonitoringAlertHistorySummary]: + """ + Gets a list of past alerts in a model monitoring schedule. + + Parameters: + monitoring_schedule_name: The name of a monitoring schedule. + sort_by: The field used to sort results. The default is CreationTime. + sort_order: The sort order, whether Ascending or Descending, of the alert history. The default is Descending. + next_token: If the result of the previous ListMonitoringAlertHistory request was truncated, the response includes a NextToken. To retrieve the next set of alerts in the history, use the token in the next request. + max_results: The maximum number of results to display. The default is 100. + creation_time_before: A filter that returns only alerts created on or before the specified time. + creation_time_after: A filter that returns only alerts created on or after the specified time. + status_equals: A filter that retrieves only alerts with a specific status. + session: Boto3 session. + region: Region name. + + Returns: + MonitoringAlertHistorySummary + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "MonitoringScheduleName": monitoring_schedule_name, + "MonitoringAlertName": self.monitoring_alert_name, + "SortBy": sort_by, + "SortOrder": sort_order, + "NextToken": next_token, + "MaxResults": max_results, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + "StatusEquals": status_equals, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + logger.debug(f"Calling list_monitoring_alert_history API") + response = client.list_monitoring_alert_history(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "ListMonitoringAlertHistoryResponse") + return MonitoringAlertHistorySummary(**transformed_response) + + +class MonitoringExecution(Base): + """ + Class representing resource MonitoringExecution + + Attributes: + monitoring_schedule_name: The name of the monitoring schedule. + scheduled_time: The time the monitoring job was scheduled. + creation_time: The time at which the monitoring job was created. + last_modified_time: A timestamp that indicates the last time the monitoring job was modified. + monitoring_execution_status: The status of the monitoring job. + processing_job_arn: The Amazon Resource Name (ARN) of the monitoring job. + endpoint_name: The name of the endpoint used to run the monitoring job. + failure_reason: Contains the reason a monitoring job failed, if it failed. + monitoring_job_definition_name: The name of the monitoring job. + monitoring_type: The type of the monitoring job. + + """ + + monitoring_schedule_name: str + scheduled_time: datetime.datetime + creation_time: datetime.datetime + last_modified_time: datetime.datetime + monitoring_execution_status: str + processing_job_arn: Optional[str] = Unassigned() + endpoint_name: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() + monitoring_job_definition_name: Optional[str] = Unassigned() + monitoring_type: Optional[str] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "monitoring_execution_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object monitoring_execution") + return None + + @classmethod + @Base.add_validate_call + def get_all( + cls, + monitoring_schedule_name: Optional[str] = Unassigned(), + endpoint_name: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + scheduled_time_before: Optional[datetime.datetime] = Unassigned(), + scheduled_time_after: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + status_equals: Optional[str] = Unassigned(), + monitoring_job_definition_name: Optional[str] = Unassigned(), + monitoring_type_equals: Optional[str] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["MonitoringExecution"]: + """ + Get all MonitoringExecution resources + + Parameters: + monitoring_schedule_name: Name of a specific schedule to fetch jobs for. + endpoint_name: Name of a specific endpoint to fetch jobs for. + sort_by: Whether to sort the results by the Status, CreationTime, or ScheduledTime field. The default is CreationTime. + sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. + next_token: The token returned if the response is truncated. To retrieve the next set of job executions, use it in the next request. + max_results: The maximum number of jobs to return in the response. The default value is 10. + scheduled_time_before: Filter for jobs scheduled before a specified time. + scheduled_time_after: Filter for jobs scheduled after a specified time. + creation_time_before: A filter that returns only jobs created before a specified time. + creation_time_after: A filter that returns only jobs created after a specified time. + last_modified_time_before: A filter that returns only jobs modified after a specified time. + last_modified_time_after: A filter that returns only jobs modified before a specified time. + status_equals: A filter that retrieves only jobs with a specific status. + monitoring_job_definition_name: Gets a list of the monitoring job runs of the specified monitoring job definitions. + monitoring_type_equals: A filter that returns only the monitoring job runs of the specified monitoring type. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed MonitoringExecution resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "MonitoringScheduleName": monitoring_schedule_name, + "EndpointName": endpoint_name, + "SortBy": sort_by, + "SortOrder": sort_order, + "ScheduledTimeBefore": scheduled_time_before, + "ScheduledTimeAfter": scheduled_time_after, + "CreationTimeBefore": creation_time_before, + "CreationTimeAfter": creation_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "LastModifiedTimeAfter": last_modified_time_after, + "StatusEquals": status_equals, + "MonitoringJobDefinitionName": monitoring_job_definition_name, + "MonitoringTypeEquals": monitoring_type_equals, + } + + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_monitoring_executions", + summaries_key="MonitoringExecutionSummaries", + summary_name="MonitoringExecutionSummary", + resource_cls=MonitoringExecution, + list_method_kwargs=operation_input_args, + ) + + +class MonitoringSchedule(Base): + """ + Class representing resource MonitoringSchedule + + Attributes: + monitoring_schedule_arn: The Amazon Resource Name (ARN) of the monitoring schedule. + monitoring_schedule_name: Name of the monitoring schedule. + monitoring_schedule_status: The status of an monitoring job. + creation_time: The time at which the monitoring job was created. + last_modified_time: The time at which the monitoring job was last modified. + monitoring_schedule_config: The configuration object that specifies the monitoring schedule and defines the monitoring job. + monitoring_type: The type of the monitoring job that this schedule runs. This is one of the following values. DATA_QUALITY - The schedule is for a data quality monitoring job. MODEL_QUALITY - The schedule is for a model quality monitoring job. MODEL_BIAS - The schedule is for a bias monitoring job. MODEL_EXPLAINABILITY - The schedule is for an explainability monitoring job. + failure_reason: A string, up to one KB in size, that contains the reason a monitoring job failed, if it failed. + endpoint_name: The name of the endpoint for the monitoring job. + last_monitoring_execution_summary: Describes metadata on the last execution to run, if there was one. + + """ + + monitoring_schedule_name: str + monitoring_schedule_arn: Optional[str] = Unassigned() + monitoring_schedule_status: Optional[str] = Unassigned() + monitoring_type: Optional[str] = Unassigned() + failure_reason: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + monitoring_schedule_config: Optional[MonitoringScheduleConfig] = Unassigned() + endpoint_name: Optional[str] = Unassigned() + last_monitoring_execution_summary: Optional[MonitoringExecutionSummary] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "monitoring_schedule_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object monitoring_schedule") + return None + + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "monitoring_schedule_config": { + "monitoring_job_definition": { + "monitoring_output_config": {"kms_key_id": {"type": "string"}}, + "monitoring_resources": { + "cluster_config": {"volume_kms_key_id": {"type": "string"}} + }, + "role_arn": {"type": "string"}, + "baseline_config": { + "constraints_resource": {"s3_uri": {"type": "string"}}, + "statistics_resource": {"s3_uri": {"type": "string"}}, + }, + "network_config": { + "vpc_config": { + "security_group_ids": { + "type": "array", + "items": {"type": "string"}, + }, + "subnets": {"type": "array", "items": {"type": "string"}}, + } + }, + } + } + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "MonitoringSchedule", **kwargs + ), + ) + + return wrapper + + @classmethod + @populate_inputs_decorator + @Base.add_validate_call + def create( + cls, + monitoring_schedule_name: str, + monitoring_schedule_config: MonitoringScheduleConfig, + tags: Optional[List[Tag]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["MonitoringSchedule"]: + """ + Create a MonitoringSchedule resource + + Parameters: + monitoring_schedule_name: The name of the monitoring schedule. The name must be unique within an Amazon Web Services Region within an Amazon Web Services account. + monitoring_schedule_config: The configuration object that specifies the monitoring schedule and defines the monitoring job. + tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. + session: Boto3 session. + region: Region name. Returns: - The ModelQualityJobDefinition resource. + The MonitoringSchedule resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20049,29 +21688,59 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ + logger.info("Creating monitoring_schedule resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + operation_input_args = { - "JobDefinitionName": self.job_definition_name, + "MonitoringScheduleName": monitoring_schedule_name, + "MonitoringScheduleConfig": monitoring_schedule_config, + "Tags": tags, } + + operation_input_args = Base.populate_chained_attributes( + resource_name="MonitoringSchedule", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client() - response = client.describe_model_quality_job_definition(**operation_input_args) + # create the resource + response = client.create_monitoring_schedule(**operation_input_args) + logger.debug(f"Response: {response}") - # deserialize response and update self - transform(response, "DescribeModelQualityJobDefinitionResponse", self) - return self + return cls.get( + monitoring_schedule_name=monitoring_schedule_name, session=session, region=region + ) + @classmethod @Base.add_validate_call - def delete( - self, - ) -> None: + def get( + cls, + monitoring_schedule_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["MonitoringSchedule"]: """ - Delete a ModelQualityJobDefinition resource + Get a MonitoringSchedule resource + + Parameters: + monitoring_schedule_name: Name of a previously created monitoring schedule. + session: Boto3 session. + region: Region name. + + Returns: + The MonitoringSchedule resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20086,49 +21755,34 @@ def delete( ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() - operation_input_args = { - "JobDefinitionName": self.job_definition_name, + "MonitoringScheduleName": monitoring_schedule_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_model_quality_job_definition(**operation_input_args) + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + response = client.describe_monitoring_schedule(**operation_input_args) - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeMonitoringScheduleResponse") + monitoring_schedule = cls(**transformed_response) + return monitoring_schedule - @classmethod @Base.add_validate_call - def get_all( - cls, - endpoint_name: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["ModelQualityJobDefinition"]: + def refresh( + self, + ) -> Optional["MonitoringSchedule"]: """ - Get all ModelQualityJobDefinition resources - - Parameters: - endpoint_name: A filter that returns only model quality monitoring job definitions that are associated with the specified endpoint. - sort_by: The field to sort results by. The default is CreationTime. - sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. - next_token: If the result of the previous ListModelQualityJobDefinitions request was truncated, the response includes a NextToken. To retrieve the next set of model quality monitoring job definitions, use the token in the next request. - max_results: The maximum number of results to return in a call to ListModelQualityJobDefinitions. - name_contains: A string in the transform job name. This filter returns only model quality monitoring job definitions whose name contains the specified string. - creation_time_before: A filter that returns only model quality monitoring job definitions created before the specified time. - creation_time_after: A filter that returns only model quality monitoring job definitions created after the specified time. - session: Boto3 session. - region: Region name. + Refresh a MonitoringSchedule resource Returns: - Iterator for listed ModelQualityJobDefinition resources. + The MonitoringSchedule resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20140,93 +21794,34 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { - "EndpointName": endpoint_name, - "SortBy": sort_by, - "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, - } - custom_key_mapping = { - "monitoring_job_definition_name": "job_definition_name", - "monitoring_job_definition_arn": "job_definition_arn", + "MonitoringScheduleName": self.monitoring_schedule_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - return ResourceIterator( - client=client, - list_method="list_model_quality_job_definitions", - summaries_key="JobDefinitionSummaries", - summary_name="MonitoringJobDefinitionSummary", - resource_cls=ModelQualityJobDefinition, - custom_key_mapping=custom_key_mapping, - list_method_kwargs=operation_input_args, - ) - - -class MonitoringAlert(Base): - """ - Class representing resource MonitoringAlert - - Attributes: - monitoring_alert_name: The name of a monitoring alert. - creation_time: A timestamp that indicates when a monitor alert was created. - last_modified_time: A timestamp that indicates when a monitor alert was last updated. - alert_status: The current status of an alert. - datapoints_to_alert: Within EvaluationPeriod, how many execution failures will raise an alert. - evaluation_period: The number of most recent monitoring executions to consider when evaluating alert status. - actions: A list of alert actions taken in response to an alert going into InAlert status. - - """ - - monitoring_alert_name: str - creation_time: datetime.datetime - last_modified_time: datetime.datetime - alert_status: str - datapoints_to_alert: int - evaluation_period: int - actions: MonitoringAlertActions - - def get_name(self) -> str: - attributes = vars(self) - resource_name = "monitoring_alert_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] - - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) + client = Base.get_sagemaker_client() + response = client.describe_monitoring_schedule(**operation_input_args) - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object monitoring_alert") - return None + # deserialize response and update self + transform(response, "DescribeMonitoringScheduleResponse", self) + return self + @populate_inputs_decorator @Base.add_validate_call def update( self, - monitoring_schedule_name: str, - datapoints_to_alert: int, - evaluation_period: int, - ) -> Optional["MonitoringAlert"]: + monitoring_schedule_config: MonitoringScheduleConfig, + ) -> Optional["MonitoringSchedule"]: """ - Update a MonitoringAlert resource - - Parameters: - monitoring_schedule_name: The name of a monitoring schedule. + Update a MonitoringSchedule resource Returns: - The MonitoringAlert resource. + The MonitoringSchedule resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20242,14 +21837,12 @@ def update( ResourceNotFound: Resource being access is not found. """ - logger.info("Updating monitoring_alert resource.") + logger.info("Updating monitoring_schedule resource.") client = Base.get_sagemaker_client() operation_input_args = { - "MonitoringScheduleName": monitoring_schedule_name, - "MonitoringAlertName": self.monitoring_alert_name, - "DatapointsToAlert": datapoints_to_alert, - "EvaluationPeriod": evaluation_period, + "MonitoringScheduleName": self.monitoring_schedule_name, + "MonitoringScheduleConfig": monitoring_schedule_config, } logger.debug(f"Input request: {operation_input_args}") # serialize the input request @@ -20257,32 +21850,18 @@ def update( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_monitoring_alert(**operation_input_args) + response = client.update_monitoring_schedule(**operation_input_args) logger.debug(f"Response: {response}") self.refresh() return self - @classmethod @Base.add_validate_call - def get_all( - cls, - monitoring_schedule_name: str, - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> ResourceIterator["MonitoringAlert"]: + def delete( + self, + ) -> None: """ - Get all MonitoringAlert resources - - Parameters: - monitoring_schedule_name: The name of a monitoring schedule. - next_token: If the result of the previous ListMonitoringAlerts request was truncated, the response includes a NextToken. To retrieve the next set of alerts in the history, use the token in the next request. - max_results: The maximum number of results to display. The default is 100. - session: Boto3 session. - region: Region name. - - Returns: - Iterator for listed MonitoringAlert resources. + Delete a MonitoringSchedule resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20294,61 +21873,26 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. - """ - - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - operation_input_args = { - "MonitoringScheduleName": monitoring_schedule_name, - } - - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - return ResourceIterator( - client=client, - list_method="list_monitoring_alerts", - summaries_key="MonitoringAlertSummaries", - summary_name="MonitoringAlertSummary", - resource_cls=MonitoringAlert, - list_method_kwargs=operation_input_args, - ) - - @Base.add_validate_call - def list_history( - self, - monitoring_schedule_name: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - next_token: Optional[str] = Unassigned(), - max_results: Optional[int] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - creation_time_after: Optional[datetime.datetime] = Unassigned(), - status_equals: Optional[str] = Unassigned(), - session: Optional[Session] = None, - region: Optional[str] = None, - ) -> Optional[MonitoringAlertHistorySummary]: + ResourceNotFound: Resource being access is not found. """ - Gets a list of past alerts in a model monitoring schedule. - Parameters: - monitoring_schedule_name: The name of a monitoring schedule. - sort_by: The field used to sort results. The default is CreationTime. - sort_order: The sort order, whether Ascending or Descending, of the alert history. The default is Descending. - next_token: If the result of the previous ListMonitoringAlertHistory request was truncated, the response includes a NextToken. To retrieve the next set of alerts in the history, use the token in the next request. - max_results: The maximum number of results to display. The default is 100. - creation_time_before: A filter that returns only alerts created on or before the specified time. - creation_time_after: A filter that returns only alerts created on or after the specified time. - status_equals: A filter that retrieves only alerts with a specific status. - session: Boto3 session. - region: Region name. + client = Base.get_sagemaker_client() - Returns: - MonitoringAlertHistorySummary + operation_input_args = { + "MonitoringScheduleName": self.monitoring_schedule_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_monitoring_schedule(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + + @Base.add_validate_call + def stop(self) -> None: + """ + Stop a MonitoringSchedule resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20363,88 +21907,89 @@ def list_history( ResourceNotFound: Resource being access is not found. """ + client = SageMakerClient().client + operation_input_args = { - "MonitoringScheduleName": monitoring_schedule_name, - "MonitoringAlertName": self.monitoring_alert_name, - "SortBy": sort_by, - "SortOrder": sort_order, - "NextToken": next_token, - "MaxResults": max_results, - "CreationTimeBefore": creation_time_before, - "CreationTimeAfter": creation_time_after, - "StatusEquals": status_equals, + "MonitoringScheduleName": self.monitoring_schedule_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - logger.debug(f"Calling list_monitoring_alert_history API") - response = client.list_monitoring_alert_history(**operation_input_args) - logger.debug(f"Response: {response}") + client.stop_monitoring_schedule(**operation_input_args) - transformed_response = transform(response, "ListMonitoringAlertHistoryResponse") - return MonitoringAlertHistorySummary(**transformed_response) + logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") + @Base.add_validate_call + def wait_for_status( + self, + target_status: Literal["Pending", "Failed", "Scheduled", "Stopped"], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a MonitoringSchedule resource to reach certain status. -class MonitoringExecution(Base): - """ - Class representing resource MonitoringExecution + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. - Attributes: - monitoring_schedule_name: The name of the monitoring schedule. - scheduled_time: The time the monitoring job was scheduled. - creation_time: The time at which the monitoring job was created. - last_modified_time: A timestamp that indicates the last time the monitoring job was modified. - monitoring_execution_status: The status of the monitoring job. - processing_job_arn: The Amazon Resource Name (ARN) of the monitoring job. - endpoint_name: The name of the endpoint used to run the monitoring job. - failure_reason: Contains the reason a monitoring job failed, if it failed. - monitoring_job_definition_name: The name of the monitoring job. - monitoring_type: The type of the monitoring job. + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() - """ + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task( + f"Waiting for MonitoringSchedule to reach [bold]{target_status} status..." + ) + status = Status("Current status:") - monitoring_schedule_name: str - scheduled_time: datetime.datetime - creation_time: datetime.datetime - last_modified_time: datetime.datetime - monitoring_execution_status: str - processing_job_arn: Optional[str] = Unassigned() - endpoint_name: Optional[str] = Unassigned() - failure_reason: Optional[str] = Unassigned() - monitoring_job_definition_name: Optional[str] = Unassigned() - monitoring_type: Optional[str] = Unassigned() + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.monitoring_schedule_status + status.update(f"Current status: [bold]{current_status}") - def get_name(self) -> str: - attributes = vars(self) - resource_name = "monitoring_execution_name" - resource_name_split = resource_name.split("_") - attribute_name_candidates = [] + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return - l = len(resource_name_split) - for i in range(0, l): - attribute_name_candidates.append("_".join(resource_name_split[i:l])) + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="MonitoringSchedule", + status=current_status, + reason=self.failure_reason, + ) - for attribute, value in attributes.items(): - if attribute == "name" or attribute in attribute_name_candidates: - return value - logger.error("Name attribute not found for object monitoring_execution") - return None + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="MonitoringSchedule", status=current_status + ) + time.sleep(poll) @classmethod @Base.add_validate_call def get_all( cls, - monitoring_schedule_name: Optional[str] = Unassigned(), endpoint_name: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), - scheduled_time_before: Optional[datetime.datetime] = Unassigned(), - scheduled_time_after: Optional[datetime.datetime] = Unassigned(), + name_contains: Optional[str] = Unassigned(), creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), last_modified_time_before: Optional[datetime.datetime] = Unassigned(), @@ -20454,31 +21999,29 @@ def get_all( monitoring_type_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["MonitoringExecution"]: + ) -> ResourceIterator["MonitoringSchedule"]: """ - Get all MonitoringExecution resources + Get all MonitoringSchedule resources Parameters: - monitoring_schedule_name: Name of a specific schedule to fetch jobs for. - endpoint_name: Name of a specific endpoint to fetch jobs for. + endpoint_name: Name of a specific endpoint to fetch schedules for. sort_by: Whether to sort the results by the Status, CreationTime, or ScheduledTime field. The default is CreationTime. sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. next_token: The token returned if the response is truncated. To retrieve the next set of job executions, use it in the next request. max_results: The maximum number of jobs to return in the response. The default value is 10. - scheduled_time_before: Filter for jobs scheduled before a specified time. - scheduled_time_after: Filter for jobs scheduled after a specified time. - creation_time_before: A filter that returns only jobs created before a specified time. - creation_time_after: A filter that returns only jobs created after a specified time. - last_modified_time_before: A filter that returns only jobs modified after a specified time. - last_modified_time_after: A filter that returns only jobs modified before a specified time. - status_equals: A filter that retrieves only jobs with a specific status. - monitoring_job_definition_name: Gets a list of the monitoring job runs of the specified monitoring job definitions. - monitoring_type_equals: A filter that returns only the monitoring job runs of the specified monitoring type. + name_contains: Filter for monitoring schedules whose name contains a specified string. + creation_time_before: A filter that returns only monitoring schedules created before a specified time. + creation_time_after: A filter that returns only monitoring schedules created after a specified time. + last_modified_time_before: A filter that returns only monitoring schedules modified before a specified time. + last_modified_time_after: A filter that returns only monitoring schedules modified after a specified time. + status_equals: A filter that returns only monitoring schedules modified before a specified time. + monitoring_job_definition_name: Gets a list of the monitoring schedules for the specified monitoring job definition. + monitoring_type_equals: A filter that returns only the monitoring schedules for the specified monitoring type. session: Boto3 session. region: Region name. Returns: - Iterator for listed MonitoringExecution resources. + Iterator for listed MonitoringSchedule resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20497,12 +22040,10 @@ def get_all( ) operation_input_args = { - "MonitoringScheduleName": monitoring_schedule_name, "EndpointName": endpoint_name, "SortBy": sort_by, "SortOrder": sort_order, - "ScheduledTimeBefore": scheduled_time_before, - "ScheduledTimeAfter": scheduled_time_after, + "NameContains": name_contains, "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, "LastModifiedTimeBefore": last_modified_time_before, @@ -20518,46 +22059,72 @@ def get_all( return ResourceIterator( client=client, - list_method="list_monitoring_executions", - summaries_key="MonitoringExecutionSummaries", - summary_name="MonitoringExecutionSummary", - resource_cls=MonitoringExecution, + list_method="list_monitoring_schedules", + summaries_key="MonitoringScheduleSummaries", + summary_name="MonitoringScheduleSummary", + resource_cls=MonitoringSchedule, list_method_kwargs=operation_input_args, ) -class MonitoringSchedule(Base): +class NotebookInstance(Base): """ - Class representing resource MonitoringSchedule + Class representing resource NotebookInstance Attributes: - monitoring_schedule_arn: The Amazon Resource Name (ARN) of the monitoring schedule. - monitoring_schedule_name: Name of the monitoring schedule. - monitoring_schedule_status: The status of an monitoring job. - creation_time: The time at which the monitoring job was created. - last_modified_time: The time at which the monitoring job was last modified. - monitoring_schedule_config: The configuration object that specifies the monitoring schedule and defines the monitoring job. - monitoring_type: The type of the monitoring job that this schedule runs. This is one of the following values. DATA_QUALITY - The schedule is for a data quality monitoring job. MODEL_QUALITY - The schedule is for a model quality monitoring job. MODEL_BIAS - The schedule is for a bias monitoring job. MODEL_EXPLAINABILITY - The schedule is for an explainability monitoring job. - failure_reason: A string, up to one KB in size, that contains the reason a monitoring job failed, if it failed. - endpoint_name: The name of the endpoint for the monitoring job. - last_monitoring_execution_summary: Describes metadata on the last execution to run, if there was one. + notebook_instance_arn: The Amazon Resource Name (ARN) of the notebook instance. + notebook_instance_name: The name of the SageMaker notebook instance. + notebook_instance_status: The status of the notebook instance. + failure_reason: If status is Failed, the reason it failed. + url: The URL that you use to connect to the Jupyter notebook that is running in your notebook instance. + instance_type: The type of ML compute instance running on the notebook instance. + subnet_id: The ID of the VPC subnet. + security_groups: The IDs of the VPC security groups. + role_arn: The Amazon Resource Name (ARN) of the IAM role associated with the instance. + kms_key_id: The Amazon Web Services KMS key ID SageMaker uses to encrypt data when storing it on the ML storage volume attached to the instance. + network_interface_id: The network interface IDs that SageMaker created at the time of creating the instance. + last_modified_time: A timestamp. Use this parameter to retrieve the time when the notebook instance was last modified. + creation_time: A timestamp. Use this parameter to return the time when the notebook instance was created + notebook_instance_lifecycle_config_name: Returns the name of a notebook instance lifecycle configuration. For information about notebook instance lifestyle configurations, see Step 2.1: (Optional) Customize a Notebook Instance + direct_internet_access: Describes whether SageMaker provides internet access to the notebook instance. If this value is set to Disabled, the notebook instance does not have internet access, and cannot connect to SageMaker training and endpoint services. For more information, see Notebook Instances Are Internet-Enabled by Default. + volume_size_in_gb: The size, in GB, of the ML storage volume attached to the notebook instance. + accelerator_types: This parameter is no longer supported. Elastic Inference (EI) is no longer available. This parameter was used to specify a list of the EI instance types associated with this notebook instance. + default_code_repository: The Git repository associated with the notebook instance as its default code repository. This can be either the name of a Git repository stored as a resource in your account, or the URL of a Git repository in Amazon Web Services CodeCommit or in any other Git repository. When you open a notebook instance, it opens in the directory that contains this repository. For more information, see Associating Git Repositories with SageMaker Notebook Instances. + additional_code_repositories: An array of up to three Git repositories associated with the notebook instance. These can be either the names of Git repositories stored as resources in your account, or the URL of Git repositories in Amazon Web Services CodeCommit or in any other Git repository. These repositories are cloned at the same level as the default repository of your notebook instance. For more information, see Associating Git Repositories with SageMaker Notebook Instances. + root_access: Whether root access is enabled or disabled for users of the notebook instance. Lifecycle configurations need root access to be able to set up a notebook instance. Because of this, lifecycle configurations associated with a notebook instance always run with root access even if you disable root access for users. + platform_identifier: The platform identifier of the notebook instance runtime environment. + instance_metadata_service_configuration: Information on the IMDS configuration of the notebook instance """ - monitoring_schedule_name: str - monitoring_schedule_arn: Optional[str] = Unassigned() - monitoring_schedule_status: Optional[str] = Unassigned() - monitoring_type: Optional[str] = Unassigned() + notebook_instance_name: str + notebook_instance_arn: Optional[str] = Unassigned() + notebook_instance_status: Optional[str] = Unassigned() failure_reason: Optional[str] = Unassigned() - creation_time: Optional[datetime.datetime] = Unassigned() + url: Optional[str] = Unassigned() + instance_type: Optional[str] = Unassigned() + subnet_id: Optional[str] = Unassigned() + security_groups: Optional[List[str]] = Unassigned() + role_arn: Optional[str] = Unassigned() + kms_key_id: Optional[str] = Unassigned() + network_interface_id: Optional[str] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() - monitoring_schedule_config: Optional[MonitoringScheduleConfig] = Unassigned() - endpoint_name: Optional[str] = Unassigned() - last_monitoring_execution_summary: Optional[MonitoringExecutionSummary] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + notebook_instance_lifecycle_config_name: Optional[str] = Unassigned() + direct_internet_access: Optional[str] = Unassigned() + volume_size_in_gb: Optional[int] = Unassigned() + accelerator_types: Optional[List[str]] = Unassigned() + default_code_repository: Optional[str] = Unassigned() + additional_code_repositories: Optional[List[str]] = Unassigned() + root_access: Optional[str] = Unassigned() + platform_identifier: Optional[str] = Unassigned() + instance_metadata_service_configuration: Optional[InstanceMetadataServiceConfiguration] = ( + Unassigned() + ) def get_name(self) -> str: attributes = vars(self) - resource_name = "monitoring_schedule_name" + resource_name = "notebook_instance_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -20568,40 +22135,22 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object monitoring_schedule") + logger.error("Name attribute not found for object notebook_instance") return None def populate_inputs_decorator(create_func): @functools.wraps(create_func) def wrapper(*args, **kwargs): config_schema_for_resource = { - "monitoring_schedule_config": { - "monitoring_job_definition": { - "monitoring_output_config": {"kms_key_id": {"type": "string"}}, - "monitoring_resources": { - "cluster_config": {"volume_kms_key_id": {"type": "string"}} - }, - "role_arn": {"type": "string"}, - "baseline_config": { - "constraints_resource": {"s3_uri": {"type": "string"}}, - "statistics_resource": {"s3_uri": {"type": "string"}}, - }, - "network_config": { - "vpc_config": { - "security_group_ids": { - "type": "array", - "items": {"type": "string"}, - }, - "subnets": {"type": "array", "items": {"type": "string"}}, - } - }, - } - } + "subnet_id": {"type": "string"}, + "security_groups": {"type": "array", "items": {"type": "string"}}, + "role_arn": {"type": "string"}, + "kms_key_id": {"type": "string"}, } return create_func( *args, **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "MonitoringSchedule", **kwargs + config_schema_for_resource, "NotebookInstance", **kwargs ), ) @@ -20612,24 +22161,52 @@ def wrapper(*args, **kwargs): @Base.add_validate_call def create( cls, - monitoring_schedule_name: str, - monitoring_schedule_config: MonitoringScheduleConfig, + notebook_instance_name: str, + instance_type: str, + role_arn: str, + subnet_id: Optional[str] = Unassigned(), + security_group_ids: Optional[List[str]] = Unassigned(), + kms_key_id: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), + lifecycle_config_name: Optional[str] = Unassigned(), + direct_internet_access: Optional[str] = Unassigned(), + volume_size_in_gb: Optional[int] = Unassigned(), + accelerator_types: Optional[List[str]] = Unassigned(), + default_code_repository: Optional[str] = Unassigned(), + additional_code_repositories: Optional[List[str]] = Unassigned(), + root_access: Optional[str] = Unassigned(), + platform_identifier: Optional[str] = Unassigned(), + instance_metadata_service_configuration: Optional[ + InstanceMetadataServiceConfiguration + ] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["MonitoringSchedule"]: + ) -> Optional["NotebookInstance"]: """ - Create a MonitoringSchedule resource + Create a NotebookInstance resource Parameters: - monitoring_schedule_name: The name of the monitoring schedule. The name must be unique within an Amazon Web Services Region within an Amazon Web Services account. - monitoring_schedule_config: The configuration object that specifies the monitoring schedule and defines the monitoring job. - tags: (Optional) An array of key-value pairs. For more information, see Using Cost Allocation Tags in the Amazon Web Services Billing and Cost Management User Guide. + notebook_instance_name: The name of the new notebook instance. + instance_type: The type of ML compute instance to launch for the notebook instance. + role_arn: When you send any requests to Amazon Web Services resources from the notebook instance, SageMaker assumes this role to perform tasks on your behalf. You must grant this role necessary permissions so SageMaker can perform these tasks. The policy must allow the SageMaker service principal (sagemaker.amazonaws.com) permissions to assume this role. For more information, see SageMaker Roles. To be able to pass this role to SageMaker, the caller of this API must have the iam:PassRole permission. + subnet_id: The ID of the subnet in a VPC to which you would like to have a connectivity from your ML compute instance. + security_group_ids: The VPC security group IDs, in the form sg-xxxxxxxx. The security groups must be for the same VPC as specified in the subnet. + kms_key_id: The Amazon Resource Name (ARN) of a Amazon Web Services Key Management Service key that SageMaker uses to encrypt data on the storage volume attached to your notebook instance. The KMS key you provide must be enabled. For information, see Enabling and Disabling Keys in the Amazon Web Services Key Management Service Developer Guide. + tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. + lifecycle_config_name: The name of a lifecycle configuration to associate with the notebook instance. For information about lifestyle configurations, see Step 2.1: (Optional) Customize a Notebook Instance. + direct_internet_access: Sets whether SageMaker provides internet access to the notebook instance. If you set this to Disabled this notebook instance is able to access resources only in your VPC, and is not be able to connect to SageMaker training and endpoint services unless you configure a NAT Gateway in your VPC. For more information, see Notebook Instances Are Internet-Enabled by Default. You can set the value of this parameter to Disabled only if you set a value for the SubnetId parameter. + volume_size_in_gb: The size, in GB, of the ML storage volume to attach to the notebook instance. The default value is 5 GB. + accelerator_types: This parameter is no longer supported. Elastic Inference (EI) is no longer available. This parameter was used to specify a list of EI instance types to associate with this notebook instance. + default_code_repository: A Git repository to associate with the notebook instance as its default code repository. This can be either the name of a Git repository stored as a resource in your account, or the URL of a Git repository in Amazon Web Services CodeCommit or in any other Git repository. When you open a notebook instance, it opens in the directory that contains this repository. For more information, see Associating Git Repositories with SageMaker Notebook Instances. + additional_code_repositories: An array of up to three Git repositories to associate with the notebook instance. These can be either the names of Git repositories stored as resources in your account, or the URL of Git repositories in Amazon Web Services CodeCommit or in any other Git repository. These repositories are cloned at the same level as the default repository of your notebook instance. For more information, see Associating Git Repositories with SageMaker Notebook Instances. + root_access: Whether root access is enabled or disabled for users of the notebook instance. The default value is Enabled. Lifecycle configurations need root access to be able to set up a notebook instance. Because of this, lifecycle configurations associated with a notebook instance always run with root access even if you disable root access for users. + platform_identifier: The platform identifier of the notebook instance runtime environment. + instance_metadata_service_configuration: Information on the IMDS configuration of the notebook instance session: Boto3 session. region: Region name. Returns: - The MonitoringSchedule resource. + The NotebookInstance resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20641,26 +22218,38 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating monitoring_schedule resource.") + logger.info("Creating notebook_instance resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "MonitoringScheduleName": monitoring_schedule_name, - "MonitoringScheduleConfig": monitoring_schedule_config, + "NotebookInstanceName": notebook_instance_name, + "InstanceType": instance_type, + "SubnetId": subnet_id, + "SecurityGroupIds": security_group_ids, + "RoleArn": role_arn, + "KmsKeyId": kms_key_id, "Tags": tags, + "LifecycleConfigName": lifecycle_config_name, + "DirectInternetAccess": direct_internet_access, + "VolumeSizeInGB": volume_size_in_gb, + "AcceleratorTypes": accelerator_types, + "DefaultCodeRepository": default_code_repository, + "AdditionalCodeRepositories": additional_code_repositories, + "RootAccess": root_access, + "PlatformIdentifier": platform_identifier, + "InstanceMetadataServiceConfiguration": instance_metadata_service_configuration, } operation_input_args = Base.populate_chained_attributes( - resource_name="MonitoringSchedule", operation_input_args=operation_input_args + resource_name="NotebookInstance", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -20669,31 +22258,31 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_monitoring_schedule(**operation_input_args) + response = client.create_notebook_instance(**operation_input_args) logger.debug(f"Response: {response}") return cls.get( - monitoring_schedule_name=monitoring_schedule_name, session=session, region=region + notebook_instance_name=notebook_instance_name, session=session, region=region ) @classmethod @Base.add_validate_call def get( cls, - monitoring_schedule_name: str, + notebook_instance_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["MonitoringSchedule"]: + ) -> Optional["NotebookInstance"]: """ - Get a MonitoringSchedule resource + Get a NotebookInstance resource Parameters: - monitoring_schedule_name: Name of a previously created monitoring schedule. + notebook_instance_name: The name of the notebook instance that you want information about. session: Boto3 session. region: Region name. Returns: - The MonitoringSchedule resource. + The NotebookInstance resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20705,11 +22294,10 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "MonitoringScheduleName": monitoring_schedule_name, + "NotebookInstanceName": notebook_instance_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -20718,24 +22306,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_monitoring_schedule(**operation_input_args) + response = client.describe_notebook_instance(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeMonitoringScheduleResponse") - monitoring_schedule = cls(**transformed_response) - return monitoring_schedule + transformed_response = transform(response, "DescribeNotebookInstanceOutput") + notebook_instance = cls(**transformed_response) + return notebook_instance @Base.add_validate_call def refresh( self, - ) -> Optional["MonitoringSchedule"]: + ) -> Optional["NotebookInstance"]: """ - Refresh a MonitoringSchedule resource + Refresh a NotebookInstance resource Returns: - The MonitoringSchedule resource. + The NotebookInstance resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20747,34 +22335,54 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "MonitoringScheduleName": self.monitoring_schedule_name, + "NotebookInstanceName": self.notebook_instance_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_monitoring_schedule(**operation_input_args) + response = client.describe_notebook_instance(**operation_input_args) # deserialize response and update self - transform(response, "DescribeMonitoringScheduleResponse", self) + transform(response, "DescribeNotebookInstanceOutput", self) return self @populate_inputs_decorator @Base.add_validate_call def update( self, - monitoring_schedule_config: MonitoringScheduleConfig, - ) -> Optional["MonitoringSchedule"]: + instance_type: Optional[str] = Unassigned(), + role_arn: Optional[str] = Unassigned(), + lifecycle_config_name: Optional[str] = Unassigned(), + disassociate_lifecycle_config: Optional[bool] = Unassigned(), + volume_size_in_gb: Optional[int] = Unassigned(), + default_code_repository: Optional[str] = Unassigned(), + additional_code_repositories: Optional[List[str]] = Unassigned(), + accelerator_types: Optional[List[str]] = Unassigned(), + disassociate_accelerator_types: Optional[bool] = Unassigned(), + disassociate_default_code_repository: Optional[bool] = Unassigned(), + disassociate_additional_code_repositories: Optional[bool] = Unassigned(), + root_access: Optional[str] = Unassigned(), + instance_metadata_service_configuration: Optional[ + InstanceMetadataServiceConfiguration + ] = Unassigned(), + ) -> Optional["NotebookInstance"]: """ - Update a MonitoringSchedule resource + Update a NotebookInstance resource + + Parameters: + lifecycle_config_name: The name of a lifecycle configuration to associate with the notebook instance. For information about lifestyle configurations, see Step 2.1: (Optional) Customize a Notebook Instance. + disassociate_lifecycle_config: Set to true to remove the notebook instance lifecycle configuration currently associated with the notebook instance. This operation is idempotent. If you specify a lifecycle configuration that is not associated with the notebook instance when you call this method, it does not throw an error. + disassociate_accelerator_types: This parameter is no longer supported. Elastic Inference (EI) is no longer available. This parameter was used to specify a list of the EI instance types to remove from this notebook instance. + disassociate_default_code_repository: The name or URL of the default Git repository to remove from this notebook instance. This operation is idempotent. If you specify a Git repository that is not associated with the notebook instance when you call this method, it does not throw an error. + disassociate_additional_code_repositories: A list of names or URLs of the default Git repositories to remove from this notebook instance. This operation is idempotent. If you specify a Git repository that is not associated with the notebook instance when you call this method, it does not throw an error. Returns: - The MonitoringSchedule resource. + The NotebookInstance resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20787,15 +22395,26 @@ def update( error_code = e.response['Error']['Code'] ``` ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. - ResourceNotFound: Resource being access is not found. """ - logger.info("Updating monitoring_schedule resource.") + logger.info("Updating notebook_instance resource.") client = Base.get_sagemaker_client() operation_input_args = { - "MonitoringScheduleName": self.monitoring_schedule_name, - "MonitoringScheduleConfig": monitoring_schedule_config, + "NotebookInstanceName": self.notebook_instance_name, + "InstanceType": instance_type, + "RoleArn": role_arn, + "LifecycleConfigName": lifecycle_config_name, + "DisassociateLifecycleConfig": disassociate_lifecycle_config, + "VolumeSizeInGB": volume_size_in_gb, + "DefaultCodeRepository": default_code_repository, + "AdditionalCodeRepositories": additional_code_repositories, + "AcceleratorTypes": accelerator_types, + "DisassociateAcceleratorTypes": disassociate_accelerator_types, + "DisassociateDefaultCodeRepository": disassociate_default_code_repository, + "DisassociateAdditionalCodeRepositories": disassociate_additional_code_repositories, + "RootAccess": root_access, + "InstanceMetadataServiceConfiguration": instance_metadata_service_configuration, } logger.debug(f"Input request: {operation_input_args}") # serialize the input request @@ -20803,7 +22422,7 @@ def update( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_monitoring_schedule(**operation_input_args) + response = client.update_notebook_instance(**operation_input_args) logger.debug(f"Response: {response}") self.refresh() @@ -20814,7 +22433,7 @@ def delete( self, ) -> None: """ - Delete a MonitoringSchedule resource + Delete a NotebookInstance resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20826,26 +22445,25 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ client = Base.get_sagemaker_client() operation_input_args = { - "MonitoringScheduleName": self.monitoring_schedule_name, + "NotebookInstanceName": self.notebook_instance_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_monitoring_schedule(**operation_input_args) + client.delete_notebook_instance(**operation_input_args) logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call def stop(self) -> None: """ - Stop a MonitoringSchedule resource + Stop a NotebookInstance resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20857,40 +22475,108 @@ def stop(self) -> None: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceNotFound: Resource being access is not found. """ client = SageMakerClient().client operation_input_args = { - "MonitoringScheduleName": self.monitoring_schedule_name, + "NotebookInstanceName": self.notebook_instance_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.stop_monitoring_schedule(**operation_input_args) + client.stop_notebook_instance(**operation_input_args) logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def wait_for_status( + def wait_for_status( + self, + target_status: Literal[ + "Pending", "InService", "Stopping", "Stopped", "Failed", "Deleting", "Updating" + ], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a NotebookInstance resource to reach certain status. + + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task(f"Waiting for NotebookInstance to reach [bold]{target_status} status...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.notebook_instance_status + status.update(f"Current status: [bold]{current_status}") + + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="NotebookInstance", + status=current_status, + reason=self.failure_reason, + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="NotebookInstance", status=current_status + ) + time.sleep(poll) + + @Base.add_validate_call + def wait_for_delete( self, - target_status: Literal["Pending", "Failed", "Scheduled", "Stopped"], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a MonitoringSchedule resource to reach certain status. + Wait for a NotebookInstance resource to be deleted. Parameters: - target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. + DeleteFailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. """ start_time = time.time() @@ -20900,9 +22586,7 @@ def wait_for_status( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task( - f"Waiting for MonitoringSchedule to reach [bold]{target_status} status..." - ) + progress.add_task("Waiting for NotebookInstance to be deleted...") status = Status("Current status:") with Live( @@ -20910,36 +22594,31 @@ def wait_for_status( Group(progress, status), title="Wait Log Panel", border_style=Style(color=Color.BLUE.value), - ), - transient=True, + ) ): while True: - self.refresh() - current_status = self.monitoring_schedule_status - status.update(f"Current status: [bold]{current_status}") - - if target_status == current_status: - logger.info(f"Final Resource Status: [bold]{current_status}") - return + try: + self.refresh() + current_status = self.notebook_instance_status + status.update(f"Current status: [bold]{current_status}") - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="MonitoringSchedule", - status=current_status, - reason=self.failure_reason, - ) + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="NotebookInstance", status=current_status + ) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="MonitoringSchedule", status=current_status - ) + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e time.sleep(poll) @classmethod @Base.add_validate_call def get_all( cls, - endpoint_name: Optional[str] = Unassigned(), sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), name_contains: Optional[str] = Unassigned(), @@ -20948,33 +22627,34 @@ def get_all( last_modified_time_before: Optional[datetime.datetime] = Unassigned(), last_modified_time_after: Optional[datetime.datetime] = Unassigned(), status_equals: Optional[str] = Unassigned(), - monitoring_job_definition_name: Optional[str] = Unassigned(), - monitoring_type_equals: Optional[str] = Unassigned(), + notebook_instance_lifecycle_config_name_contains: Optional[str] = Unassigned(), + default_code_repository_contains: Optional[str] = Unassigned(), + additional_code_repository_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["MonitoringSchedule"]: + ) -> ResourceIterator["NotebookInstance"]: """ - Get all MonitoringSchedule resources + Get all NotebookInstance resources Parameters: - endpoint_name: Name of a specific endpoint to fetch schedules for. - sort_by: Whether to sort the results by the Status, CreationTime, or ScheduledTime field. The default is CreationTime. - sort_order: Whether to sort the results in Ascending or Descending order. The default is Descending. - next_token: The token returned if the response is truncated. To retrieve the next set of job executions, use it in the next request. - max_results: The maximum number of jobs to return in the response. The default value is 10. - name_contains: Filter for monitoring schedules whose name contains a specified string. - creation_time_before: A filter that returns only monitoring schedules created before a specified time. - creation_time_after: A filter that returns only monitoring schedules created after a specified time. - last_modified_time_before: A filter that returns only monitoring schedules modified before a specified time. - last_modified_time_after: A filter that returns only monitoring schedules modified after a specified time. - status_equals: A filter that returns only monitoring schedules modified before a specified time. - monitoring_job_definition_name: Gets a list of the monitoring schedules for the specified monitoring job definition. - monitoring_type_equals: A filter that returns only the monitoring schedules for the specified monitoring type. + next_token: If the previous call to the ListNotebookInstances is truncated, the response includes a NextToken. You can use this token in your subsequent ListNotebookInstances request to fetch the next set of notebook instances. You might specify a filter or a sort order in your request. When response is truncated, you must use the same values for the filer and sort order in the next request. + max_results: The maximum number of notebook instances to return. + sort_by: The field to sort results by. The default is Name. + sort_order: The sort order for results. + name_contains: A string in the notebook instances' name. This filter returns only notebook instances whose name contains the specified string. + creation_time_before: A filter that returns only notebook instances that were created before the specified time (timestamp). + creation_time_after: A filter that returns only notebook instances that were created after the specified time (timestamp). + last_modified_time_before: A filter that returns only notebook instances that were modified before the specified time (timestamp). + last_modified_time_after: A filter that returns only notebook instances that were modified after the specified time (timestamp). + status_equals: A filter that returns only notebook instances with the specified status. + notebook_instance_lifecycle_config_name_contains: A string in the name of a notebook instances lifecycle configuration associated with this notebook instance. This filter returns only notebook instances associated with a lifecycle configuration with a name that contains the specified string. + default_code_repository_contains: A string in the name or URL of a Git repository associated with this notebook instance. This filter returns only notebook instances associated with a git repository with a name that contains the specified string. + additional_code_repository_equals: A filter that returns only notebook instances with associated with the specified git repository. session: Boto3 session. region: Region name. Returns: - Iterator for listed MonitoringSchedule resources. + Iterator for listed NotebookInstance resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -20993,7 +22673,6 @@ def get_all( ) operation_input_args = { - "EndpointName": endpoint_name, "SortBy": sort_by, "SortOrder": sort_order, "NameContains": name_contains, @@ -21002,8 +22681,9 @@ def get_all( "LastModifiedTimeBefore": last_modified_time_before, "LastModifiedTimeAfter": last_modified_time_after, "StatusEquals": status_equals, - "MonitoringJobDefinitionName": monitoring_job_definition_name, - "MonitoringTypeEquals": monitoring_type_equals, + "NotebookInstanceLifecycleConfigNameContains": notebook_instance_lifecycle_config_name_contains, + "DefaultCodeRepositoryContains": default_code_repository_contains, + "AdditionalCodeRepositoryEquals": additional_code_repository_equals, } # serialize the input request @@ -21012,72 +22692,38 @@ def get_all( return ResourceIterator( client=client, - list_method="list_monitoring_schedules", - summaries_key="MonitoringScheduleSummaries", - summary_name="MonitoringScheduleSummary", - resource_cls=MonitoringSchedule, + list_method="list_notebook_instances", + summaries_key="NotebookInstances", + summary_name="NotebookInstanceSummary", + resource_cls=NotebookInstance, list_method_kwargs=operation_input_args, ) -class NotebookInstance(Base): +class NotebookInstanceLifecycleConfig(Base): """ - Class representing resource NotebookInstance + Class representing resource NotebookInstanceLifecycleConfig Attributes: - notebook_instance_arn: The Amazon Resource Name (ARN) of the notebook instance. - notebook_instance_name: The name of the SageMaker notebook instance. - notebook_instance_status: The status of the notebook instance. - failure_reason: If status is Failed, the reason it failed. - url: The URL that you use to connect to the Jupyter notebook that is running in your notebook instance. - instance_type: The type of ML compute instance running on the notebook instance. - subnet_id: The ID of the VPC subnet. - security_groups: The IDs of the VPC security groups. - role_arn: The Amazon Resource Name (ARN) of the IAM role associated with the instance. - kms_key_id: The Amazon Web Services KMS key ID SageMaker uses to encrypt data when storing it on the ML storage volume attached to the instance. - network_interface_id: The network interface IDs that SageMaker created at the time of creating the instance. - last_modified_time: A timestamp. Use this parameter to retrieve the time when the notebook instance was last modified. - creation_time: A timestamp. Use this parameter to return the time when the notebook instance was created - notebook_instance_lifecycle_config_name: Returns the name of a notebook instance lifecycle configuration. For information about notebook instance lifestyle configurations, see Step 2.1: (Optional) Customize a Notebook Instance - direct_internet_access: Describes whether SageMaker provides internet access to the notebook instance. If this value is set to Disabled, the notebook instance does not have internet access, and cannot connect to SageMaker training and endpoint services. For more information, see Notebook Instances Are Internet-Enabled by Default. - volume_size_in_gb: The size, in GB, of the ML storage volume attached to the notebook instance. - accelerator_types: This parameter is no longer supported. Elastic Inference (EI) is no longer available. This parameter was used to specify a list of the EI instance types associated with this notebook instance. - default_code_repository: The Git repository associated with the notebook instance as its default code repository. This can be either the name of a Git repository stored as a resource in your account, or the URL of a Git repository in Amazon Web Services CodeCommit or in any other Git repository. When you open a notebook instance, it opens in the directory that contains this repository. For more information, see Associating Git Repositories with SageMaker Notebook Instances. - additional_code_repositories: An array of up to three Git repositories associated with the notebook instance. These can be either the names of Git repositories stored as resources in your account, or the URL of Git repositories in Amazon Web Services CodeCommit or in any other Git repository. These repositories are cloned at the same level as the default repository of your notebook instance. For more information, see Associating Git Repositories with SageMaker Notebook Instances. - root_access: Whether root access is enabled or disabled for users of the notebook instance. Lifecycle configurations need root access to be able to set up a notebook instance. Because of this, lifecycle configurations associated with a notebook instance always run with root access even if you disable root access for users. - platform_identifier: The platform identifier of the notebook instance runtime environment. - instance_metadata_service_configuration: Information on the IMDS configuration of the notebook instance + notebook_instance_lifecycle_config_arn: The Amazon Resource Name (ARN) of the lifecycle configuration. + notebook_instance_lifecycle_config_name: The name of the lifecycle configuration. + on_create: The shell script that runs only once, when you create a notebook instance. + on_start: The shell script that runs every time you start a notebook instance, including when you create the notebook instance. + last_modified_time: A timestamp that tells when the lifecycle configuration was last modified. + creation_time: A timestamp that tells when the lifecycle configuration was created. """ - notebook_instance_name: str - notebook_instance_arn: Optional[str] = Unassigned() - notebook_instance_status: Optional[str] = Unassigned() - failure_reason: Optional[str] = Unassigned() - url: Optional[str] = Unassigned() - instance_type: Optional[str] = Unassigned() - subnet_id: Optional[str] = Unassigned() - security_groups: Optional[List[str]] = Unassigned() - role_arn: Optional[str] = Unassigned() - kms_key_id: Optional[str] = Unassigned() - network_interface_id: Optional[str] = Unassigned() + notebook_instance_lifecycle_config_name: str + notebook_instance_lifecycle_config_arn: Optional[str] = Unassigned() + on_create: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned() + on_start: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned() last_modified_time: Optional[datetime.datetime] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - notebook_instance_lifecycle_config_name: Optional[str] = Unassigned() - direct_internet_access: Optional[str] = Unassigned() - volume_size_in_gb: Optional[int] = Unassigned() - accelerator_types: Optional[List[str]] = Unassigned() - default_code_repository: Optional[str] = Unassigned() - additional_code_repositories: Optional[List[str]] = Unassigned() - root_access: Optional[str] = Unassigned() - platform_identifier: Optional[str] = Unassigned() - instance_metadata_service_configuration: Optional[InstanceMetadataServiceConfiguration] = ( - Unassigned() - ) def get_name(self) -> str: attributes = vars(self) - resource_name = "notebook_instance_name" + resource_name = "notebook_instance_lifecycle_config_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -21088,78 +22734,31 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object notebook_instance") + logger.error("Name attribute not found for object notebook_instance_lifecycle_config") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "subnet_id": {"type": "string"}, - "security_groups": {"type": "array", "items": {"type": "string"}}, - "role_arn": {"type": "string"}, - "kms_key_id": {"type": "string"}, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "NotebookInstance", **kwargs - ), - ) - - return wrapper - @classmethod - @populate_inputs_decorator @Base.add_validate_call def create( cls, - notebook_instance_name: str, - instance_type: str, - role_arn: str, - subnet_id: Optional[str] = Unassigned(), - security_group_ids: Optional[List[str]] = Unassigned(), - kms_key_id: Optional[str] = Unassigned(), - tags: Optional[List[Tag]] = Unassigned(), - lifecycle_config_name: Optional[str] = Unassigned(), - direct_internet_access: Optional[str] = Unassigned(), - volume_size_in_gb: Optional[int] = Unassigned(), - accelerator_types: Optional[List[str]] = Unassigned(), - default_code_repository: Optional[str] = Unassigned(), - additional_code_repositories: Optional[List[str]] = Unassigned(), - root_access: Optional[str] = Unassigned(), - platform_identifier: Optional[str] = Unassigned(), - instance_metadata_service_configuration: Optional[ - InstanceMetadataServiceConfiguration - ] = Unassigned(), + notebook_instance_lifecycle_config_name: str, + on_create: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), + on_start: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["NotebookInstance"]: + ) -> Optional["NotebookInstanceLifecycleConfig"]: """ - Create a NotebookInstance resource + Create a NotebookInstanceLifecycleConfig resource Parameters: - notebook_instance_name: The name of the new notebook instance. - instance_type: The type of ML compute instance to launch for the notebook instance. - role_arn: When you send any requests to Amazon Web Services resources from the notebook instance, SageMaker assumes this role to perform tasks on your behalf. You must grant this role necessary permissions so SageMaker can perform these tasks. The policy must allow the SageMaker service principal (sagemaker.amazonaws.com) permissions to assume this role. For more information, see SageMaker Roles. To be able to pass this role to SageMaker, the caller of this API must have the iam:PassRole permission. - subnet_id: The ID of the subnet in a VPC to which you would like to have a connectivity from your ML compute instance. - security_group_ids: The VPC security group IDs, in the form sg-xxxxxxxx. The security groups must be for the same VPC as specified in the subnet. - kms_key_id: The Amazon Resource Name (ARN) of a Amazon Web Services Key Management Service key that SageMaker uses to encrypt data on the storage volume attached to your notebook instance. The KMS key you provide must be enabled. For information, see Enabling and Disabling Keys in the Amazon Web Services Key Management Service Developer Guide. - tags: An array of key-value pairs. You can use tags to categorize your Amazon Web Services resources in different ways, for example, by purpose, owner, or environment. For more information, see Tagging Amazon Web Services Resources. - lifecycle_config_name: The name of a lifecycle configuration to associate with the notebook instance. For information about lifestyle configurations, see Step 2.1: (Optional) Customize a Notebook Instance. - direct_internet_access: Sets whether SageMaker provides internet access to the notebook instance. If you set this to Disabled this notebook instance is able to access resources only in your VPC, and is not be able to connect to SageMaker training and endpoint services unless you configure a NAT Gateway in your VPC. For more information, see Notebook Instances Are Internet-Enabled by Default. You can set the value of this parameter to Disabled only if you set a value for the SubnetId parameter. - volume_size_in_gb: The size, in GB, of the ML storage volume to attach to the notebook instance. The default value is 5 GB. - accelerator_types: This parameter is no longer supported. Elastic Inference (EI) is no longer available. This parameter was used to specify a list of EI instance types to associate with this notebook instance. - default_code_repository: A Git repository to associate with the notebook instance as its default code repository. This can be either the name of a Git repository stored as a resource in your account, or the URL of a Git repository in Amazon Web Services CodeCommit or in any other Git repository. When you open a notebook instance, it opens in the directory that contains this repository. For more information, see Associating Git Repositories with SageMaker Notebook Instances. - additional_code_repositories: An array of up to three Git repositories to associate with the notebook instance. These can be either the names of Git repositories stored as resources in your account, or the URL of Git repositories in Amazon Web Services CodeCommit or in any other Git repository. These repositories are cloned at the same level as the default repository of your notebook instance. For more information, see Associating Git Repositories with SageMaker Notebook Instances. - root_access: Whether root access is enabled or disabled for users of the notebook instance. The default value is Enabled. Lifecycle configurations need root access to be able to set up a notebook instance. Because of this, lifecycle configurations associated with a notebook instance always run with root access even if you disable root access for users. - platform_identifier: The platform identifier of the notebook instance runtime environment. - instance_metadata_service_configuration: Information on the IMDS configuration of the notebook instance + notebook_instance_lifecycle_config_name: The name of the lifecycle configuration. + on_create: A shell script that runs only once, when you create a notebook instance. The shell script must be a base64-encoded string. + on_start: A shell script that runs every time you start a notebook instance, including when you create the notebook instance. The shell script must be a base64-encoded string. session: Boto3 session. region: Region name. Returns: - The NotebookInstance resource. + The NotebookInstanceLifecycleConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21177,32 +22776,20 @@ def create( S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating notebook_instance resource.") + logger.info("Creating notebook_instance_lifecycle_config resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "NotebookInstanceName": notebook_instance_name, - "InstanceType": instance_type, - "SubnetId": subnet_id, - "SecurityGroupIds": security_group_ids, - "RoleArn": role_arn, - "KmsKeyId": kms_key_id, - "Tags": tags, - "LifecycleConfigName": lifecycle_config_name, - "DirectInternetAccess": direct_internet_access, - "VolumeSizeInGB": volume_size_in_gb, - "AcceleratorTypes": accelerator_types, - "DefaultCodeRepository": default_code_repository, - "AdditionalCodeRepositories": additional_code_repositories, - "RootAccess": root_access, - "PlatformIdentifier": platform_identifier, - "InstanceMetadataServiceConfiguration": instance_metadata_service_configuration, + "NotebookInstanceLifecycleConfigName": notebook_instance_lifecycle_config_name, + "OnCreate": on_create, + "OnStart": on_start, } operation_input_args = Base.populate_chained_attributes( - resource_name="NotebookInstance", operation_input_args=operation_input_args + resource_name="NotebookInstanceLifecycleConfig", + operation_input_args=operation_input_args, ) logger.debug(f"Input request: {operation_input_args}") @@ -21211,31 +22798,33 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_notebook_instance(**operation_input_args) + response = client.create_notebook_instance_lifecycle_config(**operation_input_args) logger.debug(f"Response: {response}") return cls.get( - notebook_instance_name=notebook_instance_name, session=session, region=region + notebook_instance_lifecycle_config_name=notebook_instance_lifecycle_config_name, + session=session, + region=region, ) @classmethod @Base.add_validate_call def get( cls, - notebook_instance_name: str, + notebook_instance_lifecycle_config_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["NotebookInstance"]: + ) -> Optional["NotebookInstanceLifecycleConfig"]: """ - Get a NotebookInstance resource + Get a NotebookInstanceLifecycleConfig resource Parameters: - notebook_instance_name: The name of the notebook instance that you want information about. + notebook_instance_lifecycle_config_name: The name of the lifecycle configuration to describe. session: Boto3 session. region: Region name. Returns: - The NotebookInstance resource. + The NotebookInstanceLifecycleConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21250,7 +22839,7 @@ def get( """ operation_input_args = { - "NotebookInstanceName": notebook_instance_name, + "NotebookInstanceLifecycleConfigName": notebook_instance_lifecycle_config_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -21259,24 +22848,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_notebook_instance(**operation_input_args) + response = client.describe_notebook_instance_lifecycle_config(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeNotebookInstanceOutput") - notebook_instance = cls(**transformed_response) - return notebook_instance + transformed_response = transform(response, "DescribeNotebookInstanceLifecycleConfigOutput") + notebook_instance_lifecycle_config = cls(**transformed_response) + return notebook_instance_lifecycle_config @Base.add_validate_call def refresh( self, - ) -> Optional["NotebookInstance"]: + ) -> Optional["NotebookInstanceLifecycleConfig"]: """ - Refresh a NotebookInstance resource + Refresh a NotebookInstanceLifecycleConfig resource Returns: - The NotebookInstance resource. + The NotebookInstanceLifecycleConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21291,51 +22880,30 @@ def refresh( """ operation_input_args = { - "NotebookInstanceName": self.notebook_instance_name, + "NotebookInstanceLifecycleConfigName": self.notebook_instance_lifecycle_config_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_notebook_instance(**operation_input_args) + response = client.describe_notebook_instance_lifecycle_config(**operation_input_args) # deserialize response and update self - transform(response, "DescribeNotebookInstanceOutput", self) + transform(response, "DescribeNotebookInstanceLifecycleConfigOutput", self) return self - @populate_inputs_decorator @Base.add_validate_call def update( self, - instance_type: Optional[str] = Unassigned(), - role_arn: Optional[str] = Unassigned(), - lifecycle_config_name: Optional[str] = Unassigned(), - disassociate_lifecycle_config: Optional[bool] = Unassigned(), - volume_size_in_gb: Optional[int] = Unassigned(), - default_code_repository: Optional[str] = Unassigned(), - additional_code_repositories: Optional[List[str]] = Unassigned(), - accelerator_types: Optional[List[str]] = Unassigned(), - disassociate_accelerator_types: Optional[bool] = Unassigned(), - disassociate_default_code_repository: Optional[bool] = Unassigned(), - disassociate_additional_code_repositories: Optional[bool] = Unassigned(), - root_access: Optional[str] = Unassigned(), - instance_metadata_service_configuration: Optional[ - InstanceMetadataServiceConfiguration - ] = Unassigned(), - ) -> Optional["NotebookInstance"]: + on_create: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), + on_start: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), + ) -> Optional["NotebookInstanceLifecycleConfig"]: """ - Update a NotebookInstance resource - - Parameters: - lifecycle_config_name: The name of a lifecycle configuration to associate with the notebook instance. For information about lifestyle configurations, see Step 2.1: (Optional) Customize a Notebook Instance. - disassociate_lifecycle_config: Set to true to remove the notebook instance lifecycle configuration currently associated with the notebook instance. This operation is idempotent. If you specify a lifecycle configuration that is not associated with the notebook instance when you call this method, it does not throw an error. - disassociate_accelerator_types: This parameter is no longer supported. Elastic Inference (EI) is no longer available. This parameter was used to specify a list of the EI instance types to remove from this notebook instance. - disassociate_default_code_repository: The name or URL of the default Git repository to remove from this notebook instance. This operation is idempotent. If you specify a Git repository that is not associated with the notebook instance when you call this method, it does not throw an error. - disassociate_additional_code_repositories: A list of names or URLs of the default Git repositories to remove from this notebook instance. This operation is idempotent. If you specify a Git repository that is not associated with the notebook instance when you call this method, it does not throw an error. + Update a NotebookInstanceLifecycleConfig resource Returns: - The NotebookInstance resource. + The NotebookInstanceLifecycleConfig resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21350,24 +22918,13 @@ def update( ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. """ - logger.info("Updating notebook_instance resource.") + logger.info("Updating notebook_instance_lifecycle_config resource.") client = Base.get_sagemaker_client() operation_input_args = { - "NotebookInstanceName": self.notebook_instance_name, - "InstanceType": instance_type, - "RoleArn": role_arn, - "LifecycleConfigName": lifecycle_config_name, - "DisassociateLifecycleConfig": disassociate_lifecycle_config, - "VolumeSizeInGB": volume_size_in_gb, - "DefaultCodeRepository": default_code_repository, - "AdditionalCodeRepositories": additional_code_repositories, - "AcceleratorTypes": accelerator_types, - "DisassociateAcceleratorTypes": disassociate_accelerator_types, - "DisassociateDefaultCodeRepository": disassociate_default_code_repository, - "DisassociateAdditionalCodeRepositories": disassociate_additional_code_repositories, - "RootAccess": root_access, - "InstanceMetadataServiceConfiguration": instance_metadata_service_configuration, + "NotebookInstanceLifecycleConfigName": self.notebook_instance_lifecycle_config_name, + "OnCreate": on_create, + "OnStart": on_start, } logger.debug(f"Input request: {operation_input_args}") # serialize the input request @@ -21375,7 +22932,7 @@ def update( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.update_notebook_instance(**operation_input_args) + response = client.update_notebook_instance_lifecycle_config(**operation_input_args) logger.debug(f"Response: {response}") self.refresh() @@ -21386,187 +22943,32 @@ def delete( self, ) -> None: """ - Delete a NotebookInstance resource - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - """ - - client = Base.get_sagemaker_client() - - operation_input_args = { - "NotebookInstanceName": self.notebook_instance_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client.delete_notebook_instance(**operation_input_args) - - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def stop(self) -> None: - """ - Stop a NotebookInstance resource - - Raises: - botocore.exceptions.ClientError: This exception is raised for AWS service related errors. - The error message and error code can be parsed from the exception as follows: - ``` - try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - """ - - client = SageMakerClient().client - - operation_input_args = { - "NotebookInstanceName": self.notebook_instance_name, - } - # serialize the input request - operation_input_args = serialize(operation_input_args) - logger.debug(f"Serialized input request: {operation_input_args}") - - client.stop_notebook_instance(**operation_input_args) - - logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") - - @Base.add_validate_call - def wait_for_status( - self, - target_status: Literal[ - "Pending", "InService", "Stopping", "Stopped", "Failed", "Deleting", "Updating" - ], - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a NotebookInstance resource to reach certain status. - - Parameters: - target_status: The status to wait for. - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. - - Raises: - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - FailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task(f"Waiting for NotebookInstance to reach [bold]{target_status} status...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ), - transient=True, - ): - while True: - self.refresh() - current_status = self.notebook_instance_status - status.update(f"Current status: [bold]{current_status}") - - if target_status == current_status: - logger.info(f"Final Resource Status: [bold]{current_status}") - return - - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="NotebookInstance", - status=current_status, - reason=self.failure_reason, - ) - - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="NotebookInstance", status=current_status - ) - time.sleep(poll) - - @Base.add_validate_call - def wait_for_delete( - self, - poll: int = 5, - timeout: Optional[int] = None, - ) -> None: - """ - Wait for a NotebookInstance resource to be deleted. - - Parameters: - poll: The number of seconds to wait between each poll. - timeout: The maximum number of seconds to wait before timing out. + Delete a NotebookInstanceLifecycleConfig resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. The error message and error code can be parsed from the exception as follows: ``` try: - # AWS service call here - except botocore.exceptions.ClientError as e: - error_message = e.response['Error']['Message'] - error_code = e.response['Error']['Code'] - ``` - TimeoutExceededError: If the resource does not reach a terminal state before the timeout. - DeleteFailedStatusError: If the resource reaches a failed state. - WaiterError: Raised when an error occurs while waiting. - """ - start_time = time.time() - - progress = Progress( - SpinnerColumn("bouncingBar"), - TextColumn("{task.description}"), - TimeElapsedColumn(), - ) - progress.add_task("Waiting for NotebookInstance to be deleted...") - status = Status("Current status:") - - with Live( - Panel( - Group(progress, status), - title="Wait Log Panel", - border_style=Style(color=Color.BLUE.value), - ) - ): - while True: - try: - self.refresh() - current_status = self.notebook_instance_status - status.update(f"Current status: [bold]{current_status}") - - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="NotebookInstance", status=current_status - ) + # AWS service call here except botocore.exceptions.ClientError as e: - error_code = e.response["Error"]["Code"] + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ - if "ResourceNotFound" in error_code or "ValidationException" in error_code: - logger.info("Resource was not found. It may have been deleted.") - return - raise e - time.sleep(poll) + client = Base.get_sagemaker_client() + + operation_input_args = { + "NotebookInstanceLifecycleConfigName": self.notebook_instance_lifecycle_config_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client.delete_notebook_instance_lifecycle_config(**operation_input_args) + + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @classmethod @Base.add_validate_call @@ -21579,35 +22981,27 @@ def get_all( creation_time_after: Optional[datetime.datetime] = Unassigned(), last_modified_time_before: Optional[datetime.datetime] = Unassigned(), last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - status_equals: Optional[str] = Unassigned(), - notebook_instance_lifecycle_config_name_contains: Optional[str] = Unassigned(), - default_code_repository_contains: Optional[str] = Unassigned(), - additional_code_repository_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["NotebookInstance"]: + ) -> ResourceIterator["NotebookInstanceLifecycleConfig"]: """ - Get all NotebookInstance resources + Get all NotebookInstanceLifecycleConfig resources Parameters: - next_token: If the previous call to the ListNotebookInstances is truncated, the response includes a NextToken. You can use this token in your subsequent ListNotebookInstances request to fetch the next set of notebook instances. You might specify a filter or a sort order in your request. When response is truncated, you must use the same values for the filer and sort order in the next request. - max_results: The maximum number of notebook instances to return. - sort_by: The field to sort results by. The default is Name. + next_token: If the result of a ListNotebookInstanceLifecycleConfigs request was truncated, the response includes a NextToken. To get the next set of lifecycle configurations, use the token in the next request. + max_results: The maximum number of lifecycle configurations to return in the response. + sort_by: Sorts the list of results. The default is CreationTime. sort_order: The sort order for results. - name_contains: A string in the notebook instances' name. This filter returns only notebook instances whose name contains the specified string. - creation_time_before: A filter that returns only notebook instances that were created before the specified time (timestamp). - creation_time_after: A filter that returns only notebook instances that were created after the specified time (timestamp). - last_modified_time_before: A filter that returns only notebook instances that were modified before the specified time (timestamp). - last_modified_time_after: A filter that returns only notebook instances that were modified after the specified time (timestamp). - status_equals: A filter that returns only notebook instances with the specified status. - notebook_instance_lifecycle_config_name_contains: A string in the name of a notebook instances lifecycle configuration associated with this notebook instance. This filter returns only notebook instances associated with a lifecycle configuration with a name that contains the specified string. - default_code_repository_contains: A string in the name or URL of a Git repository associated with this notebook instance. This filter returns only notebook instances associated with a git repository with a name that contains the specified string. - additional_code_repository_equals: A filter that returns only notebook instances with associated with the specified git repository. + name_contains: A string in the lifecycle configuration name. This filter returns only lifecycle configurations whose name contains the specified string. + creation_time_before: A filter that returns only lifecycle configurations that were created before the specified time (timestamp). + creation_time_after: A filter that returns only lifecycle configurations that were created after the specified time (timestamp). + last_modified_time_before: A filter that returns only lifecycle configurations that were modified before the specified time (timestamp). + last_modified_time_after: A filter that returns only lifecycle configurations that were modified after the specified time (timestamp). session: Boto3 session. region: Region name. Returns: - Iterator for listed NotebookInstance resources. + Iterator for listed NotebookInstanceLifecycleConfig resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21633,10 +23027,6 @@ def get_all( "CreationTimeAfter": creation_time_after, "LastModifiedTimeBefore": last_modified_time_before, "LastModifiedTimeAfter": last_modified_time_after, - "StatusEquals": status_equals, - "NotebookInstanceLifecycleConfigNameContains": notebook_instance_lifecycle_config_name_contains, - "DefaultCodeRepositoryContains": default_code_repository_contains, - "AdditionalCodeRepositoryEquals": additional_code_repository_equals, } # serialize the input request @@ -21645,38 +23035,60 @@ def get_all( return ResourceIterator( client=client, - list_method="list_notebook_instances", - summaries_key="NotebookInstances", - summary_name="NotebookInstanceSummary", - resource_cls=NotebookInstance, + list_method="list_notebook_instance_lifecycle_configs", + summaries_key="NotebookInstanceLifecycleConfigs", + summary_name="NotebookInstanceLifecycleConfigSummary", + resource_cls=NotebookInstanceLifecycleConfig, list_method_kwargs=operation_input_args, ) -class NotebookInstanceLifecycleConfig(Base): +class OptimizationJob(Base): """ - Class representing resource NotebookInstanceLifecycleConfig + Class representing resource OptimizationJob Attributes: - notebook_instance_lifecycle_config_arn: The Amazon Resource Name (ARN) of the lifecycle configuration. - notebook_instance_lifecycle_config_name: The name of the lifecycle configuration. - on_create: The shell script that runs only once, when you create a notebook instance. - on_start: The shell script that runs every time you start a notebook instance, including when you create the notebook instance. - last_modified_time: A timestamp that tells when the lifecycle configuration was last modified. - creation_time: A timestamp that tells when the lifecycle configuration was created. + optimization_job_arn: The Amazon Resource Name (ARN) of the optimization job. + optimization_job_status: The current status of the optimization job. + creation_time: The time when you created the optimization job. + last_modified_time: The time when the optimization job was last updated. + optimization_job_name: The name that you assigned to the optimization job. + model_source: The location of the source model to optimize with an optimization job. + deployment_instance_type: The type of instance that hosts the optimized model that you create with the optimization job. + optimization_configs: Settings for each of the optimization techniques that the job applies. + output_config: Details for where to store the optimized model that you create with the optimization job. + role_arn: The ARN of the IAM role that you assigned to the optimization job. + stopping_condition: + optimization_start_time: The time when the optimization job started. + optimization_end_time: The time when the optimization job finished processing. + failure_reason: If the optimization job status is FAILED, the reason for the failure. + optimization_environment: The environment variables to set in the model container. + optimization_output: Output values produced by an optimization job. + vpc_config: A VPC in Amazon VPC that your optimized model has access to. """ - notebook_instance_lifecycle_config_name: str - notebook_instance_lifecycle_config_arn: Optional[str] = Unassigned() - on_create: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned() - on_start: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() + optimization_job_name: str + optimization_job_arn: Optional[str] = Unassigned() + optimization_job_status: Optional[str] = Unassigned() + optimization_start_time: Optional[datetime.datetime] = Unassigned() + optimization_end_time: Optional[datetime.datetime] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + failure_reason: Optional[str] = Unassigned() + model_source: Optional[OptimizationJobModelSource] = Unassigned() + optimization_environment: Optional[Dict[str, str]] = Unassigned() + deployment_instance_type: Optional[str] = Unassigned() + optimization_configs: Optional[List[OptimizationConfig]] = Unassigned() + output_config: Optional[OptimizationJobOutputConfig] = Unassigned() + optimization_output: Optional[OptimizationOutput] = Unassigned() + role_arn: Optional[str] = Unassigned() + stopping_condition: Optional[StoppingCondition] = Unassigned() + vpc_config: Optional[OptimizationVpcConfig] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "notebook_instance_lifecycle_config_name" + resource_name = "optimization_job_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -21687,31 +23099,70 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object notebook_instance_lifecycle_config") + logger.error("Name attribute not found for object optimization_job") return None + def populate_inputs_decorator(create_func): + @functools.wraps(create_func) + def wrapper(*args, **kwargs): + config_schema_for_resource = { + "model_source": {"s3": {"s3_uri": {"type": "string"}}}, + "output_config": { + "s3_output_location": {"type": "string"}, + "kms_key_id": {"type": "string"}, + }, + "role_arn": {"type": "string"}, + "vpc_config": { + "security_group_ids": {"type": "array", "items": {"type": "string"}}, + "subnets": {"type": "array", "items": {"type": "string"}}, + }, + } + return create_func( + *args, + **Base.get_updated_kwargs_with_configured_attributes( + config_schema_for_resource, "OptimizationJob", **kwargs + ), + ) + + return wrapper + @classmethod + @populate_inputs_decorator @Base.add_validate_call def create( cls, - notebook_instance_lifecycle_config_name: str, - on_create: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), - on_start: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), + optimization_job_name: str, + role_arn: str, + model_source: OptimizationJobModelSource, + deployment_instance_type: str, + optimization_configs: List[OptimizationConfig], + output_config: OptimizationJobOutputConfig, + stopping_condition: StoppingCondition, + optimization_environment: Optional[Dict[str, str]] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + vpc_config: Optional[OptimizationVpcConfig] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["NotebookInstanceLifecycleConfig"]: + ) -> Optional["OptimizationJob"]: """ - Create a NotebookInstanceLifecycleConfig resource + Create a OptimizationJob resource Parameters: - notebook_instance_lifecycle_config_name: The name of the lifecycle configuration. - on_create: A shell script that runs only once, when you create a notebook instance. The shell script must be a base64-encoded string. - on_start: A shell script that runs every time you start a notebook instance, including when you create the notebook instance. The shell script must be a base64-encoded string. + optimization_job_name: A custom name for the new optimization job. + role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. During model optimization, Amazon SageMaker needs your permission to: Read input data from an S3 bucket Write model artifacts to an S3 bucket Write logs to Amazon CloudWatch Logs Publish metrics to Amazon CloudWatch You grant permissions for all of these tasks to an IAM role. To pass this role to Amazon SageMaker, the caller of this API must have the iam:PassRole permission. For more information, see Amazon SageMaker Roles. + model_source: The location of the source model to optimize with an optimization job. + deployment_instance_type: The type of instance that hosts the optimized model that you create with the optimization job. + optimization_configs: Settings for each of the optimization techniques that the job applies. + output_config: Details for where to store the optimized model that you create with the optimization job. + stopping_condition: + optimization_environment: The environment variables to set in the model container. + tags: A list of key-value pairs associated with the optimization job. For more information, see Tagging Amazon Web Services resources in the Amazon Web Services General Reference Guide. + vpc_config: A VPC in Amazon VPC that your optimized model has access to. session: Boto3 session. region: Region name. Returns: - The NotebookInstanceLifecycleConfig resource. + The OptimizationJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21723,26 +23174,33 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceInUse: Resource being accessed is in use. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - logger.info("Creating notebook_instance_lifecycle_config resource.") + logger.info("Creating optimization_job resource.") client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) operation_input_args = { - "NotebookInstanceLifecycleConfigName": notebook_instance_lifecycle_config_name, - "OnCreate": on_create, - "OnStart": on_start, + "OptimizationJobName": optimization_job_name, + "RoleArn": role_arn, + "ModelSource": model_source, + "DeploymentInstanceType": deployment_instance_type, + "OptimizationEnvironment": optimization_environment, + "OptimizationConfigs": optimization_configs, + "OutputConfig": output_config, + "StoppingCondition": stopping_condition, + "Tags": tags, + "VpcConfig": vpc_config, } operation_input_args = Base.populate_chained_attributes( - resource_name="NotebookInstanceLifecycleConfig", - operation_input_args=operation_input_args, + resource_name="OptimizationJob", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -21751,33 +23209,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_notebook_instance_lifecycle_config(**operation_input_args) + response = client.create_optimization_job(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get( - notebook_instance_lifecycle_config_name=notebook_instance_lifecycle_config_name, - session=session, - region=region, - ) + return cls.get(optimization_job_name=optimization_job_name, session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - notebook_instance_lifecycle_config_name: str, + optimization_job_name: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["NotebookInstanceLifecycleConfig"]: + ) -> Optional["OptimizationJob"]: """ - Get a NotebookInstanceLifecycleConfig resource + Get a OptimizationJob resource Parameters: - notebook_instance_lifecycle_config_name: The name of the lifecycle configuration to describe. + optimization_job_name: The name that you assigned to the optimization job. session: Boto3 session. region: Region name. Returns: - The NotebookInstanceLifecycleConfig resource. + The OptimizationJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21789,10 +23243,11 @@ def get( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "NotebookInstanceLifecycleConfigName": notebook_instance_lifecycle_config_name, + "OptimizationJobName": optimization_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -21801,24 +23256,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_notebook_instance_lifecycle_config(**operation_input_args) + response = client.describe_optimization_job(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeNotebookInstanceLifecycleConfigOutput") - notebook_instance_lifecycle_config = cls(**transformed_response) - return notebook_instance_lifecycle_config + transformed_response = transform(response, "DescribeOptimizationJobResponse") + optimization_job = cls(**transformed_response) + return optimization_job @Base.add_validate_call def refresh( self, - ) -> Optional["NotebookInstanceLifecycleConfig"]: + ) -> Optional["OptimizationJob"]: """ - Refresh a NotebookInstanceLifecycleConfig resource + Refresh a OptimizationJob resource Returns: - The NotebookInstanceLifecycleConfig resource. + The OptimizationJob resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21830,33 +23285,29 @@ def refresh( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ operation_input_args = { - "NotebookInstanceLifecycleConfigName": self.notebook_instance_lifecycle_config_name, + "OptimizationJobName": self.optimization_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_notebook_instance_lifecycle_config(**operation_input_args) + response = client.describe_optimization_job(**operation_input_args) # deserialize response and update self - transform(response, "DescribeNotebookInstanceLifecycleConfigOutput", self) + transform(response, "DescribeOptimizationJobResponse", self) return self @Base.add_validate_call - def update( + def delete( self, - on_create: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), - on_start: Optional[List[NotebookInstanceLifecycleHook]] = Unassigned(), - ) -> Optional["NotebookInstanceLifecycleConfig"]: + ) -> None: """ - Update a NotebookInstanceLifecycleConfig resource - - Returns: - The NotebookInstanceLifecycleConfig resource. + Delete a OptimizationJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21868,35 +23319,26 @@ def update( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. """ - logger.info("Updating notebook_instance_lifecycle_config resource.") client = Base.get_sagemaker_client() operation_input_args = { - "NotebookInstanceLifecycleConfigName": self.notebook_instance_lifecycle_config_name, - "OnCreate": on_create, - "OnStart": on_start, + "OptimizationJobName": self.optimization_job_name, } - logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - # create the resource - response = client.update_notebook_instance_lifecycle_config(**operation_input_args) - logger.debug(f"Response: {response}") - self.refresh() + client.delete_optimization_job(**operation_input_args) - return self + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def delete( - self, - ) -> None: + def stop(self) -> None: """ - Delete a NotebookInstanceLifecycleConfig resource + Stop a OptimizationJob resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21908,53 +23350,119 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. """ - client = Base.get_sagemaker_client() + client = SageMakerClient().client operation_input_args = { - "NotebookInstanceLifecycleConfigName": self.notebook_instance_lifecycle_config_name, + "OptimizationJobName": self.optimization_job_name, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_notebook_instance_lifecycle_config(**operation_input_args) + client.stop_optimization_job(**operation_input_args) - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") + + @Base.add_validate_call + def wait( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a OptimizationJob resource. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + + """ + terminal_states = ["COMPLETED", "FAILED", "STOPPED"] + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for OptimizationJob...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.optimization_job_status + status.update(f"Current status: [bold]{current_status}") + + if current_status in terminal_states: + logger.info(f"Final Resource Status: [bold]{current_status}") + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="OptimizationJob", + status=current_status, + reason=self.failure_reason, + ) + + return + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError( + resouce_type="OptimizationJob", status=current_status + ) + time.sleep(poll) @classmethod @Base.add_validate_call def get_all( cls, - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), creation_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + creation_time_before: Optional[datetime.datetime] = Unassigned(), last_modified_time_after: Optional[datetime.datetime] = Unassigned(), + last_modified_time_before: Optional[datetime.datetime] = Unassigned(), + optimization_contains: Optional[str] = Unassigned(), + name_contains: Optional[str] = Unassigned(), + status_equals: Optional[str] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["NotebookInstanceLifecycleConfig"]: + ) -> ResourceIterator["OptimizationJob"]: """ - Get all NotebookInstanceLifecycleConfig resources + Get all OptimizationJob resources Parameters: - next_token: If the result of a ListNotebookInstanceLifecycleConfigs request was truncated, the response includes a NextToken. To get the next set of lifecycle configurations, use the token in the next request. - max_results: The maximum number of lifecycle configurations to return in the response. - sort_by: Sorts the list of results. The default is CreationTime. - sort_order: The sort order for results. - name_contains: A string in the lifecycle configuration name. This filter returns only lifecycle configurations whose name contains the specified string. - creation_time_before: A filter that returns only lifecycle configurations that were created before the specified time (timestamp). - creation_time_after: A filter that returns only lifecycle configurations that were created after the specified time (timestamp). - last_modified_time_before: A filter that returns only lifecycle configurations that were modified before the specified time (timestamp). - last_modified_time_after: A filter that returns only lifecycle configurations that were modified after the specified time (timestamp). + next_token: A token that you use to get the next set of results following a truncated response. If the response to the previous request was truncated, that response provides the value for this token. + max_results: The maximum number of optimization jobs to return in the response. The default is 50. + creation_time_after: Filters the results to only those optimization jobs that were created after the specified time. + creation_time_before: Filters the results to only those optimization jobs that were created before the specified time. + last_modified_time_after: Filters the results to only those optimization jobs that were updated after the specified time. + last_modified_time_before: Filters the results to only those optimization jobs that were updated before the specified time. + optimization_contains: Filters the results to only those optimization jobs that apply the specified optimization techniques. You can specify either Quantization or Compilation. + name_contains: Filters the results to only those optimization jobs with a name that contains the specified string. + status_equals: Filters the results to only those optimization jobs with the specified status. + sort_by: The field by which to sort the optimization jobs in the response. The default is CreationTime + sort_order: The sort order for results. The default is Ascending session: Boto3 session. region: Region name. Returns: - Iterator for listed NotebookInstanceLifecycleConfig resources. + Iterator for listed OptimizationJob resources. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -21973,13 +23481,15 @@ def get_all( ) operation_input_args = { - "SortBy": sort_by, - "SortOrder": sort_order, - "NameContains": name_contains, - "CreationTimeBefore": creation_time_before, "CreationTimeAfter": creation_time_after, - "LastModifiedTimeBefore": last_modified_time_before, + "CreationTimeBefore": creation_time_before, "LastModifiedTimeAfter": last_modified_time_after, + "LastModifiedTimeBefore": last_modified_time_before, + "OptimizationContains": optimization_contains, + "NameContains": name_contains, + "StatusEquals": status_equals, + "SortBy": sort_by, + "SortOrder": sort_order, } # serialize the input request @@ -21988,60 +23498,54 @@ def get_all( return ResourceIterator( client=client, - list_method="list_notebook_instance_lifecycle_configs", - summaries_key="NotebookInstanceLifecycleConfigs", - summary_name="NotebookInstanceLifecycleConfigSummary", - resource_cls=NotebookInstanceLifecycleConfig, + list_method="list_optimization_jobs", + summaries_key="OptimizationJobSummaries", + summary_name="OptimizationJobSummary", + resource_cls=OptimizationJob, list_method_kwargs=operation_input_args, ) -class OptimizationJob(Base): +class PartnerApp(Base): """ - Class representing resource OptimizationJob + Class representing resource PartnerApp Attributes: - optimization_job_arn: The Amazon Resource Name (ARN) of the optimization job. - optimization_job_status: The current status of the optimization job. - creation_time: The time when you created the optimization job. - last_modified_time: The time when the optimization job was last updated. - optimization_job_name: The name that you assigned to the optimization job. - model_source: The location of the source model to optimize with an optimization job. - deployment_instance_type: The type of instance that hosts the optimized model that you create with the optimization job. - optimization_configs: Settings for each of the optimization techniques that the job applies. - output_config: Details for where to store the optimized model that you create with the optimization job. - role_arn: The ARN of the IAM role that you assigned to the optimization job. - stopping_condition: - optimization_start_time: The time when the optimization job started. - optimization_end_time: The time when the optimization job finished processing. - failure_reason: If the optimization job status is FAILED, the reason for the failure. - optimization_environment: The environment variables to set in the model container. - optimization_output: Output values produced by an optimization job. - vpc_config: A VPC in Amazon VPC that your optimized model has access to. + arn: The ARN of the SageMaker Partner AI App that was described. + name: The name of the SageMaker Partner AI App. + type: The type of SageMaker Partner AI App. Must be one of the following: lakera-guard, comet, deepchecks-llm-evaluation, or fiddler. + status: The status of the SageMaker Partner AI App. + creation_time: The time that the SageMaker Partner AI App was created. + execution_role_arn: The ARN of the IAM role associated with the SageMaker Partner AI App. + base_url: The URL of the SageMaker Partner AI App that the Application SDK uses to support in-app calls for the user. + maintenance_config: Maintenance configuration settings for the SageMaker Partner AI App. + tier: The instance type and size of the cluster attached to the SageMaker Partner AI App. + version: The version of the SageMaker Partner AI App. + application_config: Configuration settings for the SageMaker Partner AI App. + auth_type: The authorization type that users use to access the SageMaker Partner AI App. + enable_iam_session_based_identity: When set to TRUE, the SageMaker Partner AI App sets the Amazon Web Services IAM session name or the authenticated IAM user as the identity of the SageMaker Partner AI App user. + error: This is an error field object that contains the error code and the reason for an operation failure. """ - optimization_job_name: str - optimization_job_arn: Optional[str] = Unassigned() - optimization_job_status: Optional[str] = Unassigned() - optimization_start_time: Optional[datetime.datetime] = Unassigned() - optimization_end_time: Optional[datetime.datetime] = Unassigned() + arn: str + name: Optional[str] = Unassigned() + type: Optional[str] = Unassigned() + status: Optional[str] = Unassigned() creation_time: Optional[datetime.datetime] = Unassigned() - last_modified_time: Optional[datetime.datetime] = Unassigned() - failure_reason: Optional[str] = Unassigned() - model_source: Optional[OptimizationJobModelSource] = Unassigned() - optimization_environment: Optional[Dict[str, str]] = Unassigned() - deployment_instance_type: Optional[str] = Unassigned() - optimization_configs: Optional[List[OptimizationConfig]] = Unassigned() - output_config: Optional[OptimizationJobOutputConfig] = Unassigned() - optimization_output: Optional[OptimizationOutput] = Unassigned() - role_arn: Optional[str] = Unassigned() - stopping_condition: Optional[StoppingCondition] = Unassigned() - vpc_config: Optional[OptimizationVpcConfig] = Unassigned() + execution_role_arn: Optional[str] = Unassigned() + base_url: Optional[str] = Unassigned() + maintenance_config: Optional[PartnerAppMaintenanceConfig] = Unassigned() + tier: Optional[str] = Unassigned() + version: Optional[str] = Unassigned() + application_config: Optional[PartnerAppConfig] = Unassigned() + auth_type: Optional[str] = Unassigned() + enable_iam_session_based_identity: Optional[bool] = Unassigned() + error: Optional[ErrorInfo] = Unassigned() def get_name(self) -> str: attributes = vars(self) - resource_name = "optimization_job_name" + resource_name = "partner_app_name" resource_name_split = resource_name.split("_") attribute_name_candidates = [] @@ -22052,70 +23556,45 @@ def get_name(self) -> str: for attribute, value in attributes.items(): if attribute == "name" or attribute in attribute_name_candidates: return value - logger.error("Name attribute not found for object optimization_job") + logger.error("Name attribute not found for object partner_app") return None - def populate_inputs_decorator(create_func): - @functools.wraps(create_func) - def wrapper(*args, **kwargs): - config_schema_for_resource = { - "model_source": {"s3": {"s3_uri": {"type": "string"}}}, - "output_config": { - "s3_output_location": {"type": "string"}, - "kms_key_id": {"type": "string"}, - }, - "role_arn": {"type": "string"}, - "vpc_config": { - "security_group_ids": {"type": "array", "items": {"type": "string"}}, - "subnets": {"type": "array", "items": {"type": "string"}}, - }, - } - return create_func( - *args, - **Base.get_updated_kwargs_with_configured_attributes( - config_schema_for_resource, "OptimizationJob", **kwargs - ), - ) - - return wrapper - @classmethod - @populate_inputs_decorator @Base.add_validate_call def create( cls, - optimization_job_name: str, - role_arn: str, - model_source: OptimizationJobModelSource, - deployment_instance_type: str, - optimization_configs: List[OptimizationConfig], - output_config: OptimizationJobOutputConfig, - stopping_condition: StoppingCondition, - optimization_environment: Optional[Dict[str, str]] = Unassigned(), + name: str, + type: str, + execution_role_arn: str, + tier: str, + auth_type: str, + maintenance_config: Optional[PartnerAppMaintenanceConfig] = Unassigned(), + application_config: Optional[PartnerAppConfig] = Unassigned(), + enable_iam_session_based_identity: Optional[bool] = Unassigned(), + client_token: Optional[str] = Unassigned(), tags: Optional[List[Tag]] = Unassigned(), - vpc_config: Optional[OptimizationVpcConfig] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["OptimizationJob"]: + ) -> Optional["PartnerApp"]: """ - Create a OptimizationJob resource + Create a PartnerApp resource Parameters: - optimization_job_name: A custom name for the new optimization job. - role_arn: The Amazon Resource Name (ARN) of an IAM role that enables Amazon SageMaker to perform tasks on your behalf. During model optimization, Amazon SageMaker needs your permission to: Read input data from an S3 bucket Write model artifacts to an S3 bucket Write logs to Amazon CloudWatch Logs Publish metrics to Amazon CloudWatch You grant permissions for all of these tasks to an IAM role. To pass this role to Amazon SageMaker, the caller of this API must have the iam:PassRole permission. For more information, see Amazon SageMaker Roles. - model_source: The location of the source model to optimize with an optimization job. - deployment_instance_type: The type of instance that hosts the optimized model that you create with the optimization job. - optimization_configs: Settings for each of the optimization techniques that the job applies. - output_config: Details for where to store the optimized model that you create with the optimization job. - stopping_condition: - optimization_environment: The environment variables to set in the model container. - tags: A list of key-value pairs associated with the optimization job. For more information, see Tagging Amazon Web Services resources in the Amazon Web Services General Reference Guide. - vpc_config: A VPC in Amazon VPC that your optimized model has access to. + name: The name to give the SageMaker Partner AI App. + type: The type of SageMaker Partner AI App to create. Must be one of the following: lakera-guard, comet, deepchecks-llm-evaluation, or fiddler. + execution_role_arn: The ARN of the IAM role that the partner application uses. + tier: Indicates the instance type and size of the cluster attached to the SageMaker Partner AI App. + auth_type: The authorization type that users use to access the SageMaker Partner AI App. + maintenance_config: Maintenance configuration settings for the SageMaker Partner AI App. + application_config: Configuration settings for the SageMaker Partner AI App. + enable_iam_session_based_identity: When set to TRUE, the SageMaker Partner AI App sets the Amazon Web Services IAM session name or the authenticated IAM user as the identity of the SageMaker Partner AI App user. + client_token: A unique token that guarantees that the call to this API is idempotent. + tags: Each tag consists of a key and an optional value. Tag keys must be unique per resource. session: Boto3 session. region: Region name. Returns: - The OptimizationJob resource. + The PartnerApp resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -22127,33 +23606,33 @@ def create( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` - ResourceInUse: Resource being accessed is in use. + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema LocalConfigNotFoundError: Raised when a configuration file is not found in local file system S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - - logger.info("Creating optimization_job resource.") - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - - operation_input_args = { - "OptimizationJobName": optimization_job_name, - "RoleArn": role_arn, - "ModelSource": model_source, - "DeploymentInstanceType": deployment_instance_type, - "OptimizationEnvironment": optimization_environment, - "OptimizationConfigs": optimization_configs, - "OutputConfig": output_config, - "StoppingCondition": stopping_condition, + + logger.info("Creating partner_app resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "Name": name, + "Type": type, + "ExecutionRoleArn": execution_role_arn, + "MaintenanceConfig": maintenance_config, + "Tier": tier, + "ApplicationConfig": application_config, + "AuthType": auth_type, + "EnableIamSessionBasedIdentity": enable_iam_session_based_identity, + "ClientToken": client_token, "Tags": tags, - "VpcConfig": vpc_config, } operation_input_args = Base.populate_chained_attributes( - resource_name="OptimizationJob", operation_input_args=operation_input_args + resource_name="PartnerApp", operation_input_args=operation_input_args ) logger.debug(f"Input request: {operation_input_args}") @@ -22162,29 +23641,29 @@ def create( logger.debug(f"Serialized input request: {operation_input_args}") # create the resource - response = client.create_optimization_job(**operation_input_args) + response = client.create_partner_app(**operation_input_args) logger.debug(f"Response: {response}") - return cls.get(optimization_job_name=optimization_job_name, session=session, region=region) + return cls.get(arn=response["Arn"], session=session, region=region) @classmethod @Base.add_validate_call def get( cls, - optimization_job_name: str, + arn: str, session: Optional[Session] = None, region: Optional[str] = None, - ) -> Optional["OptimizationJob"]: + ) -> Optional["PartnerApp"]: """ - Get a OptimizationJob resource + Get a PartnerApp resource Parameters: - optimization_job_name: The name that you assigned to the optimization job. + arn: The ARN of the SageMaker Partner AI App to describe. session: Boto3 session. region: Region name. Returns: - The OptimizationJob resource. + The PartnerApp resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -22200,7 +23679,7 @@ def get( """ operation_input_args = { - "OptimizationJobName": optimization_job_name, + "Arn": arn, } # serialize the input request operation_input_args = serialize(operation_input_args) @@ -22209,24 +23688,24 @@ def get( client = Base.get_sagemaker_client( session=session, region_name=region, service_name="sagemaker" ) - response = client.describe_optimization_job(**operation_input_args) + response = client.describe_partner_app(**operation_input_args) logger.debug(response) # deserialize the response - transformed_response = transform(response, "DescribeOptimizationJobResponse") - optimization_job = cls(**transformed_response) - return optimization_job + transformed_response = transform(response, "DescribePartnerAppResponse") + partner_app = cls(**transformed_response) + return partner_app @Base.add_validate_call def refresh( self, - ) -> Optional["OptimizationJob"]: + ) -> Optional["PartnerApp"]: """ - Refresh a OptimizationJob resource + Refresh a PartnerApp resource Returns: - The OptimizationJob resource. + The PartnerApp resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -22242,25 +23721,38 @@ def refresh( """ operation_input_args = { - "OptimizationJobName": self.optimization_job_name, + "Arn": self.arn, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") client = Base.get_sagemaker_client() - response = client.describe_optimization_job(**operation_input_args) + response = client.describe_partner_app(**operation_input_args) # deserialize response and update self - transform(response, "DescribeOptimizationJobResponse", self) + transform(response, "DescribePartnerAppResponse", self) return self @Base.add_validate_call - def delete( + def update( self, - ) -> None: + maintenance_config: Optional[PartnerAppMaintenanceConfig] = Unassigned(), + tier: Optional[str] = Unassigned(), + application_config: Optional[PartnerAppConfig] = Unassigned(), + enable_iam_session_based_identity: Optional[bool] = Unassigned(), + client_token: Optional[str] = Unassigned(), + tags: Optional[List[Tag]] = Unassigned(), + ) -> Optional["PartnerApp"]: """ - Delete a OptimizationJob resource + Update a PartnerApp resource + + Parameters: + client_token: A unique token that guarantees that the call to this API is idempotent. + tags: Each tag consists of a key and an optional value. Tag keys must be unique per resource. + + Returns: + The PartnerApp resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -22272,26 +23764,41 @@ def delete( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceNotFound: Resource being access is not found. """ + logger.info("Updating partner_app resource.") client = Base.get_sagemaker_client() operation_input_args = { - "OptimizationJobName": self.optimization_job_name, + "Arn": self.arn, + "MaintenanceConfig": maintenance_config, + "Tier": tier, + "ApplicationConfig": application_config, + "EnableIamSessionBasedIdentity": enable_iam_session_based_identity, + "ClientToken": client_token, + "Tags": tags, } + logger.debug(f"Input request: {operation_input_args}") # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.delete_optimization_job(**operation_input_args) + # create the resource + response = client.update_partner_app(**operation_input_args) + logger.debug(f"Response: {response}") + self.refresh() - logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") + return self @Base.add_validate_call - def stop(self) -> None: + def delete( + self, + client_token: Optional[str] = Unassigned(), + ) -> None: """ - Stop a OptimizationJob resource + Delete a PartnerApp resource Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -22303,32 +23810,38 @@ def stop(self) -> None: error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ConflictException: There was a conflict when you attempted to modify a SageMaker entity such as an Experiment or Artifact. ResourceNotFound: Resource being access is not found. """ - client = SageMakerClient().client + client = Base.get_sagemaker_client() operation_input_args = { - "OptimizationJobName": self.optimization_job_name, + "Arn": self.arn, + "ClientToken": client_token, } # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - client.stop_optimization_job(**operation_input_args) + client.delete_partner_app(**operation_input_args) - logger.info(f"Stopping {self.__class__.__name__} - {self.get_name()}") + logger.info(f"Deleting {self.__class__.__name__} - {self.get_name()}") @Base.add_validate_call - def wait( + def wait_for_status( self, + target_status: Literal[ + "Creating", "Updating", "Deleting", "Available", "Failed", "UpdateFailed", "Deleted" + ], poll: int = 5, timeout: Optional[int] = None, ) -> None: """ - Wait for a OptimizationJob resource. + Wait for a PartnerApp resource to reach certain status. Parameters: + target_status: The status to wait for. poll: The number of seconds to wait between each poll. timeout: The maximum number of seconds to wait before timing out. @@ -22336,9 +23849,7 @@ def wait( TimeoutExceededError: If the resource does not reach a terminal state before the timeout. FailedStatusError: If the resource reaches a failed state. WaiterError: Raised when an error occurs while waiting. - """ - terminal_states = ["COMPLETED", "FAILED", "STOPPED"] start_time = time.time() progress = Progress( @@ -22346,7 +23857,7 @@ def wait( TextColumn("{task.description}"), TimeElapsedColumn(), ) - progress.add_task("Waiting for OptimizationJob...") + progress.add_task(f"Waiting for PartnerApp to reach [bold]{target_status} status...") status = Status("Current status:") with Live( @@ -22359,63 +23870,171 @@ def wait( ): while True: self.refresh() - current_status = self.optimization_job_status + current_status = self.status status.update(f"Current status: [bold]{current_status}") - if current_status in terminal_states: + if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}") - - if "failed" in current_status.lower(): - raise FailedStatusError( - resource_type="OptimizationJob", - status=current_status, - reason=self.failure_reason, - ) - return - if timeout is not None and time.time() - start_time >= timeout: - raise TimeoutExceededError( - resouce_type="OptimizationJob", status=current_status + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="PartnerApp", status=current_status, reason="(Unknown)" ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="PartnerApp", status=current_status) + time.sleep(poll) + + @Base.add_validate_call + def wait_for_delete( + self, + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a PartnerApp resource to be deleted. + + Parameters: + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + DeleteFailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task("Waiting for PartnerApp to be deleted...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ) + ): + while True: + try: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") + + if current_status.lower() == "deleted": + print("Resource was deleted.") + return + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="PartnerApp", status=current_status) + except botocore.exceptions.ClientError as e: + error_code = e.response["Error"]["Code"] + + if "ResourceNotFound" in error_code or "ValidationException" in error_code: + logger.info("Resource was not found. It may have been deleted.") + return + raise e time.sleep(poll) @classmethod @Base.add_validate_call def get_all( cls, - creation_time_after: Optional[datetime.datetime] = Unassigned(), - creation_time_before: Optional[datetime.datetime] = Unassigned(), - last_modified_time_after: Optional[datetime.datetime] = Unassigned(), - last_modified_time_before: Optional[datetime.datetime] = Unassigned(), - optimization_contains: Optional[str] = Unassigned(), - name_contains: Optional[str] = Unassigned(), - status_equals: Optional[str] = Unassigned(), - sort_by: Optional[str] = Unassigned(), - sort_order: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, - ) -> ResourceIterator["OptimizationJob"]: + ) -> ResourceIterator["PartnerApp"]: """ - Get all OptimizationJob resources + Get all PartnerApp resources. Parameters: - next_token: A token that you use to get the next set of results following a truncated response. If the response to the previous request was truncated, that response provides the value for this token. - max_results: The maximum number of optimization jobs to return in the response. The default is 50. - creation_time_after: Filters the results to only those optimization jobs that were created after the specified time. - creation_time_before: Filters the results to only those optimization jobs that were created before the specified time. - last_modified_time_after: Filters the results to only those optimization jobs that were updated after the specified time. - last_modified_time_before: Filters the results to only those optimization jobs that were updated before the specified time. - optimization_contains: Filters the results to only those optimization jobs that apply the specified optimization techniques. You can specify either Quantization or Compilation. - name_contains: Filters the results to only those optimization jobs with a name that contains the specified string. - status_equals: Filters the results to only those optimization jobs with the specified status. - sort_by: The field by which to sort the optimization jobs in the response. The default is CreationTime - sort_order: The sort order for results. The default is Ascending session: Boto3 session. region: Region name. Returns: - Iterator for listed OptimizationJob resources. + Iterator for listed PartnerApp resources. + + """ + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + return ResourceIterator( + client=client, + list_method="list_partner_apps", + summaries_key="Summaries", + summary_name="PartnerAppSummary", + resource_cls=PartnerApp, + ) + + +class PartnerAppPresignedUrl(Base): + """ + Class representing resource PartnerAppPresignedUrl + + Attributes: + arn: The ARN of the SageMaker Partner AI App to create the presigned URL for. + expires_in_seconds: The time that will pass before the presigned URL expires. + session_expiration_duration_in_seconds: Indicates how long the Amazon SageMaker Partner AI App session can be accessed for after logging in. + url: The presigned URL that you can use to access the SageMaker Partner AI App. + + """ + + arn: str + expires_in_seconds: Optional[int] = Unassigned() + session_expiration_duration_in_seconds: Optional[int] = Unassigned() + url: Optional[str] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "partner_app_presigned_url_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object partner_app_presigned_url") + return None + + @classmethod + @Base.add_validate_call + def create( + cls, + arn: str, + expires_in_seconds: Optional[int] = Unassigned(), + session_expiration_duration_in_seconds: Optional[int] = Unassigned(), + ) -> Optional["PartnerAppPresignedUrl"]: + """ + Create a PartnerAppPresignedUrl resource + + Parameters: + arn: The ARN of the SageMaker Partner AI App to create the presigned URL for. + expires_in_seconds: The time that will pass before the presigned URL expires. + session_expiration_duration_in_seconds: Indicates how long the Amazon SageMaker Partner AI App session can be accessed for after logging in. + session: Boto3 session. + region: Region name. + + Returns: + The PartnerAppPresignedUrl resource. Raises: botocore.exceptions.ClientError: This exception is raised for AWS service related errors. @@ -22427,37 +24046,32 @@ def get_all( error_message = e.response['Error']['Message'] error_code = e.response['Error']['Code'] ``` + ResourceNotFound: Resource being access is not found. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 """ - client = Base.get_sagemaker_client( - session=session, region_name=region, service_name="sagemaker" - ) - operation_input_args = { - "CreationTimeAfter": creation_time_after, - "CreationTimeBefore": creation_time_before, - "LastModifiedTimeAfter": last_modified_time_after, - "LastModifiedTimeBefore": last_modified_time_before, - "OptimizationContains": optimization_contains, - "NameContains": name_contains, - "StatusEquals": status_equals, - "SortBy": sort_by, - "SortOrder": sort_order, + "Arn": arn, + "ExpiresInSeconds": expires_in_seconds, + "SessionExpirationDurationInSeconds": session_expiration_duration_in_seconds, } - # serialize the input request operation_input_args = serialize(operation_input_args) logger.debug(f"Serialized input request: {operation_input_args}") - - return ResourceIterator( - client=client, - list_method="list_optimization_jobs", - summaries_key="OptimizationJobSummaries", - summary_name="OptimizationJobSummary", - resource_cls=OptimizationJob, - list_method_kwargs=operation_input_args, + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" ) + logger.debug(f"Calling create_partner_app_presigned_url API") + response = client.create_partner_app_presigned_url(**operation_input_args) + logger.debug(f"Response: {response}") + + transformed_response = transform(response, "CreatePartnerAppPresignedUrlResponse") + return cls(**operation_input_args, **transformed_response) + class Pipeline(Base): """ @@ -26680,6 +28294,7 @@ def get_all( sort_by: Optional[str] = Unassigned(), sort_order: Optional[str] = Unassigned(), warm_pool_status_equals: Optional[str] = Unassigned(), + training_plan_arn_equals: Optional[str] = Unassigned(), session: Optional[Session] = None, region: Optional[str] = None, ) -> ResourceIterator["TrainingJob"]: @@ -26698,6 +28313,7 @@ def get_all( sort_by: The field to sort results by. The default is CreationTime. sort_order: The sort order for results. The default is Ascending. warm_pool_status_equals: A filter that retrieves only training jobs with a specific warm pool status. + training_plan_arn_equals: The Amazon Resource Name (ARN); of the training plan to filter training jobs by. For more information about reserving GPU capacity for your SageMaker training jobs using Amazon SageMaker Training Plan, see CreateTrainingPlan . session: Boto3 session. region: Region name. @@ -26730,6 +28346,7 @@ def get_all( "SortBy": sort_by, "SortOrder": sort_order, "WarmPoolStatusEquals": warm_pool_status_equals, + "TrainingPlanArnEquals": training_plan_arn_equals, } # serialize the input request @@ -26746,6 +28363,341 @@ def get_all( ) +class TrainingPlan(Base): + """ + Class representing resource TrainingPlan + + Attributes: + training_plan_arn: The Amazon Resource Name (ARN); of the training plan. + training_plan_name: The name of the training plan. + status: The current status of the training plan (e.g., Pending, Active, Expired). To see the complete list of status values available for a training plan, refer to the Status attribute within the TrainingPlanSummary object. + status_message: A message providing additional information about the current status of the training plan. + duration_hours: The number of whole hours in the total duration for this training plan. + duration_minutes: The additional minutes beyond whole hours in the total duration for this training plan. + start_time: The start time of the training plan. + end_time: The end time of the training plan. + upfront_fee: The upfront fee for the training plan. + currency_code: The currency code for the upfront fee (e.g., USD). + total_instance_count: The total number of instances reserved in this training plan. + available_instance_count: The number of instances currently available for use in this training plan. + in_use_instance_count: The number of instances currently in use from this training plan. + target_resources: The target resources (e.g., SageMaker Training Jobs, SageMaker HyperPod) that can use this training plan. Training plans are specific to their target resource. A training plan designed for SageMaker training jobs can only be used to schedule and run training jobs. A training plan for HyperPod clusters can be used exclusively to provide compute resources to a cluster's instance group. + reserved_capacity_summaries: The list of Reserved Capacity providing the underlying compute resources of the plan. + + """ + + training_plan_name: str + training_plan_arn: Optional[str] = Unassigned() + status: Optional[str] = Unassigned() + status_message: Optional[str] = Unassigned() + duration_hours: Optional[int] = Unassigned() + duration_minutes: Optional[int] = Unassigned() + start_time: Optional[datetime.datetime] = Unassigned() + end_time: Optional[datetime.datetime] = Unassigned() + upfront_fee: Optional[str] = Unassigned() + currency_code: Optional[str] = Unassigned() + total_instance_count: Optional[int] = Unassigned() + available_instance_count: Optional[int] = Unassigned() + in_use_instance_count: Optional[int] = Unassigned() + target_resources: Optional[List[str]] = Unassigned() + reserved_capacity_summaries: Optional[List[ReservedCapacitySummary]] = Unassigned() + + def get_name(self) -> str: + attributes = vars(self) + resource_name = "training_plan_name" + resource_name_split = resource_name.split("_") + attribute_name_candidates = [] + + l = len(resource_name_split) + for i in range(0, l): + attribute_name_candidates.append("_".join(resource_name_split[i:l])) + + for attribute, value in attributes.items(): + if attribute == "name" or attribute in attribute_name_candidates: + return value + logger.error("Name attribute not found for object training_plan") + return None + + @classmethod + @Base.add_validate_call + def create( + cls, + training_plan_name: str, + training_plan_offering_id: str, + tags: Optional[List[Tag]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["TrainingPlan"]: + """ + Create a TrainingPlan resource + + Parameters: + training_plan_name: The name of the training plan to create. + training_plan_offering_id: The unique identifier of the training plan offering to use for creating this plan. + tags: An array of key-value pairs to apply to this training plan. + session: Boto3 session. + region: Region name. + + Returns: + The TrainingPlan resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceInUse: Resource being accessed is in use. + ResourceLimitExceeded: You have exceeded an SageMaker resource limit. For example, you might have too many training jobs created. + ResourceNotFound: Resource being access is not found. + ConfigSchemaValidationError: Raised when a configuration file does not adhere to the schema + LocalConfigNotFoundError: Raised when a configuration file is not found in local file system + S3ConfigNotFoundError: Raised when a configuration file is not found in S3 + """ + + logger.info("Creating training_plan resource.") + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "TrainingPlanName": training_plan_name, + "TrainingPlanOfferingId": training_plan_offering_id, + "Tags": tags, + } + + operation_input_args = Base.populate_chained_attributes( + resource_name="TrainingPlan", operation_input_args=operation_input_args + ) + + logger.debug(f"Input request: {operation_input_args}") + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + # create the resource + response = client.create_training_plan(**operation_input_args) + logger.debug(f"Response: {response}") + + return cls.get(training_plan_name=training_plan_name, session=session, region=region) + + @classmethod + @Base.add_validate_call + def get( + cls, + training_plan_name: str, + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> Optional["TrainingPlan"]: + """ + Get a TrainingPlan resource + + Parameters: + training_plan_name: The name of the training plan to describe. + session: Boto3 session. + region: Region name. + + Returns: + The TrainingPlan resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "TrainingPlanName": training_plan_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + response = client.describe_training_plan(**operation_input_args) + + logger.debug(response) + + # deserialize the response + transformed_response = transform(response, "DescribeTrainingPlanResponse") + training_plan = cls(**transformed_response) + return training_plan + + @Base.add_validate_call + def refresh( + self, + ) -> Optional["TrainingPlan"]: + """ + Refresh a TrainingPlan resource + + Returns: + The TrainingPlan resource. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + ResourceNotFound: Resource being access is not found. + """ + + operation_input_args = { + "TrainingPlanName": self.training_plan_name, + } + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + client = Base.get_sagemaker_client() + response = client.describe_training_plan(**operation_input_args) + + # deserialize response and update self + transform(response, "DescribeTrainingPlanResponse", self) + return self + + @Base.add_validate_call + def wait_for_status( + self, + target_status: Literal["Pending", "Active", "Scheduled", "Expired", "Failed"], + poll: int = 5, + timeout: Optional[int] = None, + ) -> None: + """ + Wait for a TrainingPlan resource to reach certain status. + + Parameters: + target_status: The status to wait for. + poll: The number of seconds to wait between each poll. + timeout: The maximum number of seconds to wait before timing out. + + Raises: + TimeoutExceededError: If the resource does not reach a terminal state before the timeout. + FailedStatusError: If the resource reaches a failed state. + WaiterError: Raised when an error occurs while waiting. + """ + start_time = time.time() + + progress = Progress( + SpinnerColumn("bouncingBar"), + TextColumn("{task.description}"), + TimeElapsedColumn(), + ) + progress.add_task(f"Waiting for TrainingPlan to reach [bold]{target_status} status...") + status = Status("Current status:") + + with Live( + Panel( + Group(progress, status), + title="Wait Log Panel", + border_style=Style(color=Color.BLUE.value), + ), + transient=True, + ): + while True: + self.refresh() + current_status = self.status + status.update(f"Current status: [bold]{current_status}") + + if target_status == current_status: + logger.info(f"Final Resource Status: [bold]{current_status}") + return + + if "failed" in current_status.lower(): + raise FailedStatusError( + resource_type="TrainingPlan", + status=current_status, + reason=self.status_message, + ) + + if timeout is not None and time.time() - start_time >= timeout: + raise TimeoutExceededError(resouce_type="TrainingPlan", status=current_status) + time.sleep(poll) + + @classmethod + @Base.add_validate_call + def get_all( + cls, + start_time_after: Optional[datetime.datetime] = Unassigned(), + start_time_before: Optional[datetime.datetime] = Unassigned(), + sort_by: Optional[str] = Unassigned(), + sort_order: Optional[str] = Unassigned(), + filters: Optional[List[TrainingPlanFilter]] = Unassigned(), + session: Optional[Session] = None, + region: Optional[str] = None, + ) -> ResourceIterator["TrainingPlan"]: + """ + Get all TrainingPlan resources + + Parameters: + next_token: A token to continue pagination if more results are available. + max_results: The maximum number of results to return in the response. + start_time_after: Filter to list only training plans with an actual start time after this date. + start_time_before: Filter to list only training plans with an actual start time before this date. + sort_by: The training plan field to sort the results by (e.g., StartTime, Status). + sort_order: The order to sort the results (Ascending or Descending). + filters: Additional filters to apply to the list of training plans. + session: Boto3 session. + region: Region name. + + Returns: + Iterator for listed TrainingPlan resources. + + Raises: + botocore.exceptions.ClientError: This exception is raised for AWS service related errors. + The error message and error code can be parsed from the exception as follows: + ``` + try: + # AWS service call here + except botocore.exceptions.ClientError as e: + error_message = e.response['Error']['Message'] + error_code = e.response['Error']['Code'] + ``` + """ + + client = Base.get_sagemaker_client( + session=session, region_name=region, service_name="sagemaker" + ) + + operation_input_args = { + "StartTimeAfter": start_time_after, + "StartTimeBefore": start_time_before, + "SortBy": sort_by, + "SortOrder": sort_order, + "Filters": filters, + } + + # serialize the input request + operation_input_args = serialize(operation_input_args) + logger.debug(f"Serialized input request: {operation_input_args}") + + return ResourceIterator( + client=client, + list_method="list_training_plans", + summaries_key="TrainingPlanSummaries", + summary_name="TrainingPlanSummary", + resource_cls=TrainingPlan, + list_method_kwargs=operation_input_args, + ) + + class TransformJob(Base): """ Class representing resource TransformJob diff --git a/src/sagemaker_core/main/shapes.py b/src/sagemaker_core/main/shapes.py index d1cba353..36c34d51 100644 --- a/src/sagemaker_core/main/shapes.py +++ b/src/sagemaker_core/main/shapes.py @@ -967,6 +967,7 @@ class ResourceConfig(Base): volume_kms_key_id: The Amazon Web Services KMS key that SageMaker uses to encrypt data on the storage volume attached to the ML compute instance(s) that run the training job. Certain Nitro-based instances include local storage, dependent on the instance type. Local storage volumes are encrypted using a hardware module on the instance. You can't request a VolumeKmsKeyId when using an instance type with local storage. For a list of instance types that support local instance storage, see Instance Store Volumes. For more information about local instance storage encryption, see SSD Instance Store Volumes. The VolumeKmsKeyId can be in any of the following formats: // KMS Key ID "1234abcd-12ab-34cd-56ef-1234567890ab" // Amazon Resource Name (ARN) of a KMS Key "arn:aws:kms:us-west-2:111122223333:key/1234abcd-12ab-34cd-56ef-1234567890ab" keep_alive_period_in_seconds: The duration of time in seconds to retain configured resources in a warm pool for subsequent training jobs. instance_groups: The configuration of a heterogeneous cluster in JSON format. + training_plan_arn: The Amazon Resource Name (ARN); of the training plan to use for this resource configuration. """ volume_size_in_gb: int @@ -975,6 +976,7 @@ class ResourceConfig(Base): volume_kms_key_id: Optional[str] = Unassigned() keep_alive_period_in_seconds: Optional[int] = Unassigned() instance_groups: Optional[List[InstanceGroup]] = Unassigned() + training_plan_arn: Optional[str] = Unassigned() class StoppingCondition(Base): @@ -3140,6 +3142,9 @@ class ClusterInstanceGroupDetails(Base): threads_per_core: The number you specified to TreadsPerCore in CreateCluster for enabling or disabling multithreading. For instance types that support multithreading, you can specify 1 for disabling multithreading and 2 for enabling multithreading. For more information, see the reference table of CPU cores and threads per CPU core per instance type in the Amazon Elastic Compute Cloud User Guide. instance_storage_configs: The additional storage configurations for the instances in the SageMaker HyperPod cluster instance group. on_start_deep_health_checks: A flag indicating whether deep health checks should be performed when the cluster instance group is created or updated. + status: The current status of the cluster instance group. InService: The instance group is active and healthy. Creating: The instance group is being provisioned. Updating: The instance group is being updated. Failed: The instance group has failed to provision or is no longer healthy. Degraded: The instance group is degraded, meaning that some instances have failed to provision or are no longer healthy. Deleting: The instance group is being deleted. + training_plan_arn: The Amazon Resource Name (ARN); of the training plan associated with this cluster instance group. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . + training_plan_status: The current status of the training plan associated with this cluster instance group. override_vpc_config """ @@ -3152,6 +3157,9 @@ class ClusterInstanceGroupDetails(Base): threads_per_core: Optional[int] = Unassigned() instance_storage_configs: Optional[List[ClusterInstanceStorageConfig]] = Unassigned() on_start_deep_health_checks: Optional[List[str]] = Unassigned() + status: Optional[str] = Unassigned() + training_plan_arn: Optional[str] = Unassigned() + training_plan_status: Optional[str] = Unassigned() override_vpc_config: Optional[VpcConfig] = Unassigned() @@ -3170,6 +3178,7 @@ class ClusterInstanceGroupSpecification(Base): threads_per_core: Specifies the value for Threads per core. For instance types that support multithreading, you can specify 1 for disabling multithreading and 2 for enabling multithreading. For instance types that doesn't support multithreading, specify 1. For more information, see the reference table of CPU cores and threads per CPU core per instance type in the Amazon Elastic Compute Cloud User Guide. instance_storage_configs: Specifies the additional storage configurations for the instances in the SageMaker HyperPod cluster instance group. on_start_deep_health_checks: A flag indicating whether deep health checks should be performed when the cluster instance group is created or updated. + training_plan_arn: The Amazon Resource Name (ARN); of the training plan to use for this cluster instance group. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . override_vpc_config """ @@ -3181,6 +3190,7 @@ class ClusterInstanceGroupSpecification(Base): threads_per_core: Optional[int] = Unassigned() instance_storage_configs: Optional[List[ClusterInstanceStorageConfig]] = Unassigned() on_start_deep_health_checks: Optional[List[str]] = Unassigned() + training_plan_arn: Optional[str] = Unassigned() override_vpc_config: Optional[VpcConfig] = Unassigned() @@ -3296,6 +3306,33 @@ class ClusterOrchestrator(Base): eks: ClusterOrchestratorEksConfig +class ClusterSchedulerConfigSummary(Base): + """ + ClusterSchedulerConfigSummary + Summary of the cluster policy. + + Attributes + ---------------------- + cluster_scheduler_config_arn: ARN of the cluster policy. + cluster_scheduler_config_id: ID of the cluster policy. + cluster_scheduler_config_version: Version of the cluster policy. + name: Name of the cluster policy. + creation_time: Creation time of the cluster policy. + last_modified_time: Last modified time of the cluster policy. + status: Status of the cluster policy. + cluster_arn: ARN of the cluster. + """ + + cluster_scheduler_config_arn: str + cluster_scheduler_config_id: str + name: str + creation_time: datetime.datetime + status: str + cluster_scheduler_config_version: Optional[int] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + cluster_arn: Optional[str] = Unassigned() + + class ClusterSummary(Base): """ ClusterSummary @@ -3307,12 +3344,14 @@ class ClusterSummary(Base): cluster_name: The name of the SageMaker HyperPod cluster. creation_time: The time when the SageMaker HyperPod cluster is created. cluster_status: The status of the SageMaker HyperPod cluster. + training_plan_arns: A list of Amazon Resource Names (ARNs) of the training plans associated with this cluster. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . """ cluster_arn: str cluster_name: Union[str, object] creation_time: datetime.datetime cluster_status: str + training_plan_arns: Optional[List[str]] = Unassigned() class CustomImage(Base): @@ -3510,6 +3549,101 @@ class CompilationJobSummary(Base): last_modified_time: Optional[datetime.datetime] = Unassigned() +class ComputeQuotaResourceConfig(Base): + """ + ComputeQuotaResourceConfig + Configuration of the resources used for the compute allocation definition. + + Attributes + ---------------------- + instance_type: The instance type of the instance group for the cluster. + count: The number of instances to add to the instance group of a SageMaker HyperPod cluster. + """ + + instance_type: str + count: int + + +class ResourceSharingConfig(Base): + """ + ResourceSharingConfig + Resource sharing configuration. + + Attributes + ---------------------- + strategy: The strategy of how idle compute is shared within the cluster. The following are the options of strategies. DontLend: entities do not lend idle compute. Lend: entities can lend idle compute to entities that can borrow. LendandBorrow: entities can lend idle compute and borrow idle compute from other entities. Default is LendandBorrow. + borrow_limit: The limit on how much idle compute can be borrowed.The values can be 1 - 500 percent of idle compute that the team is allowed to borrow. Default is 50. + """ + + strategy: str + borrow_limit: Optional[int] = Unassigned() + + +class ComputeQuotaConfig(Base): + """ + ComputeQuotaConfig + Configuration of the compute allocation definition for an entity. This includes the resource sharing option and the setting to preempt low priority tasks. + + Attributes + ---------------------- + compute_quota_resources: Allocate compute resources by instance types. + resource_sharing_config: Resource sharing configuration. This defines how an entity can lend and borrow idle compute with other entities within the cluster. + preempt_team_tasks: Allows workloads from within an entity to preempt same-team workloads. When set to LowerPriority, the entity's lower priority tasks are preempted by their own higher priority tasks. Default is LowerPriority. + """ + + compute_quota_resources: Optional[List[ComputeQuotaResourceConfig]] = Unassigned() + resource_sharing_config: Optional[ResourceSharingConfig] = Unassigned() + preempt_team_tasks: Optional[str] = Unassigned() + + +class ComputeQuotaTarget(Base): + """ + ComputeQuotaTarget + The target entity to allocate compute resources to. + + Attributes + ---------------------- + team_name: Name of the team to allocate compute resources to. + fair_share_weight: Assigned entity fair-share weight. Idle compute will be shared across entities based on these assigned weights. This weight is only used when FairShare is enabled. A weight of 0 is the lowest priority and 100 is the highest. Weight 0 is the default. + """ + + team_name: str + fair_share_weight: Optional[int] = Unassigned() + + +class ComputeQuotaSummary(Base): + """ + ComputeQuotaSummary + Summary of the compute allocation definition. + + Attributes + ---------------------- + compute_quota_arn: ARN of the compute allocation definition. + compute_quota_id: ID of the compute allocation definition. + name: Name of the compute allocation definition. + compute_quota_version: Version of the compute allocation definition. + status: Status of the compute allocation definition. + cluster_arn: ARN of the cluster. + compute_quota_config: Configuration of the compute allocation definition. This includes the resource sharing option, and the setting to preempt low priority tasks. + compute_quota_target: The target entity to allocate compute resources to. + activation_state: The state of the compute allocation being described. Use to enable or disable compute allocation. Default is Enabled. + creation_time: Creation time of the compute allocation definition. + last_modified_time: Last modified time of the compute allocation definition. + """ + + compute_quota_arn: str + compute_quota_id: str + name: str + status: str + compute_quota_target: ComputeQuotaTarget + creation_time: datetime.datetime + compute_quota_version: Optional[int] = Unassigned() + cluster_arn: Optional[str] = Unassigned() + compute_quota_config: Optional[ComputeQuotaConfig] = Unassigned() + activation_state: Optional[str] = Unassigned() + last_modified_time: Optional[datetime.datetime] = Unassigned() + + class ConditionStepMetadata(Base): """ ConditionStepMetadata @@ -3840,6 +3974,36 @@ class ModelDeployConfig(Base): endpoint_name: Optional[Union[str, object]] = Unassigned() +class PriorityClass(Base): + """ + PriorityClass + Priority class configuration. When included in PriorityClasses, these class configurations define how tasks are queued. + + Attributes + ---------------------- + name: Name of the priority class. + weight: Weight of the priority class. The value is within a range from 0 to 100, where 0 is the default. A weight of 0 is the lowest priority and 100 is the highest. Weight 0 is the default. + """ + + name: str + weight: int + + +class SchedulerConfig(Base): + """ + SchedulerConfig + Cluster policy configuration. This policy is used for task prioritization and fair-share allocation. This helps prioritize critical workloads and distributes idle compute across entities. + + Attributes + ---------------------- + priority_classes: List of the priority classes, PriorityClass, of the cluster policy. When specified, these class configurations define how tasks are queued. + fair_share: When enabled, entities borrow idle compute based on their assigned FairShareWeight. When disabled, entities borrow idle compute based on a first-come first-serve basis. Default is Enabled. + """ + + priority_classes: Optional[List[PriorityClass]] = Unassigned() + fair_share: Optional[str] = Unassigned() + + class InputConfig(Base): """ InputConfig @@ -4344,6 +4508,21 @@ class EFSFileSystemConfig(Base): file_system_path: Optional[str] = Unassigned() +class FSxLustreFileSystemConfig(Base): + """ + FSxLustreFileSystemConfig + The settings for assigning a custom Amazon FSx for Lustre file system to a user profile or space for an Amazon SageMaker Domain. + + Attributes + ---------------------- + file_system_id: The globally unique, 17-digit, ID of the file system, assigned by Amazon FSx for Lustre. + file_system_path: The path to the file system directory that is accessible in Amazon SageMaker Studio. Permitted users can access only this directory and below. + """ + + file_system_id: str + file_system_path: Optional[str] = Unassigned() + + class CustomFileSystemConfig(Base): """ CustomFileSystemConfig @@ -4352,9 +4531,11 @@ class CustomFileSystemConfig(Base): Attributes ---------------------- efs_file_system_config: The settings for a custom Amazon EFS file system. + f_sx_lustre_file_system_config: The settings for a custom Amazon FSx for Lustre file system. """ efs_file_system_config: Optional[EFSFileSystemConfig] = Unassigned() + f_sx_lustre_file_system_config: Optional[FSxLustreFileSystemConfig] = Unassigned() class HiddenSageMakerImage(Base): @@ -6701,6 +6882,34 @@ class OptimizationVpcConfig(Base): subnets: List[str] +class PartnerAppMaintenanceConfig(Base): + """ + PartnerAppMaintenanceConfig + Maintenance configuration settings for the SageMaker Partner AI App. + + Attributes + ---------------------- + maintenance_window_start: The day and time of the week in Coordinated Universal Time (UTC) 24-hour standard time that weekly maintenance updates are scheduled. This value must take the following format: 3-letter-day:24-h-hour:minute. For example: TUE:03:30. + """ + + maintenance_window_start: Optional[str] = Unassigned() + + +class PartnerAppConfig(Base): + """ + PartnerAppConfig + Configuration settings for the SageMaker Partner AI App. + + Attributes + ---------------------- + admin_users: The list of users that are given admin access to the SageMaker Partner AI App. + arguments: This is a map of required inputs for a SageMaker Partner AI App. Based on the application type, the map is populated with a key and value pair that is specific to the user and application. + """ + + admin_users: Optional[List[str]] = Unassigned() + arguments: Optional[Dict[str, str]] = Unassigned() + + class PipelineDefinitionS3Location(Base): """ PipelineDefinitionS3Location @@ -7082,6 +7291,19 @@ class EFSFileSystem(Base): file_system_id: str +class FSxLustreFileSystem(Base): + """ + FSxLustreFileSystem + A custom file system in Amazon FSx for Lustre. + + Attributes + ---------------------- + file_system_id: Amazon FSx for Lustre file system ID. + """ + + file_system_id: str + + class CustomFileSystem(Base): """ CustomFileSystem @@ -7090,9 +7312,11 @@ class CustomFileSystem(Base): Attributes ---------------------- efs_file_system: A custom file system in Amazon EFS. + f_sx_lustre_file_system: A custom file system in Amazon FSx for Lustre. """ efs_file_system: Optional[EFSFileSystem] = Unassigned() + f_sx_lustre_file_system: Optional[FSxLustreFileSystem] = Unassigned() class SpaceSettings(Base): @@ -8464,6 +8688,21 @@ class OptimizationOutput(Base): recommended_inference_image: Optional[str] = Unassigned() +class ErrorInfo(Base): + """ + ErrorInfo + This is an error field object that contains the error code and the reason for an operation failure. + + Attributes + ---------------------- + code: The error code for an invalid or failed operation. + reason: The failure reason for the operation. + """ + + code: Optional[str] = Unassigned() + reason: Optional[str] = Unassigned() + + class DescribePipelineDefinitionForExecutionResponse(Base): """ DescribePipelineDefinitionForExecutionResponse @@ -8631,6 +8870,35 @@ class ProfilerRuleEvaluationStatus(Base): last_modified_time: Optional[datetime.datetime] = Unassigned() +class ReservedCapacitySummary(Base): + """ + ReservedCapacitySummary + Details of a reserved capacity for the training plan. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . + + Attributes + ---------------------- + reserved_capacity_arn: The Amazon Resource Name (ARN); of the reserved capacity. + instance_type: The instance type for the reserved capacity. + total_instance_count: The total number of instances in the reserved capacity. + status: The current status of the reserved capacity. + availability_zone: The availability zone for the reserved capacity. + duration_hours: The number of whole hours in the total duration for this reserved capacity. + duration_minutes: The additional minutes beyond whole hours in the total duration for this reserved capacity. + start_time: The start time of the reserved capacity. + end_time: The end time of the reserved capacity. + """ + + reserved_capacity_arn: str + instance_type: str + total_instance_count: int + status: str + availability_zone: Optional[str] = Unassigned() + duration_hours: Optional[int] = Unassigned() + duration_minutes: Optional[int] = Unassigned() + start_time: Optional[datetime.datetime] = Unassigned() + end_time: Optional[datetime.datetime] = Unassigned() + + class TrialComponentSource(Base): """ TrialComponentSource @@ -10537,6 +10805,27 @@ class OptimizationJobSummary(Base): last_modified_time: Optional[datetime.datetime] = Unassigned() +class PartnerAppSummary(Base): + """ + PartnerAppSummary + A subset of information related to a SageMaker Partner AI App. This information is used as part of the ListPartnerApps API response. + + Attributes + ---------------------- + arn: The ARN of the SageMaker Partner AI App. + name: The name of the SageMaker Partner AI App. + type: The type of SageMaker Partner AI App to create. Must be one of the following: lakera-guard, comet, deepchecks-llm-evaluation, or fiddler. + status: The status of the SageMaker Partner AI App. + creation_time: The creation time of the SageMaker Partner AI App. + """ + + arn: Optional[str] = Unassigned() + name: Optional[str] = Unassigned() + type: Optional[str] = Unassigned() + status: Optional[str] = Unassigned() + creation_time: Optional[datetime.datetime] = Unassigned() + + class TrainingJobStepMetadata(Base): """ TrainingJobStepMetadata @@ -10975,6 +11264,7 @@ class TrainingJobSummary(Base): training_job_status: The status of the training job. secondary_status: The secondary status of the training job. warm_pool_status: The status of the warm pool associated with the training job. + training_plan_arn: The Amazon Resource Name (ARN); of the training plan associated with this training job. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . """ training_job_name: Union[str, object] @@ -10985,6 +11275,63 @@ class TrainingJobSummary(Base): last_modified_time: Optional[datetime.datetime] = Unassigned() secondary_status: Optional[str] = Unassigned() warm_pool_status: Optional[WarmPoolStatus] = Unassigned() + training_plan_arn: Optional[str] = Unassigned() + + +class TrainingPlanFilter(Base): + """ + TrainingPlanFilter + A filter to apply when listing or searching for training plans. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . + + Attributes + ---------------------- + name: The name of the filter field (e.g., Status, InstanceType). + value: The value to filter by for the specified field. + """ + + name: str + value: str + + +class TrainingPlanSummary(Base): + """ + TrainingPlanSummary + Details of the training plan. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . + + Attributes + ---------------------- + training_plan_arn: The Amazon Resource Name (ARN); of the training plan. + training_plan_name: The name of the training plan. + status: The current status of the training plan (e.g., Pending, Active, Expired). To see the complete list of status values available for a training plan, refer to the Status attribute within the TrainingPlanSummary object. + status_message: A message providing additional information about the current status of the training plan. + duration_hours: The number of whole hours in the total duration for this training plan. + duration_minutes: The additional minutes beyond whole hours in the total duration for this training plan. + start_time: The start time of the training plan. + end_time: The end time of the training plan. + upfront_fee: The upfront fee for the training plan. + currency_code: The currency code for the upfront fee (e.g., USD). + total_instance_count: The total number of instances reserved in this training plan. + available_instance_count: The number of instances currently available for use in this training plan. + in_use_instance_count: The number of instances currently in use from this training plan. + target_resources: The target resources (e.g., training jobs, HyperPod clusters) that can use this training plan. Training plans are specific to their target resource. A training plan designed for SageMaker training jobs can only be used to schedule and run training jobs. A training plan for HyperPod clusters can be used exclusively to provide compute resources to a cluster's instance group. + reserved_capacity_summaries: A list of reserved capacities associated with this training plan, including details such as instance types, counts, and availability zones. + """ + + training_plan_arn: str + training_plan_name: Union[str, object] + status: str + status_message: Optional[str] = Unassigned() + duration_hours: Optional[int] = Unassigned() + duration_minutes: Optional[int] = Unassigned() + start_time: Optional[datetime.datetime] = Unassigned() + end_time: Optional[datetime.datetime] = Unassigned() + upfront_fee: Optional[str] = Unassigned() + currency_code: Optional[str] = Unassigned() + total_instance_count: Optional[int] = Unassigned() + available_instance_count: Optional[int] = Unassigned() + in_use_instance_count: Optional[int] = Unassigned() + target_resources: Optional[List[str]] = Unassigned() + reserved_capacity_summaries: Optional[List[ReservedCapacitySummary]] = Unassigned() class TransformJobSummary(Base): @@ -11741,6 +12088,31 @@ class RenderingError(Base): message: str +class ReservedCapacityOffering(Base): + """ + ReservedCapacityOffering + Details about a reserved capacity offering for a training plan offering. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . + + Attributes + ---------------------- + instance_type: The instance type for the reserved capacity offering. + instance_count: The number of instances in the reserved capacity offering. + availability_zone: The availability zone for the reserved capacity offering. + duration_hours: The number of whole hours in the total duration for this reserved capacity offering. + duration_minutes: The additional minutes beyond whole hours in the total duration for this reserved capacity offering. + start_time: The start time of the reserved capacity offering. + end_time: The end time of the reserved capacity offering. + """ + + instance_type: str + instance_count: int + availability_zone: Optional[str] = Unassigned() + duration_hours: Optional[int] = Unassigned() + duration_minutes: Optional[int] = Unassigned() + start_time: Optional[datetime.datetime] = Unassigned() + end_time: Optional[datetime.datetime] = Unassigned() + + class ResourceConfigForUpdate(Base): """ ResourceConfigForUpdate @@ -12070,6 +12442,35 @@ class VisibilityConditions(Base): value: Optional[str] = Unassigned() +class TrainingPlanOffering(Base): + """ + TrainingPlanOffering + Details about a training plan offering. For more information about how to reserve GPU capacity for your SageMaker HyperPod clusters using Amazon SageMaker Training Plan, see CreateTrainingPlan . + + Attributes + ---------------------- + training_plan_offering_id: The unique identifier for this training plan offering. + target_resources: The target resources (e.g., SageMaker Training Jobs, SageMaker HyperPod) for this training plan offering. Training plans are specific to their target resource. A training plan designed for SageMaker training jobs can only be used to schedule and run training jobs. A training plan for HyperPod clusters can be used exclusively to provide compute resources to a cluster's instance group. + requested_start_time_after: The requested start time that the user specified when searching for the training plan offering. + requested_end_time_before: The requested end time that the user specified when searching for the training plan offering. + duration_hours: The number of whole hours in the total duration for this training plan offering. + duration_minutes: The additional minutes beyond whole hours in the total duration for this training plan offering. + upfront_fee: The upfront fee for this training plan offering. + currency_code: The currency code for the upfront fee (e.g., USD). + reserved_capacity_offerings: A list of reserved capacity offerings associated with this training plan offering. + """ + + training_plan_offering_id: str + target_resources: List[str] + requested_start_time_after: Optional[datetime.datetime] = Unassigned() + requested_end_time_before: Optional[datetime.datetime] = Unassigned() + duration_hours: Optional[int] = Unassigned() + duration_minutes: Optional[int] = Unassigned() + upfront_fee: Optional[str] = Unassigned() + currency_code: Optional[str] = Unassigned() + reserved_capacity_offerings: Optional[List[ReservedCapacityOffering]] = Unassigned() + + class ServiceCatalogProvisioningUpdateDetails(Base): """ ServiceCatalogProvisioningUpdateDetails diff --git a/src/sagemaker_core/tools/api_coverage.json b/src/sagemaker_core/tools/api_coverage.json index 4d6fdbd0..833ea347 100644 --- a/src/sagemaker_core/tools/api_coverage.json +++ b/src/sagemaker_core/tools/api_coverage.json @@ -1 +1 @@ -{"SupportedAPIs": 340, "UnsupportedAPIs": 5} \ No newline at end of file +{"SupportedAPIs": 359, "UnsupportedAPIs": 6} \ No newline at end of file diff --git a/tst/generated/test_resources.py b/tst/generated/test_resources.py index 90b14d47..081247ec 100644 --- a/tst/generated/test_resources.py +++ b/tst/generated/test_resources.py @@ -112,6 +112,7 @@ def test_resources(self, session, mock_transform): "JobDefinitionSummaries": [summary], f"{name}SummaryList": [summary], f"{name}s": [summary], + f"Summaries": [summary], } if name == "MlflowTrackingServer": summary_response = {"TrackingServerSummaries": [summary]}