Skip to content

Commit

Permalink
Use base aws classes in Amazon Athena Operators/Sensors/Triggers (#35133
Browse files Browse the repository at this point in the history
)

* Use base aws classes in Amazon Athena Operators/Sensors/Triggers

* Fix positional arguments in AthenaTrigger
  • Loading branch information
Taragolis committed Oct 24, 2023
1 parent f457228 commit da45606
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 65 deletions.
47 changes: 31 additions & 16 deletions airflow/providers/amazon/aws/operators/athena.py
Expand Up @@ -17,22 +17,22 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context


class AthenaOperator(BaseOperator):
class AthenaOperator(AwsBaseOperator[AthenaHook]):
"""
An operator that submits a presto query to athena.
An operator that submits a Trino/Presto query to Amazon Athena.
.. note:: if the task is killed while it runs, it'll cancel the athena query that was launched,
EXCEPT if running in deferrable mode.
Expand All @@ -41,11 +41,10 @@ class AthenaOperator(BaseOperator):
For more information on how to use this operator, take a look at the guide:
:ref:`howto/operator:AthenaOperator`
:param query: Presto to be run on athena. (templated)
:param query: Trino/Presto query to be run on Amazon Athena. (templated)
:param database: Database to select. (templated)
:param catalog: Catalog to select. (templated)
:param output_location: s3 path to write the query results into. (templated)
:param aws_conn_id: aws connection to use
:param client_request_token: Unique token created by user to avoid multiple executions of same query
:param workgroup: Athena workgroup in which query will be run. (templated)
:param query_execution_context: Context in which query need to be run
Expand All @@ -55,10 +54,23 @@ class AthenaOperator(BaseOperator):
To limit task execution time, use execution_timeout.
:param log_query: Whether to log athena query and other execution params when it's executed.
Defaults to *True*.
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

aws_hook_class = AthenaHook
ui_color = "#44b5e2"
template_fields: Sequence[str] = ("query", "database", "output_location", "workgroup", "catalog")
template_fields: Sequence[str] = aws_template_fields(
"query", "database", "output_location", "workgroup", "catalog"
)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"query": "sql"}

Expand All @@ -68,7 +80,6 @@ def __init__(
query: str,
database: str,
output_location: str,
aws_conn_id: str = "aws_default",
client_request_token: str | None = None,
workgroup: str = "primary",
query_execution_context: dict[str, str] | None = None,
Expand All @@ -84,7 +95,6 @@ def __init__(
self.query = query
self.database = database
self.output_location = output_location
self.aws_conn_id = aws_conn_id
self.client_request_token = client_request_token
self.workgroup = workgroup
self.query_execution_context = query_execution_context or {}
Expand All @@ -96,13 +106,12 @@ def __init__(
self.deferrable = deferrable
self.catalog: str = catalog

@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id, log_query=self.log_query)
@property
def _hook_parameters(self) -> dict[str, Any]:
return {**super()._hook_parameters, "log_query": self.log_query}

def execute(self, context: Context) -> str | None:
"""Run Presto Query on Athena."""
"""Run Trino/Presto Query on Amazon Athena."""
self.query_execution_context["Database"] = self.database
self.query_execution_context["Catalog"] = self.catalog
self.result_configuration["OutputLocation"] = self.output_location
Expand All @@ -117,7 +126,13 @@ def execute(self, context: Context) -> str | None:
if self.deferrable:
self.defer(
trigger=AthenaTrigger(
self.query_execution_id, self.sleep_time, self.max_polling_attempts, self.aws_conn_id
query_execution_id=self.query_execution_id,
waiter_delay=self.sleep_time,
waiter_max_attempts=self.max_polling_attempts,
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
botocore_config=self.botocore_config,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -148,7 +163,7 @@ def execute_complete(self, context, event=None):
return event["value"]

def on_kill(self) -> None:
"""Cancel the submitted athena query."""
"""Cancel the submitted Amazon Athena query."""
if self.query_execution_id:
self.log.info("Received a kill signal.")
response = self.hook.stop_query(self.query_execution_id)
Expand Down
31 changes: 18 additions & 13 deletions airflow/providers/amazon/aws/sensors/athena.py
Expand Up @@ -17,18 +17,19 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any, Sequence

from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields

if TYPE_CHECKING:
from airflow.utils.context import Context

from airflow.exceptions import AirflowException, AirflowSkipException
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.sensors.base import BaseSensorOperator


class AthenaSensor(BaseSensorOperator):
class AthenaSensor(AwsBaseSensor[AthenaHook]):
"""
Poll the state of the Query until it reaches a terminal state; fails if the query fails.
Expand All @@ -40,9 +41,18 @@ class AthenaSensor(BaseSensorOperator):
:param query_execution_id: query_execution_id to check the state of
:param max_retries: Number of times to poll for query state before
returning the current state, defaults to None
:param aws_conn_id: aws connection to use, defaults to 'aws_default'
:param sleep_time: Time in seconds to wait between two consecutive call to
check query status on athena, defaults to 10
:param aws_conn_id: The Airflow connection used for AWS credentials.
If this is ``None`` or empty then the default boto3 behaviour is used. If
running Airflow in a distributed manner and aws_conn_id is None or
empty, then default boto3 configuration would be used (and must be
maintained on each worker node).
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
:param verify: Whether or not to verify SSL certificates. See:
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
"""

INTERMEDIATE_STATES = (
Expand All @@ -55,21 +65,21 @@ class AthenaSensor(BaseSensorOperator):
)
SUCCESS_STATES = ("SUCCEEDED",)

template_fields: Sequence[str] = ("query_execution_id",)
template_ext: Sequence[str] = ()
aws_hook_class = AthenaHook
template_fields: Sequence[str] = aws_template_fields(
"query_execution_id",
)
ui_color = "#66c3ff"

def __init__(
self,
*,
query_execution_id: str,
max_retries: int | None = None,
aws_conn_id: str = "aws_default",
sleep_time: int = 10,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.aws_conn_id = aws_conn_id
self.query_execution_id = query_execution_id
self.sleep_time = sleep_time
self.max_retries = max_retries
Expand All @@ -87,8 +97,3 @@ def poke(self, context: Context) -> bool:
if state in self.INTERMEDIATE_STATES:
return False
return True

@cached_property
def hook(self) -> AthenaHook:
"""Create and return an AthenaHook."""
return AthenaHook(self.aws_conn_id)
11 changes: 9 additions & 2 deletions airflow/providers/amazon/aws/triggers/athena.py
Expand Up @@ -43,7 +43,8 @@ def __init__(
query_execution_id: str,
waiter_delay: int,
waiter_max_attempts: int,
aws_conn_id: str,
aws_conn_id: str | None,
**kwargs,
):
super().__init__(
serialized_fields={"query_execution_id": query_execution_id},
Expand All @@ -56,7 +57,13 @@ def __init__(
waiter_delay=waiter_delay,
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
**kwargs,
)

def hook(self) -> AwsGenericHook:
return AthenaHook(self.aws_conn_id)
return AthenaHook(
aws_conn_id=self.aws_conn_id,
region_name=self.region_name,
verify=self.verify,
config=self.botocore_config,
)
5 changes: 5 additions & 0 deletions docs/apache-airflow-providers-amazon/operators/athena.rst
Expand Up @@ -30,6 +30,11 @@ Prerequisite Tasks

.. include:: ../_partials/prerequisite_tasks.rst

Generic Parameters
------------------

.. include:: ../_partials/generic_parameters.rst

Operators
---------

Expand Down
30 changes: 26 additions & 4 deletions tests/providers/amazon/aws/operators/test_athena.py
Expand Up @@ -53,25 +53,47 @@ def setup_method(self):
"start_date": DEFAULT_DATE,
}

self.dag = DAG(f"{TEST_DAG_ID}test_schedule_dag_once", default_args=args, schedule="@once")
self.dag = DAG(TEST_DAG_ID, default_args=args, schedule="@once")

self.athena = AthenaOperator(
self.default_op_kwargs = dict(
task_id="test_athena_operator",
query="SELECT * FROM TEST_TABLE",
database="TEST_DATABASE",
output_location="s3://test_s3_bucket/",
client_request_token="eac427d0-1c6d-4dfb-96aa-2835d3ac6595",
sleep_time=0,
max_polling_attempts=3,
dag=self.dag,
)
self.athena = AthenaOperator(**self.default_op_kwargs, aws_conn_id=None, dag=self.dag)

def test_base_aws_op_attributes(self):
op = AthenaOperator(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None
assert op.hook.log_query is True

op = AthenaOperator(
**self.default_op_kwargs,
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
log_query=False,
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42
assert op.hook.log_query is False

def test_init(self):
assert self.athena.task_id == MOCK_DATA["task_id"]
assert self.athena.query == MOCK_DATA["query"]
assert self.athena.database == MOCK_DATA["database"]
assert self.athena.catalog == MOCK_DATA["catalog"]
assert self.athena.aws_conn_id == "aws_default"
assert self.athena.client_request_token == MOCK_DATA["client_request_token"]
assert self.athena.sleep_time == 0

Expand Down
76 changes: 46 additions & 30 deletions tests/providers/amazon/aws/sensors/test_athena.py
Expand Up @@ -26,48 +26,64 @@
from airflow.providers.amazon.aws.sensors.athena import AthenaSensor


@pytest.fixture
def mock_poll_query_status():
with mock.patch.object(AthenaHook, "poll_query_status") as m:
yield m


class TestAthenaSensor:
def setup_method(self):
self.sensor = AthenaSensor(
self.default_op_kwargs = dict(
task_id="test_athena_sensor",
query_execution_id="abc",
sleep_time=5,
max_retries=1,
aws_conn_id="aws_default",
)
self.sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None)

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("SUCCEEDED",))
def test_poke_success(self, mock_poll_query_status):
assert self.sensor.poke({}) is True

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("RUNNING",))
def test_poke_running(self, mock_poll_query_status):
assert self.sensor.poke({}) is False
def test_base_aws_op_attributes(self):
op = AthenaSensor(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
assert op.hook._region_name is None
assert op.hook._verify is None
assert op.hook._config is None
assert op.hook.log_query is True

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("QUEUED",))
def test_poke_queued(self, mock_poll_query_status):
assert self.sensor.poke({}) is False
op = AthenaSensor(
**self.default_op_kwargs,
aws_conn_id="aws-test-custom-conn",
region_name="eu-west-1",
verify=False,
botocore_config={"read_timeout": 42},
)
assert op.hook.aws_conn_id == "aws-test-custom-conn"
assert op.hook._region_name == "eu-west-1"
assert op.hook._verify is False
assert op.hook._config is not None
assert op.hook._config.read_timeout == 42

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("FAILED",))
def test_poke_failed(self, mock_poll_query_status):
with pytest.raises(AirflowException) as ctx:
self.sensor.poke({})
assert "Athena sensor failed" in str(ctx.value)
@pytest.mark.parametrize("state", ["SUCCEEDED"])
def test_poke_success_states(self, state, mock_poll_query_status):
mock_poll_query_status.side_effect = [state]
assert self.sensor.poke({}) is True

@mock.patch.object(AthenaHook, "poll_query_status", side_effect=("CANCELLED",))
def test_poke_cancelled(self, mock_poll_query_status):
with pytest.raises(AirflowException) as ctx:
self.sensor.poke({})
assert "Athena sensor failed" in str(ctx.value)
@pytest.mark.parametrize("state", ["RUNNING", "QUEUED"])
def test_poke_intermediate_states(self, state, mock_poll_query_status):
mock_poll_query_status.side_effect = [state]
assert self.sensor.poke({}) is False

@pytest.mark.parametrize(
"soft_fail, expected_exception", ((False, AirflowException), (True, AirflowSkipException))
"soft_fail, expected_exception",
[
pytest.param(False, AirflowException, id="not-soft-fail"),
pytest.param(True, AirflowSkipException, id="soft-fail"),
],
)
def test_fail_poke(self, soft_fail, expected_exception):
self.sensor.soft_fail = soft_fail
@pytest.mark.parametrize("state", ["FAILED", "CANCELLED"])
def test_poke_failure_states(self, state, soft_fail, expected_exception, mock_poll_query_status):
mock_poll_query_status.side_effect = [state]
sensor = AthenaSensor(**self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail)
message = "Athena sensor failed"
with pytest.raises(expected_exception, match=message), mock.patch(
"airflow.providers.amazon.aws.hooks.athena.AthenaHook.poll_query_status"
) as poll_query_status:
poll_query_status.return_value = "FAILED"
self.sensor.poke(context={})
with pytest.raises(expected_exception, match=message):
sensor.poke({})

0 comments on commit da45606

Please sign in to comment.