Skip to content

Commit

Permalink
Add Amazon Athena query results extra link (#36447)
Browse files Browse the repository at this point in the history
  • Loading branch information
Taragolis committed Dec 27, 2023
1 parent 9e55f51 commit d73bef2
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 1 deletion.
30 changes: 30 additions & 0 deletions airflow/providers/amazon/aws/links/athena.py
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink


class AthenaQueryResultsLink(BaseAwsLink):
"""Helper class for constructing Amazon Athena query results."""

name = "Query Results"
key = "_athena_query_results"
format_str = (
BASE_AWS_CONSOLE_LINK + "/athena/home?region={region_name}#"
"/query-editor/history/{query_execution_id}"
)
9 changes: 9 additions & 0 deletions airflow/providers/amazon/aws/operators/athena.py
Expand Up @@ -23,6 +23,7 @@
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.athena import AthenaHook
from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
from airflow.providers.amazon.aws.triggers.athena import AthenaTrigger
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
Expand Down Expand Up @@ -82,6 +83,7 @@ class AthenaOperator(AwsBaseOperator[AthenaHook]):
)
template_ext: Sequence[str] = (".sql",)
template_fields_renderers = {"query": "sql"}
operator_extra_links = (AthenaQueryResultsLink(),)

def __init__(
self,
Expand Down Expand Up @@ -132,6 +134,13 @@ def execute(self, context: Context) -> str | None:
self.client_request_token,
self.workgroup,
)
AthenaQueryResultsLink.persist(
context=context,
operator=self,
region_name=self.hook.conn_region_name,
aws_partition=self.hook.conn_partition,
query_execution_id=self.query_execution_id,
)

if self.deferrable:
self.defer(
Expand Down
1 change: 1 addition & 0 deletions airflow/providers/amazon/provider.yaml
Expand Up @@ -711,6 +711,7 @@ transfers:
python-module: airflow.providers.amazon.aws.transfers.azure_blob_to_s3

extra-links:
- airflow.providers.amazon.aws.links.athena.AthenaQueryResultsLink
- airflow.providers.amazon.aws.links.batch.BatchJobDefinitionLink
- airflow.providers.amazon.aws.links.batch.BatchJobDetailsLink
- airflow.providers.amazon.aws.links.batch.BatchJobQueueLink
Expand Down
35 changes: 35 additions & 0 deletions tests/providers/amazon/aws/links/test_athena.py
@@ -0,0 +1,35 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from airflow.providers.amazon.aws.links.athena import AthenaQueryResultsLink
from tests.providers.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase


class TestAthenaQueryResultsLink(BaseAwsLinksTestCase):
link_class = AthenaQueryResultsLink

def test_extra_link(self):
self.assert_extra_link_url(
expected_url=(
"https://console.aws.amazon.com/athena/home"
"?region=eu-west-1#/query-editor/history/00000000-0000-0000-0000-000000000000"
),
region_name="eu-west-1",
aws_partition="aws",
query_execution_id="00000000-0000-0000-0000-000000000000",
)
25 changes: 24 additions & 1 deletion tests/providers/amazon/aws/operators/test_athena.py
Expand Up @@ -57,7 +57,8 @@


class TestAthenaOperator:
def setup_method(self):
@pytest.fixture(autouse=True)
def setup_test_cases(self):
args = {
"owner": "airflow",
"start_date": DEFAULT_DATE,
Expand All @@ -77,6 +78,10 @@ def setup_method(self):
**self.default_op_kwargs, output_location="s3://test_s3_bucket/", aws_conn_id=None, dag=self.dag
)

with mock.patch("airflow.providers.amazon.aws.links.athena.AthenaQueryResultsLink.persist") as m:
self.mocked_athena_result_link = m
yield

def test_base_aws_op_attributes(self):
op = AthenaOperator(**self.default_op_kwargs)
assert op.hook.aws_conn_id == "aws_default"
Expand Down Expand Up @@ -138,6 +143,15 @@ def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_chec
)
assert mock_check_query_status.call_count == 1

# Validate call persist Athena Query result link
self.mocked_athena_result_link.assert_called_once_with(
aws_partition=mock.ANY,
context=mock.ANY,
operator=mock.ANY,
region_name=mock.ANY,
query_execution_id=ATHENA_QUERY_ID,
)

@mock.patch.object(
AthenaHook,
"check_query_status",
Expand Down Expand Up @@ -241,6 +255,15 @@ def test_is_deferred(self, mock_run_query):

assert isinstance(deferred.value.trigger, AthenaTrigger)

# Validate call persist Athena Query result link
self.mocked_athena_result_link.assert_called_once_with(
aws_partition=mock.ANY,
context=mock.ANY,
operator=mock.ANY,
region_name=mock.ANY,
query_execution_id=ATHENA_QUERY_ID,
)

@mock.patch.object(AthenaHook, "region_name", new_callable=mock.PropertyMock)
@mock.patch.object(AthenaHook, "get_conn")
def test_operator_openlineage_data(self, mock_conn, mock_region_name):
Expand Down

0 comments on commit d73bef2

Please sign in to comment.