diff --git a/awswrangler/timestream.py b/awswrangler/timestream.py index 4bbc337c7..de760baeb 100644 --- a/awswrangler/timestream.py +++ b/awswrangler/timestream.py @@ -32,6 +32,7 @@ def _write_batch( table: str, cols_names: List[str], measure_type: str, + version: int, batch: List[Any], boto3_primitives: _utils.Boto3PrimitivesType, ) -> List[Dict[str, str]]: @@ -59,6 +60,7 @@ def _write_batch( "MeasureValue": str(rec[1]), "Time": str(round(rec[0].timestamp() * 1_000)), "TimeUnit": "MILLISECONDS", + "Version": version, } for rec in batch ], @@ -117,6 +119,7 @@ def write( time_col: str, measure_col: str, dimensions_cols: List[str], + version: int = 1, num_threads: int = 32, boto3_session: Optional[boto3.Session] = None, ) -> List[Dict[str, str]]: @@ -136,6 +139,9 @@ def write( DataFrame column name to be used as measure. dimensions_cols : List[str] List of DataFrame column names to be used as dimensions. + version : int + Version number used for upserts. + Documentation https://docs.aws.amazon.com/timestream/latest/developerguide/API_WriteRecords.html. num_threads : str Number of thread to be used for concurrent writing. boto3_session : boto3.Session(), optional @@ -185,6 +191,7 @@ def write( itertools.repeat(table), itertools.repeat(cols_names), itertools.repeat(measure_type), + itertools.repeat(version), batches, itertools.repeat(_utils.boto3_to_primitives(boto3_session=boto3_session)), ) diff --git a/tests/test_timestream.py b/tests/test_timestream.py index 769c6fea2..1d19370dc 100644 --- a/tests/test_timestream.py +++ b/tests/test_timestream.py @@ -46,6 +46,59 @@ def test_basic_scenario(timestream_database_and_table): assert df.shape == (3, 8) +def test_versioned(timestream_database_and_table): + name = timestream_database_and_table + time = [datetime.now(), datetime.now(), datetime.now()] + dfs = [ + pd.DataFrame( + { + "time": time, + "dim0": ["foo", "boo", "bar"], + "dim1": [1, 2, 3], + "measure": [1.0, 1.1, 1.2], + } + ), + pd.DataFrame( + { + "time": time, + "dim0": ["foo", "boo", "bar"], + "dim1": [1, 2, 3], + "measure": [1.0, 1.1, 1.9], + } + ), + pd.DataFrame( + { + "time": time, + "dim0": ["foo", "boo", "bar"], + "dim1": [1, 2, 3], + "measure": [1.0, 1.1, 1.9], + } + ), + ] + versions = [1, 1, 2] + rejected_rec_nums = [0, 1, 0] + for df, version, rejected_rec_num in zip(dfs, versions, rejected_rec_nums): + rejected_records = wr.timestream.write( + df=df, + database=name, + table=name, + time_col="time", + measure_col="measure", + dimensions_cols=["dim0", "dim1"], + version=version, + ) + assert len(rejected_records) == rejected_rec_num + df_out = wr.timestream.query( + f""" + SELECT + * + FROM "{name}"."{name}" + DESC LIMIT 10 + """ + ) + assert df_out.shape == (3, 5) + + def test_real_csv_load_scenario(timestream_database_and_table): name = timestream_database_and_table df = pd.read_csv(