Skip to content

Commit

Permalink
Add batch_size to map, optimize (#19489)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Feb 16, 2024
1 parent bbc5488 commit bb35e8e
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/lightning/data/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def _get_input_dir(inputs: Sequence[Any]) -> Optional[str]:
return "/" + os.path.join(*str(absolute_path).split("/")[:4])


def _get_default_num_workers() -> int:
if torch.cuda.is_available():
return torch.cuda.device_count()
return os.cpu_count() or 1


class LambdaDataTransformRecipe(DataTransformRecipe):
def __init__(self, fn: Callable[[str, Any], None], inputs: Sequence[Any]):
super().__init__()
Expand Down Expand Up @@ -161,6 +167,7 @@ def map(
reorder_files: bool = True,
error_when_not_empty: bool = False,
reader: Optional[BaseReader] = None,
batch_size: Optional[int] = None,
) -> None:
"""This function map a callbable over a collection of files possibly in a distributed way.
Expand All @@ -178,6 +185,7 @@ def map(
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.
error_when_not_empty: Whether we should error if the output folder isn't empty.
batch_size: Group the inputs into batches of batch_size length.
"""
if not isinstance(inputs, Sequence):
Expand Down Expand Up @@ -212,10 +220,13 @@ def map(

input_dir = _resolve_dir(_get_input_dir(inputs))

if isinstance(batch_size, int) and batch_size > 1:
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]

data_processor = DataProcessor(
input_dir=input_dir,
output_dir=_output_dir,
num_workers=num_workers or os.cpu_count(),
num_workers=num_workers or _get_default_num_workers(),
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
num_uploaders=num_uploaders,
Expand Down Expand Up @@ -247,6 +258,7 @@ def optimize(
num_uploaders: Optional[int] = None,
reorder_files: bool = True,
reader: Optional[BaseReader] = None,
batch_size: Optional[int] = None,
) -> None:
"""This function converts a dataset into chunks possibly in a distributed way.
Expand All @@ -266,6 +278,7 @@ def optimize(
num_uploaders: The numbers of uploaders per worker.
reorder_files: By default, reorders the files by file size to distribute work equally among all workers.
Set this to ``False`` if the order in which samples are processed should be preserved.
batch_size: Group the inputs into batches of batch_size length.
"""
if not isinstance(inputs, Sequence):
Expand Down Expand Up @@ -302,10 +315,13 @@ def optimize(

input_dir = _resolve_dir(_get_input_dir(inputs))

if isinstance(batch_size, int) and batch_size > 1:
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]

data_processor = DataProcessor(
input_dir=input_dir,
output_dir=_output_dir,
num_workers=num_workers or os.cpu_count(),
num_workers=num_workers or _get_default_num_workers(),
fast_dev_run=fast_dev_run,
num_downloaders=num_downloaders,
num_uploaders=num_uploaders,
Expand Down
19 changes: 19 additions & 0 deletions tests/tests_data/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,6 +1025,25 @@ def test_map_is_last(num_workers, expected, tmpdir):
assert sorted(os.listdir(tmpdir)) == expected


def map_batch_size_fn(indexes, output_dir):
path = os.path.join(output_dir, str(indexes))
with open(path, "w") as f:
f.write("hello world")


def test_map_batch_size(tmpdir):
map(
map_batch_size_fn,
list(range(5)),
output_dir=str(tmpdir),
error_when_not_empty=False,
num_workers=1,
batch_size=2,
)

assert sorted(os.listdir(tmpdir)) == ["[0, 1]", "[2, 3]", "[4]"]


def no_op(index):
pass

Expand Down

0 comments on commit bb35e8e

Please sign in to comment.