Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions awswrangler/athena/_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
"""Cache Module for Amazon Athena."""
import datetime
import logging
import re
from heapq import heappop, heappush
from typing import Any, Dict, List, Match, NamedTuple, Optional, Tuple, Union

import boto3

from awswrangler import _utils

_logger: logging.Logger = logging.getLogger(__name__)


class _CacheInfo(NamedTuple):
has_valid_cache: bool
file_format: Optional[str] = None
query_execution_id: Optional[str] = None
query_execution_payload: Optional[Dict[str, Any]] = None


class _LocalMetadataCacheManager:
def __init__(self) -> None:
self._cache: Dict[str, Any] = {}
self._pqueue: List[Tuple[datetime.datetime, str]] = []
self._max_cache_size = 100

def update_cache(self, items: List[Dict[str, Any]]) -> None:
"""
Update the local metadata cache with new query metadata.

Parameters
----------
items : List[Dict[str, Any]]
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.

Returns
-------
None
None.
"""
if self._pqueue:
oldest_item = self._cache[self._pqueue[0][1]]
items = list(
filter(lambda x: x["Status"]["SubmissionDateTime"] > oldest_item["Status"]["SubmissionDateTime"], items)
)

cache_oversize = len(self._cache) + len(items) - self._max_cache_size
for _ in range(cache_oversize):
_, query_execution_id = heappop(self._pqueue)
del self._cache[query_execution_id]

for item in items[: self._max_cache_size]:
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
self._cache[item["QueryExecutionId"]] = item

def sorted_successful_generator(self) -> List[Dict[str, Any]]:
"""
Sorts the entries in the local cache based on query Completion DateTime.

This is useful to guarantee LRU caching rules.

Returns
-------
List[Dict[str, Any]]
Returns successful DDL and DML queries sorted by query completion time.
"""
filtered: List[Dict[str, Any]] = []
for query in self._cache.values():
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
filtered.append(query)
return sorted(filtered, key=lambda e: str(e["Status"]["CompletionDateTime"]), reverse=True)

def __contains__(self, key: str) -> bool:
return key in self._cache

@property
def max_cache_size(self) -> int:
"""Property max_cache_size."""
return self._max_cache_size

@max_cache_size.setter
def max_cache_size(self, value: int) -> None:
self._max_cache_size = value


def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
"""Check if `possible_ctas` is a valid parquet-generating CTAS and returns the full SELECT statement."""
possible_ctas = possible_ctas.lower()
parquet_format_regex: str = r"format\s*=\s*\'parquet\'\s*,"
is_parquet_format: Optional[Match[str]] = re.search(pattern=parquet_format_regex, string=possible_ctas)
if is_parquet_format is not None:
unstripped_select_statement_regex: str = r"\s+as\s+\(*(select|with).*"
unstripped_select_statement_match: Optional[Match[str]] = re.search(
unstripped_select_statement_regex, possible_ctas, re.DOTALL
)
if unstripped_select_statement_match is not None:
stripped_select_statement_match: Optional[Match[str]] = re.search(
r"(select|with).*", unstripped_select_statement_match.group(0), re.DOTALL
)
if stripped_select_statement_match is not None:
return stripped_select_statement_match.group(0)
return None


def _compare_query_string(sql: str, other: str) -> bool:
comparison_query = _prepare_query_string_for_comparison(query_string=other)
_logger.debug("sql: %s", sql)
_logger.debug("comparison_query: %s", comparison_query)
if sql == comparison_query:
return True
return False


def _prepare_query_string_for_comparison(query_string: str) -> str:
"""To use cached data, we need to compare queries. Returns a query string in canonical form."""
# for now this is a simple complete strip, but it could grow into much more sophisticated
# query comparison data structures
query_string = "".join(query_string.split()).strip("()").lower()
query_string = query_string[:-1] if query_string.endswith(";") else query_string
return query_string


def _get_last_query_infos(
max_remote_cache_entries: int,
boto3_session: Optional[boto3.Session] = None,
workgroup: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
page_size = 50
args: Dict[str, Union[str, Dict[str, int]]] = {
"PaginationConfig": {"MaxItems": max_remote_cache_entries, "PageSize": page_size}
}
if workgroup is not None:
args["WorkGroup"] = workgroup
paginator = client_athena.get_paginator("list_query_executions")
uncached_ids = []
for page in paginator.paginate(**args):
_logger.debug("paginating Athena's queries history...")
query_execution_id_list: List[str] = page["QueryExecutionIds"]
for query_execution_id in query_execution_id_list:
if query_execution_id not in _cache_manager:
uncached_ids.append(query_execution_id)
if uncached_ids:
new_execution_data = []
for i in range(0, len(uncached_ids), page_size):
new_execution_data.extend(
client_athena.batch_get_query_execution(QueryExecutionIds=uncached_ids[i : i + page_size]).get(
"QueryExecutions"
)
)
_cache_manager.update_cache(new_execution_data)
return _cache_manager.sorted_successful_generator()


def _check_for_cached_results(
sql: str,
boto3_session: boto3.Session,
workgroup: Optional[str],
max_cache_seconds: int,
max_cache_query_inspections: int,
max_remote_cache_entries: int,
) -> _CacheInfo:
"""
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.

If so, returns a dict with Athena's `query_execution_info` and the data format.
"""
if max_cache_seconds <= 0:
return _CacheInfo(has_valid_cache=False)
num_executions_inspected: int = 0
comparable_sql: str = _prepare_query_string_for_comparison(sql)
current_timestamp: datetime.datetime = datetime.datetime.now(datetime.timezone.utc)
_logger.debug("current_timestamp: %s", current_timestamp)
for query_info in _get_last_query_infos(
max_remote_cache_entries=max_remote_cache_entries,
boto3_session=boto3_session,
workgroup=workgroup,
):
query_execution_id: str = query_info["QueryExecutionId"]
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
_logger.debug("query_timestamp: %s", query_timestamp)
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
return _CacheInfo(
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
)
statement_type: Optional[str] = query_info.get("StatementType")
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
if parsed_query is not None:
if _compare_query_string(sql=comparable_sql, other=parsed_query):
return _CacheInfo(
has_valid_cache=True,
file_format="parquet",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
return _CacheInfo(
has_valid_cache=True,
file_format="csv",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
num_executions_inspected += 1
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
if num_executions_inspected >= max_cache_query_inspections:
return _CacheInfo(has_valid_cache=False)
return _CacheInfo(has_valid_cache=False)


_cache_manager = _LocalMetadataCacheManager()
144 changes: 3 additions & 141 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Amazon Athena Module gathering all read_sql_* function."""

import csv
import datetime
import logging
import re
import sys
import uuid
from typing import Any, Dict, Iterator, List, Match, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union

import boto3
import botocore.exceptions
Expand All @@ -21,20 +19,14 @@
_get_query_metadata,
_get_s3_output,
_get_workgroup_config,
_LocalMetadataCacheManager,
_QueryMetadata,
_start_query_execution,
_WorkGroupConfig,
)

_logger: logging.Logger = logging.getLogger(__name__)

from ._cache import _cache_manager, _CacheInfo, _check_for_cached_results

class _CacheInfo(NamedTuple):
has_valid_cache: bool
file_format: Optional[str] = None
query_execution_id: Optional[str] = None
query_execution_payload: Optional[Dict[str, Any]] = None
_logger: logging.Logger = logging.getLogger(__name__)


def _extract_ctas_manifest_paths(path: str, boto3_session: Optional[boto3.Session] = None) -> List[str]:
Expand Down Expand Up @@ -86,133 +78,6 @@ def _delete_after_iterate(
)


def _prepare_query_string_for_comparison(query_string: str) -> str:
"""To use cached data, we need to compare queries. Returns a query string in canonical form."""
# for now this is a simple complete strip, but it could grow into much more sophisticated
# query comparison data structures
query_string = "".join(query_string.split()).strip("()").lower()
query_string = query_string[:-1] if query_string.endswith(";") else query_string
return query_string


def _compare_query_string(sql: str, other: str) -> bool:
comparison_query = _prepare_query_string_for_comparison(query_string=other)
_logger.debug("sql: %s", sql)
_logger.debug("comparison_query: %s", comparison_query)
if sql == comparison_query:
return True
return False


def _get_last_query_infos(
max_remote_cache_entries: int,
boto3_session: Optional[boto3.Session] = None,
workgroup: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
page_size = 50
args: Dict[str, Union[str, Dict[str, int]]] = {
"PaginationConfig": {"MaxItems": max_remote_cache_entries, "PageSize": page_size}
}
if workgroup is not None:
args["WorkGroup"] = workgroup
paginator = client_athena.get_paginator("list_query_executions")
uncached_ids = []
for page in paginator.paginate(**args):
_logger.debug("paginating Athena's queries history...")
query_execution_id_list: List[str] = page["QueryExecutionIds"]
for query_execution_id in query_execution_id_list:
if query_execution_id not in _cache_manager:
uncached_ids.append(query_execution_id)
if uncached_ids:
new_execution_data = []
for i in range(0, len(uncached_ids), page_size):
new_execution_data.extend(
client_athena.batch_get_query_execution(QueryExecutionIds=uncached_ids[i : i + page_size]).get(
"QueryExecutions"
)
)
_cache_manager.update_cache(new_execution_data)
return _cache_manager.sorted_successful_generator()


def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
"""Check if `possible_ctas` is a valid parquet-generating CTAS and returns the full SELECT statement."""
possible_ctas = possible_ctas.lower()
parquet_format_regex: str = r"format\s*=\s*\'parquet\'\s*,"
is_parquet_format: Optional[Match[str]] = re.search(pattern=parquet_format_regex, string=possible_ctas)
if is_parquet_format is not None:
unstripped_select_statement_regex: str = r"\s+as\s+\(*(select|with).*"
unstripped_select_statement_match: Optional[Match[str]] = re.search(
unstripped_select_statement_regex, possible_ctas, re.DOTALL
)
if unstripped_select_statement_match is not None:
stripped_select_statement_match: Optional[Match[str]] = re.search(
r"(select|with).*", unstripped_select_statement_match.group(0), re.DOTALL
)
if stripped_select_statement_match is not None:
return stripped_select_statement_match.group(0)
return None


def _check_for_cached_results(
sql: str,
boto3_session: boto3.Session,
workgroup: Optional[str],
max_cache_seconds: int,
max_cache_query_inspections: int,
max_remote_cache_entries: int,
) -> _CacheInfo:
"""
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.

If so, returns a dict with Athena's `query_execution_info` and the data format.
"""
if max_cache_seconds <= 0:
return _CacheInfo(has_valid_cache=False)
num_executions_inspected: int = 0
comparable_sql: str = _prepare_query_string_for_comparison(sql)
current_timestamp: datetime.datetime = datetime.datetime.now(datetime.timezone.utc)
_logger.debug("current_timestamp: %s", current_timestamp)
for query_info in _get_last_query_infos(
max_remote_cache_entries=max_remote_cache_entries,
boto3_session=boto3_session,
workgroup=workgroup,
):
query_execution_id: str = query_info["QueryExecutionId"]
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
_logger.debug("query_timestamp: %s", query_timestamp)
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
return _CacheInfo(
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
)
statement_type: Optional[str] = query_info.get("StatementType")
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
if parsed_query is not None:
if _compare_query_string(sql=comparable_sql, other=parsed_query):
return _CacheInfo(
has_valid_cache=True,
file_format="parquet",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
return _CacheInfo(
has_valid_cache=True,
file_format="csv",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
num_executions_inspected += 1
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
if num_executions_inspected >= max_cache_query_inspections:
return _CacheInfo(has_valid_cache=False)
return _CacheInfo(has_valid_cache=False)


def _fetch_parquet_result(
query_metadata: _QueryMetadata,
keep_files: bool,
Expand Down Expand Up @@ -1114,6 +979,3 @@ def read_sql_table(
s3_additional_kwargs=s3_additional_kwargs,
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
)


_cache_manager = _LocalMetadataCacheManager()
Loading