Skip to content

Commit

Permalink
[AIRFLOW-3266] Add AWS Athena Hook and Operator (#4111)
Browse files Browse the repository at this point in the history
Provides AWS Athena hook and operator to submit Athena(presto) queries on AWS.

Authored-by: Phanindhra <phani8996@gmail.com>
  • Loading branch information
phanindhra876 authored and kaxil committed Dec 2, 2018
1 parent 02e294f commit c98d48d
Show file tree
Hide file tree
Showing 4 changed files with 357 additions and 0 deletions.
150 changes: 150 additions & 0 deletions airflow/contrib/hooks/aws_athena_hook.py
@@ -0,0 +1,150 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from time import sleep
from airflow.contrib.hooks.aws_hook import AwsHook


class AWSAthenaHook(AwsHook):
"""
Interact with AWS Athena to run, poll queries and return query results
:param aws_conn_id: aws connection to use.
:type aws_conn_id: str
:param sleep_time: Time to wait between two consecutive call to check query status on athena
:type sleep_time: int
"""

INTERMEDIATE_STATES = ('QUEUED', 'RUNNING',)
FAILURE_STATES = ('FAILED', 'CANCELLED',)
SUCCESS_STATES = ('SUCCEEDED',)

def __init__(self, aws_conn_id='aws_default', sleep_time=30, *args, **kwargs):
super(AWSAthenaHook, self).__init__(aws_conn_id, **kwargs)
self.sleep_time = sleep_time
self.conn = None

def get_conn(self):
"""
check if aws conn exists already or create one and return it
:return: boto3 session
"""
if not self.conn:
self.conn = self.get_client_type('athena')
return self.conn

def run_query(self, query, query_context, result_configuration, client_request_token=None):
"""
Run Presto query on athena with provided config and return submitted query_execution_id
:param query: Presto query to run
:type query: str
:param query_context: Context in which query need to be run
:type query_context: dict
:param result_configuration: Dict with path to store results in and config related to encryption
:type result_configuration: dict
:param client_request_token: Unique token created by user to avoid multiple executions of same query
:type client_request_token: str
:return: str
"""
response = self.conn.start_query_execution(QueryString=query,
ClientRequestToken=client_request_token,
QueryExecutionContext=query_context,
ResultConfiguration=result_configuration)
query_execution_id = response['QueryExecutionId']
return query_execution_id

def check_query_status(self, query_execution_id):
"""
Fetch the status of submitted athena query. Returns None or one of valid query states.
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: str
"""
response = self.conn.get_query_execution(QueryExecutionId=query_execution_id)
state = None
try:
state = response['QueryExecution']['Status']['State']
except Exception as ex:
self.log.error('Exception while getting query state', ex)
finally:
return state

def get_query_results(self, query_execution_id):
"""
Fetch submitted athena query results. returns none if query is in intermediate state or
failed/cancelled state else dict of query output
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: dict
"""
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.error('Invalid Query state')
return None
elif query_state in self.INTERMEDIATE_STATES or query_state in self.FAILURE_STATES:
self.log.error('Query is in {state} state. Cannot fetch results'.format(state=query_state))
return None
return self.conn.get_query_results(QueryExecutionId=query_execution_id)

def poll_query_status(self, query_execution_id, max_tries=None):
"""
Poll the status of submitted athena query until query state reaches final state.
Returns one of the final states
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:param max_tries: Number of times to poll for query state before function exits
:type max_tries: int
:return: str
"""
try_number = 1
final_query_state = None # Query state when query reaches final state or max_tries reached
while True:
query_state = self.check_query_status(query_execution_id)
if query_state is None:
self.log.info('Trial {try_number}: Invalid query state. Retrying again'.format(
try_number=try_number))
elif query_state in self.INTERMEDIATE_STATES:
self.log.info('Trial {try_number}: Query is still in an intermediate state - {state}'
.format(try_number=try_number, state=query_state))
else:
self.log.info('Trial {try_number}: Query execution completed. Final state is {state}'
.format(try_number=try_number, state=query_state))
final_query_state = query_state
break
if max_tries and try_number >= max_tries: # Break loop if max_tries reached
final_query_state = query_state
break
try_number += 1
sleep(self.sleep_time)
return final_query_state

def stop_query(self, query_execution_id):
"""
Cancel the submitted athena query
:param query_execution_id: Id of submitted athena query
:type query_execution_id: str
:return: dict
"""
return self.conn.stop_query_execution(QueryExecutionId=query_execution_id)
98 changes: 98 additions & 0 deletions airflow/contrib/operators/aws_athena_operator.py
@@ -0,0 +1,98 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

from uuid import uuid4

from airflow.models import BaseOperator
from airflow.utils.decorators import apply_defaults
from airflow.contrib.hooks.aws_athena_hook import AWSAthenaHook


class AWSAthenaOperator(BaseOperator):
"""
An operator that submit presto query to athena.
:param query: Presto to be run on athena. (templated)
:type query: str
:param database: Database to select. (templated)
:type database: str
:param output_location: s3 path to write the query results into. (templated)
:type output_location: str
:param aws_conn_id: aws connection to use
:type aws_conn_id: str
:param sleep_time: Time to wait between two consecutive call to check query status on athena
:type sleep_time: int
"""

ui_color = '#44b5e2'
template_fields = ('query', 'database', 'output_location')

@apply_defaults
def __init__(self, query, database, output_location, aws_conn_id='aws_default', client_request_token=None,
query_execution_context=None, result_configuration=None, sleep_time=30, *args, **kwargs):
super(AWSAthenaOperator, self).__init__(*args, **kwargs)
self.query = query
self.database = database
self.output_location = output_location
self.aws_conn_id = aws_conn_id
self.client_request_token = client_request_token or str(uuid4())
self.query_execution_context = query_execution_context or {}
self.result_configuration = result_configuration or {}
self.sleep_time = sleep_time
self.query_execution_id = None
self.hook = None

def get_hook(self):
return AWSAthenaHook(self.aws_conn_id, self.sleep_time)

def execute(self, context):
"""
Run Presto Query on Athena
"""
self.hook = self.get_hook()
self.hook.get_conn()

self.query_execution_context['Database'] = self.database
self.result_configuration['OutputLocation'] = self.output_location
self.query_execution_id = self.hook.run_query(self.query, self.query_execution_context,
self.result_configuration, self.client_request_token)
self.hook.poll_query_status(self.query_execution_id)

def on_kill(self):
"""
Cancel the submitted athena query
"""
if self.query_execution_id:
self.log.info('⚰️⚰️⚰️ Received a kill Signal. Time to Die')
self.log.info('Stopping Query with executionId - {queryId}'.format(
queryId=self.query_execution_id))
response = self.hook.stop_query(self.query_execution_id)
http_status_code = None
try:
http_status_code = response['ResponseMetadata']['HTTPStatusCode']
except Exception as ex:
self.log.error('Exception while cancelling query', ex)
finally:
if http_status_code is None or http_status_code != 200:
self.log.error('Unable to request query cancel on athena. Exiting')
else:
self.log.info('Polling Athena for query with id {queryId} to reach final state'.format(
queryId=self.query_execution_id))
self.hook.poll_query_status(self.query_execution_id)
2 changes: 2 additions & 0 deletions docs/code.rst
Expand Up @@ -130,6 +130,7 @@ Operators
.. Alphabetize this list
.. autoclass:: airflow.contrib.operators.adls_list_operator.AzureDataLakeStorageListOperator
.. autoclass:: airflow.contrib.operators.aws_athena_operator.AWSAthenaOperator
.. autoclass:: airflow.contrib.operators.awsbatch_operator.AWSBatchOperator
.. autoclass:: airflow.contrib.operators.bigquery_check_operator.BigQueryCheckOperator
.. autoclass:: airflow.contrib.operators.bigquery_check_operator.BigQueryValueCheckOperator
Expand Down Expand Up @@ -403,6 +404,7 @@ interface when possible and acting as building blocks for operators.
Community contributed hooks
'''''''''''''''''''''''''''
.. Alphabetize this list
.. autoclass:: airflow.contrib.hooks.aws_athena_hook.AWSAthenaHook
.. autoclass:: airflow.contrib.hooks.aws_dynamodb_hook.AwsDynamoDBHook
.. autoclass:: airflow.contrib.hooks.aws_firehose_hook.AwsFirehoseHook
.. autoclass:: airflow.contrib.hooks.aws_hook.AwsHook
Expand Down
107 changes: 107 additions & 0 deletions tests/contrib/operators/test_aws_athena_operator.py
@@ -0,0 +1,107 @@
# -*- coding: utf-8 -*-
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# 'License'); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# 'AS IS' BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#

import unittest

from airflow.contrib.operators.aws_athena_operator import AWSAthenaOperator
from airflow.contrib.hooks.aws_athena_hook import AWSAthenaHook
from airflow import configuration

try:
from unittest import mock
except ImportError:
try:
import mock
except ImportError:
mock = None

MOCK_DATA = {
'task_id': 'test_aws_athena_operator',
'query': 'SELECT * FROM TEST_TABLE',
'database': 'TEST_DATABASE',
'outputLocation': 's3://test_s3_bucket/',
'client_request_token': 'eac427d0-1c6d-4dfb-96aa-2835d3ac6595'
}

query_context = {
'Database': MOCK_DATA['database']
}
result_configuration = {
'OutputLocation': MOCK_DATA['outputLocation']
}


class TestAWSAthenaOperator(unittest.TestCase):

def setUp(self):
configuration.load_test_config()

self.athena = AWSAthenaOperator(task_id='test_aws_athena_operator', query='SELECT * FROM TEST_TABLE',
database='TEST_DATABASE', output_location='s3://test_s3_bucket/',
client_request_token='eac427d0-1c6d-4dfb-96aa-2835d3ac6595',
sleep_time=1)

def test_init(self):
self.assertEqual(self.athena.task_id, MOCK_DATA['task_id'])
self.assertEqual(self.athena.query, MOCK_DATA['query'])
self.assertEqual(self.athena.database, MOCK_DATA['database'])
self.assertEqual(self.athena.aws_conn_id, 'aws_default')
self.assertEqual(self.athena.client_request_token, MOCK_DATA['client_request_token'])
self.assertEqual(self.athena.sleep_time, 1)

@mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("SUCCESS",))
@mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_small_success_query(self, mock_conn, mock_run_query, mock_check_query_status):
self.athena.execute(None)
mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
MOCK_DATA['client_request_token'])
self.assertEqual(mock_check_query_status.call_count, 1)

@mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "SUCCESS",))
@mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_big_success_query(self, mock_conn, mock_run_query, mock_check_query_status):
self.athena.execute(None)
mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
MOCK_DATA['client_request_token'])
self.assertEqual(mock_check_query_status.call_count, 3)

@mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "FAILED",))
@mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_failure_query(self, mock_conn, mock_run_query, mock_check_query_status):
self.athena.execute(None)
mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
MOCK_DATA['client_request_token'])
self.assertEqual(mock_check_query_status.call_count, 2)

@mock.patch.object(AWSAthenaHook, 'check_query_status', side_effect=("RUNNING", "RUNNING", "CANCELLED",))
@mock.patch.object(AWSAthenaHook, 'run_query', return_value='1234')
@mock.patch.object(AWSAthenaHook, 'get_conn')
def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status):
self.athena.execute(None)
mock_run_query.assert_called_once_with(MOCK_DATA['query'], query_context, result_configuration,
MOCK_DATA['client_request_token'])
self.assertEqual(mock_check_query_status.call_count, 3)


if __name__ == '__main__':
unittest.main()

0 comments on commit c98d48d

Please sign in to comment.