Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion awswrangler/catalog/_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ def add_csv_partitions(
catalog_id: Optional[str] = None,
compression: Optional[str] = None,
sep: str = ",",
serde_library: Optional[str] = None,
serde_parameters: Optional[Dict[str, str]] = None,
boto3_session: Optional[boto3.Session] = None,
columns_types: Optional[Dict[str, str]] = None,
) -> None:
"""Add partitions (metadata) to a CSV Table in the AWS Glue Catalog.
r"""Add partitions (metadata) to a CSV Table in the AWS Glue Catalog.

Parameters
----------
Expand All @@ -73,6 +75,13 @@ def add_csv_partitions(
Compression style (``None``, ``gzip``, etc).
sep : str
String of length 1. Field delimiter for the output file.
serde_library : Optional[str]
Specifies the SerDe Serialization library which will be used. You need to provide the Class library name
as a string.
If no library is provided the default is `org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe`.
serde_parameters : Optional[str]
Dictionary of initialization parameters for the SerDe.
The default is `{"field.delim": sep, "escape.delim": "\\"}`.
boto3_session : boto3.Session(), optional
Boto3 Session. The default boto3 session will be used if boto3_session receive None.
columns_types: Optional[Dict[str, str]]
Expand Down Expand Up @@ -107,6 +116,8 @@ def add_csv_partitions(
compression=compression,
sep=sep,
columns_types=columns_types,
serde_library=serde_library,
serde_parameters=serde_parameters,
)
for k, v in partitions_values.items()
]
Expand Down
13 changes: 9 additions & 4 deletions awswrangler/catalog/_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,19 +152,24 @@ def _csv_partition_definition(
bucketing_info: Optional[Tuple[List[str], int]],
compression: Optional[str],
sep: str,
serde_library: Optional[str],
serde_parameters: Optional[Dict[str, str]],
columns_types: Optional[Dict[str, str]],
) -> Dict[str, Any]:
compressed: bool = compression is not None
serde_info = {
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
if serde_library is None
else serde_library,
"Parameters": {"field.delim": sep, "escape.delim": "\\"} if serde_parameters is None else serde_parameters,
}
Comment on lines +160 to +165
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we missed this method in the previous PR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxispeicher, apologies, I am not sure if I already tagged you to this PR or not for review?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure either. If so, I completely missed it 🙈
I also thought that I did include it but seems like I forgot about it eventually 😓.

definition: Dict[str, Any] = {
"StorageDescriptor": {
"InputFormat": "org.apache.hadoop.mapred.TextInputFormat",
"OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
"Location": location,
"Compressed": compressed,
"SerdeInfo": {
"Parameters": {"field.delim": sep, "escape.delim": "\\"},
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
},
"SerdeInfo": serde_info,
"StoredAsSubDirectories": False,
"NumberOfBuckets": -1 if bucketing_info is None else bucketing_info[1],
"BucketColumns": [] if bucketing_info is None else bucketing_info[0],
Expand Down
11 changes: 9 additions & 2 deletions awswrangler/s3/_write_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,11 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned(
df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True
)
serde_info: Dict[str, Any] = {}
if catalog_table_input:
serde_info = catalog_table_input["StorageDescriptor"]["SerdeInfo"]
serde_library: Optional[str] = serde_info.get("SerializationLibrary", None)
serde_parameters: Optional[Dict[str, str]] = serde_info.get("Parameters", None)
Comment on lines +504 to +508
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the to_csv method, instead of passing None to the create_ csv_table and add_csv_partitions method, we extract the serde info from the existing catalog table if it exists

catalog._create_csv_table( # pylint: disable=protected-access
database=database,
table=table,
Expand All @@ -525,8 +530,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
catalog_id=catalog_id,
compression=pandas_kwargs.get("compression"),
skip_header_line_count=None,
serde_library=None,
serde_parameters=None,
serde_library=serde_library,
serde_parameters=serde_parameters,
)
if partitions_values and (regular_partitions is True):
_logger.debug("partitions_values:\n%s", partitions_values)
Expand All @@ -537,6 +542,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
bucketing_info=bucketing_info,
boto3_session=session,
sep=sep,
serde_library=serde_library,
serde_parameters=serde_parameters,
catalog_id=catalog_id,
columns_types=columns_types,
compression=pandas_kwargs.get("compression"),
Expand Down
24 changes: 20 additions & 4 deletions tests/test_athena_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,15 +451,31 @@ def test_csv_compressed(path, glue_table, glue_database, use_threads, concurrent
@pytest.mark.parametrize("use_threads", [True, False])
@pytest.mark.parametrize("ctas_approach", [True, False])
def test_opencsv_serde(path, glue_table, glue_database, use_threads, ctas_approach):
df = pd.DataFrame({"c0": ['"1"', '"2"', '"3"'], "c1": ['"4"', '"5"', '"6"'], "c2": ['"a"', '"b"', '"c"']})
wr.s3.to_csv(
df=df, path=f"{path}0.csv", sep=",", index=False, header=False, use_threads=use_threads, quoting=csv.QUOTE_NONE
df = pd.DataFrame({"col": ["1", "2", "3"], "col2": ["A", "A", "B"]})
response = wr.s3.to_csv(
df=df,
path=path,
dataset=True,
partition_cols=["col2"],
sep=",",
index=False,
header=False,
use_threads=use_threads,
quoting=csv.QUOTE_NONE,
)
wr.catalog.create_csv_table(
database=glue_database,
table=glue_table,
path=path,
columns_types={"c0": "string", "c1": "string", "c2": "string"},
columns_types={"col": "string"},
partitions_types={"col2": "string"},
serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde",
serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"},
)
wr.catalog.add_csv_partitions(
database=glue_database,
table=glue_table,
partitions_values=response["partitions_values"],
Comment on lines +475 to +478
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extended the test to include add_csv_partitions

serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde",
serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"},
)
Expand Down
40 changes: 29 additions & 11 deletions tests/test_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ def test_catalog_get_databases(glue_database):
assert db["Description"] == "AWS Data Wrangler Test Arena - Glue Database"


def test_catalog_versioning(path, glue_database, glue_table):
def test_catalog_versioning(path, glue_database, glue_table, glue_table2):
wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table)
wr.s3.delete_objects(path=path)

# Version 0
# Version 1 - Parquet
df = pd.DataFrame({"c0": [1, 2]})
wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table, mode="overwrite")[
"paths"
Expand All @@ -172,7 +172,7 @@ def test_catalog_versioning(path, glue_database, glue_table):
assert len(df.columns) == 1
assert str(df.c0.dtype).startswith("Int")

# Version 1
# Version 2 - Parquet
df = pd.DataFrame({"c1": ["foo", "boo"]})
wr.s3.to_parquet(
df=df,
Expand All @@ -189,38 +189,56 @@ def test_catalog_versioning(path, glue_database, glue_table):
assert len(df.columns) == 1
assert str(df.c1.dtype) == "string"

# Version 2
# Version 1 - CSV
df = pd.DataFrame({"c1": [1.0, 2.0]})
wr.s3.to_csv(
df=df,
path=path,
dataset=True,
database=glue_database,
table=glue_table,
table=glue_table2,
mode="overwrite",
catalog_versioning=True,
index=False,
)
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 3
df = wr.athena.read_sql_table(table=glue_table, database=glue_database)
assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 1
df = wr.athena.read_sql_table(table=glue_table2, database=glue_database)
assert len(df.index) == 2
assert len(df.columns) == 1
assert str(df.c1.dtype).startswith("float")

# Version 3 (removing version 2)
# Version 1 - CSV (No evolution)
df = pd.DataFrame({"c1": [True, False]})
wr.s3.to_csv(
df=df,
path=path,
dataset=True,
database=glue_database,
table=glue_table,
table=glue_table2,
mode="overwrite",
catalog_versioning=False,
index=False,
)
assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 3
df = wr.athena.read_sql_table(table=glue_table, database=glue_database)
assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 1
df = wr.athena.read_sql_table(table=glue_table2, database=glue_database)
assert len(df.index) == 2
assert len(df.columns) == 1
assert str(df.c1.dtype).startswith("boolean")

# Version 2 - CSV
df = pd.DataFrame({"c1": [True, False]})
wr.s3.to_csv(
df=df,
path=path,
dataset=True,
database=glue_database,
table=glue_table2,
mode="overwrite",
catalog_versioning=True,
index=False,
)
assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 2
df = wr.athena.read_sql_table(table=glue_table2, database=glue_database)
assert len(df.index) == 2
assert len(df.columns) == 1
assert str(df.c1.dtype).startswith("boolean")
Expand Down