diff --git a/awswrangler/athena.py b/awswrangler/athena.py index fb4b410ae..d8a845619 100644 --- a/awswrangler/athena.py +++ b/awswrangler/athena.py @@ -31,8 +31,10 @@ def _type_athena2pandas(dtype): return "bool" elif dtype in ["string", "char", "varchar", "array", "row", "map"]: return "object" - elif dtype in ["timestamp", "date"]: + elif dtype == "timestamp": return "datetime64" + elif dtype == "date": + return "date" else: raise UnsupportedType(f"Unsupported Athena type: {dtype}") @@ -40,16 +42,19 @@ def get_query_dtype(self, query_execution_id): cols_metadata = self.get_query_columns_metadata( query_execution_id=query_execution_id) dtype = {} + parse_timestamps = [] parse_dates = [] for col_name, col_type in cols_metadata.items(): ptype = Athena._type_athena2pandas(dtype=col_type) - if ptype == "datetime64": - parse_dates.append(col_name) + if ptype in ["datetime64", "date"]: + parse_timestamps.append(col_name) + if ptype == "date": + parse_dates.append(col_name) else: dtype[col_name] = ptype logger.debug(f"dtype: {dtype}") logger.debug(f"parse_dates: {parse_dates}") - return dtype, parse_dates + return dtype, parse_timestamps, parse_dates def create_athena_bucket(self): """ diff --git a/awswrangler/glue.py b/awswrangler/glue.py index 3bc816f6a..5aca6105c 100644 --- a/awswrangler/glue.py +++ b/awswrangler/glue.py @@ -3,6 +3,8 @@ import logging from datetime import datetime, date +import pyarrow + from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat logger = logging.getLogger(__name__) @@ -43,6 +45,28 @@ def get_table_python_types(self, database, table): dtypes = self.get_table_athena_types(database=database, table=table) return {k: Glue.type_athena2python(v) for k, v in dtypes.items()} + @staticmethod + def type_pyarrow2athena(dtype): + dtype = str(dtype).lower() + if dtype == "int32": + return "int" + elif dtype == "int64": + return "bigint" + elif dtype == "float": + return "float" + elif dtype == "double": + return "double" + elif dtype == "bool": + return "boolean" + elif dtype == "string": + return "string" + elif dtype.startswith("timestamp"): + return "timestamp" + elif dtype.startswith("date"): + return "date" + else: + raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}") + @staticmethod def type_pandas2athena(dtype): dtype = dtype.lower() @@ -58,7 +82,7 @@ def type_pandas2athena(dtype): return "boolean" elif dtype == "object": return "string" - elif dtype[:10] == "datetime64": + elif dtype.startswith("datetime64"): return "timestamp" else: raise UnsupportedType(f"Unsupported Pandas type: {dtype}") @@ -113,8 +137,7 @@ def metadata_to_glue(self, extra_args=None): schema = Glue._build_schema(dataframe=dataframe, partition_cols=partition_cols, - preserve_index=preserve_index, - cast_columns=cast_columns) + preserve_index=preserve_index) table = table if table else Glue._parse_table_name(path) table = table.lower().replace(".", "_") if mode == "overwrite": @@ -198,31 +221,38 @@ def get_connection_details(self, name): Name=name, HidePassword=False)["Connection"] @staticmethod - def _build_schema(dataframe, - partition_cols, - preserve_index, - cast_columns=None): + def _extract_pyarrow_schema(dataframe, preserve_index): + cols = [] + schema = [] + for name, dtype in dataframe.dtypes.to_dict().items(): + dtype = str(dtype) + if str(dtype) == "Int64": + schema.append((name, "int64")) + else: + cols.append(name) + + # Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...]) + schema += [(str(x.name), str(x.type)) + for x in pyarrow.Schema.from_pandas( + df=dataframe[cols], preserve_index=preserve_index)] + logger.debug(f"schema: {schema}") + return schema + + @staticmethod + def _build_schema(dataframe, partition_cols, preserve_index): logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}") if not partition_cols: partition_cols = [] + + pyarrow_schema = Glue._extract_pyarrow_schema( + dataframe=dataframe, preserve_index=preserve_index) + schema_built = [] - if preserve_index: - name = str( - dataframe.index.name) if dataframe.index.name else "index" - dataframe.index.name = "index" - dtype = str(dataframe.index.dtype) - if name not in partition_cols: - athena_type = Glue.type_pandas2athena(dtype) - schema_built.append((name, athena_type)) - for col in dataframe.columns: - name = str(col) - if cast_columns and name in cast_columns: - dtype = cast_columns[name] - else: - dtype = str(dataframe[name].dtype) + for name, dtype in pyarrow_schema: if name not in partition_cols: - athena_type = Glue.type_pandas2athena(dtype) + athena_type = Glue.type_pyarrow2athena(dtype) schema_built.append((name, athena_type)) + logger.debug(f"schema_built:\n{schema_built}") return schema_built diff --git a/awswrangler/pandas.py b/awswrangler/pandas.py index 909c172fe..578de70d7 100644 --- a/awswrangler/pandas.py +++ b/awswrangler/pandas.py @@ -419,14 +419,16 @@ def read_sql_athena(self, message_error = f"Query error: {reason}" raise AthenaQueryError(message_error) else: - dtype, parse_dates = self._session.athena.get_query_dtype( + dtype, parse_timestamps, parse_dates = self._session.athena.get_query_dtype( query_execution_id=query_execution_id) path = f"{s3_output}{query_execution_id}.csv" ret = self.read_csv(path=path, dtype=dtype, - parse_dates=parse_dates, + parse_dates=parse_timestamps, quoting=csv.QUOTE_ALL, max_result_size=max_result_size) + for col in parse_dates: + ret[col] = ret[col].dt.date return ret def to_csv( diff --git a/testing/test_awswrangler/test_pandas.py b/testing/test_awswrangler/test_pandas.py index c33303e4d..6a63f3fdf 100644 --- a/testing/test_awswrangler/test_pandas.py +++ b/testing/test_awswrangler/test_pandas.py @@ -489,3 +489,56 @@ def test_to_csv_with_sep( assert len(list(dataframe.columns)) == len(list(dataframe2.columns)) assert dataframe[dataframe["id"] == 0].iloc[0]["name"] == dataframe2[ dataframe2["id"] == 0].iloc[0]["name"] + + +@pytest.mark.parametrize("index", [None, "default", "my_date", "my_timestamp"]) +def test_to_parquet_types(session, bucket, database, index): + dataframe = pandas.read_csv("data_samples/complex.csv", + dtype={"my_int_with_null": "Int64"}, + parse_dates=["my_timestamp", "my_date"]) + dataframe["my_date"] = dataframe["my_date"].dt.date + dataframe["my_bool"] = True + + preserve_index = True + if not index: + preserve_index = False + elif index != "default": + dataframe["new_index"] = dataframe[index] + dataframe = dataframe.set_index("new_index") + + session.pandas.to_parquet(dataframe=dataframe, + database=database, + path=f"s3://{bucket}/test/", + preserve_index=preserve_index, + mode="overwrite", + procs_cpu_bound=1) + sleep(1) + dataframe2 = session.pandas.read_sql_athena(sql="select * from test", + database=database) + for row in dataframe2.itertuples(): + if index: + if index == "default": + assert isinstance(row[8], numpy.int64) + elif index == "my_date": + assert isinstance(row.new_index, date) + elif index == "my_timestamp": + assert isinstance(row.new_index, datetime) + assert isinstance(row.my_timestamp, datetime) + assert type(row.my_date) == date + assert isinstance(row.my_float, float) + assert isinstance(row.my_int, numpy.int64) + assert isinstance(row.my_string, str) + assert isinstance(row.my_bool, bool) + assert str(row.my_timestamp) == "2018-01-01 04:03:02.001000" + assert str(row.my_date) == "2019-02-02" + assert str(row.my_float) == "12345.6789" + assert str(row.my_int) == "123456789" + assert str(row.my_bool) == "True" + assert str( + row.my_string + ) == "foo\nboo\nbar\nFOO\nBOO\nBAR\nxxxxx\nÁÃÀÂÇ\n汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå" + assert len(dataframe.index) == len(dataframe2.index) + if index: + assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns)) + else: + assert len(list(dataframe.columns)) == len(list(dataframe2.columns))