Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 44 additions & 23 deletions awswrangler/timestream.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, Iterator, List, Optional, Union, cast

import boto3
import pandas as pd
Expand Down Expand Up @@ -103,6 +103,14 @@ def _process_row(schema: List[Dict[str, str]], row: Dict[str, Any]) -> List[Any]
return row_processed


def _rows_to_df(rows: List[List[Any]], schema: List[Dict[str, str]]) -> pd.DataFrame:
df = pd.DataFrame(data=rows, columns=[c["name"] for c in schema])
for col in schema:
if col["type"] == "VARCHAR":
df[col["name"]] = df[col["name"]].astype("string")
return df


def _process_schema(page: Dict[str, Any]) -> List[Dict[str, str]]:
schema: List[Dict[str, str]] = []
for col in page["ColumnInfo"]:
Expand All @@ -112,6 +120,29 @@ def _process_schema(page: Dict[str, Any]) -> List[Dict[str, str]]:
return schema


def _paginate_query(
sql: str, pagination_config: Optional[Dict[str, Any]], boto3_session: Optional[boto3.Session] = None
) -> Iterator[pd.DataFrame]:
client: boto3.client = _utils.client(
service_name="timestream-query",
session=boto3_session,
botocore_config=Config(read_timeout=60, retries={"max_attempts": 10}),
)
paginator = client.get_paginator("query")
rows: List[List[Any]] = []
schema: List[Dict[str, str]] = []
page_iterator = paginator.paginate(QueryString=sql, PaginationConfig=pagination_config or {})
for page in page_iterator:
if not schema:
schema = _process_schema(page=page)
_logger.debug("schema: %s", schema)
for row in page["Rows"]:
rows.append(_process_row(schema=schema, row=row))
if len(rows) > 0:
yield _rows_to_df(rows, schema)
rows = []


def write(
df: pd.DataFrame,
database: str,
Expand Down Expand Up @@ -200,14 +231,19 @@ def write(


def query(
sql: str, pagination_config: Optional[Dict[str, Any]] = None, boto3_session: Optional[boto3.Session] = None
) -> pd.DataFrame:
sql: str,
chunked: bool = False,
pagination_config: Optional[Dict[str, Any]] = None,
boto3_session: Optional[boto3.Session] = None,
) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]:
"""Run a query and retrieve the result as a Pandas DataFrame.

Parameters
----------
sql: str
SQL query.
chunked: bool
If True returns dataframe iterator, and a single dataframe otherwise. False by default.
pagination_config: Dict[str, Any], optional
Pagination configuration dictionary of a form {'MaxItems': 10, 'PageSize': 10, 'StartingToken': '...'}
boto3_session : boto3.Session(), optional
Expand All @@ -220,31 +256,16 @@ def query(

Examples
--------
Running a query and storing the result as a Pandas DataFrame
Run a query and return the result as a Pandas DataFrame or an iterable.

>>> import awswrangler as wr
>>> df = wr.timestream.query('SELECT * FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 10')

"""
client: boto3.client = _utils.client(
service_name="timestream-query",
session=boto3_session,
botocore_config=Config(read_timeout=60, retries={"max_attempts": 10}),
)
paginator = client.get_paginator("query")
rows: List[List[Any]] = []
schema: List[Dict[str, str]] = []
for page in paginator.paginate(QueryString=sql, PaginationConfig=pagination_config or {}):
if not schema:
schema = _process_schema(page=page)
for row in page["Rows"]:
rows.append(_process_row(schema=schema, row=row))
_logger.debug("schema: %s", schema)
df = pd.DataFrame(data=rows, columns=[c["name"] for c in schema])
for col in schema:
if col["type"] == "VARCHAR":
df[col["name"]] = df[col["name"]].astype("string")
return df
result_iterator = _paginate_query(sql, pagination_config, boto3_session)
if chunked:
return result_iterator
return pd.concat(result_iterator, ignore_index=True)


def create_database(
Expand Down
35 changes: 35 additions & 0 deletions tests/test_timestream.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,41 @@ def test_basic_scenario(timestream_database_and_table, pagination):
assert df.shape == (3, 8)


def test_chunked_scenario(timestream_database_and_table):
df = pd.DataFrame(
{
"time": [datetime.now() for _ in range(5)],
"dim0": ["foo", "boo", "bar", "fizz", "buzz"],
"dim1": [1, 2, 3, 4, 5],
"measure": [1.0, 1.1, 1.2, 1.3, 1.4],
}
)
rejected_records = wr.timestream.write(
df=df,
database=timestream_database_and_table,
table=timestream_database_and_table,
time_col="time",
measure_col="measure",
dimensions_cols=["dim0", "dim1"],
)
assert len(rejected_records) == 0
shapes = [(3, 5), (2, 5)]
for df, shape in zip(
wr.timestream.query(
f"""
SELECT
*
FROM "{timestream_database_and_table}"."{timestream_database_and_table}"
ORDER BY time ASC
""",
chunked=True,
pagination_config={"MaxItems": 5, "PageSize": 3},
),
shapes,
):
assert df.shape == shape


def test_versioned(timestream_database_and_table):
name = timestream_database_and_table
time = [datetime.now(), datetime.now(), datetime.now()]
Expand Down