From d307e275f85e41be5b7daf6a619d5bba0640fd6e Mon Sep 17 00:00:00 2001 From: Khue Ngoc Dang Date: Sat, 6 Aug 2022 07:46:54 +0700 Subject: [PATCH 1/2] Add get_query_results function to Athena module --- awswrangler/athena/__init__.py | 3 +- awswrangler/athena/_read.py | 90 ++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 1 deletion(-) diff --git a/awswrangler/athena/__init__.py b/awswrangler/athena/__init__.py index 4273b4fa5..7bbbb8cf1 100644 --- a/awswrangler/athena/__init__.py +++ b/awswrangler/athena/__init__.py @@ -1,6 +1,6 @@ """Amazon Athena Module.""" -from awswrangler.athena._read import read_sql_query, read_sql_table, unload # noqa +from awswrangler.athena._read import get_query_results, read_sql_query, read_sql_table, unload # noqa from awswrangler.athena._utils import ( # noqa create_athena_bucket, create_ctas_table, @@ -23,6 +23,7 @@ "describe_table", "get_query_columns_types", "get_query_execution", + "get_query_results", "get_named_query_statement", "get_work_group", "repair_table", diff --git a/awswrangler/athena/_read.py b/awswrangler/athena/_read.py index 850ba0392..cc486f0c0 100644 --- a/awswrangler/athena/_read.py +++ b/awswrangler/athena/_read.py @@ -559,6 +559,96 @@ def _unload( return query_metadata +@apply_configs +def get_query_results( + query_execution_id: str, + use_threads: Union[bool, int] = True, + boto3_session: Optional[boto3.Session] = None, + categories: Optional[List[str]] = None, + chunksize: Optional[Union[int, bool]] = None, + s3_additional_kwargs: Optional[Dict[str, Any]] = None, + pyarrow_additional_kwargs: Optional[Dict[str, Any]] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Get AWS Athena SQL query results as a Pandas DataFrame. + + Parameters + ---------- + query_execution_id : str + SQL query's execution_id on AWS Athena. + use_threads : bool, int + True to enable concurrent requests, False to disable multiple threads. + If enabled os.cpu_count() will be used as the max number of threads. + If integer is provided, specified number is used. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + categories: List[str], optional + List of columns names that should be returned as pandas.Categorical. + Recommended for memory restricted environments. + chunksize : Union[int, bool], optional + If passed will split the data in a Iterable of DataFrames (Memory friendly). + If `True` wrangler will iterate on the data by files in the most efficient way without guarantee of chunksize. + If an `INTEGER` is passed Wrangler will iterate on the data by number of rows igual the received INTEGER. + s3_additional_kwargs : Optional[Dict[str, Any]] + Forwarded to botocore requests. + e.g. s3_additional_kwargs={'RequestPayer': 'requester'} + pyarrow_additional_kwargs : Optional[Dict[str, Any]] + Forward to the ParquetFile class or converting an Arrow table to Pandas, currently only an + "coerce_int96_timestamp_unit" or "timestamp_as_object" argument will be considered. If reading parquet + files where you cannot convert a timestamp to pandas Timestamp[ns] consider setting timestamp_as_object=True, + to allow for timestamp units larger than "ns". If reading parquet data that still uses INT96 (like Athena + outputs) you can use coerce_int96_timestamp_unit to specify what timestamp unit to encode INT96 to (by default + this is "ns", if you know the output parquet came from a system that encodes timestamp to a particular unit + then set this to that same unit e.g. coerce_int96_timestamp_unit="ms"). + + Returns + ------- + Union[pd.DataFrame, Iterator[pd.DataFrame]] + Pandas DataFrame or Generator of Pandas DataFrames if chunksize is passed. + + Examples + -------- + >>> import awswrangler as wr + >>> res = wr.athena.get_query_results( + ... query_execution_id="cbae5b41-8103-4709-95bb-887f88edd4f2" + ... ) + + """ + query_metadata: _QueryMetadata = _get_query_metadata( + query_execution_id=query_execution_id, + boto3_session=boto3_session, + categories=categories, + metadata_cache_manager=_cache_manager, + ) + client_athena: boto3.client = _utils.client(service_name="athena", session=boto3_session) + query_info: Dict[str, Any] = client_athena.get_query_execution(QueryExecutionId=query_execution_id)[ + "QueryExecution" + ] + statement_type: Optional[str] = query_info.get("StatementType") + if (statement_type == "DDL" and query_info["Query"].startswith("CREATE TABLE")) or ( + statement_type == "DML" and query_info["Query"].startswith("UNLOAD") + ): + return _fetch_parquet_result( + query_metadata=query_metadata, + keep_files=True, + categories=categories, + chunksize=chunksize, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + pyarrow_additional_kwargs=pyarrow_additional_kwargs, + ) + if statement_type == "DML" and not query_info["Query"].startswith("INSERT"): + return _fetch_csv_result( + query_metadata=query_metadata, + keep_files=True, + chunksize=chunksize, + use_threads=use_threads, + boto3_session=boto3_session, + s3_additional_kwargs=s3_additional_kwargs, + ) + raise exceptions.UndetectedType(f"""Unable to get results for: {query_info["Query"]}.""") + + @apply_configs def read_sql_query( sql: str, From 914ce24b91d0ef03c728999fbfe7f15a708ec61b Mon Sep 17 00:00:00 2001 From: Khue Ngoc Dang Date: Sat, 6 Aug 2022 07:47:24 +0700 Subject: [PATCH 2/2] Add test_get_query_results to test_athena --- tests/test_athena.py | 46 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/tests/test_athena.py b/tests/test_athena.py index d084924e9..4b8a42144 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -1097,3 +1097,49 @@ def test_start_query_execution_wait(path, glue_database, glue_table): assert query_execution_result["Query"] == sql assert query_execution_result["StatementType"] == "DML" assert query_execution_result["QueryExecutionContext"]["Database"] == glue_database + + +def test_get_query_results(path, glue_table, glue_database): + + sql = ( + "SELECT CAST(" + " ROW(1, ROW(2, ROW(3, '4'))) AS" + " ROW(field0 BIGINT, field1 ROW(field2 BIGINT, field3 ROW(field4 BIGINT, field5 VARCHAR)))" + ") AS col0" + ) + + df_ctas: pd.DataFrame = wr.athena.read_sql_query( + sql=sql, database=glue_database, ctas_approach=True, unload_approach=False + ) + query_id_ctas = df_ctas.query_metadata["QueryExecutionId"] + df_get_query_results_ctas = wr.athena.get_query_results(query_execution_id=query_id_ctas) + pd.testing.assert_frame_equal(df_get_query_results_ctas, df_ctas) + + df_unload: pd.DataFrame = wr.athena.read_sql_query( + sql=sql, database=glue_database, ctas_approach=False, unload_approach=True, s3_output=path + ) + query_id_unload = df_unload.query_metadata["QueryExecutionId"] + df_get_query_results_df_unload = wr.athena.get_query_results(query_execution_id=query_id_unload) + pd.testing.assert_frame_equal(df_get_query_results_df_unload, df_unload) + + wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) + wr.s3.to_parquet( + df=get_df(), + path=path, + index=True, + use_threads=True, + dataset=True, + mode="overwrite", + database=glue_database, + table=glue_table, + partition_cols=["par0", "par1"], + ) + + reg_sql = f"SELECT * FROM {glue_table}" + + df_regular: pd.DataFrame = wr.athena.read_sql_query( + sql=reg_sql, database=glue_database, ctas_approach=False, unload_approach=False + ) + query_id_regular = df_regular.query_metadata["QueryExecutionId"] + df_get_query_results_df_regular = wr.athena.get_query_results(query_execution_id=query_id_regular) + pd.testing.assert_frame_equal(df_get_query_results_df_regular, df_regular)