diff --git a/src/litdata/__init__.py b/src/litdata/__init__.py index 9245ac76a..3ab6fb356 100644 --- a/src/litdata/__init__.py +++ b/src/litdata/__init__.py @@ -13,7 +13,7 @@ from litdata.__about__ import * # noqa: F403 from litdata.imports import RequirementCache -from litdata.processing.functions import map, optimize, walk +from litdata.processing.functions import map, merge_datasets, optimize, walk from litdata.streaming.combined import CombinedStreamingDataset from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.dataset import StreamingDataset @@ -29,6 +29,7 @@ "optimize", "walk", "train_test_split", + "merge_datasets", ] if RequirementCache("lightning_sdk"): from lightning_sdk import Machine # noqa: F401 diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 17c6482db..8befa208b 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -32,6 +32,7 @@ _TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio") _ZSTD_AVAILABLE = RequirementCache("zstd") _GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage") +_TQDM_AVAILABLE = RequirementCache("tqdm") # DON'T CHANGE ORDER diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 8906fe2ac..30ef688ea 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -39,8 +39,8 @@ _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE, + _TQDM_AVAILABLE, ) -from litdata.imports import RequirementCache from litdata.processing.readers import BaseReader, StreamingDataLoaderReader from litdata.processing.utilities import _create_dataset from litdata.streaming import Cache @@ -52,8 +52,6 @@ from litdata.utilities.broadcast import broadcast_object from litdata.utilities.packing import _pack_greedily -_TQDM_AVAILABLE = RequirementCache("tqdm") - if _TQDM_AVAILABLE: from tqdm.auto import tqdm as _tqdm diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index c38db2b3e..6fe3199c1 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -13,7 +13,11 @@ import concurrent.futures import inspect +import json import os +import shutil +import tempfile +from dataclasses import dataclass from datetime import datetime from functools import partial from pathlib import Path @@ -23,7 +27,7 @@ import torch -from litdata.constants import _IS_IN_STUDIO +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _TQDM_AVAILABLE from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -31,6 +35,7 @@ optimize_dns_context, read_index_file_content, ) +from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.resolver import ( Dir, @@ -41,6 +46,13 @@ ) from litdata.utilities._pytree import tree_flatten +if _TQDM_AVAILABLE: + from tqdm.auto import tqdm as _tqdm +else: + + def _tqdm(iterator: Any) -> Any: + yield from iterator + def _is_remote_file(path: str) -> bool: obj = parse.urlparse(path) @@ -470,3 +482,115 @@ def __iter__(self) -> Any: future = executor.submit(_listdir, folder) self.futures.append(future) return + + +@dataclass +class CopyInfo: + input_dir: Dir + old_filename: str + new_filename: str + + +def merge_datasets(input_dirs: List[str], output_dir: str) -> None: + """The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized + dataset. + + Arguments: + input_dirs: A list of directories pointing to the existing optimized datasets. + output_dir: The directory where the merged dataset would be stored. + + """ + if len(input_dirs) == 0: + raise ValueError("The input directories needs to be defined.") + + if len(input_dirs) == 1: + raise ValueError("There should be more than 1 input directory") + + resolved_input_dirs = [_resolve_dir(input_dir) for input_dir in input_dirs] + resolved_output_dir = _resolve_dir(output_dir) + + if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs): + raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.") + + input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] + + if any(file_content is None for file_content in input_dirs_file_content): + raise ValueError("One of the provided input_dir doesn't have an index file.") + + output_dir_file_content = read_index_file_content(resolved_output_dir) + + if output_dir_file_content is not None: + raise ValueError("The output_dir already contains an optimized dataset") + + assert input_dirs_file_content + + for input_dir_file_content in input_dirs_file_content[1:]: + if input_dirs_file_content[0]["config"]["data_format"] != input_dir_file_content["config"]["data_format"]: # type: ignore + raise ValueError("Your are trying to merge datasets with different data formats") + + if input_dirs_file_content[0]["config"]["compression"] != input_dir_file_content["config"]["compression"]: # type: ignore + raise ValueError("Your are trying to merge datasets with different compression configuration.") + + chunks = [] + copy_infos: List[CopyInfo] = [] + counter = 0 + for input_dir, input_dir_file_content in zip(resolved_input_dirs, input_dirs_file_content): + for chunk in input_dir_file_content["chunks"]: # type: ignore + assert isinstance(chunk, dict) + old_filename = chunk["filename"] + new_filename = f"chunk-0-{counter}.bin" + copy_infos.append(CopyInfo(input_dir=input_dir, old_filename=old_filename, new_filename=new_filename)) + chunk["filename"] = new_filename + chunks.append(chunk) + counter += 1 + + index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore + + for copy_info in _tqdm(copy_infos): + _apply_copy(copy_info, resolved_output_dir) + + _save_index(index_json, resolved_output_dir) + + +def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: + if output_dir.url is None and copy_info.input_dir.url is None: + assert copy_info.input_dir.path + assert output_dir.path + input_filepath = os.path.join(copy_info.input_dir.path, copy_info.old_filename) + output_filepath = os.path.join(output_dir.path, copy_info.new_filename) + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + shutil.copyfile(input_filepath, output_filepath) + + elif output_dir.url and copy_info.input_dir.url: + input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename)) + output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename)) + + s3 = S3Client() + s3.client.copy( + {"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")}, + output_obj.netloc, + output_obj.path.lstrip("/"), + ) + else: + raise NotImplementedError + + +def _save_index(index_json: Dict, output_dir: Dir) -> None: + if output_dir.url is None: + assert output_dir.path + with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: + json.dump(index_json, f) + else: + with tempfile.NamedTemporaryFile("w") as f: + json.dump(index_json, f) + + f.flush() + + obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME)) + + s3 = S3Client() + s3.client.upload_file( + f.name, + obj.netloc, + obj.path.lstrip("/"), + ) diff --git a/src/litdata/processing/readers.py b/src/litdata/processing/readers.py index f87c536c8..eb3c1511f 100644 --- a/src/litdata/processing/readers.py +++ b/src/litdata/processing/readers.py @@ -16,11 +16,11 @@ from abc import ABC, abstractmethod from typing import Any, List +from litdata.constants import _TQDM_AVAILABLE from litdata.imports import RequirementCache from litdata.streaming.dataloader import StreamingDataLoader _PYARROW_AVAILABLE = RequirementCache("pyarrow") -_TQDM_AVAILABLE = RequirementCache("tqdm") if _TQDM_AVAILABLE: from tqdm.auto import tqdm as _tqdm diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index c7bb11b1b..adabf76ea 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -3,8 +3,9 @@ from unittest import mock import pytest -from litdata import StreamingDataset, optimize, walk +from litdata import StreamingDataset, merge_datasets, optimize, walk from litdata.processing.functions import _get_input_dir, _resolve_dir +from litdata.streaming.cache import Cache @pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") @@ -154,3 +155,36 @@ def test_optimize_append_overwrite(tmpdir): assert len(ds) == 5 assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)] + + +def test_merge_datasets(tmpdir): + folder_1 = os.path.join(tmpdir, "folder_1") + folder_2 = os.path.join(tmpdir, "folder_2") + folder_3 = os.path.join(tmpdir, "folder_3") + + os.makedirs(folder_1, exist_ok=True) + os.makedirs(folder_2, exist_ok=True) + + cache_1 = Cache(input_dir=folder_1, chunk_bytes="64MB") + for i in range(10): + cache_1[i] = i + + cache_1.done() + cache_1.merge() + + cache_2 = Cache(input_dir=folder_2, chunk_bytes="64MB") + for i in range(10, 20): + cache_2[i] = i + + cache_2.done() + cache_2.merge() + + merge_datasets( + input_dirs=[folder_1, folder_2], + output_dir=folder_3, + ) + + ds = StreamingDataset(input_dir=folder_3) + + assert len(ds) == 20 + assert ds[:] == list(range(20)) diff --git a/tests/streaming/test_reader.py b/tests/streaming/test_reader.py index e9f2395c9..69fac570e 100644 --- a/tests/streaming/test_reader.py +++ b/tests/streaming/test_reader.py @@ -102,5 +102,7 @@ def test_prepare_chunks_thread_eviction(tmpdir, monkeypatch): assert thread._pre_download_counter <= 2 assert len(os.listdir(cache_dir)) == 9 - assert thread._has_exited + thread.join() + sleep(0.1) + assert thread._has_exited