From 1a377221c911d25540626113a9e9c26a70d1d5d5 Mon Sep 17 00:00:00 2001 From: igorborgest Date: Thu, 7 May 2020 23:38:30 -0300 Subject: [PATCH] Add support for uint8, uint16, uint32 and uint64. #76 --- awswrangler/_data_types.py | 8 +++--- testing/test_awswrangler/test_data_lake.py | 29 ++++++++++++++++++++++ 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/awswrangler/_data_types.py b/awswrangler/_data_types.py index fac82a37b..819ce0f98 100644 --- a/awswrangler/_data_types.py +++ b/awswrangler/_data_types.py @@ -114,12 +114,14 @@ def pyarrow2athena(dtype: pa.DataType) -> str: # pylint: disable=too-many-branc """Pyarrow to Athena data types conversion.""" if pa.types.is_int8(dtype): return "tinyint" - if pa.types.is_int16(dtype): + if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): return "smallint" - if pa.types.is_int32(dtype): + if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): return "int" - if pa.types.is_int64(dtype): + if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): return "bigint" + if pa.types.is_uint64(dtype): + raise exceptions.UnsupportedType("There is no support for uint64, please consider int64 or uint32.") if pa.types.is_float32(dtype): return "float" if pa.types.is_float64(dtype): diff --git a/testing/test_awswrangler/test_data_lake.py b/testing/test_awswrangler/test_data_lake.py index f95c691fb..4fb4bf68f 100644 --- a/testing/test_awswrangler/test_data_lake.py +++ b/testing/test_awswrangler/test_data_lake.py @@ -1367,3 +1367,32 @@ def test_copy_replacing_filename(bucket): assert objs[0] == expected_file wr.s3.delete_objects(path=path) wr.s3.delete_objects(path=path2) + + +def test_unsigned_parquet(bucket, database): + path = f"s3://{bucket}/test_unsigned_parquet/" + table = "test_unsigned_parquet" + wr.s3.delete_objects(path=path) + df = pd.DataFrame({"c0": [0, 0, (2 ** 8) - 1], "c1": [0, 0, (2 ** 16) - 1], "c2": [0, 0, (2 ** 32) - 1]}) + df["c0"] = df.c0.astype("uint8") + df["c1"] = df.c1.astype("uint16") + df["c2"] = df.c2.astype("uint32") + paths = wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table, mode="overwrite")["paths"] + wr.s3.wait_objects_exist(paths=paths, use_threads=False) + df = wr.athena.read_sql_table(table=table, database=database) + assert df.c0.sum() == (2 ** 8) - 1 + assert df.c1.sum() == (2 ** 16) - 1 + assert df.c2.sum() == (2 ** 32) - 1 + schema = wr.s3.read_parquet_metadata(path=path)[0] + assert schema["c0"] == "smallint" + assert schema["c1"] == "int" + assert schema["c2"] == "bigint" + df = wr.s3.read_parquet(path=path) + assert df.c0.sum() == (2 ** 8) - 1 + assert df.c1.sum() == (2 ** 16) - 1 + assert df.c2.sum() == (2 ** 32) - 1 + + df = pd.DataFrame({"c0": [0, 0, (2 ** 64) - 1]}) + df["c0"] = df.c0.astype("uint64") + with pytest.raises(wr.exceptions.UnsupportedType): + wr.s3.to_parquet(df=df, path=path, dataset=True, database=database, table=table, mode="overwrite")