From 09ec0007e5810b195eae0f45f1d6f9789acd892c Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Mon, 3 May 2021 21:53:48 +0100 Subject: [PATCH 1/2] Extending logic to add_csv_partitions and leveraging catalog_table_input --- awswrangler/catalog/_add.py | 13 ++++++++++++- awswrangler/catalog/_definitions.py | 13 +++++++++---- awswrangler/s3/_write_text.py | 11 +++++++++-- tests/test_athena_csv.py | 24 ++++++++++++++++++++---- 4 files changed, 50 insertions(+), 11 deletions(-) diff --git a/awswrangler/catalog/_add.py b/awswrangler/catalog/_add.py index 8d9407221..01f30200e 100644 --- a/awswrangler/catalog/_add.py +++ b/awswrangler/catalog/_add.py @@ -48,10 +48,12 @@ def add_csv_partitions( catalog_id: Optional[str] = None, compression: Optional[str] = None, sep: str = ",", + serde_library: Optional[str] = None, + serde_parameters: Optional[Dict[str, str]] = None, boto3_session: Optional[boto3.Session] = None, columns_types: Optional[Dict[str, str]] = None, ) -> None: - """Add partitions (metadata) to a CSV Table in the AWS Glue Catalog. + r"""Add partitions (metadata) to a CSV Table in the AWS Glue Catalog. Parameters ---------- @@ -73,6 +75,13 @@ def add_csv_partitions( Compression style (``None``, ``gzip``, etc). sep : str String of length 1. Field delimiter for the output file. + serde_library : Optional[str] + Specifies the SerDe Serialization library which will be used. You need to provide the Class library name + as a string. + If no library is provided the default is `org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe`. + serde_parameters : Optional[str] + Dictionary of initialization parameters for the SerDe. + The default is `{"field.delim": sep, "escape.delim": "\\"}`. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. columns_types: Optional[Dict[str, str]] @@ -107,6 +116,8 @@ def add_csv_partitions( compression=compression, sep=sep, columns_types=columns_types, + serde_library=serde_library, + serde_parameters=serde_parameters, ) for k, v in partitions_values.items() ] diff --git a/awswrangler/catalog/_definitions.py b/awswrangler/catalog/_definitions.py index 60d617cf1..b1ee71bfd 100644 --- a/awswrangler/catalog/_definitions.py +++ b/awswrangler/catalog/_definitions.py @@ -152,19 +152,24 @@ def _csv_partition_definition( bucketing_info: Optional[Tuple[List[str], int]], compression: Optional[str], sep: str, + serde_library: Optional[str], + serde_parameters: Optional[Dict[str, str]], columns_types: Optional[Dict[str, str]], ) -> Dict[str, Any]: compressed: bool = compression is not None + serde_info = { + "SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + if serde_library is None + else serde_library, + "Parameters": {"field.delim": sep, "escape.delim": "\\"} if serde_parameters is None else serde_parameters, + } definition: Dict[str, Any] = { "StorageDescriptor": { "InputFormat": "org.apache.hadoop.mapred.TextInputFormat", "OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat", "Location": location, "Compressed": compressed, - "SerdeInfo": { - "Parameters": {"field.delim": sep, "escape.delim": "\\"}, - "SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe", - }, + "SerdeInfo": serde_info, "StoredAsSubDirectories": False, "NumberOfBuckets": -1 if bucketing_info is None else bucketing_info[1], "BucketColumns": [] if bucketing_info is None else bucketing_info[0], diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index 48a8256c6..67e184433 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -501,6 +501,11 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state columns_types, partitions_types = _data_types.athena_types_from_pandas_partitioned( df=df, index=index, partition_cols=partition_cols, dtype=dtype, index_left=True ) + serde_info: Dict[str, Any] = {} + if catalog_table_input: + serde_info = catalog_table_input["StorageDescriptor"]["SerdeInfo"] + serde_library: Optional[str] = serde_info.get("SerializationLibrary", None) + serde_parameters: Optional[Dict[str, str]] = serde_info.get("Parameters", None) catalog._create_csv_table( # pylint: disable=protected-access database=database, table=table, @@ -525,8 +530,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state catalog_id=catalog_id, compression=pandas_kwargs.get("compression"), skip_header_line_count=None, - serde_library=None, - serde_parameters=None, + serde_library=serde_library, + serde_parameters=serde_parameters, ) if partitions_values and (regular_partitions is True): _logger.debug("partitions_values:\n%s", partitions_values) @@ -537,6 +542,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state bucketing_info=bucketing_info, boto3_session=session, sep=sep, + serde_library=serde_library, + serde_parameters=serde_parameters, catalog_id=catalog_id, columns_types=columns_types, compression=pandas_kwargs.get("compression"), diff --git a/tests/test_athena_csv.py b/tests/test_athena_csv.py index 4af697421..84bb7c52d 100644 --- a/tests/test_athena_csv.py +++ b/tests/test_athena_csv.py @@ -451,15 +451,31 @@ def test_csv_compressed(path, glue_table, glue_database, use_threads, concurrent @pytest.mark.parametrize("use_threads", [True, False]) @pytest.mark.parametrize("ctas_approach", [True, False]) def test_opencsv_serde(path, glue_table, glue_database, use_threads, ctas_approach): - df = pd.DataFrame({"c0": ['"1"', '"2"', '"3"'], "c1": ['"4"', '"5"', '"6"'], "c2": ['"a"', '"b"', '"c"']}) - wr.s3.to_csv( - df=df, path=f"{path}0.csv", sep=",", index=False, header=False, use_threads=use_threads, quoting=csv.QUOTE_NONE + df = pd.DataFrame({"col": ["1", "2", "3"], "col2": ["A", "A", "B"]}) + response = wr.s3.to_csv( + df=df, + path=path, + dataset=True, + partition_cols=["col2"], + sep=",", + index=False, + header=False, + use_threads=use_threads, + quoting=csv.QUOTE_NONE, ) wr.catalog.create_csv_table( database=glue_database, table=glue_table, path=path, - columns_types={"c0": "string", "c1": "string", "c2": "string"}, + columns_types={"col": "string"}, + partitions_types={"col2": "string"}, + serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde", + serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"}, + ) + wr.catalog.add_csv_partitions( + database=glue_database, + table=glue_table, + partitions_values=response["partitions_values"], serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde", serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"}, ) From 979f728cf3efdc8ac76f264ed66c8445f005158c Mon Sep 17 00:00:00 2001 From: Abdel Jaidi Date: Mon, 3 May 2021 22:57:43 +0100 Subject: [PATCH 2/2] Adapting catalog versioning test --- tests/test_catalog.py | 40 +++++++++++++++++++++++++++++----------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/tests/test_catalog.py b/tests/test_catalog.py index ac0982169..6db17a2d0 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -157,11 +157,11 @@ def test_catalog_get_databases(glue_database): assert db["Description"] == "AWS Data Wrangler Test Arena - Glue Database" -def test_catalog_versioning(path, glue_database, glue_table): +def test_catalog_versioning(path, glue_database, glue_table, glue_table2): wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) wr.s3.delete_objects(path=path) - # Version 0 + # Version 1 - Parquet df = pd.DataFrame({"c0": [1, 2]}) wr.s3.to_parquet(df=df, path=path, dataset=True, database=glue_database, table=glue_table, mode="overwrite")[ "paths" @@ -172,7 +172,7 @@ def test_catalog_versioning(path, glue_database, glue_table): assert len(df.columns) == 1 assert str(df.c0.dtype).startswith("Int") - # Version 1 + # Version 2 - Parquet df = pd.DataFrame({"c1": ["foo", "boo"]}) wr.s3.to_parquet( df=df, @@ -189,38 +189,56 @@ def test_catalog_versioning(path, glue_database, glue_table): assert len(df.columns) == 1 assert str(df.c1.dtype) == "string" - # Version 2 + # Version 1 - CSV df = pd.DataFrame({"c1": [1.0, 2.0]}) wr.s3.to_csv( df=df, path=path, dataset=True, database=glue_database, - table=glue_table, + table=glue_table2, mode="overwrite", catalog_versioning=True, index=False, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 3 - df = wr.athena.read_sql_table(table=glue_table, database=glue_database) + assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 1 + df = wr.athena.read_sql_table(table=glue_table2, database=glue_database) assert len(df.index) == 2 assert len(df.columns) == 1 assert str(df.c1.dtype).startswith("float") - # Version 3 (removing version 2) + # Version 1 - CSV (No evolution) df = pd.DataFrame({"c1": [True, False]}) wr.s3.to_csv( df=df, path=path, dataset=True, database=glue_database, - table=glue_table, + table=glue_table2, mode="overwrite", catalog_versioning=False, index=False, ) - assert wr.catalog.get_table_number_of_versions(table=glue_table, database=glue_database) == 3 - df = wr.athena.read_sql_table(table=glue_table, database=glue_database) + assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 1 + df = wr.athena.read_sql_table(table=glue_table2, database=glue_database) + assert len(df.index) == 2 + assert len(df.columns) == 1 + assert str(df.c1.dtype).startswith("boolean") + + # Version 2 - CSV + df = pd.DataFrame({"c1": [True, False]}) + wr.s3.to_csv( + df=df, + path=path, + dataset=True, + database=glue_database, + table=glue_table2, + mode="overwrite", + catalog_versioning=True, + index=False, + ) + assert wr.catalog.get_table_number_of_versions(table=glue_table2, database=glue_database) == 2 + df = wr.athena.read_sql_table(table=glue_table2, database=glue_database) assert len(df.index) == 2 assert len(df.columns) == 1 assert str(df.c1.dtype).startswith("boolean")