diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 208dc208b4..f31cfd938d 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -18,7 +18,7 @@ import os import uuid from abc import ABCMeta, abstractmethod -from typing import Any, Dict +from typing import Any, Dict, Union, Optional, List from six import string_types, with_metaclass from six.moves.urllib.parse import urlparse @@ -36,6 +36,7 @@ TensorBoardOutputConfig, get_default_profiler_rule, get_rule_container_image_uri, + RuleBase, ) from sagemaker.deprecations import removed_function, removed_kwargs, renamed_kwargs from sagemaker.fw_utils import ( @@ -46,7 +47,7 @@ tar_and_upload_dir, validate_source_dir, ) -from sagemaker.inputs import TrainingInput +from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.job import _Job from sagemaker.jumpstart.utils import ( add_jumpstart_tags, @@ -75,6 +76,7 @@ name_from_base, ) from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import ( PipelineSession, runnable_by_pipeline, @@ -105,44 +107,44 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man def __init__( self, - role, - instance_count=None, - instance_type=None, - volume_size=30, - volume_kms_key=None, - max_run=24 * 60 * 60, - input_mode="File", - output_path=None, - output_kms_key=None, - base_job_name=None, - sagemaker_session=None, - tags=None, - subnets=None, - security_group_ids=None, - model_uri=None, - model_channel_name="model", - metric_definitions=None, - encrypt_inter_container_traffic=False, - use_spot_instances=False, - max_wait=None, - checkpoint_s3_uri=None, - checkpoint_local_path=None, - rules=None, - debugger_hook_config=None, - tensorboard_output_config=None, - enable_sagemaker_metrics=None, - enable_network_isolation=False, - profiler_config=None, - disable_profiler=False, - environment=None, - max_retry_attempts=None, - source_dir=None, - git_config=None, - hyperparameters=None, - container_log_level=logging.INFO, - code_location=None, - entry_point=None, - dependencies=None, + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + volume_size: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, + input_mode: Union[str, PipelineVariable] = "File", + output_path: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + model_uri: Optional[str] = None, + model_channel_name: Union[str, PipelineVariable] = "model", + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, + use_spot_instances: Union[bool, PipelineVariable] = False, + max_wait: Optional[Union[int, PipelineVariable]] = None, + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, + rules: Optional[List[RuleBase]] = None, + debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None, + tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + profiler_config: Optional[ProfilerConfig] = None, + disable_profiler: bool = False, + environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, + source_dir: Optional[str] = None, + git_config: Optional[Dict[str, str]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + entry_point: Optional[str] = None, + dependencies: Optional[List[Union[str]]] = None, **kwargs, ): """Initialize an ``EstimatorBase`` instance. @@ -922,7 +924,14 @@ def latest_job_profiler_artifacts_path(self): return None @runnable_by_pipeline - def fit(self, inputs=None, wait=True, logs="All", job_name=None, experiment_config=None): + def fit( + self, + inputs: Optional[Union[str, Dict, TrainingInput, FileSystemInput]] = None, + wait: bool = True, + logs: str = "All", + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + ): """Train a model using the input training dataset. The API calls the Amazon SageMaker CreateTrainingJob API to start @@ -1870,16 +1879,22 @@ def _get_train_args(cls, estimator, inputs, experiment_config): ) train_args["input_mode"] = inputs.config["InputMode"] + # enable_network_isolation may be a pipeline variable place holder object + # which is parsed in execution time if estimator.enable_network_isolation(): - train_args["enable_network_isolation"] = True + train_args["enable_network_isolation"] = estimator.enable_network_isolation() if estimator.max_retry_attempts is not None: train_args["retry_strategy"] = {"MaximumRetryAttempts": estimator.max_retry_attempts} else: train_args["retry_strategy"] = None + # encrypt_inter_container_traffic may be a pipeline variable place holder object + # which is parsed in execution time if estimator.encrypt_inter_container_traffic: - train_args["encrypt_inter_container_traffic"] = True + train_args[ + "encrypt_inter_container_traffic" + ] = estimator.encrypt_inter_container_traffic if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator): train_args["algorithm_arn"] = estimator.algorithm_arn @@ -2025,45 +2040,45 @@ class Estimator(EstimatorBase): def __init__( self, - image_uri, - role, - instance_count=None, - instance_type=None, - volume_size=30, - volume_kms_key=None, - max_run=24 * 60 * 60, - input_mode="File", - output_path=None, - output_kms_key=None, - base_job_name=None, - sagemaker_session=None, - hyperparameters=None, - tags=None, - subnets=None, - security_group_ids=None, - model_uri=None, - model_channel_name="model", - metric_definitions=None, - encrypt_inter_container_traffic=False, - use_spot_instances=False, - max_wait=None, - checkpoint_s3_uri=None, - checkpoint_local_path=None, - enable_network_isolation=False, - rules=None, - debugger_hook_config=None, - tensorboard_output_config=None, - enable_sagemaker_metrics=None, - profiler_config=None, - disable_profiler=False, - environment=None, - max_retry_attempts=None, - source_dir=None, - git_config=None, - container_log_level=logging.INFO, - code_location=None, - entry_point=None, - dependencies=None, + image_uri: Union[str, PipelineVariable], + role: str, + instance_count: Optional[Union[int, PipelineVariable]] = None, + instance_type: Optional[Union[str, PipelineVariable]] = None, + volume_size: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_run: Union[int, PipelineVariable] = 24 * 60 * 60, + input_mode: Union[str, PipelineVariable] = "File", + output_path: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + model_uri: Optional[str] = None, + model_channel_name: Union[str, PipelineVariable] = "model", + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + encrypt_inter_container_traffic: Union[bool, PipelineVariable] = False, + use_spot_instances: Union[bool, PipelineVariable] = False, + max_wait: Optional[Union[int, PipelineVariable]] = None, + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + rules: Optional[List[RuleBase]] = None, + debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None, + tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, + profiler_config: Optional[ProfilerConfig] = None, + disable_profiler: bool = False, + environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, + source_dir: Optional[str] = None, + git_config: Optional[Dict[str, str]] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + entry_point: Optional[str] = None, + dependencies: Optional[List[str]] = None, **kwargs, ): """Initialize an ``Estimator`` instance. @@ -2488,18 +2503,18 @@ class Framework(EstimatorBase): def __init__( self, - entry_point, - source_dir=None, - hyperparameters=None, - container_log_level=logging.INFO, - code_location=None, - image_uri=None, - dependencies=None, - enable_network_isolation=False, - git_config=None, - checkpoint_s3_uri=None, - checkpoint_local_path=None, - enable_sagemaker_metrics=None, + entry_point: str, + source_dir: Optional[str] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + container_log_level: Union[int, PipelineVariable] = logging.INFO, + code_location: Optional[str] = None, + image_uri: Optional[Union[str, PipelineVariable]] = None, + dependencies: Optional[List[str]] = None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + git_config: Optional[Dict[str, str]] = None, + checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, + checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, + enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, **kwargs, ): """Base class initializer. diff --git a/src/sagemaker/network.py b/src/sagemaker/network.py index 1d2ae8c6ca..b3bf72a95a 100644 --- a/src/sagemaker/network.py +++ b/src/sagemaker/network.py @@ -16,6 +16,10 @@ """ from __future__ import absolute_import +from typing import Union, Optional, List + +from sagemaker.workflow.entities import PipelineVariable + class NetworkConfig(object): """Accepts network configuration parameters for conversion to request dict. @@ -25,10 +29,10 @@ class NetworkConfig(object): def __init__( self, - enable_network_isolation=False, - security_group_ids=None, - subnets=None, - encrypt_inter_container_traffic=None, + enable_network_isolation: Union[bool, PipelineVariable] = False, + security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, + subnets: Optional[List[Union[str, PipelineVariable]]] = None, + encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None, ): """Initialize a ``NetworkConfig`` instance. diff --git a/src/sagemaker/parameter.py b/src/sagemaker/parameter.py index 52efdeb7c6..79bbc62da2 100644 --- a/src/sagemaker/parameter.py +++ b/src/sagemaker/parameter.py @@ -14,8 +14,10 @@ from __future__ import absolute_import import json +from typing import Union from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.entities import PipelineVariable class ParameterRange(object): @@ -27,7 +29,12 @@ class ParameterRange(object): __all_types__ = ("Continuous", "Categorical", "Integer") - def __init__(self, min_value, max_value, scaling_type="Auto"): + def __init__( + self, + min_value: Union[int, float, PipelineVariable], + max_value: Union[int, float, PipelineVariable], + scaling_type: Union[str, PipelineVariable] = "Auto", + ): """Initialize a parameter range. Args: diff --git a/src/sagemaker/processing.py b/src/sagemaker/processing.py index cebe25dbab..1e4cfae4ff 100644 --- a/src/sagemaker/processing.py +++ b/src/sagemaker/processing.py @@ -22,7 +22,7 @@ import pathlib import logging from textwrap import dedent -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union import attr @@ -31,10 +31,12 @@ from sagemaker import s3 from sagemaker.job import _Job from sagemaker.local import LocalSession +from sagemaker.network import NetworkConfig from sagemaker.utils import base_name_from_image, get_config_value, name_from_base from sagemaker.session import Session from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.pipeline_context import runnable_by_pipeline +from sagemaker.workflow.entities import PipelineVariable from sagemaker.dataset_definition.inputs import S3Input, DatasetDefinition from sagemaker.apiutils._base_types import ApiObject from sagemaker.s3 import S3Uploader @@ -47,20 +49,20 @@ class Processor(object): def __init__( self, - role, - image_uri, - instance_count, - instance_type, - entrypoint=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + image_uri: Union[str, PipelineVariable], + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + entrypoint: Optional[List[Union[str, PipelineVariable]]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initializes a ``Processor`` instance. @@ -133,14 +135,14 @@ def __init__( @runnable_by_pipeline def run( self, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - kms_key=None, + inputs: Optional[List["ProcessingInput"]] = None, + outputs: Optional[List["ProcessingOutput"]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. @@ -388,20 +390,20 @@ class ScriptProcessor(Processor): def __init__( self, - role, - image_uri, - command, - instance_count, - instance_type, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + role: str, + image_uri: Union[str, PipelineVariable], + command: List[str], + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initializes a ``ScriptProcessor`` instance. @@ -498,15 +500,15 @@ def get_run_args( @runnable_by_pipeline def run( self, - code, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - kms_key=None, + code: str, + inputs: Optional[List["ProcessingInput"]] = None, + outputs: Optional[List["ProcessingOutput"]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. @@ -537,6 +539,8 @@ def run( * If both `ExperimentName` and `TrialName` are not supplied the trial component will be unassociated. * `TrialComponentDisplayName` is used for display in Studio. + kms_key (str): The ARN of the KMS key that is used to encrypt the + user code file (default: None). """ normalized_inputs, normalized_outputs = self._normalize_args( job_name=job_name, @@ -1072,16 +1076,16 @@ class ProcessingInput(object): def __init__( self, - source=None, - destination=None, - input_name=None, - s3_data_type="S3Prefix", - s3_input_mode="File", - s3_data_distribution_type="FullyReplicated", - s3_compression_type="None", - s3_input=None, - dataset_definition=None, - app_managed=False, + source: Optional[Union[str, PipelineVariable]] = None, + destination: Optional[Union[str, PipelineVariable]] = None, + input_name: Optional[Union[str, PipelineVariable]] = None, + s3_data_type: Union[str, PipelineVariable] = "S3Prefix", + s3_input_mode: Union[str, PipelineVariable] = "File", + s3_data_distribution_type: Union[str, PipelineVariable] = "FullyReplicated", + s3_compression_type: Union[str, PipelineVariable] = "None", + s3_input: Optional[S3Input] = None, + dataset_definition: Optional[DatasetDefinition] = None, + app_managed: Union[bool, PipelineVariable] = False, ): """Initializes a ``ProcessingInput`` instance. @@ -1179,12 +1183,12 @@ class ProcessingOutput(object): def __init__( self, - source=None, - destination=None, - output_name=None, - s3_upload_mode="EndOfJob", - app_managed=False, - feature_store_output=None, + source: Optional[Union[str, PipelineVariable]] = None, + destination: Optional[Union[str, PipelineVariable]] = None, + output_name: Optional[Union[str, PipelineVariable]] = None, + s3_upload_mode: Union[str, PipelineVariable] = "EndOfJob", + app_managed: Union[bool, PipelineVariable] = False, + feature_store_output: Optional["FeatureStoreOutput"] = None, ): """Initializes a ``ProcessingOutput`` instance. @@ -1277,24 +1281,24 @@ class FrameworkProcessor(ScriptProcessor): # Added new (kw)args for estimator. The rest are from ScriptProcessor with same defaults. def __init__( self, - estimator_cls, - framework_version, - role, - instance_count, - instance_type, - py_version="py3", - image_uri=None, - command=None, - volume_size_in_gb=30, - volume_kms_key=None, - output_kms_key=None, - code_location=None, - max_runtime_in_seconds=None, - base_job_name=None, - sagemaker_session=None, - env=None, - tags=None, - network_config=None, + estimator_cls: type, + framework_version: str, + role: str, + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + py_version: str = "py3", + image_uri: Optional[Union[str, PipelineVariable]] = None, + command: Optional[List[str]] = None, + volume_size_in_gb: Union[int, PipelineVariable] = 30, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + code_location: Optional[str] = None, + max_runtime_in_seconds: Optional[Union[int, PipelineVariable]] = None, + base_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + network_config: Optional[NetworkConfig] = None, ): """Initializes a ``FrameworkProcessor`` instance. @@ -1486,18 +1490,18 @@ def get_run_args( def run( # type: ignore[override] self, - code, - source_dir=None, - dependencies=None, - git_config=None, - inputs=None, - outputs=None, - arguments=None, - wait=True, - logs=True, - job_name=None, - experiment_config=None, - kms_key=None, + code: str, + source_dir: Optional[str] = None, + dependencies: Optional[List[str]] = None, + git_config: Optional[Dict[str, str]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingOutput]] = None, + arguments: Optional[List[Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, + job_name: Optional[str] = None, + experiment_config: Optional[Dict[str, str]] = None, + kms_key: Optional[str] = None, ): """Runs a processing job. diff --git a/src/sagemaker/transformer.py b/src/sagemaker/transformer.py index 36fb86a90b..7bd2f09063 100644 --- a/src/sagemaker/transformer.py +++ b/src/sagemaker/transformer.py @@ -13,10 +13,13 @@ """Placeholder docstring""" from __future__ import absolute_import +from typing import Union, Optional, List, Dict + from botocore import exceptions from sagemaker.job import _Job from sagemaker.session import Session +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.workflow import is_pipeline_variable from sagemaker.utils import base_name_from_image, name_from_base @@ -27,21 +30,21 @@ class Transformer(object): def __init__( self, - model_name, - instance_count, - instance_type, - strategy=None, - assemble_with=None, - output_path=None, - output_kms_key=None, - accept=None, - max_concurrent_transforms=None, - max_payload=None, - tags=None, - env=None, - base_transform_job_name=None, - sagemaker_session=None, - volume_kms_key=None, + model_name: Union[str, PipelineVariable], + instance_count: Union[int, PipelineVariable], + instance_type: Union[str, PipelineVariable], + strategy: Optional[Union[str, PipelineVariable]] = None, + assemble_with: Optional[Union[str, PipelineVariable]] = None, + output_path: Optional[Union[str, PipelineVariable]] = None, + output_kms_key: Optional[Union[str, PipelineVariable]] = None, + accept: Optional[Union[str, PipelineVariable]] = None, + max_concurrent_transforms: Optional[Union[int, PipelineVariable]] = None, + max_payload: Optional[Union[int, PipelineVariable]] = None, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + env: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + base_transform_job_name: Optional[str] = None, + sagemaker_session: Optional[Session] = None, + volume_kms_key: Optional[Union[str, PipelineVariable]] = None, ): """Initialize a ``Transformer``. @@ -111,19 +114,19 @@ def __init__( @runnable_by_pipeline def transform( self, - data, - data_type="S3Prefix", - content_type=None, - compression_type=None, - split_type=None, - job_name=None, - input_filter=None, - output_filter=None, - join_source=None, - experiment_config=None, - model_client_config=None, - wait=True, - logs=True, + data: Union[str, PipelineVariable], + data_type: Union[str, PipelineVariable] = "S3Prefix", + content_type: Optional[Union[str, PipelineVariable]] = None, + compression_type: Optional[Union[str, PipelineVariable]] = None, + split_type: Optional[Union[str, PipelineVariable]] = None, + job_name: Optional[str] = None, + input_filter: Optional[Union[str, PipelineVariable]] = None, + output_filter: Optional[Union[str, PipelineVariable]] = None, + join_source: Optional[Union[str, PipelineVariable]] = None, + experiment_config: Optional[Dict[str, str]] = None, + model_client_config: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + wait: bool = True, + logs: bool = True, ): """Start a new transform job. diff --git a/src/sagemaker/tuner.py b/src/sagemaker/tuner.py index f6229172c8..76337b8b4f 100644 --- a/src/sagemaker/tuner.py +++ b/src/sagemaker/tuner.py @@ -19,6 +19,7 @@ import logging from enum import Enum +from typing import Union, Dict, Optional, List, Set import sagemaker from sagemaker.amazon.amazon_estimator import ( @@ -29,8 +30,8 @@ from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.deprecations import removed_function -from sagemaker.estimator import Framework -from sagemaker.inputs import TrainingInput +from sagemaker.estimator import Framework, EstimatorBase +from sagemaker.inputs import TrainingInput, FileSystemInput from sagemaker.job import _Job from sagemaker.jumpstart.utils import add_jumpstart_tags, get_jumpstart_base_name_if_jumpstart_model from sagemaker.parameter import ( @@ -39,6 +40,7 @@ IntegerParameter, ParameterRange, ) +from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.session import Session @@ -95,7 +97,11 @@ class WarmStartConfig(object): {"p1","p2"} """ - def __init__(self, warm_start_type, parents): + def __init__( + self, + warm_start_type: WarmStartTypes, + parents: Set[Union[str, PipelineVariable]], + ): """Creates a ``WarmStartConfig`` with provided ``WarmStartTypes`` and parents. Args: @@ -208,19 +214,19 @@ class HyperparameterTuner(object): def __init__( self, - estimator, - objective_metric_name, - hyperparameter_ranges, - metric_definitions=None, - strategy="Bayesian", - objective_type="Maximize", - max_jobs=1, - max_parallel_jobs=1, - tags=None, - base_tuning_job_name=None, - warm_start_config=None, - early_stopping_type="Off", - estimator_name=None, + estimator: EstimatorBase, + objective_metric_name: Union[str, PipelineVariable], + hyperparameter_ranges: Dict[str, ParameterRange], + metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + strategy: Union[str, PipelineVariable] = "Bayesian", + objective_type: Union[str, PipelineVariable] = "Maximize", + max_jobs: Union[int, PipelineVariable] = 1, + max_parallel_jobs: Union[int, PipelineVariable] = 1, + tags: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, + base_tuning_job_name: Optional[str] = None, + warm_start_config: Optional[WarmStartConfig] = None, + early_stopping_type: Union[str, PipelineVariable] = "Off", + estimator_name: Optional[str] = None, ): """Creates a ``HyperparameterTuner`` instance. @@ -427,11 +433,13 @@ def _prepare_static_hyperparameters( @runnable_by_pipeline def fit( self, - inputs=None, - job_name=None, - include_cls_metadata=False, - estimator_kwargs=None, - wait=True, + inputs: Optional[ + Union[str, Dict, List, TrainingInput, FileSystemInput, RecordSet, FileSystemRecordSet] + ] = None, + job_name: Optional[str] = None, + include_cls_metadata: Union[bool, Dict[str, bool]] = False, + estimator_kwargs: Optional[Dict[str, dict]] = None, + wait: bool = True, **kwargs ): """Start a hyperparameter tuning job. diff --git a/tests/unit/sagemaker/workflow/test_training_step.py b/tests/unit/sagemaker/workflow/test_training_step.py index 0c6a6e34df..397e65f867 100644 --- a/tests/unit/sagemaker/workflow/test_training_step.py +++ b/tests/unit/sagemaker/workflow/test_training_step.py @@ -24,7 +24,7 @@ from sagemaker.transformer import Transformer from sagemaker.tuner import HyperparameterTuner from sagemaker.workflow.pipeline_context import PipelineSession -from sagemaker.workflow.parameters import ParameterString +from sagemaker.workflow.parameters import ParameterString, ParameterBoolean from sagemaker.workflow.steps import TrainingStep from sagemaker.workflow.pipeline import Pipeline, PipelineGraph @@ -203,6 +203,8 @@ def hyperparameters(): def test_training_step_with_estimator(pipeline_session, training_input, hyperparameters): custom_step1 = CustomStep("TestStep") custom_step2 = CustomStep("SecondTestStep") + enable_network_isolation = ParameterBoolean(name="enable_network_isolation") + encrypt_container_traffic = ParameterBoolean(name="encrypt_container_traffic") estimator = Estimator( role=ROLE, instance_count=1, @@ -210,6 +212,8 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar sagemaker_session=pipeline_session, image_uri=IMAGE_URI, hyperparameters=hyperparameters, + enable_network_isolation=enable_network_isolation, + encrypt_inter_container_traffic=encrypt_container_traffic, ) with warnings.catch_warnings(record=True) as w: @@ -231,8 +235,13 @@ def test_training_step_with_estimator(pipeline_session, training_input, hyperpar pipeline = Pipeline( name="MyPipeline", steps=[step, custom_step1, custom_step2], + parameters=[enable_network_isolation, encrypt_container_traffic], sagemaker_session=pipeline_session, ) + step_args.args["EnableInterContainerTrafficEncryption"] = { + "Get": "Parameters.encrypt_container_traffic" + } + step_args.args["EnableNetworkIsolation"] = {"Get": "Parameters.encrypt_container_traffic"} assert json.loads(pipeline.definition())["Steps"][0] == { "Name": "MyTrainingStep", "Description": "TrainingStep description",