Skip to content

Commit

Permalink
Add feature to slice, subsample and split dataset (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
deependujha committed Jun 17, 2024
1 parent b51b597 commit 5c242b4
Show file tree
Hide file tree
Showing 25 changed files with 900 additions and 124 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,7 @@ lightning_logs

# Ruff
.ruff_cache/


# status.json file
status.json
40 changes: 38 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ dataloader = StreamingDataLoader(dataset)
# Key Features

- [Multi-GPU / Multi-Node Support](#multi-gpu--multi-node-support)
- [Subsample and split your datasets](#access-any-item)
- [Access any item](#access-any-item)
- [Use any data transforms](#use-any-data-transforms)
- [The Map Operator](#the-map-operator)
Expand All @@ -131,6 +132,7 @@ dataloader = StreamingDataLoader(dataset)
- [Configure Cache Size Limit](#configure-cache-size-limit)
- [On-Prem Optimizations](#on-prem-optimizations)


## Multi-GPU / Multi-Node Support

The `StreamingDataset` and `StreamingDataLoader` automatically make sure each rank receives the same quantity of varied batches of data, so it works out of the box with your favorite frameworks ([PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/), [Lightning Fabric](https://lightning.ai/docs/fabric/stable/), or [PyTorch](https://pytorch.org/docs/stable/index.html)) to do distributed training.
Expand All @@ -139,6 +141,41 @@ Here you can see an illustration showing how the Streaming Dataset works with mu

![An illustration showing how the Streaming Dataset works with multi node.](https://pl-flash-data.s3.amazonaws.com/streaming_dataset.gif)

## Subsample and split your datasets

You can split your dataset with more ease with `train_test_split`.

```python
from litdata import StreamingDataset, train_test_split

dataset = StreamingDataset("s3://my-bucket/my-data") # data are stored in the cloud

print(len(dataset)) # display the length of your data
# out: 100,000

train_dataset, val_dataset, test_dataset = train_test_split(dataset, splits=[0.3, 0.2, 0.5])

print(train_dataset)
# out: 30,000

print(val_dataset)
# out: 20,000

print(test_dataset)
# out: 50,000
```

Or simply subsample them

```python
from litdata import StreamingDataset, train_test_split

dataset = StreamingDataset("s3://my-bucket/my-data", subsample=0.01) # data are stored in the cloud

print(len(dataset)) # display the length of your data
# out: 1000
```

## Access any item

Access the data you need, whenever you need it, regardless of where it is stored.
Expand Down Expand Up @@ -209,8 +246,7 @@ Easily experiment with dataset mixtures using the `CombinedStreamingDataset` cla
As an example, this mixture of [Slimpajama](https://huggingface.co/datasets/cerebras/SlimPajama-627B) & [StarCoder](https://huggingface.co/datasets/bigcode/starcoderdata) was used in the [TinyLLAMA](https://github.com/jzhang38/TinyLlama) project to pretrain a 1.1B Llama model on 3 trillion tokens.

```python
from litdata import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader
from litdata.streaming.item_loader import TokensLoader
from litdata import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader, TokensLoader
from tqdm import tqdm
import os

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
torch
filelock
numpy
numpy < 2.0.0
boto3
requests
4 changes: 4 additions & 0 deletions src/litdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,18 @@
from litdata.streaming.combined import CombinedStreamingDataset
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import TokensLoader
from litdata.utilities.train_test_split import train_test_split

__all__ = [
"StreamingDataset",
"CombinedStreamingDataset",
"StreamingDataLoader",
"TokensLoader",
"map",
"optimize",
"walk",
"train_test_split",
]
if RequirementCache("lightning_sdk"):
from lightning_sdk import Machine # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,4 @@

_TIME_FORMAT = "%Y-%m-%d_%H-%M-%S.%fZ"
_IS_IN_STUDIO = bool(os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None)) and bool(os.getenv("LIGHTNING_CLUSTER_ID", None))
_ENABLE_STATUS = bool(int(os.getenv("ENABLE_STATUS_REPORT", "0")))
3 changes: 2 additions & 1 deletion src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from litdata.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_ENABLE_STATUS,
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_AVAILABLE,
Expand Down Expand Up @@ -995,7 +996,7 @@ def run(self, data_recipe: DataRecipe) -> None:
if current_total == num_items:
break

if _IS_IN_STUDIO and node_rank == 0:
if _IS_IN_STUDIO and node_rank == 0 and _ENABLE_STATUS:
with open("status.json", "w") as f:
json.dump({"progress": str(100 * current_total * num_nodes / total_num_items) + "%"}, f)

Expand Down
10 changes: 8 additions & 2 deletions src/litdata/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from litdata.constants import (
_INDEX_FILENAME,
)
from litdata.streaming.item_loader import BaseItemLoader
from litdata.streaming.item_loader import BaseItemLoader, Interval
from litdata.streaming.reader import BinaryReader
from litdata.streaming.resolver import Dir, _resolve_dir
from litdata.streaming.sampler import ChunkedIndex
Expand All @@ -34,6 +34,8 @@ class Cache:
def __init__(
self,
input_dir: Optional[Union[str, Dir]],
subsampled_files: Optional[List[str]] = None,
region_of_interest: Optional[List[Tuple[int, int]]] = None,
compression: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
Expand All @@ -46,6 +48,8 @@ def __init__(
Arguments:
input_dir: The path to where the chunks will be or are stored.
subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file.
region_of_interest: List of tuples of (start,end) of region of interest for each chunk.
compression: The name of the algorithm to reduce the size of the chunks.
chunk_bytes: The maximum number of bytes within a chunk.
chunk_size: The maximum number of items within a chunk.
Expand All @@ -67,6 +71,8 @@ def __init__(
)
self._reader = BinaryReader(
self._cache_dir,
subsampled_files=subsampled_files,
region_of_interest=region_of_interest,
max_cache_size=_convert_bytes_to_int(max_cache_size) if isinstance(max_cache_size, str) else max_cache_size,
remote_input_dir=input_dir.url,
compression=compression,
Expand Down Expand Up @@ -138,7 +144,7 @@ def _merge_no_wait(self, node_rank: Optional[int] = None) -> None:
def __len__(self) -> int:
return self._reader.get_length()

def get_chunk_intervals(self) -> List[Tuple[int, int]]:
def get_chunk_intervals(self) -> List[Interval]:
return self._reader.get_chunk_intervals()

def _get_chunk_index_from_index(self, index: int) -> int:
Expand Down
63 changes: 55 additions & 8 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from litdata.constants import _INDEX_FILENAME
from litdata.streaming.compression import _COMPRESSORS, Compressor
from litdata.streaming.downloader import get_downloader_cls
from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from litdata.streaming.item_loader import BaseItemLoader, Interval, PyTreeLoader, TokensLoader
from litdata.streaming.sampler import ChunkedIndex
from litdata.streaming.serializers import Serializer
from litdata.utilities._pytree import tree_unflatten, treespec_loads
Expand All @@ -31,6 +31,8 @@ def __init__(
serializers: Dict[str, Serializer],
remote_dir: Optional[str],
item_loader: Optional[BaseItemLoader] = None,
subsampled_files: Optional[List[str]] = None,
region_of_interest: Optional[List[Tuple[int, int]]] = None,
) -> None:
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
chunk.
Expand All @@ -40,24 +42,34 @@ def __init__(
serializers: The serializers used to serialize and deserialize the chunks.
remote_dir: The path to a remote folder where the data are located.
The scheme needs to be added to the path.
subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file.
region_of_interest: List of tuples of {start,end} of region of interest for each chunk.
"""
self._cache_dir = cache_dir
self._intervals: List[Tuple[int, int]] = []
self._intervals: List[Interval] = []
self._config = None
self._chunks = []
self._chunks = None
self._remote_dir = remote_dir
self._item_loader = item_loader or PyTreeLoader()

with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
data = json.load(f)
_original_chunks = data["chunks"]
self._config = data["config"]
self._validate_item_loader()
self._chunks.extend(data["chunks"])

assert _original_chunks is not None

if subsampled_files is None:
self._chunks = _original_chunks
else:
self._chunks = load_subsampled_chunks(subsampled_files, _original_chunks)

self._config["data_spec"] = treespec_loads(self._config["data_spec"])

self._item_loader.setup(self._config, self._chunks, serializers)
assert self._chunks is not None
self._item_loader.setup(self._config, self._chunks, serializers, region_of_interest)
self._intervals = self._item_loader.generate_intervals()
self._length = self._intervals[-1][-1]
self._downloader = None
Expand Down Expand Up @@ -87,6 +99,7 @@ def skip_chunk_indexes_deletion(self, skip_chunk_indexes_deletion: List[int]) ->
self._skip_chunk_indexes_deletion = skip_chunk_indexes_deletion

def download_chunk_from_index(self, chunk_index: int) -> None:
assert self._chunks is not None
chunk_filename = self._chunks[chunk_index]["filename"]

local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
Expand Down Expand Up @@ -124,7 +137,7 @@ def try_decompress(self, local_chunkpath: str) -> None:
f.write(data)

@property
def intervals(self) -> List[Tuple[int, int]]:
def intervals(self) -> List[Interval]:
if self._intervals is None:
raise RuntimeError("The intervals should be defined.")
return self._intervals
Expand All @@ -133,6 +146,7 @@ def intervals(self) -> List[Tuple[int, int]]:
def num_bytes(self) -> int:
if self._config is None:
raise RuntimeError("The config should be defined.")
assert self._chunks is not None
return sum(c["chunk_bytes"] for c in self._chunks)

@property
Expand Down Expand Up @@ -167,14 +181,15 @@ def config(self) -> Dict[str, Any]:

def _get_chunk_index_from_index(self, index: int) -> int:
for chunk_index, internal in enumerate(self._intervals):
if internal[0] <= index < internal[1]:
if internal[0] <= index < internal[-1]:
return chunk_index
raise ValueError(
f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}."
)

def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
"""Find the associated chunk metadata."""
assert self._chunks is not None
chunk = self._chunks[index.chunk_index]

local_chunkpath = os.path.join(self._cache_dir, chunk["filename"])
Expand All @@ -188,6 +203,7 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:

def _get_chunk_index_from_filename(self, chunk_filename: str) -> int:
"""Retrieves the associated chunk_index for a given chunk filename."""
assert self._chunks is not None
for chunk_index, chunk in enumerate(self._chunks):
if chunk["filename"] == chunk_filename:
return chunk_index
Expand All @@ -200,6 +216,8 @@ def load(
serializers: Dict[str, Serializer],
remote_dir: Optional[str] = None,
item_loader: Optional[BaseItemLoader] = None,
subsampled_files: Optional[List[str]] = None,
region_of_interest: Optional[List[Tuple[int, int]]] = None,
) -> Optional["ChunksConfig"]:
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)

Expand All @@ -210,7 +228,7 @@ def load(
if not os.path.exists(cache_index_filepath):
return None

return ChunksConfig(cache_dir, serializers, remote_dir, item_loader)
return ChunksConfig(cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest)

def __len__(self) -> int:
return self._length
Expand All @@ -223,3 +241,32 @@ def _validate_item_loader(self) -> None:
and not isinstance(self._item_loader, TokensLoader)
):
raise ValueError("Please, use Cache(..., item_loader=TokensLoader(block_size=...))")


def load_subsampled_chunks(subsampled_files: List[str], original_chunks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Loads Chunks based on subsample provided."""
_subsampled_chunks: List[Dict[str, Any]] = [{} for _ in range(len(subsampled_files))]

assert len(_subsampled_chunks) == len(subsampled_files)

filename_dict = {}

# Populate the dictionary with filenames and their indices
for index, filename in enumerate(subsampled_files):
filename_dict[filename] = index

for curr_chunk in original_chunks:
if curr_chunk["filename"] in filename_dict:
idx = filename_dict[curr_chunk["filename"]]
_subsampled_chunks[idx] = curr_chunk

# if any idx of _subsampled_chunks is None, means,
# some elements in subsampled_files were not actually part of chunks
# raise error
if any(not _subsampled_chunk for _subsampled_chunk in _subsampled_chunks):
raise ValueError(
"Mismatch in subsampled files and the chunks loaded",
"Make sure subsampled chunks are actually part of the original chunk",
)

return _subsampled_chunks
Loading

0 comments on commit 5c242b4

Please sign in to comment.