Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/litdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from litdata.__about__ import * # noqa: F403
from litdata.imports import RequirementCache
from litdata.processing.functions import map, optimize, walk
from litdata.processing.functions import map, merge_datasets, optimize, walk
from litdata.streaming.combined import CombinedStreamingDataset
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.dataset import StreamingDataset
Expand All @@ -29,6 +29,7 @@
"optimize",
"walk",
"train_test_split",
"merge_datasets",
]
if RequirementCache("lightning_sdk"):
from lightning_sdk import Machine # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")
_GOOGLE_STORAGE_AVAILABLE = RequirementCache("google.cloud.storage")
_TQDM_AVAILABLE = RequirementCache("tqdm")


# DON'T CHANGE ORDER
Expand Down
4 changes: 1 addition & 3 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_AVAILABLE,
_TQDM_AVAILABLE,
)
from litdata.imports import RequirementCache
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
from litdata.processing.utilities import _create_dataset
from litdata.streaming import Cache
Expand All @@ -52,8 +52,6 @@
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.packing import _pack_greedily

_TQDM_AVAILABLE = RequirementCache("tqdm")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

Expand Down
126 changes: 125 additions & 1 deletion src/litdata/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@

import concurrent.futures
import inspect
import json
import os
import shutil
import tempfile
from dataclasses import dataclass
from datetime import datetime
from functools import partial
from pathlib import Path
Expand All @@ -23,14 +27,15 @@

import torch

from litdata.constants import _IS_IN_STUDIO
from litdata.constants import _INDEX_FILENAME, _IS_IN_STUDIO, _TQDM_AVAILABLE
from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from litdata.processing.readers import BaseReader
from litdata.processing.utilities import (
extract_rank_and_index_from_filename,
optimize_dns_context,
read_index_file_content,
)
from litdata.streaming.client import S3Client
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.resolver import (
Dir,
Expand All @@ -41,6 +46,13 @@
)
from litdata.utilities._pytree import tree_flatten

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator


def _is_remote_file(path: str) -> bool:
obj = parse.urlparse(path)
Expand Down Expand Up @@ -470,3 +482,115 @@ def __iter__(self) -> Any:
future = executor.submit(_listdir, folder)
self.futures.append(future)
return


@dataclass
class CopyInfo:
input_dir: Dir
old_filename: str
new_filename: str


def merge_datasets(input_dirs: List[str], output_dir: str) -> None:
"""The merge_datasets utility enables to merge multiple existing optimized datasets into a single optimized
dataset.

Arguments:
input_dirs: A list of directories pointing to the existing optimized datasets.
output_dir: The directory where the merged dataset would be stored.

"""
if len(input_dirs) == 0:
raise ValueError("The input directories needs to be defined.")

if len(input_dirs) == 1:
raise ValueError("There should be more than 1 input directory")

resolved_input_dirs = [_resolve_dir(input_dir) for input_dir in input_dirs]
resolved_output_dir = _resolve_dir(output_dir)

if any(input_dir == resolved_output_dir for input_dir in resolved_input_dirs):
raise ValueError("The provided output_dir was found within the input_dirs. This isn't supported.")

input_dirs_file_content = [read_index_file_content(input_dir) for input_dir in resolved_input_dirs]

if any(file_content is None for file_content in input_dirs_file_content):
raise ValueError("One of the provided input_dir doesn't have an index file.")

output_dir_file_content = read_index_file_content(resolved_output_dir)

if output_dir_file_content is not None:
raise ValueError("The output_dir already contains an optimized dataset")

assert input_dirs_file_content

for input_dir_file_content in input_dirs_file_content[1:]:
if input_dirs_file_content[0]["config"]["data_format"] != input_dir_file_content["config"]["data_format"]: # type: ignore
raise ValueError("Your are trying to merge datasets with different data formats")

if input_dirs_file_content[0]["config"]["compression"] != input_dir_file_content["config"]["compression"]: # type: ignore
raise ValueError("Your are trying to merge datasets with different compression configuration.")

chunks = []
copy_infos: List[CopyInfo] = []
counter = 0
for input_dir, input_dir_file_content in zip(resolved_input_dirs, input_dirs_file_content):
for chunk in input_dir_file_content["chunks"]: # type: ignore
assert isinstance(chunk, dict)
old_filename = chunk["filename"]
new_filename = f"chunk-0-{counter}.bin"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work! This is exactly what I needed.

I notice that this filename is incorrect if compression==zstd.

The correct filename is "chunk-0-{counter}.{compression}.bin"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @ouj. You are right ! Do you want to contribute a fix ?

Copy link
Contributor

@ouj ouj Jul 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Will do when I get a chance. Should be a two line change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW @ouj You can join our Discord to follow dev on Litdata: https://discord.gg/BH765hvQ

copy_infos.append(CopyInfo(input_dir=input_dir, old_filename=old_filename, new_filename=new_filename))
chunk["filename"] = new_filename
chunks.append(chunk)
counter += 1

index_json = {"config": input_dirs_file_content[0]["config"], "chunks": chunks} # type: ignore

for copy_info in _tqdm(copy_infos):
_apply_copy(copy_info, resolved_output_dir)

_save_index(index_json, resolved_output_dir)


def _apply_copy(copy_info: CopyInfo, output_dir: Dir) -> None:
if output_dir.url is None and copy_info.input_dir.url is None:
assert copy_info.input_dir.path
assert output_dir.path
input_filepath = os.path.join(copy_info.input_dir.path, copy_info.old_filename)
output_filepath = os.path.join(output_dir.path, copy_info.new_filename)
os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
shutil.copyfile(input_filepath, output_filepath)

elif output_dir.url and copy_info.input_dir.url:
input_obj = parse.urlparse(os.path.join(copy_info.input_dir.url, copy_info.old_filename))
output_obj = parse.urlparse(os.path.join(output_dir.url, copy_info.new_filename))

s3 = S3Client()
s3.client.copy(
{"Bucket": input_obj.netloc, "Key": input_obj.path.lstrip("/")},
output_obj.netloc,
output_obj.path.lstrip("/"),
)
else:
raise NotImplementedError


def _save_index(index_json: Dict, output_dir: Dir) -> None:
if output_dir.url is None:
assert output_dir.path
with open(os.path.join(output_dir.path, _INDEX_FILENAME), "w") as f:
json.dump(index_json, f)
else:
with tempfile.NamedTemporaryFile("w") as f:
json.dump(index_json, f)

f.flush()

obj = parse.urlparse(os.path.join(output_dir.url, _INDEX_FILENAME))

s3 = S3Client()
s3.client.upload_file(
f.name,
obj.netloc,
obj.path.lstrip("/"),
)
2 changes: 1 addition & 1 deletion src/litdata/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from abc import ABC, abstractmethod
from typing import Any, List

from litdata.constants import _TQDM_AVAILABLE
from litdata.imports import RequirementCache
from litdata.streaming.dataloader import StreamingDataLoader

_PYARROW_AVAILABLE = RequirementCache("pyarrow")
_TQDM_AVAILABLE = RequirementCache("tqdm")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
Expand Down
36 changes: 35 additions & 1 deletion tests/processing/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from unittest import mock

import pytest
from litdata import StreamingDataset, optimize, walk
from litdata import StreamingDataset, merge_datasets, optimize, walk
from litdata.processing.functions import _get_input_dir, _resolve_dir
from litdata.streaming.cache import Cache


@pytest.mark.skipif(sys.platform == "win32", reason="currently not supported for windows.")
Expand Down Expand Up @@ -154,3 +155,36 @@ def test_optimize_append_overwrite(tmpdir):

assert len(ds) == 5
assert ds[:] == [(i, i**2, i**3) for i in range(0, 5)]


def test_merge_datasets(tmpdir):
folder_1 = os.path.join(tmpdir, "folder_1")
folder_2 = os.path.join(tmpdir, "folder_2")
folder_3 = os.path.join(tmpdir, "folder_3")

os.makedirs(folder_1, exist_ok=True)
os.makedirs(folder_2, exist_ok=True)

cache_1 = Cache(input_dir=folder_1, chunk_bytes="64MB")
for i in range(10):
cache_1[i] = i

cache_1.done()
cache_1.merge()

cache_2 = Cache(input_dir=folder_2, chunk_bytes="64MB")
for i in range(10, 20):
cache_2[i] = i

cache_2.done()
cache_2.merge()

merge_datasets(
input_dirs=[folder_1, folder_2],
output_dir=folder_3,
)

ds = StreamingDataset(input_dir=folder_3)

assert len(ds) == 20
assert ds[:] == list(range(20))
4 changes: 3 additions & 1 deletion tests/streaming/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,5 +102,7 @@ def test_prepare_chunks_thread_eviction(tmpdir, monkeypatch):
assert thread._pre_download_counter <= 2

assert len(os.listdir(cache_dir)) == 9
assert thread._has_exited

thread.join()
sleep(0.1)
assert thread._has_exited