Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change: Type checking with Pydantic #4265

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def read_requirements(filename):
"packaging>=20.0",
"pandas",
"pathos",
"pydantic",
"schema",
"PyYAML~=6.0",
"jsonschema",
"platformdirs",
"tblib>=1.7.0,<3",
"urllib3<1.27",
"uvicorn==0.22.0",
"fastapi==0.95.2",
"requests",
"docker",
"tqdm",
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"""Test docstring"""
from __future__ import absolute_import

from numbers import Number
from typing import Optional, Union, Dict, List

import sagemaker
Expand Down Expand Up @@ -58,7 +59,7 @@ def __init__(
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
tags: Optional[Tags] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/chainer/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import

import logging
from numbers import Number
from typing import Union, Optional, Dict

from sagemaker.estimator import Framework, EstimatorBase
Expand Down Expand Up @@ -50,7 +51,7 @@ def __init__(
process_slots_per_host: Optional[Union[int, PipelineVariable]] = None,
additional_mpi_options: Optional[Union[str, PipelineVariable]] = None,
source_dir: Optional[Union[str, PipelineVariable]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
framework_version: Optional[str] = None,
py_version: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/chainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

import logging
from typing import Optional, Union, List, Dict
from typing import Callable, Optional, Union, List, Dict

import sagemaker
from sagemaker import image_uris, ModelMetrics
Expand Down Expand Up @@ -91,7 +91,7 @@ def __init__(
image_uri: Optional[Union[str, PipelineVariable]] = None,
framework_version: Optional[str] = None,
py_version: Optional[str] = None,
predictor_cls: callable = ChainerPredictor,
predictor_cls: Callable = ChainerPredictor,
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
**kwargs
):
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/djl_inference/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from json import JSONDecodeError
from urllib.error import HTTPError, URLError
from enum import Enum
from typing import Optional, Union, Dict, Any, List
from typing import Callable, Optional, Union, Dict, Any, List

import sagemaker
from sagemaker import s3, Predictor, image_uris, fw_utils
Expand Down Expand Up @@ -313,7 +313,7 @@ def __init__(
prediction_timeout: Optional[int] = None,
entry_point: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
predictor_cls: callable = DJLPredictor,
predictor_cls: Callable = DJLPredictor,
**kwargs,
):
"""Initialize a DJLModel.
Expand Down
23 changes: 13 additions & 10 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import re
import uuid
from abc import ABCMeta, abstractmethod
from numbers import Number
from typing import Any, Dict, Union, Optional, List
from packaging.specifiers import SpecifierSet
from packaging.version import Version
Expand Down Expand Up @@ -100,6 +101,7 @@
resolve_value_from_config,
format_tags,
Tags,
validate_call_inputs,
)
from sagemaker.workflow import is_pipeline_variable
from sagemaker.workflow.entities import PipelineVariable
Expand Down Expand Up @@ -132,9 +134,10 @@ class EstimatorBase(with_metaclass(ABCMeta, object)): # pylint: disable=too-man
CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = "/opt/ml/input/data/code/sourcedir.tar.gz"
JOB_CLASS_NAME = "training-job"

@validate_call_inputs
def __init__(
self,
role: str = None,
role: Optional[Union[str, ParameterString]] = None,
instance_count: Optional[Union[int, PipelineVariable]] = None,
instance_type: Optional[Union[str, PipelineVariable]] = None,
keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None,
Expand All @@ -152,23 +155,23 @@ def __init__(
model_uri: Optional[str] = None,
model_channel_name: Union[str, PipelineVariable] = "model",
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None,
encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None,
use_spot_instances: Union[bool, PipelineVariable] = False,
max_wait: Optional[Union[int, PipelineVariable]] = None,
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
rules: Optional[List[RuleBase]] = None,
debugger_hook_config: Optional[Union[bool, DebuggerHookConfig]] = None,
debugger_hook_config: Optional[Union[bool, DebuggerHookConfig, Dict]] = None,
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = None,
enable_network_isolation: Optional[Union[bool, PipelineVariable]] = None,
profiler_config: Optional[ProfilerConfig] = None,
disable_profiler: bool = None,
environment: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
max_retry_attempts: Optional[Union[int, PipelineVariable]] = None,
source_dir: Optional[Union[str, PipelineVariable]] = None,
git_config: Optional[Dict[str, str]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
container_log_level: Union[int, PipelineVariable] = logging.INFO,
code_location: Optional[str] = None,
entry_point: Optional[Union[str, PipelineVariable]] = None,
Expand Down Expand Up @@ -2709,7 +2712,7 @@ class Estimator(EstimatorBase):
def __init__(
self,
image_uri: Union[str, PipelineVariable],
role: str = None,
role: Optional[str] = None,
instance_count: Optional[Union[int, PipelineVariable]] = None,
instance_type: Optional[Union[str, PipelineVariable]] = None,
keep_alive_period_in_seconds: Optional[Union[int, PipelineVariable]] = None,
Expand All @@ -2721,21 +2724,21 @@ def __init__(
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
tags: Optional[Tags] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
model_uri: Optional[str] = None,
model_channel_name: Union[str, PipelineVariable] = "model",
metric_definitions: Optional[List[Dict[str, Union[str, PipelineVariable]]]] = None,
encrypt_inter_container_traffic: Union[bool, PipelineVariable] = None,
encrypt_inter_container_traffic: Optional[Union[bool, PipelineVariable]] = None,
use_spot_instances: Union[bool, PipelineVariable] = False,
max_wait: Optional[Union[int, PipelineVariable]] = None,
checkpoint_s3_uri: Optional[Union[str, PipelineVariable]] = None,
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = None,
rules: Optional[List[RuleBase]] = None,
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None,
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool, Dict]] = None,
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
profiler_config: Optional[ProfilerConfig] = None,
Expand Down Expand Up @@ -3288,7 +3291,7 @@ def __init__(
self,
entry_point: Union[str, PipelineVariable],
source_dir: Optional[Union[str, PipelineVariable]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
container_log_level: Union[int, PipelineVariable] = logging.INFO,
code_location: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
Expand Down
3 changes: 2 additions & 1 deletion src/sagemaker/huggingface/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import logging
import re
from numbers import Number
from typing import Optional, Union, Dict

from sagemaker.estimator import Framework, EstimatorBase
Expand Down Expand Up @@ -47,7 +48,7 @@ def __init__(
tensorflow_version: Optional[str] = None,
pytorch_version: Optional[str] = None,
source_dir: Optional[Union[str, PipelineVariable]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
distribution: Optional[Dict] = None,
compiler_config: Optional[TrainingCompilerConfig] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/huggingface/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

import logging
from typing import Optional, Union, List, Dict
from typing import Callable, Optional, Union, List, Dict

import sagemaker
from sagemaker import image_uris, ModelMetrics
Expand Down Expand Up @@ -118,7 +118,7 @@ def __init__(
pytorch_version: Optional[str] = None,
py_version: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
predictor_cls: callable = HuggingFacePredictor,
predictor_cls: Callable = HuggingFacePredictor,
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
**kwargs,
):
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""This module stores JumpStart implementation of Estimator class."""
from __future__ import absolute_import


from numbers import Number
from typing import Dict, List, Optional, Union
from sagemaker import session
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[session.Session] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
tags: Optional[Tags] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
Expand All @@ -86,7 +86,7 @@ def __init__(
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = None,
rules: Optional[List[RuleBase]] = None,
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None,
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool, Dict]] = None,
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
profiler_config: Optional[ProfilerConfig] = None,
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from __future__ import absolute_import


from numbers import Number
from typing import Dict, List, Optional, Union
from sagemaker import (
environment_variables,
Expand Down Expand Up @@ -93,7 +94,7 @@ def get_init_kwargs(
output_kms_key: Optional[Union[str, PipelineVariable]] = None,
base_job_name: Optional[str] = None,
sagemaker_session: Optional[Session] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
tags: Optional[Tags] = None,
subnets: Optional[List[Union[str, PipelineVariable]]] = None,
security_group_ids: Optional[List[Union[str, PipelineVariable]]] = None,
Expand All @@ -107,7 +108,7 @@ def get_init_kwargs(
checkpoint_local_path: Optional[Union[str, PipelineVariable]] = None,
enable_network_isolation: Union[bool, PipelineVariable] = None,
rules: Optional[List[RuleBase]] = None,
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool]] = None,
debugger_hook_config: Optional[Union[DebuggerHookConfig, bool, Dict]] = None,
tensorboard_output_config: Optional[TensorBoardOutputConfig] = None,
enable_sagemaker_metrics: Optional[Union[bool, PipelineVariable]] = None,
profiler_config: Optional[ProfilerConfig] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

import logging
from typing import Union, Optional, List, Dict
from typing import Callable, Union, Optional, List, Dict

import packaging.version

Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
framework_version: str = _LOWEST_MMS_VERSION,
py_version: Optional[str] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
predictor_cls: callable = MXNetPredictor,
predictor_cls: Callable = MXNetPredictor,
model_server_workers: Optional[Union[int, PipelineVariable]] = None,
**kwargs
):
Expand Down
21 changes: 13 additions & 8 deletions src/sagemaker/pytorch/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from __future__ import absolute_import

import logging
from typing import Union, Optional, Dict
from numbers import Number
from typing import Union, Optional, Dict, List

from packaging.version import Version

Expand All @@ -32,6 +33,8 @@
from sagemaker.pytorch.training_compiler.config import TrainingCompilerConfig
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT
from sagemaker.workflow.entities import PipelineVariable
from sagemaker.workflow.parameters import ParameterString
from sagemaker.utils import validate_call_inputs

logger = logging.getLogger("sagemaker")

Expand All @@ -44,13 +47,14 @@ class PyTorch(Framework):
LAUNCH_TORCH_DISTRIBUTED_ENV_NAME = "sagemaker_torch_distributed_enabled"
INSTANCE_TYPE_ENV_NAME = "sagemaker_instance_type"

@validate_call_inputs
def __init__(
self,
entry_point: Union[str, PipelineVariable],
framework_version: Optional[str] = None,
py_version: Optional[str] = None,
source_dir: Optional[Union[str, PipelineVariable]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable]]] = None,
hyperparameters: Optional[Dict[str, Union[str, PipelineVariable, Number]]] = None,
image_uri: Optional[Union[str, PipelineVariable]] = None,
distribution: Optional[Dict] = None,
compiler_config: Optional[TrainingCompilerConfig] = None,
Expand Down Expand Up @@ -362,14 +366,15 @@ def hyperparameters(self):

return hyperparameters

@validate_call_inputs
def create_model(
self,
model_server_workers=None,
role=None,
vpc_config_override=VPC_CONFIG_DEFAULT,
entry_point=None,
source_dir=None,
dependencies=None,
model_server_workers: Optional[int] = None,
role: Optional[Union[str, ParameterString]] = None,
vpc_config_override: Optional[Dict[str, List[str]]] = VPC_CONFIG_DEFAULT,
entry_point: Optional[str] = None,
source_dir: Optional[str] = None,
dependencies: Optional[List[str]] = None,
**kwargs,
):
"""Create a SageMaker ``PyTorchModel`` object that can be deployed to an ``Endpoint``.
Expand Down