diff --git a/awswrangler/s3/_write_dataset.py b/awswrangler/s3/_write_dataset.py index bf3a7a1f4..977ff152c 100644 --- a/awswrangler/s3/_write_dataset.py +++ b/awswrangler/s3/_write_dataset.py @@ -1,7 +1,6 @@ """Amazon S3 Write Dataset (PRIVATE).""" import logging -import uuid from typing import Any, Callable, Dict, List, Optional, Tuple, Union import boto3 @@ -24,12 +23,12 @@ def _to_partitions( mode: str, partition_cols: List[str], bucketing_info: Optional[Tuple[List[str], int]], + filename_prefix: str, boto3_session: boto3.Session, **func_kwargs: Any, ) -> Tuple[List[str], Dict[str, List[str]]]: partitions_values: Dict[str, List[str]] = {} proxy: _WriteProxy = _WriteProxy(use_threads=concurrent_partitioning) - filename_prefix = uuid.uuid4().hex for keys, subgroup in df.groupby(by=partition_cols, observed=True): subgroup = subgroup.drop(partition_cols, axis="columns") @@ -60,6 +59,7 @@ def _to_partitions( func=func, df=subgroup, path_root=prefix, + filename_prefix=filename_prefix, boto3_session=boto3_session, use_threads=use_threads, **func_kwargs, @@ -74,10 +74,10 @@ def _to_buckets( df: pd.DataFrame, path_root: str, bucketing_info: Tuple[List[str], int], + filename_prefix: str, boto3_session: boto3.Session, use_threads: bool, proxy: Optional[_WriteProxy] = None, - filename_prefix: Optional[str] = None, **func_kwargs: Any, ) -> List[str]: _proxy: _WriteProxy = proxy if proxy else _WriteProxy(use_threads=False) @@ -85,14 +85,12 @@ def _to_buckets( lambda row: _get_bucket_number(bucketing_info[1], [row[col_name] for col_name in bucketing_info[0]]), axis="columns", ) - if filename_prefix is None: - filename_prefix = uuid.uuid4().hex for bucket_number, subgroup in df.groupby(by=bucket_number_series, observed=True): _proxy.write( func=func, df=subgroup, path_root=path_root, - filename=f"{filename_prefix}_bucket-{bucket_number:05d}", + filename_prefix=f"{filename_prefix}_bucket-{bucket_number:05d}", boto3_session=boto3_session, use_threads=use_threads, **func_kwargs, @@ -133,6 +131,7 @@ def _to_dataset( concurrent_partitioning: bool, df: pd.DataFrame, path_root: str, + filename_prefix: str, index: bool, use_threads: bool, mode: str, @@ -168,6 +167,7 @@ def _to_dataset( use_threads=use_threads, mode=mode, bucketing_info=bucketing_info, + filename_prefix=filename_prefix, partition_cols=partition_cols, boto3_session=boto3_session, index=index, @@ -180,13 +180,20 @@ def _to_dataset( path_root=path_root, use_threads=use_threads, bucketing_info=bucketing_info, + filename_prefix=filename_prefix, boto3_session=boto3_session, index=index, **func_kwargs, ) else: paths = func( - df=df, path_root=path_root, use_threads=use_threads, boto3_session=boto3_session, index=index, **func_kwargs + df=df, + path_root=path_root, + filename_prefix=filename_prefix, + use_threads=use_threads, + boto3_session=boto3_session, + index=index, + **func_kwargs, ) _logger.debug("paths: %s", paths) _logger.debug("partitions_values: %s", partitions_values) diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index 21cd492e4..329288b53 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -150,13 +150,11 @@ def _to_parquet( use_threads: bool, path: Optional[str] = None, path_root: Optional[str] = None, - filename: Optional[str] = None, + filename_prefix: Optional[str] = uuid.uuid4().hex, max_rows_by_file: Optional[int] = 0, ) -> List[str]: if path is None and path_root is not None: - if filename is None: - filename = uuid.uuid4().hex - file_path: str = f"{path_root}{filename}{compression_ext}.parquet" + file_path: str = f"{path_root}{filename_prefix}{compression_ext}.parquet" elif path is not None and path_root is None: file_path = path else: @@ -207,6 +205,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals s3_additional_kwargs: Optional[Dict[str, Any]] = None, sanitize_columns: bool = False, dataset: bool = False, + filename_prefix: Optional[str] = None, partition_cols: Optional[List[str]] = None, bucketing_info: Optional[Tuple[List[str], int]] = None, concurrent_partitioning: bool = False, @@ -283,6 +282,8 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals partition_cols, mode, database, table, description, parameters, columns_comments, concurrent_partitioning, catalog_versioning, projection_enabled, projection_types, projection_ranges, projection_values, projection_intervals, projection_digits, catalog_id, schema_evolution. + filename_prefix: str, optional + If dataset=True, add a filename prefix to the output files. partition_cols: List[str], optional List of column names that will be used to create partitions. Only takes effect if dataset=True. bucketing_info: Tuple[List[str], int], optional @@ -499,6 +500,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals dtype = dtype if dtype else {} partitions_values: Dict[str, List[str]] = {} mode = "append" if mode is None else mode + filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex cpus: int = _utils.ensure_cpu_count(use_threads=use_threads) session: boto3.Session = _utils.ensure_session(session=boto3_session) @@ -560,6 +562,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals concurrent_partitioning=concurrent_partitioning, df=df, path_root=path, # type: ignore + filename_prefix=filename_prefix, index=index, compression=compression, compression_ext=compression_ext, diff --git a/awswrangler/s3/_write_text.py b/awswrangler/s3/_write_text.py index 71f435536..864e9774c 100644 --- a/awswrangler/s3/_write_text.py +++ b/awswrangler/s3/_write_text.py @@ -37,16 +37,14 @@ def _to_text( s3_additional_kwargs: Optional[Dict[str, str]], path: Optional[str] = None, path_root: Optional[str] = None, - filename: Optional[str] = None, + filename_prefix: Optional[str] = uuid.uuid4().hex, **pandas_kwargs: Any, ) -> List[str]: if df.empty is True: raise exceptions.EmptyDataFrame() if path is None and path_root is not None: - if filename is None: - filename = uuid.uuid4().hex file_path: str = ( - f"{path_root}{filename}.{file_format}{_COMPRESSION_2_EXT.get(pandas_kwargs.get('compression'))}" + f"{path_root}{filename_prefix}.{file_format}{_COMPRESSION_2_EXT.get(pandas_kwargs.get('compression'))}" ) elif path is not None and path_root is None: file_path = path @@ -83,6 +81,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state s3_additional_kwargs: Optional[Dict[str, Any]] = None, sanitize_columns: bool = False, dataset: bool = False, + filename_prefix: Optional[str] = None, partition_cols: Optional[List[str]] = None, bucketing_info: Optional[Tuple[List[str], int]] = None, concurrent_partitioning: bool = False, @@ -165,6 +164,8 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state partition_cols, mode, database, table, description, parameters, columns_comments, concurrent_partitioning, catalog_versioning, projection_enabled, projection_types, projection_ranges, projection_values, projection_intervals, projection_digits, catalog_id, schema_evolution. + filename_prefix: str, optional + If dataset=True, add a filename prefix to the output files. partition_cols: List[str], optional List of column names that will be used to create partitions. Only takes effect if dataset=True. bucketing_info: Tuple[List[str], int], optional @@ -403,6 +404,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state dtype = dtype if dtype else {} partitions_values: Dict[str, List[str]] = {} mode = "append" if mode is None else mode + filename_prefix = filename_prefix + uuid.uuid4().hex if filename_prefix else uuid.uuid4().hex session: boto3.Session = _utils.ensure_session(session=boto3_session) # Sanitize table to respect Athena's standards @@ -480,6 +482,7 @@ def to_csv( # pylint: disable=too-many-arguments,too-many-locals,too-many-state index=index, sep=sep, compression=compression, + filename_prefix=filename_prefix, use_threads=use_threads, partition_cols=partition_cols, bucketing_info=bucketing_info, diff --git a/tests/conftest.py b/tests/conftest.py index 7f44ff12c..93e83ec0d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -271,3 +271,14 @@ def timestream_database_and_table(): yield name wr.timestream.delete_table(name, name) wr.timestream.delete_database(name) + + +@pytest.fixture(scope="function") +def compare_filename_prefix(): + def assert_filename_prefix(filename, filename_prefix, test_prefix): + if filename_prefix: + assert filename.startswith(test_prefix) + else: + assert not filename.startswith(test_prefix) + + return assert_filename_prefix diff --git a/tests/test_s3_parquet.py b/tests/test_s3_parquet.py index a7eeda567..1a82c4c99 100644 --- a/tests/test_s3_parquet.py +++ b/tests/test_s3_parquet.py @@ -192,6 +192,49 @@ def test_to_parquet_file_dtype(path, use_threads): assert str(df2.c1.dtype) == "string" +@pytest.mark.parametrize("filename_prefix", [None, "my_prefix"]) +@pytest.mark.parametrize("use_threads", [True, False]) +def test_to_parquet_filename_prefix(compare_filename_prefix, path, filename_prefix, use_threads): + test_prefix = "my_prefix" + df = pd.DataFrame({"col": [1, 2, 3], "col2": ["A", "A", "B"]}) + file_path = f"{path}0.parquet" + + # If Dataset is False, parquet file should never start with prefix + filename = wr.s3.to_parquet( + df=df, path=file_path, dataset=False, filename_prefix=filename_prefix, use_threads=use_threads + )["paths"][0].split("/")[-1] + assert not filename.startswith(test_prefix) + + # If Dataset is True, parquet file starts with prefix if one is supplied + filename = wr.s3.to_parquet( + df=df, path=path, dataset=True, filename_prefix=filename_prefix, use_threads=use_threads + )["paths"][0].split("/")[-1] + compare_filename_prefix(filename, filename_prefix, test_prefix) + + # Partitioned + filename = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + filename_prefix=filename_prefix, + partition_cols=["col2"], + use_threads=use_threads, + )["paths"][0].split("/")[-1] + compare_filename_prefix(filename, filename_prefix, test_prefix) + + # Bucketing + filename = wr.s3.to_parquet( + df=df, + path=path, + dataset=True, + filename_prefix=filename_prefix, + bucketing_info=(["col2"], 2), + use_threads=use_threads, + )["paths"][0].split("/")[-1] + compare_filename_prefix(filename, filename_prefix, test_prefix) + assert filename.endswith("bucket-00000.snappy.parquet") + + def test_read_parquet_map_types(path): df = pd.DataFrame({"c0": [0, 1, 1, 2]}, dtype=np.int8) file_path = f"{path}0.parquet" diff --git a/tests/test_s3_text.py b/tests/test_s3_text.py index 55aaa883c..21e441df2 100644 --- a/tests/test_s3_text.py +++ b/tests/test_s3_text.py @@ -130,6 +130,52 @@ def test_json(path): assert df1.equals(wr.s3.read_json(path=[path0, path1], use_threads=True)) +@pytest.mark.parametrize("filename_prefix", [None, "my_prefix"]) +@pytest.mark.parametrize("use_threads", [True, False]) +def test_to_text_filename_prefix(compare_filename_prefix, path, filename_prefix, use_threads): + test_prefix = "my_prefix" + df = pd.DataFrame({"col": [1, 2, 3], "col2": ["A", "A", "B"]}) + + # If Dataset is False, csv/json file should never start with prefix + file_path = f"{path}0.json" + filename = wr.s3.to_json(df=df, path=file_path, use_threads=use_threads)[0].split("/")[-1] + assert not filename.startswith(test_prefix) + file_path = f"{path}0.csv" + filename = wr.s3.to_csv( + df=df, path=file_path, dataset=False, filename_prefix=filename_prefix, use_threads=use_threads + )["paths"][0].split("/")[-1] + assert not filename.startswith(test_prefix) + + # If Dataset is True, csv file starts with prefix if one is supplied + filename = wr.s3.to_csv(df=df, path=path, dataset=True, filename_prefix=filename_prefix, use_threads=use_threads)[ + "paths" + ][0].split("/")[-1] + compare_filename_prefix(filename, filename_prefix, test_prefix) + + # Partitioned + filename = wr.s3.to_csv( + df=df, + path=path, + dataset=True, + filename_prefix=filename_prefix, + partition_cols=["col2"], + use_threads=use_threads, + )["paths"][0].split("/")[-1] + compare_filename_prefix(filename, filename_prefix, test_prefix) + + # Bucketing + filename = wr.s3.to_csv( + df=df, + path=path, + dataset=True, + filename_prefix=filename_prefix, + bucketing_info=(["col2"], 2), + use_threads=use_threads, + )["paths"][0].split("/")[-1] + compare_filename_prefix(filename, filename_prefix, test_prefix) + assert filename.endswith("bucket-00000.csv") + + def test_fwf(path): text = "1 Herfelingen27-12-18\n2 Lambusart14-06-18\n3Spormaggiore15-04-18" client_s3 = boto3.client("s3")