Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions awswrangler/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,25 +31,30 @@ def _type_athena2pandas(dtype):
return "bool"
elif dtype in ["string", "char", "varchar", "array", "row", "map"]:
return "object"
elif dtype in ["timestamp", "date"]:
elif dtype == "timestamp":
return "datetime64"
elif dtype == "date":
return "date"
else:
raise UnsupportedType(f"Unsupported Athena type: {dtype}")

def get_query_dtype(self, query_execution_id):
cols_metadata = self.get_query_columns_metadata(
query_execution_id=query_execution_id)
dtype = {}
parse_timestamps = []
parse_dates = []
for col_name, col_type in cols_metadata.items():
ptype = Athena._type_athena2pandas(dtype=col_type)
if ptype == "datetime64":
parse_dates.append(col_name)
if ptype in ["datetime64", "date"]:
parse_timestamps.append(col_name)
if ptype == "date":
parse_dates.append(col_name)
else:
dtype[col_name] = ptype
logger.debug(f"dtype: {dtype}")
logger.debug(f"parse_dates: {parse_dates}")
return dtype, parse_dates
return dtype, parse_timestamps, parse_dates

def create_athena_bucket(self):
"""
Expand Down
74 changes: 52 additions & 22 deletions awswrangler/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import logging
from datetime import datetime, date

import pyarrow

from awswrangler.exceptions import UnsupportedType, UnsupportedFileFormat

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -43,6 +45,28 @@ def get_table_python_types(self, database, table):
dtypes = self.get_table_athena_types(database=database, table=table)
return {k: Glue.type_athena2python(v) for k, v in dtypes.items()}

@staticmethod
def type_pyarrow2athena(dtype):
dtype = str(dtype).lower()
if dtype == "int32":
return "int"
elif dtype == "int64":
return "bigint"
elif dtype == "float":
return "float"
elif dtype == "double":
return "double"
elif dtype == "bool":
return "boolean"
elif dtype == "string":
return "string"
elif dtype.startswith("timestamp"):
return "timestamp"
elif dtype.startswith("date"):
return "date"
else:
raise UnsupportedType(f"Unsupported Pyarrow type: {dtype}")

@staticmethod
def type_pandas2athena(dtype):
dtype = dtype.lower()
Expand All @@ -58,7 +82,7 @@ def type_pandas2athena(dtype):
return "boolean"
elif dtype == "object":
return "string"
elif dtype[:10] == "datetime64":
elif dtype.startswith("datetime64"):
return "timestamp"
else:
raise UnsupportedType(f"Unsupported Pandas type: {dtype}")
Expand Down Expand Up @@ -113,8 +137,7 @@ def metadata_to_glue(self,
extra_args=None):
schema = Glue._build_schema(dataframe=dataframe,
partition_cols=partition_cols,
preserve_index=preserve_index,
cast_columns=cast_columns)
preserve_index=preserve_index)
table = table if table else Glue._parse_table_name(path)
table = table.lower().replace(".", "_")
if mode == "overwrite":
Expand Down Expand Up @@ -198,31 +221,38 @@ def get_connection_details(self, name):
Name=name, HidePassword=False)["Connection"]

@staticmethod
def _build_schema(dataframe,
partition_cols,
preserve_index,
cast_columns=None):
def _extract_pyarrow_schema(dataframe, preserve_index):
cols = []
schema = []
for name, dtype in dataframe.dtypes.to_dict().items():
dtype = str(dtype)
if str(dtype) == "Int64":
schema.append((name, "int64"))
else:
cols.append(name)

# Convert pyarrow.Schema to list of tuples (e.g. [(name1, type1), (name2, type2)...])
schema += [(str(x.name), str(x.type))
for x in pyarrow.Schema.from_pandas(
df=dataframe[cols], preserve_index=preserve_index)]
logger.debug(f"schema: {schema}")
return schema

@staticmethod
def _build_schema(dataframe, partition_cols, preserve_index):
logger.debug(f"dataframe.dtypes:\n{dataframe.dtypes}")
if not partition_cols:
partition_cols = []

pyarrow_schema = Glue._extract_pyarrow_schema(
dataframe=dataframe, preserve_index=preserve_index)

schema_built = []
if preserve_index:
name = str(
dataframe.index.name) if dataframe.index.name else "index"
dataframe.index.name = "index"
dtype = str(dataframe.index.dtype)
if name not in partition_cols:
athena_type = Glue.type_pandas2athena(dtype)
schema_built.append((name, athena_type))
for col in dataframe.columns:
name = str(col)
if cast_columns and name in cast_columns:
dtype = cast_columns[name]
else:
dtype = str(dataframe[name].dtype)
for name, dtype in pyarrow_schema:
if name not in partition_cols:
athena_type = Glue.type_pandas2athena(dtype)
athena_type = Glue.type_pyarrow2athena(dtype)
schema_built.append((name, athena_type))

logger.debug(f"schema_built:\n{schema_built}")
return schema_built

Expand Down
6 changes: 4 additions & 2 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,14 +419,16 @@ def read_sql_athena(self,
message_error = f"Query error: {reason}"
raise AthenaQueryError(message_error)
else:
dtype, parse_dates = self._session.athena.get_query_dtype(
dtype, parse_timestamps, parse_dates = self._session.athena.get_query_dtype(
query_execution_id=query_execution_id)
path = f"{s3_output}{query_execution_id}.csv"
ret = self.read_csv(path=path,
dtype=dtype,
parse_dates=parse_dates,
parse_dates=parse_timestamps,
quoting=csv.QUOTE_ALL,
max_result_size=max_result_size)
for col in parse_dates:
ret[col] = ret[col].dt.date
return ret

def to_csv(
Expand Down
53 changes: 53 additions & 0 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,56 @@ def test_to_csv_with_sep(
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))
assert dataframe[dataframe["id"] == 0].iloc[0]["name"] == dataframe2[
dataframe2["id"] == 0].iloc[0]["name"]


@pytest.mark.parametrize("index", [None, "default", "my_date", "my_timestamp"])
def test_to_parquet_types(session, bucket, database, index):
dataframe = pandas.read_csv("data_samples/complex.csv",
dtype={"my_int_with_null": "Int64"},
parse_dates=["my_timestamp", "my_date"])
dataframe["my_date"] = dataframe["my_date"].dt.date
dataframe["my_bool"] = True

preserve_index = True
if not index:
preserve_index = False
elif index != "default":
dataframe["new_index"] = dataframe[index]
dataframe = dataframe.set_index("new_index")

session.pandas.to_parquet(dataframe=dataframe,
database=database,
path=f"s3://{bucket}/test/",
preserve_index=preserve_index,
mode="overwrite",
procs_cpu_bound=1)
sleep(1)
dataframe2 = session.pandas.read_sql_athena(sql="select * from test",
database=database)
for row in dataframe2.itertuples():
if index:
if index == "default":
assert isinstance(row[8], numpy.int64)
elif index == "my_date":
assert isinstance(row.new_index, date)
elif index == "my_timestamp":
assert isinstance(row.new_index, datetime)
assert isinstance(row.my_timestamp, datetime)
assert type(row.my_date) == date
assert isinstance(row.my_float, float)
assert isinstance(row.my_int, numpy.int64)
assert isinstance(row.my_string, str)
assert isinstance(row.my_bool, bool)
assert str(row.my_timestamp) == "2018-01-01 04:03:02.001000"
assert str(row.my_date) == "2019-02-02"
assert str(row.my_float) == "12345.6789"
assert str(row.my_int) == "123456789"
assert str(row.my_bool) == "True"
assert str(
row.my_string
) == "foo\nboo\nbar\nFOO\nBOO\nBAR\nxxxxx\nÁÃÀÂÇ\n汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå汉字汉字汉字汉字汉字汉字汉字æøåæøåæøåæøåæøåæøåæøåæøåæøåæøå"
assert len(dataframe.index) == len(dataframe2.index)
if index:
assert (len(list(dataframe.columns)) + 1) == len(list(dataframe2.columns))
else:
assert len(list(dataframe.columns)) == len(list(dataframe2.columns))