From 99e92cdec54224e31480fd882420ff303c34c6ac Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Sat, 1 May 2021 14:51:10 +0200 Subject: [PATCH] Add serde parameters to csv table creation --- awswrangler/catalog/_create.py | 17 +++++++++++++++- awswrangler/catalog/_definitions.py | 24 +++++++++++------------ awswrangler/s3/_write_text.py | 2 ++ tests/test_athena_csv.py | 30 ++++++++++++++++++++++++++--- 4 files changed, 56 insertions(+), 17 deletions(-) diff --git a/awswrangler/catalog/_create.py b/awswrangler/catalog/_create.py index 3ea83e62b..262632472 100644 --- a/awswrangler/catalog/_create.py +++ b/awswrangler/catalog/_create.py @@ -296,6 +296,8 @@ def _create_csv_table( catalog_versioning: bool, sep: str, skip_header_line_count: Optional[int], + serde_library: Optional[str], + serde_parameters: Optional[Dict[str, str]], boto3_session: Optional[boto3.Session], projection_enabled: bool, projection_types: Optional[Dict[str, str]], @@ -329,6 +331,8 @@ def _create_csv_table( compression=compression, sep=sep, skip_header_line_count=skip_header_line_count, + serde_library=serde_library, + serde_parameters=serde_parameters, ) table_exist: bool = catalog_table_input is not None _logger.debug("table_exist: %s", table_exist) @@ -670,6 +674,8 @@ def create_csv_table( catalog_versioning: bool = False, sep: str = ",", skip_header_line_count: Optional[int] = None, + serde_library: Optional[str] = None, + serde_parameters: Optional[Dict[str, str]] = None, boto3_session: Optional[boto3.Session] = None, projection_enabled: bool = False, projection_types: Optional[Dict[str, str]] = None, @@ -679,7 +685,7 @@ def create_csv_table( projection_digits: Optional[Dict[str, str]] = None, catalog_id: Optional[str] = None, ) -> None: - """Create a CSV Table (Metadata Only) in the AWS Glue Catalog. + r"""Create a CSV Table (Metadata Only) in the AWS Glue Catalog. 'https://docs.aws.amazon.com/athena/latest/ug/data-types.html' @@ -715,6 +721,13 @@ def create_csv_table( String of length 1. Field delimiter for the output file. skip_header_line_count : Optional[int] Number of Lines to skip regarding to the header. + serde_library : Optional[str] + Specifies the SerDe Serialization library which will be used. You need to provide the Class library name + as a string. + If no library is provided the default is `org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe`. + serde_parameters : Optional[str] + Dictionary of initialization parameters for the SerDe. + The default is `{"field.delim": sep, "escape.delim": "\\"}`. projection_enabled : bool Enable Partition Projection on Athena (https://docs.aws.amazon.com/athena/latest/ug/partition-projection.html) projection_types : Optional[Dict[str, str]] @@ -793,4 +806,6 @@ def create_csv_table( catalog_table_input=catalog_table_input, sep=sep, skip_header_line_count=skip_header_line_count, + serde_library=serde_library, + serde_parameters=serde_parameters, ) diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index 778d428dd..60d617cf1 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -105,6 +105,8 @@ def _csv_table_definition( compression: Optional[str], sep: str, skip_header_line_count: Optional[int], + serde_library: Optional[str], + serde_parameters: Optional[Dict[str, str]], ) -> Dict[str, Any]: compressed: bool = compression is not None parameters: Dict[str, str] = { @@ -116,7 +118,13 @@ def _csv_table_definition( "areColumnsQuoted": "false", } if skip_header_line_count is not None: - parameters["skip.header.line.count"] = "1" + parameters["skip.header.line.count"] = str(skip_header_line_count) + serde_info = { + "SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + if serde_library is None + else serde_library, + "Parameters": {"field.delim": sep, "escape.delim": "\\"} if serde_parameters is None else serde_parameters, + } return { "Name": table, "PartitionKeys": [{"Name": cname, "Type": dtype} for cname, dtype in partitions_types.items()], @@ -129,21 +137,11 @@ def _csv_table_definition( "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat", "Compressed": compressed, "NumberOfBuckets": -1 if bucketing_info is None else bucketing_info[1], - "SerdeInfo": { - "Parameters": {"field.delim": sep, "escape.delim": "\\"}, - "SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", - }, + "SerdeInfo": serde_info, "BucketColumns": [] if bucketing_info is None else bucketing_info[0], "StoredAsSubDirectories": False, "SortColumns": [], - "Parameters": { - "classification": "csv", - "compressionType": str(compression).lower(), - "typeOfData": "file", - "delimiter": sep, - "columnsOrdered": "true", - "areColumnsQuoted": "false", - }, + "Parameters": parameters, }, } diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index e2d832e35..48a8256c6 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -525,6 +525,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state catalog_id=catalog_id, compression=pandas_kwargs.get("compression"), skip_header_line_count=None, + serde_library=None, + serde_parameters=None, ) if partitions_values and (regular_partitions is True): _logger.debug("partitions_values:\n%s", partitions_values) diff --git a/tests/test_athena_csv.py b/tests/test_athena_csv.py index f62613421..4af697421 100644 --- a/tests/test_athena_csv.py +++ b/tests/test_athena_csv.py @@ -1,3 +1,4 @@ +import csv import logging from sys import version_info @@ -337,7 +338,8 @@ def test_athena_csv_types(path, glue_database, glue_table): @pytest.mark.parametrize("use_threads", [True, False]) @pytest.mark.parametrize("ctas_approach", [True, False]) -def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach): +@pytest.mark.parametrize("line_count", [1, 2]) +def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach, line_count): df = pd.DataFrame({"c0": [1, 2], "c1": [3.3, 4.4], "c2": ["foo", "boo"]}) df["c0"] = df["c0"].astype("Int64") df["c2"] = df["c2"].astype("string") @@ -347,10 +349,10 @@ def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach table=glue_table, path=path, columns_types={"c0": "bigint", "c1": "double", "c2": "string"}, - skip_header_line_count=1, + skip_header_line_count=line_count, ) df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads, ctas_approach=ctas_approach) - assert df.equals(df2) + assert df.iloc[line_count - 1 :].reset_index(drop=True).equals(df2) @pytest.mark.parametrize("use_threads", [True, False]) @@ -444,3 +446,25 @@ def test_csv_compressed(path, glue_table, glue_database, use_threads, concurrent assert df2["id"].sum() == 6 ensure_data_types_csv(df2) assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True + + +@pytest.mark.parametrize("use_threads", [True, False]) +@pytest.mark.parametrize("ctas_approach", [True, False]) +def test_opencsv_serde(path, glue_table, glue_database, use_threads, ctas_approach): + df = pd.DataFrame({"c0": ['"1"', '"2"', '"3"'], "c1": ['"4"', '"5"', '"6"'], "c2": ['"a"', '"b"', '"c"']}) + wr.s3.to_csv( + df=df, path=f"{path}0.csv", sep=",", index=False, header=False, use_threads=use_threads, quoting=csv.QUOTE_NONE + ) + wr.catalog.create_csv_table( + database=glue_database, + table=glue_table, + path=path, + columns_types={"c0": "string", "c1": "string", "c2": "string"}, + serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde", + serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"}, + ) + df2 = wr.athena.read_sql_table( + table=glue_table, database=glue_database, use_threads=use_threads, ctas_approach=ctas_approach + ) + df = df.applymap(lambda x: x.replace('"', "")).convert_dtypes() + assert df.equals(df2)