From 67e86b76bcd3f17c6957184df565ab24a399abb9 Mon Sep 17 00:00:00 2001 From: martinRenou Date: Tue, 28 Nov 2023 10:09:33 +0100 Subject: [PATCH 1/6] Change: Use pydantic type validation --- setup.py | 1 + src/sagemaker/estimator.py | 8 ++-- src/sagemaker/pytorch/estimator.py | 20 ++++++---- src/sagemaker/pytorch/model.py | 39 +++++++++++------- src/sagemaker/pytorch/processing.py | 3 +- src/sagemaker/sklearn/estimator.py | 27 ++++++++----- src/sagemaker/sklearn/model.py | 40 ++++++++++++------- src/sagemaker/sklearn/processing.py | 3 +- src/sagemaker/spark/processing.py | 61 +++++++++++++++-------------- src/sagemaker/utils.py | 21 +++++++++- tests/unit/test_pytorch.py | 16 ++++++++ tests/unit/test_sklearn.py | 27 +++++++++++++ tests/unit/tuner_test_utils.py | 4 +- 13 files changed, 187 insertions(+), 83 deletions(-) diff --git a/setup.py b/setup.py index ff833b7a5c..1ff8163dcd 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ def read_requirements(filename): "packaging>=20.0", "pandas", "pathos", + "pydantic", "schema", "PyYAML~=6.0", "jsonschema", diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index f899570775..7ccbb244ef 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -100,6 +100,7 @@ resolve_value_from_config, format_tags, Tags, + validate_call_inputs, ) from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -132,6 +133,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz" JOB_CLASS_NAME = "training-job" + @validate_call_inputs def __init__( self, role: str = None, @@ -152,7 +154,7 @@ def __init__( model_uri: Optional[str] = None, model_channel_name: Union[str, PipelineVariable] = "model", metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, - encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None, + encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None, use_spot_instances: Union[bool, PipelineVariable] = False, max_wait: Optional[Union[int, PipelineVariable]] = None, checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, @@ -161,7 +163,7 @@ def __init__( debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, - enable_network_isolation: Union[bool, PipelineVariable] = None, + enable_network_isolation: Optional[Union[bool, PipelineVariable]] = None, profiler_config: Optional[ProfilerConfig] = None, disable_profiler: bool = None, environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None, @@ -2728,7 +2730,7 @@ def __init__( model_uri: Optional[str] = None, model_channel_name: Union[str, PipelineVariable] = "model", metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None, - encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None, + encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None, use_spot_instances: Union[bool, PipelineVariable] = False, max_wait: Optional[Union[int, PipelineVariable]] = None, checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index a4e24d1ff0..8c440629f2 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -14,7 +14,8 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, Dict +from numbers import Number +from typing import Union, Optional, Dict, List from packaging.version import Version @@ -32,6 +33,7 @@ from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.utils import validate_call_inputs logger = logging.getLogger("sagemaker") @@ -44,13 +46,14 @@ class PyTorch(Framework): LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled" INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type" + @validate_call_inputs def __init__( self, entry_point: Union[str, PipelineVariable], framework_version: Optional[str] = None, py_version: Optional[str] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, compiler_config: Optional[TrainingCompilerConfig] = None, @@ -362,14 +365,15 @@ def hyperparameters(self): return hyperparameters + @validate_call_inputs def create_model( self, - model_server_workers=None, - role=None, - vpc_config_override=VPC_CONFIG_DEFAULT, - entry_point=None, - source_dir=None, - dependencies=None, + model_server_workers: Optional[int] = None, + role: Optional[str] = None, + vpc_config_override: Optional[Dict[str, List[str]]] = VPC_CONFIG_DEFAULT, + entry_point: Optional[str] = None, + source_dir: Optional[str] = None, + dependencies: Optional[List[str]] = None, **kwargs, ): """Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``. diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index fb731cabf4..617a83171d 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -20,7 +20,7 @@ import sagemaker from sagemaker import image_uris, ModelMetrics -from sagemaker.deserializers import NumpyDeserializer +from sagemaker.deserializers import BaseDeserializer, NumpyDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import ( model_code_key_prefix, @@ -31,8 +31,10 @@ from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.pytorch import defaults from sagemaker.predictor import Predictor -from sagemaker.serializers import NumpySerializer -from sagemaker.utils import to_string +from sagemaker.serializers import BaseSerializer, NumpySerializer +from sagemaker.session import Session +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.utils import to_string, validate_call_inputs from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -46,13 +48,14 @@ class PyTorchPredictor(Predictor): multidimensional tensors for PyTorch inference. """ + @validate_call_inputs def __init__( self, - endpoint_name, - sagemaker_session=None, - serializer=NumpySerializer(), - deserializer=NumpyDeserializer(), - component_name=None, + endpoint_name: str, + sagemaker_session: Optional[Session] = None, + serializer: BaseSerializer = NumpySerializer(), + deserializer: BaseDeserializer = NumpyDeserializer(), + component_name: Optional[str] = None, ): """Initialize an ``PyTorchPredictor``. @@ -86,12 +89,13 @@ class PyTorchModel(FrameworkModel): _framework_name = "pytorch" _LOWEST_MMS_VERSION = "1.2" + @validate_call_inputs def __init__( self, model_data: Union[str, PipelineVariable], role: Optional[str] = None, entry_point: Optional[str] = None, - framework_version: str = "1.3", + framework_version: Optional[str] = "1.3", py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, predictor_cls: callable = PyTorchPredictor, @@ -154,6 +158,7 @@ def __init__( self.model_server_workers = model_server_workers + @validate_call_inputs def register( self, content_types: List[Union[str, PipelineVariable]] = None, @@ -268,12 +273,13 @@ def register( skip_model_validation=skip_model_validation, ) + @validate_call_inputs def prepare_container_def( self, - instance_type=None, - accelerator_type=None, - serverless_inference_config=None, - accept_eula=None, + instance_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, + accept_eula: Optional[bool] = None, ): """A container definition with framework configuration set in model environment variables. @@ -327,8 +333,13 @@ def prepare_container_def( accept_eula=accept_eula, ) + @validate_call_inputs def serving_image_uri( - self, region_name, instance_type, accelerator_type=None, serverless_inference_config=None + self, + region_name: str, + instance_type: str, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, ): """Create a URI for the serving image. diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index e04e4ba65a..8ceb6662f7 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -24,7 +24,7 @@ from sagemaker.processing import FrameworkProcessor from sagemaker.pytorch.estimator import PyTorch from sagemaker.workflow.entities import PipelineVariable -from sagemaker.utils import format_tags, Tags +from sagemaker.utils import format_tags, Tags, validate_call_inputs class PyTorchProcessor(FrameworkProcessor): @@ -32,6 +32,7 @@ class PyTorchProcessor(FrameworkProcessor): estimator_cls = PyTorch + @validate_call_inputs def __init__( self, framework_version: str, # New arg diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 9f4b25f214..48cc155011 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -14,7 +14,8 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, Dict +from numbers import Number +from typing import Union, Optional, Dict, List from sagemaker import image_uris from sagemaker.deprecations import renamed_kwargs @@ -29,6 +30,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow import is_pipeline_variable +from sagemaker.utils import validate_call_inputs logger = logging.getLogger("sagemaker") @@ -38,13 +40,14 @@ class SKLearn(Framework): _framework_name = defaults.SKLEARN_NAME + @validate_call_inputs def __init__( self, entry_point: Union[str, PipelineVariable], framework_version: Optional[str] = None, - py_version: str = "py3", + py_version: Optional[str] = "py3", source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Optional[Union[str, PipelineVariable, Number]]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, image_uri_region: Optional[str] = None, **kwargs @@ -166,14 +169,15 @@ def __init__( instance_type=instance_type, ) + @validate_call_inputs def create_model( self, - model_server_workers=None, - role=None, - vpc_config_override=VPC_CONFIG_DEFAULT, - entry_point=None, - source_dir=None, - dependencies=None, + model_server_workers: Optional[int] = None, + role: Optional[str] = None, + vpc_config_override: Optional[Union[str, Dict[str, List[str]]]] = VPC_CONFIG_DEFAULT, + entry_point: Optional[str] = None, + source_dir: Optional[str] = None, + dependencies: Optional[List[str]] = None, **kwargs ): """Create a SageMaker ``SKLearnModel`` object that can be deployed to an ``Endpoint``. @@ -233,7 +237,10 @@ def create_model( ) @classmethod - def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None): + @validate_call_inputs + def _prepare_init_params_from_job_description( + cls, job_details, model_channel_name: Optional[str] = None + ): """Convert the job description to init params that can be handled by the class constructor. Args: diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 195a6a3a57..99925344ba 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -18,15 +18,17 @@ import sagemaker from sagemaker import image_uris, ModelMetrics -from sagemaker.deserializers import NumpyDeserializer +from sagemaker.deserializers import BaseDeserializer, NumpyDeserializer from sagemaker.drift_check_baselines import DriftCheckBaselines from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args from sagemaker.metadata_properties import MetadataProperties from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME from sagemaker.predictor import Predictor -from sagemaker.serializers import NumpySerializer +from sagemaker.serializers import BaseSerializer, NumpySerializer +from sagemaker.serverless import ServerlessInferenceConfig +from sagemaker.session import Session from sagemaker.sklearn import defaults -from sagemaker.utils import to_string +from sagemaker.utils import to_string, validate_call_inputs from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.entities import PipelineVariable @@ -40,13 +42,14 @@ class SKLearnPredictor(Predictor): multidimensional tensors for Scikit-learn inference. """ + @validate_call_inputs def __init__( self, - endpoint_name, - sagemaker_session=None, - serializer=NumpySerializer(), - deserializer=NumpyDeserializer(), - component_name=None, + endpoint_name: str, + sagemaker_session: Optional[Session] = None, + serializer: BaseSerializer = NumpySerializer(), + deserializer: BaseDeserializer = NumpyDeserializer(), + component_name: Optional[str] = None, ): """Initialize an ``SKLearnPredictor``. @@ -79,13 +82,14 @@ class SKLearnModel(FrameworkModel): _framework_name = defaults.SKLEARN_NAME + @validate_call_inputs def __init__( self, model_data: Union[str, PipelineVariable], role: Optional[str] = None, entry_point: Optional[str] = None, framework_version: Optional[str] = None, - py_version: str = "py3", + py_version: Optional[str] = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, predictor_cls: callable = SKLearnPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, @@ -147,6 +151,7 @@ def __init__( self.model_server_workers = model_server_workers + @validate_call_inputs def register( self, content_types: List[Union[str, PipelineVariable]] = None, @@ -261,12 +266,13 @@ def register( skip_model_validation=skip_model_validation, ) + @validate_call_inputs def prepare_container_def( self, - instance_type=None, - accelerator_type=None, - serverless_inference_config=None, - accept_eula=None, + instance_type: Optional[str] = None, + accelerator_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, + accept_eula: Optional[bool] = None, ): """Container definition with framework configuration set in model environment variables. @@ -318,7 +324,13 @@ def prepare_container_def( accept_eula=accept_eula, ) - def serving_image_uri(self, region_name, instance_type, serverless_inference_config=None): + @validate_call_inputs + def serving_image_uri( + self, + region_name: str, + instance_type: Optional[str] = None, + serverless_inference_config: Optional[ServerlessInferenceConfig] = None, + ): """Create a URI for the serving image. Args: diff --git a/src/sagemaker/sklearn/processing.py b/src/sagemaker/sklearn/processing.py index ff209b3740..13df40f55a 100644 --- a/src/sagemaker/sklearn/processing.py +++ b/src/sagemaker/sklearn/processing.py @@ -24,12 +24,13 @@ from sagemaker.processing import ScriptProcessor from sagemaker.sklearn import defaults from sagemaker.workflow.entities import PipelineVariable -from sagemaker.utils import format_tags, Tags +from sagemaker.utils import format_tags, Tags, validate_call_inputs class SKLearnProcessor(ScriptProcessor): """Handles Amazon SageMaker processing tasks for jobs using scikit-learn.""" + @validate_call_inputs def __init__( self, framework_version: str, # New arg diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index 82634071cc..941f824c96 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -41,7 +41,7 @@ from sagemaker.session import Session from sagemaker.network import NetworkConfig from sagemaker.spark import defaults -from sagemaker.utils import format_tags, Tags +from sagemaker.utils import format_tags, Tags, validate_call_inputs from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.pipeline_context import runnable_by_pipeline @@ -686,6 +686,7 @@ def _handle_script_dependencies(self, inputs, submit_files, file_type): class PySparkProcessor(_SparkProcessorBase): """Handles Amazon SageMaker processing tasks for jobs using PySpark.""" + @validate_call_inputs def __init__( self, role: str = None, @@ -776,18 +777,19 @@ def __init__( network_config=network_config, ) + @validate_call_inputs def get_run_args( self, - submit_app, - submit_py_files=None, - submit_jars=None, - submit_files=None, - inputs=None, - outputs=None, - arguments=None, - job_name=None, - configuration=None, - spark_event_logs_s3_uri=None, + submit_app: str, + submit_py_files: Optional[List[str]] = None, + submit_jars: Optional[List[str]] = None, + submit_files: Optional[List[str]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingInput]] = None, + arguments: Optional[List[str]] = None, + job_name: Optional[str] = None, + configuration: Optional[Union[List[dict], dict]] = None, + spark_event_logs_s3_uri: Optional[str] = None, ): """Returns a RunArgs object. @@ -823,9 +825,6 @@ def get_run_args( """ self._current_job_name = self._generate_current_job_name(job_name=job_name) - if not submit_app: - raise ValueError("submit_app is required") - extended_inputs, extended_outputs = self._extend_processing_args( inputs=inputs, outputs=outputs, @@ -843,6 +842,7 @@ def get_run_args( arguments=arguments, ) + @validate_call_inputs @runnable_by_pipeline def run( self, @@ -963,6 +963,7 @@ def _extend_processing_args(self, inputs, outputs, **kwargs): class SparkJarProcessor(_SparkProcessorBase): """Handles Amazon SageMaker processing tasks for jobs using Spark with Java or Scala Jars.""" + @validate_call_inputs def __init__( self, role: str = None, @@ -1052,18 +1053,19 @@ def __init__( network_config=network_config, ) + @validate_call_inputs def get_run_args( self, - submit_app, - submit_class=None, - submit_jars=None, - submit_files=None, - inputs=None, - outputs=None, - arguments=None, - job_name=None, - configuration=None, - spark_event_logs_s3_uri=None, + submit_app: str, + submit_class: Optional[str] = None, + submit_jars: Optional[List[str]] = None, + submit_files: Optional[List[str]] = None, + inputs: Optional[List[ProcessingInput]] = None, + outputs: Optional[List[ProcessingInput]] = None, + arguments: Optional[List[str]] = None, + job_name: Optional[str] = None, + configuration: Optional[Union[List[dict], dict]] = None, + spark_event_logs_s3_uri: Optional[str] = None, ): """Returns a RunArgs object. @@ -1099,9 +1101,6 @@ def get_run_args( """ self._current_job_name = self._generate_current_job_name(job_name=job_name) - if not submit_app: - raise ValueError("submit_app is required") - extended_inputs, extended_outputs = self._extend_processing_args( inputs=inputs, outputs=outputs, @@ -1119,6 +1118,7 @@ def get_run_args( arguments=arguments, ) + @validate_call_inputs @runnable_by_pipeline def run( self, @@ -1316,15 +1316,16 @@ class SparkConfigUtils: ] @staticmethod - def validate_configuration(configuration: Dict): + @validate_call_inputs + def validate_configuration(configuration: Union[Dict, list]): """Validates the user-provided Hadoop/Spark/Hive configuration. This ensures that the list or dictionary the user provides will serialize to JSON matching the schema of EMR's application configuration Args: - configuration (Dict): A dict that contains the configuration overrides to - the default values. For more information, please visit: + configuration (Dict or List): A dict or a list of dicts that contains the configuration + overrides to the default values. For more information, please visit: https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html """ emr_configure_apps_url = ( diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index e203693f84..0fecde4646 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -34,6 +34,7 @@ from importlib import import_module import botocore from botocore.utils import merge_dicts +from pydantic import validate_call, ConfigDict from six.moves.urllib import parse from sagemaker import deprecations @@ -1482,10 +1483,28 @@ def create_paginator_config(max_items: int = None, page_size: int = None) -> Dic "PageSize": page_size if page_size else PAGE_SIZE, } - def format_tags(tags: Tags) -> List[TagsDict]: """Process tags to turn them into the expected format for Sagemaker.""" if isinstance(tags, dict): return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()] return tags + + +def validate_call_inputs( + __func: callable, + *args, + config: Optional[ConfigDict] = None, + validate_return: bool = False, +): + """Decorator for function input types using pydantic. + + This calls pydantic.validate_call under the hood, with "arbitrary_types_allowed" enabled. + See its documentation for more information. + """ + if config is None: + config = ConfigDict() + + config.setdefault("arbitrary_types_allowed", True) + + return validate_call(__func, *args, config=config, validate_return=validate_return) diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 9624f7612f..744066e3ca 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -19,11 +19,14 @@ from mock import ANY, MagicMock, Mock, patch from packaging.version import Version +from pydantic import ValidationError + from sagemaker import image_uris from sagemaker.pytorch import defaults from sagemaker.pytorch import PyTorch, PyTorchPredictor, PyTorchModel from sagemaker.instance_group import InstanceGroup from sagemaker.session_settings import SessionSettings +from sagemaker.session import Session DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -66,8 +69,10 @@ def fixture_sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, boto_session=boto_mock, boto_region_name=REGION, + sagemaker_client=Mock(), config=None, local_mode=False, s3_resource=None, @@ -228,6 +233,17 @@ def test_create_model( name_from_base.assert_called_with(base_job_name) + with pytest.raises(ValidationError): + PyTorch(entry_point=3, py_version="py3", framework_version=pytorch_inference_version) + + with pytest.raises(ValidationError): + PyTorch( + entry_point="", + py_version="py3", + framework_version=pytorch_inference_version, + role=5, + ) + def test_create_model_with_optional_params( sagemaker_session, pytorch_inference_version, pytorch_inference_py_version diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index 9745c4ea26..43151876aa 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -20,7 +20,10 @@ from mock import Mock from mock import patch +from pydantic import ValidationError + from sagemaker.fw_utils import UploadedCode +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from sagemaker.sklearn import SKLearn, SKLearnModel, SKLearnPredictor @@ -60,7 +63,9 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( + spec=Session, name="sagemaker_session", + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, @@ -171,6 +176,17 @@ def test_training_image_uri(sagemaker_session, sklearn_version): assert _get_full_cpu_image_uri(sklearn_version) == sklearn.training_image_uri() +def test_ctor_wrong_parameters(sagemaker_session, sklearn_version): + with pytest.raises(ValidationError): + SKLearn() + + with pytest.raises(ValidationError): + SKLearn(entry_point=3) + + with pytest.raises(ValidationError): + SKLearn(entry_point="", image_uri_region=3) + + def test_create_model(sagemaker_session, sklearn_version): source_dir = "s3://mybucket/source" @@ -186,6 +202,17 @@ def test_create_model(sagemaker_session, sklearn_version): assert model_values["Image"] == image_uri +def test_create_model_wrong_parameters(sagemaker_session, sklearn_version): + with pytest.raises(ValidationError): + SKLearnModel(model_data=1) + + with pytest.raises(ValidationError): + SKLearnModel(model_data="", role=2) + + with pytest.raises(ValidationError): + SKLearnModel(model_data="", entry_point=3) + + @patch("sagemaker.model.FrameworkModel._upload_code") def test_create_model_with_network_isolation(upload, sagemaker_session, sklearn_version): source_dir = "s3://mybucket/source" diff --git a/tests/unit/tuner_test_utils.py b/tests/unit/tuner_test_utils.py index c7b1abcbb2..8fcc9d881f 100644 --- a/tests/unit/tuner_test_utils.py +++ b/tests/unit/tuner_test_utils.py @@ -15,6 +15,7 @@ import os from mock import Mock +from sagemaker.session import Session from sagemaker.amazon.pca import PCA from sagemaker.estimator import Estimator from sagemaker.parameter import CategoricalParameter, ContinuousParameter, IntegerParameter @@ -71,11 +72,12 @@ ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"} -SAGEMAKER_SESSION = Mock() +SAGEMAKER_SESSION = Mock(spec=Session) # For tests which doesn't verify config file injection, operate with empty config SAGEMAKER_SESSION.sagemaker_config = {} SAGEMAKER_SESSION.default_bucket = Mock(return_value=BUCKET_NAME) SAGEMAKER_SESSION.default_bucket_prefix = None +SAGEMAKER_SESSION.local_mode = False ESTIMATOR = Estimator( From d65923c4af58ff7cde662789ff1b68c98dc2d080 Mon Sep 17 00:00:00 2001 From: martinRenou Date: Wed, 6 Dec 2023 10:37:38 +0100 Subject: [PATCH 2/6] Remove fastapi dependency --- setup.py | 1 - src/sagemaker/estimator.py | 2 +- src/sagemaker/pytorch/estimator.py | 3 ++- src/sagemaker/pytorch/model.py | 3 ++- src/sagemaker/pytorch/processing.py | 4 +++- src/sagemaker/sklearn/estimator.py | 3 ++- src/sagemaker/sklearn/model.py | 3 ++- src/sagemaker/sklearn/processing.py | 3 ++- src/sagemaker/spark/processing.py | 3 ++- src/sagemaker/utils.py | 1 + 10 files changed, 17 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index 1ff8163dcd..bc4475c81b 100644 --- a/setup.py +++ b/setup.py @@ -67,7 +67,6 @@ def read_requirements(filename): "tblib>=1.7.0,<3", "urllib3<1.27", "uvicorn==0.22.0", - "fastapi==0.95.2", "requests", "docker", "tqdm", diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 7ccbb244ef..c7548e8ef5 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -136,7 +136,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man @validate_call_inputs def __init__( self, - role: str = None, + role: Union[str, ParameterString] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None, diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 8c440629f2..81c3d471d8 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -33,6 +33,7 @@ from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.parameters import ParameterString from sagemaker.utils import validate_call_inputs logger = logging.getLogger("sagemaker") @@ -369,7 +370,7 @@ def hyperparameters(self): def create_model( self, model_server_workers: Optional[int] = None, - role: Optional[str] = None, + role: Optional[Union[str, ParameterString]] = None, vpc_config_override: Optional[Dict[str, List[str]]] = VPC_CONFIG_DEFAULT, entry_point: Optional[str] = None, source_dir: Optional[str] = None, diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 617a83171d..7b521720e9 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -36,6 +36,7 @@ from sagemaker.serverless import ServerlessInferenceConfig from sagemaker.utils import to_string, validate_call_inputs from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -93,7 +94,7 @@ class PyTorchModel(FrameworkModel): def __init__( self, model_data: Union[str, PipelineVariable], - role: Optional[str] = None, + role: Optional[Union[str, ParameterString]] = None, entry_point: Optional[str] = None, framework_version: Optional[str] = "1.3", py_version: Optional[str] = None, diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index 8ceb6662f7..5f884f8e02 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -25,6 +25,8 @@ from sagemaker.pytorch.estimator import PyTorch from sagemaker.workflow.entities import PipelineVariable from sagemaker.utils import format_tags, Tags, validate_call_inputs +from sagemaker.workflow.parameters import ParameterString +from sagemaker.utils import validate_call_inputs class PyTorchProcessor(FrameworkProcessor): @@ -36,7 +38,7 @@ class PyTorchProcessor(FrameworkProcessor): def __init__( self, framework_version: str, # New arg - role: Optional[Union[str, PipelineVariable]] = None, + role: Optional[Union[str, PipelineVariable, ParameterString]] = None, instance_count: Union[int, PipelineVariable] = None, instance_type: Union[str, PipelineVariable] = None, py_version: str = "py3", # New kwarg diff --git a/src/sagemaker/sklearn/estimator.py b/src/sagemaker/sklearn/estimator.py index 48cc155011..3881b488e4 100644 --- a/src/sagemaker/sklearn/estimator.py +++ b/src/sagemaker/sklearn/estimator.py @@ -30,6 +30,7 @@ from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT from sagemaker.workflow.entities import PipelineVariable from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.parameters import ParameterString from sagemaker.utils import validate_call_inputs logger = logging.getLogger("sagemaker") @@ -173,7 +174,7 @@ def __init__( def create_model( self, model_server_workers: Optional[int] = None, - role: Optional[str] = None, + role: Optional[Union[str, ParameterString]] = None, vpc_config_override: Optional[Union[str, Dict[str, List[str]]]] = VPC_CONFIG_DEFAULT, entry_point: Optional[str] = None, source_dir: Optional[str] = None, diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 99925344ba..bf028fd5b7 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -30,6 +30,7 @@ from sagemaker.sklearn import defaults from sagemaker.utils import to_string, validate_call_inputs from sagemaker.workflow import is_pipeline_variable +from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.entities import PipelineVariable logger = logging.getLogger("sagemaker") @@ -86,7 +87,7 @@ class SKLearnModel(FrameworkModel): def __init__( self, model_data: Union[str, PipelineVariable], - role: Optional[str] = None, + role: Optional[Union[str, ParameterString]] = None, entry_point: Optional[str] = None, framework_version: Optional[str] = None, py_version: Optional[str] = "py3", diff --git a/src/sagemaker/sklearn/processing.py b/src/sagemaker/sklearn/processing.py index 13df40f55a..e04b1e0cc7 100644 --- a/src/sagemaker/sklearn/processing.py +++ b/src/sagemaker/sklearn/processing.py @@ -25,6 +25,7 @@ from sagemaker.sklearn import defaults from sagemaker.workflow.entities import PipelineVariable from sagemaker.utils import format_tags, Tags, validate_call_inputs +from sagemaker.workflow.parameters import ParameterString class SKLearnProcessor(ScriptProcessor): @@ -34,7 +35,7 @@ class SKLearnProcessor(ScriptProcessor): def __init__( self, framework_version: str, # New arg - role: Optional[Union[str, PipelineVariable]] = None, + role: Optional[Union[str, PipelineVariable, ParameterString]] = None, instance_count: Union[int, PipelineVariable] = None, instance_type: Union[str, PipelineVariable] = None, command: Optional[List[str]] = None, diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index 941f824c96..40f45530a6 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -46,6 +46,7 @@ from sagemaker.workflow import is_pipeline_variable from sagemaker.workflow.pipeline_context import runnable_by_pipeline from sagemaker.workflow.entities import PipelineVariable +from sagemaker.workflow.parameters import ParameterString from sagemaker.workflow.functions import Join logger = logging.getLogger(__name__) @@ -689,7 +690,7 @@ class PySparkProcessor(_SparkProcessorBase): @validate_call_inputs def __init__( self, - role: str = None, + role: Union[str, ParameterString] = None, instance_type: Union[str, PipelineVariable] = None, instance_count: Union[int, PipelineVariable] = None, framework_version: Optional[str] = None, diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 0fecde4646..a55f0e9454 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -1483,6 +1483,7 @@ def create_paginator_config(max_items: int = None, page_size: int = None) -> Dic "PageSize": page_size if page_size else PAGE_SIZE, } + def format_tags(tags: Tags) -> List[TagsDict]: """Process tags to turn them into the expected format for Sagemaker.""" if isinstance(tags, dict): From cf367660587aafeac83e4051b924fb7b5749e877 Mon Sep 17 00:00:00 2001 From: martinRenou Date: Wed, 6 Dec 2023 14:21:04 +0100 Subject: [PATCH 3/6] Fix callable types --- src/sagemaker/chainer/model.py | 4 ++-- src/sagemaker/djl_inference/model.py | 4 ++-- src/sagemaker/huggingface/model.py | 4 ++-- src/sagemaker/mxnet/model.py | 4 ++-- src/sagemaker/pytorch/model.py | 4 ++-- src/sagemaker/remote_function/job.py | 2 +- src/sagemaker/serve/builder/schema_builder.py | 3 ++- src/sagemaker/serve/utils/tuning.py | 3 ++- src/sagemaker/sklearn/model.py | 4 ++-- src/sagemaker/spark/processing.py | 3 +-- src/sagemaker/tensorflow/model.py | 4 ++-- src/sagemaker/utils.py | 4 ++-- src/sagemaker/xgboost/model.py | 4 ++-- 13 files changed, 24 insertions(+), 23 deletions(-) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index bafcfde3a8..ecf78d72f9 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -91,7 +91,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, - predictor_cls: callable = ChainerPredictor, + predictor_cls: Callable = ChainerPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): diff --git a/src/sagemaker/djl_inference/model.py b/src/sagemaker/djl_inference/model.py index 8308215e81..40446f6a53 100644 --- a/src/sagemaker/djl_inference/model.py +++ b/src/sagemaker/djl_inference/model.py @@ -20,7 +20,7 @@ from json import JSONDecodeError from urllib.error import HTTPError, URLError from enum import Enum -from typing import Optional, Union, Dict, Any, List +from typing import Callable, Optional, Union, Dict, Any, List import sagemaker from sagemaker import s3, Predictor, image_uris, fw_utils @@ -313,7 +313,7 @@ def __init__( prediction_timeout: Optional[int] = None, entry_point: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = DJLPredictor, + predictor_cls: Callable = DJLPredictor, **kwargs, ): """Initialize a DJLModel. diff --git a/src/sagemaker/huggingface/model.py b/src/sagemaker/huggingface/model.py index efe6a85288..73e4996e67 100644 --- a/src/sagemaker/huggingface/model.py +++ b/src/sagemaker/huggingface/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -118,7 +118,7 @@ def __init__( pytorch_version: Optional[str] = None, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = HuggingFacePredictor, + predictor_cls: Callable = HuggingFacePredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs, ): diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index 8cd0ac6b65..56091476bf 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import packaging.version @@ -93,7 +93,7 @@ def __init__( framework_version: str = _LOWEST_MMS_VERSION, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = MXNetPredictor, + predictor_cls: Callable = MXNetPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 7b521720e9..2201f88ccb 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import packaging.version @@ -99,7 +99,7 @@ def __init__( framework_version: Optional[str] = "1.3", py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = PyTorchPredictor, + predictor_cls: Callable = PyTorchPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 205a2adf41..ee66e1991c 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -692,7 +692,7 @@ def compile( job_settings: _JobSettings, job_name: str, s3_base_uri: str, - func: callable, + func: Callable, func_args: tuple, func_kwargs: dict, run_info=None, diff --git a/src/sagemaker/serve/builder/schema_builder.py b/src/sagemaker/serve/builder/schema_builder.py index 24900a5dc8..90f77bd10d 100644 --- a/src/sagemaker/serve/builder/schema_builder.py +++ b/src/sagemaker/serve/builder/schema_builder.py @@ -3,6 +3,7 @@ import io import logging from pathlib import Path +from typing import Callable import numpy as np from pandas import DataFrame @@ -266,7 +267,7 @@ def _is_path_to_file(data: object) -> bool: def _validate_translations( - payload: object, serialize_callable: callable, deserialize_callable: callable + payload: object, serialize_callable: Callable, deserialize_callable: Callable ) -> None: """Placeholder docstring""" try: diff --git a/src/sagemaker/serve/utils/tuning.py b/src/sagemaker/serve/utils/tuning.py index c095791ef9..c83e7890d3 100644 --- a/src/sagemaker/serve/utils/tuning.py +++ b/src/sagemaker/serve/utils/tuning.py @@ -2,6 +2,7 @@ from __future__ import absolute_import import logging from time import perf_counter +from typing import Callable import collections from multiprocessing.pool import ThreadPool from math import ceil @@ -104,7 +105,7 @@ def _tokens_per_second(generated_text: str, max_token_length: int, latency: floa return min(est_tokens, max_token_length) / latency -def _timed_invoke(predict: callable, sample_input: object) -> tuple: +def _timed_invoke(predict: Callable, sample_input: object) -> tuple: """Placeholder docstring""" start_timer = perf_counter() response = predict(sample_input) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index bf028fd5b7..98a7bbab47 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -92,7 +92,7 @@ def __init__( framework_version: Optional[str] = None, py_version: Optional[str] = "py3", image_uri: Optional[Union[str, PipelineVariable]] = None, - predictor_cls: callable = SKLearnPredictor, + predictor_cls: Callable = SKLearnPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): diff --git a/src/sagemaker/spark/processing.py b/src/sagemaker/spark/processing.py index 40f45530a6..edc7c1b07d 100644 --- a/src/sagemaker/spark/processing.py +++ b/src/sagemaker/spark/processing.py @@ -1317,8 +1317,7 @@ class SparkConfigUtils: ] @staticmethod - @validate_call_inputs - def validate_configuration(configuration: Union[Dict, list]): + def validate_configuration(configuration: Union[Dict, List]): """Validates the user-provided Hadoop/Spark/Hive configuration. This ensures that the list or dictionary the user provides will serialize to diff --git a/src/sagemaker/tensorflow/model.py b/src/sagemaker/tensorflow/model.py index 1b35afbe7c..10aedd6b55 100644 --- a/src/sagemaker/tensorflow/model.py +++ b/src/sagemaker/tensorflow/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Union, Optional, List, Dict +from typing import Callable, Union, Optional, List, Dict import sagemaker from sagemaker import image_uris, s3, ModelMetrics @@ -141,7 +141,7 @@ def __init__( image_uri: Optional[Union[str, PipelineVariable]] = None, framework_version: Optional[str] = None, container_log_level: Optional[int] = None, - predictor_cls: callable = TensorFlowPredictor, + predictor_cls: Callable = TensorFlowPredictor, **kwargs, ): """Initialize a Model. diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index a55f0e9454..15eddf4450 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -25,7 +25,7 @@ import tarfile import tempfile import time -from typing import Union, Any, List, Optional, Dict +from typing import Union, Any, Callable, List, Optional, Dict import json import abc import uuid @@ -1493,7 +1493,7 @@ def format_tags(tags: Tags) -> List[TagsDict]: def validate_call_inputs( - __func: callable, + __func: Callable, *args, config: Optional[ConfigDict] = None, validate_return: bool = False, diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 74776f8f72..51206b5843 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -14,7 +14,7 @@ from __future__ import absolute_import import logging -from typing import Optional, Union, List, Dict +from typing import Callable, Optional, Union, List, Dict import sagemaker from sagemaker import image_uris, ModelMetrics @@ -86,7 +86,7 @@ def __init__( framework_version: str = None, image_uri: Optional[Union[str, PipelineVariable]] = None, py_version: str = "py3", - predictor_cls: callable = XGBoostPredictor, + predictor_cls: Callable = XGBoostPredictor, model_server_workers: Optional[Union[int, PipelineVariable]] = None, **kwargs ): From 4b101df180c177e00efc5acbf4e83f3ee129fbd3 Mon Sep 17 00:00:00 2001 From: martinRenou Date: Wed, 6 Dec 2023 17:33:16 +0100 Subject: [PATCH 4/6] Missing optional --- src/sagemaker/pytorch/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 2201f88ccb..3668716847 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -338,7 +338,7 @@ def prepare_container_def( def serving_image_uri( self, region_name: str, - instance_type: str, + instance_type: Optional[str] = None, accelerator_type: Optional[str] = None, serverless_inference_config: Optional[ServerlessInferenceConfig] = None, ): From 071616436f6087883f528227ea061ad5fdf3451c Mon Sep 17 00:00:00 2001 From: martinRenou Date: Thu, 7 Dec 2023 10:51:04 +0100 Subject: [PATCH 5/6] Fix sagemaker session mock --- src/sagemaker/algorithm.py | 3 +- src/sagemaker/chainer/estimator.py | 3 +- src/sagemaker/estimator.py | 15 +-- src/sagemaker/huggingface/estimator.py | 3 +- src/sagemaker/jumpstart/estimator.py | 6 +- src/sagemaker/jumpstart/factory/estimator.py | 5 +- src/sagemaker/pytorch/model.py | 2 +- tests/unit/__init__.py | 11 ++ .../sagemaker/huggingface/test_estimator.py | 3 + .../sagemaker/huggingface/test_processing.py | 4 + .../image_uris/jumpstart/test_catboost.py | 4 +- .../image_uris/jumpstart/test_huggingface.py | 5 +- .../image_uris/jumpstart/test_lightgbm.py | 5 +- .../image_uris/jumpstart/test_mxnet.py | 5 +- .../image_uris/jumpstart/test_pytorch.py | 5 +- .../image_uris/jumpstart/test_sklearn.py | 5 +- .../image_uris/jumpstart/test_tensorflow.py | 5 +- .../image_uris/jumpstart/test_xgboost.py | 5 +- .../jumpstart/estimator/test_estimator.py | 35 +++++- tests/unit/sagemaker/model/test_model.py | 5 + tests/unit/sagemaker/spark/test_processing.py | 107 +++++++++--------- .../sagemaker/tensorflow/test_estimator.py | 3 + .../tensorflow/test_estimator_attach.py | 4 +- .../tensorflow/test_estimator_init.py | 4 + tests/unit/sagemaker/tensorflow/test_tfs.py | 5 + .../test_huggingface_pytorch_compiler.py | 4 +- .../test_huggingface_tensorflow_compiler.py | 4 +- .../test_pytorch_compiler.py | 4 +- .../test_tensorflow_compiler.py | 4 +- tests/unit/sagemaker/workflow/test_airflow.py | 89 ++++++++------- tests/unit/sagemaker/workflow/test_utils.py | 3 + tests/unit/test_algorithm.py | 98 ++++++++++++---- tests/unit/test_amazon_estimator.py | 3 + tests/unit/test_chainer.py | 3 + tests/unit/test_djl_inference.py | 8 +- tests/unit/test_estimator.py | 34 ++++-- tests/unit/test_fm.py | 3 + tests/unit/test_ipinsights.py | 4 +- tests/unit/test_job.py | 5 +- tests/unit/test_kmeans.py | 4 +- tests/unit/test_knn.py | 4 +- tests/unit/test_lda.py | 4 +- tests/unit/test_linear_learner.py | 4 +- tests/unit/test_model_card.py | 65 +++++++---- tests/unit/test_mxnet.py | 3 + tests/unit/test_ntm.py | 3 + tests/unit/test_object2vec.py | 3 + tests/unit/test_pca.py | 4 +- tests/unit/test_processing.py | 4 + tests/unit/test_randomcutforest.py | 3 + tests/unit/test_rl.py | 3 + tests/unit/test_tuner.py | 10 +- tests/unit/test_utils.py | 7 +- tests/unit/test_xgboost.py | 3 + 54 files changed, 458 insertions(+), 189 deletions(-) diff --git a/src/sagemaker/algorithm.py b/src/sagemaker/algorithm.py index a177b93f03..b44862e610 100644 --- a/src/sagemaker/algorithm.py +++ b/src/sagemaker/algorithm.py @@ -13,6 +13,7 @@ """Test docstring""" from __future__ import absolute_import +from numbers import Number from typing import Optional, Union, Dict, List import sagemaker @@ -58,7 +59,7 @@ def __init__( output_kms_key: Optional[Union[str, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, diff --git a/src/sagemaker/chainer/estimator.py b/src/sagemaker/chainer/estimator.py index 09addf9910..183b62275d 100644 --- a/src/sagemaker/chainer/estimator.py +++ b/src/sagemaker/chainer/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import logging +from numbers import Number from typing import Union, Optional, Dict from sagemaker.estimator import Framework, EstimatorBase @@ -50,7 +51,7 @@ def __init__( process_slots_per_host: Optional[Union[int, PipelineVariable]] = None, additional_mpi_options: Optional[Union[str, PipelineVariable]] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, framework_version: Optional[str] = None, py_version: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index c7548e8ef5..fabd6cd0dd 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -19,6 +19,7 @@ import re import uuid from abc import ABCMeta, abstractmethod +from numbers import Number from typing import Any, Dict, Union, Optional, List from packaging.specifiers import SpecifierSet from packaging.version import Version @@ -136,7 +137,7 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man @validate_call_inputs def __init__( self, - role: Union[str, ParameterString] = None, + role: Optional[Union[str, ParameterString]] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None, @@ -160,7 +161,7 @@ def __init__( checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None, checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, rules: Optional[List[RuleBase]] = None, - debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None, + debugger_hook_config: Optional[Union[bool, DebuggerHookConfig, Dict]] = None, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, enable_network_isolation: Optional[Union[bool, PipelineVariable]] = None, @@ -170,7 +171,7 @@ def __init__( max_retry_attempts: Optional[Union[int, PipelineVariable]] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, git_config: Optional[Dict[str, str]] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, container_log_level: Union[int, PipelineVariable] = logging.INFO, code_location: Optional[str] = None, entry_point: Optional[Union[str, PipelineVariable]] = None, @@ -2711,7 +2712,7 @@ class Estimator(EstimatorBase): def __init__( self, image_uri: Union[str, PipelineVariable], - role: str = None, + role: Optional[str] = None, instance_count: Optional[Union[int, PipelineVariable]] = None, instance_type: Optional[Union[str, PipelineVariable]] = None, keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None, @@ -2723,7 +2724,7 @@ def __init__( output_kms_key: Optional[Union[str, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, @@ -2737,7 +2738,7 @@ def __init__( checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, enable_network_isolation: Union[bool, PipelineVariable] = None, rules: Optional[List[RuleBase]] = None, - debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None, + debugger_hook_config: Optional[Union[DebuggerHookConfig, bool, Dict]] = None, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, profiler_config: Optional[ProfilerConfig] = None, @@ -3290,7 +3291,7 @@ def __init__( self, entry_point: Union[str, PipelineVariable], source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, container_log_level: Union[int, PipelineVariable] = logging.INFO, code_location: Optional[str] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, diff --git a/src/sagemaker/huggingface/estimator.py b/src/sagemaker/huggingface/estimator.py index 86df43d4e9..a981616721 100644 --- a/src/sagemaker/huggingface/estimator.py +++ b/src/sagemaker/huggingface/estimator.py @@ -15,6 +15,7 @@ import logging import re +from numbers import Number from typing import Optional, Union, Dict from sagemaker.estimator import Framework, EstimatorBase @@ -47,7 +48,7 @@ def __init__( tensorflow_version: Optional[str] = None, pytorch_version: Optional[str] = None, source_dir: Optional[Union[str, PipelineVariable]] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, image_uri: Optional[Union[str, PipelineVariable]] = None, distribution: Optional[Dict] = None, compiler_config: Optional[TrainingCompilerConfig] = None, diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 36a188ed55..d94d5983c6 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -13,7 +13,7 @@ """This module stores JumpStart implementation of Estimator class.""" from __future__ import absolute_import - +from numbers import Number from typing import Dict, List, Optional, Union from sagemaker import session from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig @@ -72,7 +72,7 @@ def __init__( output_kms_key: Optional[Union[str, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[session.Session] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, @@ -86,7 +86,7 @@ def __init__( checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, enable_network_isolation: Union[bool, PipelineVariable] = None, rules: Optional[List[RuleBase]] = None, - debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None, + debugger_hook_config: Optional[Union[DebuggerHookConfig, bool, Dict]] = None, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, profiler_config: Optional[ProfilerConfig] = None, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index 7ccf57983b..9cc3c3915d 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -14,6 +14,7 @@ from __future__ import absolute_import +from numbers import Number from typing import Dict, List, Optional, Union from sagemaker import ( environment_variables, @@ -93,7 +94,7 @@ def get_init_kwargs( output_kms_key: Optional[Union[str, PipelineVariable]] = None, base_job_name: Optional[str] = None, sagemaker_session: Optional[Session] = None, - hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None, + hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None, tags: Optional[Tags] = None, subnets: Optional[List[Union[str, PipelineVariable]]] = None, security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None, @@ -107,7 +108,7 @@ def get_init_kwargs( checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None, enable_network_isolation: Union[bool, PipelineVariable] = None, rules: Optional[List[RuleBase]] = None, - debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None, + debugger_hook_config: Optional[Union[DebuggerHookConfig, bool, Dict]] = None, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None, enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None, profiler_config: Optional[ProfilerConfig] = None, diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index 3668716847..dcddccfba8 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -93,7 +93,7 @@ class PyTorchModel(FrameworkModel): @validate_call_inputs def __init__( self, - model_data: Union[str, PipelineVariable], + model_data: Union[str, PipelineVariable, dict], role: Optional[Union[str, ParameterString]] = None, entry_point: Optional[str] = None, framework_version: Optional[str] = "1.3", diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py index 17dcf06503..039f80a15d 100644 --- a/tests/unit/__init__.py +++ b/tests/unit/__init__.py @@ -88,6 +88,7 @@ ESTIMATOR, DEBUG_HOOK_CONFIG, ) +from sagemaker import Session DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") PY_VERSION = "py3" @@ -430,7 +431,12 @@ def _test_default_bucket_and_prefix_combinations( expected__with_user_input__with_default_bucket_only=None, session_with_bucket_and_prefix=Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), + local_mode=False, + s3_resource=None, sagemaker_config={}, + boto_session=Mock(name="boto_session"), default_bucket=Mock(name="default_bucket", return_value=DEFAULT_S3_BUCKET_NAME), default_bucket_prefix=DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, config=None, @@ -438,7 +444,12 @@ def _test_default_bucket_and_prefix_combinations( ), session_with_bucket_and_no_prefix=Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), + local_mode=False, + s3_resource=None, sagemaker_config={}, + boto_session=Mock(name="boto_session"), default_bucket_prefix=None, default_bucket=Mock(name="default_bucket", return_value=DEFAULT_S3_BUCKET_NAME), config=None, diff --git a/tests/unit/sagemaker/huggingface/test_estimator.py b/tests/unit/sagemaker/huggingface/test_estimator.py index 418a8d63db..74061f28e8 100644 --- a/tests/unit/sagemaker/huggingface/test_estimator.py +++ b/tests/unit/sagemaker/huggingface/test_estimator.py @@ -19,6 +19,7 @@ import pytest from mock import MagicMock, Mock, patch +from sagemaker import Session from sagemaker.huggingface import HuggingFace, HuggingFaceModel from sagemaker.session_settings import SessionSettings @@ -59,6 +60,8 @@ def fixture_sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, diff --git a/tests/unit/sagemaker/huggingface/test_processing.py b/tests/unit/sagemaker/huggingface/test_processing.py index 491f4ab5df..db22db24f6 100644 --- a/tests/unit/sagemaker/huggingface/test_processing.py +++ b/tests/unit/sagemaker/huggingface/test_processing.py @@ -15,6 +15,7 @@ import pytest from mock import Mock, patch, MagicMock +from sagemaker import Session from sagemaker.huggingface.processing import HuggingFaceProcessor from sagemaker.fw_utils import UploadedCode from sagemaker.session_settings import SessionSettings @@ -39,9 +40,12 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session_mock = MagicMock( name="sagemaker_session", + spec=Session, + sagemaker_client=MagicMock(), boto_session=boto_mock, boto_region_name=REGION, config=None, + s3_resource=None, local_mode=False, settings=SessionSettings(), default_bucket_prefix=None, diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py index 9261fd561e..284795ba6a 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_catboost.py @@ -14,7 +14,7 @@ from mock.mock import patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.jumpstart import accessors from sagemaker.pytorch.estimator import PyTorch from sagemaker.pytorch.model import PyTorchModel @@ -22,10 +22,12 @@ from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_catboost_image_uri(patched_get_model_specs, session): # For tests which doesn't verify config file injection, operate with empty config session.sagemaker_config = {} + session.local_mode = False patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py b/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py index 1ce213cd27..60312be353 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_huggingface.py @@ -14,7 +14,7 @@ from mock.mock import patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.huggingface.estimator import HuggingFace from sagemaker.jumpstart import accessors from sagemaker.huggingface.model import HuggingFaceModel @@ -22,8 +22,11 @@ from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_huggingface_image_uri(patched_get_model_specs, session): + session.sagemaker_config = {} + session.local_mode = False patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py b/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py index e907a19b51..3d3364e660 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_lightgbm.py @@ -14,7 +14,7 @@ from mock.mock import patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.jumpstart import accessors from sagemaker.pytorch.estimator import PyTorch from sagemaker.pytorch.model import PyTorchModel @@ -22,8 +22,11 @@ from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_lightgbm_image_uri(patched_get_model_specs, session): + session.sagemaker_config = {} + session.local_mode = False patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py b/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py index 9fd09d47d9..78cb99ca32 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_mxnet.py @@ -14,7 +14,7 @@ from mock.mock import patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.jumpstart import accessors from sagemaker.mxnet.estimator import MXNet from sagemaker.mxnet.model import MXNetModel @@ -22,8 +22,11 @@ from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_mxnet_image_uri(patched_get_model_specs, session): + session.sagemaker_config = {} + session.local_mode = False patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py b/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py index a94801da10..0669859c63 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_pytorch.py @@ -14,7 +14,7 @@ from mock.mock import patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.jumpstart import accessors from sagemaker.pytorch.estimator import PyTorch from sagemaker.pytorch.model import PyTorchModel @@ -22,8 +22,11 @@ from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_pytorch_image_uri(patched_get_model_specs, session): + session.sagemaker_config = {} + session.local_mode = False patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py b/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py index 1410c59bb6..9f0af56327 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_sklearn.py @@ -15,7 +15,7 @@ from mock.mock import patch import pytest -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.jumpstart import accessors from sagemaker.sklearn.estimator import SKLearn from sagemaker.sklearn.model import SKLearnModel @@ -23,8 +23,11 @@ from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_sklearn_image_uri(patched_get_model_specs, session): + session.sagemaker_config = {} + session.local_mode = False patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py b/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py index c924615212..96bfe8d418 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_tensorflow.py @@ -14,7 +14,7 @@ from mock.mock import patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.jumpstart import accessors from sagemaker.tensorflow.model import TensorFlowModel from sagemaker.tensorflow.estimator import TensorFlow @@ -22,8 +22,11 @@ from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_tensorflow_image_uri(patched_get_model_specs, session): + session.sagemaker_config = {} + session.local_mode = False patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py b/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py index 5da3b71176..398f6906c4 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_xgboost.py @@ -14,7 +14,7 @@ from mock.mock import patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.jumpstart import accessors from sagemaker.xgboost.model import XGBoostModel from sagemaker.xgboost.estimator import XGBoost @@ -22,8 +22,11 @@ from tests.unit.sagemaker.jumpstart.utils import get_prototype_model_spec +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs") def test_jumpstart_xgboost_image_uri(patched_get_model_specs, session): + session.sagemaker_config = {} + session.local_mode = False patched_get_model_specs.side_effect = get_prototype_model_spec diff --git a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py index 29eff40461..531c3f8368 100644 --- a/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py +++ b/tests/unit/sagemaker/jumpstart/estimator/test_estimator.py @@ -36,6 +36,7 @@ from sagemaker.jumpstart.estimator import JumpStartEstimator from sagemaker.jumpstart.utils import get_jumpstart_content_bucket +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from tests.integ.sagemaker.jumpstart.utils import get_training_dataset_for_model_and_version from sagemaker.model import Model @@ -925,11 +926,24 @@ def test_jumpstart_estimator_tags_disabled( settings = SessionSettings(include_jumpstart_tags=False) + boto_mock = mock.MagicMock(name="boto_session", region_name=region) mock_session = mock.MagicMock( - sagemaker_config={}, boto_region_name="us-west-2", settings=settings + name="sagemaker_session", + spec=Session, + sagemaker_client=mock.MagicMock(), + sagemaker_config={}, + boto_session=boto_mock, + boto_region_name=region, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + settings=settings, + default_bucket_prefix=None, ) estimator = JumpStartEstimator( + role="mock_role", model_id=model_id, sagemaker_session=mock_session, tags=[{"Key": "blah", "Value": "blahagain"}], @@ -962,9 +976,26 @@ def test_jumpstart_estimator_tags( mock_get_model_specs.side_effect = get_special_model_spec - mock_session = mock.MagicMock(sagemaker_config={}, boto_region_name="us-west-2") + settings = SessionSettings() + + boto_mock = mock.MagicMock(name="boto_session", region_name=region) + mock_session = mock.MagicMock( + name="sagemaker_session", + spec=Session, + sagemaker_client=mock.MagicMock(), + sagemaker_config={}, + boto_session=boto_mock, + boto_region_name=region, + config=None, + local_mode=False, + s3_resource=None, + s3_client=None, + settings=settings, + default_bucket_prefix=None, + ) estimator = JumpStartEstimator( + role="mock_role", model_id=model_id, sagemaker_session=mock_session, tags=[{"Key": "blah", "Value": "blahagain"}], diff --git a/tests/unit/sagemaker/model/test_model.py b/tests/unit/sagemaker/model/test_model.py index de86fcf99a..24b87d2f02 100644 --- a/tests/unit/sagemaker/model/test_model.py +++ b/tests/unit/sagemaker/model/test_model.py @@ -35,6 +35,7 @@ from sagemaker.enums import EndpointType from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements from sagemaker.workflow.properties import Properties +from sagemaker import Session from tests.unit import ( _test_default_bucket_and_prefix_combinations, DEFAULT_S3_BUCKET_NAME, @@ -115,8 +116,11 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) sms = MagicMock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, + settings=Mock(), config=None, local_mode=False, s3_client=None, @@ -924,6 +928,7 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path( model_data=source_dir, **kwargs, ).deploy( + endpoint_name="test_endpoint", instance_type="ml.m2.xlarge", initial_instance_count=INSTANCE_COUNT, endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED, diff --git a/tests/unit/sagemaker/spark/test_processing.py b/tests/unit/sagemaker/spark/test_processing.py index c20f16c35c..d6662bf0cd 100644 --- a/tests/unit/sagemaker/spark/test_processing.py +++ b/tests/unit/sagemaker/spark/test_processing.py @@ -17,6 +17,7 @@ import pytest +from sagemaker import Session from sagemaker.processing import ProcessingInput, ProcessingOutput from sagemaker.session_settings import SessionSettings from sagemaker.spark.processing import ( @@ -31,6 +32,10 @@ SPARK_EVENT_LOGS_S3_URI = "s3://bucket/spark-events" REGION = "us-east-1" BUCKET_NAME = "bucket" +PROCESSING_INPUT = ProcessingInput( + source="s3_uri", + destination="destination", +) @pytest.fixture @@ -42,19 +47,13 @@ def processing_output(): ) -@pytest.fixture -def processing_input(): - return ProcessingInput( - source="s3_uri", - destination="destination", - ) - - @pytest.fixture() def sagemaker_session(): boto_mock = MagicMock(name="boto_session", region_name=REGION) session_mock = MagicMock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, @@ -225,7 +224,7 @@ def test_spark_processor_base_run(mock_super_run, spark_processor_base): "inputs": None, "outputs": None, }, - {"inputs": [processing_input], "outputs": None}, + {"inputs": [PROCESSING_INPUT], "outputs": None}, ), ( { @@ -234,7 +233,7 @@ def test_spark_processor_base_run(mock_super_run, spark_processor_base): "inputs": [], "outputs": None, }, - {"inputs": [processing_input], "outputs": None}, + {"inputs": [PROCESSING_INPUT], "outputs": None}, ), ], ) @@ -248,7 +247,7 @@ def test_spark_processor_base_extend_processing_args( expected, sagemaker_session, ): - mock_stage_configuration.return_value = processing_input + mock_stage_configuration.return_value = PROCESSING_INPUT mock_processing_output.return_value = processing_output extended_inputs, extended_outputs = spark_processor_base._extend_processing_args( @@ -732,7 +731,7 @@ def test_check_history_server( { "inputs": None, "submit_files": None, - "files_input": [processing_input], + "files_input": [PROCESSING_INPUT], "files_opt": "opt", "file_type": FileType.JAR, }, @@ -742,34 +741,34 @@ def test_check_history_server( { "inputs": None, "submit_files": ["file1"], - "files_input": processing_input, + "files_input": PROCESSING_INPUT, "files_opt": "opt", "file_type": FileType.JAR, }, - {"command": ["smspark-submit", "--jars", "opt"], "inputs": [processing_input]}, + {"command": ["smspark-submit", "--jars", "opt"], "inputs": [PROCESSING_INPUT]}, ), ( { - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "submit_files": ["file1"], - "files_input": processing_input, + "files_input": PROCESSING_INPUT, "files_opt": "opt", "file_type": FileType.PYTHON, }, { "command": ["smspark-submit", "--py-files", "opt"], - "inputs": [processing_input, processing_input], + "inputs": [PROCESSING_INPUT, PROCESSING_INPUT], }, ), ( { - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "submit_files": ["file1"], "files_input": None, "files_opt": "", "file_type": FileType.PYTHON, }, - {"command": ["smspark-submit"], "inputs": [processing_input]}, + {"command": ["smspark-submit"], "inputs": [PROCESSING_INPUT]}, ), ], ) @@ -829,25 +828,25 @@ def test_config_aws_credentials(py_spark_processor): [ ({"submit_app": None, "files": ["test"], "inputs": [], "opt": None}, ValueError), ( - {"submit_app": "test.py", "files": None, "inputs": [processing_input], "opt": None}, - [processing_input], + {"submit_app": "test.py", "files": None, "inputs": [PROCESSING_INPUT], "opt": None}, + [PROCESSING_INPUT], ), ( { "submit_app": "test.py", "files": ["test"], - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "opt": None, }, - [processing_input, processing_input, processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT], ), ( {"submit_app": "test.py", "files": ["test"], "inputs": None, "opt": None}, - [processing_input, processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT], ), ( {"submit_app": "test.py", "files": ["test"], "inputs": None, "opt": "opt"}, - [processing_input, processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT], ), ], ) @@ -862,7 +861,7 @@ def test_py_spark_processor_run( config, expected, ): - mock_stage_submit_deps.return_value = (processing_input, "opt") + mock_stage_submit_deps.return_value = (PROCESSING_INPUT, "opt") mock_generate_current_job_name.return_value = "jobName" if expected is ValueError: @@ -913,21 +912,21 @@ def test_py_spark_processor_run( { "submit_app": "test.py", "files": None, - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "opt": None, "arguments": ["arg1"], }, - [processing_input], + [PROCESSING_INPUT], ), ( { "submit_app": "test.py", "files": ["test"], - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "opt": None, "arguments": ["arg1"], }, - [processing_input, processing_input, processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT], ), ( { @@ -937,7 +936,7 @@ def test_py_spark_processor_run( "opt": None, "arguments": ["arg1"], }, - [processing_input, processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT], ), ( { @@ -947,7 +946,7 @@ def test_py_spark_processor_run( "opt": "opt", "arguments": ["arg1"], }, - [processing_input, processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT], ), ], ) @@ -962,7 +961,7 @@ def test_py_spark_processor_get_run_args( config, expected, ): - mock_stage_submit_deps.return_value = (processing_input, "opt") + mock_stage_submit_deps.return_value = (PROCESSING_INPUT, "opt") mock_generate_current_job_name.return_value = "jobName" if expected is ValueError: @@ -999,22 +998,22 @@ def test_py_spark_processor_get_run_args( def test_py_spark_processor_run_twice( mock_generate_current_job_name, mock_stage_submit_deps, mock_super_run, py_spark_processor ): - mock_stage_submit_deps.return_value = (processing_input, "opt") + mock_stage_submit_deps.return_value = (PROCESSING_INPUT, "opt") mock_generate_current_job_name.return_value = "jobName" expected_command = ["smspark-submit", "--py-files", "opt", "--jars", "opt", "--files", "opt"] py_spark_processor.run( submit_app="submit_app", - submit_py_files="files", - submit_jars="test", - submit_files="test", + submit_py_files=["files"], + submit_jars=["test"], + submit_files=["test"], inputs=[], ) py_spark_processor.run( submit_app="submit_app", - submit_py_files="files", - submit_jars="test", - submit_files="test", + submit_py_files=["files"], + submit_jars=["test"], + submit_files=["test"], inputs=[], ) @@ -1049,20 +1048,20 @@ def test_py_spark_processor_run_twice( "submit_app": "test.py", "submit_class": "_class", "files": None, - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "opt": None, }, - [processing_input], + [PROCESSING_INPUT], ), ( { "submit_app": "test.py", "submit_class": "_class", "files": ["test"], - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "opt": None, }, - [processing_input, processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT], ), ( { @@ -1072,7 +1071,7 @@ def test_py_spark_processor_run_twice( "inputs": None, "opt": None, }, - [processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT], ), ( { @@ -1082,7 +1081,7 @@ def test_py_spark_processor_run_twice( "inputs": None, "opt": "opt", }, - [processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT], ), ], ) @@ -1097,7 +1096,7 @@ def test_spark_jar_processor_run( expected, sagemaker_session, ): - mock_stage_submit_deps.return_value = (processing_input, "opt") + mock_stage_submit_deps.return_value = (PROCESSING_INPUT, "opt") mock_generate_current_job_name.return_value = "jobName" spark_jar_processor = SparkJarProcessor( @@ -1173,24 +1172,24 @@ def test_spark_jar_processor_run( "submit_app": "test.py", "submit_class": "_class", "files": None, - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "opt": None, "arguments": ["arg1"], "kms_key": "test_kms_key", }, - [processing_input], + [PROCESSING_INPUT], ), ( { "submit_app": "test.py", "submit_class": "_class", "files": ["test"], - "inputs": [processing_input], + "inputs": [PROCESSING_INPUT], "opt": None, "arguments": ["arg1"], "kms_key": "test_kms_key", }, - [processing_input, processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT, PROCESSING_INPUT], ), ( { @@ -1202,7 +1201,7 @@ def test_spark_jar_processor_run( "arguments": ["arg1"], "kms_key": "test_kms_key", }, - [processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT], ), ( { @@ -1214,7 +1213,7 @@ def test_spark_jar_processor_run( "arguments": ["arg1"], "kms_key": "test_kms_key", }, - [processing_input, processing_input], + [PROCESSING_INPUT, PROCESSING_INPUT], ), ], ) @@ -1229,7 +1228,7 @@ def test_spark_jar_processor_get_run_args( expected, sagemaker_session, ): - mock_stage_submit_deps.return_value = (processing_input, "opt") + mock_stage_submit_deps.return_value = (PROCESSING_INPUT, "opt") mock_generate_current_job_name.return_value = "jobName" spark_jar_processor = SparkJarProcessor( diff --git a/tests/unit/sagemaker/tensorflow/test_estimator.py b/tests/unit/sagemaker/tensorflow/test_estimator.py index d6eaf74012..d030a6f0d3 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator.py @@ -25,6 +25,7 @@ from sagemaker.tensorflow import TensorFlow from sagemaker.instance_group import InstanceGroup from sagemaker.workflow.parameters import ParameterString, ParameterBoolean +from sagemaker import Session from tests.unit import DATA_DIR SCRIPT_FILE = "dummy_script.py" @@ -70,6 +71,8 @@ def sagemaker_session(): boto_session=boto_mock, boto_region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_resource=None, s3_client=None, diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py index b8ec3af69d..828cfb1fca 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_attach.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_attach.py @@ -16,7 +16,7 @@ from mock import MagicMock, Mock, patch from packaging.version import Version -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.tensorflow import TensorFlow BUCKET_NAME = "mybucket" @@ -29,7 +29,9 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( + spec=Session, name="sagemaker_session", + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, diff --git a/tests/unit/sagemaker/tensorflow/test_estimator_init.py b/tests/unit/sagemaker/tensorflow/test_estimator_init.py index 3ea09d5b10..acb233a740 100644 --- a/tests/unit/sagemaker/tensorflow/test_estimator_init.py +++ b/tests/unit/sagemaker/tensorflow/test_estimator_init.py @@ -16,6 +16,7 @@ from packaging import version import pytest +from sagemaker import Session from sagemaker.tensorflow import TensorFlow REGION = "us-west-2" @@ -26,7 +27,10 @@ @pytest.fixture() def sagemaker_session(): session_mock = Mock( + spec=Session, name="sagemaker_session", + sagemaker_client=Mock(), + local_mode=False, boto_region_name=REGION, default_bucket_prefix=None, ) diff --git a/tests/unit/sagemaker/tensorflow/test_tfs.py b/tests/unit/sagemaker/tensorflow/test_tfs.py index eaa5b3c947..6eb2758369 100644 --- a/tests/unit/sagemaker/tensorflow/test_tfs.py +++ b/tests/unit/sagemaker/tensorflow/test_tfs.py @@ -20,6 +20,7 @@ import pytest from mock import Mock, patch, ANY +from sagemaker import Session from sagemaker.serializers import CSVSerializer, IdentitySerializer from sagemaker.tensorflow import TensorFlow, TensorFlowModel, TensorFlowPredictor @@ -54,8 +55,12 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, boto_session=boto_mock, boto_region_name=REGION, + sagemaker_client=Mock(), + sagemaker_runtime_client=Mock(), + settings=Mock(), config=None, local_mode=False, s3_resource=None, diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py index 96f6998af6..d8059677e3 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_pytorch_compiler.py @@ -20,7 +20,7 @@ from mock import MagicMock, Mock, patch, ANY from packaging.version import Version -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig from sagemaker.huggingface.model import HuggingFaceModel from sagemaker.instance_group import InstanceGroup @@ -67,6 +67,8 @@ def fixture_sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, diff --git a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py index a650379dfd..6dcccc76af 100644 --- a/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_huggingface_tensorflow_compiler.py @@ -19,7 +19,7 @@ import pytest from mock import MagicMock, Mock, patch, ANY -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.huggingface import HuggingFace, TrainingCompilerConfig from sagemaker.huggingface.model import HuggingFaceModel from sagemaker.session_settings import SessionSettings @@ -65,6 +65,8 @@ def fixture_sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, diff --git a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py index 9a7ba698f3..dc0e878f44 100644 --- a/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_pytorch_compiler.py @@ -20,7 +20,7 @@ from mock import MagicMock, Mock, patch, ANY from packaging.version import Version -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.pytorch import PyTorch, TrainingCompilerConfig from sagemaker.pytorch.model import PyTorchModel from sagemaker.instance_group import InstanceGroup @@ -66,6 +66,8 @@ def fixture_sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, diff --git a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py index 67530bc288..2801fec737 100644 --- a/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py +++ b/tests/unit/sagemaker/training_compiler/test_tensorflow_compiler.py @@ -20,7 +20,7 @@ import pytest from mock import MagicMock, Mock, patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.session_settings import SessionSettings from sagemaker.tensorflow import TensorFlow, TrainingCompilerConfig @@ -73,9 +73,11 @@ def fixture_sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, boto_session=boto_mock, boto_region_name=REGION, config=None, + sagemaker_client=Mock(), local_mode=False, s3_resource=None, s3_client=None, diff --git a/tests/unit/sagemaker/workflow/test_airflow.py b/tests/unit/sagemaker/workflow/test_airflow.py index 742320cfb8..b6cf7bfee1 100644 --- a/tests/unit/sagemaker/workflow/test_airflow.py +++ b/tests/unit/sagemaker/workflow/test_airflow.py @@ -15,7 +15,17 @@ import pytest from mock import Mock, MagicMock, patch -from sagemaker import chainer, estimator, model, mxnet, tensorflow, transformer, tuner, processing +from sagemaker import ( + Session, + chainer, + estimator, + model, + mxnet, + tensorflow, + transformer, + tuner, + processing, +) from sagemaker.network import NetworkConfig from sagemaker.processing import ProcessingInput, ProcessingOutput from sagemaker.workflow import airflow @@ -32,9 +42,12 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, + settings=Mock(), local_mode=False, s3_resource=None, s3_client=None, @@ -53,7 +66,7 @@ def test_byo_training_config_required_args(sagemaker_session): byo = estimator.Estimator( image_uri="byo", role="{{ role }}", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.c4.2xlarge", sagemaker_session=sagemaker_session, ) @@ -69,7 +82,7 @@ def test_byo_training_config_required_args(sagemaker_session): "TrainingJobName": "byo-%s" % TIME_STAMP, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": { - "InstanceCount": "{{ instance_count }}", + "InstanceCount": 1, "InstanceType": "ml.c4.2xlarge", "VolumeSizeInGB": 30, }, @@ -96,11 +109,11 @@ def test_byo_training_config_all_args(sagemaker_session): byo = estimator.Estimator( image_uri="byo", role="{{ role }}", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.c4.2xlarge", - volume_size="{{ volume_size }}", + volume_size=1024, volume_kms_key="{{ volume_kms_key }}", - max_run="{{ max_run }}", + max_run=1000, input_mode="Pipe", output_path="{{ output_path }}", output_kms_key="{{ output_volume_kms_key }}", @@ -126,11 +139,11 @@ def test_byo_training_config_all_args(sagemaker_session): "KmsKeyId": "{{ output_volume_kms_key }}", }, "TrainingJobName": "{{ base_job_name }}-%s" % TIME_STAMP, - "StoppingCondition": {"MaxRuntimeInSeconds": "{{ max_run }}"}, + "StoppingCondition": {"MaxRuntimeInSeconds": 1000}, "ResourceConfig": { - "InstanceCount": "{{ instance_count }}", + "InstanceCount": 1, "InstanceType": "ml.c4.2xlarge", - "VolumeSizeInGB": "{{ volume_size }}", + "VolumeSizeInGB": 1024, "VolumeKmsKeyId": "{{ volume_kms_key }}", }, "RoleArn": "{{ role }}", @@ -188,7 +201,7 @@ def test_framework_training_config_required_args(retrieve_image_uri, sagemaker_s framework_version="1.15.2", py_version="py3", role="{{ role }}", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.c4.2xlarge", sagemaker_session=sagemaker_session, ) @@ -205,7 +218,7 @@ def test_framework_training_config_required_args(retrieve_image_uri, sagemaker_s "TrainingJobName": "tensorflow-training-%s" % TIME_STAMP, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": { - "InstanceCount": "{{ instance_count }}", + "InstanceCount": 1, "InstanceType": "ml.c4.2xlarge", "VolumeSizeInGB": 30, }, @@ -268,9 +281,9 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio role="{{ role }}", instance_count=1, instance_type="ml.c4.2xlarge", - volume_size="{{ volume_size }}", + volume_size=1024, volume_kms_key="{{ volume_kms_key }}", - max_run="{{ max_run }}", + max_run=1000, input_mode="Pipe", output_path="{{ output_path }}", output_kms_key="{{ output_volume_kms_key }}", @@ -298,11 +311,11 @@ def test_framework_training_config_all_args(retrieve_image_uri, sagemaker_sessio "KmsKeyId": "{{ output_volume_kms_key }}", }, "TrainingJobName": "{{ base_job_name }}-%s" % TIME_STAMP, - "StoppingCondition": {"MaxRuntimeInSeconds": "{{ max_run }}"}, + "StoppingCondition": {"MaxRuntimeInSeconds": 1000}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.c4.2xlarge", - "VolumeSizeInGB": "{{ volume_size }}", + "VolumeSizeInGB": 1024, "VolumeKmsKeyId": "{{ volume_kms_key }}", }, "RoleArn": "{{ role }}", @@ -357,7 +370,7 @@ def test_amazon_alg_training_config_required_args(sagemaker_session): ntm_estimator = ntm.NTM( role="{{ role }}", num_topics=10, - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.c4.2xlarge", sagemaker_session=sagemaker_session, ) @@ -376,7 +389,7 @@ def test_amazon_alg_training_config_required_args(sagemaker_session): "TrainingJobName": "ntm-%s" % TIME_STAMP, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": { - "InstanceCount": "{{ instance_count }}", + "InstanceCount": 1, "InstanceType": "ml.c4.2xlarge", "VolumeSizeInGB": 30, }, @@ -408,11 +421,11 @@ def test_amazon_alg_training_config_all_args(sagemaker_session): ntm_estimator = ntm.NTM( role="{{ role }}", num_topics=10, - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.c4.2xlarge", - volume_size="{{ volume_size }}", + volume_size=1024, volume_kms_key="{{ volume_kms_key }}", - max_run="{{ max_run }}", + max_run=1000, input_mode="Pipe", output_path="{{ output_path }}", output_kms_key="{{ output_volume_kms_key }}", @@ -438,11 +451,11 @@ def test_amazon_alg_training_config_all_args(sagemaker_session): "KmsKeyId": "{{ output_volume_kms_key }}", }, "TrainingJobName": "{{ base_job_name }}-%s" % TIME_STAMP, - "StoppingCondition": {"MaxRuntimeInSeconds": "{{ max_run }}"}, + "StoppingCondition": {"MaxRuntimeInSeconds": 1000}, "ResourceConfig": { - "InstanceCount": "{{ instance_count }}", + "InstanceCount": 1, "InstanceType": "ml.c4.2xlarge", - "VolumeSizeInGB": "{{ volume_size }}", + "VolumeSizeInGB": 1024, "VolumeKmsKeyId": "{{ volume_kms_key }}", }, "RoleArn": "{{ role }}", @@ -1092,7 +1105,7 @@ def test_model_config_from_framework_estimator(retrieve_image_uri, sagemaker_ses def test_model_config_from_amazon_alg_estimator(sagemaker_session): knn_estimator = knn.KNN( role="{{ role }}", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.m4.xlarge", k=16, sample_size=128, @@ -1126,7 +1139,7 @@ def test_model_config_from_amazon_alg_estimator(sagemaker_session): def test_transform_config(sagemaker_session): tf_transformer = transformer.Transformer( model_name="tensorflow-model", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.p2.xlarge", strategy="SingleRecord", assemble_with="Line", @@ -1173,7 +1186,7 @@ def test_transform_config(sagemaker_session): "Accept": "{{ accept }}", }, "TransformResources": { - "InstanceCount": "{{ instance_count }}", + "InstanceCount": 1, "InstanceType": "ml.p2.xlarge", "VolumeKmsKeyId": "{{ kms_key }}", }, @@ -1232,7 +1245,7 @@ def test_transform_config_from_framework_estimator(retrieve_image_uri, sagemaker estimator=mxnet_estimator, task_id="task_id", task_type="training", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.p2.xlarge", data=transform_data, input_filter="{{ input_filter }}", @@ -1267,7 +1280,7 @@ def test_transform_config_from_framework_estimator(retrieve_image_uri, sagemaker }, "TransformOutput": {"S3OutputPath": "s3://output/{{ base_job_name }}-%s" % TIME_STAMP}, "TransformResources": { - "InstanceCount": "{{ instance_count }}", + "InstanceCount": 1, "InstanceType": "ml.p2.xlarge", }, "Environment": {}, @@ -1286,7 +1299,7 @@ def test_transform_config_from_framework_estimator(retrieve_image_uri, sagemaker def test_transform_config_from_amazon_alg_estimator(sagemaker_session): knn_estimator = knn.KNN( role="{{ role }}", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.m4.xlarge", k=16, sample_size=128, @@ -1304,7 +1317,7 @@ def test_transform_config_from_amazon_alg_estimator(sagemaker_session): estimator=knn_estimator, task_id="task_id", task_type="training", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.p2.xlarge", data=transform_data, ) @@ -1329,7 +1342,7 @@ def test_transform_config_from_amazon_alg_estimator(sagemaker_session): }, "TransformOutput": {"S3OutputPath": "s3://output/knn-%s" % TIME_STAMP}, "TransformResources": { - "InstanceCount": "{{ instance_count }}", + "InstanceCount": 1, "InstanceType": "ml.p2.xlarge", }, }, @@ -1353,7 +1366,7 @@ def test_deploy_framework_model_config(sagemaker_session): ) config = airflow.deploy_config( - chainer_model, initial_instance_count="{{ instance_count }}", instance_type="ml.m4.xlarge" + chainer_model, initial_instance_count=1, instance_type="ml.m4.xlarge" ) expected_config = { "Model": { @@ -1377,7 +1390,7 @@ def test_deploy_framework_model_config(sagemaker_session): "ProductionVariants": [ { "InstanceType": "ml.m4.xlarge", - "InitialInstanceCount": "{{ instance_count }}", + "InitialInstanceCount": 1, "ModelName": "sagemaker-chainer-%s" % TIME_STAMP, "VariantName": "AllTraffic", "InitialVariantWeight": 1, @@ -1410,7 +1423,7 @@ def test_deploy_amazon_alg_model_config(sagemaker_session): ) config = airflow.deploy_config( - pca_model, initial_instance_count="{{ instance_count }}", instance_type="ml.c4.xlarge" + pca_model, initial_instance_count=1, instance_type="ml.c4.xlarge" ) expected_config = { "Model": { @@ -1427,7 +1440,7 @@ def test_deploy_amazon_alg_model_config(sagemaker_session): "ProductionVariants": [ { "InstanceType": "ml.c4.xlarge", - "InitialInstanceCount": "{{ instance_count }}", + "InitialInstanceCount": 1, "ModelName": "pca-%s" % TIME_STAMP, "VariantName": "AllTraffic", "InitialVariantWeight": 1, @@ -1528,7 +1541,7 @@ def test_deploy_config_from_framework_estimator(retrieve_image_uri, sagemaker_se def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): knn_estimator = knn.KNN( role="{{ role }}", - instance_count="{{ instance_count }}", + instance_count=1, instance_type="ml.m4.xlarge", k=16, sample_size=128, @@ -1545,7 +1558,7 @@ def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): estimator=knn_estimator, task_id="task_id", task_type="tuning", - initial_instance_count="{{ instance_count }}", + initial_instance_count=1, instance_type="ml.p2.xlarge", ) expected_config = { @@ -1564,7 +1577,7 @@ def test_deploy_config_from_amazon_alg_estimator(sagemaker_session): "ProductionVariants": [ { "InstanceType": "ml.p2.xlarge", - "InitialInstanceCount": "{{ instance_count }}", + "InitialInstanceCount": 1, "ModelName": "knn-%s" % TIME_STAMP, "VariantName": "AllTraffic", "InitialVariantWeight": 1, diff --git a/tests/unit/sagemaker/workflow/test_utils.py b/tests/unit/sagemaker/workflow/test_utils.py index 48b1d762c3..30898e6421 100644 --- a/tests/unit/sagemaker/workflow/test_utils.py +++ b/tests/unit/sagemaker/workflow/test_utils.py @@ -31,6 +31,9 @@ from tests.unit import DATA_DIR from tests.unit.sagemaker.workflow.conftest import ROLE, IMAGE_URI, BUCKET +REGION = "us-west-2" +BUCKET_NAME = "output" + @pytest.fixture def estimator(sagemaker_session): diff --git a/tests/unit/test_algorithm.py b/tests/unit/test_algorithm.py index 0e15981b24..a2cb0dfd77 100644 --- a/tests/unit/test_algorithm.py +++ b/tests/unit/test_algorithm.py @@ -18,6 +18,7 @@ import pytest from mock import Mock, patch +from sagemaker.session import Session from sagemaker.algorithm import AlgorithmEstimator from sagemaker.estimator import _TrainingJob from sagemaker.transformer import Transformer @@ -154,11 +155,14 @@ } -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_supported_input_mode_with_valid_input_types(session): # verify that the Estimator verifies the # input mode that an Algorithm supports. session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False + session.local_mode = False file_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) file_mode_algo["TrainingSpecification"]["TrainingChannels"] = [ @@ -256,11 +260,13 @@ def test_algorithm_supported_input_mode_with_valid_input_types(session): ) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_supported_input_mode_with_bad_input_types(session): # verify that the Estimator verifies raises exceptions when # attempting to train with an incorrect input type session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False file_mode_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) file_mode_algo["TrainingSpecification"]["TrainingChannels"] = [ @@ -329,9 +335,11 @@ def test_algorithm_supported_input_mode_with_bad_input_types(session): @patch("sagemaker.estimator.EstimatorBase.fit", Mock()) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_trainining_channels_with_expected_channels(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) training_channels["TrainingSpecification"]["TrainingChannels"] = [ @@ -371,9 +379,11 @@ def test_algorithm_trainining_channels_with_expected_channels(session): @patch("sagemaker.estimator.EstimatorBase.fit", Mock()) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_trainining_channels_with_invalid_channels(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False training_channels = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) training_channels["TrainingSpecification"]["TrainingChannels"] = [ @@ -414,9 +424,11 @@ def test_algorithm_trainining_channels_with_invalid_channels(session): estimator.fit({"training": "s3://some/data", "training2": "s3://some/other/data"}) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_train_instance_types_valid_instance_types(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) instance_types = ["ml.m4.xlarge", "ml.m5.2xlarge"] @@ -443,9 +455,11 @@ def test_algorithm_train_instance_types_valid_instance_types(session): ) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_train_instance_types_invalid_instance_types(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False describe_algo_response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) instance_types = ["ml.m4.xlarge", "ml.m5.2xlarge"] @@ -466,9 +480,11 @@ def test_algorithm_train_instance_types_invalid_instance_types(session): ) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_distributed_training_validation(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False distributed_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) distributed_algo["TrainingSpecification"]["SupportsDistributedTraining"] = True @@ -507,9 +523,11 @@ def test_algorithm_distributed_training_validation(session): ) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_hyperparameter_integer_range_valid_range(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False hyperparameters = [ { "Description": "Grow a tree with max_leaf_nodes in best-first fashion.", @@ -541,9 +559,11 @@ def test_algorithm_hyperparameter_integer_range_valid_range(session): estimator.set_hyperparameters(max_leaf_nodes=100000) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_hyperparameter_integer_range_invalid_range(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False hyperparameters = [ { "Description": "Grow a tree with max_leaf_nodes in best-first fashion.", @@ -578,9 +598,11 @@ def test_algorithm_hyperparameter_integer_range_invalid_range(session): estimator.set_hyperparameters(max_leaf_nodes=100001) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_hyperparameter_continuous_range_valid_range(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False hyperparameters = [ { "Description": "A continuous hyperparameter", @@ -614,9 +636,11 @@ def test_algorithm_hyperparameter_continuous_range_valid_range(session): estimator.set_hyperparameters(max_leaf_nodes=1) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_hyperparameter_continuous_range_invalid_range(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False hyperparameters = [ { "Description": "A continuous hyperparameter", @@ -651,9 +675,11 @@ def test_algorithm_hyperparameter_continuous_range_invalid_range(session): estimator.set_hyperparameters(max_leaf_nodes=-0.1) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_hyperparameter_categorical_range(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False hyperparameters = [ { "Description": "A continuous hyperparameter", @@ -689,9 +715,11 @@ def test_algorithm_hyperparameter_categorical_range(session): estimator.set_hyperparameters(hp1="MxNET") -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_required_hyperparameters_not_provided(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False hyperparameters = [ { "Description": "A continuous hyperparameter", @@ -733,10 +761,12 @@ def test_algorithm_required_hyperparameters_not_provided(session): estimator.fit({"training": "s3://some/place"}) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.estimator.EstimatorBase.fit", Mock()) def test_algorithm_required_hyperparameters_are_provided(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False hyperparameters = [ { "Description": "A categorical hyperparameter", @@ -779,9 +809,11 @@ def test_algorithm_required_hyperparameters_are_provided(session): estimator.set_hyperparameters(hp1="TF", hp2="TF2", free_text_hp1="Hello!") -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_required_free_text_hyperparameter_not_provided(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False hyperparameters = [ { "Name": "free_text_hp1", @@ -822,10 +854,12 @@ def test_algorithm_required_free_text_hyperparameter_not_provided(session): estimator.set_hyperparameters(free_text_hp2="some text") -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.algorithm.AlgorithmEstimator.create_model") def test_algorithm_create_transformer(create_model, session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -848,9 +882,11 @@ def test_algorithm_create_transformer(create_model, session): assert transformer.model_name == "my-model" -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_create_transformer_without_completed_training_job(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -866,10 +902,12 @@ def test_algorithm_create_transformer_without_completed_training_job(session): assert "No finished training job found associated with this estimator" in str(error) +@patch("sagemaker.Session", spec=Session) @patch("sagemaker.algorithm.AlgorithmEstimator.create_model") -@patch("sagemaker.Session") def test_algorithm_create_transformer_with_product_id(create_model, session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["ProductId"] = "some-product-id" session.sagemaker_client.describe_algorithm = Mock(return_value=response) @@ -891,9 +929,11 @@ def test_algorithm_create_transformer_with_product_id(create_model, session): assert transformer.env is None -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_enable_network_isolation_no_product_id(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) estimator = AlgorithmEstimator( @@ -908,9 +948,11 @@ def test_algorithm_enable_network_isolation_no_product_id(session): assert network_isolation is False -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_enable_network_isolation_with_product_id(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["ProductId"] = "some-product-id" session.sagemaker_client.describe_algorithm = Mock(return_value=response) @@ -927,9 +969,11 @@ def test_algorithm_enable_network_isolation_with_product_id(session): assert network_isolation is True -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_encrypt_inter_container_traffic(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False response = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) response["encrypt_inter_container_traffic"] = True session.sagemaker_client.describe_algorithm = Mock(return_value=response) @@ -947,9 +991,11 @@ def test_algorithm_encrypt_inter_container_traffic(session): assert encrypt_inter_container_traffic is True -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_no_required_hyperparameters(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False some_algo = copy.deepcopy(DESCRIBE_ALGORITHM_RESPONSE) del some_algo["TrainingSpecification"]["SupportedHyperParameters"] @@ -968,8 +1014,10 @@ def test_algorithm_no_required_hyperparameters(session): def test_algorithm_attach_from_hyperparameter_tuning(): - session = Mock() + session = Mock(spec=Session) session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False job_name = "training-job-that-is-part-of-a-tuning-job" algo_arn = "arn:aws:sagemaker:us-east-2:000000000000:algorithm/scikit-decision-trees" role_arn = "arn:aws:iam::123412341234:role/SageMakerRole" @@ -1040,9 +1088,11 @@ def test_algorithm_attach_from_hyperparameter_tuning(): assert estimator.sagemaker_session == session -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_algorithm_supported_with_spot_instances(session): session.sagemaker_config = {} + session.sagemaker_client = Mock() + session.local_mode = False session.sagemaker_client.describe_algorithm = Mock(return_value=DESCRIBE_ALGORITHM_RESPONSE) assert AlgorithmEstimator( diff --git a/tests/unit/test_amazon_estimator.py b/tests/unit/test_amazon_estimator.py index 8b00b68dd9..0d01636cde 100644 --- a/tests/unit/test_amazon_estimator.py +++ b/tests/unit/test_amazon_estimator.py @@ -23,6 +23,7 @@ _build_shards, FileSystemRecordSet, ) +from sagemaker import Session from sagemaker.session_settings import SessionSettings from tests.unit import ( DEFAULT_S3_OBJECT_KEY_PREFIX_NAME, @@ -43,6 +44,8 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) sms = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, region_name=REGION, config=None, diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 8ae318cb83..61b8c3d036 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -25,6 +25,7 @@ from sagemaker.chainer import Chainer from sagemaker.chainer import ChainerPredictor, ChainerModel from sagemaker.session_settings import SessionSettings +from sagemaker import Session DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -57,6 +58,8 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, diff --git a/tests/unit/test_djl_inference.py b/tests/unit/test_djl_inference.py index cc8a99cf1c..bcea093cf0 100644 --- a/tests/unit/test_djl_inference.py +++ b/tests/unit/test_djl_inference.py @@ -18,8 +18,7 @@ from json import JSONDecodeError import pytest -from mock import Mock, MagicMock -from mock import patch, mock_open +from mock import Mock, MagicMock, patch, mock_open from sagemaker.djl_inference import ( defaults, @@ -31,6 +30,7 @@ from sagemaker.djl_inference.model import DJLServingEngineEntryPointDefaults from sagemaker.s3_utils import s3_path_join from sagemaker.session_settings import SessionSettings +from sagemaker import Session from tests.unit import ( _test_default_bucket_and_prefix_combinations, DEFAULT_S3_BUCKET_NAME, @@ -54,7 +54,9 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session = Mock( - "sagemaker_session", + name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, config=None, diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 3d8b0c454d..a41f9306a7 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -54,6 +54,7 @@ from sagemaker.mxnet.estimator import MXNet from sagemaker.predictor import Predictor from sagemaker.pytorch.estimator import PyTorch +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from sagemaker.sklearn.estimator import SKLearn from sagemaker.tensorflow.estimator import TensorFlow @@ -282,6 +283,8 @@ def sagemaker_session(): boto_session=boto_mock, boto_region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, @@ -1337,10 +1340,13 @@ def test_framework_with_no_default_profiler_in_unsupported_region(region): boto_session=boto_mock, boto_region_name=region, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.sagemaker_config = {} f = DummyFramework( @@ -1368,10 +1374,13 @@ def test_framework_with_debugger_config_set_up_in_unsupported_region(region): boto_session=boto_mock, boto_region_name=region, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.sagemaker_config = {} f = DummyFramework( @@ -1396,10 +1405,13 @@ def test_framework_enable_profiling_in_unsupported_region(region): boto_session=boto_mock, boto_region_name=region, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.sagemaker_config = {} f = DummyFramework( @@ -1424,10 +1436,13 @@ def test_framework_update_profiling_in_unsupported_region(region): boto_session=boto_mock, boto_region_name=region, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.sagemaker_config = {} f = DummyFramework( @@ -1452,10 +1467,13 @@ def test_framework_disable_profiling_in_unsupported_region(region): boto_session=boto_mock, boto_region_name=region, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, settings=SessionSettings(), + default_bucket_prefix=None, ) sms.sagemaker_config = {} f = DummyFramework( @@ -4867,7 +4885,7 @@ def test_script_mode_estimator(patched_stage_user_code, sagemaker_session): @patch("time.time", return_value=TIME) @patch("sagemaker.estimator.tar_and_upload_dir") def test_script_mode_estimator_same_calls_as_framework( - patched_tar_and_upload_dir, sagemaker_session + patched_tar_and_upload_dir, time_patched, sagemaker_session ): patched_tar_and_upload_dir.return_value = UploadedCode( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" @@ -4936,7 +4954,7 @@ def test_script_mode_estimator_same_calls_as_framework( @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") def test_script_mode_estimator_tags_jumpstart_estimators_and_models( - patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_upload_code, patched_tar_and_upload_dir, time_patched, sagemaker_session ): patched_tar_and_upload_dir.return_value = UploadedCode( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" @@ -5014,7 +5032,7 @@ def test_script_mode_estimator_tags_jumpstart_estimators_and_models( @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") def test_script_mode_estimator_tags_jumpstart_models( - patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_upload_code, patched_tar_and_upload_dir, time_patched, sagemaker_session ): patched_tar_and_upload_dir.return_value = UploadedCode( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" @@ -5076,7 +5094,7 @@ def test_script_mode_estimator_tags_jumpstart_models( @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") def test_script_mode_estimator_tags_jumpstart_models_with_no_estimator_js_tags( - patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_upload_code, patched_tar_and_upload_dir, time_patched, sagemaker_session ): patched_tar_and_upload_dir.return_value = UploadedCode( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" @@ -5308,7 +5326,7 @@ def test_all_framework_estimators_support_disabling_jumpstart_uri_tags( @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") def test_script_mode_estimator_uses_jumpstart_base_name_with_js_models( - patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_upload_code, patched_tar_and_upload_dir, time_patched, sagemaker_session ): patched_tar_and_upload_dir.return_value = UploadedCode( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" @@ -5523,7 +5541,7 @@ def test_insert_invalid_source_code_args(): @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") def test_script_mode_estimator_escapes_hyperparameters_as_json( - patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_upload_code, patched_tar_and_upload_dir, time_patched, sagemaker_session ): patched_tar_and_upload_dir.return_value = UploadedCode( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" @@ -5574,7 +5592,7 @@ def test_script_mode_estimator_escapes_hyperparameters_as_json( @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") def test_estimator_local_download_dir( - patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_upload_code, patched_tar_and_upload_dir, time_patched, sagemaker_session ): patched_tar_and_upload_dir.return_value = UploadedCode( s3_prefix="s3://%s/%s" % ("bucket", "key"), script_name="script_name" @@ -5584,7 +5602,7 @@ def test_estimator_local_download_dir( local_download_dir = "some/download/dir" - sagemaker_session.settings.local_download_dir = local_download_dir + sagemaker_session.settings = SessionSettings(local_download_dir=local_download_dir) instance_type = "ml.p2.xlarge" instance_count = 1 diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index ebac0dfbb9..8dac17c3b3 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -22,6 +22,7 @@ ) from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings +from sagemaker.session import Session ROLE = "myrole" INSTANCE_COUNT = 1 @@ -56,6 +57,8 @@ def sagemaker_session(): boto_session=boto_mock, region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=False, s3_resource=False, diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index da4b8a9477..52375b5301 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -15,7 +15,7 @@ import pytest from mock import Mock, patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.amazon.ipinsights import IPInsights, IPInsightsPredictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings @@ -55,6 +55,8 @@ def sagemaker_session(): boto_session=boto_mock, region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, settings=SessionSettings(), default_bucket_prefix=None, diff --git a/tests/unit/test_job.py b/tests/unit/test_job.py index 603b494e5a..755471d021 100644 --- a/tests/unit/test_job.py +++ b/tests/unit/test_job.py @@ -16,7 +16,7 @@ import os from mock import Mock -from sagemaker import TrainingInput +from sagemaker import TrainingInput, Session from sagemaker.amazon.amazon_estimator import RecordSet, FileSystemRecordSet from sagemaker.estimator import Estimator, Framework from sagemaker.inputs import FileSystemInput @@ -80,6 +80,9 @@ def sagemaker_session(): mock_session = Mock( name="sagemaker_session", boto_session=boto_mock, + spec=Session, + sagemaker_client=Mock(), + local_mode=False, s3_client=None, s3_resource=None, default_bucket_prefix=None, diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index 3d91726478..3816ced2f5 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -15,7 +15,7 @@ import pytest from mock import Mock, patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.amazon.kmeans import KMeans, KMeansPredictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings @@ -50,6 +50,8 @@ def sagemaker_session(): boto_session=boto_mock, region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 0480d1891c..a75e8bddd8 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -15,7 +15,7 @@ import pytest from mock import Mock, patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.amazon.knn import KNN, KNNPredictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings @@ -56,6 +56,8 @@ def sagemaker_session(): boto_session=boto_mock, region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index f39df24d75..a64dd9cc84 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -15,7 +15,7 @@ import pytest from mock import Mock, patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.amazon.lda import LDA, LDAPredictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings @@ -45,6 +45,8 @@ def sagemaker_session(): name="sagemaker_session", boto_session=boto_mock, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index 3e45d76784..6dbf382f23 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -15,7 +15,7 @@ import pytest from mock import Mock, patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.amazon.linear_learner import LinearLearner, LinearLearnerPredictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings @@ -51,6 +51,8 @@ def sagemaker_session(): boto_session=boto_mock, region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, diff --git a/tests/unit/test_model_card.py b/tests/unit/test_model_card.py index e179cc02c4..e72ac64028 100644 --- a/tests/unit/test_model_card.py +++ b/tests/unit/test_model_card.py @@ -23,6 +23,7 @@ from botocore.exceptions import ClientError import botocore.response +from sagemaker import Session from sagemaker.model_card import schema_constraints from sagemaker.model_card import ( Environment, @@ -1001,7 +1002,7 @@ def fixture_additional_information_example(): return test_example -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_create_model_card( session, model_overview_example, @@ -1011,6 +1012,7 @@ def test_create_model_card( evaluation_details_example, additional_information_example, ): + session.sagemaker_client = Mock() session.sagemaker_client.create_model_card = Mock(return_value=CREATE_MODEL_CARD_RETURN_EXAMPLE) session.sagemaker_client.describe_model_card = Mock(return_value=LOAD_MODEL_CARD_EXMPLE) @@ -1031,10 +1033,11 @@ def test_create_model_card( assert card.arn == MODEL_CARD_ARN -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_create_model_card_with_model_package( session, model_package_example, training_details_example, caplog ): + session.sagemaker_client = Mock() session.sagemaker_client.create_model_card = Mock(return_value=CREATE_MODEL_CARD_RETURN_EXAMPLE) session.sagemaker_client.describe_model_card = Mock( return_value=MODEL_CARD_WITH_MODEL_PACKAGE_MOCK_RESPONSE @@ -1076,7 +1079,7 @@ def test_create_model_card_with_model_package( ) in caplog.text -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_create_model_card_with_multiple_models( session, model_package_example, model_overview_example ): @@ -1097,8 +1100,9 @@ def test_create_model_card_with_multiple_models( card.model_package_details = model_package_example -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_create_model_card_duplicate(session): + session.sagemaker_client = Mock() session.sagemaker_client.create_model_card.side_effect = [ CREATE_MODEL_CARD_RETURN_EXAMPLE, GENERAL_CLIENT_ERROR, @@ -1117,8 +1121,9 @@ def test_create_model_card_duplicate(session): card2.create() -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_create_multiple_model_cards_with_same_model(session, model_overview_example): + session.sagemaker_client = Mock() session.sagemaker_client.create_model_card.side_effect = [ CREATE_SIMPLE_MODEL_CARD_RETURN_EXAMPLE, GENERAL_CLIENT_ERROR, @@ -1142,8 +1147,9 @@ def test_create_multiple_model_cards_with_same_model(session, model_overview_exa card2.create() -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_create_model_card_with_too_long_string(session, model_overview_example): + session.sagemaker_client = Mock() too_long_string_client_error = ClientError( error_response={ "Error": { @@ -1170,8 +1176,9 @@ def test_create_model_card_with_too_long_string(session, model_overview_example) card.create() -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_carry_over_additional_content_from_model_package_group(session, model_package_example): + session.sagemaker_client = Mock() session.sagemaker_client.describe_model_card = Mock( return_value=DESCRIBE_MODEL_CARD_WITH_ADDITONAL_CONTENT ) @@ -1369,8 +1376,9 @@ def __init__(self, attr1): # pylint: disable=C0116 ) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_load_model_card(session): + session.sagemaker_client = Mock() session.sagemaker_client.describe_model_card = Mock(return_value=LOAD_SIMPLE_MODEL_CARD_EXMPLE) card = ModelCard.load(name=SIMPLE_MODEL_CARD_NAME, sagemaker_session=session) @@ -1379,11 +1387,12 @@ def test_load_model_card(session): assert card.additional_information -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_update_model_card( session, additional_information_example, ): + session.sagemaker_client = Mock() session.sagemaker_client.create_model_card = Mock( return_value=CREATE_SIMPLE_MODEL_CARD_RETURN_EXAMPLE ) @@ -1404,8 +1413,9 @@ def test_update_model_card( assert card.update() -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_delete_model_card(session): + session.sagemaker_client = Mock() session.sagemaker_client.describe_model_card = Mock(return_value=LOAD_SIMPLE_MODEL_CARD_EXMPLE) session.sagemaker_client.delete_model_card = Mock( return_value=DELETE_SIMPLE_MODEL_CARD_RETURN_EXAMPLE @@ -1467,8 +1477,9 @@ def test_hash_content_str(): assert _hash_content_str(content1) != _hash_content_str(content2) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_model_details_autodiscovery(session): + session.sagemaker_client = Mock() session.sagemaker_client.describe_model.side_effect = [ DESCRIBE_MODEL_EXAMPLE, DESCRIBE_MODEL_EXAMPLE, @@ -1501,8 +1512,9 @@ def test_model_details_autodiscovery(session): ModelOverview.from_model_name(MODEL_NAME, sagemaker_session=session) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_model_package_autodiscovery(session, model_overview_example, training_details_example): + session.sagemaker_client = Mock() session.sagemaker_client.describe_model_package.side_effect = [ DESCRIBE_MODEL_PACKAGE_EXAMPLE, DESCRIBE_MODEL_PACKAGE_EXAMPLE, @@ -1557,10 +1569,11 @@ def test_model_package_autodiscovery(session, model_overview_example, training_d ModelPackage.from_model_package_arn(MODEL_PACKAGE_ARN, sagemaker_session=session) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_training_details_autodiscovery_from_model_overview( session, model_overview_example, caplog ): + session.sagemaker_client = Mock() session.sagemaker_client.search.side_effect = [ SEARCH_TRAINING_JOB_EXAMPLE, SEARCH_IAM_PERMISSION_CLIENT_ERROR, @@ -1618,10 +1631,11 @@ def test_training_details_autodiscovery_from_model_overview( ) in caplog.text -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_training_details_autodiscovery_from_model_package_details( session, model_package_example, caplog ): + session.sagemaker_client = Mock() session.sagemaker_client.search.side_effect = [ SEARCH_TRAINING_JOB_EXAMPLE, ] @@ -1658,10 +1672,12 @@ def test_training_details_autodiscovery_from_model_package_details( ) in caplog.text -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_evaluation_details_autodiscovery_from_model_package_details( session, model_package_example, caplog ): + session.sagemaker_client = Mock() + session.boto_session = Mock() with open(CLARIFY_BIAS_JSON_PATH, "r", encoding="utf-8") as istr: data = json.dumps(json.load(istr)) response = { @@ -1692,10 +1708,11 @@ def test_evaluation_details_autodiscovery_from_model_package_details( assert len(evaluation_details[0].metric_groups) == 3 -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_training_details_autodiscovery_from_model_overview_autopilot( session, model_overview_example, caplog ): + session.sagemaker_client = Mock() session.sagemaker_client.search.side_effect = [ SEARCH_TRAINING_JOB_AUTOPILOT_EXAMPLE, ] @@ -1711,8 +1728,9 @@ def test_training_details_autodiscovery_from_model_overview_autopilot( assert len(training_details.training_job_details.hyper_parameters) == 3 -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_training_details_autodiscovery_from_job_name(session): + session.sagemaker_client = Mock() session.sagemaker_client.describe_training_job.side_effect = [ DESCRIBE_TRAINING_JOB_EXAMPLE, MISSING_TRAINING_JOB_CLIENT_ERROR, @@ -1809,6 +1827,7 @@ def test_add_evaluation_metrics_from_json(): @patch("boto3.session.Session") def test_add_evauation_metrics_from_s3(session, caplog): + session.sagemaker_client = Mock() json_path = os.path.join(DATA_DIR, "evaluation_metrics/clarify_bias.json") with open(json_path, "r", encoding="utf-8") as istr: data = json.dumps(json.load(istr)) @@ -1935,8 +1954,9 @@ def test_metrics_model_monitor_model_quality_regression(): assert json.dumps(result, sort_keys=True) == json.dumps(expected_translation, sort_keys=True) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_create_export_model_card(session, caplog): + session.sagemaker_client = Mock() session.sagemaker_client.create_model_card_export_job.side_effect = [ CREATE_EXPORT_MODEL_CARD_EXAMPLE, CREATE_EXPORT_MODEL_CARD_EXAMPLE, @@ -1971,8 +1991,9 @@ def test_create_export_model_card(session, caplog): assert "Failed to export model card" in caplog.text -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_list_export_model_cards(session): + session.sagemaker_client = Mock() session.sagemaker_client.list_model_card_export_jobs.side_effect = [ LIST_MODEL_CARD_EXPORT_JOB_EXAMPLE ] @@ -1986,8 +2007,9 @@ def test_list_export_model_cards(session): ) -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_model_card_export_pdf(session, caplog): + session.sagemaker_client = Mock() session.sagemaker_client.create_model_card_export_job.side_effect = [ CREATE_EXPORT_MODEL_CARD_EXAMPLE, CREATE_EXPORT_MODEL_CARD_EXAMPLE, @@ -2009,8 +2031,9 @@ def test_model_card_export_pdf(session, caplog): assert "Failed to export model card" in caplog.text -@patch("sagemaker.Session") +@patch("sagemaker.Session", spec=Session) def test_list_model_card_version_history(session): + session.sagemaker_client = Mock() session.sagemaker_client.list_model_card_versions.side_effect = [ LIST_MODEL_CARD_VERSION_HISTORY_EXAMPLE ] diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index ea7f6ebcd7..f9b1f8c1b9 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -29,6 +29,7 @@ from sagemaker.mxnet import MXNet from sagemaker.mxnet import MXNetPredictor, MXNetModel from sagemaker.session_settings import SessionSettings +from sagemaker.session import Session DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_NAME = "dummy_script.py" @@ -81,6 +82,8 @@ def sagemaker_session(): boto_session=boto_mock, boto_region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_resource=None, s3_client=None, diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 8db3688f84..bb02715273 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -19,6 +19,7 @@ from sagemaker.amazon.ntm import NTM, NTMPredictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings +from sagemaker.session import Session ROLE = "myrole" INSTANCE_COUNT = 1 @@ -50,6 +51,8 @@ def sagemaker_session(): boto_session=boto_mock, region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, diff --git a/tests/unit/test_object2vec.py b/tests/unit/test_object2vec.py index d26faa4bb2..13d04c4b08 100644 --- a/tests/unit/test_object2vec.py +++ b/tests/unit/test_object2vec.py @@ -20,6 +20,7 @@ from sagemaker.predictor import Predictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings +from sagemaker.session import Session ROLE = "myrole" INSTANCE_COUNT = 1 @@ -58,6 +59,8 @@ def sagemaker_session(): boto_session=boto_mock, region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_client=None, s3_resource=None, diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index 1f9460293d..4991636233 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -15,7 +15,7 @@ import pytest from mock import Mock, patch -from sagemaker import image_uris +from sagemaker import image_uris, Session from sagemaker.amazon.pca import PCA, PCAPredictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings @@ -47,6 +47,8 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) sms = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), boto_session=boto_mock, region_name=REGION, config=None, diff --git a/tests/unit/test_processing.py b/tests/unit/test_processing.py index 93e3d91f87..6957e7e16a 100644 --- a/tests/unit/test_processing.py +++ b/tests/unit/test_processing.py @@ -32,6 +32,7 @@ ScriptProcessor, ProcessingJob, ) +from sagemaker.session import Session from sagemaker.session_settings import SessionSettings from sagemaker.spark.processing import PySparkProcessor from sagemaker.sklearn.processing import SKLearnProcessor @@ -75,9 +76,12 @@ def mock_create_tar_file(): def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) session_mock = MagicMock( + spec=Session, name="sagemaker_session", + sagemaker_client=Mock(), boto_session=boto_mock, boto_region_name=REGION, + s3_resource=None, config=None, local_mode=False, settings=SessionSettings(), diff --git a/tests/unit/test_randomcutforest.py b/tests/unit/test_randomcutforest.py index d9884f9cde..0afd96d79e 100644 --- a/tests/unit/test_randomcutforest.py +++ b/tests/unit/test_randomcutforest.py @@ -19,6 +19,7 @@ from sagemaker.amazon.randomcutforest import RandomCutForest, RandomCutForestPredictor from sagemaker.amazon.amazon_estimator import RecordSet from sagemaker.session_settings import SessionSettings +from sagemaker import Session ROLE = "myrole" INSTANCE_COUNT = 1 @@ -52,6 +53,8 @@ def sagemaker_session(): boto_session=boto_mock, region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, settings=SessionSettings(), default_bucket_prefix=None, diff --git a/tests/unit/test_rl.py b/tests/unit/test_rl.py index 06e6387dd1..5cce04aede 100644 --- a/tests/unit/test_rl.py +++ b/tests/unit/test_rl.py @@ -23,6 +23,7 @@ from sagemaker.rl import RLEstimator, RLFramework, RLToolkit, TOOLKIT_FRAMEWORK_VERSION_MAP from sagemaker.session_settings import SessionSettings from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor +from sagemaker import Session DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") @@ -62,6 +63,8 @@ def fixture_sagemaker_session(): boto_session=boto_mock, boto_region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_resource=None, s3_client=None, diff --git a/tests/unit/test_tuner.py b/tests/unit/test_tuner.py index ff39535adf..433465cf21 100644 --- a/tests/unit/test_tuner.py +++ b/tests/unit/test_tuner.py @@ -54,6 +54,9 @@ def sagemaker_session(): boto_mock = Mock(name="boto_session", region_name=REGION) sms = Mock( name="sagemaker_session", + spec=Session, + sagemaker_client=Mock(), + local_mode=False, boto_session=boto_mock, s3_client=None, s3_resource=None, @@ -1893,7 +1896,7 @@ def _convert_tuning_job_details(job_details, estimator_name): @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") def test_tags_prefixes_jumpstart_models( - patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_upload_code, patched_tar_and_upload_dir, time_patched, sagemaker_session ): jumpstart_source_dir = f"s3://{list(JUMPSTART_BUCKET_NAME_SET)[0]}/source_dirs/source.tar.gz" @@ -1921,6 +1924,9 @@ def test_tags_prefixes_jumpstart_models( sagemaker_session.boto_region_name = REGION sagemaker_session.sagemaker_config = {} + sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job.return_value = { + "BestTrainingJob": {"TrainingJobName": "some-name"} + } sagemaker_session.sagemaker_client.describe_training_job.return_value = { "AlgorithmSpecification": { "TrainingInputMode": "File", @@ -2034,7 +2040,7 @@ def test_tags_prefixes_jumpstart_models( @patch("sagemaker.estimator.tar_and_upload_dir") @patch("sagemaker.model.Model._upload_code") def test_no_tags_prefixes_non_jumpstart_models( - patched_upload_code, patched_tar_and_upload_dir, sagemaker_session + patched_upload_code, patched_tar_and_upload_dir, time_patched, sagemaker_session ): non_jumpstart_source_dir = "s3://blah1/source_dirs/source.tar.gz" diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 30332741aa..d1d713775f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1017,7 +1017,12 @@ def test_repack_model_with_same_inference_file_name(tmp, fake_s3): class FakeS3(object): def __init__(self, tmp): self.tmp = tmp - self.sagemaker_session = MagicMock(settings=SessionSettings()) + self.sagemaker_session = MagicMock( + spec=sagemaker.Session, + boto_session=MagicMock(), + local_mode=False, + settings=SessionSettings(), + ) self.location_map = {} self.current_bucket = None self.object_mock = MagicMock() diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index e54da2d862..9214c84fa3 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -25,6 +25,7 @@ from sagemaker.fw_utils import UploadedCode from sagemaker.session_settings import SessionSettings from sagemaker.xgboost import XGBoost, XGBoostModel, XGBoostPredictor +from sagemaker import Session DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") @@ -67,6 +68,8 @@ def sagemaker_session(): boto_session=boto_mock, boto_region_name=REGION, config=None, + spec=Session, + sagemaker_client=Mock(), local_mode=False, s3_resource=None, s3_client=None, From bcb78deb693fceb2c5aeb9512651bc9c44b32a7e Mon Sep 17 00:00:00 2001 From: martinRenou Date: Tue, 2 Jan 2024 10:56:02 +0100 Subject: [PATCH 6/6] Linting --- src/sagemaker/pytorch/processing.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/sagemaker/pytorch/processing.py b/src/sagemaker/pytorch/processing.py index 5f884f8e02..3182c1693e 100644 --- a/src/sagemaker/pytorch/processing.py +++ b/src/sagemaker/pytorch/processing.py @@ -26,7 +26,6 @@ from sagemaker.workflow.entities import PipelineVariable from sagemaker.utils import format_tags, Tags, validate_call_inputs from sagemaker.workflow.parameters import ParameterString -from sagemaker.utils import validate_call_inputs class PyTorchProcessor(FrameworkProcessor):