From e8bffc6a509df525b1f6964fbc57e2e1f19d5326 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 27 Jun 2024 14:59:38 +0000 Subject: [PATCH 1/8] update --- src/litdata/__init__.py | 3 +- src/litdata/processing/functions.py | 104 +++++++++++++++++++++++++++- 2 files changed, 103 insertions(+), 4 deletions(-) diff --git a/src/litdata/__init__.py b/src/litdata/__init__.py index 9245ac76a..406fc383b 100644 --- a/src/litdata/__init__.py +++ b/src/litdata/__init__.py @@ -13,7 +13,7 @@ from litdata.__about__ import * # noqa: F403 from litdata.imports import RequirementCache -from litdata.processing.functions import map, optimize, walk +from litdata.processing.functions import map, optimize, walk, merge_datasets from litdata.streaming.combined import CombinedStreamingDataset from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.dataset import StreamingDataset @@ -29,6 +29,7 @@ "optimize", "walk", "train_test_split", + "merge_datasets", ] if RequirementCache("lightning_sdk"): from lightning_sdk import Machine # noqa: F401 diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index c38db2b3e..524b2e1de 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -20,10 +20,10 @@ from types import FunctionType from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union from urllib import parse - +from dataclasses import dataclass import torch - -from litdata.constants import _IS_IN_STUDIO +import shutil +from litdata.constants import _IS_IN_STUDIO, _INDEX_FILENAME from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -31,6 +31,9 @@ optimize_dns_context, read_index_file_content, ) +import tempfile +from litdata.streaming.client import S3Client +import json from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.resolver import ( Dir, @@ -39,7 +42,9 @@ _execute, _resolve_dir, ) +from io import BytesIO from litdata.utilities._pytree import tree_flatten +from tqdm.auto import tqdm def _is_remote_file(path: str) -> bool: @@ -470,3 +475,96 @@ def __iter__(self) -> Any: future = executor.submit(_listdir, folder) self.futures.append(future) return + + +@dataclass +class CopyInfo: + input_dir: Dir + old_filename: str + new_filename: str + + +def merge_datasets(input_dirs: List[str], output_dir: str) -> None: + if len(input_dirs) == 0: + raise ValueError("The input directories needs to be defined.") + + if len(input_dirs) == 1: + raise ValueError("There should be more than 1 input directory") + + resolved_input_dirs = [_resolve_dir(input_dir) for input_dir in input_dirs] + resolved_output_dir = _resolve_dir(output_dir) + + input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] + output_dir_file_content = read_index_file_content(resolved_output_dir) + + if output_dir_file_content is not None: + raise ValueError("The output_dir already contains an optimized dataset") + + for input_dir_file_content in input_dirs_file_content[1:]: + if input_dirs_file_content[0]['config']['data_format'] != input_dir_file_content['config']['data_format']: + raise ValueError("Your are trying to merge datasets with different data formats") + + if input_dirs_file_content[0]['config']['compression'] != input_dir_file_content['config']['compression']: + raise ValueError("Your are trying to merge datasets with different compression configuration.") + + chunks = [] + copy_infos: List[CopyInfo] = [] + counter = 0 + for input_dir, input_dir_file_content in zip(resolved_input_dirs, input_dirs_file_content): + for chunk in input_dir_file_content["chunks"]: + old_filename = chunk["filename"] + new_filename = f"chunk-0-{counter}.bin" + copy_infos.append(CopyInfo(input_dir=input_dir, old_filename=old_filename, new_filename=new_filename)) + chunk['filename'] = new_filename + chunks.append(chunk) + counter += 1 + + index_json = { + "config": input_dirs_file_content[0]['config'], + "chunks": chunks + } + + for copy_info in tqdm(copy_infos): + _apply_copy(copy_info, resolved_output_dir) + + _save_index(index_json, resolved_output_dir) + + +def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: + if output_dir.url is None and copy_info.input_dir.url is None: + input_filepath = os.path.join(copy_info.input_dir.path, copy_info.old_filename) + output_filepath = os.path.join(output_dir.path, copy_info.new_filename) + os.makedirs(os.path.dirname(output_filepath), exist_ok=True) + shutil.copyfile(input_filepath, output_filepath) + + elif output_dir.url and copy_info.input_dir.url: + input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename)) + output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename)) + + s3 = S3Client() + s3.client.copy( + {"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")}, + output_obj.netloc, + output_obj.path.lstrip("/"), + ) + else: + raise NotImplementedError + +def _save_index(index_json: Dict, output_dir: Dir) -> None: + if output_dir.url is None: + with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: + json.dump(index_json, f) + else: + with tempfile.NamedTemporaryFile("w") as f: + json.dump(index_json, f) + + f.flush() + + obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME)) + + s3 = S3Client() + s3.client.upload_file( + f.name, + obj.netloc, + obj.path.lstrip("/"), + ) \ No newline at end of file From d9a731bb70e8591c6decb78648c6ea837bd1db51 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 27 Jun 2024 15:03:26 +0000 Subject: [PATCH 2/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litdata/__init__.py | 2 +- src/litdata/processing/functions.py | 31 ++++++++++++++--------------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/src/litdata/__init__.py b/src/litdata/__init__.py index 406fc383b..3ab6fb356 100644 --- a/src/litdata/__init__.py +++ b/src/litdata/__init__.py @@ -13,7 +13,7 @@ from litdata.__about__ import * # noqa: F403 from litdata.imports import RequirementCache -from litdata.processing.functions import map, optimize, walk, merge_datasets +from litdata.processing.functions import map, merge_datasets, optimize, walk from litdata.streaming.combined import CombinedStreamingDataset from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.dataset import StreamingDataset diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 524b2e1de..ec312e7b0 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -13,17 +13,22 @@ import concurrent.futures import inspect +import json import os +import shutil +import tempfile +from dataclasses import dataclass from datetime import datetime from functools import partial from pathlib import Path from types import FunctionType from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Union from urllib import parse -from dataclasses import dataclass + import torch -import shutil -from litdata.constants import _IS_IN_STUDIO, _INDEX_FILENAME +from tqdm.auto import tqdm + +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -31,9 +36,7 @@ optimize_dns_context, read_index_file_content, ) -import tempfile from litdata.streaming.client import S3Client -import json from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.resolver import ( Dir, @@ -42,9 +45,7 @@ _execute, _resolve_dir, ) -from io import BytesIO from litdata.utilities._pytree import tree_flatten -from tqdm.auto import tqdm def _is_remote_file(path: str) -> bool: @@ -496,15 +497,15 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] output_dir_file_content = read_index_file_content(resolved_output_dir) - + if output_dir_file_content is not None: raise ValueError("The output_dir already contains an optimized dataset") for input_dir_file_content in input_dirs_file_content[1:]: - if input_dirs_file_content[0]['config']['data_format'] != input_dir_file_content['config']['data_format']: + if input_dirs_file_content[0]["config"]["data_format"] != input_dir_file_content["config"]["data_format"]: raise ValueError("Your are trying to merge datasets with different data formats") - if input_dirs_file_content[0]['config']['compression'] != input_dir_file_content['config']['compression']: + if input_dirs_file_content[0]["config"]["compression"] != input_dir_file_content["config"]["compression"]: raise ValueError("Your are trying to merge datasets with different compression configuration.") chunks = [] @@ -515,14 +516,11 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: old_filename = chunk["filename"] new_filename = f"chunk-0-{counter}.bin" copy_infos.append(CopyInfo(input_dir=input_dir, old_filename=old_filename, new_filename=new_filename)) - chunk['filename'] = new_filename + chunk["filename"] = new_filename chunks.append(chunk) counter += 1 - index_json = { - "config": input_dirs_file_content[0]['config'], - "chunks": chunks - } + index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} for copy_info in tqdm(copy_infos): _apply_copy(copy_info, resolved_output_dir) @@ -550,6 +548,7 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: else: raise NotImplementedError + def _save_index(index_json: Dict, output_dir: Dir) -> None: if output_dir.url is None: with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: @@ -567,4 +566,4 @@ def _save_index(index_json: Dict, output_dir: Dir) -> None: f.name, obj.netloc, obj.path.lstrip("/"), - ) \ No newline at end of file + ) From 55356ddb1cbf9ac057feaf3777628e3592cc6cc1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 27 Jun 2024 16:19:34 +0100 Subject: [PATCH 3/8] update --- tests/processing/test_functions.py | 36 +++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/tests/processing/test_functions.py b/tests/processing/test_functions.py index c7bb11b1b..adabf76ea 100644 --- a/tests/processing/test_functions.py +++ b/tests/processing/test_functions.py @@ -3,8 +3,9 @@ from unittest import mock import pytest -from litdata import StreamingDataset, optimize, walk +from litdata import StreamingDataset, merge_datasets, optimize, walk from litdata.processing.functions import _get_input_dir, _resolve_dir +from litdata.streaming.cache import Cache @pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.") @@ -154,3 +155,36 @@ def test_optimize_append_overwrite(tmpdir): assert len(ds) == 5 assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)] + + +def test_merge_datasets(tmpdir): + folder_1 = os.path.join(tmpdir, "folder_1") + folder_2 = os.path.join(tmpdir, "folder_2") + folder_3 = os.path.join(tmpdir, "folder_3") + + os.makedirs(folder_1, exist_ok=True) + os.makedirs(folder_2, exist_ok=True) + + cache_1 = Cache(input_dir=folder_1, chunk_bytes="64MB") + for i in range(10): + cache_1[i] = i + + cache_1.done() + cache_1.merge() + + cache_2 = Cache(input_dir=folder_2, chunk_bytes="64MB") + for i in range(10, 20): + cache_2[i] = i + + cache_2.done() + cache_2.merge() + + merge_datasets( + input_dirs=[folder_1, folder_2], + output_dir=folder_3, + ) + + ds = StreamingDataset(input_dir=folder_3) + + assert len(ds) == 20 + assert ds[:] == list(range(20)) From 3344dfcc2a8da3092e6cf80acb3eb066023ad067 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 27 Jun 2024 16:25:33 +0100 Subject: [PATCH 4/8] update --- src/litdata/processing/functions.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index ec312e7b0..09a463bd4 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -486,6 +486,14 @@ class CopyInfo: def merge_datasets(input_dirs: List[str], output_dir: str) -> None: + """The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized + dataset. + + Arguments: + input_dirs: A list of directories pointing to the existing optimized datasets. + output_dir: The directory where the merged dataset would be stored. + + """ if len(input_dirs) == 0: raise ValueError("The input directories needs to be defined.") @@ -496,6 +504,10 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: resolved_output_dir = _resolve_dir(output_dir) input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] + + if any(file_content is None for file_content in input_dirs_file_content): + raise ValueError("One of the provided input_dir doesn't have an index file.") + output_dir_file_content = read_index_file_content(resolved_output_dir) if output_dir_file_content is not None: From d6c3f38ac892279107d5bb2de70b733aeb1c6f8c Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 27 Jun 2024 16:34:29 +0100 Subject: [PATCH 5/8] update --- src/litdata/constants.py | 1 + src/litdata/processing/data_processor.py | 4 +--- src/litdata/processing/functions.py | 12 +++++++++--- src/litdata/processing/readers.py | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/litdata/constants.py b/src/litdata/constants.py index 17c6482db..8befa208b 100644 --- a/src/litdata/constants.py +++ b/src/litdata/constants.py @@ -32,6 +32,7 @@ _TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio") _ZSTD_AVAILABLE = RequirementCache("zstd") _GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage") +_TQDM_AVAILABLE = RequirementCache("tqdm") # DON'T CHANGE ORDER diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 8906fe2ac..30ef688ea 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -39,8 +39,8 @@ _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE, + _TQDM_AVAILABLE, ) -from litdata.imports import RequirementCache from litdata.processing.readers import BaseReader, StreamingDataLoaderReader from litdata.processing.utilities import _create_dataset from litdata.streaming import Cache @@ -52,8 +52,6 @@ from litdata.utilities.broadcast import broadcast_object from litdata.utilities.packing import _pack_greedily -_TQDM_AVAILABLE = RequirementCache("tqdm") - if _TQDM_AVAILABLE: from tqdm.auto import tqdm as _tqdm diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 09a463bd4..b282f514a 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -26,9 +26,8 @@ from urllib import parse import torch -from tqdm.auto import tqdm -from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO +from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _TQDM_AVAILABLE from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import ( @@ -47,6 +46,13 @@ ) from litdata.utilities._pytree import tree_flatten +if _TQDM_AVAILABLE: + from tqdm.auto import tqdm as _tqdm +else: + + def _tqdm(iterator: Any) -> Any: + yield from iterator + def _is_remote_file(path: str) -> bool: obj = parse.urlparse(path) @@ -534,7 +540,7 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} - for copy_info in tqdm(copy_infos): + for copy_info in _tqdm(copy_infos): _apply_copy(copy_info, resolved_output_dir) _save_index(index_json, resolved_output_dir) diff --git a/src/litdata/processing/readers.py b/src/litdata/processing/readers.py index f87c536c8..eb3c1511f 100644 --- a/src/litdata/processing/readers.py +++ b/src/litdata/processing/readers.py @@ -16,11 +16,11 @@ from abc import ABC, abstractmethod from typing import Any, List +from litdata.constants import _TQDM_AVAILABLE from litdata.imports import RequirementCache from litdata.streaming.dataloader import StreamingDataLoader _PYARROW_AVAILABLE = RequirementCache("pyarrow") -_TQDM_AVAILABLE = RequirementCache("tqdm") if _TQDM_AVAILABLE: from tqdm.auto import tqdm as _tqdm From 4eb7d50cbd9cdcfaa85585272d53e151e2c959fa Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 27 Jun 2024 16:42:08 +0100 Subject: [PATCH 6/8] update --- src/litdata/processing/functions.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index b282f514a..3f798ca1b 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -519,18 +519,21 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: if output_dir_file_content is not None: raise ValueError("The output_dir already contains an optimized dataset") + assert input_dirs_file_content + for input_dir_file_content in input_dirs_file_content[1:]: - if input_dirs_file_content[0]["config"]["data_format"] != input_dir_file_content["config"]["data_format"]: + if input_dirs_file_content[0]["config"]["data_format"] != input_dir_file_content["config"]["data_format"]: # type: ignore raise ValueError("Your are trying to merge datasets with different data formats") - if input_dirs_file_content[0]["config"]["compression"] != input_dir_file_content["config"]["compression"]: + if input_dirs_file_content[0]["config"]["compression"] != input_dir_file_content["config"]["compression"]: # type: ignore raise ValueError("Your are trying to merge datasets with different compression configuration.") chunks = [] copy_infos: List[CopyInfo] = [] counter = 0 for input_dir, input_dir_file_content in zip(resolved_input_dirs, input_dirs_file_content): - for chunk in input_dir_file_content["chunks"]: + for chunk in input_dir_file_content["chunks"]: # type: ignore + assert isinstance(chunk, dict) old_filename = chunk["filename"] new_filename = f"chunk-0-{counter}.bin" copy_infos.append(CopyInfo(input_dir=input_dir, old_filename=old_filename, new_filename=new_filename)) @@ -538,7 +541,7 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: chunks.append(chunk) counter += 1 - index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} + index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore for copy_info in _tqdm(copy_infos): _apply_copy(copy_info, resolved_output_dir) @@ -548,6 +551,8 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: if output_dir.url is None and copy_info.input_dir.url is None: + assert copy_info.input_dir.path + assert output_dir.path input_filepath = os.path.join(copy_info.input_dir.path, copy_info.old_filename) output_filepath = os.path.join(output_dir.path, copy_info.new_filename) os.makedirs(os.path.dirname(output_filepath), exist_ok=True) @@ -569,6 +574,7 @@ def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None: def _save_index(index_json: Dict, output_dir: Dir) -> None: if output_dir.url is None: + assert output_dir.path with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f: json.dump(index_json, f) else: From a0a8d2f2cee23882e624e93c280861ce35a4bba0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 27 Jun 2024 17:07:07 +0100 Subject: [PATCH 7/8] update --- src/litdata/processing/functions.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 3f798ca1b..6fe3199c1 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -509,6 +509,9 @@ def merge_datasets(input_dirs: List[str], output_dir: str) -> None: resolved_input_dirs = [_resolve_dir(input_dir) for input_dir in input_dirs] resolved_output_dir = _resolve_dir(output_dir) + if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs): + raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.") + input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs] if any(file_content is None for file_content in input_dirs_file_content): From 92cea4e4086164bd107eed37cb8dad792b629150 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 27 Jun 2024 17:20:06 +0100 Subject: [PATCH 8/8] update --- tests/streaming/test_reader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/streaming/test_reader.py b/tests/streaming/test_reader.py index e9f2395c9..69fac570e 100644 --- a/tests/streaming/test_reader.py +++ b/tests/streaming/test_reader.py @@ -102,5 +102,7 @@ def test_prepare_chunks_thread_eviction(tmpdir, monkeypatch): assert thread._pre_download_counter <= 2 assert len(os.listdir(cache_dir)) == 9 - assert thread._has_exited + thread.join() + sleep(0.1) + assert thread._has_exited