Skip to content

Commit

Permalink
Add support for text (#18807)
Browse files Browse the repository at this point in the history
* update

* update

* update

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------

Co-authored-by: thomas <thomas@thomass-MacBook-Pro.local>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Oct 19, 2023
1 parent 3f86ad7 commit c68ff64
Show file tree
Hide file tree
Showing 15 changed files with 383 additions and 105 deletions.
6 changes: 3 additions & 3 deletions .github/CODEOWNERS
Validating CODEOWNERS rules …
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@
/src/lightning/pytorch/core/module.py @williamfalcon @tchaton @awaelchli @carmocca

# Data Utilities
/examples/data/ @nohalon @justusschock @lantiga
/src/lightning/data/ @nohalon @justusschock @lantiga
/tests/tests_data @nohalon @justusschock @lantiga
/examples/data/ @tchaton @nohalon @justusschock @lantiga
/src/lightning/data/ @tchaton @nohalon @justusschock @lantiga
/tests/tests_data @tchaton @nohalon @justusschock @lantiga

# Lightning Fabric
/src/lightning/fabric @awaelchli @carmocca @justusschock
Expand Down
8 changes: 7 additions & 1 deletion src/lightning/data/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.streaming.item_loader import BaseItemLoader
from lightning.data.streaming.reader import BinaryReader
from lightning.data.streaming.sampler import ChunkedIndex
from lightning.data.streaming.writer import BinaryWriter
Expand All @@ -41,6 +42,7 @@ def __init__(
compression: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[int] = None,
item_loader: Optional[BaseItemLoader] = None,
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
together in order to accelerate fetching.
Expand All @@ -54,6 +56,7 @@ def __init__(
compression: The name of the algorithm to reduce the size of the chunks.
chunk_bytes: The maximum number of bytes within a chunk.
chunk_size: The maximum number of items within a chunk.
item_loader: The object responsible to generate the chunk intervals and load an item froma chunk.
"""
super().__init__()
Expand All @@ -71,7 +74,10 @@ def __init__(
str(cache_dir), chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression
)
self._reader = BinaryReader(
str(cache_dir), remote_dir=remote_dir, compression=compression, name=name, version=version
str(cache_dir),
remote_dir=remote_dir,
compression=compression,
item_loader=item_loader,
)
self._cache_dir = str(cache_dir)
self._is_done = False
Expand Down
36 changes: 20 additions & 16 deletions src/lightning/data/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

from lightning.data.streaming.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0
from lightning.data.streaming.downloader import get_downloader_cls
from lightning.data.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader
from lightning.data.streaming.sampler import ChunkedIndex

if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import treespec_loads


class ChunksConfig:
def __init__(self, cache_dir: str, remote_dir: Optional[str]):
def __init__(self, cache_dir: str, remote_dir: Optional[str], item_loader: Optional[BaseItemLoader] = None) -> None:
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
chunk.
Expand All @@ -39,27 +40,19 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str]):
self._config = None
self._chunks = []
self._remote_dir = remote_dir
self._item_loader = item_loader or PyTreeLoader()

with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
data = json.load(f)

self._config = data["config"]

self._validate_item_loader()
self._chunks.extend(data["chunks"])

self._config["data_spec"] = treespec_loads(self._config["data_spec"])

for chunk in self._chunks:
start, end = chunk["interval"]
if (end - start) != chunk["chunk_size"]:
raise Exception(
"The config intervals doesn't match the number of samples. This shouldn't have happened."
f" Found {end} {start} {chunk['chunk_size']}"
)
self._intervals.append((chunk["interval"][0], chunk["interval"][1]))

self._length = sum([chunk["chunk_size"] for chunk in self._chunks])

self._item_loader.setup(self._config, self._chunks)
self._intervals = self._item_loader.generate_intervals()
self._length = self._intervals[-1][-1]
self._downloader = None

if remote_dir:
Expand Down Expand Up @@ -110,7 +103,9 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]

@classmethod
def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["ChunksConfig"]:
def load(
cls, cache_dir: str, remote_dir: Optional[str] = None, item_loader: Optional[BaseItemLoader] = None
) -> Optional["ChunksConfig"]:
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)

if isinstance(remote_dir, str):
Expand All @@ -120,7 +115,16 @@ def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["Chu
if not os.path.exists(cache_index_filepath):
return None

return ChunksConfig(cache_dir, remote_dir)
return ChunksConfig(cache_dir, remote_dir, item_loader)

def __len__(self) -> int:
return self._length

def _validate_item_loader(self) -> None:
assert self._config
if (
len(self._config["data_format"]) == 1
and self._config["data_format"][0].startswith("no_header_tensor")
and not isinstance(self._item_loader, TokensLoader)
):
raise ValueError("Please, use Cache(..., item_loader=TokensLoader(block_size=...))")
25 changes: 25 additions & 0 deletions src/lightning/data/streaming/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from lightning_utilities.core.imports import RequirementCache

_INDEX_FILENAME = "index.json"
Expand All @@ -22,3 +23,27 @@
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_GREATER_EQUAL_0_5_42 = RequirementCache("lightning-cloud>=0.5.42")
_BOTO3_AVAILABLE = RequirementCache("boto3")

# DON'T CHANGE ORDER
_TORCH_DTYPES_MAPPING = {
0: torch.float32,
1: torch.float,
2: torch.float64,
3: torch.double,
4: torch.complex64,
5: torch.cfloat,
6: torch.complex128,
7: torch.cdouble,
8: torch.float16,
9: torch.half,
10: torch.bfloat16, # Not supported https://github.com/pytorch/pytorch/issues/110285
11: torch.uint8,
12: torch.int8,
13: torch.int16,
14: torch.short,
15: torch.int32,
16: torch.int,
17: torch.int64,
18: torch.long,
19: torch.bool,
}
33 changes: 21 additions & 12 deletions src/lightning/data/streaming/dataset_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import signal
import traceback
import types
from abc import ABC, abstractmethod
from enum import Enum
from multiprocessing import Process, Queue
Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(
start_index: int,
dataset_name: str,
node_rank: int,
prepare_item: Callable,
dataset_optimizer: "DatasetOptimizer",
src_dir: str,
remote_src_dir: str,
remote_dst_dir: Optional[str],
Expand All @@ -185,7 +186,7 @@ def __init__(
self.start_index = start_index
self.dataset_name = dataset_name
self.node_rank = node_rank
self.prepare_item = prepare_item
self.prepare_item = dataset_optimizer.prepare_item
self.src_dir = src_dir
self.remote_src_dir = remote_src_dir
self.remote_dst_dir = remote_dst_dir
Expand All @@ -207,6 +208,7 @@ def __init__(
self.uploader: Optional[Process] = None
self._collected_items = 0
self._counter = 0
self._index_counter = 0

def run(self) -> None:
try:
Expand Down Expand Up @@ -250,14 +252,17 @@ def _loop(self) -> None:
return
continue

item_index = index + self.start_index
item_data = self.prepare_item(self.items[index]) if self.prepare_item else self.items[index] # type: ignore
chunk_filepath = self.cache._add_item(item_index, item_data)

self._try_upload(chunk_filepath)
item_data_or_generator = self.prepare_item(self.items[index]) if self.prepare_item else self.items[index] # type: ignore
if isinstance(item_data_or_generator, types.GeneratorType):
for item_data in item_data_or_generator:
chunk_filepath = self.cache._add_item(self._index_counter, item_data)
self._try_upload(chunk_filepath)
self._index_counter += 1
else:
chunk_filepath = self.cache._add_item(index + self.start_index, item_data_or_generator)
self._try_upload(chunk_filepath)

self._counter += 1

if self.progress_queue:
self.progress_queue.put((self.worker_index, self._counter))

Expand Down Expand Up @@ -623,7 +628,7 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
begins[worker_idx],
self.name,
_get_node_rank(),
self.prepare_item,
self,
self.src_dir,
self.remote_src_dir,
self.remote_dst_dir,
Expand All @@ -632,7 +637,9 @@ def _create_thread_workers(self, begins: List[int], workers_user_items: List[Lis
self.error_queue,
self.num_downloaders,
self.delete_cached_files,
2 if self.fast_dev_run else self.chunk_size, # In dev run, create chunks with 2 items
(self.chunk_size if self.chunk_size else 2)
if self.fast_dev_run
else self.chunk_size, # In dev run, create chunks with 2 items
None if self.fast_dev_run else self.chunk_bytes,
self.compression,
)
Expand All @@ -657,7 +664,7 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li
begins[worker_idx],
self.name,
_get_node_rank(),
self.prepare_item,
self,
self.src_dir,
self.remote_src_dir,
self.remote_dst_dir,
Expand All @@ -666,7 +673,9 @@ def _create_process_workers(self, begins: List[int], workers_user_items: List[Li
self.error_queue,
self.num_downloaders,
self.delete_cached_files,
2 if self.fast_dev_run else self.chunk_size, # In dev run, create chunks with 2 items
(self.chunk_size if self.chunk_size else 2)
if self.fast_dev_run
else self.chunk_size, # In dev run, create chunks with 2 items
None if self.fast_dev_run else self.chunk_bytes,
self.compression,
)
Expand Down

0 comments on commit c68ff64

Please sign in to comment.