diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index 4109dbe1c..354e76cdf 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -456,7 +456,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state if database and table: quoting: Optional[int] = csv.QUOTE_NONE escapechar: Optional[str] = "\\" - header: Union[bool, List[str]] = False + header: Union[bool, List[str]] = pandas_kwargs.get("header", False) date_format: Optional[str] = "%Y-%m-%d %H:%M:%S.%f" pd_kwargs: Dict[str, Any] = {} compression: Optional[str] = pandas_kwargs.get("compression", None) @@ -529,7 +529,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state catalog_table_input=catalog_table_input, catalog_id=catalog_id, compression=pandas_kwargs.get("compression"), - skip_header_line_count=None, + skip_header_line_count=True if header else None, serde_library=serde_library, serde_parameters=serde_parameters, ) diff --git a/tests/test_s3_text.py b/tests/test_s3_text.py index 185d48d0d..e730caa93 100644 --- a/tests/test_s3_text.py +++ b/tests/test_s3_text.py @@ -119,6 +119,52 @@ def test_csv(path): wr.s3.read_csv(path=paths, iterator=True) +@pytest.mark.parametrize("header", [True, ["identifier"]]) +def test_csv_dataset_header(path, header, glue_database, glue_table): + path0 = f"{path}test_csv_dataset0.csv" + df0 = pd.DataFrame({"id": [1, 2, 3]}) + wr.s3.to_csv( + df=df0, + path=path0, + dataset=True, + database=glue_database, + table=glue_table, + index=False, + header=header, + ) + df1 = wr.s3.read_csv(path=path0) + if isinstance(header, list): + df0.columns = header + assert df0.equals(df1) + + +@pytest.mark.parametrize("mode", ["append", "overwrite"]) +def test_csv_dataset_header_modes(path, mode, glue_database, glue_table): + path0 = f"{path}test_csv_dataset0.csv" + dfs = [ + pd.DataFrame({"id": [1, 2, 3]}), + pd.DataFrame({"id": [4, 5, 6]}), + ] + for df in dfs: + wr.s3.to_csv( + df=df, + path=path0, + dataset=True, + database=glue_database, + table=glue_table, + mode=mode, + index=False, + header=True, + ) + dfs_conc = pd.concat(dfs) + df_res = wr.s3.read_csv(path=path0) + + if mode == "append": + assert len(df_res) == len(dfs_conc) + else: + assert df_res.equals(dfs[-1]) + + def test_json(path): df0 = pd.DataFrame({"id": [1, 2, 3]}) path0 = f"{path}test_json0.json"