Skip to content

Commit

Permalink
Organize Sagemaker classes in Amazon provider (#20370)
Browse files Browse the repository at this point in the history
Organize Sagemaker classes in Amazon provider (#20370)
  • Loading branch information
bhavaniravi committed Dec 21, 2021
1 parent 302efad commit d557965
Show file tree
Hide file tree
Showing 31 changed files with 1,085 additions and 927 deletions.
643 changes: 643 additions & 0 deletions airflow/providers/amazon/aws/operators/sagemaker.py

Large diffs are not rendered by default.

95 changes: 8 additions & 87 deletions airflow/providers/amazon/aws/operators/sagemaker_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,93 +16,14 @@
# specific language governing permissions and limitations
# under the License.

import json
import sys
from typing import Iterable
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""

if sys.version_info >= (3, 8):
from functools import cached_property
else:
from cached_property import cached_property
import warnings

from airflow.providers.amazon.aws.operators.sagemaker import SageMakerBaseOperator # noqa

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook


class SageMakerBaseOperator(BaseOperator):
"""
This is the base operator for all SageMaker operators.
:param config: The configuration necessary to start a training job (templated)
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""

template_fields = ['config']
template_ext = ()
template_fields_renderers = {"config": "json"}
ui_color = '#ededed'

integer_fields = [] # type: Iterable[Iterable[str]]

def __init__(self, *, config: dict, aws_conn_id: str = 'aws_default', **kwargs):
super().__init__(**kwargs)

self.aws_conn_id = aws_conn_id
self.config = config

def parse_integer(self, config, field):
"""Recursive method for parsing string fields holding integer values to integers."""
if len(field) == 1:
if isinstance(config, list):
for sub_config in config:
self.parse_integer(sub_config, field)
return
head = field[0]
if head in config:
config[head] = int(config[head])
return

if isinstance(config, list):
for sub_config in config:
self.parse_integer(sub_config, field)
return

head, tail = field[0], field[1:]
if head in config:
self.parse_integer(config[head], tail)
return

def parse_config_integers(self):
"""
Parse the integer fields of training config to integers in case the config is rendered by Jinja and
all fields are str.
"""
for field in self.integer_fields:
self.parse_integer(self.config, field)

def expand_role(self):
"""Placeholder for calling boto3's `expand_role`, which expands an IAM role name into an ARN."""

def preprocess_config(self):
"""Process the config into a usable form."""
self.log.info('Preprocessing the config and doing required s3_operations')

self.hook.configure_s3_resources(self.config)
self.parse_config_integers()
self.expand_role()

self.log.info(
"After preprocessing the config is:\n %s",
json.dumps(self.config, sort_keys=True, indent=4, separators=(",", ": ")),
)

def execute(self, context):
raise NotImplementedError('Please implement execute() in sub class!')

@cached_property
def hook(self):
"""Return SageMakerHook"""
return SageMakerHook(aws_conn_id=self.aws_conn_id)
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
DeprecationWarning,
stacklevel=2,
)
143 changes: 8 additions & 135 deletions airflow/providers/amazon/aws/operators/sagemaker_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,142 +15,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Optional

from botocore.exceptions import ClientError
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
import warnings

from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointOperator # noqa

class SageMakerEndpointOperator(SageMakerBaseOperator):
"""
Create a SageMaker endpoint.
This operator returns The ARN of the endpoint created in Amazon SageMaker
:param config:
The configuration necessary to create an endpoint.
If you need to create a SageMaker endpoint based on an existed
SageMaker model and an existed SageMaker endpoint config::
config = endpoint_configuration;
If you need to create all of SageMaker model, SageMaker endpoint-config and SageMaker endpoint::
config = {
'Model': model_configuration,
'EndpointConfig': endpoint_config_configuration,
'Endpoint': endpoint_configuration
}
For details of the configuration parameter of model_configuration see
:py:meth:`SageMaker.Client.create_model`
For details of the configuration parameter of endpoint_config_configuration see
:py:meth:`SageMaker.Client.create_endpoint_config`
For details of the configuration parameter of endpoint_configuration see
:py:meth:`SageMaker.Client.create_endpoint`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
:param wait_for_completion: Whether the operator should wait until the endpoint creation finishes.
:type wait_for_completion: bool
:param check_interval: If wait is set to True, this is the time interval, in seconds, that this operation
waits before polling the status of the endpoint creation.
:type check_interval: int
:param max_ingestion_time: If wait is set to True, this operation fails if the endpoint creation doesn't
finish within max_ingestion_time seconds. If you set this parameter to None it never times out.
:type max_ingestion_time: int
:param operation: Whether to create an endpoint or update an endpoint. Must be either 'create or 'update'.
:type operation: str
"""

def __init__(
self,
*,
config: dict,
wait_for_completion: bool = True,
check_interval: int = 30,
max_ingestion_time: Optional[int] = None,
operation: str = 'create',
**kwargs,
):
super().__init__(config=config, **kwargs)

self.config = config
self.wait_for_completion = wait_for_completion
self.check_interval = check_interval
self.max_ingestion_time = max_ingestion_time
self.operation = operation.lower()
if self.operation not in ['create', 'update']:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')
self.create_integer_fields()

def create_integer_fields(self) -> None:
"""Set fields which should be casted to integers."""
if 'EndpointConfig' in self.config:
self.integer_fields = [['EndpointConfig', 'ProductionVariants', 'InitialInstanceCount']]

def expand_role(self) -> None:
if 'Model' not in self.config:
return
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
config = self.config['Model']
if 'ExecutionRoleArn' in config:
config['ExecutionRoleArn'] = hook.expand_role(config['ExecutionRoleArn'])

def execute(self, context) -> dict:
self.preprocess_config()

model_info = self.config.get('Model')
endpoint_config_info = self.config.get('EndpointConfig')
endpoint_info = self.config.get('Endpoint', self.config)

if model_info:
self.log.info('Creating SageMaker model %s.', model_info['ModelName'])
self.hook.create_model(model_info)

if endpoint_config_info:
self.log.info('Creating endpoint config %s.', endpoint_config_info['EndpointConfigName'])
self.hook.create_endpoint_config(endpoint_config_info)

if self.operation == 'create':
sagemaker_operation = self.hook.create_endpoint
log_str = 'Creating'
elif self.operation == 'update':
sagemaker_operation = self.hook.update_endpoint
log_str = 'Updating'
else:
raise ValueError('Invalid value! Argument operation has to be one of "create" and "update"')

self.log.info('%s SageMaker endpoint %s.', log_str, endpoint_info['EndpointName'])
try:
response = sagemaker_operation(
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)
except ClientError: # Botocore throws a ClientError if the endpoint is already created
self.operation = 'update'
sagemaker_operation = self.hook.update_endpoint
log_str = 'Updating'
response = sagemaker_operation(
endpoint_info,
wait_for_completion=self.wait_for_completion,
check_interval=self.check_interval,
max_ingestion_time=self.max_ingestion_time,
)

if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(f'Sagemaker endpoint creation failed: {response}')
else:
return {
'EndpointConfig': self.hook.describe_endpoint_config(endpoint_info['EndpointConfigName']),
'Endpoint': self.hook.describe_endpoint(endpoint_info['EndpointName']),
}
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
DeprecationWarning,
stacklevel=2,
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,14 @@
# specific language governing permissions and limitations
# under the License.

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""

import warnings

class SageMakerEndpointConfigOperator(SageMakerBaseOperator):
"""
Create a SageMaker endpoint config.
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerEndpointConfigOperator # noqa

This operator returns The ARN of the endpoint config created in Amazon SageMaker
:param config: The configuration necessary to create an endpoint config.
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_endpoint_config`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""

integer_fields = [['ProductionVariants', 'InitialInstanceCount']]

def __init__(self, *, config: dict, **kwargs):
super().__init__(config=config, **kwargs)

self.config = config

def execute(self, context) -> dict:
self.preprocess_config()

self.log.info('Creating SageMaker Endpoint Config %s.', self.config['EndpointConfigName'])
response = self.hook.create_endpoint_config(self.config)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(f'Sagemaker endpoint config creation failed: {response}')
else:
return {'EndpointConfig': self.hook.describe_endpoint_config(self.config['EndpointConfigName'])}
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
DeprecationWarning,
stacklevel=2,
)
43 changes: 8 additions & 35 deletions airflow/providers/amazon/aws/operators/sagemaker_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,41 +16,14 @@
# specific language governing permissions and limitations
# under the License.

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
from airflow.providers.amazon.aws.operators.sagemaker_base import SageMakerBaseOperator
"""This module is deprecated. Please use :mod:`airflow.providers.amazon.aws.operators.sagemaker`."""

import warnings

class SageMakerModelOperator(SageMakerBaseOperator):
"""
Create a SageMaker model.
from airflow.providers.amazon.aws.operators.sagemaker import SageMakerModelOperator # noqa

This operator returns The ARN of the model created in Amazon SageMaker
:param config: The configuration necessary to create a model.
For details of the configuration parameter see :py:meth:`SageMaker.Client.create_model`
:type config: dict
:param aws_conn_id: The AWS connection ID to use.
:type aws_conn_id: str
"""

def __init__(self, *, config, **kwargs):
super().__init__(config=config, **kwargs)

self.config = config

def expand_role(self) -> None:
if 'ExecutionRoleArn' in self.config:
hook = AwsBaseHook(self.aws_conn_id, client_type='iam')
self.config['ExecutionRoleArn'] = hook.expand_role(self.config['ExecutionRoleArn'])

def execute(self, context) -> dict:
self.preprocess_config()

self.log.info('Creating SageMaker Model %s.', self.config['ModelName'])
response = self.hook.create_model(self.config)
if response['ResponseMetadata']['HTTPStatusCode'] != 200:
raise AirflowException(f'Sagemaker model creation failed: {response}')
else:
return {'Model': self.hook.describe_model(self.config['ModelName'])}
warnings.warn(
"This module is deprecated. Please use `airflow.providers.amazon.aws.operators.sagemaker`.",
DeprecationWarning,
stacklevel=2,
)
Loading

0 comments on commit d557965

Please sign in to comment.