Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions awswrangler/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class _ConfigArg(NamedTuple):
"database": _ConfigArg(dtype=str, nullable=True),
"max_cache_query_inspections": _ConfigArg(dtype=int, nullable=False),
"max_cache_seconds": _ConfigArg(dtype=int, nullable=False),
"max_remote_cache_entries": _ConfigArg(dtype=int, nullable=False),
"max_local_cache_entries": _ConfigArg(dtype=int, nullable=False),
"s3_block_size": _ConfigArg(dtype=int, nullable=False, enforced=True),
"workgroup": _ConfigArg(dtype=str, nullable=False, enforced=True),
# Endpoints URLs
Expand Down Expand Up @@ -226,6 +228,35 @@ def max_cache_seconds(self) -> int:
def max_cache_seconds(self, value: int) -> None:
self._set_config_value(key="max_cache_seconds", value=value)

@property
def max_local_cache_entries(self) -> int:
"""Property max_local_cache_entries."""
return cast(int, self["max_local_cache_entries"])

@max_local_cache_entries.setter
def max_local_cache_entries(self, value: int) -> None:
try:
max_remote_cache_entries = cast(int, self["max_remote_cache_entries"])
except AttributeError:
max_remote_cache_entries = 50
if value < max_remote_cache_entries:
_logger.warning(
"max_remote_cache_entries shouldn't be greater than max_local_cache_entries. "
"Therefore max_remote_cache_entries will be set to %s as well.",
value,
)
self._set_config_value(key="max_remote_cache_entries", value=value)
self._set_config_value(key="max_local_cache_entries", value=value)

@property
def max_remote_cache_entries(self) -> int:
"""Property max_remote_cache_entries."""
return cast(int, self["max_remote_cache_entries"])

@max_remote_cache_entries.setter
def max_remote_cache_entries(self, value: int) -> None:
self._set_config_value(key="max_remote_cache_entries", value=value)

@property
def s3_block_size(self) -> int:
"""Property s3_block_size."""
Expand Down
143 changes: 90 additions & 53 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
_get_query_metadata,
_get_s3_output,
_get_workgroup_config,
_LocalMetadataCacheManager,
_QueryMetadata,
_start_query_execution,
_WorkGroupConfig,
Expand Down Expand Up @@ -96,33 +97,37 @@ def _compare_query_string(sql: str, other: str) -> bool:
return False


def _get_last_query_executions(
boto3_session: Optional[boto3.Session] = None, workgroup: Optional[str] = None
) -> Iterator[List[Dict[str, Any]]]:
def _get_last_query_infos(
max_remote_cache_entries: int,
boto3_session: Optional[boto3.Session] = None,
workgroup: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Return an iterator of `query_execution_info`s run by the workgroup in Athena."""
client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session)
args: Dict[str, Union[str, Dict[str, int]]] = {"PaginationConfig": {"MaxItems": 50, "PageSize": 50}}
page_size = 50
args: Dict[str, Union[str, Dict[str, int]]] = {
"PaginationConfig": {"MaxItems": max_remote_cache_entries, "PageSize": page_size}
}
if workgroup is not None:
args["WorkGroup"] = workgroup
paginator = client_athena.get_paginator("list_query_executions")
uncached_ids = []
for page in paginator.paginate(**args):
_logger.debug("paginating Athena's queries history...")
query_execution_id_list: List[str] = page["QueryExecutionIds"]
execution_data = client_athena.batch_get_query_execution(QueryExecutionIds=query_execution_id_list)
yield execution_data.get("QueryExecutions")


def _sort_successful_executions_data(query_executions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Sorts `_get_last_query_executions`'s results based on query Completion DateTime.

This is useful to guarantee LRU caching rules.
"""
filtered: List[Dict[str, Any]] = []
for query in query_executions:
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
filtered.append(query)
return sorted(filtered, key=lambda e: str(e["Status"]["CompletionDateTime"]), reverse=True)
for query_execution_id in query_execution_id_list:
if query_execution_id not in _cache_manager:
uncached_ids.append(query_execution_id)
if uncached_ids:
new_execution_data = []
for i in range(0, len(uncached_ids), page_size):
new_execution_data.extend(
client_athena.batch_get_query_execution(QueryExecutionIds=uncached_ids[i : i + page_size]).get(
"QueryExecutions"
)
)
_cache_manager.update_cache(new_execution_data)
return _cache_manager.sorted_successful_generator()


def _parse_select_query_from_possible_ctas(possible_ctas: str) -> Optional[str]:
Expand Down Expand Up @@ -150,6 +155,7 @@ def _check_for_cached_results(
workgroup: Optional[str],
max_cache_seconds: int,
max_cache_query_inspections: int,
max_remote_cache_entries: int,
) -> _CacheInfo:
"""
Check whether `sql` has been run before, within the `max_cache_seconds` window, by the `workgroup`.
Expand All @@ -162,45 +168,41 @@ def _check_for_cached_results(
comparable_sql: str = _prepare_query_string_for_comparison(sql)
current_timestamp: datetime.datetime = datetime.datetime.now(datetime.timezone.utc)
_logger.debug("current_timestamp: %s", current_timestamp)
for query_executions in _get_last_query_executions(boto3_session=boto3_session, workgroup=workgroup):
_logger.debug("len(query_executions): %s", len(query_executions))
cached_queries: List[Dict[str, Any]] = _sort_successful_executions_data(query_executions=query_executions)
_logger.debug("len(cached_queries): %s", len(cached_queries))
for query_info in cached_queries:
query_execution_id: str = query_info["QueryExecutionId"]
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
_logger.debug("query_timestamp: %s", query_timestamp)

if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
return _CacheInfo(
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
)

statement_type: Optional[str] = query_info.get("StatementType")
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
if parsed_query is not None:
if _compare_query_string(sql=comparable_sql, other=parsed_query):
return _CacheInfo(
has_valid_cache=True,
file_format="parquet",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
for query_info in _get_last_query_infos(
max_remote_cache_entries=max_remote_cache_entries,
boto3_session=boto3_session,
workgroup=workgroup,
):
query_execution_id: str = query_info["QueryExecutionId"]
query_timestamp: datetime.datetime = query_info["Status"]["CompletionDateTime"]
_logger.debug("query_timestamp: %s", query_timestamp)
if (current_timestamp - query_timestamp).total_seconds() > max_cache_seconds:
return _CacheInfo(
has_valid_cache=False, query_execution_id=query_execution_id, query_execution_payload=query_info
)
statement_type: Optional[str] = query_info.get("StatementType")
if statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE"):
parsed_query: Optional[str] = _parse_select_query_from_possible_ctas(possible_ctas=query_info["Query"])
if parsed_query is not None:
if _compare_query_string(sql=comparable_sql, other=parsed_query):
return _CacheInfo(
has_valid_cache=True,
file_format="csv",
file_format="parquet",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)

num_executions_inspected += 1
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
if num_executions_inspected >= max_cache_query_inspections:
return _CacheInfo(has_valid_cache=False)

elif statement_type == "DML" and not query_info["Query"].startswith("INSERT"):
if _compare_query_string(sql=comparable_sql, other=query_info["Query"]):
return _CacheInfo(
has_valid_cache=True,
file_format="csv",
query_execution_id=query_execution_id,
query_execution_payload=query_info,
)
num_executions_inspected += 1
_logger.debug("num_executions_inspected: %s", num_executions_inspected)
if num_executions_inspected >= max_cache_query_inspections:
return _CacheInfo(has_valid_cache=False)
return _CacheInfo(has_valid_cache=False)


Expand Down Expand Up @@ -302,6 +304,7 @@ def _resolve_query_with_cache(
boto3_session=session,
categories=categories,
query_execution_payload=cache_info.query_execution_payload,
metadata_cache_manager=_cache_manager,
)
if cache_info.file_format == "parquet":
return _fetch_parquet_result(
Expand Down Expand Up @@ -380,6 +383,7 @@ def _resolve_query_without_cache_ctas(
query_execution_id=query_id,
boto3_session=boto3_session,
categories=categories,
metadata_cache_manager=_cache_manager,
)
except exceptions.QueryFailed as ex:
msg: str = str(ex)
Expand Down Expand Up @@ -439,6 +443,7 @@ def _resolve_query_without_cache_regular(
query_execution_id=query_id,
boto3_session=boto3_session,
categories=categories,
metadata_cache_manager=_cache_manager,
)
return _fetch_csv_result(
query_metadata=query_metadata,
Expand Down Expand Up @@ -532,6 +537,8 @@ def read_sql_query(
boto3_session: Optional[boto3.Session] = None,
max_cache_seconds: int = 0,
max_cache_query_inspections: int = 50,
max_remote_cache_entries: int = 50,
max_local_cache_entries: int = 100,
data_source: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
Expand Down Expand Up @@ -678,6 +685,15 @@ def read_sql_query(
Max number of queries that will be inspected from the history to try to find some result to reuse.
The bigger the number of inspection, the bigger will be the latency for not cached queries.
Only takes effect if max_cache_seconds > 0.
max_remote_cache_entries : int
Max number of queries that will be retrieved from AWS for cache inspection.
The bigger the number of inspection, the bigger will be the latency for not cached queries.
Only takes effect if max_cache_seconds > 0 and default value is 50.
max_local_cache_entries : int
Max number of queries for which metadata will be cached locally. This will reduce the latency and also
enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be
smaller than max_remote_cache_entries.
Only takes effect if max_cache_seconds > 0 and default value is 100.
data_source : str, optional
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
params: Dict[str, any], optional
Expand Down Expand Up @@ -718,12 +734,17 @@ def read_sql_query(
for key, value in params.items():
sql = sql.replace(f":{key};", str(value))

if max_remote_cache_entries > max_local_cache_entries:
max_remote_cache_entries = max_local_cache_entries

_cache_manager.max_cache_size = max_local_cache_entries
cache_info: _CacheInfo = _check_for_cached_results(
sql=sql,
boto3_session=session,
workgroup=workgroup,
max_cache_seconds=max_cache_seconds,
max_cache_query_inspections=max_cache_query_inspections,
max_remote_cache_entries=max_remote_cache_entries,
)
_logger.debug("cache_info:\n%s", cache_info)
if cache_info.has_valid_cache is True:
Expand Down Expand Up @@ -774,6 +795,8 @@ def read_sql_table(
boto3_session: Optional[boto3.Session] = None,
max_cache_seconds: int = 0,
max_cache_query_inspections: int = 50,
max_remote_cache_entries: int = 50,
max_local_cache_entries: int = 100,
data_source: Optional[str] = None,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
"""Extract the full table AWS Athena and return the results as a Pandas DataFrame.
Expand Down Expand Up @@ -914,6 +937,15 @@ def read_sql_table(
Max number of queries that will be inspected from the history to try to find some result to reuse.
The bigger the number of inspection, the bigger will be the latency for not cached queries.
Only takes effect if max_cache_seconds > 0.
max_remote_cache_entries : int
Max number of queries that will be retrieved from AWS for cache inspection.
The bigger the number of inspection, the bigger will be the latency for not cached queries.
Only takes effect if max_cache_seconds > 0 and default value is 50.
max_local_cache_entries : int
Max number of queries for which metadata will be cached locally. This will reduce the latency and also
enables keeping more than `max_remote_cache_entries` available for the cache. This value should not be
smaller than max_remote_cache_entries.
Only takes effect if max_cache_seconds > 0 and default value is 100.
data_source : str, optional
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.

Expand Down Expand Up @@ -947,4 +979,9 @@ def read_sql_table(
boto3_session=boto3_session,
max_cache_seconds=max_cache_seconds,
max_cache_query_inspections=max_cache_query_inspections,
max_remote_cache_entries=max_remote_cache_entries,
max_local_cache_entries=max_local_cache_entries,
)


_cache_manager = _LocalMetadataCacheManager()
72 changes: 71 additions & 1 deletion awswrangler/athena/_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Utilities Module for Amazon Athena."""
import csv
import datetime
import logging
import pprint
import time
import warnings
from decimal import Decimal
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Union, cast
from heapq import heappop, heappush
from typing import Any, Dict, Generator, List, NamedTuple, Optional, Tuple, Union, cast

import boto3
import botocore.exceptions
Expand Down Expand Up @@ -39,6 +41,71 @@ class _WorkGroupConfig(NamedTuple):
kms_key: Optional[str]


class _LocalMetadataCacheManager:
def __init__(self) -> None:
self._cache: Dict[str, Any] = dict()
self._pqueue: List[Tuple[datetime.datetime, str]] = []
self._max_cache_size = 100

def update_cache(self, items: List[Dict[str, Any]]) -> None:
"""
Update the local metadata cache with new query metadata.

Parameters
----------
items : List[Dict[str, Any]]
List of query execution metadata which is returned by boto3 `batch_get_query_execution()`.

Returns
-------
None
None.
"""
if self._pqueue:
oldest_item = self._cache[self._pqueue[0][1]]
items = list(
filter(lambda x: x["Status"]["SubmissionDateTime"] > oldest_item["Status"]["SubmissionDateTime"], items)
)

cache_oversize = len(self._cache) + len(items) - self._max_cache_size
for _ in range(cache_oversize):
_, query_execution_id = heappop(self._pqueue)
del self._cache[query_execution_id]

for item in items[: self._max_cache_size]:
heappush(self._pqueue, (item["Status"]["SubmissionDateTime"], item["QueryExecutionId"]))
self._cache[item["QueryExecutionId"]] = item

def sorted_successful_generator(self) -> List[Dict[str, Any]]:
"""
Sorts the entries in the local cache based on query Completion DateTime.

This is useful to guarantee LRU caching rules.

Returns
-------
List[Dict[str, Any]]
Returns successful DDL and DML queries sorted by query completion time.
"""
filtered: List[Dict[str, Any]] = []
for query in self._cache.values():
if (query["Status"].get("State") == "SUCCEEDED") and (query.get("StatementType") in ["DDL", "DML"]):
filtered.append(query)
return sorted(filtered, key=lambda e: str(e["Status"]["CompletionDateTime"]), reverse=True)

def __contains__(self, key: str) -> bool:
return key in self._cache

@property
def max_cache_size(self) -> int:
"""Property max_cache_size."""
return self._max_cache_size

@max_cache_size.setter
def max_cache_size(self, value: int) -> None:
self._max_cache_size = value


def _get_s3_output(s3_output: Optional[str], wg_config: _WorkGroupConfig, boto3_session: boto3.Session) -> str:
if wg_config.enforced and wg_config.s3_output is not None:
return wg_config.s3_output
Expand Down Expand Up @@ -171,6 +238,7 @@ def _get_query_metadata( # pylint: disable=too-many-statements
boto3_session: boto3.Session,
categories: Optional[List[str]] = None,
query_execution_payload: Optional[Dict[str, Any]] = None,
metadata_cache_manager: Optional[_LocalMetadataCacheManager] = None,
) -> _QueryMetadata:
"""Get query metadata."""
if (query_execution_payload is not None) and (query_execution_payload["Status"]["State"] in _QUERY_FINAL_STATES):
Expand Down Expand Up @@ -224,6 +292,8 @@ def _get_query_metadata( # pylint: disable=too-many-statements
athena_statistics: Dict[str, Union[int, str]] = _query_execution_payload.get("Statistics", {})
manifest_location: Optional[str] = str(athena_statistics.get("DataManifestLocation"))

if metadata_cache_manager is not None and query_execution_id not in metadata_cache_manager:
metadata_cache_manager.update_cache(items=[_query_execution_payload])
query_metadata: _QueryMetadata = _QueryMetadata(
execution_id=query_execution_id,
dtype=dtype,
Expand Down
Loading