Skip to content

Commit

Permalink
Implements AwsBaseOperator and AwsBaseSensor (#34784)
Browse files Browse the repository at this point in the history
* Implements `AwsBaseOperator` and `AwsBaseSensor`

* Apply suggestions from code review

Co-authored-by: D. Ferruzzi <ferruzzi@amazon.com>

* Move suggestion into the AwsBaseSensor

* Static checks + tests

* Move to base classes

* Change to base python exceptions

* Add mixin helpers

* Update airflow/providers/amazon/aws/utils/mixins.py

Co-authored-by: Niko Oliveira <onikolas@amazon.com>

---------

Co-authored-by: D. Ferruzzi <ferruzzi@amazon.com>
Co-authored-by: Niko Oliveira <onikolas@amazon.com>
  • Loading branch information
3 people committed Oct 11, 2023
1 parent 27671fa commit 84a3dae
Show file tree
Hide file tree
Showing 9 changed files with 719 additions and 2 deletions.
4 changes: 3 additions & 1 deletion airflow/providers/amazon/aws/hooks/base_aws.py
Expand Up @@ -463,14 +463,16 @@ def __init__(
region_name: str | None = None,
client_type: str | None = None,
resource_type: str | None = None,
config: Config | None = None,
config: Config | dict[str, Any] | None = None,
) -> None:
super().__init__()
self.aws_conn_id = aws_conn_id
self.client_type = client_type
self.resource_type = resource_type

self._region_name = region_name
if isinstance(config, dict):
config = Config(**config)
self._config = config
self._verify = verify

Expand Down
97 changes: 97 additions & 0 deletions airflow/providers/amazon/aws/operators/base_aws.py
@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Sequence

from airflow.models import BaseOperator
from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, AwsHookParams, AwsHookType


class AwsBaseOperator(BaseOperator, AwsBaseHookMixin[AwsHookType]):
"""
Base AWS (Amazon) Operator Class to build operators on top of AWS Hooks.
.. warning::
Only for internal usage, this class might be changed, renamed or removed in the future
without any further notice.
Examples:
.. code-block:: python
from airflow.providers.amazon.aws.hooks.foo_bar import FooBarThinHook, FooBarThickHook
class AwsFooBarOperator(AwsBaseOperator[FooBarThinHook]):
aws_hook_class = FooBarThinHook
def execute(self, context):
pass
class AwsFooBarOperator2(AwsBaseOperator[FooBarThickHook]):
aws_hook_class = FooBarThickHook
def __init__(self, *, spam: str, **kwargs):
super().__init__(**kwargs)
self.spam = spam
@property
def _hook_parameters(self):
return {**super()._hook_parameters, "spam": self.spam}
def execute(self, context):
pass
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
:meta private:
"""

template_fields: Sequence[str] = (
"aws_conn_id",
"region_name",
"botocore_config",
)

def __init__(
self,
*,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
**kwargs,
):
hook_params = AwsHookParams.from_constructor(
aws_conn_id, region_name, verify, botocore_config, additional_params=kwargs
)
super().__init__(**kwargs)
self.aws_conn_id = hook_params.aws_conn_id
self.region_name = hook_params.region_name
self.verify = hook_params.verify
self.botocore_config = hook_params.botocore_config
self.validate_attributes()
96 changes: 96 additions & 0 deletions airflow/providers/amazon/aws/sensors/base_aws.py
@@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from typing import Sequence

from airflow.providers.amazon.aws.utils.mixins import AwsBaseHookMixin, AwsHookParams, AwsHookType
from airflow.sensors.base import BaseSensorOperator


class AwsBaseSensor(BaseSensorOperator, AwsBaseHookMixin[AwsHookType]):
"""Base AWS (Amazon) Sensor Class for build sensors in top of AWS Hooks.
.. warning::
Only for internal usage, this class might be changed, renamed or removed in the future
without any further notice.
Examples:
.. code-block:: python
from airflow.providers.amazon.aws.hooks.foo_bar import FooBarThinHook, FooBarThickHook
class AwsFooBarSensor(AwsBaseSensor[FooBarThinHook]):
aws_hook_class = FooBarThinHook
def poke(self, context):
pass
class AwsFooBarSensor(AwsBaseSensor[FooBarThickHook]):
aws_hook_class = FooBarThickHook
def __init__(self, *, spam: str, **kwargs):
super().__init__(**kwargs)
self.spam = spam
@property
def _hook_parameters(self):
return {**super()._hook_parameters, "spam": self.spam}
def poke(self, context):
pass
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is None or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
:meta private:
"""

template_fields: Sequence[str] = (
"aws_conn_id",
"region_name",
"botocore_config",
)

def __init__(
self,
*,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
**kwargs,
):
hook_params = AwsHookParams.from_constructor(
aws_conn_id, region_name, verify, botocore_config, additional_params=kwargs
)
super().__init__(**kwargs)
self.aws_conn_id = hook_params.aws_conn_id
self.region_name = hook_params.region_name
self.verify = hook_params.verify
self.botocore_config = hook_params.botocore_config
self.validate_attributes()
165 changes: 165 additions & 0 deletions airflow/providers/amazon/aws/utils/mixins.py
@@ -0,0 +1,165 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
This module contains different mixin classes for internal use within the Amazon provider.
.. warning::
Only for internal usage, this module and all classes might be changed, renamed or removed in the future
without any further notice.
:meta: private
"""

from __future__ import annotations

import warnings
from functools import cached_property
from typing import Any, Generic, NamedTuple, TypeVar

from typing_extensions import final

from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook

AwsHookType = TypeVar("AwsHookType", bound=AwsGenericHook)
REGION_MSG = "`region` is deprecated and will be removed in the future. Please use `region_name` instead."


class AwsHookParams(NamedTuple):
"""
Default Aws Hook Parameters storage class.
:meta private:
"""

aws_conn_id: str | None
region_name: str | None
verify: bool | str | None
botocore_config: dict[str, Any] | None

@classmethod
def from_constructor(
cls,
aws_conn_id: str | None,
region_name: str | None,
verify: bool | str | None,
botocore_config: dict[str, Any] | None,
additional_params: dict,
):
"""
Resolve generic AWS Hooks parameters in class constructor.
Examples:
.. code-block:: python
class AwsFooBarOperator(BaseOperator):
def __init__(
self,
*,
aws_conn_id: str | None = "aws_default",
region_name: str | None = None,
verify: bool | str | None = None,
botocore_config: dict | None = None,
foo: str = "bar",
**kwargs,
):
params = AwsHookParams.from_constructor(
aws_conn_id, region_name, verify, botocore_config, additional_params=kwargs
)
super().__init__(**kwargs)
self.aws_conn_id = params.aws_conn_id
self.region_name = params.region_name
self.verify = params.verify
self.botocore_config = params.botocore_config
self.foo = foo
"""
if region := additional_params.pop("region", None):
warnings.warn(REGION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
if region_name and region_name != region:
raise ValueError(
f"Conflicting `region_name` provided, region_name={region_name!r}, region={region!r}."
)
region_name = region
return cls(aws_conn_id, region_name, verify, botocore_config)


class AwsBaseHookMixin(Generic[AwsHookType]):
"""Mixin class for AWS Operators, Sensors, etc.
.. warning::
Only for internal usage, this class might be changed, renamed or removed in the future
without any further notice.
:meta private:
"""

# Should be assigned in child class
aws_hook_class: type[AwsHookType]

aws_conn_id: str | None
region_name: str | None
verify: bool | str | None
botocore_config: dict[str, Any] | None

def validate_attributes(self):
"""Validate class attributes."""
if hasattr(self, "aws_hook_class"): # Validate if ``aws_hook_class`` is properly set.
try:
if not issubclass(self.aws_hook_class, AwsGenericHook):
raise TypeError
except TypeError:
# Raise if ``aws_hook_class`` is not a class or not a subclass of Generic/Base AWS Hook
raise AttributeError(
f"Class attribute '{type(self).__name__}.aws_hook_class' "
f"is not a subclass of AwsGenericHook."
) from None
else:
raise AttributeError(f"Class attribute '{type(self).__name__}.aws_hook_class' should be set.")

@property
def _hook_parameters(self) -> dict[str, Any]:
"""
Mapping parameters to build boto3-related hooks.
Only required to be overwritten for thick-wrapped Hooks.
"""
return {
"aws_conn_id": self.aws_conn_id,
"region_name": self.region_name,
"verify": self.verify,
"config": self.botocore_config,
}

@cached_property
@final
def hook(self) -> AwsHookType:
"""
Return AWS Provider's hook based on ``aws_hook_class``.
This method implementation should be taken as a final for
thin-wrapped Hooks around boto3. For thick-wrapped Hooks developer
should consider to overwrite ``_hook_parameters`` method instead.
"""
return self.aws_hook_class(**self._hook_parameters)

@property
@final
def region(self) -> str | None:
"""Alias for ``region_name``, used for compatibility (deprecated)."""
warnings.warn(REGION_MSG, AirflowProviderDeprecationWarning, stacklevel=3)
return self.region_name
6 changes: 6 additions & 0 deletions airflow/providers/amazon/provider.yaml
Expand Up @@ -293,6 +293,9 @@ operators:
- integration-name: Amazon Athena
python-modules:
- airflow.providers.amazon.aws.operators.athena
- integration-name: Amazon Web Services
python-modules:
- airflow.providers.amazon.aws.operators.base_aws
- integration-name: AWS Batch
python-modules:
- airflow.providers.amazon.aws.operators.batch
Expand Down Expand Up @@ -366,6 +369,9 @@ sensors:
- integration-name: Amazon Athena
python-modules:
- airflow.providers.amazon.aws.sensors.athena
- integration-name: Amazon Web Services
python-modules:
- airflow.providers.amazon.aws.sensors.base_aws
- integration-name: AWS Batch
python-modules:
- airflow.providers.amazon.aws.sensors.batch
Expand Down

0 comments on commit 84a3dae

Please sign in to comment.