diff --git a/awswrangler/timestream.py b/awswrangler/timestream.py index 8c51bf2d1..630332f52 100644 --- a/awswrangler/timestream.py +++ b/awswrangler/timestream.py @@ -4,7 +4,7 @@ import itertools import logging from datetime import datetime -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, Iterator, List, Optional, Union, cast import boto3 import pandas as pd @@ -103,6 +103,14 @@ def _process_row(schema: List[Dict[str, str]], row: Dict[str, Any]) -> List[Any] return row_processed +def _rows_to_df(rows: List[List[Any]], schema: List[Dict[str, str]]) -> pd.DataFrame: + df = pd.DataFrame(data=rows, columns=[c["name"] for c in schema]) + for col in schema: + if col["type"] == "VARCHAR": + df[col["name"]] = df[col["name"]].astype("string") + return df + + def _process_schema(page: Dict[str, Any]) -> List[Dict[str, str]]: schema: List[Dict[str, str]] = [] for col in page["ColumnInfo"]: @@ -112,6 +120,29 @@ def _process_schema(page: Dict[str, Any]) -> List[Dict[str, str]]: return schema +def _paginate_query( + sql: str, pagination_config: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session] = None +) -> Iterator[pd.DataFrame]: + client: boto3.client = _utils.client( + service_name="timestream-query", + session=boto3_session, + botocore_config=Config(read_timeout=60, retries={"max_attempts": 10}), + ) + paginator = client.get_paginator("query") + rows: List[List[Any]] = [] + schema: List[Dict[str, str]] = [] + page_iterator = paginator.paginate(QueryString=sql, PaginationConfig=pagination_config or {}) + for page in page_iterator: + if not schema: + schema = _process_schema(page=page) + _logger.debug("schema: %s", schema) + for row in page["Rows"]: + rows.append(_process_row(schema=schema, row=row)) + if len(rows) > 0: + yield _rows_to_df(rows, schema) + rows = [] + + def write( df: pd.DataFrame, database: str, @@ -200,14 +231,19 @@ def write( def query( - sql: str, pagination_config: Optional[Dict[str, Any]] = None, boto3_session: Optional[boto3.Session] = None -) -> pd.DataFrame: + sql: str, + chunked: bool = False, + pagination_config: Optional[Dict[str, Any]] = None, + boto3_session: Optional[boto3.Session] = None, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: """Run a query and retrieve the result as a Pandas DataFrame. Parameters ---------- sql: str SQL query. + chunked: bool + If True returns dataframe iterator, and a single dataframe otherwise. False by default. pagination_config: Dict[str, Any], optional Pagination configuration dictionary of a form {'MaxItems': 10, 'PageSize': 10, 'StartingToken': '...'} boto3_session : boto3.Session(), optional @@ -220,31 +256,16 @@ def query( Examples -------- - Running a query and storing the result as a Pandas DataFrame + Run a query and return the result as a Pandas DataFrame or an iterable. >>> import awswrangler as wr >>> df = wr.timestream.query('SELECT * FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 10') """ - client: boto3.client = _utils.client( - service_name="timestream-query", - session=boto3_session, - botocore_config=Config(read_timeout=60, retries={"max_attempts": 10}), - ) - paginator = client.get_paginator("query") - rows: List[List[Any]] = [] - schema: List[Dict[str, str]] = [] - for page in paginator.paginate(QueryString=sql, PaginationConfig=pagination_config or {}): - if not schema: - schema = _process_schema(page=page) - for row in page["Rows"]: - rows.append(_process_row(schema=schema, row=row)) - _logger.debug("schema: %s", schema) - df = pd.DataFrame(data=rows, columns=[c["name"] for c in schema]) - for col in schema: - if col["type"] == "VARCHAR": - df[col["name"]] = df[col["name"]].astype("string") - return df + result_iterator = _paginate_query(sql, pagination_config, boto3_session) + if chunked: + return result_iterator + return pd.concat(result_iterator, ignore_index=True) def create_database( diff --git a/tests/test_timestream.py b/tests/test_timestream.py index 0e4f26fd1..2c9a01132 100644 --- a/tests/test_timestream.py +++ b/tests/test_timestream.py @@ -49,6 +49,41 @@ def test_basic_scenario(timestream_database_and_table, pagination): assert df.shape == (3, 8) +def test_chunked_scenario(timestream_database_and_table): + df = pd.DataFrame( + { + "time": [datetime.now() for _ in range(5)], + "dim0": ["foo", "boo", "bar", "fizz", "buzz"], + "dim1": [1, 2, 3, 4, 5], + "measure": [1.0, 1.1, 1.2, 1.3, 1.4], + } + ) + rejected_records = wr.timestream.write( + df=df, + database=timestream_database_and_table, + table=timestream_database_and_table, + time_col="time", + measure_col="measure", + dimensions_cols=["dim0", "dim1"], + ) + assert len(rejected_records) == 0 + shapes = [(3, 5), (2, 5)] + for df, shape in zip( + wr.timestream.query( + f""" + SELECT + * + FROM "{timestream_database_and_table}"."{timestream_database_and_table}" + ORDER BY time ASC + """, + chunked=True, + pagination_config={"MaxItems": 5, "PageSize": 3}, + ), + shapes, + ): + assert df.shape == shape + + def test_versioned(timestream_database_and_table): name = timestream_database_and_table time = [datetime.now(), datetime.now(), datetime.now()]