Skip to content

Commit

Permalink
caching athena client (#815)
Browse files Browse the repository at this point in the history
  • Loading branch information
ryandeivert committed Sep 14, 2018
1 parent 9e86f70 commit 777f423
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions stream_alert/athena_partition_refresh/main.py
Expand Up @@ -45,6 +45,8 @@ class AthenaRefresher(object):
STREAMALERT_DATABASE = '{}_streamalert'
ATHENA_S3_PREFIX = 'athena_partition_refresh'

_ATHENA_CLIENT = None

def __init__(self):
config = load_config(include={'lambda.json', 'global.json'})
prefix = config['global']['account']['prefix']
Expand All @@ -63,14 +65,21 @@ def __init__(self):
's3://{}.streamalert.athena-results'.format(prefix)
)

self._athena_client = AthenaClient(
db_name,
results_bucket,
self.ATHENA_S3_PREFIX
)

self._s3_buckets_and_keys = defaultdict(set)

self._create_client(db_name, results_bucket)

@classmethod
def _create_client(cls, db_name, results_bucket):
if cls._ATHENA_CLIENT:
return # Client already created/cached

cls._ATHENA_CLIENT = AthenaClient(db_name, results_bucket, cls.ATHENA_S3_PREFIX)

# Check if the database exists when the client is created
if not cls._ATHENA_CLIENT.check_database_exists():
raise AthenaRefreshError('The \'{}\' database does not exist'.format(db_name))

def _get_partitions_from_keys(self):
"""Get the partitions that need to be added for the Athena tables
Expand Down Expand Up @@ -151,7 +160,7 @@ def _add_partitions(self):
athena_table=athena_table,
partition_statement=partition_statement))

success = self._athena_client.run_query(query=query)
success = self._ATHENA_CLIENT.run_query(query=query)
if not success:
raise AthenaRefreshError(
'The add hive partition query has failed:\n{}'.format(query)
Expand All @@ -169,11 +178,6 @@ def run(self, event):
should contain one (or maybe more) S3 bucket notification message.
"""
# Check that the database being used exists before running queries
if not self._athena_client.check_database_exists():
raise AthenaRefreshError(
'The \'{}\' database does not exist'.format(self._athena_client.database)
)

for sqs_rec in event['Records']:
LOGGER.debug('Processing event with message ID \'%s\' and SentTimestamp %s',
sqs_rec['messageId'],
Expand Down

0 comments on commit 777f423

Please sign in to comment.