Skip to content

Commit

Permalink
[Data] Add Datasource.on_write_start (ray-project#38298)
Browse files Browse the repository at this point in the history
Currently, we attempt to create a directory in every write task. This can cause rate limiting issues with S3. To address this problem, this PR adds a on_write_start method that is executed once per write job.

Signed-off-by: Balaji Veeramani <balaji@anyscale.com>
Signed-off-by: NripeshN <nn2012@hw.ac.uk>
  • Loading branch information
bveeramani authored and NripeshN committed Aug 15, 2023
1 parent d161fe1 commit 790fc39
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 10 deletions.
8 changes: 8 additions & 0 deletions python/ray/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,14 @@ py_test(
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_file_based_datasource",
size = "small",
srcs = ["tests/test_file_based_datasource.py"],
tags = ["team:data", "exclusive"],
deps = ["//:ray_lib", ":conftest"],
)

py_test(
name = "test_image",
size = "small",
Expand Down
3 changes: 2 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3317,6 +3317,7 @@ def write_fn_wrapper(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]:
try:
import pandas as pd

datasource.on_write_start(**write_args)
self._write_ds = Dataset(
plan, self._epoch, self._lazy, logical_plan
).materialize()
Expand All @@ -3326,7 +3327,7 @@ def write_fn_wrapper(blocks: Iterator[Block], ctx, fn) -> Iterator[Block]:
for block in blocks
)
write_results = [block["write_result"][0] for block in blocks]
datasource.on_write_complete(write_results)
datasource.on_write_complete(write_results, **write_args)
except Exception as e:
datasource.on_write_failed([], e)
raise
Expand Down
11 changes: 11 additions & 0 deletions python/ray/data/datasource/datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,17 @@ def prepare_read(self, parallelism: int, **read_args) -> List["ReadTask"]:
"""Deprecated: Please implement create_reader() instead."""
raise NotImplementedError

def on_write_start(self, **write_args) -> None:
"""Callback for when a write job starts.
Use this method to perform setup for write tasks. For example, creating a
staging bucket in S3.
Args:
write_args: Additional kwargs to pass to the datasource impl.
"""
pass

def write(
self,
blocks: Iterable[Block],
Expand Down
52 changes: 43 additions & 9 deletions python/ray/data/datasource/file_based_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,32 @@ def _read_file(self, f: "pyarrow.NativeFile", path: str, **reader_args) -> Block
"Subclasses of FileBasedDatasource must implement _read_file()."
)

def on_write_start(
self,
path: str,
try_create_dir: bool = True,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
**write_args,
) -> None:
"""Create a directory to write files to.
If ``try_create_dir`` is ``False``, this method is a no-op.
"""
from pyarrow.fs import FileType

self.has_created_dir = False
if try_create_dir:
paths, filesystem = _resolve_paths_and_filesystem(path, filesystem)
assert len(paths) == 1, len(paths)
path = paths[0]

if filesystem.get_file_info(path).type is FileType.NotFound:
# Arrow's S3FileSystem doesn't allow creating buckets by default, so we
# add a query arg enabling bucket creation if an S3 URI is provided.
tmp = _add_creatable_buckets_param_if_s3_uri(path)
filesystem.create_dir(tmp, recursive=True)
self.has_created_dir = True

def write(
self,
blocks: Iterable[Block],
Expand Down Expand Up @@ -306,15 +332,6 @@ def write(
if block.num_rows() == 0:
continue

if block_idx == 0:
# On the first non-empty block, try to create the directory.
if try_create_dir:
# Arrow's S3FileSystem doesn't allow creating buckets by
# default, so we add a query arg enabling bucket creation
# if an S3 URI is provided.
tmp = _add_creatable_buckets_param_if_s3_uri(path)
filesystem.create_dir(tmp, recursive=True)

fs = _unwrap_s3_serialization_workaround(filesystem)

if self._WRITE_FILE_PER_ROW:
Expand Down Expand Up @@ -367,6 +384,23 @@ def write(
# succeeds.
return "ok"

def on_write_complete(
self,
write_results: List[WriteResult],
path: Optional[str] = None,
filesystem: Optional["pyarrow.fs.FileSystem"] = None,
**kwargs,
) -> None:
if not self.has_created_dir:
return

paths, filesystem = _resolve_paths_and_filesystem(path, filesystem)
assert len(paths) == 1, len(paths)
path = paths[0]

if all(write_results == "skip" for write_results in write_results):
filesystem.delete_dir(path)

def _write_block(
self,
f: "pyarrow.NativeFile",
Expand Down
43 changes: 43 additions & 0 deletions python/ray/data/tests/test_file_based_datasource.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os

import pyarrow
import pytest

import ray
from ray.data.block import BlockAccessor
from ray.data.datasource import FileBasedDatasource


class MockFileBasedDatasource(FileBasedDatasource):
def _write_block(
self, f: "pyarrow.NativeFile", block: BlockAccessor, **writer_args
):
f.write(b"")


@pytest.mark.parametrize("num_rows", [0, 1])
def test_write_preserves_user_directory(num_rows, tmp_path, ray_start_regular_shared):
ds = ray.data.range(num_rows)
path = os.path.join(tmp_path, "test")
os.mkdir(path) # User-created directory

ds.write_datasource(MockFileBasedDatasource(), dataset_uuid=ds._uuid, path=path)

assert os.path.isdir(path)


def test_write_creates_dir(tmp_path, ray_start_regular_shared):
ds = ray.data.range(1)
path = os.path.join(tmp_path, "test")

ds.write_datasource(
MockFileBasedDatasource(), dataset_uuid=ds._uuid, path=path, try_create_dir=True
)

assert os.path.isdir(path)


if __name__ == "__main__":
import sys

sys.exit(pytest.main(["-v", __file__]))

0 comments on commit 790fc39

Please sign in to comment.