diff --git a/airflow/contrib/hooks/aws_athena_hook.py b/airflow/contrib/hooks/aws_athena_hook.py new file mode 100644 index 0000000000000..f11ff23c515f4 --- /dev/null +++ b/airflow/contrib/hooks/aws_athena_hook.py @@ -0,0 +1,150 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from time import sleep +from airflow.contrib.hooks.aws_hook import AwsHook + + +class AWSAthenaHook(AwsHook): + """ + Interact with AWS Athena to run, poll queries and return query results + + :param aws_conn_id: aws connection to use. + :type aws_conn_id: str + :param sleep_time: Time to wait between two consecutive call to check query status on athena + :type sleep_time: int + """ + + INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',) + FAILURE_STATES = ('FAILED', 'CANCELLED',) + SUCCESS_STATES = ('SUCCEEDED',) + + def __init__(self, aws_conn_id='aws_default', sleep_time=30, *args, **kwargs): + super(AWSAthenaHook, self).__init__(aws_conn_id, **kwargs) + self.sleep_time = sleep_time + self.conn = None + + def get_conn(self): + """ + check if aws conn exists already or create one and return it + + :return: boto3 session + """ + if not self.conn: + self.conn = self.get_client_type('athena') + return self.conn + + def run_query(self, query, query_context, result_configuration, client_request_token=None): + """ + Run Presto query on athena with provided config and return submitted query_execution_id + + :param query: Presto query to run + :type query: str + :param query_context: Context in which query need to be run + :type query_context: dict + :param result_configuration: Dict with path to store results in and config related to encryption + :type result_configuration: dict + :param client_request_token: Unique token created by user to avoid multiple executions of same query + :type client_request_token: str + :return: str + """ + response = self.conn.start_query_execution(QueryString=query, + ClientRequestToken=client_request_token, + QueryExecutionContext=query_context, + ResultConfiguration=result_configuration) + query_execution_id = response['QueryExecutionId'] + return query_execution_id + + def check_query_status(self, query_execution_id): + """ + Fetch the status of submitted athena query. Returns None or one of valid query states. + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :return: str + """ + response = self.conn.get_query_execution(QueryExecutionId=query_execution_id) + state = None + try: + state = response['QueryExecution']['Status']['State'] + except Exception as ex: + self.log.error('Exception while getting query state', ex) + finally: + return state + + def get_query_results(self, query_execution_id): + """ + Fetch submitted athena query results. returns none if query is in intermediate state or + failed/cancelled state else dict of query output + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :return: dict + """ + query_state = self.check_query_status(query_execution_id) + if query_state is None: + self.log.error('Invalid Query state') + return None + elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES: + self.log.error('Query is in {state} state. Cannot fetch results'.format(state=query_state)) + return None + return self.conn.get_query_results(QueryExecutionId=query_execution_id) + + def poll_query_status(self, query_execution_id, max_tries=None): + """ + Poll the status of submitted athena query until query state reaches final state. + Returns one of the final states + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :param max_tries: Number of times to poll for query state before function exits + :type max_tries: int + :return: str + """ + try_number = 1 + final_query_state = None # Query state when query reaches final state or max_tries reached + while True: + query_state = self.check_query_status(query_execution_id) + if query_state is None: + self.log.info('Trial {try_number}: Invalid query state. Retrying again'.format( + try_number=try_number)) + elif query_state in self.INTERMEDIATE_STATES: + self.log.info('Trial {try_number}: Query is still in an intermediate state - {state}' + .format(try_number=try_number, state=query_state)) + else: + self.log.info('Trial {try_number}: Query execution completed. Final state is {state}' + .format(try_number=try_number, state=query_state)) + final_query_state = query_state + break + if max_tries and try_number >= max_tries: # Break loop if max_tries reached + final_query_state = query_state + break + try_number += 1 + sleep(self.sleep_time) + return final_query_state + + def stop_query(self, query_execution_id): + """ + Cancel the submitted athena query + + :param query_execution_id: Id of submitted athena query + :type query_execution_id: str + :return: dict + """ + return self.conn.stop_query_execution(QueryExecutionId=query_execution_id) diff --git a/airflow/contrib/operators/aws_athena_operator.py b/airflow/contrib/operators/aws_athena_operator.py new file mode 100644 index 0000000000000..432410e31100a --- /dev/null +++ b/airflow/contrib/operators/aws_athena_operator.py @@ -0,0 +1,98 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +from uuid import uuid4 + +from airflow.models import BaseOperator +from airflow.utils.decorators import apply_defaults +from airflow.contrib.hooks.aws_athena_hook import AWSAthenaHook + + +class AWSAthenaOperator(BaseOperator): + """ + An operator that submit presto query to athena. + + :param query: Presto to be run on athena. (templated) + :type query: str + :param database: Database to select. (templated) + :type database: str + :param output_location: s3 path to write the query results into. (templated) + :type output_location: str + :param aws_conn_id: aws connection to use + :type aws_conn_id: str + :param sleep_time: Time to wait between two consecutive call to check query status on athena + :type sleep_time: int + """ + + ui_color = '#44b5e2' + template_fields = ('query', 'database', 'output_location') + + @apply_defaults + def __init__(self, query, database, output_location, aws_conn_id='aws_default', client_request_token=None, + query_execution_context=None, result_configuration=None, sleep_time=30, *args, **kwargs): + super(AWSAthenaOperator, self).__init__(*args, **kwargs) + self.query = query + self.database = database + self.output_location = output_location + self.aws_conn_id = aws_conn_id + self.client_request_token = client_request_token or str(uuid4()) + self.query_execution_context = query_execution_context or {} + self.result_configuration = result_configuration or {} + self.sleep_time = sleep_time + self.query_execution_id = None + self.hook = None + + def get_hook(self): + return AWSAthenaHook(self.aws_conn_id, self.sleep_time) + + def execute(self, context): + """ + Run Presto Query on Athena + """ + self.hook = self.get_hook() + self.hook.get_conn() + + self.query_execution_context['Database'] = self.database + self.result_configuration['OutputLocation'] = self.output_location + self.query_execution_id = self.hook.run_query(self.query, self.query_execution_context, + self.result_configuration, self.client_request_token) + self.hook.poll_query_status(self.query_execution_id) + + def on_kill(self): + """ + Cancel the submitted athena query + """ + if self.query_execution_id: + self.log.info('⚰️⚰️⚰️ Received a kill Signal. Time to Die') + self.log.info('Stopping Query with executionId - {queryId}'.format( + queryId=self.query_execution_id)) + response = self.hook.stop_query(self.query_execution_id) + http_status_code = None + try: + http_status_code = response['ResponseMetadata']['HTTPStatusCode'] + except Exception as ex: + self.log.error('Exception while cancelling query', ex) + finally: + if http_status_code is None or http_status_code != 200: + self.log.error('Unable to request query cancel on athena. Exiting') + else: + self.log.info('Polling Athena for query with id {queryId} to reach final state'.format( + queryId=self.query_execution_id)) + self.hook.poll_query_status(self.query_execution_id) diff --git a/docs/code.rst b/docs/code.rst index 817c5046ea29d..5c74b0ce3fb24 100644 --- a/docs/code.rst +++ b/docs/code.rst @@ -113,6 +113,7 @@ Operators .. Alphabetize this list .. autoclass:: airflow.contrib.operators.adls_list_operator.AzureDataLakeStorageListOperator +.. autoclass:: airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator .. autoclass:: airflow.contrib.operators.awsbatch_operator.AWSBatchOperator .. autoclass:: airflow.contrib.operators.bigquery_check_operator.BigQueryCheckOperator .. autoclass:: airflow.contrib.operators.bigquery_check_operator.BigQueryValueCheckOperator @@ -384,6 +385,7 @@ interface when possible and acting as building blocks for operators. Community contributed hooks ''''''''''''''''''''''''''' .. Alphabetize this list +.. autoclass:: airflow.contrib.hooks.aws_athena_hook.AWSAthenaHook .. autoclass:: airflow.contrib.hooks.aws_dynamodb_hook.AwsDynamoDBHook .. autoclass:: airflow.contrib.hooks.aws_hook.AwsHook .. autoclass:: airflow.contrib.hooks.aws_lambda_hook.AwsLambdaHook diff --git a/tests/contrib/operators/test_aws_athena_operator.py b/tests/contrib/operators/test_aws_athena_operator.py new file mode 100644 index 0000000000000..ecfb0d2890100 --- /dev/null +++ b/tests/contrib/operators/test_aws_athena_operator.py @@ -0,0 +1,107 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# 'License'); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# + +import unittest + +from airflow.contrib.operators.aws_athena_operator import AWSAthenaOperator +from airflow.contrib.hooks.aws_athena_hook import AWSAthenaHook +from airflow import configuration + +try: + from unittest import mock +except ImportError: + try: + import mock + except ImportError: + mock = None + +MOCK_DATA = { + 'task_id': 'test_aws_athena_operator', + 'query': 'SELECT * FROM TEST_TABLE', + 'database': 'TEST_DATABASE', + 'outputLocation': 's3://test_s3_bucket/', + 'client_request_token': 'eac427d0-1c6d-4dfb-96aa-2835d3ac6595' +} + +query_context = { + 'Database': MOCK_DATA['database'] +} +result_configuration = { + 'OutputLocation': MOCK_DATA['outputLocation'] +} + + +class TestAWSAthenaOperator(unittest.TestCase): + + def setUp(self): + configuration.load_test_config() + + self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE', + database='TEST_DATABASE', output_location='s3://test_s3_bucket/', + client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595', + sleep_time=1) + + def test_init(self): + self.assertEqual(self.athena.task_id, MOCK_DATA['task_id']) + self.assertEqual(self.athena.query, MOCK_DATA['query']) + self.assertEqual(self.athena.database, MOCK_DATA['database']) + self.assertEqual(self.athena.aws_conn_id, 'aws_default') + self.assertEqual(self.athena.client_request_token, MOCK_DATA['client_request_token']) + self.assertEqual(self.athena.sleep_time, 1) + + @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",)) + @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234') + @mock.patch.object(AWSAthenaHook, 'get_conn') + def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_check_query_status): + self.athena.execute(None) + mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, + MOCK_DATA['client_request_token']) + self.assertEqual(mock_check_query_status.call_count, 1) + + @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "SUCCESS",)) + @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234') + @mock.patch.object(AWSAthenaHook, 'get_conn') + def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_query_status): + self.athena.execute(None) + mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, + MOCK_DATA['client_request_token']) + self.assertEqual(mock_check_query_status.call_count, 3) + + @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "FAILED",)) + @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234') + @mock.patch.object(AWSAthenaHook, 'get_conn') + def test_hook_run_failure_query(self, mock_conn, mock_run_query, mock_check_query_status): + self.athena.execute(None) + mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, + MOCK_DATA['client_request_token']) + self.assertEqual(mock_check_query_status.call_count, 2) + + @mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "CANCELLED",)) + @mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234') + @mock.patch.object(AWSAthenaHook, 'get_conn') + def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status): + self.athena.execute(None) + mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration, + MOCK_DATA['client_request_token']) + self.assertEqual(mock_check_query_status.call_count, 3) + + +if __name__ == '__main__': + unittest.main()