Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4206,7 +4206,7 @@ def get_model_package_args(
description (str): Model Package description (default: None).
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs
(default: None).
container_def_list (list): A list of container defintiions (default: None).
container_def_list (list): A list of container definitions (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
Expand Down
8 changes: 7 additions & 1 deletion src/sagemaker/workflow/_repack_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,15 @@
# we'll go ahead and use the copy_tree function anyways because this
# repacking is some short-lived hackery, right??
from distutils.dir_util import copy_tree
from typing import Optional


def repack(inference_script, model_archive, dependencies=None, source_dir=None): # pragma: no cover
def repack(
inference_script: str,
model_archive: str,
dependencies: Optional[str] = None,
source_dir: Optional[str] = None,
): # pragma: no cover
"""Repack custom dependencies and code into an existing model TAR archive

Args:
Expand Down
71 changes: 39 additions & 32 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
import shutil
import tarfile
import tempfile
from typing import List, Union
from typing import List, Union, Dict, Optional
from sagemaker.session import Session
from sagemaker import image_uris
from sagemaker.inputs import TrainingInput
from sagemaker.estimator import EstimatorBase
Expand All @@ -33,6 +34,9 @@
)
from sagemaker.utils import _save_model, download_file_from_url
from sagemaker.workflow.retry import RetryPolicy
from sagemaker.model_metrics import ModelMetrics
from sagemaker.metadata_properties import MetadataProperties
from sagemaker.drift_check_baselines import DriftCheckBaselines

FRAMEWORK_VERSION = "0.23-1"
INSTANCE_TYPE = "ml.m5.large"
Expand All @@ -49,18 +53,18 @@ class _RepackModelStep(TrainingStep):
def __init__(
self,
name: str,
sagemaker_session,
role,
sagemaker_session: Session,
role: str,
model_data: str,
entry_point: str,
display_name: str = None,
description: str = None,
source_dir: str = None,
dependencies: List = None,
depends_on: Union[List[str], List[Step]] = None,
retry_policies: List[RetryPolicy] = None,
subnets=None,
security_group_ids=None,
display_name: Optional[str] = None,
description: Optional[str] = None,
source_dir: Optional[str] = None,
dependencies: Optional[List] = None,
depends_on: Optional[Union[List[str], List[Step]]] = None,
retry_policies: Optional[List[RetryPolicy]] = None,
subnets: Optional[List[str]] = None,
security_group_ids: Optional[List[str]] = None,
**kwargs,
):
"""Base class initializer.
Expand Down Expand Up @@ -237,7 +241,7 @@ def _inject_repack_script(self):
def arguments(self) -> RequestType:
"""The arguments dict that are used to call `create_training_job`.

This first prepares the source bundle for repackinglby placing artifacts
This first prepares the source bundle for repacking by placing artifacts
in locations which the training container will make available to the
repacking script and then gets the arguments for the training job.
"""
Expand Down Expand Up @@ -278,26 +282,26 @@ class _RegisterModelStep(ConfigurableRetryStep):
def __init__(
self,
name: str,
content_types,
response_types,
inference_instances,
transform_instances,
estimator: EstimatorBase = None,
model_data=None,
model_package_group_name=None,
model_metrics=None,
metadata_properties=None,
approval_status="PendingManualApproval",
image_uri=None,
compile_model_family=None,
display_name: str = None,
description=None,
depends_on: Union[List[str], List[Step]] = None,
retry_policies: List[RetryPolicy] = None,
tags=None,
container_def_list=None,
drift_check_baselines=None,
customer_metadata_properties=None,
content_types: List,
response_types: List,
inference_instances: List,
transform_instances: List,
estimator: Optional[EstimatorBase] = None,
model_data: Optional[str] = None,
model_package_group_name: Optional[str] = None,
model_metrics: Optional[ModelMetrics] = None,
metadata_properties: Optional[MetadataProperties] = None,
approval_status: str = "PendingManualApproval",
image_uri: Optional[str] = None,
compile_model_family: Optional[str] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
depends_on: Optional[Union[List[str], List[Step]]] = None,
retry_policies: Optional[List[RetryPolicy]] = None,
tags: Optional[List[Dict[str, str]]] = None,
container_def_list: Optional[List] = None,
drift_check_baselines: Optional[DriftCheckBaselines] = None,
customer_metadata_properties: Optional[Dict[str, str]] = None,
**kwargs,
):
"""Constructor of a register model step.
Expand Down Expand Up @@ -334,6 +338,9 @@ def __init__(
depends_on (List[str] or List[Step]): A list of step names or instances
this step depends on
retry_policies (List[RetryPolicy]): The list of retry policies for the current step
tags (List[dict[str, str]]): A list of dictionaries containing key-value pairs
(default: None).
container_def_list (list): A list of container definitions (default: None).
drift_check_baselines (DriftCheckBaselines): DriftCheckBaselines object (default: None).
customer_metadata_properties (dict[str, str]): A dictionary of key-value paired
metadata properties (default: None).
Expand Down
16 changes: 8 additions & 8 deletions src/sagemaker/workflow/callback_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
"""The CallbackStep definitions for workflow."""
from __future__ import absolute_import

from typing import List, Dict, Union
from typing import List, Dict, Union, Optional
from enum import Enum

import attr
Expand Down Expand Up @@ -81,12 +81,12 @@ def __init__(
self,
name: str,
sqs_queue_url: str,
inputs: dict,
inputs: Dict,
outputs: List[CallbackOutput],
display_name: str = None,
description: str = None,
cache_config: CacheConfig = None,
depends_on: Union[List[str], List[Step]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
cache_config: Optional[CacheConfig] = None,
depends_on: Optional[Union[List[str], List[Step]]] = None,
):
"""Constructs a CallbackStep.

Expand All @@ -99,7 +99,7 @@ def __init__(
display_name (str): The display name of the callback step.
description (str): The description of the callback step.
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.
depends_on (List[str] or List[Step]): A list of step names or step instances
depends_on (Union[List[str], List[Step]]): A list of step names or step instances
this `sagemaker.workflow.steps.CallbackStep` depends on
"""
super(CallbackStep, self).__init__(
Expand Down
32 changes: 17 additions & 15 deletions src/sagemaker/workflow/check_job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@
from __future__ import absolute_import

import logging
from typing import Optional
from typing import Optional, Dict, List

from sagemaker import Session
from sagemaker.network import NetworkConfig
from sagemaker.model_monitor import (
ModelMonitor,
DefaultModelMonitor,
Expand All @@ -31,18 +32,18 @@ class CheckJobConfig:

def __init__(
self,
role,
instance_count=1,
instance_type="ml.m5.xlarge",
volume_size_in_gb=30,
volume_kms_key=None,
output_kms_key=None,
max_runtime_in_seconds=None,
base_job_name=None,
sagemaker_session=None,
env=None,
tags=None,
network_config=None,
role: str,
instance_count: int = 1,
instance_type: str = "ml.m5.xlarge",
volume_size_in_gb: int = 30,
volume_kms_key: str = None,
output_kms_key: str = None,
max_runtime_in_seconds: int = None,
base_job_name: str = None,
sagemaker_session: Session = None,
env: Dict[str, str] = None,
tags: List[Dict[str, str]] = None,
network_config: NetworkConfig = None,
):
"""Constructs a CheckJobConfig instance.

Expand All @@ -65,8 +66,9 @@ def __init__(
manages interactions with Amazon SageMaker APIs and any other
AWS services needed (default: None). If not specified, one is
created using the default AWS configuration chain.
env (dict): Environment variables to be passed to the job (default: None).
tags ([dict]): List of tags to be passed to the job (default: None).
env (Dict): Environment variables to be passed to the job (default: None).
tags (List[Dict[str, str]]): A list of dictionaries containing key-value pairs
(default: None).
network_config (sagemaker.network.NetworkConfig): A NetworkConfig
object that configures network isolation, encryption of
inter-container traffic, security group IDs, and subnets (default: None).
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/workflow/clarify_check_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
"""The ClarifyCheckStep definitions for workflow."""
from __future__ import absolute_import

import copy
Expand Down Expand Up @@ -180,7 +180,7 @@ def __init__(
description (str): The description of the ClarifyCheckStep step (default: None).
cache_config (CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance
(default: None).
depends_on (List[str] or List[Step]): A list of step names or step instances
depends_on (Union[List[str], List[Step]]): A list of step names or step instances
this `sagemaker.workflow.steps.ClarifyCheckStep` depends on (default: None).
"""
if (
Expand Down
18 changes: 10 additions & 8 deletions src/sagemaker/workflow/condition_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
"""The ConditionStep definitions for workflow."""
from __future__ import absolute_import

from typing import List, Union
from typing import List, Union, Optional

import attr

Expand Down Expand Up @@ -41,12 +41,12 @@ class ConditionStep(Step):
def __init__(
self,
name: str,
depends_on: Union[List[str], List[Step]] = None,
display_name: str = None,
description: str = None,
conditions: List[Condition] = None,
if_steps: List[Union[Step, StepCollection]] = None,
else_steps: List[Union[Step, StepCollection]] = None,
display_name: Optional[str] = None,
description: Optional[str] = None,
conditions: Optional[List[Condition]] = None,
depends_on: Optional[Union[List[str], List[Step]]] = None,
if_steps: Optional[List[Union[Step, StepCollection]]] = None,
else_steps: Optional[List[Union[Step, StepCollection]]] = None,
):
"""Construct a ConditionStep for pipelines to support conditional branching.

Expand All @@ -60,6 +60,8 @@ def __init__(
description (str): The description of the condition step.
conditions (List[Condition]): A list of `sagemaker.workflow.conditions.Condition`
instances.
depends_on (List[str] or List[Step]): A list of step names or step instances
this `sagemaker.workflow.steps.ConditionStep` depends on (default: None).
if_steps (List[Union[Step, StepCollection]]): A list of `sagemaker.workflow.steps.Step`
or `sagemaker.workflow.step_collections.StepCollection` instances that are
marked as ready for execution if the list of conditions evaluates to True.
Expand Down
16 changes: 10 additions & 6 deletions src/sagemaker/workflow/emr_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
"""The EMRStep definitions for workflow."""
from __future__ import absolute_import

from typing import List
from typing import List, Union

from sagemaker.workflow.entities import (
RequestType,
Expand All @@ -28,17 +28,21 @@ class EMRStepConfig:
"""Config for a Hadoop Jar step."""

def __init__(
self, jar, args: List[str] = None, main_class: str = None, properties: List[dict] = None
self,
jar: str,
args: List[str] = None,
main_class: str = None,
properties: List[dict] = None,
):
"""Create a definition for input data used by an EMR cluster(job flow) step.

See AWS documentation on the ``StepConfig`` API for more details on the parameters.

Args:
jar(str): A path to a JAR file run during the step.
args(List[str]):
A list of command line arguments passed to
the JAR file's main function when executed.
jar(str): A path to a JAR file run during the step.
main_class(str): The name of the main class in the specified Java file.
properties(List(dict)): A list of key-value pairs that are set when the step runs.
"""
Expand Down Expand Up @@ -70,7 +74,7 @@ def __init__(
description: str,
cluster_id: str,
step_config: EMRStepConfig,
depends_on: List[str] = None,
depends_on: Union[List[str], List[Step]] = None,
cache_config: CacheConfig = None,
):
"""Constructs a EMRStep.
Expand All @@ -81,7 +85,7 @@ def __init__(
description(str): The description of the EMR step.
cluster_id(str): The ID of the running EMR cluster.
step_config(EMRStepConfig): One StepConfig to be executed by the job flow.
depends_on(List[str]):
depends_on(Union[List[str], List[Step]]):
A list of step names this `sagemaker.workflow.steps.EMRStep` depends on
cache_config(CacheConfig): A `sagemaker.workflow.steps.CacheConfig` instance.

Expand Down
7 changes: 3 additions & 4 deletions src/sagemaker/workflow/fail_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The `Step` definitions for SageMaker Pipelines Workflows."""
"""The FailStep definitions for workflow."""
from __future__ import absolute_import

from typing import List, Union
Expand Down Expand Up @@ -45,9 +45,8 @@ def __init__(
display_name (str): The display name of the `FailStep`.
The display name provides better UI readability. (default: None).
description (str): The description of the `FailStep` (default: None).
depends_on (List[str] or List[Step]): A list of `Step` names or `Step` instances
that this `FailStep` depends on.
If a listed `Step` name does not exist, an error is returned (default: None).
depends_on (Union[List[str], List[Step]]): A list of step names or step instances
this `sagemaker.workflow.steps.FailStep` depends on
"""
super(FailStep, self).__init__(
name, display_name, description, StepTypeEnum.FAIL, depends_on
Expand Down
8 changes: 4 additions & 4 deletions src/sagemaker/workflow/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
"""The functions for workflow."""
from __future__ import absolute_import

from typing import List, Union
Expand All @@ -36,13 +36,13 @@ class Join(PipelineVariable):
content_type="text/csv")

Attributes:
values (List[Union[PrimitiveType, Parameter, Expression]]):
The primitive type values, parameters, step properties, expressions to join.
on (str): The string to join the values on (Defaults to "").
values (List[PipelineVariable]):
The PipelineVariable(s) to join.
"""

on: str = attr.ib(factory=str)
values: List = attr.ib(factory=list)
values: List[PipelineVariable] = attr.ib(factory=list)

def to_string(self) -> PipelineVariable:
"""Prompt the pipeline to convert the pipeline variable to String in runtime
Expand Down
Loading