Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 64 additions & 11 deletions awswrangler/s3/_write_parquet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Amazon PARQUET S3 Parquet Write Module (PRIVATE)."""

import logging
import math
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union

Expand All @@ -13,6 +14,9 @@

from awswrangler import _data_types, _utils, catalog, exceptions
from awswrangler._config import apply_configs
from awswrangler.s3._delete import delete_objects
from awswrangler.s3._describe import size_objects
from awswrangler.s3._list import does_object_exist
from awswrangler.s3._read_parquet import _read_parquet_metadata
from awswrangler.s3._write import _COMPRESSION_2_EXT, _apply_dtype, _sanitize, _validate_args
from awswrangler.s3._write_dataset import _to_dataset
Expand All @@ -32,6 +36,7 @@ def _to_parquet_file(
s3_additional_kwargs: Optional[Dict[str, str]],
path: Optional[str] = None,
path_root: Optional[str] = None,
max_file_size: Optional[int] = 0,
) -> str:
if path is None and path_root is not None:
file_path: str = f"{path_root}{uuid.uuid4().hex}{compression_ext}.parquet"
Expand All @@ -40,6 +45,7 @@ def _to_parquet_file(
else:
raise RuntimeError("path and path_root received at the same time.")
_logger.debug("file_path: %s", file_path)
write_path = file_path
table: pa.Table = pyarrow.Table.from_pandas(df=df, schema=schema, nthreads=cpus, preserve_index=index, safe=True)
for col_name, col_type in dtype.items():
if col_name in table.column_names:
Expand All @@ -53,17 +59,58 @@ def _to_parquet_file(
session=boto3_session,
s3_additional_kwargs=s3_additional_kwargs, # 32 MB (32 * 2**20)
)
with pyarrow.parquet.ParquetWriter(
where=file_path,
write_statistics=True,
use_dictionary=True,
filesystem=fs,
coerce_timestamps="ms",
compression=compression,
flavor="spark",
schema=table.schema,
) as writer:
writer.write_table(table)

file_counter, writer, chunks, chunk_size = 1, None, 1, df.shape[0]
if max_file_size is not None and max_file_size > 0:
chunk_size = int((max_file_size * df.shape[0]) / table.nbytes)
chunks = math.ceil(df.shape[0] / chunk_size)

for chunk in range(chunks):
offset = chunk * chunk_size

if writer is None:
writer = pyarrow.parquet.ParquetWriter(
where=write_path,
write_statistics=True,
use_dictionary=True,
filesystem=fs,
coerce_timestamps="ms",
compression=compression,
flavor="spark",
schema=table.schema,
)
# handle the case of overwriting an existing file
if does_object_exist(write_path):
delete_objects([write_path])

writer.write_table(table.slice(offset, chunk_size))

if max_file_size == 0 or max_file_size is None:
continue

file_size = writer.file_handle.buffer.__sizeof__()
if does_object_exist(write_path):
file_size += size_objects([write_path])[write_path]

if file_size >= max_file_size:
write_path = __get_file_path(file_counter, file_path)
file_counter += 1
writer.close()
writer = None

if writer is not None:
writer.close()

return file_path


def __get_file_path(file_counter, file_path):
dot_index = file_path.rfind(".")
file_index = "-" + str(file_counter)
if dot_index == -1:
file_path = file_path + file_index
else:
file_path = file_path[:dot_index] + file_index + file_path[dot_index:]
return file_path


Expand Down Expand Up @@ -95,6 +142,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
projection_values: Optional[Dict[str, str]] = None,
projection_intervals: Optional[Dict[str, str]] = None,
projection_digits: Optional[Dict[str, str]] = None,
max_file_size: Optional[int] = 0,
catalog_id: Optional[str] = None,
) -> Dict[str, Union[List[str], Dict[str, List[str]]]]:
"""Write Parquet file or dataset on Amazon S3.
Expand Down Expand Up @@ -197,6 +245,10 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
Dictionary of partitions names and Athena projections digits.
https://docs.aws.amazon.com/athena/latest/ug/partition-projection-supported-types.html
(e.g. {'col_name': '1', 'col2_name': '2'})
max_file_size : int
If the file size exceeds the specified size in bytes, another file is created
Default is 0 i.e. dont split the files
(e.g. 33554432 ,268435456,0)
catalog_id : str, optional
The ID of the Data Catalog from which to retrieve Databases.
If none is provided, the AWS account ID is used by default.
Expand Down Expand Up @@ -361,6 +413,7 @@ def to_parquet( # pylint: disable=too-many-arguments,too-many-locals
boto3_session=session,
s3_additional_kwargs=s3_additional_kwargs,
dtype=dtype,
max_file_size=max_file_size,
)
]
else:
Expand Down
16 changes: 16 additions & 0 deletions tests/test_moto.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,22 @@ def test_parquet(moto_s3):
assert df.shape == (3, 19)


def test_parquet_with_size(moto_s3):
path = "s3://bucket/test.parquet"
df = get_df_list()
for i in range(20):
df = pd.concat([df, get_df_list()])
wr.s3.to_parquet(df=df, path=path, index=False, dataset=False, max_file_size=1 * 2 ** 10)
df = wr.s3.read_parquet(path="s3://bucket/", dataset=False)
ensure_data_types(df, has_list=True)
assert df.shape == (63, 19)
file_objects = wr.s3.list_objects(path="s3://bucket/")
assert len(file_objects) == 9
for i in range(7):
assert f"s3://bucket/test-{i+1}.parquet" in file_objects
assert "s3://bucket/test.parquet" in file_objects


def test_s3_delete_object_success(moto_s3):
path = "s3://bucket/test.parquet"
wr.s3.to_parquet(df=get_df_list(), path=path, index=False, dataset=True, partition_cols=["par0", "par1"])
Expand Down