diff --git a/awswrangler/s3/_write_parquet.py b/awswrangler/s3/_write_parquet.py index 398941618..6cf60319c 100644 --- a/awswrangler/s3/_write_parquet.py +++ b/awswrangler/s3/_write_parquet.py @@ -1,6 +1,7 @@ """Amazon PARQUET S3 Parquet Write Module (PRIVATE).""" import logging +import math import uuid from typing import Any, Dict, List, Optional, Tuple, Union @@ -13,6 +14,9 @@ from awswrangler import _data_types, _utils, catalog, exceptions from awswrangler._config import apply_configs +from awswrangler.s3._delete import delete_objects +from awswrangler.s3._describe import size_objects +from awswrangler.s3._list import does_object_exist from awswrangler.s3._read_parquet import _read_parquet_metadata from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args from awswrangler.s3._write_dataset import _to_dataset @@ -32,6 +36,7 @@ def _to_parquet_file( s3_additional_kwargs: Optional[Dict[str, str]], path: Optional[str] = None, path_root: Optional[str] = None, + max_file_size: Optional[int] = 0, ) -> str: if path is None and path_root is not None: file_path: str = f"{path_root}{uuid.uuid4().hex}{compression_ext}.parquet" @@ -40,6 +45,7 @@ def _to_parquet_file( else: raise RuntimeError("path and path_root received at the same time.") _logger.debug("file_path: %s", file_path) + write_path = file_path table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True) for col_name, col_type in dtype.items(): if col_name in table.column_names: @@ -53,17 +59,58 @@ def _to_parquet_file( session=boto3_session, s3_additional_kwargs=s3_additional_kwargs, # 32 MB (32 * 2**20) ) - with pyarrow.parquet.ParquetWriter( - where=file_path, - write_statistics=True, - use_dictionary=True, - filesystem=fs, - coerce_timestamps="ms", - compression=compression, - flavor="spark", - schema=table.schema, - ) as writer: - writer.write_table(table) + + file_counter, writer, chunks, chunk_size = 1, None, 1, df.shape[0] + if max_file_size is not None and max_file_size > 0: + chunk_size = int((max_file_size * df.shape[0]) / table.nbytes) + chunks = math.ceil(df.shape[0] / chunk_size) + + for chunk in range(chunks): + offset = chunk * chunk_size + + if writer is None: + writer = pyarrow.parquet.ParquetWriter( + where=write_path, + write_statistics=True, + use_dictionary=True, + filesystem=fs, + coerce_timestamps="ms", + compression=compression, + flavor="spark", + schema=table.schema, + ) + # handle the case of overwriting an existing file + if does_object_exist(write_path): + delete_objects([write_path]) + + writer.write_table(table.slice(offset, chunk_size)) + + if max_file_size == 0 or max_file_size is None: + continue + + file_size = writer.file_handle.buffer.__sizeof__() + if does_object_exist(write_path): + file_size += size_objects([write_path])[write_path] + + if file_size >= max_file_size: + write_path = __get_file_path(file_counter, file_path) + file_counter += 1 + writer.close() + writer = None + + if writer is not None: + writer.close() + + return file_path + + +def __get_file_path(file_counter, file_path): + dot_index = file_path.rfind(".") + file_index = "-" + str(file_counter) + if dot_index == -1: + file_path = file_path + file_index + else: + file_path = file_path[:dot_index] + file_index + file_path[dot_index:] return file_path @@ -95,6 +142,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals projection_values: Optional[Dict[str, str]] = None, projection_intervals: Optional[Dict[str, str]] = None, projection_digits: Optional[Dict[str, str]] = None, + max_file_size: Optional[int] = 0, catalog_id: Optional[str] = None, ) -> Dict[str, Union[List[str], Dict[str, List[str]]]]: """Write Parquet file or dataset on Amazon S3. @@ -197,6 +245,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals Dictionary of partitions names and Athena projections digits. https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html (e.g. {'col_name': '1', 'col2_name': '2'}) + max_file_size : int + If the file size exceeds the specified size in bytes, another file is created + Default is 0 i.e. dont split the files + (e.g. 33554432 ,268435456,0) catalog_id : str, optional The ID of the Data Catalog from which to retrieve Databases. If none is provided, the AWS account ID is used by default. @@ -361,6 +413,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals boto3_session=session, s3_additional_kwargs=s3_additional_kwargs, dtype=dtype, + max_file_size=max_file_size, ) ] else: diff --git a/tests/test_moto.py b/tests/test_moto.py index ae739249f..c917cb2d7 100644 --- a/tests/test_moto.py +++ b/tests/test_moto.py @@ -309,6 +309,22 @@ def test_parquet(moto_s3): assert df.shape == (3, 19) +def test_parquet_with_size(moto_s3): + path = "s3://bucket/test.parquet" + df = get_df_list() + for i in range(20): + df = pd.concat([df, get_df_list()]) + wr.s3.to_parquet(df=df, path=path, index=False, dataset=False, max_file_size=1 * 2 ** 10) + df = wr.s3.read_parquet(path="s3://bucket/", dataset=False) + ensure_data_types(df, has_list=True) + assert df.shape == (63, 19) + file_objects = wr.s3.list_objects(path="s3://bucket/") + assert len(file_objects) == 9 + for i in range(7): + assert f"s3://bucket/test-{i+1}.parquet" in file_objects + assert "s3://bucket/test.parquet" in file_objects + + def test_s3_delete_object_success(moto_s3): path = "s3://bucket/test.parquet" wr.s3.to_parquet(df=get_df_list(), path=path, index=False, dataset=True, partition_cols=["par0", "par1"])