diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index acc50d2b6..17a2d5342 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -360,6 +360,14 @@ def delete(self, chunk_index: int, chunk_filepath: str) -> None: del self._mmaps[chunk_index] os.remove(chunk_filepath) + def close(self, chunk_index: int) -> None: + """Release the memory-mapped file for a specific chunk index.""" + if chunk_index in self._mmaps: + self._mmaps[chunk_index]._mmap.close() + del self._mmaps[chunk_index] + if chunk_index in self._buffers: + del self._buffers[chunk_index] + @classmethod def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> Tuple[bytes, Optional[int]]: return data[0], flattened[0].shape[0] diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 60cd29658..8cb28ef74 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -20,7 +20,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from litdata.streaming.config import ChunksConfig, Interval -from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader +from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer, _get_serializers from litdata.utilities.encryption import Encryption @@ -288,7 +288,6 @@ def read(self, index: ChunkedIndex) -> Any: item = self._item_loader.load_item_from_chunk( index.index, index.chunk_index, chunk_filepath, begin, chunk_bytes ) - # We need to request deletion after the latest element has been loaded. # Otherwise, this could trigger segmentation fault error depending on the item loader used. if ( @@ -302,6 +301,11 @@ def read(self, index: ChunkedIndex) -> Any: # inform the chunk has been completely consumed self._prepare_thread.delete([self._last_chunk_index]) + if index.chunk_index != self._last_chunk_index: + # Close the memory-mapped file for the last chunk index + if isinstance(self._item_loader, TokensLoader) and self._last_chunk_index is not None: + self._item_loader.close(self._last_chunk_index) + # track the new chunk index as the latest one self._last_chunk_index = index.chunk_index diff --git a/tests/streaming/test_dataset.py b/tests/streaming/test_dataset.py index ef93021c9..3b0010996 100644 --- a/tests/streaming/test_dataset.py +++ b/tests/streaming/test_dataset.py @@ -507,6 +507,28 @@ def test_dataset_for_text_tokens(tmpdir): break +@pytest.mark.skipif(sys.platform == "win32", reason="windows isn't supported") +def test_dataset_for_text_tokens_with_large_num_chunks(tmpdir): + import resource + + resource.setrlimit(resource.RLIMIT_NOFILE, (1024, 1024)) + + block_size = 1024 + cache = Cache(input_dir=str(tmpdir), chunk_bytes="10KB", item_loader=TokensLoader(block_size)) + + for i in range(10000): + text_ids = torch.randint(0, 10001, (torch.randint(100, 1001, (1,)).item(),)).numpy() + cache._add_item(i, text_ids) + + cache.done() + cache.merge() + + dataset = StreamingDataset(input_dir=str(tmpdir), item_loader=TokensLoader(block_size), shuffle=True) + + for _ in dataset: + pass + + def test_dataset_with_1d_array(tmpdir): seed_everything(42)