Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,12 +521,13 @@ class CopyInfo:
new_filename: str


def merge_datasets(input_dirs: List[str], output_dir: str) -> None:
def merge_datasets(input_dirs: List[str], output_dir: str, max_workers: Optional[int] = os.cpu_count()) -> None:
"""Enables to merge multiple existing optimized datasets into a single optimized dataset.

Args:
input_dirs: A list of directories pointing to the existing optimized datasets.
output_dir: The directory where the merged dataset would be stored.
max_workers: Number of workers for multithreading

"""
if len(input_dirs) == 0:
Expand All @@ -537,6 +538,7 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:

resolved_input_dirs = [_resolve_dir(input_dir) for input_dir in input_dirs]
resolved_output_dir = _resolve_dir(output_dir)
max_workers = max_workers or 1

if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs):
raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.")
Expand Down Expand Up @@ -580,8 +582,11 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None:

_tqdm = _get_tqdm_iterator_if_available()

for copy_info in _tqdm(copy_infos):
_apply_copy(copy_info, resolved_output_dir)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures: List[concurrent.futures.Future] = []
for copy_info in _tqdm(copy_infos):
future = executor.submit(_apply_copy, copy_info, resolved_output_dir)
futures.append(future)

_save_index(index_json, resolved_output_dir)

Expand Down
Loading