Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion awswrangler/catalog/_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def _create_csv_table(
catalog_versioning: bool,
sep: str,
skip_header_line_count: Optional[int],
serde_library: Optional[str],
serde_parameters: Optional[Dict[str, str]],
boto3_session: Optional[boto3.Session],
projection_enabled: bool,
projection_types: Optional[Dict[str, str]],
Expand Down Expand Up @@ -329,6 +331,8 @@ def _create_csv_table(
compression=compression,
sep=sep,
skip_header_line_count=skip_header_line_count,
serde_library=serde_library,
serde_parameters=serde_parameters,
)
table_exist: bool = catalog_table_input is not None
_logger.debug("table_exist: %s", table_exist)
Expand Down Expand Up @@ -670,6 +674,8 @@ def create_csv_table(
catalog_versioning: bool = False,
sep: str = ",",
skip_header_line_count: Optional[int] = None,
serde_library: Optional[str] = None,
serde_parameters: Optional[Dict[str, str]] = None,
boto3_session: Optional[boto3.Session] = None,
projection_enabled: bool = False,
projection_types: Optional[Dict[str, str]] = None,
Expand All @@ -679,7 +685,7 @@ def create_csv_table(
projection_digits: Optional[Dict[str, str]] = None,
catalog_id: Optional[str] = None,
) -> None:
"""Create a CSV Table (Metadata Only) in the AWS Glue Catalog.
r"""Create a CSV Table (Metadata Only) in the AWS Glue Catalog.

'https://docs.aws.amazon.com/athena/latest/ug/data-types.html'

Expand Down Expand Up @@ -715,6 +721,13 @@ def create_csv_table(
String of length 1. Field delimiter for the output file.
skip_header_line_count : Optional[int]
Number of Lines to skip regarding to the header.
serde_library : Optional[str]
Specifies the SerDe Serialization library which will be used. You need to provide the Class library name
as a string.
If no library is provided the default is `org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe`.
serde_parameters : Optional[str]
Dictionary of initialization parameters for the SerDe.
The default is `{"field.delim": sep, "escape.delim": "\\"}`.
projection_enabled : bool
Enable Partition Projection on Athena (https://docs.aws.amazon.com/athena/latest/ug/partition-projection.html)
projection_types : Optional[Dict[str, str]]
Expand Down Expand Up @@ -793,4 +806,6 @@ def create_csv_table(
catalog_table_input=catalog_table_input,
sep=sep,
skip_header_line_count=skip_header_line_count,
serde_library=serde_library,
serde_parameters=serde_parameters,
)
24 changes: 11 additions & 13 deletions awswrangler/catalog/_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def _csv_table_definition(
compression: Optional[str],
sep: str,
skip_header_line_count: Optional[int],
serde_library: Optional[str],
serde_parameters: Optional[Dict[str, str]],
) -> Dict[str, Any]:
compressed: bool = compression is not None
parameters: Dict[str, str] = {
Expand All @@ -116,7 +118,13 @@ def _csv_table_definition(
"areColumnsQuoted": "false",
}
if skip_header_line_count is not None:
parameters["skip.header.line.count"] = "1"
parameters["skip.header.line.count"] = str(skip_header_line_count)
serde_info = {
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"
if serde_library is None
else serde_library,
"Parameters": {"field.delim": sep, "escape.delim": "\\"} if serde_parameters is None else serde_parameters,
}
return {
"Name": table,
"PartitionKeys": [{"Name": cname, "Type": dtype} for cname, dtype in partitions_types.items()],
Expand All @@ -129,21 +137,11 @@ def _csv_table_definition(
"OutputFormat": "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat",
"Compressed": compressed,
"NumberOfBuckets": -1 if bucketing_info is None else bucketing_info[1],
"SerdeInfo": {
"Parameters": {"field.delim": sep, "escape.delim": "\\"},
"SerializationLibrary": "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe",
},
"SerdeInfo": serde_info,
"BucketColumns": [] if bucketing_info is None else bucketing_info[0],
"StoredAsSubDirectories": False,
"SortColumns": [],
"Parameters": {
"classification": "csv",
"compressionType": str(compression).lower(),
"typeOfData": "file",
"delimiter": sep,
"columnsOrdered": "true",
"areColumnsQuoted": "false",
},
"Parameters": parameters,
},
}

Expand Down
2 changes: 2 additions & 0 deletions awswrangler/s3/_write_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state
catalog_id=catalog_id,
compression=pandas_kwargs.get("compression"),
skip_header_line_count=None,
serde_library=None,
serde_parameters=None,
)
if partitions_values and (regular_partitions is True):
_logger.debug("partitions_values:\n%s", partitions_values)
Expand Down
30 changes: 27 additions & 3 deletions tests/test_athena_csv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
import logging
from sys import version_info

Expand Down Expand Up @@ -337,7 +338,8 @@ def test_athena_csv_types(path, glue_database, glue_table):

@pytest.mark.parametrize("use_threads", [True, False])
@pytest.mark.parametrize("ctas_approach", [True, False])
def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach):
@pytest.mark.parametrize("line_count", [1, 2])
def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach, line_count):
df = pd.DataFrame({"c0": [1, 2], "c1": [3.3, 4.4], "c2": ["foo", "boo"]})
df["c0"] = df["c0"].astype("Int64")
df["c2"] = df["c2"].astype("string")
Expand All @@ -347,10 +349,10 @@ def test_skip_header(path, glue_database, glue_table, use_threads, ctas_approach
table=glue_table,
path=path,
columns_types={"c0": "bigint", "c1": "double", "c2": "string"},
skip_header_line_count=1,
skip_header_line_count=line_count,
)
df2 = wr.athena.read_sql_table(glue_table, glue_database, use_threads=use_threads, ctas_approach=ctas_approach)
assert df.equals(df2)
assert df.iloc[line_count - 1 :].reset_index(drop=True).equals(df2)


@pytest.mark.parametrize("use_threads", [True, False])
Expand Down Expand Up @@ -444,3 +446,25 @@ def test_csv_compressed(path, glue_table, glue_database, use_threads, concurrent
assert df2["id"].sum() == 6
ensure_data_types_csv(df2)
assert wr.catalog.delete_table_if_exists(database=glue_database, table=glue_table) is True


@pytest.mark.parametrize("use_threads", [True, False])
@pytest.mark.parametrize("ctas_approach", [True, False])
def test_opencsv_serde(path, glue_table, glue_database, use_threads, ctas_approach):
df = pd.DataFrame({"c0": ['"1"', '"2"', '"3"'], "c1": ['"4"', '"5"', '"6"'], "c2": ['"a"', '"b"', '"c"']})
wr.s3.to_csv(
df=df, path=f"{path}0.csv", sep=",", index=False, header=False, use_threads=use_threads, quoting=csv.QUOTE_NONE
)
wr.catalog.create_csv_table(
database=glue_database,
table=glue_table,
path=path,
columns_types={"c0": "string", "c1": "string", "c2": "string"},
serde_library="org.apache.hadoop.hive.serde2.OpenCSVSerde",
serde_parameters={"separatorChar": ",", "quoteChar": '"', "escapeChar": "\\"},
)
df2 = wr.athena.read_sql_table(
table=glue_table, database=glue_database, use_threads=use_threads, ctas_approach=ctas_approach
)
df = df.applymap(lambda x: x.replace('"', "")).convert_dtypes()
assert df.equals(df2)