Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pyarrow
import ray
from ray.data._internal.delegating_block_builder import DelegatingBlockBuilder
from ray.data.block import Block, BlockAccessor
from ray.data.block import Block, BlockAccessor, BlockMetadata
from ray.data.datasource.datasource import WriteResult
from ray.data.datasource.file_based_datasource import (
BlockWritePathProvider,
Expand Down Expand Up @@ -64,11 +64,24 @@ def __init__(self) -> None:
def _read_file(self, f: pyarrow.NativeFile, path: str, **reader_args: Any) -> pd.DataFrame:
raise NotImplementedError()

def do_write(
def do_write( # pylint: disable=arguments-differ
self,
blocks: List[ObjectRef[pd.DataFrame]],
*args: Any,
**kwargs: Any,
metadata: List[BlockMetadata],
path: str,
dataset_uuid: str,
filesystem: Optional[pyarrow.fs.FileSystem] = None,
try_create_dir: bool = True,
open_stream_args: Optional[Dict[str, Any]] = None,
block_path_provider: BlockWritePathProvider = DefaultBlockWritePathProvider(),
write_args_fn: Callable[[], Dict[str, Any]] = lambda: {},
_block_udf: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None,
ray_remote_args: Optional[Dict[str, Any]] = None,
s3_additional_kwargs: Optional[Dict[str, str]] = None,
pandas_kwargs: Optional[Dict[str, Any]] = None,
compression: Optional[str] = None,
mode: str = "wb",
**write_args: Any,
) -> List[ObjectRef[WriteResult]]:
"""Create and return write tasks for a file-based datasource.

Expand All @@ -77,21 +90,53 @@ def do_write(
plan allowing query optimisation ("fuse" with other operations). The change is not backward-compatible
with earlier versions still attempting to call do_write().
"""
write_tasks = []
path: str = kwargs.pop("path")
dataset_uuid: str = kwargs.pop("dataset_uuid")
ray_remote_args: Dict[str, Any] = kwargs.pop("ray_remote_args") or {}
_write_block_to_file = self._write_block

if ray_remote_args is None:
ray_remote_args = {}

if pandas_kwargs is None:
pandas_kwargs = {}

_write = ray_remote(**ray_remote_args)(self.write)
if not compression:
compression = pandas_kwargs.get("compression")

def write_block(write_path: str, block: pd.DataFrame) -> str:
if _block_udf is not None:
block = _block_udf(block)

with open_s3_object(
path=write_path,
mode=mode,
use_threads=False,
s3_additional_kwargs=s3_additional_kwargs,
encoding=write_args.get("encoding"),
newline=write_args.get("newline"),
) as f:
_write_block_to_file(
f,
BlockAccessor.for_block(block),
pandas_kwargs=pandas_kwargs,
compression=compression,
**write_args,
)
return write_path

write_block_fn = ray_remote(**ray_remote_args)(write_block)

file_suffix = self._get_file_suffix(self._FILE_EXTENSION, compression)
write_tasks = []

for block_idx, block in enumerate(blocks):
write_task = _write(
[block],
TaskContext(task_idx=block_idx),
write_path = block_path_provider(
path,
dataset_uuid,
**kwargs,
filesystem=filesystem,
dataset_uuid=dataset_uuid,
block=block,
block_index=block_idx,
file_format=file_suffix,
)
write_task = write_block_fn(write_path, block)
write_tasks.append(write_task)

return write_tasks
Expand Down
Loading