diff --git a/awswrangler/timestream.py b/awswrangler/timestream.py index 660b7edb5..fc89b80f7 100644 --- a/awswrangler/timestream.py +++ b/awswrangler/timestream.py @@ -31,7 +31,8 @@ def _write_batch( database: str, table: str, cols_names: List[str], - measure_type: str, + measure_cols_names: List[str], + measure_types: List[str], version: int, batch: List[Any], boto3_primitives: _utils.Boto3PrimitivesType, @@ -43,27 +44,41 @@ def _write_batch( botocore_config=Config(read_timeout=20, max_pool_connections=5000, retries={"max_attempts": 10}), ) try: + time_loc = 0 + measure_cols_loc = 1 + dimensions_cols_loc = 1 + len(measure_cols_names) + records: List[Dict[str, Any]] = [] + for rec in batch: + record: Dict[str, Any] = { + "Dimensions": [ + {"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)} + for name, value in zip(cols_names[dimensions_cols_loc:], rec[dimensions_cols_loc:]) + ], + "Time": str(round(rec[time_loc].timestamp() * 1_000)), + "TimeUnit": "MILLISECONDS", + "Version": version, + } + if len(measure_cols_names) == 1: + record["MeasureName"] = measure_cols_names[0] + record["MeasureValueType"] = measure_types[0] + record["MeasureValue"] = str(rec[measure_cols_loc]) + else: + record["MeasureName"] = measure_cols_names[0] + record["MeasureValueType"] = "MULTI" + record["MeasureValues"] = [ + {"Name": measure_name, "Value": str(measure_value), "Type": measure_value_type} + for measure_name, measure_value, measure_value_type in zip( + measure_cols_names, rec[measure_cols_loc:dimensions_cols_loc], measure_types + ) + ] + records.append(record) _utils.try_it( f=client.write_records, ex=(client.exceptions.ThrottlingException, client.exceptions.InternalServerException), max_num_tries=5, DatabaseName=database, TableName=table, - Records=[ - { - "Dimensions": [ - {"Name": name, "DimensionValueType": "VARCHAR", "Value": str(value)} - for name, value in zip(cols_names[2:], rec[2:]) - ], - "MeasureName": cols_names[1], - "MeasureValueType": measure_type, - "MeasureValue": str(rec[1]), - "Time": str(round(rec[0].timestamp() * 1_000)), - "TimeUnit": "MILLISECONDS", - "Version": version, - } - for rec in batch - ], + Records=records, ) except client.exceptions.RejectedRecordsException as ex: return cast(List[Dict[str, str]], ex.response["RejectedRecords"]) @@ -148,7 +163,7 @@ def write( database: str, table: str, time_col: str, - measure_col: str, + measure_col: Union[str, List[str]], dimensions_cols: List[str], version: int = 1, num_threads: int = 32, @@ -166,8 +181,8 @@ def write( Amazon Timestream table name. time_col : str DataFrame column name to be used as time. MUST be a timestamp column. - measure_col : str - DataFrame column name to be used as measure. + measure_col : Union[str, List[str]] + DataFrame column name(s) to be used as measure. dimensions_cols : List[str] List of DataFrame column names to be used as dimensions. version : int @@ -208,9 +223,13 @@ def write( >>> assert len(rejected_records) == 0 """ - measure_type: str = _data_types.timestream_type_from_pandas(df[[measure_col]]) - _logger.debug("measure_type: %s", measure_type) - cols_names: List[str] = [time_col, measure_col] + dimensions_cols + measure_cols_names: List[str] = measure_col if isinstance(measure_col, list) else [measure_col] + _logger.debug("measure_cols_names: %s", measure_cols_names) + measure_types: List[str] = [ + _data_types.timestream_type_from_pandas(df[[measure_col_name]]) for measure_col_name in measure_cols_names + ] + _logger.debug("measure_types: %s", measure_types) + cols_names: List[str] = [time_col] + measure_cols_names + dimensions_cols _logger.debug("cols_names: %s", cols_names) batches: List[List[Any]] = _utils.chunkify(lst=_df2list(df=df[cols_names]), max_length=100) _logger.debug("len(batches): %s", len(batches)) @@ -221,7 +240,8 @@ def write( itertools.repeat(database), itertools.repeat(table), itertools.repeat(cols_names), - itertools.repeat(measure_type), + itertools.repeat(measure_cols_names), + itertools.repeat(measure_types), itertools.repeat(version), batches, itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), diff --git a/tests/test_timestream.py b/tests/test_timestream.py index 2c9a01132..a2d35de6e 100644 --- a/tests/test_timestream.py +++ b/tests/test_timestream.py @@ -180,3 +180,34 @@ def test_real_csv_load_scenario(timestream_database_and_table): assert len(rejected_records) == 0 df = wr.timestream.query(f'SELECT COUNT(*) AS counter FROM "{name}"."{name}"') assert df["counter"].iloc[0] == 126_000 + + +def test_multimeasure_scenario(timestream_database_and_table): + df = pd.DataFrame( + { + "time": [datetime.now(), datetime.now(), datetime.now()], + "dim0": ["foo", "boo", "bar"], + "dim1": [1, 2, 3], + "measure1": [1.0, 1.1, 1.2], + "measure2": [2.0, 2.1, 2.2], + } + ) + rejected_records = wr.timestream.write( + df=df, + database=timestream_database_and_table, + table=timestream_database_and_table, + time_col="time", + measure_col=["measure1", "measure2"], + dimensions_cols=["dim0", "dim1"], + ) + assert len(rejected_records) == 0 + df = wr.timestream.query( + f""" + SELECT + * + FROM "{timestream_database_and_table}"."{timestream_database_and_table}" + ORDER BY time + DESC LIMIT 10 + """, + ) + assert df.shape == (3, 6)