Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 31 additions & 23 deletions awswrangler/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ def metadata_to_glue(self,
mode="append",
cast_columns=None,
extra_args=None):
schema = Glue._build_schema(dataframe=dataframe,
partition_cols=partition_cols,
preserve_index=preserve_index)
schema, partition_cols_schema = Glue._build_schema(
dataframe=dataframe,
partition_cols=partition_cols,
preserve_index=preserve_index)
table = table if table else Glue._parse_table_name(path)
table = table.lower().replace(".", "_")
if mode == "overwrite":
Expand All @@ -147,7 +148,7 @@ def metadata_to_glue(self,
self.create_table(database=database,
table=table,
schema=schema,
partition_cols=partition_cols,
partition_cols_schema=partition_cols_schema,
path=path,
file_format=file_format,
extra_args=extra_args)
Expand Down Expand Up @@ -180,14 +181,13 @@ def create_table(self,
schema,
path,
file_format,
partition_cols=None,
partition_cols_schema=None,
extra_args=None):
if file_format == "parquet":
table_input = Glue.parquet_table_definition(
table, partition_cols, schema, path)
table_input = Glue.parquet_table_definition(table, partition_cols_schema, schema, path)
elif file_format == "csv":
table_input = Glue.csv_table_definition(table,
partition_cols,
partition_cols_schema,
schema,
path,
extra_args=extra_args)
Expand Down Expand Up @@ -248,13 +248,20 @@ def _build_schema(dataframe, partition_cols, preserve_index):
dataframe=dataframe, preserve_index=preserve_index)

schema_built = []
partition_cols_types = {}
for name, dtype in pyarrow_schema:
if name not in partition_cols:
athena_type = Glue.type_pyarrow2athena(dtype)
athena_type = Glue.type_pyarrow2athena(dtype)
if name in partition_cols:
partition_cols_types[name] = athena_type
else:
schema_built.append((name, athena_type))

partition_cols_schema_built = [(name, partition_cols_types[name]) for name in partition_cols]

logger.debug(f"schema_built:\n{schema_built}")
return schema_built
logger.debug(
f"partition_cols_schema_built:\n{partition_cols_schema_built}")
return schema_built, partition_cols_schema_built

@staticmethod
def _parse_table_name(path):
Expand All @@ -263,17 +270,17 @@ def _parse_table_name(path):
return path.rpartition("/")[2]

@staticmethod
def csv_table_definition(table, partition_cols, schema, path, extra_args):
def csv_table_definition(table, partition_cols_schema, schema, path, extra_args):
sep = extra_args["sep"] if "sep" in extra_args else ","
if not partition_cols:
partition_cols = []
if not partition_cols_schema:
partition_cols_schema = []
return {
"Name":
table,
"PartitionKeys": [{
"Name": x,
"Type": "string"
} for x in partition_cols],
"Name": x[0],
"Type": x[1]
} for x in partition_cols_schema],
"TableType":
"EXTERNAL_TABLE",
"Parameters": {
Expand Down Expand Up @@ -334,16 +341,17 @@ def csv_partition_definition(partition):
}

@staticmethod
def parquet_table_definition(table, partition_cols, schema, path):
if not partition_cols:
partition_cols = []
def parquet_table_definition(table, partition_cols_schema,
schema, path):
if not partition_cols_schema:
partition_cols_schema = []
return {
"Name":
table,
"PartitionKeys": [{
"Name": x,
"Type": "string"
} for x in partition_cols],
"Name": x[0],
"Type": x[1]
} for x in partition_cols_schema],
"TableType":
"EXTERNAL_TABLE",
"Parameters": {
Expand Down
41 changes: 38 additions & 3 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,8 +491,41 @@ def test_to_csv_with_sep(
dataframe2["id"] == 0].iloc[0]["name"]


@pytest.mark.parametrize("index", [None, "default", "my_date", "my_timestamp"])
def test_to_parquet_types(session, bucket, database, index):
@pytest.mark.parametrize("index, partition_cols", [
(None, []),
("default", []),
("my_date", []),
("my_timestamp", []),
(None, ["my_int"]),
("default", ["my_int"]),
("my_date", ["my_int"]),
("my_timestamp", ["my_int"]),
(None, ["my_float"]),
("default", ["my_float"]),
("my_date", ["my_float"]),
("my_timestamp", ["my_float"]),
(None, ["my_bool"]),
("default", ["my_bool"]),
("my_date", ["my_bool"]),
("my_timestamp", ["my_bool"]),
(None, ["my_date"]),
("default", ["my_date"]),
("my_date", ["my_date"]),
("my_timestamp", ["my_date"]),
(None, ["my_timestamp"]),
("default", ["my_timestamp"]),
("my_date", ["my_timestamp"]),
("my_timestamp", ["my_timestamp"]),
(None, ["my_timestamp", "my_date"]),
("default", ["my_date", "my_timestamp"]),
("my_date", ["my_timestamp", "my_date"]),
("my_timestamp", ["my_date", "my_timestamp"]),
(None, ["my_bool", "my_timestamp", "my_date"]),
("default", ["my_date", "my_timestamp", "my_int"]),
("my_date", ["my_timestamp", "my_float", "my_date"]),
("my_timestamp", ["my_int", "my_float", "my_bool", "my_date", "my_timestamp"]),
])
def test_to_parquet_types(session, bucket, database, index, partition_cols):
dataframe = pandas.read_csv("data_samples/complex.csv",
dtype={"my_int_with_null": "Int64"},
parse_dates=["my_timestamp", "my_date"])
Expand All @@ -510,6 +543,7 @@ def test_to_parquet_types(session, bucket, database, index):
database=database,
path=f"s3://{bucket}/test/",
preserve_index=preserve_index,
partition_cols=partition_cols,
mode="overwrite",
procs_cpu_bound=1)
sleep(1)
Expand All @@ -518,7 +552,8 @@ def test_to_parquet_types(session, bucket, database, index):
for row in dataframe2.itertuples():
if index:
if index == "default":
assert isinstance(row[8], numpy.int64)
ex_index_col = 8 - len(partition_cols)
assert isinstance(row[ex_index_col], numpy.int64)
elif index == "my_date":
assert isinstance(row.new_index, date)
elif index == "my_timestamp":
Expand Down