Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions awswrangler/timestream.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def _write_batch(
database: str,
table: str,
cols_names: List[str],
measure_type: str,
measure_cols_names: List[str],
measure_types: List[str],
version: int,
batch: List[Any],
boto3_primitives: _utils.Boto3PrimitivesType,
Expand All @@ -43,27 +44,41 @@ def _write_batch(
botocore_config=Config(read_timeout=20, max_pool_connections=5000, retries={"max_attempts": 10}),
)
try:
time_loc = 0
measure_cols_loc = 1
dimensions_cols_loc = 1 + len(measure_cols_names)
records: List[Dict[str, Any]] = []
for rec in batch:
record: Dict[str, Any] = {
"Dimensions": [
{"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)}
for name, value in zip(cols_names[dimensions_cols_loc:], rec[dimensions_cols_loc:])
],
"Time": str(round(rec[time_loc].timestamp() * 1_000)),
"TimeUnit": "MILLISECONDS",
"Version": version,
}
if len(measure_cols_names) == 1:
record["MeasureName"] = measure_cols_names[0]
record["MeasureValueType"] = measure_types[0]
record["MeasureValue"] = str(rec[measure_cols_loc])
else:
record["MeasureName"] = measure_cols_names[0]
record["MeasureValueType"] = "MULTI"
record["MeasureValues"] = [
{"Name": measure_name, "Value": str(measure_value), "Type": measure_value_type}
for measure_name, measure_value, measure_value_type in zip(
measure_cols_names, rec[measure_cols_loc:dimensions_cols_loc], measure_types
)
]
records.append(record)
_utils.try_it(
f=client.write_records,
ex=(client.exceptions.ThrottlingException, client.exceptions.InternalServerException),
max_num_tries=5,
DatabaseName=database,
TableName=table,
Records=[
{
"Dimensions": [
{"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)}
for name, value in zip(cols_names[2:], rec[2:])
],
"MeasureName": cols_names[1],
"MeasureValueType": measure_type,
"MeasureValue": str(rec[1]),
"Time": str(round(rec[0].timestamp() * 1_000)),
"TimeUnit": "MILLISECONDS",
"Version": version,
}
for rec in batch
],
Records=records,
)
except client.exceptions.RejectedRecordsException as ex:
return cast(List[Dict[str, str]], ex.response["RejectedRecords"])
Expand Down Expand Up @@ -148,7 +163,7 @@ def write(
database: str,
table: str,
time_col: str,
measure_col: str,
measure_col: Union[str, List[str]],
dimensions_cols: List[str],
version: int = 1,
num_threads: int = 32,
Expand All @@ -166,8 +181,8 @@ def write(
Amazon Timestream table name.
time_col : str
DataFrame column name to be used as time. MUST be a timestamp column.
measure_col : str
DataFrame column name to be used as measure.
measure_col : Union[str, List[str]]
DataFrame column name(s) to be used as measure.
dimensions_cols : List[str]
List of DataFrame column names to be used as dimensions.
version : int
Expand Down Expand Up @@ -208,9 +223,13 @@ def write(
>>> assert len(rejected_records) == 0

"""
measure_type: str = _data_types.timestream_type_from_pandas(df[[measure_col]])
_logger.debug("measure_type: %s", measure_type)
cols_names: List[str] = [time_col, measure_col] + dimensions_cols
measure_cols_names: List[str] = measure_col if isinstance(measure_col, list) else [measure_col]
_logger.debug("measure_cols_names: %s", measure_cols_names)
measure_types: List[str] = [
_data_types.timestream_type_from_pandas(df[[measure_col_name]]) for measure_col_name in measure_cols_names
]
_logger.debug("measure_types: %s", measure_types)
cols_names: List[str] = [time_col] + measure_cols_names + dimensions_cols
_logger.debug("cols_names: %s", cols_names)
batches: List[List[Any]] = _utils.chunkify(lst=_df2list(df=df[cols_names]), max_length=100)
_logger.debug("len(batches): %s", len(batches))
Expand All @@ -221,7 +240,8 @@ def write(
itertools.repeat(database),
itertools.repeat(table),
itertools.repeat(cols_names),
itertools.repeat(measure_type),
itertools.repeat(measure_cols_names),
itertools.repeat(measure_types),
itertools.repeat(version),
batches,
itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)),
Expand Down
31 changes: 31 additions & 0 deletions tests/test_timestream.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,34 @@ def test_real_csv_load_scenario(timestream_database_and_table):
assert len(rejected_records) == 0
df = wr.timestream.query(f'SELECT COUNT(*) AS counter FROM "{name}"."{name}"')
assert df["counter"].iloc[0] == 126_000


def test_multimeasure_scenario(timestream_database_and_table):
df = pd.DataFrame(
{
"time": [datetime.now(), datetime.now(), datetime.now()],
"dim0": ["foo", "boo", "bar"],
"dim1": [1, 2, 3],
"measure1": [1.0, 1.1, 1.2],
"measure2": [2.0, 2.1, 2.2],
}
)
rejected_records = wr.timestream.write(
df=df,
database=timestream_database_and_table,
table=timestream_database_and_table,
time_col="time",
measure_col=["measure1", "measure2"],
dimensions_cols=["dim0", "dim1"],
)
assert len(rejected_records) == 0
df = wr.timestream.query(
f"""
SELECT
*
FROM "{timestream_database_and_table}"."{timestream_database_and_table}"
ORDER BY time
DESC LIMIT 10
""",
)
assert df.shape == (3, 6)