From 52bc8db13f40b70882e250ba6f821cda47cff247 Mon Sep 17 00:00:00 2001 From: Stijn De Haes Date: Thu, 19 Sep 2019 10:23:36 +0200 Subject: [PATCH 1/2] Partition columns now have correct type Previously all partition columns where of the string type. The same logic for non partition columns is reused to find out the type of the partition column, --- awswrangler/exceptions.py | 4 +++ awswrangler/glue.py | 62 +++++++++++++++++++++++---------------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/awswrangler/exceptions.py b/awswrangler/exceptions.py index de5dc56a6..fd2911fa4 100644 --- a/awswrangler/exceptions.py +++ b/awswrangler/exceptions.py @@ -64,3 +64,7 @@ class QueryCancelled(Exception): class QueryFailed(Exception): pass + + +class PartitionColumnTypeNotFound(Exception): + pass diff --git a/awswrangler/glue.py b/awswrangler/glue.py index 3bc816f6a..23a8f8c99 100644 --- a/awswrangler/glue.py +++ b/awswrangler/glue.py @@ -3,7 +3,7 @@ import logging from datetime import datetime, date -from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat +from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat, PartitionColumnTypeNotFound logger = logging.getLogger(__name__) @@ -111,10 +111,11 @@ def metadata_to_glue(self, mode="append", cast_columns=None, extra_args=None): - schema = Glue._build_schema(dataframe=dataframe, - partition_cols=partition_cols, - preserve_index=preserve_index, - cast_columns=cast_columns) + schema, partition_cols_schema = Glue._build_schema( + dataframe=dataframe, + partition_cols=partition_cols, + preserve_index=preserve_index, + cast_columns=cast_columns) table = table if table else Glue._parse_table_name(path) table = table.lower().replace(".", "_") if mode == "overwrite": @@ -124,13 +125,14 @@ def metadata_to_glue(self, self.create_table(database=database, table=table, schema=schema, - partition_cols=partition_cols, + partition_cols_schema=partition_cols_schema, path=path, file_format=file_format, extra_args=extra_args) if partition_cols: partitions_tuples = Glue._parse_partitions_tuples( objects_paths=objects_paths, partition_cols=partition_cols) + print(partitions_tuples) self.add_partitions( database=database, table=table, @@ -157,14 +159,13 @@ def create_table(self, schema, path, file_format, - partition_cols=None, + partition_cols_schema=None, extra_args=None): if file_format == "parquet": - table_input = Glue.parquet_table_definition( - table, partition_cols, schema, path) + table_input = Glue.parquet_table_definition(table, partition_cols_schema, schema, path) elif file_format == "csv": table_input = Glue.csv_table_definition(table, - partition_cols, + partition_cols_schema, schema, path, extra_args=extra_args) @@ -189,6 +190,9 @@ def add_partitions(self, database, table, partition_paths, file_format): for _ in range(pages_num): page = partitions[:100] del partitions[:100] + print(database) + print(table) + print(page) self._client_glue.batch_create_partition(DatabaseName=database, TableName=table, PartitionInputList=page) @@ -206,25 +210,32 @@ def _build_schema(dataframe, if not partition_cols: partition_cols = [] schema_built = [] + partition_cols_schema_built = [] if preserve_index: name = str( dataframe.index.name) if dataframe.index.name else "index" dataframe.index.name = "index" dtype = str(dataframe.index.dtype) + athena_type = Glue.type_pandas2athena(dtype) if name not in partition_cols: - athena_type = Glue.type_pandas2athena(dtype) schema_built.append((name, athena_type)) + else: + partition_cols_schema_built.append((name, athena_type)) for col in dataframe.columns: name = str(col) if cast_columns and name in cast_columns: dtype = cast_columns[name] else: dtype = str(dataframe[name].dtype) + athena_type = Glue.type_pandas2athena(dtype) if name not in partition_cols: - athena_type = Glue.type_pandas2athena(dtype) schema_built.append((name, athena_type)) + else: + partition_cols_schema_built.append((name, athena_type)) logger.debug(f"schema_built:\n{schema_built}") - return schema_built + logger.debug( + f"partition_cols_schema_built:\n{partition_cols_schema_built}") + return schema_built, partition_cols_schema_built @staticmethod def _parse_table_name(path): @@ -233,17 +244,17 @@ def _parse_table_name(path): return path.rpartition("/")[2] @staticmethod - def csv_table_definition(table, partition_cols, schema, path, extra_args): + def csv_table_definition(table, partition_cols_schema, schema, path, extra_args): sep = extra_args["sep"] if "sep" in extra_args else "," - if not partition_cols: - partition_cols = [] + if not partition_cols_schema: + partition_cols_schema = [] return { "Name": table, "PartitionKeys": [{ - "Name": x, - "Type": "string" - } for x in partition_cols], + "Name": x[0], + "Type": x[1] + } for x in partition_cols_schema], "TableType": "EXTERNAL_TABLE", "Parameters": { @@ -304,16 +315,17 @@ def csv_partition_definition(partition): } @staticmethod - def parquet_table_definition(table, partition_cols, schema, path): - if not partition_cols: - partition_cols = [] + def parquet_table_definition(table, partition_cols_schema, + schema, path): + if not partition_cols_schema: + partition_cols_schema = [] return { "Name": table, "PartitionKeys": [{ - "Name": x, - "Type": "string" - } for x in partition_cols], + "Name": x[0], + "Type": x[1] + } for x in partition_cols_schema], "TableType": "EXTERNAL_TABLE", "Parameters": { From 6df511daeac65d81705f6e03bc00485f74b0d3ae Mon Sep 17 00:00:00 2001 From: Igor Tavares Date: Thu, 19 Sep 2019 10:25:59 -0300 Subject: [PATCH 2/2] Updating branch, resolving conflicts, adding more tests --- awswrangler/athena.py | 13 ++-- awswrangler/exceptions.py | 4 -- awswrangler/glue.py | 90 ++++++++++++++++--------- awswrangler/pandas.py | 6 +- testing/test_awswrangler/test_pandas.py | 88 ++++++++++++++++++++++++ 5 files changed, 159 insertions(+), 42 deletions(-) diff --git a/awswrangler/athena.py b/awswrangler/athena.py index fb4b410ae..d8a845619 100644 --- a/awswrangler/athena.py +++ b/awswrangler/athena.py @@ -31,8 +31,10 @@ def _type_athena2pandas(dtype): return "bool" elif dtype in ["string", "char", "varchar", "array", "row", "map"]: return "object" - elif dtype in ["timestamp", "date"]: + elif dtype == "timestamp": return "datetime64" + elif dtype == "date": + return "date" else: raise UnsupportedType(f"Unsupported Athena type: {dtype}") @@ -40,16 +42,19 @@ def get_query_dtype(self, query_execution_id): cols_metadata = self.get_query_columns_metadata( query_execution_id=query_execution_id) dtype = {} + parse_timestamps = [] parse_dates = [] for col_name, col_type in cols_metadata.items(): ptype = Athena._type_athena2pandas(dtype=col_type) - if ptype == "datetime64": - parse_dates.append(col_name) + if ptype in ["datetime64", "date"]: + parse_timestamps.append(col_name) + if ptype == "date": + parse_dates.append(col_name) else: dtype[col_name] = ptype logger.debug(f"dtype: {dtype}") logger.debug(f"parse_dates: {parse_dates}") - return dtype, parse_dates + return dtype, parse_timestamps, parse_dates def create_athena_bucket(self): """ diff --git a/awswrangler/exceptions.py b/awswrangler/exceptions.py index fd2911fa4..de5dc56a6 100644 --- a/awswrangler/exceptions.py +++ b/awswrangler/exceptions.py @@ -64,7 +64,3 @@ class QueryCancelled(Exception): class QueryFailed(Exception): pass - - -class PartitionColumnTypeNotFound(Exception): - pass diff --git a/awswrangler/glue.py b/awswrangler/glue.py index 23a8f8c99..617228fd9 100644 --- a/awswrangler/glue.py +++ b/awswrangler/glue.py @@ -3,7 +3,9 @@ import logging from datetime import datetime, date -from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat, PartitionColumnTypeNotFound +import pyarrow + +from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat logger = logging.getLogger(__name__) @@ -43,6 +45,28 @@ def get_table_python_types(self, database, table): dtypes = self.get_table_athena_types(database=database, table=table) return {k: Glue.type_athena2python(v) for k, v in dtypes.items()} + @staticmethod + def type_pyarrow2athena(dtype): + dtype = str(dtype).lower() + if dtype == "int32": + return "int" + elif dtype == "int64": + return "bigint" + elif dtype == "float": + return "float" + elif dtype == "double": + return "double" + elif dtype == "bool": + return "boolean" + elif dtype == "string": + return "string" + elif dtype.startswith("timestamp"): + return "timestamp" + elif dtype.startswith("date"): + return "date" + else: + raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}") + @staticmethod def type_pandas2athena(dtype): dtype = dtype.lower() @@ -58,7 +82,7 @@ def type_pandas2athena(dtype): return "boolean" elif dtype == "object": return "string" - elif dtype[:10] == "datetime64": + elif dtype.startswith("datetime64"): return "timestamp" else: raise UnsupportedType(f"Unsupported Pandas type: {dtype}") @@ -114,8 +138,7 @@ def metadata_to_glue(self, schema, partition_cols_schema = Glue._build_schema( dataframe=dataframe, partition_cols=partition_cols, - preserve_index=preserve_index, - cast_columns=cast_columns) + preserve_index=preserve_index) table = table if table else Glue._parse_table_name(path) table = table.lower().replace(".", "_") if mode == "overwrite": @@ -132,7 +155,6 @@ def metadata_to_glue(self, if partition_cols: partitions_tuples = Glue._parse_partitions_tuples( objects_paths=objects_paths, partition_cols=partition_cols) - print(partitions_tuples) self.add_partitions( database=database, table=table, @@ -190,9 +212,6 @@ def add_partitions(self, database, table, partition_paths, file_format): for _ in range(pages_num): page = partitions[:100] del partitions[:100] - print(database) - print(table) - print(page) self._client_glue.batch_create_partition(DatabaseName=database, TableName=table, PartitionInputList=page) @@ -202,36 +221,43 @@ def get_connection_details(self, name): Name=name, HidePassword=False)["Connection"] @staticmethod - def _build_schema(dataframe, - partition_cols, - preserve_index, - cast_columns=None): + def _extract_pyarrow_schema(dataframe, preserve_index): + cols = [] + schema = [] + for name, dtype in dataframe.dtypes.to_dict().items(): + dtype = str(dtype) + if str(dtype) == "Int64": + schema.append((name, "int64")) + else: + cols.append(name) + + # Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...]) + schema += [(str(x.name), str(x.type)) + for x in pyarrow.Schema.from_pandas( + df=dataframe[cols], preserve_index=preserve_index)] + logger.debug(f"schema: {schema}") + return schema + + @staticmethod + def _build_schema(dataframe, partition_cols, preserve_index): logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}") if not partition_cols: partition_cols = [] + + pyarrow_schema = Glue._extract_pyarrow_schema( + dataframe=dataframe, preserve_index=preserve_index) + schema_built = [] - partition_cols_schema_built = [] - if preserve_index: - name = str( - dataframe.index.name) if dataframe.index.name else "index" - dataframe.index.name = "index" - dtype = str(dataframe.index.dtype) - athena_type = Glue.type_pandas2athena(dtype) - if name not in partition_cols: - schema_built.append((name, athena_type)) - else: - partition_cols_schema_built.append((name, athena_type)) - for col in dataframe.columns: - name = str(col) - if cast_columns and name in cast_columns: - dtype = cast_columns[name] + partition_cols_types = {} + for name, dtype in pyarrow_schema: + athena_type = Glue.type_pyarrow2athena(dtype) + if name in partition_cols: + partition_cols_types[name] = athena_type else: - dtype = str(dataframe[name].dtype) - athena_type = Glue.type_pandas2athena(dtype) - if name not in partition_cols: schema_built.append((name, athena_type)) - else: - partition_cols_schema_built.append((name, athena_type)) + + partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols] + logger.debug(f"schema_built:\n{schema_built}") logger.debug( f"partition_cols_schema_built:\n{partition_cols_schema_built}") diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 909c172fe..578de70d7 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -419,14 +419,16 @@ def read_sql_athena(self, message_error = f"Query error: {reason}" raise AthenaQueryError(message_error) else: - dtype, parse_dates = self._session.athena.get_query_dtype( + dtype, parse_timestamps, parse_dates = self._session.athena.get_query_dtype( query_execution_id=query_execution_id) path = f"{s3_output}{query_execution_id}.csv" ret = self.read_csv(path=path, dtype=dtype, - parse_dates=parse_dates, + parse_dates=parse_timestamps, quoting=csv.QUOTE_ALL, max_result_size=max_result_size) + for col in parse_dates: + ret[col] = ret[col].dt.date return ret def to_csv( diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index c33303e4d..5625c3eda 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -489,3 +489,91 @@ def test_to_csv_with_sep( assert len(list(dataframe.columns)) == len(list(dataframe2.columns)) assert dataframe[dataframe["id"] == 0].iloc[0]["name"] == dataframe2[ dataframe2["id"] == 0].iloc[0]["name"] + + +@pytest.mark.parametrize("index, partition_cols", [ + (None, []), + ("default", []), + ("my_date", []), + ("my_timestamp", []), + (None, ["my_int"]), + ("default", ["my_int"]), + ("my_date", ["my_int"]), + ("my_timestamp", ["my_int"]), + (None, ["my_float"]), + ("default", ["my_float"]), + ("my_date", ["my_float"]), + ("my_timestamp", ["my_float"]), + (None, ["my_bool"]), + ("default", ["my_bool"]), + ("my_date", ["my_bool"]), + ("my_timestamp", ["my_bool"]), + (None, ["my_date"]), + ("default", ["my_date"]), + ("my_date", ["my_date"]), + ("my_timestamp", ["my_date"]), + (None, ["my_timestamp"]), + ("default", ["my_timestamp"]), + ("my_date", ["my_timestamp"]), + ("my_timestamp", ["my_timestamp"]), + (None, ["my_timestamp", "my_date"]), + ("default", ["my_date", "my_timestamp"]), + ("my_date", ["my_timestamp", "my_date"]), + ("my_timestamp", ["my_date", "my_timestamp"]), + (None, ["my_bool", "my_timestamp", "my_date"]), + ("default", ["my_date", "my_timestamp", "my_int"]), + ("my_date", ["my_timestamp", "my_float", "my_date"]), + ("my_timestamp", ["my_int", "my_float", "my_bool", "my_date", "my_timestamp"]), +]) +def test_to_parquet_types(session, bucket, database, index, partition_cols): + dataframe = pandas.read_csv("data_samples/complex.csv", + dtype={"my_int_with_null": "Int64"}, + parse_dates=["my_timestamp", "my_date"]) + dataframe["my_date"] = dataframe["my_date"].dt.date + dataframe["my_bool"] = True + + preserve_index = True + if not index: + preserve_index = False + elif index != "default": + dataframe["new_index"] = dataframe[index] + dataframe = dataframe.set_index("new_index") + + session.pandas.to_parquet(dataframe=dataframe, + database=database, + path=f"s3://{bucket}/test/", + preserve_index=preserve_index, + partition_cols=partition_cols, + mode="overwrite", + procs_cpu_bound=1) + sleep(1) + dataframe2 = session.pandas.read_sql_athena(sql="select * from test", + database=database) + for row in dataframe2.itertuples(): + if index: + if index == "default": + ex_index_col = 8 - len(partition_cols) + assert isinstance(row[ex_index_col], numpy.int64) + elif index == "my_date": + assert isinstance(row.new_index, date) + elif index == "my_timestamp": + assert isinstance(row.new_index, datetime) + assert isinstance(row.my_timestamp, datetime) + assert type(row.my_date) == date + assert isinstance(row.my_float, float) + assert isinstance(row.my_int, numpy.int64) + assert isinstance(row.my_string, str) + assert isinstance(row.my_bool, bool) + assert str(row.my_timestamp) == "2018-01-01 04:03:02.001000" + assert str(row.my_date) == "2019-02-02" + assert str(row.my_float) == "12345.6789" + assert str(row.my_int) == "123456789" + assert str(row.my_bool) == "True" + assert str( + row.my_string + ) == "foo\nboo\nbar\nFOO\nBOO\nBAR\nxxxxx\nÁÃÀÂÇ\n汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå" + assert len(dataframe.index) == len(dataframe2.index) + if index: + assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns)) + else: + assert len(list(dataframe.columns)) == len(list(dataframe2.columns))