diff --git a/awswrangler/glue.py b/awswrangler/glue.py index 5aca6105c..617228fd9 100644 --- a/awswrangler/glue.py +++ b/awswrangler/glue.py @@ -135,9 +135,10 @@ def metadata_to_glue(self, mode="append", cast_columns=None, extra_args=None): - schema = Glue._build_schema(dataframe=dataframe, - partition_cols=partition_cols, - preserve_index=preserve_index) + schema, partition_cols_schema = Glue._build_schema( + dataframe=dataframe, + partition_cols=partition_cols, + preserve_index=preserve_index) table = table if table else Glue._parse_table_name(path) table = table.lower().replace(".", "_") if mode == "overwrite": @@ -147,7 +148,7 @@ def metadata_to_glue(self, self.create_table(database=database, table=table, schema=schema, - partition_cols=partition_cols, + partition_cols_schema=partition_cols_schema, path=path, file_format=file_format, extra_args=extra_args) @@ -180,14 +181,13 @@ def create_table(self, schema, path, file_format, - partition_cols=None, + partition_cols_schema=None, extra_args=None): if file_format == "parquet": - table_input = Glue.parquet_table_definition( - table, partition_cols, schema, path) + table_input = Glue.parquet_table_definition(table, partition_cols_schema, schema, path) elif file_format == "csv": table_input = Glue.csv_table_definition(table, - partition_cols, + partition_cols_schema, schema, path, extra_args=extra_args) @@ -248,13 +248,20 @@ def _build_schema(dataframe, partition_cols, preserve_index): dataframe=dataframe, preserve_index=preserve_index) schema_built = [] + partition_cols_types = {} for name, dtype in pyarrow_schema: - if name not in partition_cols: - athena_type = Glue.type_pyarrow2athena(dtype) + athena_type = Glue.type_pyarrow2athena(dtype) + if name in partition_cols: + partition_cols_types[name] = athena_type + else: schema_built.append((name, athena_type)) + partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols] + logger.debug(f"schema_built:\n{schema_built}") - return schema_built + logger.debug( + f"partition_cols_schema_built:\n{partition_cols_schema_built}") + return schema_built, partition_cols_schema_built @staticmethod def _parse_table_name(path): @@ -263,17 +270,17 @@ def _parse_table_name(path): return path.rpartition("/")[2] @staticmethod - def csv_table_definition(table, partition_cols, schema, path, extra_args): + def csv_table_definition(table, partition_cols_schema, schema, path, extra_args): sep = extra_args["sep"] if "sep" in extra_args else "," - if not partition_cols: - partition_cols = [] + if not partition_cols_schema: + partition_cols_schema = [] return { "Name": table, "PartitionKeys": [{ - "Name": x, - "Type": "string" - } for x in partition_cols], + "Name": x[0], + "Type": x[1] + } for x in partition_cols_schema], "TableType": "EXTERNAL_TABLE", "Parameters": { @@ -334,16 +341,17 @@ def csv_partition_definition(partition): } @staticmethod - def parquet_table_definition(table, partition_cols, schema, path): - if not partition_cols: - partition_cols = [] + def parquet_table_definition(table, partition_cols_schema, + schema, path): + if not partition_cols_schema: + partition_cols_schema = [] return { "Name": table, "PartitionKeys": [{ - "Name": x, - "Type": "string" - } for x in partition_cols], + "Name": x[0], + "Type": x[1] + } for x in partition_cols_schema], "TableType": "EXTERNAL_TABLE", "Parameters": { diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index 6a63f3fdf..5625c3eda 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -491,8 +491,41 @@ def test_to_csv_with_sep( dataframe2["id"] == 0].iloc[0]["name"] -@pytest.mark.parametrize("index", [None, "default", "my_date", "my_timestamp"]) -def test_to_parquet_types(session, bucket, database, index): +@pytest.mark.parametrize("index, partition_cols", [ + (None, []), + ("default", []), + ("my_date", []), + ("my_timestamp", []), + (None, ["my_int"]), + ("default", ["my_int"]), + ("my_date", ["my_int"]), + ("my_timestamp", ["my_int"]), + (None, ["my_float"]), + ("default", ["my_float"]), + ("my_date", ["my_float"]), + ("my_timestamp", ["my_float"]), + (None, ["my_bool"]), + ("default", ["my_bool"]), + ("my_date", ["my_bool"]), + ("my_timestamp", ["my_bool"]), + (None, ["my_date"]), + ("default", ["my_date"]), + ("my_date", ["my_date"]), + ("my_timestamp", ["my_date"]), + (None, ["my_timestamp"]), + ("default", ["my_timestamp"]), + ("my_date", ["my_timestamp"]), + ("my_timestamp", ["my_timestamp"]), + (None, ["my_timestamp", "my_date"]), + ("default", ["my_date", "my_timestamp"]), + ("my_date", ["my_timestamp", "my_date"]), + ("my_timestamp", ["my_date", "my_timestamp"]), + (None, ["my_bool", "my_timestamp", "my_date"]), + ("default", ["my_date", "my_timestamp", "my_int"]), + ("my_date", ["my_timestamp", "my_float", "my_date"]), + ("my_timestamp", ["my_int", "my_float", "my_bool", "my_date", "my_timestamp"]), +]) +def test_to_parquet_types(session, bucket, database, index, partition_cols): dataframe = pandas.read_csv("data_samples/complex.csv", dtype={"my_int_with_null": "Int64"}, parse_dates=["my_timestamp", "my_date"]) @@ -510,6 +543,7 @@ def test_to_parquet_types(session, bucket, database, index): database=database, path=f"s3://{bucket}/test/", preserve_index=preserve_index, + partition_cols=partition_cols, mode="overwrite", procs_cpu_bound=1) sleep(1) @@ -518,7 +552,8 @@ def test_to_parquet_types(session, bucket, database, index): for row in dataframe2.itertuples(): if index: if index == "default": - assert isinstance(row[8], numpy.int64) + ex_index_col = 8 - len(partition_cols) + assert isinstance(row[ex_index_col], numpy.int64) elif index == "my_date": assert isinstance(row.new_index, date) elif index == "my_timestamp":