From 2e3e1c20d5e3163363a4a3d9d5b545de227e1f94 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 26 Sep 2023 16:11:25 +0000 Subject: [PATCH 01/84] update --- src/lightning/data/builder/__init__.py | 0 src/lightning/data/builder/base.py | 108 ++++++++++++++++++++++ src/lightning/data/builder/cache.py | 79 ++++++++++++++++ src/lightning/data/builder/reader.py | 31 +++++++ src/lightning/data/builder/serializers.py | 50 ++++++++++ src/lightning/data/builder/writer.py | 106 +++++++++++++++++++++ src/lightning/data/datasets/iterable.py | 1 + 7 files changed, 375 insertions(+) create mode 100644 src/lightning/data/builder/__init__.py create mode 100644 src/lightning/data/builder/base.py create mode 100644 src/lightning/data/builder/cache.py create mode 100644 src/lightning/data/builder/reader.py create mode 100644 src/lightning/data/builder/serializers.py create mode 100644 src/lightning/data/builder/writer.py diff --git a/src/lightning/data/builder/__init__.py b/src/lightning/data/builder/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/src/lightning/data/builder/base.py b/src/lightning/data/builder/base.py new file mode 100644 index 0000000000000..de6a32615f867 --- /dev/null +++ b/src/lightning/data/builder/base.py @@ -0,0 +1,108 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import ABC, abstractmethod +from typing import Optional, Dict, Any +import json + + +class BaseWriter(ABC): + + def __init__( + self, + out_dir: str, + chunk_size: int = 1 << 26, + compression: Optional[str] = None, + name: Optional[str] = None, + ): + self._out_dir = out_dir + + if not os.path.exists(self._out_dir): + raise Exception(f"The provided output directory {self._out_dir} doesn't exists.") + + self._chunk_size = chunk_size + self._compression = compression + self._name = name + + self._current_chunk_size = 0 + self._counter = 0 + self._serialized_items = [] + self._serializers: List[Serializer] = [] + self._chunks = [] + + @property + def is_cached(self) -> bool: + return os.path.exists(os.path.join(self._out_dir, "index.json")) + + def get_config(self) -> Dict[str, Any]: + return { + 'compression': self._compression, + 'chunk_size': self._chunk_size + } + + @property + def available_serializers(self): + return self._serializers + + @abstractmethod + def serialize(self, data: any) -> bytes: + """Convert a given data type into its bytes format""" + + @abstractmethod + def write_chunk(self, rank: int) -> None: + """Write the current chunk to the filesystem""" + + def reset(self) -> None: + """Reset the writer to handle the next chunk""" + self._serialized_items = [] + self._current_chunk_size = 0 + + def write(self, items: any, rank): + serialized_items = self.serialize(items) + serialized_items_size = len(serialized_items) + + if self._chunk_size < self._current_chunk_size + serialized_items_size: + self.write_chunk(rank) + self.reset() + self._counter += 1 + + self._serialized_items.append(serialized_items) + self._current_chunk_size += serialized_items_size + + def write_file(self, raw_data: bytes, filename: str,) -> None: + filepath = os.path.join(self._out_dir, filename) + with open(filepath, 'wb') as out: + out.write(raw_data) + + def write_chunks_index(self, rank: int): + filepath = os.path.join(self._out_dir, f"{rank}.index.json") + with open(filepath, 'w') as out: + json.dump({'chunks': self._chunks}, out, sort_keys=True) + + def done(self, rank: int): + if self._serialized_items: + self.write_chunk(rank) + self.write_chunks_index(rank) + self.reset() + + +class Serializer(ABC): + + @abstractmethod + def serialize(self, data: any) -> bytes: + pass + + @abstractmethod + def deserialize(self, data: bytes) -> any: + pass \ No newline at end of file diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py new file mode 100644 index 0000000000000..f614c6e16f5f2 --- /dev/null +++ b/src/lightning/data/builder/cache.py @@ -0,0 +1,79 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Union, Dict, Optional +from lightning.data.builder.writer import Writer +from lightning.data.builder.reader import Reader +from torch.utils.data import get_worker_info +import torch +from torch.distributed import is_initialized, is_available, get_world_size +from lightning.data.datasets.env import _WorkerEnv, _DistributedEnv +from torch.utils.data.dataloader import DataLoader, _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter + +class Cache: + + def __init__(self, cache_dir: str, data_format: Union[Dict[str, any], str], compression: Optional[str] = None, chunk_size: int = 2 << 26): + super().__init__() + self._writer = Writer(cache_dir, data_format, chunk_size) + self._reader = Reader(cache_dir) + self._cache_dir = cache_dir + + self._env = _DistributedEnv.detect() + self._worker_env = None + self._rank = None + + @property + def rank(self): + if self._rank is None: + self._worker_env = _WorkerEnv.detect() + self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank + return self._rank + + @property + def filled(self) -> bool: + files = os.listdir(self._cache_dir) + return any(f.endswith("index.json") for f in files) + + def __setitem__(self, index, data): + self._writer.write(data, self.rank) + + def __getitem__(self, index): + self._reader.read(index, self.rank) + + def done(self): + self._writer.done(self.rank) + + +class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): + + def _next_data(self): + try: + return super()._next_data() + except StopIteration: + for v in self._dataset_fetcher.dataset.__dict__.values(): + if isinstance(v, Cache): + v.done() + raise StopIteration() + +class CacheDataLoader(DataLoader): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _get_iterator(self) -> '_BaseDataLoaderIter': + if self.num_workers == 0: + return _SingleProcessDataLoaderIterPatch(self) + else: + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIter(self) diff --git a/src/lightning/data/builder/reader.py b/src/lightning/data/builder/reader.py new file mode 100644 index 0000000000000..a8121475db6d0 --- /dev/null +++ b/src/lightning/data/builder/reader.py @@ -0,0 +1,31 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional +from lightning_utilities.core.imports import RequirementCache +from lightning.data.builder.serializers import _SERIALIZERS +from lightning.data.builder.base import BaseWriter +import numpy as np +import json +import os + +class Reader: + + def __init__(self, out_dir: str): + super().__init__() + + self.out_dir = out_dir + + def read(self, index: int, rank): + pass + diff --git a/src/lightning/data/builder/serializers.py b/src/lightning/data/builder/serializers.py new file mode 100644 index 0000000000000..539f38aea8600 --- /dev/null +++ b/src/lightning/data/builder/serializers.py @@ -0,0 +1,50 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict +from lightning_utilities.core.imports import RequirementCache +from lightning.data.builder.base import Serializer +from enum import Enum +import numpy as np + + +class PILSerializer(Serializer): + + def serialize(self, item: any) -> bytes: + mode = item.mode.encode('utf-8') + width, height = item.size + raw = item.tobytes() + ints = np.array([width, height, len(mode)], np.uint32) + return ints.tobytes() + mode + raw + + def deserialize(self, data: bytes) -> any: + idx = 3 * 4 + width, height, mode_size = np.frombuffer(data[:idx], np.uint32) + idx2 = idx + mode_size + mode = data[idx:idx2].decode('utf-8') + size = width, height + raw = data[idx2:] + return Image.frombytes(mode, size, raw) # pyright: ignore + + +class IntSerializer(Serializer): + def serialize(self, item: int) -> bytes: + return str(item).encode('utf-8') + + def deserialize(self, data: bytes) -> int: + return int(data.decode('utf-8')) + +_SERIALIZERS = { + "pil": PILSerializer(), + "int": IntSerializer(), +} \ No newline at end of file diff --git a/src/lightning/data/builder/writer.py b/src/lightning/data/builder/writer.py new file mode 100644 index 0000000000000..16766ec5d2afe --- /dev/null +++ b/src/lightning/data/builder/writer.py @@ -0,0 +1,106 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional +from lightning_utilities.core.imports import RequirementCache +from lightning.data.builder.serializers import _SERIALIZERS +from lightning.data.builder.base import BaseWriter +import numpy as np +import json +import os + +_PIL_AVAILABLE = RequirementCache("PIL") + +if _PIL_AVAILABLE: + from PIL import Image +else: + Image = Any + +class Writer(BaseWriter): + + def __init__( + self, + out_dir: str, + dict_format: Dict[str, str], + chunk_size: int = 1 << 26, + compression: Optional[str] = None, + name: Optional[str] = None, + ): + super().__init__(out_dir, chunk_size, compression, name) + + if not _PIL_AVAILABLE: + raise Exception("The ImageWriter requires pil to be installed") + + self._dict_format = {k.lower(): v for k, v in dict_format.items()} + self._dict_format_keys = sorted(self._dict_format.keys()) + self._serializers = _SERIALIZERS + + available_serializers = set(self._serializers.keys()) + selected_serializers = set(self._dict_format.values()) + if selected_serializers.difference(available_serializers): + raise Exception(f"The provided dict_format don't match the provided serializers. Should be selected from {available_serializers}.") + + obj = self.get_config() + text = json.dumps(obj, sort_keys=True) + self._config_data = text.encode('utf-8') + + def get_config(self) -> Dict[str, Any]: + out = super().get_config() + out.update(self._dict_format) + return out + + def serialize(self, items: Dict[str, Any]) -> bytes: + if not isinstance(items, dict): + raise Exception("The provided data should be a dictionary.") + + keys = sorted(items.keys()) + + if keys != self._dict_format_keys: + raise Exception(f"The provided keys don't match the provided format. Found {keys} instead of {self._dict_format_keys}.") + + sizes = [] + data = [] + + for key in self._dict_format_keys: + serializer_name = self._dict_format[key] + serializer = self._serializers[serializer_name] + serialized_data = serializer.serialize(items[key]) + + sizes.append(len(serialized_data)) + data.append(serialized_data) + + head = np.array(sizes, np.uint32).tobytes() + body = b''.join(data) + return head + body + + def _create_chunk(self, filename: str) -> bytes: + num_items = np.uint32(len(self._serialized_items)) + sizes = list(map(len, self._serialized_items)) + offsets = np.array([0] + sizes).cumsum().astype(np.uint32) + offsets += len(num_items.tobytes()) + len(offsets.tobytes()) + len(self._config_data) + sample_data = b''.join(self._serialized_items) + + self._chunks.append({ + 'samples': len(self._serialized_items), + "config": self.get_config(), + "filename": filename, + }) + + return num_items.tobytes() + offsets.tobytes() + self._config_data + sample_data + + def write_chunk(self, rank: int): + filename = f"chunk-{rank}-{self._counter}.bin" + self.write_file(self._create_chunk(filename), filename) + + def reset(self): + pass \ No newline at end of file diff --git a/src/lightning/data/datasets/iterable.py b/src/lightning/data/datasets/iterable.py index 54388138bcf79..d6d0d147f075e 100644 --- a/src/lightning/data/datasets/iterable.py +++ b/src/lightning/data/datasets/iterable.py @@ -165,6 +165,7 @@ def load_chunk(self, chunk: Any) -> Any: chunk: The chunk that should be currently loaded """ + @abstractmethod def load_sample_from_chunk(self, chunk: Any, index: int) -> Any: From 90bcd89dd96cea686c1156a1f6fe0f75bb1e1628 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 26 Sep 2023 17:48:47 +0000 Subject: [PATCH 02/84] update --- src/lightning/data/builder/cache.py | 103 +++++++++++++++++++++++++-- src/lightning/data/builder/reader.py | 70 +++++++++++++++++- 2 files changed, 166 insertions(+), 7 deletions(-) diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py index f614c6e16f5f2..5903eac927e69 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/builder/cache.py @@ -12,14 +12,18 @@ # limitations under the License. import os -from typing import Union, Dict, Optional +import numpy as np +from typing import Union, Dict, Optional, Iterable, Iterator from lightning.data.builder.writer import Writer from lightning.data.builder.reader import Reader -from torch.utils.data import get_worker_info +from torch.utils.data import get_worker_info, IterableDataset import torch from torch.distributed import is_initialized, is_available, get_world_size from lightning.data.datasets.env import _WorkerEnv, _DistributedEnv from torch.utils.data.dataloader import DataLoader, _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter +from torch.utils.data.sampler import BatchSampler, RandomSampler, SequentialSampler, Sampler, Sized +from torch.utils.data.distributed import DistributedSampler + class Cache: @@ -54,6 +58,12 @@ def __getitem__(self, index): def done(self): self._writer.done(self.rank) + def __len__(self): + return self._reader.get_length() + + def get_chunk_interval(self): + return self._reader.get_chunk_interval() + class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): @@ -66,10 +76,95 @@ def _next_data(self): v.done() raise StopIteration() + +class CacheSampler(Sampler): + + def __init__(self, dataset, generator, shuffle): + super().__init__(dataset) + + if shuffle: + self._sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] + else: + self._sampler = SequentialSampler(dataset) # type: ignore[arg-type] + + def __iter__(self): + return iter(self._sampler) + + def __len__(self) -> int: + return len(self._sampler) + + +class IteratorSampler(Sampler[int]): + r"""Samples elements sequentially, always in the same order. + + Args: + data_source (Dataset): dataset to sample from + """ + data_source: Sized + + def __init__(self, data_source: Sized) -> None: + self.data_source = data_source + + def __iter__(self) -> Iterator[int]: + return iter(self.data_source) + + def __len__(self) -> int: + return len(self.data_source) + + +class CacheBatchSampler(BatchSampler): + + def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool, shuffle: bool, cache: Cache): + super().__init__(sampler, batch_size, drop_last) + self._cache = cache + self._shuffle = shuffle + + def __iter__(self): + if self._cache.filled and self._shuffle: + return self.__iter__cache__() + return super().__iter__() + + def __iter__cache__(self): + chunk_intervals = self._cache.get_chunk_interval()[:-1] + shuffled_chunk_intervals = np.random.permutation(chunk_intervals) + + dataset = [] + for interval in shuffled_chunk_intervals: + interval_indices = np.arange(interval[0], interval[1]) + shuffled_interval_indices = np.random.permutation(interval_indices) + dataset.extend(shuffled_interval_indices.tolist()) + + if len(dataset) != len(self.sampler): + raise Exception("The generated indices don't match the initial length of the sampler.") + + self.sampler = IteratorSampler(dataset) + + return super().__iter__() + + def __len__(self) -> int: + return super().__len__() + + class CacheDataLoader(DataLoader): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) + def __init__(self, dataset, *args, sampler=None, batch_sampler=None, shuffle: bool = False, generator=None, batch_size=None, drop_last=False, **kwargs): + if sampler: + raise Exception("Passing a sampler isn't supoprt with the CacheDataLoader yet.") + + if batch_sampler: + raise Exception("Passing a batch_sampler isn't supoprt with the CacheDataLoader yet.") + + if isinstance(dataset, IterableDataset): + raise Exception("Only map-based dataset are supported by the CacheDataLoader for now.") + + cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] + + if not cache or len(cache) > 1: + raise Exception("The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") + + cache = cache[0] + batch_sampler = CacheBatchSampler(CacheSampler(dataset, generator, shuffle), batch_size, drop_last, shuffle, cache) + super().__init__(dataset, *args, sampler=None, batch_sampler=batch_sampler, generator=generator, **kwargs) def _get_iterator(self) -> '_BaseDataLoaderIter': if self.num_workers == 0: diff --git a/src/lightning/data/builder/reader.py b/src/lightning/data/builder/reader.py index a8121475db6d0..b4744f07ac8bd 100644 --- a/src/lightning/data/builder/reader.py +++ b/src/lightning/data/builder/reader.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from lightning_utilities.core.imports import RequirementCache from lightning.data.builder.serializers import _SERIALIZERS from lightning.data.builder.base import BaseWriter @@ -19,13 +19,77 @@ import json import os + class Reader: - def __init__(self, out_dir: str): + def __init__(self, _cache_dir: str): super().__init__() - self.out_dir = out_dir + self._cache_dir = _cache_dir + self._index = None + self._intervals = None + self._chunks = [] + + def _try_read_index(self): + files = os.listdir(self._cache_dir) + indexes_filepath = sorted([os.path.join(self._cache_dir, f) for f in files if f.endswith("index.json")]) + if not indexes_filepath: + return + + index = {"chunks": []} + for path in indexes_filepath: + with open(path, "r") as f: + data = json.load(f) + index['chunks'].extend(data["chunks"]) + + self._index = index + + self._intervals = [] + cumsum_samples = np.cumsum([0] + [v["samples"] for v in self._index['chunks']] + [1]) + for i in range(len(cumsum_samples) - 1): + self._intervals.append([cumsum_samples[i], cumsum_samples[i + 1]]) + + print(self._intervals) + + def _map_index_to_chunk_id(self, index): + for interval_index, internal in enumerate(self._intervals): + print(internal, index) + if internal[0] <= index and index < internal[1]: + return interval_index + return None def read(self, index: int, rank): + if self._index is None: + self._try_read_index() + + if self._index is None: + raise Exception("The reader index isn't defined.") + + chunk_id = self._map_index_to_chunk_id(index) + chunk_config = self._index['chunks'][chunk_id] + chunk_path = os.path.join(self._cache_dir, chunk_config['filename']) + if not os.path.exists(chunk_path): + download_chunk(chunk_path) + + return self.load_data_from_chunk(chunk_path) + + def load_data_from_chunk(self, chunk_path): pass + def get_length(self) -> int: + if self._index is None: + self._try_read_index() + + if self._index is None: + raise Exception("The reader index isn't defined.") + + return sum([v["samples"] for v in self._index['chunks']]) + + def get_chunk_interval(self): + if self._index is None: + self._try_read_index() + + if self._intervals is None: + raise Exception("The reader index isn't defined.") + + return self._intervals From 8d76988dffdf26eac5576ffc199ad67a52e130dd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Sep 2023 17:51:19 +0000 Subject: [PATCH 03/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/builder/base.py | 35 +++++++-------- src/lightning/data/builder/cache.py | 54 +++++++++++++++-------- src/lightning/data/builder/reader.py | 24 +++++----- src/lightning/data/builder/serializers.py | 18 ++++---- src/lightning/data/builder/writer.py | 45 +++++++++++-------- src/lightning/data/datasets/iterable.py | 1 - 6 files changed, 96 insertions(+), 81 deletions(-) diff --git a/src/lightning/data/builder/base.py b/src/lightning/data/builder/base.py index de6a32615f867..b4da32f912888 100644 --- a/src/lightning/data/builder/base.py +++ b/src/lightning/data/builder/base.py @@ -11,14 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os from abc import ABC, abstractmethod -from typing import Optional, Dict, Any -import json +from typing import Any, Dict, Optional class BaseWriter(ABC): - def __init__( self, out_dir: str, @@ -46,10 +45,7 @@ def is_cached(self) -> bool: return os.path.exists(os.path.join(self._out_dir, "index.json")) def get_config(self) -> Dict[str, Any]: - return { - 'compression': self._compression, - 'chunk_size': self._chunk_size - } + return {"compression": self._compression, "chunk_size": self._chunk_size} @property def available_serializers(self): @@ -57,14 +53,14 @@ def available_serializers(self): @abstractmethod def serialize(self, data: any) -> bytes: - """Convert a given data type into its bytes format""" + """Convert a given data type into its bytes format.""" @abstractmethod def write_chunk(self, rank: int) -> None: - """Write the current chunk to the filesystem""" + """Write the current chunk to the filesystem.""" def reset(self) -> None: - """Reset the writer to handle the next chunk""" + """Reset the writer to handle the next chunk.""" self._serialized_items = [] self._current_chunk_size = 0 @@ -76,33 +72,36 @@ def write(self, items: any, rank): self.write_chunk(rank) self.reset() self._counter += 1 - + self._serialized_items.append(serialized_items) self._current_chunk_size += serialized_items_size - def write_file(self, raw_data: bytes, filename: str,) -> None: + def write_file( + self, + raw_data: bytes, + filename: str, + ) -> None: filepath = os.path.join(self._out_dir, filename) - with open(filepath, 'wb') as out: + with open(filepath, "wb") as out: out.write(raw_data) def write_chunks_index(self, rank: int): filepath = os.path.join(self._out_dir, f"{rank}.index.json") - with open(filepath, 'w') as out: - json.dump({'chunks': self._chunks}, out, sort_keys=True) + with open(filepath, "w") as out: + json.dump({"chunks": self._chunks}, out, sort_keys=True) def done(self, rank: int): - if self._serialized_items: + if self._serialized_items: self.write_chunk(rank) self.write_chunks_index(rank) self.reset() class Serializer(ABC): - @abstractmethod def serialize(self, data: any) -> bytes: pass @abstractmethod def deserialize(self, data: bytes) -> any: - pass \ No newline at end of file + pass diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py index 5903eac927e69..a7e31887d1b0a 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/builder/cache.py @@ -12,22 +12,26 @@ # limitations under the License. import os +from typing import Dict, Iterable, Iterator, Optional, Union + import numpy as np -from typing import Union, Dict, Optional, Iterable, Iterator -from lightning.data.builder.writer import Writer +from torch.utils.data import IterableDataset +from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter +from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized + from lightning.data.builder.reader import Reader -from torch.utils.data import get_worker_info, IterableDataset -import torch -from torch.distributed import is_initialized, is_available, get_world_size -from lightning.data.datasets.env import _WorkerEnv, _DistributedEnv -from torch.utils.data.dataloader import DataLoader, _SingleProcessDataLoaderIter, _MultiProcessingDataLoaderIter -from torch.utils.data.sampler import BatchSampler, RandomSampler, SequentialSampler, Sampler, Sized -from torch.utils.data.distributed import DistributedSampler +from lightning.data.builder.writer import Writer +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv class Cache: - - def __init__(self, cache_dir: str, data_format: Union[Dict[str, any], str], compression: Optional[str] = None, chunk_size: int = 2 << 26): + def __init__( + self, + cache_dir: str, + data_format: Union[Dict[str, any], str], + compression: Optional[str] = None, + chunk_size: int = 2 << 26, + ): super().__init__() self._writer = Writer(cache_dir, data_format, chunk_size) self._reader = Reader(cache_dir) @@ -66,7 +70,6 @@ def get_chunk_interval(self): class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): - def _next_data(self): try: return super()._next_data() @@ -78,7 +81,6 @@ def _next_data(self): class CacheSampler(Sampler): - def __init__(self, dataset, generator, shuffle): super().__init__(dataset) @@ -99,6 +101,7 @@ class IteratorSampler(Sampler[int]): Args: data_source (Dataset): dataset to sample from + """ data_source: Sized @@ -113,8 +116,9 @@ def __len__(self) -> int: class CacheBatchSampler(BatchSampler): - - def __init__(self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool, shuffle: bool, cache: Cache): + def __init__( + self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool, shuffle: bool, cache: Cache + ): super().__init__(sampler, batch_size, drop_last) self._cache = cache self._shuffle = shuffle @@ -146,8 +150,18 @@ def __len__(self) -> int: class CacheDataLoader(DataLoader): - - def __init__(self, dataset, *args, sampler=None, batch_sampler=None, shuffle: bool = False, generator=None, batch_size=None, drop_last=False, **kwargs): + def __init__( + self, + dataset, + *args, + sampler=None, + batch_sampler=None, + shuffle: bool = False, + generator=None, + batch_size=None, + drop_last=False, + **kwargs + ): if sampler: raise Exception("Passing a sampler isn't supoprt with the CacheDataLoader yet.") @@ -163,10 +177,12 @@ def __init__(self, dataset, *args, sampler=None, batch_sampler=None, shuffle: bo raise Exception("The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") cache = cache[0] - batch_sampler = CacheBatchSampler(CacheSampler(dataset, generator, shuffle), batch_size, drop_last, shuffle, cache) + batch_sampler = CacheBatchSampler( + CacheSampler(dataset, generator, shuffle), batch_size, drop_last, shuffle, cache + ) super().__init__(dataset, *args, sampler=None, batch_sampler=batch_sampler, generator=generator, **kwargs) - def _get_iterator(self) -> '_BaseDataLoaderIter': + def _get_iterator(self) -> "_BaseDataLoaderIter": if self.num_workers == 0: return _SingleProcessDataLoaderIterPatch(self) else: diff --git a/src/lightning/data/builder/reader.py b/src/lightning/data/builder/reader.py index b4744f07ac8bd..fc3ed5e7cfe26 100644 --- a/src/lightning/data/builder/reader.py +++ b/src/lightning/data/builder/reader.py @@ -11,17 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, List -from lightning_utilities.core.imports import RequirementCache -from lightning.data.builder.serializers import _SERIALIZERS -from lightning.data.builder.base import BaseWriter -import numpy as np import json import os +import numpy as np + class Reader: - def __init__(self, _cache_dir: str): super().__init__() @@ -38,14 +34,14 @@ def _try_read_index(self): index = {"chunks": []} for path in indexes_filepath: - with open(path, "r") as f: + with open(path) as f: data = json.load(f) - index['chunks'].extend(data["chunks"]) + index["chunks"].extend(data["chunks"]) self._index = index self._intervals = [] - cumsum_samples = np.cumsum([0] + [v["samples"] for v in self._index['chunks']] + [1]) + cumsum_samples = np.cumsum([0] + [v["samples"] for v in self._index["chunks"]] + [1]) for i in range(len(cumsum_samples) - 1): self._intervals.append([cumsum_samples[i], cumsum_samples[i + 1]]) @@ -66,24 +62,24 @@ def read(self, index: int, rank): raise Exception("The reader index isn't defined.") chunk_id = self._map_index_to_chunk_id(index) - chunk_config = self._index['chunks'][chunk_id] - chunk_path = os.path.join(self._cache_dir, chunk_config['filename']) + chunk_config = self._index["chunks"][chunk_id] + chunk_path = os.path.join(self._cache_dir, chunk_config["filename"]) if not os.path.exists(chunk_path): download_chunk(chunk_path) return self.load_data_from_chunk(chunk_path) - + def load_data_from_chunk(self, chunk_path): pass def get_length(self) -> int: if self._index is None: self._try_read_index() - + if self._index is None: raise Exception("The reader index isn't defined.") - return sum([v["samples"] for v in self._index['chunks']]) + return sum([v["samples"] for v in self._index["chunks"]]) def get_chunk_interval(self): if self._index is None: diff --git a/src/lightning/data/builder/serializers.py b/src/lightning/data/builder/serializers.py index 539f38aea8600..91efd4a0f1918 100644 --- a/src/lightning/data/builder/serializers.py +++ b/src/lightning/data/builder/serializers.py @@ -11,17 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict -from lightning_utilities.core.imports import RequirementCache -from lightning.data.builder.base import Serializer -from enum import Enum import numpy as np +from lightning.data.builder.base import Serializer -class PILSerializer(Serializer): +class PILSerializer(Serializer): def serialize(self, item: any) -> bytes: - mode = item.mode.encode('utf-8') + mode = item.mode.encode("utf-8") width, height = item.size raw = item.tobytes() ints = np.array([width, height, len(mode)], np.uint32) @@ -31,7 +28,7 @@ def deserialize(self, data: bytes) -> any: idx = 3 * 4 width, height, mode_size = np.frombuffer(data[:idx], np.uint32) idx2 = idx + mode_size - mode = data[idx:idx2].decode('utf-8') + mode = data[idx:idx2].decode("utf-8") size = width, height raw = data[idx2:] return Image.frombytes(mode, size, raw) # pyright: ignore @@ -39,12 +36,13 @@ def deserialize(self, data: bytes) -> any: class IntSerializer(Serializer): def serialize(self, item: int) -> bytes: - return str(item).encode('utf-8') + return str(item).encode("utf-8") def deserialize(self, data: bytes) -> int: - return int(data.decode('utf-8')) + return int(data.decode("utf-8")) + _SERIALIZERS = { "pil": PILSerializer(), "int": IntSerializer(), -} \ No newline at end of file +} diff --git a/src/lightning/data/builder/writer.py b/src/lightning/data/builder/writer.py index 16766ec5d2afe..70ac1ff7fca63 100644 --- a/src/lightning/data/builder/writer.py +++ b/src/lightning/data/builder/writer.py @@ -11,13 +11,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import Any, Dict, Optional + +import numpy as np from lightning_utilities.core.imports import RequirementCache -from lightning.data.builder.serializers import _SERIALIZERS + from lightning.data.builder.base import BaseWriter -import numpy as np -import json -import os +from lightning.data.builder.serializers import _SERIALIZERS _PIL_AVAILABLE = RequirementCache("PIL") @@ -26,8 +27,8 @@ else: Image = Any + class Writer(BaseWriter): - def __init__( self, out_dir: str, @@ -48,11 +49,13 @@ def __init__( available_serializers = set(self._serializers.keys()) selected_serializers = set(self._dict_format.values()) if selected_serializers.difference(available_serializers): - raise Exception(f"The provided dict_format don't match the provided serializers. Should be selected from {available_serializers}.") - + raise Exception( + f"The provided dict_format don't match the provided serializers. Should be selected from {available_serializers}." + ) + obj = self.get_config() text = json.dumps(obj, sort_keys=True) - self._config_data = text.encode('utf-8') + self._config_data = text.encode("utf-8") def get_config(self) -> Dict[str, Any]: out = super().get_config() @@ -66,7 +69,9 @@ def serialize(self, items: Dict[str, Any]) -> bytes: keys = sorted(items.keys()) if keys != self._dict_format_keys: - raise Exception(f"The provided keys don't match the provided format. Found {keys} instead of {self._dict_format_keys}.") + raise Exception( + f"The provided keys don't match the provided format. Found {keys} instead of {self._dict_format_keys}." + ) sizes = [] data = [] @@ -80,27 +85,29 @@ def serialize(self, items: Dict[str, Any]) -> bytes: data.append(serialized_data) head = np.array(sizes, np.uint32).tobytes() - body = b''.join(data) + body = b"".join(data) return head + body - + def _create_chunk(self, filename: str) -> bytes: num_items = np.uint32(len(self._serialized_items)) sizes = list(map(len, self._serialized_items)) offsets = np.array([0] + sizes).cumsum().astype(np.uint32) offsets += len(num_items.tobytes()) + len(offsets.tobytes()) + len(self._config_data) - sample_data = b''.join(self._serialized_items) + sample_data = b"".join(self._serialized_items) - self._chunks.append({ - 'samples': len(self._serialized_items), - "config": self.get_config(), - "filename": filename, - }) + self._chunks.append( + { + "samples": len(self._serialized_items), + "config": self.get_config(), + "filename": filename, + } + ) return num_items.tobytes() + offsets.tobytes() + self._config_data + sample_data def write_chunk(self, rank: int): - filename = f"chunk-{rank}-{self._counter}.bin" + filename = f"chunk-{rank}-{self._counter}.bin" self.write_file(self._create_chunk(filename), filename) def reset(self): - pass \ No newline at end of file + pass diff --git a/src/lightning/data/datasets/iterable.py b/src/lightning/data/datasets/iterable.py index d6d0d147f075e..54388138bcf79 100644 --- a/src/lightning/data/datasets/iterable.py +++ b/src/lightning/data/datasets/iterable.py @@ -165,7 +165,6 @@ def load_chunk(self, chunk: Any) -> Any: chunk: The chunk that should be currently loaded """ - @abstractmethod def load_sample_from_chunk(self, chunk: Any, index: int) -> Any: From fa8a5f304542450ac57667edef420131f41799b1 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 27 Sep 2023 12:59:24 +0000 Subject: [PATCH 04/84] update --- src/lightning/data/builder/base.py | 8 +++ src/lightning/data/builder/cache.py | 19 +++++-- src/lightning/data/builder/compression.py | 62 +++++++++++++++++++++++ src/lightning/data/builder/reader.py | 5 +- src/lightning/data/builder/serializers.py | 27 +++++++++- src/lightning/data/builder/writer.py | 13 +++-- 6 files changed, 122 insertions(+), 12 deletions(-) create mode 100644 src/lightning/data/builder/compression.py diff --git a/src/lightning/data/builder/base.py b/src/lightning/data/builder/base.py index b4da32f912888..25ea29cc98f9f 100644 --- a/src/lightning/data/builder/base.py +++ b/src/lightning/data/builder/base.py @@ -15,6 +15,7 @@ import os from abc import ABC, abstractmethod from typing import Any, Dict, Optional +from lightning.data.builder.compression import _COMPRESSORS class BaseWriter(ABC): @@ -34,6 +35,11 @@ def __init__( self._compression = compression self._name = name + if compression and compression not in _COMPRESSORS: + raise Exception(f"The provided compression {compression} isn't available in {sorted(_COMPRESSORS)}") + + self._compressor = _COMPRESSORS[compression] + self._current_chunk_size = 0 self._counter = 0 self._serialized_items = [] @@ -81,6 +87,8 @@ def write_file( raw_data: bytes, filename: str, ) -> None: + if self._compression: + raw_data = self._compressor.compress(raw_data) filepath = os.path.join(self._out_dir, filename) with open(filepath, "wb") as out: out.write(raw_data) diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py index a7e31887d1b0a..4f6ab184f2fe7 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/builder/cache.py @@ -18,7 +18,7 @@ from torch.utils.data import IterableDataset from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized - +from torch.utils.data._utils.collate import default_collate from lightning.data.builder.reader import Reader from lightning.data.builder.writer import Writer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -33,8 +33,8 @@ def __init__( chunk_size: int = 2 << 26, ): super().__init__() - self._writer = Writer(cache_dir, data_format, chunk_size) - self._reader = Reader(cache_dir) + self._writer = Writer(cache_dir, data_format, chunk_size=chunk_size, compression=compression) + self._reader = Reader(cache_dir, compression=compression) self._cache_dir = cache_dir self._env = _DistributedEnv.detect() @@ -149,6 +149,17 @@ def __len__(self) -> int: return super().__len__() +class CacheCollateFn: + + def __init__(self): + self.collate_fn = default_collate + + def __call__(self, items): + if all(item is None for item in items): + return None + return self.collate_fn(items) + + class CacheDataLoader(DataLoader): def __init__( self, @@ -180,7 +191,7 @@ def __init__( batch_sampler = CacheBatchSampler( CacheSampler(dataset, generator, shuffle), batch_size, drop_last, shuffle, cache ) - super().__init__(dataset, *args, sampler=None, batch_sampler=batch_sampler, generator=generator, **kwargs) + super().__init__(dataset, *args, sampler=None, batch_sampler=batch_sampler, generator=generator, collate_fn=CacheCollateFn(), **kwargs) def _get_iterator(self) -> "_BaseDataLoaderIter": if self.num_workers == 0: diff --git a/src/lightning/data/builder/compression.py b/src/lightning/data/builder/compression.py new file mode 100644 index 0000000000000..9f69bb1330aea --- /dev/null +++ b/src/lightning/data/builder/compression.py @@ -0,0 +1,62 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, TypeVar +from abc import ABC, abstractmethod, abstractclassmethod +import zstd + +TCompressor = TypeVar("TCompressor", bound="Compressor") + +class Compressor(ABC): + + @abstractmethod + def compress(self, data: bytes) -> bytes: + pass + + @abstractmethod + def decompress(self, data: bytes) -> bytes: + pass + + @abstractclassmethod + def register(cls, compressors: Dict[str, TCompressor]): + pass + + +class ZSTDCompressor(Compressor): + + def __init__(self, level): + super().__init__() + self.level = level + self.extension = 'zstd' + + @property + def name(self): + return f"{self.extension}:{self.level}" + + def compress(self, data: bytes) -> bytes: + return zstd.compress(data, self.level) + + def decompress(self, data: bytes) -> bytes: + return zstd.decompress(data) + + @classmethod + def register(cls, compressors): + # default + compressors["zstd"] = ZSTDCompressor(4) + + for level in list(range(1, 23)): + compressors[f"zstd:{level}"] = ZSTDCompressor(level) + +_COMPRESSORS = {} + +ZSTDCompressor.register(_COMPRESSORS) diff --git a/src/lightning/data/builder/reader.py b/src/lightning/data/builder/reader.py index fc3ed5e7cfe26..eaf9b05ec3237 100644 --- a/src/lightning/data/builder/reader.py +++ b/src/lightning/data/builder/reader.py @@ -13,15 +13,16 @@ import json import os - +from typing import Optional import numpy as np class Reader: - def __init__(self, _cache_dir: str): + def __init__(self, _cache_dir: str, compression: Optional[str] = None): super().__init__() self._cache_dir = _cache_dir + self._compression = compression self._index = None self._intervals = None self._chunks = [] diff --git a/src/lightning/data/builder/serializers.py b/src/lightning/data/builder/serializers.py index 91efd4a0f1918..4e9a1c0b8befa 100644 --- a/src/lightning/data/builder/serializers.py +++ b/src/lightning/data/builder/serializers.py @@ -12,9 +12,18 @@ # limitations under the License. import numpy as np - +from lightning_utilities.core.imports import RequirementCache from lightning.data.builder.base import Serializer +_PIL_AVAILABLE = RequirementCache("PIL") + +if _PIL_AVAILABLE: + from PIL import Image + from PIL.JpegImagePlugin import JpegImageFile +else: + Image = Any + JpegImageFile = None + class PILSerializer(Serializer): def serialize(self, item: any) -> bytes: @@ -42,7 +51,23 @@ def deserialize(self, data: bytes) -> int: return int(data.decode("utf-8")) +class JPEGSerializer(Serializer): + def serialize(self, obj: Image) -> bytes: + if isinstance(obj, JpegImageFile) and hasattr(obj, 'filename'): + with open(obj.filename, 'rb') as f: + return f.read() + else: + out = BytesIO() + obj.save(out, format='JPEG') + return out.getvalue() + + def deserialize(self, data: bytes) -> Image: + inp = BytesIO(data) + return Image.open(inp) + + _SERIALIZERS = { "pil": PILSerializer(), "int": IntSerializer(), + "jpeg": JPEGSerializer(), } diff --git a/src/lightning/data/builder/writer.py b/src/lightning/data/builder/writer.py index 70ac1ff7fca63..af8a7b91a672f 100644 --- a/src/lightning/data/builder/writer.py +++ b/src/lightning/data/builder/writer.py @@ -79,7 +79,10 @@ def serialize(self, items: Dict[str, Any]) -> bytes: for key in self._dict_format_keys: serializer_name = self._dict_format[key] serializer = self._serializers[serializer_name] - serialized_data = serializer.serialize(items[key]) + if not isinstance(items[key], bytes): + serialized_data = serializer.serialize(items[key]) + else: + serialized_data = items[key] sizes.append(len(serialized_data)) data.append(serialized_data) @@ -106,8 +109,8 @@ def _create_chunk(self, filename: str) -> bytes: return num_items.tobytes() + offsets.tobytes() + self._config_data + sample_data def write_chunk(self, rank: int): - filename = f"chunk-{rank}-{self._counter}.bin" + if self._compression: + filename = f"chunk-{rank}-{self._counter}.{self._compression}.bin" + else: + filename = f"chunk-{rank}-{self._counter}.bin" self.write_file(self._create_chunk(filename), filename) - - def reset(self): - pass From 70332a92a6c6c6653af64fdde271b57ae150a160 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 13:04:09 +0000 Subject: [PATCH 05/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/builder/base.py | 1 + src/lightning/data/builder/cache.py | 16 ++++++++++++---- src/lightning/data/builder/compression.py | 11 ++++++----- src/lightning/data/builder/reader.py | 1 + src/lightning/data/builder/serializers.py | 7 ++++--- src/lightning/data/builder/writer.py | 5 +---- 6 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/lightning/data/builder/base.py b/src/lightning/data/builder/base.py index 25ea29cc98f9f..e8e8ff5412421 100644 --- a/src/lightning/data/builder/base.py +++ b/src/lightning/data/builder/base.py @@ -15,6 +15,7 @@ import os from abc import ABC, abstractmethod from typing import Any, Dict, Optional + from lightning.data.builder.compression import _COMPRESSORS diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py index 4f6ab184f2fe7..7b9711caf9793 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/builder/cache.py @@ -16,9 +16,10 @@ import numpy as np from torch.utils.data import IterableDataset +from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized -from torch.utils.data._utils.collate import default_collate + from lightning.data.builder.reader import Reader from lightning.data.builder.writer import Writer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -150,14 +151,13 @@ def __len__(self) -> int: class CacheCollateFn: - def __init__(self): self.collate_fn = default_collate def __call__(self, items): if all(item is None for item in items): return None - return self.collate_fn(items) + return self.collate_fn(items) class CacheDataLoader(DataLoader): @@ -191,7 +191,15 @@ def __init__( batch_sampler = CacheBatchSampler( CacheSampler(dataset, generator, shuffle), batch_size, drop_last, shuffle, cache ) - super().__init__(dataset, *args, sampler=None, batch_sampler=batch_sampler, generator=generator, collate_fn=CacheCollateFn(), **kwargs) + super().__init__( + dataset, + *args, + sampler=None, + batch_sampler=batch_sampler, + generator=generator, + collate_fn=CacheCollateFn(), + **kwargs + ) def _get_iterator(self) -> "_BaseDataLoaderIter": if self.num_workers == 0: diff --git a/src/lightning/data/builder/compression.py b/src/lightning/data/builder/compression.py index 9f69bb1330aea..ecae66dec3147 100644 --- a/src/lightning/data/builder/compression.py +++ b/src/lightning/data/builder/compression.py @@ -11,14 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractclassmethod, abstractmethod from typing import Dict, TypeVar -from abc import ABC, abstractmethod, abstractclassmethod + import zstd TCompressor = TypeVar("TCompressor", bound="Compressor") -class Compressor(ABC): +class Compressor(ABC): @abstractmethod def compress(self, data: bytes) -> bytes: pass @@ -33,11 +34,10 @@ def register(cls, compressors: Dict[str, TCompressor]): class ZSTDCompressor(Compressor): - def __init__(self, level): super().__init__() self.level = level - self.extension = 'zstd' + self.extension = "zstd" @property def name(self): @@ -50,13 +50,14 @@ def decompress(self, data: bytes) -> bytes: return zstd.decompress(data) @classmethod - def register(cls, compressors): + def register(cls, compressors): # default compressors["zstd"] = ZSTDCompressor(4) for level in list(range(1, 23)): compressors[f"zstd:{level}"] = ZSTDCompressor(level) + _COMPRESSORS = {} ZSTDCompressor.register(_COMPRESSORS) diff --git a/src/lightning/data/builder/reader.py b/src/lightning/data/builder/reader.py index eaf9b05ec3237..8e0f6b4368ac0 100644 --- a/src/lightning/data/builder/reader.py +++ b/src/lightning/data/builder/reader.py @@ -14,6 +14,7 @@ import json import os from typing import Optional + import numpy as np diff --git a/src/lightning/data/builder/serializers.py b/src/lightning/data/builder/serializers.py index 4e9a1c0b8befa..6b132c89a0828 100644 --- a/src/lightning/data/builder/serializers.py +++ b/src/lightning/data/builder/serializers.py @@ -13,6 +13,7 @@ import numpy as np from lightning_utilities.core.imports import RequirementCache + from lightning.data.builder.base import Serializer _PIL_AVAILABLE = RequirementCache("PIL") @@ -53,12 +54,12 @@ def deserialize(self, data: bytes) -> int: class JPEGSerializer(Serializer): def serialize(self, obj: Image) -> bytes: - if isinstance(obj, JpegImageFile) and hasattr(obj, 'filename'): - with open(obj.filename, 'rb') as f: + if isinstance(obj, JpegImageFile) and hasattr(obj, "filename"): + with open(obj.filename, "rb") as f: return f.read() else: out = BytesIO() - obj.save(out, format='JPEG') + obj.save(out, format="JPEG") return out.getvalue() def deserialize(self, data: bytes) -> Image: diff --git a/src/lightning/data/builder/writer.py b/src/lightning/data/builder/writer.py index af8a7b91a672f..4710f04dbe304 100644 --- a/src/lightning/data/builder/writer.py +++ b/src/lightning/data/builder/writer.py @@ -79,10 +79,7 @@ def serialize(self, items: Dict[str, Any]) -> bytes: for key in self._dict_format_keys: serializer_name = self._dict_format[key] serializer = self._serializers[serializer_name] - if not isinstance(items[key], bytes): - serialized_data = serializer.serialize(items[key]) - else: - serialized_data = items[key] + serialized_data = serializer.serialize(items[key]) if not isinstance(items[key], bytes) else items[key] sizes.append(len(serialized_data)) data.append(serialized_data) From a894cc42b652cba09f589ccc80aca5e9d4fbac41 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 27 Sep 2023 14:40:35 +0000 Subject: [PATCH 06/84] update --- src/lightning/data/builder/base.py | 8 ++--- src/lightning/data/builder/cache.py | 47 +++++++++++++++++++++++++++-- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/src/lightning/data/builder/base.py b/src/lightning/data/builder/base.py index 25ea29cc98f9f..7e371bc7f3789 100644 --- a/src/lightning/data/builder/base.py +++ b/src/lightning/data/builder/base.py @@ -35,10 +35,10 @@ def __init__( self._compression = compression self._name = name - if compression and compression not in _COMPRESSORS: - raise Exception(f"The provided compression {compression} isn't available in {sorted(_COMPRESSORS)}") - - self._compressor = _COMPRESSORS[compression] + if compression: + if compression not in _COMPRESSORS: + raise Exception(f"The provided compression {compression} isn't available in {sorted(_COMPRESSORS)}") + self._compressor = _COMPRESSORS[compression] self._current_chunk_size = 0 self._counter = 0 diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py index 4f6ab184f2fe7..758eccce1c7a4 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/builder/cache.py @@ -11,9 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from time import sleep import os from typing import Dict, Iterable, Iterator, Optional, Union - +from enum import Enum import numpy as np from torch.utils.data import IterableDataset from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter @@ -22,6 +23,10 @@ from lightning.data.builder.reader import Reader from lightning.data.builder.writer import Writer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv +import signal +import sys +from torch.utils.data import get_worker_info +from torch._utils import ExceptionWrapper class Cache: @@ -40,12 +45,19 @@ def __init__( self._env = _DistributedEnv.detect() self._worker_env = None self._rank = None + self._dataset_size = None + self._num_workers = None + + def setup(self, size, num_workers): + self._dataset_size = size + self._num_workers = num_workers @property def rank(self): if self._rank is None: self._worker_env = _WorkerEnv.detect() self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank + return self._rank @property @@ -160,6 +172,25 @@ def __call__(self, items): return self.collate_fn(items) +StopIterationEvent = "StopIterationEvent" + + +class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): + + def _next_index(self): + try: + return super()._next_index() + except StopIteration as e: + for worker_queue_idx in range(self._num_workers): + self._index_queues[worker_queue_idx].put((worker_queue_idx + self._send_idx, [StopIterationEvent])) + self._task_info[self._send_idx] = (worker_queue_idx,) + + # Get enough time to receive termination event + sleep(1) + + raise StopIteration() + + class CacheDataLoader(DataLoader): def __init__( self, @@ -167,6 +198,7 @@ def __init__( *args, sampler=None, batch_sampler=None, + num_workers=1, shuffle: bool = False, generator=None, batch_size=None, @@ -188,14 +220,23 @@ def __init__( raise Exception("The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") cache = cache[0] + cache.setup(len(dataset), num_workers) batch_sampler = CacheBatchSampler( CacheSampler(dataset, generator, shuffle), batch_size, drop_last, shuffle, cache ) - super().__init__(dataset, *args, sampler=None, batch_sampler=batch_sampler, generator=generator, collate_fn=CacheCollateFn(), **kwargs) + super().__init__( + dataset, *args, + sampler=None, + batch_sampler=batch_sampler, + generator=generator, + num_workers=num_workers, + collate_fn=CacheCollateFn(), + **kwargs + ) def _get_iterator(self) -> "_BaseDataLoaderIter": if self.num_workers == 0: return _SingleProcessDataLoaderIterPatch(self) else: self.check_worker_number_rationality() - return _MultiProcessingDataLoaderIter(self) + return _MultiProcessingDataLoaderIterPatch(self) From bf4741291aa1a4abd1434f4a375dbf1a3dcdcf4b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 14:43:59 +0000 Subject: [PATCH 07/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/builder/cache.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py index de6389c7a6b93..51131ad924e6a 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/builder/cache.py @@ -11,10 +11,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from time import sleep import os +from time import sleep from typing import Dict, Iterable, Iterator, Optional, Union -from enum import Enum + import numpy as np from torch.utils.data import IterableDataset from torch.utils.data._utils.collate import default_collate @@ -24,10 +24,6 @@ from lightning.data.builder.reader import Reader from lightning.data.builder.writer import Writer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv -import signal -import sys -from torch.utils.data import get_worker_info -from torch._utils import ExceptionWrapper class Cache: @@ -176,12 +172,11 @@ def __call__(self, items): class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): - def _next_index(self): try: return super()._next_index() - except StopIteration as e: - for worker_queue_idx in range(self._num_workers): + except StopIteration: + for worker_queue_idx in range(self._num_workers): self._index_queues[worker_queue_idx].put((worker_queue_idx + self._send_idx, [StopIterationEvent])) self._task_info[self._send_idx] = (worker_queue_idx,) From 35cae788298d94ce10144d3bdcf215368ae0fa8f Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 27 Sep 2023 19:31:55 +0000 Subject: [PATCH 08/84] update --- src/lightning/data/builder/cache.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py index 51131ad924e6a..aec033ac390ff 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/builder/cache.py @@ -42,12 +42,6 @@ def __init__( self._env = _DistributedEnv.detect() self._worker_env = None self._rank = None - self._dataset_size = None - self._num_workers = None - - def setup(self, size, num_workers): - self._dataset_size = size - self._num_workers = num_workers @property def rank(self): @@ -179,11 +173,16 @@ def _next_index(self): for worker_queue_idx in range(self._num_workers): self._index_queues[worker_queue_idx].put((worker_queue_idx + self._send_idx, [StopIterationEvent])) self._task_info[self._send_idx] = (worker_queue_idx,) + raise StopIteration - # Get enough time to receive termination event - sleep(1) - - raise StopIteration() + def _next_data(self): + try: + return super()._next_data() + except (KeyError, AssertionError, ValueError): + self._shutdown_workers() + return + except Exception as e: + raise e class CacheDataLoader(DataLoader): @@ -212,10 +211,9 @@ def __init__( cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] if not cache or len(cache) > 1: - raise Exception("The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") + raise Exception(f"The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") cache = cache[0] - cache.setup(len(dataset), num_workers) batch_sampler = CacheBatchSampler( CacheSampler(dataset, generator, shuffle), batch_size, drop_last, shuffle, cache ) @@ -226,6 +224,7 @@ def __init__( batch_sampler=batch_sampler, generator=generator, collate_fn=CacheCollateFn(), + num_workers=num_workers, **kwargs ) From 2376c3e2cb9284a6fba75bff7df09b819f2e6998 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Sep 2023 19:33:15 +0000 Subject: [PATCH 09/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/builder/cache.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/builder/cache.py index aec033ac390ff..63a47450f5179 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/builder/cache.py @@ -12,7 +12,6 @@ # limitations under the License. import os -from time import sleep from typing import Dict, Iterable, Iterator, Optional, Union import numpy as np @@ -180,7 +179,7 @@ def _next_data(self): return super()._next_data() except (KeyError, AssertionError, ValueError): self._shutdown_workers() - return + return None except Exception as e: raise e @@ -197,7 +196,7 @@ def __init__( generator=None, batch_size=None, drop_last=False, - **kwargs + **kwargs, ): if sampler: raise Exception("Passing a sampler isn't supoprt with the CacheDataLoader yet.") @@ -225,7 +224,7 @@ def __init__( generator=generator, collate_fn=CacheCollateFn(), num_workers=num_workers, - **kwargs + **kwargs, ) def _get_iterator(self) -> "_BaseDataLoaderIter": From 28fab534c25b9bd88346eaac989ab9c4c7aa69c7 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 12:34:21 +0100 Subject: [PATCH 10/84] update --- src/lightning/data/builder/base.py | 116 ----------- src/lightning/data/builder/writer.py | 113 ---------- src/lightning/data/cache/__init__.py | 3 + .../data/{builder => cache}/cache.py | 20 +- .../data/{builder => cache}/compression.py | 11 +- .../data/{builder => cache}/reader.py | 63 ++++-- .../data/{builder => cache}/serializers.py | 34 ++- src/lightning/data/cache/writer.py | 196 ++++++++++++++++++ .../tests_data/cache}/__init__.py | 0 tests/tests_data/cache/test_serializer.py | 82 ++++++++ tests/tests_data/cache/test_writer.py | 42 ++++ 11 files changed, 422 insertions(+), 258 deletions(-) delete mode 100644 src/lightning/data/builder/base.py delete mode 100644 src/lightning/data/builder/writer.py create mode 100644 src/lightning/data/cache/__init__.py rename src/lightning/data/{builder => cache}/cache.py (92%) rename src/lightning/data/{builder => cache}/compression.py (88%) rename src/lightning/data/{builder => cache}/reader.py (58%) rename src/lightning/data/{builder => cache}/serializers.py (74%) create mode 100644 src/lightning/data/cache/writer.py rename {src/lightning/data/builder => tests/tests_data/cache}/__init__.py (100%) create mode 100644 tests/tests_data/cache/test_serializer.py create mode 100644 tests/tests_data/cache/test_writer.py diff --git a/src/lightning/data/builder/base.py b/src/lightning/data/builder/base.py deleted file mode 100644 index 959c97c284601..0000000000000 --- a/src/lightning/data/builder/base.py +++ /dev/null @@ -1,116 +0,0 @@ -# Copyright The Lightning AI team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import os -from abc import ABC, abstractmethod -from typing import Any, Dict, Optional - -from lightning.data.builder.compression import _COMPRESSORS - - -class BaseWriter(ABC): - def __init__( - self, - out_dir: str, - chunk_size: int = 1 << 26, - compression: Optional[str] = None, - name: Optional[str] = None, - ): - self._out_dir = out_dir - - if not os.path.exists(self._out_dir): - raise Exception(f"The provided output directory {self._out_dir} doesn't exists.") - - self._chunk_size = chunk_size - self._compression = compression - self._name = name - - if compression: - if compression not in _COMPRESSORS: - raise Exception(f"The provided compression {compression} isn't available in {sorted(_COMPRESSORS)}") - self._compressor = _COMPRESSORS[compression] - - self._current_chunk_size = 0 - self._counter = 0 - self._serialized_items = [] - self._serializers: List[Serializer] = [] - self._chunks = [] - - @property - def is_cached(self) -> bool: - return os.path.exists(os.path.join(self._out_dir, "index.json")) - - def get_config(self) -> Dict[str, Any]: - return {"compression": self._compression, "chunk_size": self._chunk_size} - - @property - def available_serializers(self): - return self._serializers - - @abstractmethod - def serialize(self, data: any) -> bytes: - """Convert a given data type into its bytes format.""" - - @abstractmethod - def write_chunk(self, rank: int) -> None: - """Write the current chunk to the filesystem.""" - - def reset(self) -> None: - """Reset the writer to handle the next chunk.""" - self._serialized_items = [] - self._current_chunk_size = 0 - - def write(self, items: any, rank): - serialized_items = self.serialize(items) - serialized_items_size = len(serialized_items) - - if self._chunk_size < self._current_chunk_size + serialized_items_size: - self.write_chunk(rank) - self.reset() - self._counter += 1 - - self._serialized_items.append(serialized_items) - self._current_chunk_size += serialized_items_size - - def write_file( - self, - raw_data: bytes, - filename: str, - ) -> None: - if self._compression: - raw_data = self._compressor.compress(raw_data) - filepath = os.path.join(self._out_dir, filename) - with open(filepath, "wb") as out: - out.write(raw_data) - - def write_chunks_index(self, rank: int): - filepath = os.path.join(self._out_dir, f"{rank}.index.json") - with open(filepath, "w") as out: - json.dump({"chunks": self._chunks}, out, sort_keys=True) - - def done(self, rank: int): - if self._serialized_items: - self.write_chunk(rank) - self.write_chunks_index(rank) - self.reset() - - -class Serializer(ABC): - @abstractmethod - def serialize(self, data: any) -> bytes: - pass - - @abstractmethod - def deserialize(self, data: bytes) -> any: - pass diff --git a/src/lightning/data/builder/writer.py b/src/lightning/data/builder/writer.py deleted file mode 100644 index 4710f04dbe304..0000000000000 --- a/src/lightning/data/builder/writer.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright The Lightning AI team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from typing import Any, Dict, Optional - -import numpy as np -from lightning_utilities.core.imports import RequirementCache - -from lightning.data.builder.base import BaseWriter -from lightning.data.builder.serializers import _SERIALIZERS - -_PIL_AVAILABLE = RequirementCache("PIL") - -if _PIL_AVAILABLE: - from PIL import Image -else: - Image = Any - - -class Writer(BaseWriter): - def __init__( - self, - out_dir: str, - dict_format: Dict[str, str], - chunk_size: int = 1 << 26, - compression: Optional[str] = None, - name: Optional[str] = None, - ): - super().__init__(out_dir, chunk_size, compression, name) - - if not _PIL_AVAILABLE: - raise Exception("The ImageWriter requires pil to be installed") - - self._dict_format = {k.lower(): v for k, v in dict_format.items()} - self._dict_format_keys = sorted(self._dict_format.keys()) - self._serializers = _SERIALIZERS - - available_serializers = set(self._serializers.keys()) - selected_serializers = set(self._dict_format.values()) - if selected_serializers.difference(available_serializers): - raise Exception( - f"The provided dict_format don't match the provided serializers. Should be selected from {available_serializers}." - ) - - obj = self.get_config() - text = json.dumps(obj, sort_keys=True) - self._config_data = text.encode("utf-8") - - def get_config(self) -> Dict[str, Any]: - out = super().get_config() - out.update(self._dict_format) - return out - - def serialize(self, items: Dict[str, Any]) -> bytes: - if not isinstance(items, dict): - raise Exception("The provided data should be a dictionary.") - - keys = sorted(items.keys()) - - if keys != self._dict_format_keys: - raise Exception( - f"The provided keys don't match the provided format. Found {keys} instead of {self._dict_format_keys}." - ) - - sizes = [] - data = [] - - for key in self._dict_format_keys: - serializer_name = self._dict_format[key] - serializer = self._serializers[serializer_name] - serialized_data = serializer.serialize(items[key]) if not isinstance(items[key], bytes) else items[key] - - sizes.append(len(serialized_data)) - data.append(serialized_data) - - head = np.array(sizes, np.uint32).tobytes() - body = b"".join(data) - return head + body - - def _create_chunk(self, filename: str) -> bytes: - num_items = np.uint32(len(self._serialized_items)) - sizes = list(map(len, self._serialized_items)) - offsets = np.array([0] + sizes).cumsum().astype(np.uint32) - offsets += len(num_items.tobytes()) + len(offsets.tobytes()) + len(self._config_data) - sample_data = b"".join(self._serialized_items) - - self._chunks.append( - { - "samples": len(self._serialized_items), - "config": self.get_config(), - "filename": filename, - } - ) - - return num_items.tobytes() + offsets.tobytes() + self._config_data + sample_data - - def write_chunk(self, rank: int): - if self._compression: - filename = f"chunk-{rank}-{self._counter}.{self._compression}.bin" - else: - filename = f"chunk-{rank}-{self._counter}.bin" - self.write_file(self._create_chunk(filename), filename) diff --git a/src/lightning/data/cache/__init__.py b/src/lightning/data/cache/__init__.py new file mode 100644 index 0000000000000..bbfd93dfbc591 --- /dev/null +++ b/src/lightning/data/cache/__init__.py @@ -0,0 +1,3 @@ +from lightning.data.cache.cache import Cache, CacheDataLoader + +__all__ = ["Cache", "CacheDataLoader"] diff --git a/src/lightning/data/builder/cache.py b/src/lightning/data/cache/cache.py similarity index 92% rename from src/lightning/data/builder/cache.py rename to src/lightning/data/cache/cache.py index 63a47450f5179..67b33ee7c3759 100644 --- a/src/lightning/data/builder/cache.py +++ b/src/lightning/data/cache/cache.py @@ -17,11 +17,16 @@ import numpy as np from torch.utils.data import IterableDataset from torch.utils.data._utils.collate import default_collate -from torch.utils.data.dataloader import DataLoader, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter +from torch.utils.data.dataloader import ( + _BaseDataLoaderIter, + _MultiProcessingDataLoaderIter, + _SingleProcessDataLoaderIter, + DataLoader, +) from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized -from lightning.data.builder.reader import Reader -from lightning.data.builder.writer import Writer +from lightning.data.cache.reader import BinaryReader +from lightning.data.cache.writer import BinaryWriter from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -34,8 +39,8 @@ def __init__( chunk_size: int = 2 << 26, ): super().__init__() - self._writer = Writer(cache_dir, data_format, chunk_size=chunk_size, compression=compression) - self._reader = Reader(cache_dir, compression=compression) + self._writer = BinaryWriter(cache_dir, data_format, chunk_size=chunk_size, compression=compression) + self._reader = BinaryReader(cache_dir, compression=compression) self._cache_dir = cache_dir self._env = _DistributedEnv.detect() @@ -230,6 +235,5 @@ def __init__( def _get_iterator(self) -> "_BaseDataLoaderIter": if self.num_workers == 0: return _SingleProcessDataLoaderIterPatch(self) - else: - self.check_worker_number_rationality() - return _MultiProcessingDataLoaderIterPatch(self) + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIterPatch(self) diff --git a/src/lightning/data/builder/compression.py b/src/lightning/data/cache/compression.py similarity index 88% rename from src/lightning/data/builder/compression.py rename to src/lightning/data/cache/compression.py index ecae66dec3147..5b36f9eb8bf60 100644 --- a/src/lightning/data/builder/compression.py +++ b/src/lightning/data/cache/compression.py @@ -14,7 +14,12 @@ from abc import ABC, abstractclassmethod, abstractmethod from typing import Dict, TypeVar -import zstd +from lightning_utilities.core.imports import RequirementCache, requires + +_ZSTD_AVAILABLE = RequirementCache("zstd") + +if _ZSTD_AVAILABLE: + import zstd TCompressor = TypeVar("TCompressor", bound="Compressor") @@ -34,6 +39,7 @@ def register(cls, compressors: Dict[str, TCompressor]): class ZSTDCompressor(Compressor): + @requires("zstd") def __init__(self, level): super().__init__() self.level = level @@ -51,6 +57,9 @@ def decompress(self, data: bytes) -> bytes: @classmethod def register(cls, compressors): + if not _ZSTD_AVAILABLE: + return + # default compressors["zstd"] = ZSTDCompressor(4) diff --git a/src/lightning/data/builder/reader.py b/src/lightning/data/cache/reader.py similarity index 58% rename from src/lightning/data/builder/reader.py rename to src/lightning/data/cache/reader.py index 8e0f6b4368ac0..3c47a0c23efe6 100644 --- a/src/lightning/data/builder/reader.py +++ b/src/lightning/data/cache/reader.py @@ -17,16 +17,18 @@ import numpy as np +from lightning.data.cache.serializers import _SERIALIZERS -class Reader: + +class BinaryReader: def __init__(self, _cache_dir: str, compression: Optional[str] = None): super().__init__() - self._cache_dir = _cache_dir self._compression = compression self._index = None self._intervals = None - self._chunks = [] + self._chunks_data = {} + self._serializers = _SERIALIZERS def _try_read_index(self): files = os.listdir(self._cache_dir) @@ -42,20 +44,24 @@ def _try_read_index(self): self._index = index + for chunk in self._index["chunks"]: + chunk["data"] = None + self._chunks_data[chunk["filename"]] = chunk + self._intervals = [] cumsum_samples = np.cumsum([0] + [v["samples"] for v in self._index["chunks"]] + [1]) for i in range(len(cumsum_samples) - 1): self._intervals.append([cumsum_samples[i], cumsum_samples[i + 1]]) - print(self._intervals) - def _map_index_to_chunk_id(self, index): for interval_index, internal in enumerate(self._intervals): - print(internal, index) if internal[0] <= index and index < internal[1]: return interval_index return None + def _should_keep_in_memory(self): + return True + def read(self, index: int, rank): if self._index is None: self._try_read_index() @@ -66,13 +72,44 @@ def read(self, index: int, rank): chunk_id = self._map_index_to_chunk_id(index) chunk_config = self._index["chunks"][chunk_id] chunk_path = os.path.join(self._cache_dir, chunk_config["filename"]) - if not os.path.exists(chunk_path): - download_chunk(chunk_path) - - return self.load_data_from_chunk(chunk_path) - - def load_data_from_chunk(self, chunk_path): - pass + raw_item_data, item_config = self.load_item_from_chunk( + index, chunk_path, keep_in_memory=self._should_keep_in_memory() + ) + return self.deserialize(raw_item_data, item_config) + + def deserialize(self, raw_item_data, item_config): + sizes = [] + idx = 0 + data_format = item_config["data_format"] + keys = sorted(data_format) + for key in keys: + (size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32) + sizes.append(size) + idx += 4 + sample = {} + for key, size in zip(keys, sizes): + value = raw_item_data[idx : idx + size] + serializer = self._serializers[data_format[key]] + sample[key] = serializer.deserialize(value) + idx += size + return sample + + def load_item_from_chunk(self, index: int, chunk_path: str, keep_in_memory: bool = False): + chunk_name = os.path.basename(chunk_path) + begin, end = self._chunks_data[chunk_name]["mapping"][str(index)] + config = self._chunks_data[chunk_name]["config"] + if self._chunks_data[chunk_name]["data"] is not None: + return self._chunks_data[chunk_name]["data"][begin:end], config + + if keep_in_memory: + with open(chunk_path, "rb", 0) as fp: + data = fp.read() + self._chunks_data[chunk_name]["data"] = data + return data[begin:end], config + with open(chunk_path, "rb", 0) as fp: + fp.seek(begin) + data = fp.read(end - begin) + return data, config def get_length(self) -> int: if self._index is None: diff --git a/src/lightning/data/builder/serializers.py b/src/lightning/data/cache/serializers.py similarity index 74% rename from src/lightning/data/builder/serializers.py rename to src/lightning/data/cache/serializers.py index 6b132c89a0828..e940e36a9d1fb 100644 --- a/src/lightning/data/builder/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -11,22 +11,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC, abstractmethod +from io import BytesIO + import numpy as np from lightning_utilities.core.imports import RequirementCache -from lightning.data.builder.base import Serializer - _PIL_AVAILABLE = RequirementCache("PIL") if _PIL_AVAILABLE: from PIL import Image from PIL.JpegImagePlugin import JpegImageFile else: - Image = Any + Image = None JpegImageFile = None +class Serializer(ABC): + """The base interface for any serializers. + + A Serializer serialize and deserialize to and from bytes. + + """ + + @abstractmethod + def serialize(self, data: any) -> bytes: + pass + + @abstractmethod + def deserialize(self, data: bytes) -> any: + pass + + class PILSerializer(Serializer): + """The PILSerializer serialize and deserialize PIL Image to and from bytes.""" + def serialize(self, item: any) -> bytes: mode = item.mode.encode("utf-8") width, height = item.size @@ -45,6 +64,8 @@ def deserialize(self, data: bytes) -> any: class IntSerializer(Serializer): + """The IntSerializer serialize and deserialize integer to and from bytes.""" + def serialize(self, item: int) -> bytes: return str(item).encode("utf-8") @@ -53,14 +74,13 @@ def deserialize(self, data: bytes) -> int: class JPEGSerializer(Serializer): + """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" + def serialize(self, obj: Image) -> bytes: if isinstance(obj, JpegImageFile) and hasattr(obj, "filename"): with open(obj.filename, "rb") as f: return f.read() - else: - out = BytesIO() - obj.save(out, format="JPEG") - return out.getvalue() + raise TypeError(f"The provided object should be of type {JpegImageFile}") def deserialize(self, data: bytes) -> Image: inp = BytesIO(data) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py new file mode 100644 index 0000000000000..7b043b0513e84 --- /dev/null +++ b/src/lightning/data/cache/writer.py @@ -0,0 +1,196 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Any, Dict, Optional + +import numpy as np + +from lightning.data.cache.compression import _COMPRESSORS +from lightning.data.cache.serializers import _SERIALIZERS + + +class BinaryWriter: + def __init__( + self, + cache_dir: str, + data_format: Dict[str, str], + chunk_size: int = 1 << 26, + compression: Optional[str] = None, + ): + """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training. + + Arguments: + cache_dir: The path to where the chunks will be saved. + data_format: The format of the provided data to cache. Only dictionary are supported for now. + chunk_size: The maximum number of bytes to store within a chunk. + compression: The compression algorithm to use. + + """ + self._cache_dir = cache_dir + self._data_format = {k.lower(): v for k, v in data_format.items()} + self._chunk_size = chunk_size + self._compression = compression + + if not os.path.exists(self._cache_dir): + raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.") + + if len(self._data_format) == 0: + raise ValueError("The provided data format shouldn't be empty.") + + self._data_format_keys = sorted(self._data_format.keys()) + self._serializers = _SERIALIZERS + + available_serializers = set(self._serializers.keys()) + selected_serializers = set(self._data_format.values()) + if selected_serializers.difference(available_serializers): + raise ValueError( + "The provided data_format don't match the provided serializers." + " Should be selected from {sorted(available_serializers)}." + ) + + if self._compression: + if len(_COMPRESSORS) == 0: + raise ValueError("No compresion algorithms are installed.") + if self._compression not in _COMPRESSORS: + raise ValueError( + f"The provided compression {self._compression} isn't available in {sorted(_COMPRESSORS)}" + ) + self._compressor = _COMPRESSORS[self._compression] + + self._current_chunk_size = 0 + self._chunk_id = 0 + self._serialized_items = [] + self._chunks_info = [] + self._indexes = [] + obj = self.get_config() + text = json.dumps(obj, sort_keys=True) + self._config_data = text.encode("utf-8") + + def get_config(self) -> Dict[str, Any]: + out = super().get_config() + out.update(self._data_format) + return out + + def serialize(self, items: Dict[str, Any]) -> bytes: + if not isinstance(items, dict): + raise Exception("The provided data should be a dictionary.") + + keys = sorted(items.keys()) + + if keys != self._data_format_keys: + raise Exception( + f"The provided keys don't match the provided format. Found {keys} instead of {self._data_format_keys}." + ) + + sizes = [] + data = [] + + for key in self._data_format_keys: + serializer_name = self._data_format[key] + serializer = self._serializers[serializer_name] + serialized_data = serializer.serialize(items[key]) if not isinstance(items[key], bytes) else items[key] + + sizes.append(len(serialized_data)) + data.append(serialized_data) + + head = np.array(sizes, np.uint32).tobytes() + body = b"".join(data) + return head + body + + def _create_chunk(self, filename: str) -> bytes: + num_items = np.uint32(len(self._serialized_items)) + sizes = list(map(len, self._serialized_items)) + offsets = np.array([0] + sizes).cumsum().astype(np.uint32) + offsets += len(num_items.tobytes()) + len(offsets.tobytes()) + len(self._config_data) + sample_data = b"".join(self._serialized_items) + + data = num_items.tobytes() + offsets.tobytes() + self._config_data + sample_data + + offsets = offsets.tolist() + mapping = {} + for i in range(len(self._indexes)): + mapping[self._indexes[i]] = [offsets[i], offsets[i + 1]] + + assert len(mapping) == len(self._indexes) + + chunk_info = { + "samples": len(self._serialized_items), + "config": self.get_config(), + "filename": filename, + "mapping": mapping, + } + + self._chunks_info.append(chunk_info) + + return data + + def write_chunk(self, rank: int): + if self._compression: + filename = f"chunk-{rank}-{self._chunk_id}.{self._compression}.bin" + else: + filename = f"chunk-{rank}-{self._chunk_id}.bin" + self.write_file(self._create_chunk(filename), filename) + + @property + def is_cached(self) -> bool: + return os.path.exists(os.path.join(self._cache_dir, "index.json")) + + def get_config(self) -> Dict[str, Any]: + return {"compression": self._compression, "chunk_size": self._chunk_size, "data_format": self._data_format} + + @property + def available_serializers(self): + return self._serializers + + def reset(self) -> None: + """Reset the writer to handle the next chunk.""" + self._serialized_items = [] + self._indexes = [] + self._current_chunk_size = 0 + + def __setitem__(self, index, items: any, rank=0): + serialized_items = self.serialize(items) + serialized_items_size = len(serialized_items) + + if self._chunk_size < self._current_chunk_size + serialized_items_size: + self.write_chunk(rank) + self.reset() + self._chunk_id += 1 + + self._serialized_items.append(serialized_items) + self._current_chunk_size += serialized_items_size + self._indexes.append(index) + + def write_file( + self, + raw_data: bytes, + filename: str, + ) -> None: + if self._compression: + raw_data = self._compressor.compress(raw_data) + filepath = os.path.join(self._cache_dir, filename) + with open(filepath, "wb") as out: + out.write(raw_data) + + def write_chunks_index(self, rank: int): + filepath = os.path.join(self._cache_dir, f"{rank}.index.json") + with open(filepath, "w") as out: + json.dump({"chunks": self._chunks_info}, out, sort_keys=True) + + def done(self, rank: int = 0): + if self._serialized_items: + self.write_chunk(rank) + self.write_chunks_index(rank) + self.reset() diff --git a/src/lightning/data/builder/__init__.py b/tests/tests_data/cache/__init__.py similarity index 100% rename from src/lightning/data/builder/__init__.py rename to tests/tests_data/cache/__init__.py diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py new file mode 100644 index 0000000000000..165dcde340140 --- /dev/null +++ b/tests/tests_data/cache/test_serializer.py @@ -0,0 +1,82 @@ +import os + +import numpy as np +import pytest +from lightning_utilities.core.imports import RequirementCache + +from lightning.data.cache.serializers import _SERIALIZERS, IntSerializer, JPEGSerializer, PILSerializer + +_PIL_AVAILABLE = RequirementCache("PIL") + + +def test_serializers(): + assert sorted(_SERIALIZERS) == ["int", "jpeg", "pil"] + + +def test_int_serializer(): + serializer = IntSerializer() + + for i in range(100): + data = serializer.serialize(i) + assert isinstance(data, bytes) + assert i == serializer.deserialize(data) + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +@pytest.mark.parametrize("mode", ["L", "RGB"]) +def test_jpeg_serializer(mode, tmpdir): + serializer = JPEGSerializer() + + from PIL import Image + + path = os.path.join(tmpdir, "img.jpeg") + + size = {"RGB": (28, 28, 3), "L": (28, 28)}[mode] + np_data = np.random.randint(255, size=size, dtype=np.uint8) + img = Image.fromarray(np_data).convert(mode) + + np.testing.assert_array_equal(np_data, np.array(img)) + + with pytest.raises(TypeError, match="PIL.JpegImagePlugin.JpegImageFile"): + serializer.serialize(img) + + # from the JPEG image directly + img.save(path, format="jpeg", quality=100) + img = Image.open(path) + + data = serializer.serialize(img) + assert isinstance(data, bytes) + deserialized_img = np.asarray(serializer.deserialize(data)) + assert np.array_equal(np.asarray(img), np.array(deserialized_img)) + + # read bytes from the file + with open(path, "rb") as f: + data = f.read() + + assert isinstance(data, bytes) + deserialized_img = np.asarray(serializer.deserialize(data)) + + assert np.array_equal(np.asarray(img), np.array(deserialized_img)) + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +@pytest.mark.parametrize("mode", ["I", "L", "RGB"]) +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +def test_pil_serializer(mode): + serializer = PILSerializer() + + from PIL import Image + + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32) + img = Image.fromarray(np_data).convert(mode) + + data = serializer.serialize(img) + assert isinstance(data, bytes) + + deserialized_img = serializer.deserialize(data) + deserialized_img = deserialized_img.convert("I") + np_dec_data = np.asarray(deserialized_img, dtype=np.uint32) + assert isinstance(deserialized_img, Image.Image) + + # Validate data content + assert np.array_equal(np_data, np_dec_data) diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py new file mode 100644 index 0000000000000..e3e6e39947b6a --- /dev/null +++ b/tests/tests_data/cache/test_writer.py @@ -0,0 +1,42 @@ +import json +import os + +import pytest + +from lightning.data.cache.reader import BinaryReader +from lightning.data.cache.writer import BinaryWriter + + +def test_binary_writer(tmpdir): + with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): + BinaryWriter("dontexists", {}) + + with pytest.raises(ValueError, match="The provided data format shouldn't be empty."): + BinaryWriter(tmpdir, {}) + + with pytest.raises(ValueError, match="['int', 'jpeg', 'pil']"): + BinaryWriter(tmpdir, {"i": "random"}) + + with pytest.raises(ValueError, match="No compresion algorithms are installed."): + BinaryWriter(tmpdir, {"i": "int"}, compression="something_else") + + binary_writer = BinaryWriter(tmpdir, {"i": "int", "i+1": "int", "i+2": "int"}, chunk_size=90) + + for i in range(100): + binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2} + + assert len(os.listdir(tmpdir)) == 19 + binary_writer.done(0) + assert len(os.listdir(tmpdir)) == 21 + + with open(os.path.join(tmpdir, "0.index.json")) as f: + data = json.load(f) + + assert data["chunks"][0]["samples"] == 6 + assert data["chunks"][1]["samples"] == 5 + assert data["chunks"][-1]["samples"] == 4 + + reader = BinaryReader(tmpdir) + for i in range(100): + data = reader.read(i, 0) + assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} From 7f54886b1e7b7d89e58444a6c1cb3315218e0791 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 12:53:03 +0100 Subject: [PATCH 11/84] update --- src/lightning/data/cache/reader.py | 3 +- src/lightning/data/cache/writer.py | 11 +++++- tests/tests_data/cache/test_serializer.py | 1 - tests/tests_data/cache/test_writer.py | 46 ++++++++++++++++++++++- 4 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 3c47a0c23efe6..60747d4f12520 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -62,7 +62,7 @@ def _map_index_to_chunk_id(self, index): def _should_keep_in_memory(self): return True - def read(self, index: int, rank): + def read(self, index: int, rank: int = 0): if self._index is None: self._try_read_index() @@ -106,6 +106,7 @@ def load_item_from_chunk(self, index: int, chunk_path: str, keep_in_memory: bool data = fp.read() self._chunks_data[chunk_name]["data"] = data return data[begin:end], config + with open(chunk_path, "rb", 0) as fp: fp.seek(begin) data = fp.read(end - begin) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 7b043b0513e84..e38b76c89f300 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -115,9 +115,7 @@ def _create_chunk(self, filename: str) -> bytes: offsets = np.array([0] + sizes).cumsum().astype(np.uint32) offsets += len(num_items.tobytes()) + len(offsets.tobytes()) + len(self._config_data) sample_data = b"".join(self._serialized_items) - data = num_items.tobytes() + offsets.tobytes() + self._config_data + sample_data - offsets = offsets.tolist() mapping = {} for i in range(len(self._indexes)): @@ -165,12 +163,21 @@ def __setitem__(self, index, items: any, rank=0): serialized_items_size = len(serialized_items) if self._chunk_size < self._current_chunk_size + serialized_items_size: + if self._current_chunk_size == 0: + raise Exception( + f"The provided chunk_size {self._chunk_size} is too small." + f" You should use a multiple of {serialized_items_size} bytes." + ) self.write_chunk(rank) self.reset() self._chunk_id += 1 self._serialized_items.append(serialized_items) self._current_chunk_size += serialized_items_size + + # The dataset should be indexed in a non-sorted manner + if self._indexes: + assert self._indexes[-1] + 1 == index self._indexes.append(index) def write_file( diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index 165dcde340140..e4d8b89b44a4d 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -61,7 +61,6 @@ def test_jpeg_serializer(mode, tmpdir): @pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") @pytest.mark.parametrize("mode", ["I", "L", "RGB"]) -@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") def test_pil_serializer(mode): serializer = PILSerializer() diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index e3e6e39947b6a..f901aa42058dc 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -1,13 +1,17 @@ import json import os +import numpy as np import pytest +from lightning_utilities.core.imports import RequirementCache from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter +_PIL_AVAILABLE = RequirementCache("PIL") -def test_binary_writer(tmpdir): + +def test_binary_writer_with_ints(tmpdir): with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): BinaryWriter("dontexists", {}) @@ -38,5 +42,43 @@ def test_binary_writer(tmpdir): reader = BinaryReader(tmpdir) for i in range(100): - data = reader.read(i, 0) + data = reader.read(i) assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +def test_binary_writer_with_jpeg_and_int(tmpdir): + """Validate the writer and reader can serialize / deserialize a pair of image and label.""" + from PIL import Image + + cache_dir = os.path.join(tmpdir, "chunks") + os.makedirs(cache_dir, exist_ok=True) + binary_writer = BinaryWriter(cache_dir, {"x": "jpeg", "y": "int"}, chunk_size=2 << 12) + + imgs = [] + + for i in range(100): + path = os.path.join(tmpdir, f"img{i}.jpeg") + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8) + img = Image.fromarray(np_data).convert("L") + img.save(path, format="jpeg", quality=100) + img = Image.open(path) + imgs.append(img) + binary_writer[i] = {"x": img, "y": i} + + assert len(os.listdir(cache_dir)) == 24 + binary_writer.done(0) + assert len(os.listdir(cache_dir)) == 26 + + with open(os.path.join(cache_dir, "0.index.json")) as f: + data = json.load(f) + + assert data["chunks"][0]["samples"] == 4 + assert data["chunks"][1]["samples"] == 4 + assert data["chunks"][-1]["samples"] == 4 + + reader = BinaryReader(cache_dir) + for i in range(100): + data = reader.read(i) + assert data["x"] == imgs[i] + assert data["y"] == i From c1b197f374f32805108eca1ba0365e2e78a09bff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 11:54:21 +0000 Subject: [PATCH 12/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/cache/cache.py | 2 +- tests/tests_data/cache/test_serializer.py | 3 +-- tests/tests_data/cache/test_writer.py | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 67b33ee7c3759..0d3860c976a2c 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -18,10 +18,10 @@ from torch.utils.data import IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import ( + DataLoader, _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, - DataLoader, ) from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index e4d8b89b44a4d..dd7a743fd11c7 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -2,9 +2,8 @@ import numpy as np import pytest -from lightning_utilities.core.imports import RequirementCache - from lightning.data.cache.serializers import _SERIALIZERS, IntSerializer, JPEGSerializer, PILSerializer +from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index f901aa42058dc..a2ea6261a85bc 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -3,10 +3,9 @@ import numpy as np import pytest -from lightning_utilities.core.imports import RequirementCache - from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter +from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") From 81f5a79a81066182721c2342d0fd244d79d7d83d Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 16:27:52 +0100 Subject: [PATCH 13/84] update --- src/lightning/data/cache/cache.py | 153 ++++++---- src/lightning/data/cache/env.py | 49 ++++ src/lightning/data/cache/reader.py | 19 +- src/lightning/data/cache/serializers.py | 8 +- src/lightning/data/cache/worker.py | 366 ++++++++++++++++++++++++ src/lightning/data/cache/writer.py | 38 ++- tests/tests_data/cache/test_cache.py | 62 ++++ 7 files changed, 621 insertions(+), 74 deletions(-) create mode 100644 src/lightning/data/cache/env.py create mode 100644 src/lightning/data/cache/worker.py create mode 100644 tests/tests_data/cache/test_cache.py diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 67b33ee7c3759..2a4fe877c0536 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -11,8 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging import os -from typing import Dict, Iterable, Iterator, Optional, Union +from typing import Dict, Iterator, List, Optional, Union import numpy as np from torch.utils.data import IterableDataset @@ -23,11 +24,12 @@ _SingleProcessDataLoaderIter, DataLoader, ) -from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized +from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler, Sized from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter -from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv + +logger = logging.Logger(__name__) class Cache: @@ -43,31 +45,20 @@ def __init__( self._reader = BinaryReader(cache_dir, compression=compression) self._cache_dir = cache_dir - self._env = _DistributedEnv.detect() - self._worker_env = None - self._rank = None - - @property - def rank(self): - if self._rank is None: - self._worker_env = _WorkerEnv.detect() - self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank - - return self._rank - + # TODO: Find a way to make this faster @property def filled(self) -> bool: files = os.listdir(self._cache_dir) return any(f.endswith("index.json") for f in files) def __setitem__(self, index, data): - self._writer.write(data, self.rank) + self._writer[index] = data def __getitem__(self, index): - self._reader.read(index, self.rank) + return self._reader.read(index) def done(self): - self._writer.done(self.rank) + self._writer.done() def __len__(self): return self._reader.get_length() @@ -87,22 +78,6 @@ def _next_data(self): raise StopIteration() -class CacheSampler(Sampler): - def __init__(self, dataset, generator, shuffle): - super().__init__(dataset) - - if shuffle: - self._sampler = RandomSampler(dataset, generator=generator) # type: ignore[arg-type] - else: - self._sampler = SequentialSampler(dataset) # type: ignore[arg-type] - - def __iter__(self): - return iter(self._sampler) - - def __len__(self) -> int: - return len(self._sampler) - - class IteratorSampler(Sampler[int]): r"""Samples elements sequentially, always in the same order. @@ -122,17 +97,90 @@ def __len__(self) -> int: return len(self.data_source) +class CacheSampler(Sampler): + def __init__(self, dataset_size: int, num_workers: int, batch_size: int): + super().__init__(None) + self.batch_size = batch_size + self.num_workers = num_workers + self.indices = range(dataset_size) + worker_size = dataset_size // self.num_workers + self.samplers = [] + for worker_idx in range(num_workers): + is_last = worker_idx == num_workers - 1 + worker_indices = self.indices[ + worker_idx * worker_size : dataset_size if is_last else (worker_idx + 1) * worker_size + ] + self.samplers.append(IteratorSampler(worker_indices)) + self.iterators = [] + self._done = set() + assert sum([len(s) for s in self.samplers]) == dataset_size + self.worker_id = 0 + self.indice_id = 0 + + @property + def done(self) -> bool: + return len(self._done) == len(self.iterators) + + def __iter__(self): + self._done = set() + + for sampler in self.samplers: + self.iterators.append(iter(sampler)) + + return self + + def __next__(self): + while len(self._done) != self.iterators: + try: + data = next(self.iterators[self.worker_id]) + self.indice_id += 1 + if self.indice_id == self.batch_size: + self.indice_id = 0 + self.worker_id = (self.worker_id + 1) % self.num_workers + return data + except StopIteration: + self._done.add(self.worker_id) + self.indice_id = 0 + self.worker_id = (self.worker_id + 1) % self.num_workers + raise StopIteration + + class CacheBatchSampler(BatchSampler): def __init__( - self, sampler: Union[Sampler[int], Iterable[int]], batch_size: int, drop_last: bool, shuffle: bool, cache: Cache + self, dataset_size: int, num_workers: int, batch_size: int, drop_last: bool, shuffle: bool, cache: Cache ): + if num_workers >= 1: + sampler = CacheSampler(dataset_size, num_workers, batch_size) + else: + sampler = SequentialSampler(range(dataset_size)) super().__init__(sampler, batch_size, drop_last) self._cache = cache self._shuffle = shuffle + self._num_workers = num_workers + + def __modified_iter__(self) -> Iterator[List[int]]: + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + iterator = iter(self.sampler) + batch = [] + while not self.sampler.done: + try: + idx = next(iterator) + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + except StopIteration: + if self.sampler.done: + yield batch + return + yield batch + batch = [] def __iter__(self): if self._cache.filled and self._shuffle: return self.__iter__cache__() + if self._num_workers >= 1: + return self.__modified_iter__() return super().__iter__() def __iter__cache__(self): @@ -170,23 +218,14 @@ def __call__(self, items): class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): - def _next_index(self): - try: - return super()._next_index() - except StopIteration: - for worker_queue_idx in range(self._num_workers): - self._index_queues[worker_queue_idx].put((worker_queue_idx + self._send_idx, [StopIterationEvent])) - self._task_info[self._send_idx] = (worker_queue_idx,) - raise StopIteration + def __init__(self, loader): + # Patch PyTorch worker loop + from torch.utils.data._utils import worker - def _next_data(self): - try: - return super()._next_data() - except (KeyError, AssertionError, ValueError): - self._shutdown_workers() - return None - except Exception as e: - raise e + from lightning.data.cache.worker import _worker_loop + + worker._worker_loop = _worker_loop + super().__init__(loader) class CacheDataLoader(DataLoader): @@ -196,10 +235,10 @@ def __init__( *args, sampler=None, batch_sampler=None, - num_workers=1, + num_workers=0, shuffle: bool = False, generator=None, - batch_size=None, + batch_size=1, drop_last=False, **kwargs, ): @@ -218,14 +257,14 @@ def __init__( raise Exception(f"The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") cache = cache[0] - batch_sampler = CacheBatchSampler( - CacheSampler(dataset, generator, shuffle), batch_size, drop_last, shuffle, cache - ) + if not cache.filled and shuffle: + logger.info("Shuffle is ignored during caching phase") + super().__init__( dataset, *args, sampler=None, - batch_sampler=batch_sampler, + batch_sampler=CacheBatchSampler(len(dataset), num_workers, batch_size, drop_last, shuffle, cache), generator=generator, collate_fn=CacheCollateFn(), num_workers=num_workers, diff --git a/src/lightning/data/cache/env.py b/src/lightning/data/cache/env.py new file mode 100644 index 0000000000000..95c203b879a57 --- /dev/null +++ b/src/lightning/data/cache/env.py @@ -0,0 +1,49 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lightning.data.cache.worker import get_worker_info + + +class _WorkerEnv: + """Contains the environment for the current dataloader within the current training process. + + Args: + world_size: The number of dataloader workers for the current training process + rank: The rank of the current worker within the number of workers + + """ + + def __init__(self, world_size: int, rank: int): + self.world_size = world_size + self.rank = rank + + @classmethod + def detect(cls) -> "_WorkerEnv": + """Automatically detects the number of workers and the current rank. + + Note: + This only works reliably within a dataloader worker as otherwise the necessary information won't be present. + In such a case it will default to 1 worker + + """ + worker_info = get_worker_info() + num_workers = worker_info.num_workers if worker_info is not None else 1 + current_worker_rank = worker_info.id if worker_info is not None else 0 + + return cls(world_size=num_workers, rank=current_worker_rank) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(world_size: {self.world_size}, rank: {self.rank})" + + def __str__(self) -> str: + return repr(self) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 60747d4f12520..6a25e7ce4ec8d 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -17,7 +17,9 @@ import numpy as np +from lightning.data.cache.env import _WorkerEnv from lightning.data.cache.serializers import _SERIALIZERS +from lightning.data.datasets.env import _DistributedEnv class BinaryReader: @@ -30,6 +32,18 @@ def __init__(self, _cache_dir: str, compression: Optional[str] = None): self._chunks_data = {} self._serializers = _SERIALIZERS + self._env = _DistributedEnv.detect() + self._worker_env = None + self._rank = None + + @property + def rank(self): + if self._rank is None: + self._worker_env = _WorkerEnv.detect() + self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank + + return self._rank + def _try_read_index(self): files = os.listdir(self._cache_dir) indexes_filepath = sorted([os.path.join(self._cache_dir, f) for f in files if f.endswith("index.json")]) @@ -49,7 +63,8 @@ def _try_read_index(self): self._chunks_data[chunk["filename"]] = chunk self._intervals = [] - cumsum_samples = np.cumsum([0] + [v["samples"] for v in self._index["chunks"]] + [1]) + num_samples = [v["samples"] for v in self._index["chunks"]] + cumsum_samples = np.cumsum([0] + num_samples) for i in range(len(cumsum_samples) - 1): self._intervals.append([cumsum_samples[i], cumsum_samples[i + 1]]) @@ -62,7 +77,7 @@ def _map_index_to_chunk_id(self, index): def _should_keep_in_memory(self): return True - def read(self, index: int, rank: int = 0): + def read(self, index: int): if self._index is None: self._try_read_index() diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index e940e36a9d1fb..2d74baf1b5a57 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -77,10 +77,14 @@ class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" def serialize(self, obj: Image) -> bytes: - if isinstance(obj, JpegImageFile) and hasattr(obj, "filename"): + if isinstance(obj, JpegImageFile): + if not hasattr(obj, "filename"): + raise ValueError( + "The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method." + ) with open(obj.filename, "rb") as f: return f.read() - raise TypeError(f"The provided object should be of type {JpegImageFile}") + raise TypeError(f"The provided object should be of type {JpegImageFile}. Found {obj}.") def deserialize(self, data: bytes) -> Image: inp = BytesIO(data) diff --git a/src/lightning/data/cache/worker.py b/src/lightning/data/cache/worker.py new file mode 100644 index 0000000000000..04d7174cf3fcd --- /dev/null +++ b/src/lightning/data/cache/worker.py @@ -0,0 +1,366 @@ +r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. + +These **needs** to be in global scope since Py2 doesn't support serializing static methods. + +""" + +# Taken from PyTorch + +import os +import queue +import random +from dataclasses import dataclass +from typing import Optional, TYPE_CHECKING, Union + +import torch +from torch._utils import ExceptionWrapper +from torch.utils.data._utils import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling + +if TYPE_CHECKING: + from torch.utils.data import Dataset + +if IS_WINDOWS: + import ctypes + from ctypes.wintypes import BOOL, DWORD, HANDLE + + # On Windows, the parent ID of the worker process remains unchanged when the manager process + # is gone, and the only way to check it through OS is to let the worker have a process handle + # of the manager and ask if the process status has changed. + class ManagerWatchdog: + def __init__(self): + self.manager_pid = os.getppid() + + # mypy cannot detect this code is windows only + self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] + self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) + self.kernel32.OpenProcess.restype = HANDLE + self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) + self.kernel32.WaitForSingleObject.restype = DWORD + + # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx + SYNCHRONIZE = 0x00100000 + self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) + + if not self.manager_handle: + raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] + + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx + self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 + return not self.manager_dead + +else: + + class ManagerWatchdog: # type: ignore[no-redef] + def __init__(self): + self.manager_pid = os.getppid() + self.manager_dead = False + + def is_alive(self): + if not self.manager_dead: + self.manager_dead = os.getppid() != self.manager_pid + return not self.manager_dead + + +_worker_info = None + + +class WorkerInfo: + id: int + num_workers: int + seed: int + dataset: "Dataset" + __initialized = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + self.__keys = tuple(kwargs.keys()) + self.__initialized = True + + def __setattr__(self, key, val): + if self.__initialized: + raise RuntimeError(f"Cannot assign attributes to {self.__class__.__name__} objects") + return super().__setattr__(key, val) + + def __repr__(self): + items = [] + for k in self.__keys: + items.append(f"{k}={getattr(self, k)}") + return "{}({})".format(self.__class__.__name__, ", ".join(items)) + + +def get_worker_info() -> Optional[WorkerInfo]: + r"""Returns the information about the current + :class:`~torch.utils.data.DataLoader` iterator worker process. + + When called in a worker, this returns an object guaranteed to have the + following attributes: + + * :attr:`id`: the current worker id. + * :attr:`num_workers`: the total number of workers. + * :attr:`seed`: the random seed set for the current worker. This value is + determined by main process RNG and the worker id. See + :class:`~torch.utils.data.DataLoader`'s documentation for more details. + * :attr:`dataset`: the copy of the dataset object in **this** process. Note + that this will be a different object in a different process than the one + in the main process. + + When called in the main process, this returns ``None``. + + .. note:: + When used in a :attr:`worker_init_fn` passed over to + :class:`~torch.utils.data.DataLoader`, this method can be useful to + set up each worker process differently, for instance, using ``worker_id`` + to configure the ``dataset`` object to only read a specific fraction of a + sharded dataset, or use ``seed`` to seed other libraries used in dataset + code. + """ + return _worker_info + + +r"""Dummy class used to signal the end of an IterableDataset""" + + +@dataclass(frozen=True) +class _IterableDatasetStopIteration: + worker_id: int + + +r"""Dummy class used to resume the fetching when worker reuse is enabled""" + + +@dataclass(frozen=True) +class _ResumeIteration: + seed: Optional[int] = None + + +# The function `_generate_state` is adapted from `numpy.random.SeedSequence` +# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx +# It's MIT licensed, here is the copyright: + +# Copyright (c) 2015 Melissa E. O'Neill +# Copyright (c) 2019 NumPy Developers +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +# This function generates an array of int32 as the seed for +# `numpy.random`, in order to prevent state collision due to same +# seed and algorithm for `numpy.random` and `random` modules. +# TODO: Implement `SeedSequence` like object for `torch.random` +def _generate_state(base_seed, worker_id): + INIT_A = 0x43B0D7E5 + MULT_A = 0x931E8875 + INIT_B = 0x8B51F9DD + MULT_B = 0x58F38DED + MIX_MULT_L = 0xCA01F9DD + MIX_MULT_R = 0x4973F715 + XSHIFT = 4 * 8 // 2 + MASK32 = 0xFFFFFFFF + + entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] + pool = [0] * 4 + + hash_const_A = INIT_A + + def hash(value): + nonlocal hash_const_A + value = (value ^ hash_const_A) & MASK32 + hash_const_A = (hash_const_A * MULT_A) & MASK32 + value = (value * hash_const_A) & MASK32 + value = (value ^ (value >> XSHIFT)) & MASK32 + return value + + def mix(x, y): + result_x = (MIX_MULT_L * x) & MASK32 + result_y = (MIX_MULT_R * y) & MASK32 + result = (result_x - result_y) & MASK32 + result = (result ^ (result >> XSHIFT)) & MASK32 + return result + + # Add in the entropy to the pool. + for i in range(len(pool)): + pool[i] = hash(entropy[i]) + + # Mix all bits together so late bits can affect earlier bits. + for i_src in range(len(pool)): + for i_dst in range(len(pool)): + if i_src != i_dst: + pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) + + hash_const_B = INIT_B + state = [] + for i_dst in range(4): + data_val = pool[i_dst] + data_val = (data_val ^ hash_const_B) & MASK32 + hash_const_B = (hash_const_B * MULT_B) & MASK32 + data_val = (data_val * hash_const_B) & MASK32 + data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 + state.append(data_val) + return state + + +def _worker_loop( + dataset_kind, + dataset, + index_queue, + data_queue, + done_event, + auto_collation, + collate_fn, + drop_last, + base_seed, + init_fn, + worker_id, + num_workers, + persistent_workers, + shared_seed, +): + # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the + # logic of this function. + + try: + # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal + # module's handlers are executed after Python returns from C low-level + # handlers, likely when the same fatal signal had already happened + # again. + # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers + signal_handling._set_worker_signal_handlers() + + torch.set_num_threads(1) + seed = base_seed + worker_id + random.seed(seed) + torch.manual_seed(seed) + if HAS_NUMPY: + np_seed = _generate_state(base_seed, worker_id) + import numpy as np + + np.random.seed(np_seed) + + from torch.utils.data import IterDataPipe + from torch.utils.data.graph_settings import apply_random_seed + + shared_rng = torch.Generator() + if isinstance(dataset, IterDataPipe): + assert shared_seed is not None + shared_rng.manual_seed(shared_seed) + dataset = apply_random_seed(dataset, shared_rng) + + global _worker_info + _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset) + + from torch.utils.data import _DatasetKind + + init_exception = None + + try: + if init_fn is not None: + init_fn(worker_id) + + fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) + except Exception: + init_exception = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}") + + # When using Iterable mode, some worker can exit earlier than others due + # to the IterableDataset behaving differently for different workers. + # When such things happen, an `_IterableDatasetStopIteration` object is + # sent over to the main process with the ID of this worker, so that the + # main process won't send more tasks to this worker, and will send + # `None` to this worker to properly exit it. + # + # Note that we cannot set `done_event` from a worker as it is shared + # among all processes. Instead, we set the `iteration_end` flag to + # signify that the iterator is exhausted. When either `done_event` or + # `iteration_end` is set, we skip all processing step and just wait for + # `None`. + iteration_end = False + + watchdog = ManagerWatchdog() + + while watchdog.is_alive(): + try: + r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) + except queue.Empty: + continue + if isinstance(r, _ResumeIteration): + # Acknowledge the main process + data_queue.put((r, None)) + iteration_end = False + + if isinstance(dataset, IterDataPipe): + assert r.seed is not None + shared_rng.manual_seed(r.seed) + dataset = apply_random_seed(dataset, shared_rng) + + # Recreate the fetcher for worker-reuse policy + fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) + continue + elif r is None: + # Received the final signal + assert done_event.is_set() or iteration_end + break + elif done_event.is_set() or iteration_end: + # `done_event` is set. But I haven't received the final signal + # (None) yet. I will keep continuing until get it, and skip the + # processing steps. + continue + idx, index = r + data: Union[_IterableDatasetStopIteration, ExceptionWrapper] + if init_exception is not None: + data = init_exception + init_exception = None + else: + try: + data = fetcher.fetch(index) + except Exception as e: + if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable: + data = _IterableDatasetStopIteration(worker_id) + # Set `iteration_end` + # (1) to save future `next(...)` calls, and + # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. + iteration_end = True + else: + # It is important that we don't store exc_info in a variable. + # `ExceptionWrapper` does the correct thing. + # See NOTE [ Python Traceback Reference Cycle Problem ] + data = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}") + data_queue.put((idx, data)) + del data, idx, index, r # save memory + except KeyboardInterrupt: + # Main process will raise KeyboardInterrupt anyways. + pass + if done_event.is_set(): + ####### ADDTIONATIONAL CODE ####### + + from lightning.data.cache import Cache + + # required to ensure the cache is persisted + if dataset_kind == _DatasetKind.Map: + for v in fetcher.dataset.__dict__.values(): + if isinstance(v, Cache): + v.done() + + ####### ADDTIONATIONAL CODE ####### + + data_queue.cancel_join_thread() + data_queue.close() diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index e38b76c89f300..13763890edb2b 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -18,7 +18,9 @@ import numpy as np from lightning.data.cache.compression import _COMPRESSORS +from lightning.data.cache.env import _WorkerEnv from lightning.data.cache.serializers import _SERIALIZERS +from lightning.data.datasets.env import _DistributedEnv class BinaryWriter: @@ -78,6 +80,17 @@ def __init__( text = json.dumps(obj, sort_keys=True) self._config_data = text.encode("utf-8") + self._env = _DistributedEnv.detect() + self._worker_env = None + self._rank = None + + @property + def rank(self): + if self._rank is None: + self._worker_env = _WorkerEnv.detect() + self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank + return self._rank + def get_config(self) -> Dict[str, Any]: out = super().get_config() out.update(self._data_format) @@ -101,7 +114,6 @@ def serialize(self, items: Dict[str, Any]) -> bytes: serializer_name = self._data_format[key] serializer = self._serializers[serializer_name] serialized_data = serializer.serialize(items[key]) if not isinstance(items[key], bytes) else items[key] - sizes.append(len(serialized_data)) data.append(serialized_data) @@ -134,11 +146,11 @@ def _create_chunk(self, filename: str) -> bytes: return data - def write_chunk(self, rank: int): + def write_chunk(self): if self._compression: - filename = f"chunk-{rank}-{self._chunk_id}.{self._compression}.bin" + filename = f"chunk-{self.rank}-{self._chunk_id}.{self._compression}.bin" else: - filename = f"chunk-{rank}-{self._chunk_id}.bin" + filename = f"chunk-{self.rank}-{self._chunk_id}.bin" self.write_file(self._create_chunk(filename), filename) @property @@ -158,7 +170,7 @@ def reset(self) -> None: self._indexes = [] self._current_chunk_size = 0 - def __setitem__(self, index, items: any, rank=0): + def __setitem__(self, index, items: any): serialized_items = self.serialize(items) serialized_items_size = len(serialized_items) @@ -168,16 +180,16 @@ def __setitem__(self, index, items: any, rank=0): f"The provided chunk_size {self._chunk_size} is too small." f" You should use a multiple of {serialized_items_size} bytes." ) - self.write_chunk(rank) + self.write_chunk() self.reset() self._chunk_id += 1 self._serialized_items.append(serialized_items) self._current_chunk_size += serialized_items_size - # The dataset should be indexed in a non-sorted manner if self._indexes: - assert self._indexes[-1] + 1 == index + assert self._indexes[-1] == index - 1 + self._indexes.append(index) def write_file( @@ -191,13 +203,13 @@ def write_file( with open(filepath, "wb") as out: out.write(raw_data) - def write_chunks_index(self, rank: int): - filepath = os.path.join(self._cache_dir, f"{rank}.index.json") + def write_chunks_index(self): + filepath = os.path.join(self._cache_dir, f"{self.rank}.index.json") with open(filepath, "w") as out: json.dump({"chunks": self._chunks_info}, out, sort_keys=True) - def done(self, rank: int = 0): + def done(self): if self._serialized_items: - self.write_chunk(rank) - self.write_chunks_index(rank) + self.write_chunk() + self.write_chunks_index() self.reset() diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py new file mode 100644 index 0000000000000..35fd138f61a3b --- /dev/null +++ b/tests/tests_data/cache/test_cache.py @@ -0,0 +1,62 @@ +import os + +import numpy as np +import pytest +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import Dataset + +from lightning import seed_everything +from lightning.data.cache import Cache, CacheDataLoader + +_PIL_AVAILABLE = RequirementCache("PIL") + + +class ImageDataset(Dataset): + def __init__(self, tmpdir, cache, size, num_classes): + from PIL import Image + + self.data = [] + self.cache = cache + + seed_everything(42) + + for i in range(size): + path = os.path.join(tmpdir, f"img{i}.jpeg") + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8) + img = Image.fromarray(np_data).convert("L") + img.save(path, format="jpeg", quality=100) + # read bytes from the file + with open(path, "rb") as f: + data = f.read() + self.data.append({"image": data, "class": np.random.randint(num_classes)}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + if self.cache.filled: + return self.cache[index] + self.cache[index] = self.data[index] + return None + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +def test_cache_for_image_dataset(num_workers, tmpdir): + import io + + from PIL import Image + + cache_dir = os.path.join(tmpdir, "cache") + os.makedirs(cache_dir) + cache = Cache(cache_dir, data_format={"image": "jpeg", "class": "int"}, chunk_size=2 << 12) + dataset = ImageDataset(tmpdir, cache, 85, 10) + for _ in CacheDataLoader(dataset, num_workers=num_workers, batch_size=4): + pass + + for i in range(len(dataset)): + cached_data = dataset[i] + original_data = dataset.data[i] + assert cached_data["class"] == original_data["class"] + original_image = Image.open(io.BytesIO(original_data["image"])) + assert cached_data["image"] == original_image From c2ee47c693bf647cff56b58d0c6ab8a1e165e649 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 15:31:28 +0000 Subject: [PATCH 14/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/cache/worker.py | 2 +- tests/tests_data/cache/test_cache.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lightning/data/cache/worker.py b/src/lightning/data/cache/worker.py index 04d7174cf3fcd..92396157d8663 100644 --- a/src/lightning/data/cache/worker.py +++ b/src/lightning/data/cache/worker.py @@ -10,7 +10,7 @@ import queue import random from dataclasses import dataclass -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union import torch from torch._utils import ExceptionWrapper diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 35fd138f61a3b..15b392d3209a9 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -2,11 +2,10 @@ import numpy as np import pytest -from lightning_utilities.core.imports import RequirementCache -from torch.utils.data import Dataset - from lightning import seed_everything from lightning.data.cache import Cache, CacheDataLoader +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import Dataset _PIL_AVAILABLE = RequirementCache("PIL") From d0708e052e7b2a5fad91ac41e80cf1a62bd18c2a Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 19:38:35 +0100 Subject: [PATCH 15/84] update --- src/lightning/data/cache/cache.py | 35 ++++++++++--------- src/lightning/data/cache/writer.py | 37 +++++++++++++++++++-- tests/tests_data/cache/test_cache.py | 48 +++++++++++++++++++++++---- tests/tests_data/cache/test_writer.py | 21 +++++++++++- 4 files changed, 116 insertions(+), 25 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index e812c23ff6164..29b2e1d6f8eb6 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -19,12 +19,12 @@ from torch.utils.data import IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import ( - DataLoader, _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, + DataLoader, ) -from torch.utils.data.sampler import BatchSampler, Sampler, SequentialSampler, Sized +from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter @@ -103,6 +103,7 @@ def __init__(self, dataset_size: int, num_workers: int, batch_size: int): self.batch_size = batch_size self.num_workers = num_workers self.indices = range(dataset_size) + self.dataset_size = dataset_size worker_size = dataset_size // self.num_workers self.samplers = [] for worker_idx in range(num_workers): @@ -117,6 +118,9 @@ def __init__(self, dataset_size: int, num_workers: int, batch_size: int): self.worker_id = 0 self.indice_id = 0 + def __len__(self) -> int: + return self.dataset_size + @property def done(self) -> bool: return len(self._done) == len(self.iterators) @@ -149,8 +153,10 @@ class CacheBatchSampler(BatchSampler): def __init__( self, dataset_size: int, num_workers: int, batch_size: int, drop_last: bool, shuffle: bool, cache: Cache ): - if num_workers >= 1: + if not cache.filled and num_workers > 1: sampler = CacheSampler(dataset_size, num_workers, batch_size) + elif shuffle: + sampler = RandomSampler(range(dataset_size)) else: sampler = SequentialSampler(range(dataset_size)) super().__init__(sampler, batch_size, drop_last) @@ -158,7 +164,7 @@ def __init__( self._shuffle = shuffle self._num_workers = num_workers - def __modified_iter__(self) -> Iterator[List[int]]: + def __iter_ordered__(self) -> Iterator[List[int]]: # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 iterator = iter(self.sampler) batch = [] @@ -178,25 +184,25 @@ def __modified_iter__(self) -> Iterator[List[int]]: def __iter__(self): if self._cache.filled and self._shuffle: - return self.__iter__cache__() - if self._num_workers >= 1: - return self.__modified_iter__() + return self.__iter_from_chunks__() + if self._num_workers > 1 and not self._cache.filled: + return self.__iter_ordered__() return super().__iter__() - def __iter__cache__(self): - chunk_intervals = self._cache.get_chunk_interval()[:-1] + def __iter_from_chunks__(self): + chunk_intervals = self._cache.get_chunk_interval() shuffled_chunk_intervals = np.random.permutation(chunk_intervals) - dataset = [] + indices = [] for interval in shuffled_chunk_intervals: interval_indices = np.arange(interval[0], interval[1]) shuffled_interval_indices = np.random.permutation(interval_indices) - dataset.extend(shuffled_interval_indices.tolist()) + indices.extend(shuffled_interval_indices.tolist()) - if len(dataset) != len(self.sampler): + if len(indices) != len(self.sampler): raise Exception("The generated indices don't match the initial length of the sampler.") - self.sampler = IteratorSampler(dataset) + self.sampler = IteratorSampler(indices) return super().__iter__() @@ -214,9 +220,6 @@ def __call__(self, items): return self.collate_fn(items) -StopIterationEvent = "StopIterationEvent" - - class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): def __init__(self, loader): # Patch PyTorch worker loop diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 13763890edb2b..c06636d031c82 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -23,6 +23,16 @@ from lightning.data.datasets.env import _DistributedEnv +def cloud_path(cache_dir: str) -> Optional[str]: + cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) + project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) + cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) + + if cluster_id is None or project_id is None or cloud_space_id is None: + return None + return f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/content/{cache_dir}/" + + class BinaryWriter: def __init__( self, @@ -83,6 +93,7 @@ def __init__( self._env = _DistributedEnv.detect() self._worker_env = None self._rank = None + self._is_done = False @property def rank(self): @@ -94,6 +105,10 @@ def rank(self): def get_config(self) -> Dict[str, Any]: out = super().get_config() out.update(self._data_format) + + cloud_path = self.get_cloud_path(self._cache_dir) + if cloud_path: + out["cloud_path"] = cloud_path return out def serialize(self, items: Dict[str, Any]) -> bytes: @@ -209,7 +224,25 @@ def write_chunks_index(self): json.dump({"chunks": self._chunks_info}, out, sort_keys=True) def done(self): + if self._is_done: + return if self._serialized_items: self.write_chunk() - self.write_chunks_index() - self.reset() + self.write_chunks_index() + self.reset() + self._is_done = True + + @classmethod + def get_cloud_path(cls, cache_dir: str) -> Optional[str]: + cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) + project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) + cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) + + if cluster_id is None or project_id is None or cloud_space_id is None: + return None + cache_dir = cache_dir.replace("~/", "").replace("~", "").replace("/teamspace/studios/this_studio/", "") + if cache_dir.startswith("/"): + cache_dir = cache_dir[1:] + return os.path.join( + f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/code/content", cache_dir + ) diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 35fd138f61a3b..c56dd98fd38b0 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -9,11 +9,13 @@ from lightning.data.cache import Cache, CacheDataLoader _PIL_AVAILABLE = RequirementCache("PIL") +_TORCH_VISION_AVAILABLE = RequirementCache("torchvision") class ImageDataset(Dataset): - def __init__(self, tmpdir, cache, size, num_classes): + def __init__(self, tmpdir, cache, size, num_classes, use_transform: bool = False): from PIL import Image + from torchvision import transforms as T self.data = [] self.cache = cache @@ -30,27 +32,37 @@ def __init__(self, tmpdir, cache, size, num_classes): data = f.read() self.data.append({"image": data, "class": np.random.randint(num_classes)}) + self.use_transform = use_transform + self.transform = T.Compose([T.ToTensor()]) + def __len__(self): return len(self.data) def __getitem__(self, index): if self.cache.filled: - return self.cache[index] - self.cache[index] = self.data[index] + data = self.cache[index] + if self.use_transform: + data["image"] = self.transform(data["image"]).unsqueeze(0) + return data + self.cache[index] = {**self.data[index], "index": index} return None -@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +@pytest.mark.skipif( + condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" +) @pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_cache_for_image_dataset(num_workers, tmpdir): import io from PIL import Image + dataset_size = 85 + cache_dir = os.path.join(tmpdir, "cache") os.makedirs(cache_dir) - cache = Cache(cache_dir, data_format={"image": "jpeg", "class": "int"}, chunk_size=2 << 12) - dataset = ImageDataset(tmpdir, cache, 85, 10) + cache = Cache(cache_dir, data_format={"image": "jpeg", "class": "int", "index": "int"}, chunk_size=2 << 12) + dataset = ImageDataset(tmpdir, cache, dataset_size, 10) for _ in CacheDataLoader(dataset, num_workers=num_workers, batch_size=4): pass @@ -60,3 +72,27 @@ def test_cache_for_image_dataset(num_workers, tmpdir): assert cached_data["class"] == original_data["class"] original_image = Image.open(io.BytesIO(original_data["image"])) assert cached_data["image"] == original_image + + dataset.use_transform = True + + indexes = [] + for batch in CacheDataLoader(dataset, num_workers=num_workers, batch_size=4): + indexes.extend(batch["index"].numpy().tolist()) + + assert indexes == list(range(dataset_size)) + + seed_everything(42) + + dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True) + + indexes = [] + for batch in dataloader: + indexes.extend(batch["index"].numpy().tolist()) + + assert len(indexes) == dataset_size + + indexes2 = [] + for batch in dataloader: + indexes2.extend(batch["index"].numpy().tolist()) + + assert indexes2 != indexes diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index a2ea6261a85bc..3aa12093784cd 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -3,9 +3,10 @@ import numpy as np import pytest +from lightning_utilities.core.imports import RequirementCache + from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter -from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") @@ -81,3 +82,21 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): data = reader.read(i) assert data["x"] == imgs[i] assert data["y"] == i + + +def test_binary_writer_config(monkeypatch): + assert BinaryWriter.get_cloud_path("") is None + + monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "cluster_id") + monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id") + monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "cloud_space_id") + + prefix = "s3://cluster_id/projects/project_id/cloudspaces/cloud_space_id/code/content/" + + assert BinaryWriter.get_cloud_path("") == prefix + assert BinaryWriter.get_cloud_path("~") == prefix + assert BinaryWriter.get_cloud_path("~/") == prefix + assert BinaryWriter.get_cloud_path("/") == prefix + assert BinaryWriter.get_cloud_path("/data") == f"{prefix}/data" + assert BinaryWriter.get_cloud_path("~/data") == f"{prefix}/data" + assert BinaryWriter.get_cloud_path("/teamspace/studios/this_studio/data") == f"{prefix}/data" From 6f6ce5fd6ba69a8ab48a8ff137379ca19d9538a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 18:40:30 +0000 Subject: [PATCH 16/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/cache/cache.py | 2 +- tests/tests_data/cache/test_writer.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 29b2e1d6f8eb6..576e04dab0e2e 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -19,10 +19,10 @@ from torch.utils.data import IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import ( + DataLoader, _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, - DataLoader, ) from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index 3aa12093784cd..cfd5a993569dc 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -3,10 +3,9 @@ import numpy as np import pytest -from lightning_utilities.core.imports import RequirementCache - from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter +from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") From 23bc5c494e5de90b495414f25a543ab9a1ea5a61 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 19:45:07 +0100 Subject: [PATCH 17/84] update --- src/lightning/data/cache/writer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index c06636d031c82..7f647ebc62b70 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -105,10 +105,12 @@ def rank(self): def get_config(self) -> Dict[str, Any]: out = super().get_config() out.update(self._data_format) - cloud_path = self.get_cloud_path(self._cache_dir) if cloud_path: out["cloud_path"] = cloud_path + user_id = os.getenv("LIGHTNING_USER_ID", None) + if user_id: + out["user_id"] = user_id return out def serialize(self, items: Dict[str, Any]) -> bytes: From f79a29249965ab1b8a2b16ae9b48b82403219f48 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 20:43:22 +0100 Subject: [PATCH 18/84] update --- src/lightning/data/cache/cache.py | 122 ++++++++++++++++++++++++--- src/lightning/data/cache/reader.py | 5 +- tests/tests_data/cache/test_cache.py | 56 +++++++++--- 3 files changed, 158 insertions(+), 25 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 576e04dab0e2e..1956d5a9f8087 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -19,15 +19,17 @@ from torch.utils.data import IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import ( - DataLoader, _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, + DataLoader, ) +from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter +from lightning.data.datasets.env import _DistributedEnv logger = logging.Logger(__name__) @@ -44,12 +46,16 @@ def __init__( self._writer = BinaryWriter(cache_dir, data_format, chunk_size=chunk_size, compression=compression) self._reader = BinaryReader(cache_dir, compression=compression) self._cache_dir = cache_dir + self._is_done = False # TODO: Find a way to make this faster @property def filled(self) -> bool: + if self._is_done: + return True files = os.listdir(self._cache_dir) - return any(f.endswith("index.json") for f in files) + self._is_done = any(f.endswith("index.json") for f in files) + return self._is_done def __setitem__(self, index, data): self._writer[index] = data @@ -149,17 +155,99 @@ def __next__(self): raise StopIteration +class DistributedCacheSampler(Sampler): + def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: int, batch_size: int): + super().__init__(None) + self.batch_size = batch_size + self.num_workers = num_workers + self.indices = range(dataset_size) + self.dataset_size = dataset_size + replica_size = dataset_size // num_replicas + worker_size = dataset_size // (num_replicas * self.num_workers) + self.samplers = [] + for replica_idx in range(num_replicas): + if replica_idx != rank: + continue + + is_last_replica = replica_idx == num_replicas - 1 + start_replica = replica_idx * replica_size + end_replica = dataset_size if is_last_replica else (replica_idx + 1) * replica_size + replica_indices = self.indices[start_replica:end_replica] + + replica_size = len(replica_indices) + + for worker_idx in range(num_workers): + is_last_worker = worker_idx == num_workers - 1 + start_worker = worker_idx * worker_size + end_worker = replica_size if is_last_worker else (worker_idx + 1) * worker_size + worker_indices = replica_indices[start_worker:end_worker] + self.samplers.append(IteratorSampler(worker_indices)) + + self.iterators = [] + self._done = set() + + assert sum([len(s) for s in self.samplers]) == replica_size + self.worker_id = 0 + self.indice_id = 0 + + def __len__(self) -> int: + return self.dataset_size + + @property + def done(self) -> bool: + return len(self._done) == len(self.iterators) + + def __iter__(self): + self._done = set() + + for sampler in self.samplers: + self.iterators.append(iter(sampler)) + + return self + + def __next__(self): + while len(self._done) != self.iterators: + try: + data = next(self.iterators[self.worker_id]) + self.indice_id += 1 + if self.indice_id == self.batch_size: + self.indice_id = 0 + self.worker_id = (self.worker_id + 1) % self.num_workers + return data + except StopIteration: + self._done.add(self.worker_id) + self.indice_id = 0 + self.worker_id = (self.worker_id + 1) % self.num_workers + raise StopIteration + + class CacheBatchSampler(BatchSampler): def __init__( - self, dataset_size: int, num_workers: int, batch_size: int, drop_last: bool, shuffle: bool, cache: Cache + self, + dataset_size: int, + num_replicas: int, + rank: int, + num_workers: int, + batch_size: int, + drop_last: bool, + shuffle: bool, + cache: Cache, ): - if not cache.filled and num_workers > 1: - sampler = CacheSampler(dataset_size, num_workers, batch_size) - elif shuffle: - sampler = RandomSampler(range(dataset_size)) + if num_replicas == 1: + if not cache.filled and num_workers > 1: + sampler = CacheSampler(dataset_size, num_workers, batch_size) + elif shuffle: + sampler = RandomSampler(range(dataset_size)) + else: + sampler = SequentialSampler(range(dataset_size)) else: - sampler = SequentialSampler(range(dataset_size)) + if not cache.filled: + sampler = DistributedCacheSampler(dataset_size, num_replicas, rank, num_workers, batch_size) + else: + sampler = DistributedSampler(range(dataset_size), num_replicas=num_replicas, rank=rank, shuffle=shuffle) super().__init__(sampler, batch_size, drop_last) + self._num_replicas = num_replicas + self._rank = rank self._cache = cache self._shuffle = shuffle self._num_workers = num_workers @@ -246,10 +334,10 @@ def __init__( **kwargs, ): if sampler: - raise Exception("Passing a sampler isn't supoprt with the CacheDataLoader yet.") + raise Exception("Passing a sampler isn't supoprt with the CacheDataLoader.") if batch_sampler: - raise Exception("Passing a batch_sampler isn't supoprt with the CacheDataLoader yet.") + raise Exception("Passing a batch_sampler isn't supoprt with the CacheDataLoader.") if isinstance(dataset, IterableDataset): raise Exception("Only map-based dataset are supported by the CacheDataLoader for now.") @@ -263,11 +351,23 @@ def __init__( if not cache.filled and shuffle: logger.info("Shuffle is ignored during caching phase") + distributed_env = _DistributedEnv.detect() + batch_sampler = CacheBatchSampler( + len(dataset), + distributed_env.world_size, + distributed_env.global_rank, + num_workers, + batch_size, + drop_last, + shuffle, + cache, + ) + super().__init__( dataset, *args, sampler=None, - batch_sampler=CacheBatchSampler(len(dataset), num_workers, batch_size, drop_last, shuffle, cache), + batch_sampler=batch_sampler, generator=generator, collate_fn=CacheCollateFn(), num_workers=num_workers, diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 6a25e7ce4ec8d..a193b71f385f0 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -85,7 +85,10 @@ def read(self, index: int): raise Exception("The reader index isn't defined.") chunk_id = self._map_index_to_chunk_id(index) - chunk_config = self._index["chunks"][chunk_id] + try: + chunk_config = self._index["chunks"][chunk_id] + except Exception as e: + raise Exception(f"Found {str(self._index['chunks'])} {chunk_id}" + str(e)) chunk_path = os.path.join(self._cache_dir, chunk_config["filename"]) raw_item_data, item_config = self.load_item_from_chunk( index, chunk_path, keep_in_memory=self._should_keep_in_memory() diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 6dc983931873d..f5de28c041703 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -1,12 +1,16 @@ import os +from functools import partial import numpy as np import pytest -from lightning import seed_everything -from lightning.data.cache import Cache, CacheDataLoader from lightning_utilities.core.imports import RequirementCache from torch.utils.data import Dataset +from lightning import seed_everything +from lightning.data.cache import Cache, CacheDataLoader +from lightning.data.datasets.env import _DistributedEnv +from lightning.fabric import Fabric + _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") @@ -47,11 +51,7 @@ def __getitem__(self, index): return None -@pytest.mark.skipif( - condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" -) -@pytest.mark.parametrize("num_workers", [0, 1, 2]) -def test_cache_for_image_dataset(num_workers, tmpdir): +def cache_for_image_dataset(num_workers, tmpdir): import io from PIL import Image @@ -59,10 +59,11 @@ def test_cache_for_image_dataset(num_workers, tmpdir): dataset_size = 85 cache_dir = os.path.join(tmpdir, "cache") - os.makedirs(cache_dir) + cache = Cache(cache_dir, data_format={"image": "jpeg", "class": "int", "index": "int"}, chunk_size=2 << 12) dataset = ImageDataset(tmpdir, cache, dataset_size, 10) - for _ in CacheDataLoader(dataset, num_workers=num_workers, batch_size=4): + dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4) + for _ in dataloader: pass for i in range(len(dataset)): @@ -72,13 +73,15 @@ def test_cache_for_image_dataset(num_workers, tmpdir): original_image = Image.open(io.BytesIO(original_data["image"])) assert cached_data["image"] == original_image + distributed_env = _DistributedEnv.detect() dataset.use_transform = True - indexes = [] - for batch in CacheDataLoader(dataset, num_workers=num_workers, batch_size=4): - indexes.extend(batch["index"].numpy().tolist()) + if distributed_env.world_size == 1: + indexes = [] + for batch in CacheDataLoader(dataset, num_workers=num_workers, batch_size=4): + indexes.extend(batch["index"].numpy().tolist()) - assert indexes == list(range(dataset_size)) + assert len(indexes) == dataset_size seed_everything(42) @@ -95,3 +98,30 @@ def test_cache_for_image_dataset(num_workers, tmpdir): indexes2.extend(batch["index"].numpy().tolist()) assert indexes2 != indexes + + +@pytest.mark.skipif( + condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" +) +@pytest.mark.parametrize("num_workers", [0, 1, 2]) +def test_cache_for_image_dataset(num_workers, tmpdir): + cache_dir = os.path.join(tmpdir, "cache") + os.makedirs(cache_dir) + + cache_for_image_dataset(num_workers, tmpdir) + + +def fabric_cache_for_image_dataset(_, num_workers, tmpdir): + cache_for_image_dataset(num_workers, tmpdir) + + +@pytest.mark.skipif( + condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" +) +@pytest.mark.parametrize("num_workers", [2]) +def test_cache_for_image_dataset_distributed(num_workers, tmpdir): + cache_dir = os.path.join(tmpdir, "cache") + os.makedirs(cache_dir) + + fabric = Fabric(accelerator="cpu", devices=2, strategy="ddp_spawn") + fabric.launch(partial(fabric_cache_for_image_dataset, num_workers=num_workers, tmpdir=tmpdir)) From 477003882c3daacdd13d9b7df8aab8e88360802c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 19:44:45 +0000 Subject: [PATCH 19/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/cache/cache.py | 2 +- tests/tests_data/cache/test_cache.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 1956d5a9f8087..fa9906da141f4 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -19,10 +19,10 @@ from torch.utils.data import IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import ( + DataLoader, _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, - DataLoader, ) from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index f5de28c041703..60bf8f1998c06 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -3,13 +3,12 @@ import numpy as np import pytest -from lightning_utilities.core.imports import RequirementCache -from torch.utils.data import Dataset - from lightning import seed_everything from lightning.data.cache import Cache, CacheDataLoader from lightning.data.datasets.env import _DistributedEnv from lightning.fabric import Fabric +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import Dataset _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") From dd0991c92881ba0ae9f513e979703bee6bf0d62b Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 21:06:18 +0100 Subject: [PATCH 20/84] update --- src/lightning/data/cache/cache.py | 39 ++++++++++++++++++++++------ src/lightning/data/cache/reader.py | 12 ++++----- tests/tests_data/cache/test_cache.py | 23 +++++++++------- 3 files changed, 50 insertions(+), 24 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 1956d5a9f8087..ecf2d538d3dcc 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -47,14 +47,19 @@ def __init__( self._reader = BinaryReader(cache_dir, compression=compression) self._cache_dir = cache_dir self._is_done = False + self._distributed_env = _DistributedEnv.detect() + self._num_workers = None # TODO: Find a way to make this faster @property def filled(self) -> bool: + if self._num_workers is None: + raise Exception("The Cache wasn't setup properly. HINT: Did you use the CacheDataLoader ?") if self._is_done: return True files = os.listdir(self._cache_dir) - self._is_done = any(f.endswith("index.json") for f in files) + index_files = [f for f in files if f.endswith("index.json")] + self._is_done = len(index_files) == self._distributed_env.world_size * self._num_workers return self._is_done def __setitem__(self, index, data): @@ -281,14 +286,31 @@ def __iter_from_chunks__(self): chunk_intervals = self._cache.get_chunk_interval() shuffled_chunk_intervals = np.random.permutation(chunk_intervals) - indices = [] - for interval in shuffled_chunk_intervals: - interval_indices = np.arange(interval[0], interval[1]) - shuffled_interval_indices = np.random.permutation(interval_indices) - indices.extend(shuffled_interval_indices.tolist()) + if self._num_replicas == 1: + indices = [] + for interval in shuffled_chunk_intervals: + interval_indices = np.arange(interval[0], interval[1]) + shuffled_interval_indices = np.random.permutation(interval_indices) + indices.extend(shuffled_interval_indices.tolist()) - if len(indices) != len(self.sampler): - raise Exception("The generated indices don't match the initial length of the sampler.") + if len(indices) != len(self.sampler): + raise Exception("The generated indices don't match the initial length of the sampler.") + + else: + chunks_per_replica = len(shuffled_chunk_intervals) // self._num_replicas + for replica_idx in range(self._num_replicas): + if replica_idx != self._rank: + continue + is_last_replica = replica_idx == self._num_replicas - 1 + start_replica = replica_idx * chunks_per_replica + end_replica = len(chunk_intervals) if is_last_replica else (replica_idx + 1) * chunks_per_replica + shuffled_chunk_intervals_replica = shuffled_chunk_intervals[start_replica:end_replica] + + indices = [] + for interval in shuffled_chunk_intervals_replica: + interval_indices = np.arange(interval[0], interval[1]) + shuffled_interval_indices = np.random.permutation(interval_indices) + indices.extend(shuffled_interval_indices.tolist()) self.sampler = IteratorSampler(indices) @@ -348,6 +370,7 @@ def __init__( raise Exception(f"The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") cache = cache[0] + cache._num_workers = num_workers if not cache.filled and shuffle: logger.info("Shuffle is ignored during caching phase") diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index a193b71f385f0..ed8e2983c022e 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -72,7 +72,7 @@ def _map_index_to_chunk_id(self, index): for interval_index, internal in enumerate(self._intervals): if internal[0] <= index and index < internal[1]: return interval_index - return None + raise Exception(f"The chunk interval weren't properly defined. Found {self._intervals} for inded {index}.") def _should_keep_in_memory(self): return True @@ -85,10 +85,7 @@ def read(self, index: int): raise Exception("The reader index isn't defined.") chunk_id = self._map_index_to_chunk_id(index) - try: - chunk_config = self._index["chunks"][chunk_id] - except Exception as e: - raise Exception(f"Found {str(self._index['chunks'])} {chunk_id}" + str(e)) + chunk_config = self._index["chunks"][chunk_id] chunk_path = os.path.join(self._cache_dir, chunk_config["filename"]) raw_item_data, item_config = self.load_item_from_chunk( index, chunk_path, keep_in_memory=self._should_keep_in_memory() @@ -114,7 +111,10 @@ def deserialize(self, raw_item_data, item_config): def load_item_from_chunk(self, index: int, chunk_path: str, keep_in_memory: bool = False): chunk_name = os.path.basename(chunk_path) - begin, end = self._chunks_data[chunk_name]["mapping"][str(index)] + try: + begin, end = self._chunks_data[chunk_name]["mapping"][str(index)] + except Exception as e: + raise Exception(f"Medata: ({self._chunks_data[chunk_name]}), Error: {e}") config = self._chunks_data[chunk_name]["config"] if self._chunks_data[chunk_name]["data"] is not None: return self._chunks_data[chunk_name]["data"][begin:end], config diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index f5de28c041703..f5cb7838f4754 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -1,3 +1,4 @@ +import io import os from functools import partial @@ -51,19 +52,20 @@ def __getitem__(self, index): return None -def cache_for_image_dataset(num_workers, tmpdir): - import io - +def cache_for_image_dataset(num_workers, tmpdir, fabric=None): from PIL import Image dataset_size = 85 cache_dir = os.path.join(tmpdir, "cache") + distributed_env = _DistributedEnv.detect() cache = Cache(cache_dir, data_format={"image": "jpeg", "class": "int", "index": "int"}, chunk_size=2 << 12) dataset = ImageDataset(tmpdir, cache, dataset_size, 10) dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4) - for _ in dataloader: + dataloader_iter = iter(dataloader) + + for _ in dataloader_iter: pass for i in range(len(dataset)): @@ -73,7 +75,6 @@ def cache_for_image_dataset(num_workers, tmpdir): original_image = Image.open(io.BytesIO(original_data["image"])) assert cached_data["image"] == original_image - distributed_env = _DistributedEnv.detect() dataset.use_transform = True if distributed_env.world_size == 1: @@ -86,15 +87,17 @@ def cache_for_image_dataset(num_workers, tmpdir): seed_everything(42) dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True) + dataloader_iter = iter(dataloader) indexes = [] - for batch in dataloader: + for batch in dataloader_iter: indexes.extend(batch["index"].numpy().tolist()) - assert len(indexes) == dataset_size + if distributed_env.world_size == 1: + assert len(indexes) == dataset_size indexes2 = [] - for batch in dataloader: + for batch in dataloader_iter: indexes2.extend(batch["index"].numpy().tolist()) assert indexes2 != indexes @@ -111,8 +114,8 @@ def test_cache_for_image_dataset(num_workers, tmpdir): cache_for_image_dataset(num_workers, tmpdir) -def fabric_cache_for_image_dataset(_, num_workers, tmpdir): - cache_for_image_dataset(num_workers, tmpdir) +def fabric_cache_for_image_dataset(fabric, num_workers, tmpdir): + cache_for_image_dataset(num_workers, tmpdir, fabric=fabric) @pytest.mark.skipif( From 32fe81148f086ca3cbd54142a73b73dbf7b0280c Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 21:10:52 +0100 Subject: [PATCH 21/84] update --- src/lightning/data/cache/__init__.py | 13 +++++++++++++ src/lightning/data/cache/worker.py | 19 ++++++++++++++++--- tests/tests_data/cache/test_cache.py | 18 ++++++++++++++++-- tests/tests_data/cache/test_serializer.py | 16 +++++++++++++++- tests/tests_data/cache/test_writer.py | 16 +++++++++++++++- 5 files changed, 75 insertions(+), 7 deletions(-) diff --git a/src/lightning/data/cache/__init__.py b/src/lightning/data/cache/__init__.py index bbfd93dfbc591..08a18fab987c6 100644 --- a/src/lightning/data/cache/__init__.py +++ b/src/lightning/data/cache/__init__.py @@ -1,3 +1,16 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from lightning.data.cache.cache import Cache, CacheDataLoader __all__ = ["Cache", "CacheDataLoader"] diff --git a/src/lightning/data/cache/worker.py b/src/lightning/data/cache/worker.py index 92396157d8663..f971e5001f516 100644 --- a/src/lightning/data/cache/worker.py +++ b/src/lightning/data/cache/worker.py @@ -1,16 +1,29 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copy pasted from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py + lines 370-378 +# TODO: Delete me when this is addressed https://github.com/pytorch/pytorch/issues/110156 r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. These **needs** to be in global scope since Py2 doesn't support serializing static methods. """ -# Taken from PyTorch - import os import queue import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import Optional, TYPE_CHECKING, Union import torch from torch._utils import ExceptionWrapper diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index de25e454e704b..64f122d139f33 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -1,15 +1,29 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import io import os from functools import partial import numpy as np import pytest +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import Dataset + from lightning import seed_everything from lightning.data.cache import Cache, CacheDataLoader from lightning.data.datasets.env import _DistributedEnv from lightning.fabric import Fabric -from lightning_utilities.core.imports import RequirementCache -from torch.utils.data import Dataset _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index dd7a743fd11c7..2a22d77a19c5e 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -1,10 +1,24 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import numpy as np import pytest -from lightning.data.cache.serializers import _SERIALIZERS, IntSerializer, JPEGSerializer, PILSerializer from lightning_utilities.core.imports import RequirementCache +from lightning.data.cache.serializers import _SERIALIZERS, IntSerializer, JPEGSerializer, PILSerializer + _PIL_AVAILABLE = RequirementCache("PIL") diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index cfd5a993569dc..3fb202624240e 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -1,11 +1,25 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import os import numpy as np import pytest +from lightning_utilities.core.imports import RequirementCache + from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter -from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") From 1e3d1ab8ca1bddf3ea5d42971db0d88ae1cb1713 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 21:11:32 +0100 Subject: [PATCH 22/84] update --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 27334b81220e6..4783767af6df4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,6 +149,7 @@ exclude = [ "src/lightning/app/cli/pl-app-template", "src/lightning/app/cli/react-ui-template", "src/lightning/app/launcher", + "src/lightning/pytorch/data/cache", ] install_types = "True" non_interactive = "True" From d05e34ff2ab6a5db809bf8c5fc29c216818d33c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Sep 2023 20:12:35 +0000 Subject: [PATCH 23/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/cache/worker.py | 2 +- tests/tests_data/cache/test_cache.py | 5 ++--- tests/tests_data/cache/test_serializer.py | 3 +-- tests/tests_data/cache/test_writer.py | 3 +-- 4 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/lightning/data/cache/worker.py b/src/lightning/data/cache/worker.py index f971e5001f516..11c2c0ed0f2d4 100644 --- a/src/lightning/data/cache/worker.py +++ b/src/lightning/data/cache/worker.py @@ -23,7 +23,7 @@ import queue import random from dataclasses import dataclass -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, Optional, Union import torch from torch._utils import ExceptionWrapper diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 64f122d139f33..dbe02e10b6946 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -17,13 +17,12 @@ import numpy as np import pytest -from lightning_utilities.core.imports import RequirementCache -from torch.utils.data import Dataset - from lightning import seed_everything from lightning.data.cache import Cache, CacheDataLoader from lightning.data.datasets.env import _DistributedEnv from lightning.fabric import Fabric +from lightning_utilities.core.imports import RequirementCache +from torch.utils.data import Dataset _PIL_AVAILABLE = RequirementCache("PIL") _TORCH_VISION_AVAILABLE = RequirementCache("torchvision") diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index 2a22d77a19c5e..01b6e41b121e6 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -15,9 +15,8 @@ import numpy as np import pytest -from lightning_utilities.core.imports import RequirementCache - from lightning.data.cache.serializers import _SERIALIZERS, IntSerializer, JPEGSerializer, PILSerializer +from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index 3fb202624240e..2674ebea4225e 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -16,10 +16,9 @@ import numpy as np import pytest -from lightning_utilities.core.imports import RequirementCache - from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter +from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") From 6a631172900e962c0b1ee71bfc4081cb7ee3744d Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 22:07:17 +0100 Subject: [PATCH 24/84] update --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4783767af6df4..ede062857a2dd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,7 +149,7 @@ exclude = [ "src/lightning/app/cli/pl-app-template", "src/lightning/app/cli/react-ui-template", "src/lightning/app/launcher", - "src/lightning/pytorch/data/cache", + "src/lightning/data/cache", ] install_types = "True" non_interactive = "True" From 41de12376d2c98bd237035775292a32740a3995e Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 22:20:06 +0100 Subject: [PATCH 25/84] update --- src/lightning/data/cache/cache.py | 3 +- src/lightning/data/cache/worker.py | 44 +++------------------------ tests/tests_data/cache/test_cache.py | 3 +- tests/tests_data/cache/test_writer.py | 10 +++--- 4 files changed, 12 insertions(+), 48 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 45be4bd47faca..a02122af7d093 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -50,7 +50,6 @@ def __init__( self._distributed_env = _DistributedEnv.detect() self._num_workers = None - # TODO: Find a way to make this faster @property def filled(self) -> bool: if self._num_workers is None: @@ -59,7 +58,7 @@ def filled(self) -> bool: return True files = os.listdir(self._cache_dir) index_files = [f for f in files if f.endswith("index.json")] - self._is_done = len(index_files) == self._distributed_env.world_size * self._num_workers + self._is_done = len(index_files) == self._distributed_env.world_size * (self._num_workers or 1) return self._is_done def __setitem__(self, index, data): diff --git a/src/lightning/data/cache/worker.py b/src/lightning/data/cache/worker.py index 11c2c0ed0f2d4..b034f44813a8d 100644 --- a/src/lightning/data/cache/worker.py +++ b/src/lightning/data/cache/worker.py @@ -11,13 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copy pasted from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py + lines 370-378 +# Copy pasted from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py + lines 337-347 # TODO: Delete me when this is addressed https://github.com/pytorch/pytorch/issues/110156 -r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers. -These **needs** to be in global scope since Py2 doesn't support serializing static methods. - -""" import os import queue @@ -135,48 +131,16 @@ def get_worker_info() -> Optional[WorkerInfo]: return _worker_info -r"""Dummy class used to signal the end of an IterableDataset""" - - @dataclass(frozen=True) class _IterableDatasetStopIteration: worker_id: int -r"""Dummy class used to resume the fetching when worker reuse is enabled""" - - @dataclass(frozen=True) class _ResumeIteration: seed: Optional[int] = None -# The function `_generate_state` is adapted from `numpy.random.SeedSequence` -# from https://github.com/numpy/numpy/blob/main/numpy/random/bit_generator.pyx -# It's MIT licensed, here is the copyright: - -# Copyright (c) 2015 Melissa E. O'Neill -# Copyright (c) 2019 NumPy Developers -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in -# all copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - - # This function generates an array of int32 as the seed for # `numpy.random`, in order to prevent state collision due to same # seed and algorithm for `numpy.random` and `random` modules. @@ -328,11 +292,13 @@ def _worker_loop( # Recreate the fetcher for worker-reuse policy fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) continue - elif r is None: + + if r is None: # Received the final signal assert done_event.is_set() or iteration_end break - elif done_event.is_set() or iteration_end: + + if done_event.is_set() or iteration_end: # `done_event` is set. But I haven't received the final signal # (None) yet. I will keep continuing until get it, and skip the # processing steps. diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index dbe02e10b6946..67531fa9d3e89 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -75,9 +75,8 @@ def cache_for_image_dataset(num_workers, tmpdir, fabric=None): cache = Cache(cache_dir, data_format={"image": "jpeg", "class": "int", "index": "int"}, chunk_size=2 << 12) dataset = ImageDataset(tmpdir, cache, dataset_size, 10) dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4) - dataloader_iter = iter(dataloader) - for _ in dataloader_iter: + for _ in dataloader: pass for i in range(len(dataset)): diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index 2674ebea4225e..7cba4ad86cb7e 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -42,7 +42,7 @@ def test_binary_writer_with_ints(tmpdir): binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2} assert len(os.listdir(tmpdir)) == 19 - binary_writer.done(0) + binary_writer.done() assert len(os.listdir(tmpdir)) == 21 with open(os.path.join(tmpdir, "0.index.json")) as f: @@ -79,7 +79,7 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): binary_writer[i] = {"x": img, "y": i} assert len(os.listdir(cache_dir)) == 24 - binary_writer.done(0) + binary_writer.done() assert len(os.listdir(cache_dir)) == 26 with open(os.path.join(cache_dir, "0.index.json")) as f: @@ -109,6 +109,6 @@ def test_binary_writer_config(monkeypatch): assert BinaryWriter.get_cloud_path("~") == prefix assert BinaryWriter.get_cloud_path("~/") == prefix assert BinaryWriter.get_cloud_path("/") == prefix - assert BinaryWriter.get_cloud_path("/data") == f"{prefix}/data" - assert BinaryWriter.get_cloud_path("~/data") == f"{prefix}/data" - assert BinaryWriter.get_cloud_path("/teamspace/studios/this_studio/data") == f"{prefix}/data" + assert BinaryWriter.get_cloud_path("/data") == f"{prefix}data" + assert BinaryWriter.get_cloud_path("~/data") == f"{prefix}data" + assert BinaryWriter.get_cloud_path("/teamspace/studios/this_studio/data") == f"{prefix}data" From e4397d0460d7e8d936f61118f2c899fc0ea63003 Mon Sep 17 00:00:00 2001 From: thomas Date: Thu, 28 Sep 2023 22:20:20 +0100 Subject: [PATCH 26/84] update --- src/lightning/data/cache/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/worker.py b/src/lightning/data/cache/worker.py index b034f44813a8d..f7c8ec5c45146 100644 --- a/src/lightning/data/cache/worker.py +++ b/src/lightning/data/cache/worker.py @@ -11,7 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Copy pasted from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py + lines 337-347 +# Copy pasted from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py + lines 332-342 # TODO: Delete me when this is addressed https://github.com/pytorch/pytorch/issues/110156 From a442858eea1248c27c13fd23142027606b7969e1 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 09:29:16 +0100 Subject: [PATCH 27/84] update --- src/lightning/data/cache/cache.py | 76 ++++++++++++++++++--- src/lightning/data/cache/compression.py | 4 ++ src/lightning/data/cache/env.py | 5 +- src/lightning/data/cache/reader.py | 28 ++++++-- tests/tests_data/cache/test_cache.py | 16 +++-- tests/tests_data/cache/test_writer.py | 2 + tests/tests_data/datasets/test_get_index.py | 19 +++--- 7 files changed, 115 insertions(+), 35 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index a02122af7d093..6fbcb841f6eca 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -13,7 +13,7 @@ import logging import os -from typing import Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional, Union import numpy as np from torch.utils.data import IterableDataset @@ -42,16 +42,30 @@ def __init__( compression: Optional[str] = None, chunk_size: int = 2 << 26, ): + """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements + together in order to accelerate fetching. + + Arguments: + cache_dir: The path to where the chunks will be stored. + data_format: The structure of the data to be serialized. + compression: The name of the algorithm to reduce the size of the chunks + chunk_size: The maximum byte size of chunk. + + """ super().__init__() self._writer = BinaryWriter(cache_dir, data_format, chunk_size=chunk_size, compression=compression) self._reader = BinaryReader(cache_dir, compression=compression) self._cache_dir = cache_dir self._is_done = False self._distributed_env = _DistributedEnv.detect() - self._num_workers = None + self._num_workers: Optional[int] = None + + def _setup(self, num_workers: int) -> None: + self._num_workers = num_workers @property def filled(self) -> bool: + """Returns whether the caching phase is done.""" if self._num_workers is None: raise Exception("The Cache wasn't setup properly. HINT: Did you use the CacheDataLoader ?") if self._is_done: @@ -61,16 +75,19 @@ def filled(self) -> bool: self._is_done = len(index_files) == self._distributed_env.world_size * (self._num_workers or 1) return self._is_done - def __setitem__(self, index, data): + def __setitem__(self, index, data) -> None: + """Store an item in the writer.""" self._writer[index] = data - def __getitem__(self, index): + def __getitem__(self, index) -> Dict[str, Any]: + """Read an item in the reader.""" return self._reader.read(index) - def done(self): + def done(self) -> None: + """Inform the writer the chunking phase is finished.""" self._writer.done() - def __len__(self): + def __len__(self) -> int: return self._reader.get_length() def get_chunk_interval(self): @@ -78,6 +95,8 @@ def get_chunk_interval(self): class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): + """This is overriden to inform the cache is done chunking.""" + def _next_data(self): try: return super()._next_data() @@ -109,6 +128,16 @@ def __len__(self) -> int: class CacheSampler(Sampler): def __init__(self, dataset_size: int, num_workers: int, batch_size: int): + """The CacheSampler splits the dataset indices into ordered chunks and assign each one of them to a DataLoader + worker. The Cache Writer expects the index to be provided in an ordered fashion. + + Arguments: + dataset_size: The size of the dataset. + num_workers: The number of workers provided to the DataLoader + batch_size: The number of items in a batch + + """ + super().__init__(None) self.batch_size = batch_size self.num_workers = num_workers @@ -135,7 +164,7 @@ def __len__(self) -> int: def done(self) -> bool: return len(self._done) == len(self.iterators) - def __iter__(self): + def __iter__(self) -> "CacheSampler": self._done = set() for sampler in self.samplers: @@ -143,7 +172,7 @@ def __iter__(self): return self - def __next__(self): + def __next__(self) -> List[int]: while len(self._done) != self.iterators: try: data = next(self.iterators[self.worker_id]) @@ -161,6 +190,15 @@ def __next__(self): class DistributedCacheSampler(Sampler): def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: int, batch_size: int): + """The DistributedCacheSampler splits the dataset indices into ordered chunks along all the replicas and their + workers. The Cache Writer expects the index to be provided in an ordered fashion. + + Arguments: + dataset_size: The size of the dataset. + num_workers: The number of workers provided to the DataLoader + batch_size: The number of items in a batch + + """ super().__init__(None) self.batch_size = batch_size self.num_workers = num_workers @@ -201,7 +239,7 @@ def __len__(self) -> int: def done(self) -> bool: return len(self._done) == len(self.iterators) - def __iter__(self): + def __iter__(self) -> "DistributedCacheSampler": self._done = set() for sampler in self.samplers: @@ -209,7 +247,7 @@ def __iter__(self): return self - def __next__(self): + def __next__(self) -> List[str]: while len(self._done) != self.iterators: try: data = next(self.iterators[self.worker_id]) @@ -237,6 +275,22 @@ def __init__( shuffle: bool, cache: Cache, ): + """The CacheBatchSampler handles the generation of batch indices. + + If the cache isn't filled, the batch sampler alternates with ordered indices for the writer to chunk the dataset + If the cache is filled, it acts as normal BatchSampler. + + Arguments: + dataset_size: The size of the dataset. + num_replicas: The number of processes involves in the distributed training. + rank: The rank of the given process + num_workers: The number of workers provided to the DataLoader. + batch_size: The number of items in a batch. + shuffle: Whether the data should be shuffled. + cache: The cache associated to the dataset. + + """ + if num_replicas == 1: if not cache.filled and num_workers > 1: sampler = CacheSampler(dataset_size, num_workers, batch_size) @@ -369,7 +423,7 @@ def __init__( raise Exception(f"The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") cache = cache[0] - cache._num_workers = num_workers + cache._setup(num_workers) if not cache.filled and shuffle: logger.info("Shuffle is ignored during caching phase") diff --git a/src/lightning/data/cache/compression.py b/src/lightning/data/cache/compression.py index 5b36f9eb8bf60..731abb979ceb0 100644 --- a/src/lightning/data/cache/compression.py +++ b/src/lightning/data/cache/compression.py @@ -25,6 +25,8 @@ class Compressor(ABC): + """Base class for compression algorithm.""" + @abstractmethod def compress(self, data: bytes) -> bytes: pass @@ -39,6 +41,8 @@ def register(cls, compressors: Dict[str, TCompressor]): class ZSTDCompressor(Compressor): + """Compressor for the zstd package.""" + @requires("zstd") def __init__(self, level): super().__init__() diff --git a/src/lightning/data/cache/env.py b/src/lightning/data/cache/env.py index 95c203b879a57..897c9486c8f96 100644 --- a/src/lightning/data/cache/env.py +++ b/src/lightning/data/cache/env.py @@ -15,7 +15,10 @@ class _WorkerEnv: - """Contains the environment for the current dataloader within the current training process. + """ + Note: This is using our fork: `from lightning.data.cache.worker import get_worker_info` to get the worker_infor + + Contains the environment for the current dataloader within the current training process. Args: world_size: The number of dataloader workers for the current training process diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index ed8e2983c022e..d4806bcde45e7 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -13,7 +13,7 @@ import json import os -from typing import Optional +from typing import Any, Dict, Optional import numpy as np @@ -23,21 +23,35 @@ class BinaryReader: - def __init__(self, _cache_dir: str, compression: Optional[str] = None): + def __init__(self, cache_dir: str, compression: Optional[str] = None): + """The BinaryReader enables to read chunked dataset in an efficient way. + + Arguments: + cache_dir: The path to cache folder + compression: The algorithm to decompress the chunks. + + """ + super().__init__() - self._cache_dir = _cache_dir + self._cache_dir = cache_dir + + if not os.path.exists(self._cache_dir): + raise FileNotFoundError(f"The provided cache_dir `{self._cache_dir}` doesn't exist.") + self._compression = compression self._index = None self._intervals = None + + # TODO: Use a chunk class self._chunks_data = {} self._serializers = _SERIALIZERS self._env = _DistributedEnv.detect() - self._worker_env = None - self._rank = None + self._worker_env: Optional[_WorkerEnv] = None + self._rank: Optional[int] = None @property - def rank(self): + def rank(self) -> Optional[int]: if self._rank is None: self._worker_env = _WorkerEnv.detect() self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank @@ -92,7 +106,7 @@ def read(self, index: int): ) return self.deserialize(raw_item_data, item_config) - def deserialize(self, raw_item_data, item_config): + def deserialize(self, raw_item_data: bytes, item_config: Dict[str, Any]) -> Any: sizes = [] idx = 0 data_format = item_config["data_format"] diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 67531fa9d3e89..db0dce1c76e7e 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -13,6 +13,7 @@ import io import os +import sys from functools import partial import numpy as np @@ -64,7 +65,7 @@ def __getitem__(self, index): return None -def cache_for_image_dataset(num_workers, tmpdir, fabric=None): +def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): from PIL import Image dataset_size = 85 @@ -117,20 +118,21 @@ def cache_for_image_dataset(num_workers, tmpdir, fabric=None): @pytest.mark.skipif( condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" ) -@pytest.mark.parametrize("num_workers", [0, 1, 2]) +@pytest.mark.parametrize("num_workers", [0]) def test_cache_for_image_dataset(num_workers, tmpdir): cache_dir = os.path.join(tmpdir, "cache") os.makedirs(cache_dir) - cache_for_image_dataset(num_workers, tmpdir) + _cache_for_image_dataset(num_workers, tmpdir) -def fabric_cache_for_image_dataset(fabric, num_workers, tmpdir): - cache_for_image_dataset(num_workers, tmpdir, fabric=fabric) +def _fabric_cache_for_image_dataset(fabric, num_workers, tmpdir): + _cache_for_image_dataset(num_workers, tmpdir, fabric=fabric) @pytest.mark.skipif( - condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" + condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE or sys.platform == "win32", + reason="Requires: ['pil', 'torchvision']", ) @pytest.mark.parametrize("num_workers", [2]) def test_cache_for_image_dataset_distributed(num_workers, tmpdir): @@ -138,4 +140,4 @@ def test_cache_for_image_dataset_distributed(num_workers, tmpdir): os.makedirs(cache_dir) fabric = Fabric(accelerator="cpu", devices=2, strategy="ddp_spawn") - fabric.launch(partial(fabric_cache_for_image_dataset, num_workers=num_workers, tmpdir=tmpdir)) + fabric.launch(partial(_fabric_cache_for_image_dataset, num_workers=num_workers, tmpdir=tmpdir)) diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index 7cba4ad86cb7e..7a16db693c748 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -13,6 +13,7 @@ import json import os +import sys import numpy as np import pytest @@ -96,6 +97,7 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): assert data["y"] == i +@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") def test_binary_writer_config(monkeypatch): assert BinaryWriter.get_cloud_path("") is None diff --git a/tests/tests_data/datasets/test_get_index.py b/tests/tests_data/datasets/test_get_index.py index 3dad62bf62771..b51f10966cbbe 100644 --- a/tests/tests_data/datasets/test_get_index.py +++ b/tests/tests_data/datasets/test_get_index.py @@ -17,8 +17,7 @@ def get_test_index_data(index_path): return list(dict.fromkeys([item.split("/")[-1] for item in data if "jpeg" in item])) -@pytest.fixture(scope="session") -def image_set(tmp_path_factory): +def image_set(tmpdir): from PIL import Image file_nums = [ @@ -45,11 +44,11 @@ def image_set(tmp_path_factory): img = img.astype(np.uint8) im = Image.fromarray(img) - for i in file_nums: - fn = tmp_path_factory.mktemp("test_data") / f"img-{i}.jpeg" - im.save(fn) + folder_path = os.path.join(tmpdir, "test_data") + os.makedirs(folder_path, exist_ok=True) - return tmp_path_factory.getbasetemp()._str + for i in file_nums: + im.save(os.path.join(folder_path, f"img-{i}.jpeg")) @pytest.mark.xfail(strict=False, reason="Need a valid AWS key and AWS secret key in CI for this to work") @@ -70,7 +69,6 @@ def test_get_index_generate_for_s3_bucket(monkeypatch): test_bucket = "s3://nohaspublictestbucket" index_path = os.path.join(os.getcwd(), "index_1.txt") - print(index_path) got_index = get_index(s3_connection_path=test_bucket, index_file_path=index_path) assert got_index @@ -80,13 +78,16 @@ def test_get_index_generate_for_s3_bucket(monkeypatch): assert len(test_index_data) == len(generated_index) assert test_index_data == generated_index + os.remove(index_path) @pytest.mark.skipif(not package_available("lightning"), reason="Supported only with mono-package") @mock.patch("lightning.data.datasets.index.LightningClient", MagicMock()) -def test_get_index_generate_for_local_folder(image_set, monkeypatch): +def test_get_index_generate_for_local_folder(monkeypatch, tmpdir): """Can generate an index for an s3 bucket.""" + image_set(tmpdir) + client = MagicMock() client.projects_service_list_project_cluster_bindings.return_value = None client.data_connection_service_list_data_connections.return_value = None @@ -100,7 +101,7 @@ def test_get_index_generate_for_local_folder(image_set, monkeypatch): # test_local_bucket = "data/test_dataset" index_path = os.path.join(THIS_DIR, "index_2.txt") - got_index = get_index(s3_connection_path=image_set, index_file_path=index_path) + got_index = get_index(s3_connection_path=str(tmpdir), index_file_path=index_path) assert got_index From a30b2777fa7c7ed9b0ce6d11821e66e19b70a46b Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 09:35:30 +0100 Subject: [PATCH 28/84] update --- src/lightning/data/cache/env.py | 52 ------------------------------ src/lightning/data/cache/reader.py | 6 ++-- src/lightning/data/cache/writer.py | 6 ++-- src/lightning/data/datasets/env.py | 6 ++-- 4 files changed, 9 insertions(+), 61 deletions(-) delete mode 100644 src/lightning/data/cache/env.py diff --git a/src/lightning/data/cache/env.py b/src/lightning/data/cache/env.py deleted file mode 100644 index 897c9486c8f96..0000000000000 --- a/src/lightning/data/cache/env.py +++ /dev/null @@ -1,52 +0,0 @@ -# Copyright The Lightning AI team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from lightning.data.cache.worker import get_worker_info - - -class _WorkerEnv: - """ - Note: This is using our fork: `from lightning.data.cache.worker import get_worker_info` to get the worker_infor - - Contains the environment for the current dataloader within the current training process. - - Args: - world_size: The number of dataloader workers for the current training process - rank: The rank of the current worker within the number of workers - - """ - - def __init__(self, world_size: int, rank: int): - self.world_size = world_size - self.rank = rank - - @classmethod - def detect(cls) -> "_WorkerEnv": - """Automatically detects the number of workers and the current rank. - - Note: - This only works reliably within a dataloader worker as otherwise the necessary information won't be present. - In such a case it will default to 1 worker - - """ - worker_info = get_worker_info() - num_workers = worker_info.num_workers if worker_info is not None else 1 - current_worker_rank = worker_info.id if worker_info is not None else 0 - - return cls(world_size=num_workers, rank=current_worker_rank) - - def __repr__(self) -> str: - return f"{self.__class__.__name__}(world_size: {self.world_size}, rank: {self.rank})" - - def __str__(self) -> str: - return repr(self) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index d4806bcde45e7..7beb75c95f754 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -17,9 +17,9 @@ import numpy as np -from lightning.data.cache.env import _WorkerEnv from lightning.data.cache.serializers import _SERIALIZERS -from lightning.data.datasets.env import _DistributedEnv +from lightning.data.cache.worker import get_worker_info +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv class BinaryReader: @@ -53,7 +53,7 @@ def __init__(self, cache_dir: str, compression: Optional[str] = None): @property def rank(self) -> Optional[int]: if self._rank is None: - self._worker_env = _WorkerEnv.detect() + self._worker_env = _WorkerEnv.detect(get_worker_info_fn=get_worker_info) self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank return self._rank diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 7f647ebc62b70..2eccfc5889fef 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -18,9 +18,9 @@ import numpy as np from lightning.data.cache.compression import _COMPRESSORS -from lightning.data.cache.env import _WorkerEnv from lightning.data.cache.serializers import _SERIALIZERS -from lightning.data.datasets.env import _DistributedEnv +from lightning.data.cache.worker import get_worker_info +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv def cloud_path(cache_dir: str) -> Optional[str]: @@ -98,7 +98,7 @@ def __init__( @property def rank(self): if self._rank is None: - self._worker_env = _WorkerEnv.detect() + self._worker_env = _WorkerEnv.detect(get_worker_info_fn=get_worker_info) self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank return self._rank diff --git a/src/lightning/data/datasets/env.py b/src/lightning/data/datasets/env.py index 51a9f21271e81..b5be69c01d8a3 100644 --- a/src/lightning/data/datasets/env.py +++ b/src/lightning/data/datasets/env.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Callable, Optional import torch from torch.utils.data import get_worker_info @@ -60,7 +60,7 @@ def __init__(self, world_size: int, rank: int): self.rank = rank @classmethod - def detect(cls) -> "_WorkerEnv": + def detect(cls, get_worker_info_fn: Optional[Callable] = get_worker_info) -> "_WorkerEnv": """Automatically detects the number of workers and the current rank. Note: @@ -68,7 +68,7 @@ def detect(cls) -> "_WorkerEnv": In such a case it will default to 1 worker """ - worker_info = get_worker_info() + worker_info = get_worker_info_fn() num_workers = worker_info.num_workers if worker_info is not None else 1 current_worker_rank = worker_info.id if worker_info is not None else 0 From 6171812cd464ef5bb5f43592473c523662aaf41c Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 09:40:50 +0100 Subject: [PATCH 29/84] update --- src/lightning/data/cache/__init__.py | 3 +- src/lightning/data/cache/cache.py | 377 +------------------------ src/lightning/data/cache/dataloader.py | 126 +++++++++ src/lightning/data/cache/sampler.py | 289 +++++++++++++++++++ tests/tests_data/cache/test_sampler.py | 0 5 files changed, 418 insertions(+), 377 deletions(-) create mode 100644 src/lightning/data/cache/dataloader.py create mode 100644 src/lightning/data/cache/sampler.py create mode 100644 tests/tests_data/cache/test_sampler.py diff --git a/src/lightning/data/cache/__init__.py b/src/lightning/data/cache/__init__.py index 08a18fab987c6..996936c877edf 100644 --- a/src/lightning/data/cache/__init__.py +++ b/src/lightning/data/cache/__init__.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from lightning.data.cache.cache import Cache, CacheDataLoader +from lightning.data.cache.cache import Cache +from lightning.data.cache.dataloader import CacheDataLoader __all__ = ["Cache", "CacheDataLoader"] diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 6fbcb841f6eca..2729f665f52cf 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -13,19 +13,7 @@ import logging import os -from typing import Any, Dict, Iterator, List, Optional, Union - -import numpy as np -from torch.utils.data import IterableDataset -from torch.utils.data._utils.collate import default_collate -from torch.utils.data.dataloader import ( - DataLoader, - _BaseDataLoaderIter, - _MultiProcessingDataLoaderIter, - _SingleProcessDataLoaderIter, -) -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized +from typing import Any, Dict, Optional, Union from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter @@ -92,366 +80,3 @@ def __len__(self) -> int: def get_chunk_interval(self): return self._reader.get_chunk_interval() - - -class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): - """This is overriden to inform the cache is done chunking.""" - - def _next_data(self): - try: - return super()._next_data() - except StopIteration: - for v in self._dataset_fetcher.dataset.__dict__.values(): - if isinstance(v, Cache): - v.done() - raise StopIteration() - - -class IteratorSampler(Sampler[int]): - r"""Samples elements sequentially, always in the same order. - - Args: - data_source (Dataset): dataset to sample from - - """ - data_source: Sized - - def __init__(self, data_source: Sized) -> None: - self.data_source = data_source - - def __iter__(self) -> Iterator[int]: - return iter(self.data_source) - - def __len__(self) -> int: - return len(self.data_source) - - -class CacheSampler(Sampler): - def __init__(self, dataset_size: int, num_workers: int, batch_size: int): - """The CacheSampler splits the dataset indices into ordered chunks and assign each one of them to a DataLoader - worker. The Cache Writer expects the index to be provided in an ordered fashion. - - Arguments: - dataset_size: The size of the dataset. - num_workers: The number of workers provided to the DataLoader - batch_size: The number of items in a batch - - """ - - super().__init__(None) - self.batch_size = batch_size - self.num_workers = num_workers - self.indices = range(dataset_size) - self.dataset_size = dataset_size - worker_size = dataset_size // self.num_workers - self.samplers = [] - for worker_idx in range(num_workers): - is_last = worker_idx == num_workers - 1 - worker_indices = self.indices[ - worker_idx * worker_size : dataset_size if is_last else (worker_idx + 1) * worker_size - ] - self.samplers.append(IteratorSampler(worker_indices)) - self.iterators = [] - self._done = set() - assert sum([len(s) for s in self.samplers]) == dataset_size - self.worker_id = 0 - self.indice_id = 0 - - def __len__(self) -> int: - return self.dataset_size - - @property - def done(self) -> bool: - return len(self._done) == len(self.iterators) - - def __iter__(self) -> "CacheSampler": - self._done = set() - - for sampler in self.samplers: - self.iterators.append(iter(sampler)) - - return self - - def __next__(self) -> List[int]: - while len(self._done) != self.iterators: - try: - data = next(self.iterators[self.worker_id]) - self.indice_id += 1 - if self.indice_id == self.batch_size: - self.indice_id = 0 - self.worker_id = (self.worker_id + 1) % self.num_workers - return data - except StopIteration: - self._done.add(self.worker_id) - self.indice_id = 0 - self.worker_id = (self.worker_id + 1) % self.num_workers - raise StopIteration - - -class DistributedCacheSampler(Sampler): - def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: int, batch_size: int): - """The DistributedCacheSampler splits the dataset indices into ordered chunks along all the replicas and their - workers. The Cache Writer expects the index to be provided in an ordered fashion. - - Arguments: - dataset_size: The size of the dataset. - num_workers: The number of workers provided to the DataLoader - batch_size: The number of items in a batch - - """ - super().__init__(None) - self.batch_size = batch_size - self.num_workers = num_workers - self.indices = range(dataset_size) - self.dataset_size = dataset_size - replica_size = dataset_size // num_replicas - worker_size = dataset_size // (num_replicas * self.num_workers) - self.samplers = [] - for replica_idx in range(num_replicas): - if replica_idx != rank: - continue - - is_last_replica = replica_idx == num_replicas - 1 - start_replica = replica_idx * replica_size - end_replica = dataset_size if is_last_replica else (replica_idx + 1) * replica_size - replica_indices = self.indices[start_replica:end_replica] - - replica_size = len(replica_indices) - - for worker_idx in range(num_workers): - is_last_worker = worker_idx == num_workers - 1 - start_worker = worker_idx * worker_size - end_worker = replica_size if is_last_worker else (worker_idx + 1) * worker_size - worker_indices = replica_indices[start_worker:end_worker] - self.samplers.append(IteratorSampler(worker_indices)) - - self.iterators = [] - self._done = set() - - assert sum([len(s) for s in self.samplers]) == replica_size - self.worker_id = 0 - self.indice_id = 0 - - def __len__(self) -> int: - return self.dataset_size - - @property - def done(self) -> bool: - return len(self._done) == len(self.iterators) - - def __iter__(self) -> "DistributedCacheSampler": - self._done = set() - - for sampler in self.samplers: - self.iterators.append(iter(sampler)) - - return self - - def __next__(self) -> List[str]: - while len(self._done) != self.iterators: - try: - data = next(self.iterators[self.worker_id]) - self.indice_id += 1 - if self.indice_id == self.batch_size: - self.indice_id = 0 - self.worker_id = (self.worker_id + 1) % self.num_workers - return data - except StopIteration: - self._done.add(self.worker_id) - self.indice_id = 0 - self.worker_id = (self.worker_id + 1) % self.num_workers - raise StopIteration - - -class CacheBatchSampler(BatchSampler): - def __init__( - self, - dataset_size: int, - num_replicas: int, - rank: int, - num_workers: int, - batch_size: int, - drop_last: bool, - shuffle: bool, - cache: Cache, - ): - """The CacheBatchSampler handles the generation of batch indices. - - If the cache isn't filled, the batch sampler alternates with ordered indices for the writer to chunk the dataset - If the cache is filled, it acts as normal BatchSampler. - - Arguments: - dataset_size: The size of the dataset. - num_replicas: The number of processes involves in the distributed training. - rank: The rank of the given process - num_workers: The number of workers provided to the DataLoader. - batch_size: The number of items in a batch. - shuffle: Whether the data should be shuffled. - cache: The cache associated to the dataset. - - """ - - if num_replicas == 1: - if not cache.filled and num_workers > 1: - sampler = CacheSampler(dataset_size, num_workers, batch_size) - elif shuffle: - sampler = RandomSampler(range(dataset_size)) - else: - sampler = SequentialSampler(range(dataset_size)) - else: - if not cache.filled: - sampler = DistributedCacheSampler(dataset_size, num_replicas, rank, num_workers, batch_size) - else: - sampler = DistributedSampler(range(dataset_size), num_replicas=num_replicas, rank=rank, shuffle=shuffle) - super().__init__(sampler, batch_size, drop_last) - self._num_replicas = num_replicas - self._rank = rank - self._cache = cache - self._shuffle = shuffle - self._num_workers = num_workers - - def __iter_ordered__(self) -> Iterator[List[int]]: - # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 - iterator = iter(self.sampler) - batch = [] - while not self.sampler.done: - try: - idx = next(iterator) - batch.append(idx) - if len(batch) == self.batch_size: - yield batch - batch = [] - except StopIteration: - if self.sampler.done: - yield batch - return - yield batch - batch = [] - - def __iter__(self): - if self._cache.filled and self._shuffle: - return self.__iter_from_chunks__() - if self._num_workers > 1 and not self._cache.filled: - return self.__iter_ordered__() - return super().__iter__() - - def __iter_from_chunks__(self): - chunk_intervals = self._cache.get_chunk_interval() - shuffled_chunk_intervals = np.random.permutation(chunk_intervals) - - if self._num_replicas == 1: - indices = [] - for interval in shuffled_chunk_intervals: - interval_indices = np.arange(interval[0], interval[1]) - shuffled_interval_indices = np.random.permutation(interval_indices) - indices.extend(shuffled_interval_indices.tolist()) - - if len(indices) != len(self.sampler): - raise Exception("The generated indices don't match the initial length of the sampler.") - - else: - chunks_per_replica = len(shuffled_chunk_intervals) // self._num_replicas - for replica_idx in range(self._num_replicas): - if replica_idx != self._rank: - continue - is_last_replica = replica_idx == self._num_replicas - 1 - start_replica = replica_idx * chunks_per_replica - end_replica = len(chunk_intervals) if is_last_replica else (replica_idx + 1) * chunks_per_replica - shuffled_chunk_intervals_replica = shuffled_chunk_intervals[start_replica:end_replica] - - indices = [] - for interval in shuffled_chunk_intervals_replica: - interval_indices = np.arange(interval[0], interval[1]) - shuffled_interval_indices = np.random.permutation(interval_indices) - indices.extend(shuffled_interval_indices.tolist()) - - self.sampler = IteratorSampler(indices) - - return super().__iter__() - - def __len__(self) -> int: - return super().__len__() - - -class CacheCollateFn: - def __init__(self): - self.collate_fn = default_collate - - def __call__(self, items): - if all(item is None for item in items): - return None - return self.collate_fn(items) - - -class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): - def __init__(self, loader): - # Patch PyTorch worker loop - from torch.utils.data._utils import worker - - from lightning.data.cache.worker import _worker_loop - - worker._worker_loop = _worker_loop - super().__init__(loader) - - -class CacheDataLoader(DataLoader): - def __init__( - self, - dataset, - *args, - sampler=None, - batch_sampler=None, - num_workers=0, - shuffle: bool = False, - generator=None, - batch_size=1, - drop_last=False, - **kwargs, - ): - if sampler: - raise Exception("Passing a sampler isn't supoprt with the CacheDataLoader.") - - if batch_sampler: - raise Exception("Passing a batch_sampler isn't supoprt with the CacheDataLoader.") - - if isinstance(dataset, IterableDataset): - raise Exception("Only map-based dataset are supported by the CacheDataLoader for now.") - - cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] - - if not cache or len(cache) > 1: - raise Exception(f"The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") - - cache = cache[0] - cache._setup(num_workers) - if not cache.filled and shuffle: - logger.info("Shuffle is ignored during caching phase") - - distributed_env = _DistributedEnv.detect() - batch_sampler = CacheBatchSampler( - len(dataset), - distributed_env.world_size, - distributed_env.global_rank, - num_workers, - batch_size, - drop_last, - shuffle, - cache, - ) - - super().__init__( - dataset, - *args, - sampler=None, - batch_sampler=batch_sampler, - generator=generator, - collate_fn=CacheCollateFn(), - num_workers=num_workers, - **kwargs, - ) - - def _get_iterator(self) -> "_BaseDataLoaderIter": - if self.num_workers == 0: - return _SingleProcessDataLoaderIterPatch(self) - self.check_worker_number_rationality() - return _MultiProcessingDataLoaderIterPatch(self) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py new file mode 100644 index 0000000000000..9cf4256aa559a --- /dev/null +++ b/src/lightning/data/cache/dataloader.py @@ -0,0 +1,126 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from torch.utils.data import IterableDataset +from torch.utils.data._utils.collate import default_collate +from torch.utils.data.dataloader import ( + DataLoader, + _BaseDataLoaderIter, + _MultiProcessingDataLoaderIter, + _SingleProcessDataLoaderIter, +) + +from lightning.data.cache import Cache +from lightning.data.cache.sampler import CacheBatchSampler +from lightning.data.datasets.env import _DistributedEnv + +logger = logging.Logger(__name__) + + +class CacheCollateFn: + def __init__(self): + self.collate_fn = default_collate + + def __call__(self, items): + if all(item is None for item in items): + return None + return self.collate_fn(items) + + +class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): + """This is overriden to inform the cache is done chunking.""" + + def _next_data(self): + try: + return super()._next_data() + except StopIteration: + for v in self._dataset_fetcher.dataset.__dict__.values(): + if isinstance(v, Cache): + v.done() + raise StopIteration() + + +class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): + def __init__(self, loader): + # Patch PyTorch worker loop + from torch.utils.data._utils import worker + + from lightning.data.cache.worker import _worker_loop + + worker._worker_loop = _worker_loop + super().__init__(loader) + + +class CacheDataLoader(DataLoader): + def __init__( + self, + dataset, + *args, + sampler=None, + batch_sampler=None, + num_workers=0, + shuffle: bool = False, + generator=None, + batch_size=1, + drop_last=False, + **kwargs, + ): + if sampler: + raise Exception("Passing a sampler isn't supoprt with the CacheDataLoader.") + + if batch_sampler: + raise Exception("Passing a batch_sampler isn't supoprt with the CacheDataLoader.") + + if isinstance(dataset, IterableDataset): + raise Exception("Only map-based dataset are supported by the CacheDataLoader for now.") + + cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] + + if not cache or len(cache) > 1: + raise Exception(f"The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") + + cache = cache[0] + cache._setup(num_workers) + if not cache.filled and shuffle: + logger.info("Shuffle is ignored during caching phase") + + distributed_env = _DistributedEnv.detect() + batch_sampler = CacheBatchSampler( + len(dataset), + distributed_env.world_size, + distributed_env.global_rank, + num_workers, + batch_size, + drop_last, + shuffle, + cache, + ) + + super().__init__( + dataset, + *args, + sampler=None, + batch_sampler=batch_sampler, + generator=generator, + collate_fn=CacheCollateFn(), + num_workers=num_workers, + **kwargs, + ) + + def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0: + return _SingleProcessDataLoaderIterPatch(self) + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIterPatch(self) diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py new file mode 100644 index 0000000000000..48b5b9acbb605 --- /dev/null +++ b/src/lightning/data/cache/sampler.py @@ -0,0 +1,289 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Iterator, List + +import numpy as np +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized + +from lightning.data.cache import Cache + +logger = logging.Logger(__name__) + + +class IteratorSampler(Sampler[int]): + r"""Samples elements sequentially, always in the same order. + + Args: + data_source (Dataset): dataset to sample from + + """ + data_source: Sized + + def __init__(self, data_source: Sized) -> None: + self.data_source = data_source + + def __iter__(self) -> Iterator[int]: + return iter(self.data_source) + + def __len__(self) -> int: + return len(self.data_source) + + +class CacheSampler(Sampler): + def __init__(self, dataset_size: int, num_workers: int, batch_size: int): + """The CacheSampler splits the dataset indices into ordered chunks and assign each one of them to a DataLoader + worker. The Cache Writer expects the index to be provided in an ordered fashion. + + Arguments: + dataset_size: The size of the dataset. + num_workers: The number of workers provided to the DataLoader + batch_size: The number of items in a batch + + """ + + super().__init__(None) + self.batch_size = batch_size + self.num_workers = num_workers + self.indices = range(dataset_size) + self.dataset_size = dataset_size + worker_size = dataset_size // self.num_workers + self.samplers = [] + for worker_idx in range(num_workers): + is_last = worker_idx == num_workers - 1 + worker_indices = self.indices[ + worker_idx * worker_size : dataset_size if is_last else (worker_idx + 1) * worker_size + ] + self.samplers.append(IteratorSampler(worker_indices)) + self.iterators = [] + self._done = set() + assert sum([len(s) for s in self.samplers]) == dataset_size + self.worker_id = 0 + self.indice_id = 0 + + def __len__(self) -> int: + return self.dataset_size + + @property + def done(self) -> bool: + return len(self._done) == len(self.iterators) + + def __iter__(self) -> "CacheSampler": + self._done = set() + + for sampler in self.samplers: + self.iterators.append(iter(sampler)) + + return self + + def __next__(self) -> List[int]: + while len(self._done) != self.iterators: + try: + data = next(self.iterators[self.worker_id]) + self.indice_id += 1 + if self.indice_id == self.batch_size: + self.indice_id = 0 + self.worker_id = (self.worker_id + 1) % self.num_workers + return data + except StopIteration: + self._done.add(self.worker_id) + self.indice_id = 0 + self.worker_id = (self.worker_id + 1) % self.num_workers + raise StopIteration + + +class DistributedCacheSampler(Sampler): + def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: int, batch_size: int): + """The DistributedCacheSampler splits the dataset indices into ordered chunks along all the replicas and their + workers. The Cache Writer expects the index to be provided in an ordered fashion. + + Arguments: + dataset_size: The size of the dataset. + num_workers: The number of workers provided to the DataLoader + batch_size: The number of items in a batch + + """ + super().__init__(None) + self.batch_size = batch_size + self.num_workers = num_workers + self.indices = range(dataset_size) + self.dataset_size = dataset_size + replica_size = dataset_size // num_replicas + worker_size = dataset_size // (num_replicas * self.num_workers) + self.samplers = [] + for replica_idx in range(num_replicas): + if replica_idx != rank: + continue + + is_last_replica = replica_idx == num_replicas - 1 + start_replica = replica_idx * replica_size + end_replica = dataset_size if is_last_replica else (replica_idx + 1) * replica_size + replica_indices = self.indices[start_replica:end_replica] + + replica_size = len(replica_indices) + + for worker_idx in range(num_workers): + is_last_worker = worker_idx == num_workers - 1 + start_worker = worker_idx * worker_size + end_worker = replica_size if is_last_worker else (worker_idx + 1) * worker_size + worker_indices = replica_indices[start_worker:end_worker] + self.samplers.append(IteratorSampler(worker_indices)) + + self.iterators = [] + self._done = set() + + assert sum([len(s) for s in self.samplers]) == replica_size + self.worker_id = 0 + self.indice_id = 0 + + def __len__(self) -> int: + return self.dataset_size + + @property + def done(self) -> bool: + return len(self._done) == len(self.iterators) + + def __iter__(self) -> "DistributedCacheSampler": + self._done = set() + + for sampler in self.samplers: + self.iterators.append(iter(sampler)) + + return self + + def __next__(self) -> List[str]: + while len(self._done) != self.iterators: + try: + data = next(self.iterators[self.worker_id]) + self.indice_id += 1 + if self.indice_id == self.batch_size: + self.indice_id = 0 + self.worker_id = (self.worker_id + 1) % self.num_workers + return data + except StopIteration: + self._done.add(self.worker_id) + self.indice_id = 0 + self.worker_id = (self.worker_id + 1) % self.num_workers + raise StopIteration + + +class CacheBatchSampler(BatchSampler): + def __init__( + self, + dataset_size: int, + num_replicas: int, + rank: int, + num_workers: int, + batch_size: int, + drop_last: bool, + shuffle: bool, + cache: Cache, + ): + """The CacheBatchSampler handles the generation of batch indices. + + If the cache isn't filled, the batch sampler alternates with ordered indices for the writer to chunk the dataset + If the cache is filled, it acts as normal BatchSampler. + + Arguments: + dataset_size: The size of the dataset. + num_replicas: The number of processes involves in the distributed training. + rank: The rank of the given process + num_workers: The number of workers provided to the DataLoader. + batch_size: The number of items in a batch. + shuffle: Whether the data should be shuffled. + cache: The cache associated to the dataset. + + """ + + if num_replicas == 1: + if not cache.filled and num_workers > 1: + sampler = CacheSampler(dataset_size, num_workers, batch_size) + elif shuffle: + sampler = RandomSampler(range(dataset_size)) + else: + sampler = SequentialSampler(range(dataset_size)) + else: + if not cache.filled: + sampler = DistributedCacheSampler(dataset_size, num_replicas, rank, num_workers, batch_size) + else: + sampler = DistributedSampler(range(dataset_size), num_replicas=num_replicas, rank=rank, shuffle=shuffle) + super().__init__(sampler, batch_size, drop_last) + self._num_replicas = num_replicas + self._rank = rank + self._cache = cache + self._shuffle = shuffle + self._num_workers = num_workers + + def __iter_ordered__(self) -> Iterator[List[int]]: + # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 + iterator = iter(self.sampler) + batch = [] + while not self.sampler.done: + try: + idx = next(iterator) + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + except StopIteration: + if self.sampler.done: + yield batch + return + yield batch + batch = [] + + def __iter__(self): + if self._cache.filled and self._shuffle: + return self.__iter_from_chunks__() + if self._num_workers > 1 and not self._cache.filled: + return self.__iter_ordered__() + return super().__iter__() + + def __iter_from_chunks__(self): + chunk_intervals = self._cache.get_chunk_interval() + shuffled_chunk_intervals = np.random.permutation(chunk_intervals) + + if self._num_replicas == 1: + indices = [] + for interval in shuffled_chunk_intervals: + interval_indices = np.arange(interval[0], interval[1]) + shuffled_interval_indices = np.random.permutation(interval_indices) + indices.extend(shuffled_interval_indices.tolist()) + + if len(indices) != len(self.sampler): + raise Exception("The generated indices don't match the initial length of the sampler.") + + else: + chunks_per_replica = len(shuffled_chunk_intervals) // self._num_replicas + for replica_idx in range(self._num_replicas): + if replica_idx != self._rank: + continue + is_last_replica = replica_idx == self._num_replicas - 1 + start_replica = replica_idx * chunks_per_replica + end_replica = len(chunk_intervals) if is_last_replica else (replica_idx + 1) * chunks_per_replica + shuffled_chunk_intervals_replica = shuffled_chunk_intervals[start_replica:end_replica] + + indices = [] + for interval in shuffled_chunk_intervals_replica: + interval_indices = np.arange(interval[0], interval[1]) + shuffled_interval_indices = np.random.permutation(interval_indices) + indices.extend(shuffled_interval_indices.tolist()) + + self.sampler = IteratorSampler(indices) + + return super().__iter__() + + def __len__(self) -> int: + return super().__len__() diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py new file mode 100644 index 0000000000000..e69de29bb2d1d From f21a7d88bbd57e83c32276b4ab7950dbd4fe9eca Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 29 Sep 2023 11:23:30 +0100 Subject: [PATCH 30/84] Update src/lightning/data/cache/dataloader.py Co-authored-by: Ethan Harris --- src/lightning/data/cache/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 9cf4256aa559a..7dccfa4826529 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -78,7 +78,7 @@ def __init__( **kwargs, ): if sampler: - raise Exception("Passing a sampler isn't supoprt with the CacheDataLoader.") + raise Exception("Passing a sampler isn't supported with the CacheDataLoader.") if batch_sampler: raise Exception("Passing a batch_sampler isn't supoprt with the CacheDataLoader.") From 9bc88110727cfa9c1a7db7c51d6f573c84d6188d Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 29 Sep 2023 11:23:38 +0100 Subject: [PATCH 31/84] Update src/lightning/data/cache/dataloader.py Co-authored-by: Ethan Harris --- src/lightning/data/cache/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 7dccfa4826529..8df3e477ccf10 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -81,7 +81,7 @@ def __init__( raise Exception("Passing a sampler isn't supported with the CacheDataLoader.") if batch_sampler: - raise Exception("Passing a batch_sampler isn't supoprt with the CacheDataLoader.") + raise Exception("Passing a batch_sampler isn't supported with the CacheDataLoader.") if isinstance(dataset, IterableDataset): raise Exception("Only map-based dataset are supported by the CacheDataLoader for now.") From 5df5984118274263fe0bb7fb5e0b1f97bc0fd902 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 11:24:44 +0100 Subject: [PATCH 32/84] update --- src/lightning/data/cache/worker.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/src/lightning/data/cache/worker.py b/src/lightning/data/cache/worker.py index f7c8ec5c45146..74534dd09b85b 100644 --- a/src/lightning/data/cache/worker.py +++ b/src/lightning/data/cache/worker.py @@ -1,20 +1,6 @@ -# Copyright The Lightning AI team. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copy pasted from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py + lines 332-342 +# Copy pasted from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py + lines 318-328 # TODO: Delete me when this is addressed https://github.com/pytorch/pytorch/issues/110156 - import os import queue import random From ab470c3bcfb8cf9e14e6f25c879b948b41467261 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 11:34:30 +0100 Subject: [PATCH 33/84] update --- src/lightning/data/cache/cache.py | 1 + src/lightning/data/cache/dataloader.py | 1 + src/lightning/data/cache/reader.py | 29 +++++++++++--------------- src/lightning/data/cache/writer.py | 27 +++++++++++++++++------- 4 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 2729f665f52cf..499eddca441eb 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -49,6 +49,7 @@ def __init__( self._num_workers: Optional[int] = None def _setup(self, num_workers: int) -> None: + """Called by the CacheDataLoader to ensure the num_workers is known.""" self._num_workers = num_workers @property diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 8df3e477ccf10..111509ca8cfe2 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -120,6 +120,7 @@ def __init__( ) def _get_iterator(self) -> "_BaseDataLoaderIter": + """Overriden to ensure the `Cache.done` method is triggered on iteration done.""" if self.num_workers == 0: return _SingleProcessDataLoaderIterPatch(self) self.check_worker_number_rationality() diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 7beb75c95f754..87cec2915a5e4 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -18,7 +18,6 @@ import numpy as np from lightning.data.cache.serializers import _SERIALIZERS -from lightning.data.cache.worker import get_worker_info from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -48,17 +47,9 @@ def __init__(self, cache_dir: str, compression: Optional[str] = None): self._env = _DistributedEnv.detect() self._worker_env: Optional[_WorkerEnv] = None - self._rank: Optional[int] = None - - @property - def rank(self) -> Optional[int]: - if self._rank is None: - self._worker_env = _WorkerEnv.detect(get_worker_info_fn=get_worker_info) - self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank - - return self._rank def _try_read_index(self): + """Try to read the chunks json index files if available.""" files = os.listdir(self._cache_dir) indexes_filepath = sorted([os.path.join(self._cache_dir, f) for f in files if f.endswith("index.json")]) if not indexes_filepath: @@ -82,16 +73,19 @@ def _try_read_index(self): for i in range(len(cumsum_samples) - 1): self._intervals.append([cumsum_samples[i], cumsum_samples[i + 1]]) - def _map_index_to_chunk_id(self, index): + def _map_index_to_chunk_id(self, index: int) -> int: + """Find the associated chunk in which the current index was stored.""" for interval_index, internal in enumerate(self._intervals): if internal[0] <= index and index < internal[1]: return interval_index raise Exception(f"The chunk interval weren't properly defined. Found {self._intervals} for inded {index}.") - def _should_keep_in_memory(self): - return True - def read(self, index: int): + """Read an item for the given from a chunk. + + If the chunk isn't available, it will be downloaded. + + """ if self._index is None: self._try_read_index() @@ -101,12 +95,11 @@ def read(self, index: int): chunk_id = self._map_index_to_chunk_id(index) chunk_config = self._index["chunks"][chunk_id] chunk_path = os.path.join(self._cache_dir, chunk_config["filename"]) - raw_item_data, item_config = self.load_item_from_chunk( - index, chunk_path, keep_in_memory=self._should_keep_in_memory() - ) + raw_item_data, item_config = self.load_item_from_chunk(index, chunk_path, keep_in_memory=True) return self.deserialize(raw_item_data, item_config) def deserialize(self, raw_item_data: bytes, item_config: Dict[str, Any]) -> Any: + """Deserialize the raw bytes into their python equivalent.""" sizes = [] idx = 0 data_format = item_config["data_format"] @@ -145,6 +138,7 @@ def load_item_from_chunk(self, index: int, chunk_path: str, keep_in_memory: bool return data, config def get_length(self) -> int: + """Get the number of samples across all chunks.""" if self._index is None: self._try_read_index() @@ -154,6 +148,7 @@ def get_length(self) -> int: return sum([v["samples"] for v in self._index["chunks"]]) def get_chunk_interval(self): + """Get the index interval of each chunks.""" if self._index is None: self._try_read_index() diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 2eccfc5889fef..1ff7a1f5024ec 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -97,13 +97,15 @@ def __init__( @property def rank(self): + """Returns the rank of the writer.""" if self._rank is None: self._worker_env = _WorkerEnv.detect(get_worker_info_fn=get_worker_info) self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank return self._rank def get_config(self) -> Dict[str, Any]: - out = super().get_config() + """Returns the config of the writer.""" + out = {"compression": self._compression, "chunk_size": self._chunk_size, "data_format": self._data_format} out.update(self._data_format) cloud_path = self.get_cloud_path(self._cache_dir) if cloud_path: @@ -114,6 +116,7 @@ def get_config(self) -> Dict[str, Any]: return out def serialize(self, items: Dict[str, Any]) -> bytes: + """Serialize a dictionary into its binary format.""" if not isinstance(items, dict): raise Exception("The provided data should be a dictionary.") @@ -139,6 +142,7 @@ def serialize(self, items: Dict[str, Any]) -> bytes: return head + body def _create_chunk(self, filename: str) -> bytes: + """Create a binary chunk from all the binarized items.""" num_items = np.uint32(len(self._serialized_items)) sizes = list(map(len, self._serialized_items)) offsets = np.array([0] + sizes).cumsum().astype(np.uint32) @@ -164,19 +168,13 @@ def _create_chunk(self, filename: str) -> bytes: return data def write_chunk(self): + """Write a chunk to the filesystem.""" if self._compression: filename = f"chunk-{self.rank}-{self._chunk_id}.{self._compression}.bin" else: filename = f"chunk-{self.rank}-{self._chunk_id}.bin" self.write_file(self._create_chunk(filename), filename) - @property - def is_cached(self) -> bool: - return os.path.exists(os.path.join(self._cache_dir, "index.json")) - - def get_config(self) -> Dict[str, Any]: - return {"compression": self._compression, "chunk_size": self._chunk_size, "data_format": self._data_format} - @property def available_serializers(self): return self._serializers @@ -188,6 +186,11 @@ def reset(self) -> None: self._current_chunk_size = 0 def __setitem__(self, index, items: any): + """Store an item to a chunk. + + The index needs to be provided in order. + + """ serialized_items = self.serialize(items) serialized_items_size = len(serialized_items) @@ -214,6 +217,7 @@ def write_file( raw_data: bytes, filename: str, ) -> None: + """Write chunk bytes to a file.""" if self._compression: raw_data = self._compressor.compress(raw_data) filepath = os.path.join(self._cache_dir, filename) @@ -221,11 +225,17 @@ def write_file( out.write(raw_data) def write_chunks_index(self): + """Write the chunks index to a JSON file.""" filepath = os.path.join(self._cache_dir, f"{self.rank}.index.json") with open(filepath, "w") as out: json.dump({"chunks": self._chunks_info}, out, sort_keys=True) def done(self): + """Called when StopIteration is triggered. + + It tries to save the last chunk and write the chunks index. + + """ if self._is_done: return if self._serialized_items: @@ -236,6 +246,7 @@ def done(self): @classmethod def get_cloud_path(cls, cache_dir: str) -> Optional[str]: + """Returns the s3 URL to the cache_dir.""" cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) From f6d8184f44547cd590aacfce679df85f70a9bfb7 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 12:20:28 +0100 Subject: [PATCH 34/84] update --- src/lightning/data/cache/sampler.py | 113 +++++++--------- src/lightning/data/datasets/env.py | 7 +- tests/tests_data/cache/test_sampler.py | 177 +++++++++++++++++++++++++ 3 files changed, 232 insertions(+), 65 deletions(-) diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py index 48b5b9acbb605..5808bf93cf024 100644 --- a/src/lightning/data/cache/sampler.py +++ b/src/lightning/data/cache/sampler.py @@ -42,36 +42,14 @@ def __len__(self) -> int: return len(self.data_source) -class CacheSampler(Sampler): - def __init__(self, dataset_size: int, num_workers: int, batch_size: int): - """The CacheSampler splits the dataset indices into ordered chunks and assign each one of them to a DataLoader - worker. The Cache Writer expects the index to be provided in an ordered fashion. - - Arguments: - dataset_size: The size of the dataset. - num_workers: The number of workers provided to the DataLoader - batch_size: The number of items in a batch - - """ - +class BaseCacheSampler(Sampler): + def __init__(self, dataset_size: int): super().__init__(None) - self.batch_size = batch_size - self.num_workers = num_workers - self.indices = range(dataset_size) self.dataset_size = dataset_size - worker_size = dataset_size // self.num_workers - self.samplers = [] - for worker_idx in range(num_workers): - is_last = worker_idx == num_workers - 1 - worker_indices = self.indices[ - worker_idx * worker_size : dataset_size if is_last else (worker_idx + 1) * worker_size - ] - self.samplers.append(IteratorSampler(worker_indices)) - self.iterators = [] - self._done = set() - assert sum([len(s) for s in self.samplers]) == dataset_size self.worker_id = 0 self.indice_id = 0 + self.iterators = [] + self._done = set() def __len__(self) -> int: return self.dataset_size @@ -80,7 +58,7 @@ def __len__(self) -> int: def done(self) -> bool: return len(self._done) == len(self.iterators) - def __iter__(self) -> "CacheSampler": + def __iter__(self) -> "BaseCacheSampler": self._done = set() for sampler in self.samplers: @@ -88,6 +66,17 @@ def __iter__(self) -> "CacheSampler": return self + def _next_worker_id(self): + if self.done: + return + counter = 1 + while True: + next_worker_id = (self.worker_id + counter) % self.num_workers + if next_worker_id not in self._done: + self.worker_id = next_worker_id + break + counter += 1 + def __next__(self) -> List[int]: while len(self._done) != self.iterators: try: @@ -95,16 +84,47 @@ def __next__(self) -> List[int]: self.indice_id += 1 if self.indice_id == self.batch_size: self.indice_id = 0 - self.worker_id = (self.worker_id + 1) % self.num_workers + self._next_worker_id() return data except StopIteration: self._done.add(self.worker_id) self.indice_id = 0 - self.worker_id = (self.worker_id + 1) % self.num_workers + self._next_worker_id() raise StopIteration -class DistributedCacheSampler(Sampler): +class CacheSampler(BaseCacheSampler): + def __init__(self, dataset_size: int, num_workers: int, batch_size: int): + """The CacheSampler splits the dataset indices into ordered chunks and assign each one of them to a DataLoader + worker. The Cache Writer expects the index to be provided in an ordered fashion. + + Arguments: + dataset_size: The size of the dataset. + num_workers: The number of workers provided to the DataLoader + batch_size: The number of items in a batch + + """ + + super().__init__(dataset_size) + self.batch_size = batch_size + self.num_workers = num_workers + self.indices = range(dataset_size) + worker_size = dataset_size // self.num_workers + self.samplers = [] + for worker_idx in range(num_workers): + is_last = worker_idx == num_workers - 1 + worker_indices = self.indices[ + worker_idx * worker_size : dataset_size if is_last else (worker_idx + 1) * worker_size + ] + self.samplers.append(IteratorSampler(worker_indices)) + self.iterators = [] + self._done = set() + assert sum([len(s) for s in self.samplers]) == dataset_size + self.worker_id = 0 + self.indice_id = 0 + + +class DistributedCacheSampler(BaseCacheSampler): def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: int, batch_size: int): """The DistributedCacheSampler splits the dataset indices into ordered chunks along all the replicas and their workers. The Cache Writer expects the index to be provided in an ordered fashion. @@ -115,11 +135,10 @@ def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: batch_size: The number of items in a batch """ - super().__init__(None) + super().__init__(dataset_size) self.batch_size = batch_size self.num_workers = num_workers self.indices = range(dataset_size) - self.dataset_size = dataset_size replica_size = dataset_size // num_replicas worker_size = dataset_size // (num_replicas * self.num_workers) self.samplers = [] @@ -148,36 +167,6 @@ def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: self.worker_id = 0 self.indice_id = 0 - def __len__(self) -> int: - return self.dataset_size - - @property - def done(self) -> bool: - return len(self._done) == len(self.iterators) - - def __iter__(self) -> "DistributedCacheSampler": - self._done = set() - - for sampler in self.samplers: - self.iterators.append(iter(sampler)) - - return self - - def __next__(self) -> List[str]: - while len(self._done) != self.iterators: - try: - data = next(self.iterators[self.worker_id]) - self.indice_id += 1 - if self.indice_id == self.batch_size: - self.indice_id = 0 - self.worker_id = (self.worker_id + 1) % self.num_workers - return data - except StopIteration: - self._done.add(self.worker_id) - self.indice_id = 0 - self.worker_id = (self.worker_id + 1) % self.num_workers - raise StopIteration - class CacheBatchSampler(BatchSampler): def __init__( diff --git a/src/lightning/data/datasets/env.py b/src/lightning/data/datasets/env.py index b5be69c01d8a3..e369448ff8cab 100644 --- a/src/lightning/data/datasets/env.py +++ b/src/lightning/data/datasets/env.py @@ -1,7 +1,7 @@ from typing import Callable, Optional import torch -from torch.utils.data import get_worker_info +from torch.utils.data import get_worker_info as torch_get_worker_info class _DistributedEnv: @@ -60,7 +60,7 @@ def __init__(self, world_size: int, rank: int): self.rank = rank @classmethod - def detect(cls, get_worker_info_fn: Optional[Callable] = get_worker_info) -> "_WorkerEnv": + def detect(cls, get_worker_info_fn: Optional[Callable] = None) -> "_WorkerEnv": """Automatically detects the number of workers and the current rank. Note: @@ -68,7 +68,8 @@ def detect(cls, get_worker_info_fn: Optional[Callable] = get_worker_info) -> "_W In such a case it will default to 1 worker """ - worker_info = get_worker_info_fn() + get_worker_info = get_worker_info_fn or torch_get_worker_info + worker_info = get_worker_info() num_workers = worker_info.num_workers if worker_info is not None else 1 current_worker_rank = worker_info.id if worker_info is not None else 0 diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index e69de29bb2d1d..46ed5c8f81802 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -0,0 +1,177 @@ +import pytest +from lightning.data.cache.sampler import CacheSampler, DistributedCacheSampler + + +def test_cache_sampler_sampling(): + """Valides the CacheSampler can return batch of data in an ordered way.""" + dataset_size = 17 + sampler = CacheSampler(dataset_size, 3, 3) + iter_sampler = iter(sampler) + + all_indexes = [] + indexes = [] + while True: + try: + index = next(iter_sampler) + indexes.append(index) + all_indexes.append(index) + except StopIteration: + assert indexes == [0, 1, 2, 5, 6, 7, 10, 11, 12, 3, 4] + assert sampler._done == {0} + break + + indexes = [] + while True: + try: + index = next(iter_sampler) + indexes.append(index) + all_indexes.append(index) + except StopIteration: + assert indexes == [8, 9] + assert sampler._done == {0, 1} + break + + indexes = [] + while True: + try: + index = next(iter_sampler) + indexes.append(index) + all_indexes.append(index) + except StopIteration: + assert indexes == [13, 14, 15, 16] + assert sampler._done == {0, 1, 2} + break + + assert sorted(all_indexes) == list(range(dataset_size)) + + +@pytest.mark.parametrize( + "params", + [ + (21, range(0, 7), range(7, 14), range(14, 21)), + (23, range(0, 7), range(7, 14), range(14, 23)), + (33, range(0, 11), range(11, 22), range(22, 33)), + (49, range(0, 16), range(16, 32), range(32, 49)), + (5, range(0, 1), range(1, 2), range(2, 5)), + (12, range(0, 4), range(4, 8), range(8, 12)), + ], +) +def test_cache_sampler_samplers(params): + sampler = CacheSampler(params[0], 3, 3) + assert sampler.samplers[0].data_source == params[1] + assert sampler.samplers[1].data_source == params[2] + assert sampler.samplers[2].data_source == params[3] + + +@pytest.mark.parametrize( + "params", + [ + ( + 102, + 2, + [ + [range(0, 17), range(17, 34), range(34, 51)], + [range(51, 68), range(68, 85), range(85, 102)], + ], + ), + ( + 227, + 5, + [ + [range(0, 15), range(15, 30), range(30, 45)], + [range(45, 60), range(60, 75), range(75, 90)], + [range(90, 105), range(105, 120), range(120, 135)], + [range(135, 150), range(150, 165), range(165, 180)], + [range(180, 195), range(195, 210), range(210, 227)], + ], + ), + ( + 1025, + 7, + [ + [range(0, 48), range(48, 96), range(96, 146)], + [range(146, 194), range(194, 242), range(242, 292)], + [range(292, 340), range(340, 388), range(388, 438)], + [range(438, 486), range(486, 534), range(534, 584)], + [range(584, 632), range(632, 680), range(680, 730)], + [range(730, 778), range(778, 826), range(826, 876)], + [range(876, 924), range(924, 972), range(972, 1025)], + ], + ), + ( + 323, + 2, + [ + [range(0, 53), range(53, 106), range(106, 161)], + [range(161, 214), range(214, 267), range(267, 323)], + ], + ), + ( + 23, + 3, + [ + [range(0, 2), range(2, 4), range(4, 7)], + [range(7, 9), range(9, 11), range(11, 14)], + [range(14, 16), range(16, 18), range(18, 23)], + ], + ), + ( + 45, + 2, + [ + [range(0, 7), range(7, 14), range(14, 22)], + [range(22, 29), range(29, 36), range(36, 45)], + ], + ), + ], +) +def test_cache_distributed_sampler_samplers(params): + """This test validates the sub-samplers of the DistributedCacheSampler has the right sampling intervals.""" + for rank in range(params[1]): + sampler = DistributedCacheSampler(params[0], params[1], rank, 3, 3) + assert sampler.samplers[0].data_source == params[2][rank][0] + assert sampler.samplers[1].data_source == params[2][rank][1] + assert sampler.samplers[2].data_source == params[2][rank][2] + + +def test_cache_distributed_sampler_sampling(): + """Valides the DistributedCacheSampler can return batch of data in an ordered way.""" + dataset_size = 129 + sampler = DistributedCacheSampler(dataset_size, 5, 3, 3) + iter_sampler = iter(sampler) + + all_indexes = [] + indexes = [] + while True: + try: + index = next(iter_sampler) + indexes.append(index) + all_indexes.append(index) + except StopIteration: + assert indexes == [0, 1, 2, 5, 6, 7, 10, 11, 12, 3, 4] + assert sampler._done == {0} + break + + indexes = [] + while True: + try: + index = next(iter_sampler) + indexes.append(index) + all_indexes.append(index) + except StopIteration: + assert indexes == [8, 9] + assert sampler._done == {0, 1} + break + + indexes = [] + while True: + try: + index = next(iter_sampler) + indexes.append(index) + all_indexes.append(index) + except StopIteration: + assert indexes == [13, 14, 15, 16] + assert sampler._done == {0, 1, 2} + break + + assert sorted(all_indexes) == list(range(dataset_size)) From 41c2ba9790cd7691a36c01defbc73dffa9dc1d43 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 12:50:42 +0100 Subject: [PATCH 35/84] update --- tests/tests_data/cache/test_sampler.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index 46ed5c8f81802..e4ba606d8a3cc 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -1,5 +1,7 @@ +from unittest import mock + import pytest -from lightning.data.cache.sampler import CacheSampler, DistributedCacheSampler +from lightning.data.cache.sampler import CacheBatchSampler, CacheSampler, DistributedCacheSampler def test_cache_sampler_sampling(): @@ -175,3 +177,25 @@ def test_cache_distributed_sampler_sampling(): break assert sorted(all_indexes) == list(range(dataset_size)) + + +@pytest.mark.parametrize( + "params", + [ + (21, [[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]]), + (11, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]]), + (8, [[0, 1], [2, 3], [4, 5, 6], [7]]), + (4, [[0], [1], [2, 3]]), + (9, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], []]), + (19, [[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]]), + ], +) +def test_cache_batch_sampler(params): + cache = mock.MagicMock() + cache.filled = False + batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache) + batches = [] + for batch in batch_sampler: + batches.append(batch) + print(batches) + assert batches == params[1] From b642fd35b249ddda6a7023007c52038ef02ed54a Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 15:00:00 +0100 Subject: [PATCH 36/84] updatte --- src/lightning/data/cache/cache.py | 2 +- src/lightning/data/cache/pytree.py | 567 ++++++++++++++++++++++++ src/lightning/data/cache/reader.py | 158 ++++--- src/lightning/data/cache/serializers.py | 13 + src/lightning/data/cache/writer.py | 121 +++-- tests/tests_data/cache/test_cache.py | 28 ++ 6 files changed, 775 insertions(+), 114 deletions(-) create mode 100644 src/lightning/data/cache/pytree.py diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 499eddca441eb..9616c979915c0 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -26,7 +26,7 @@ class Cache: def __init__( self, cache_dir: str, - data_format: Union[Dict[str, any], str], + data_format: Union[Dict[str, any], str] = None, compression: Optional[str] = None, chunk_size: int = 2 << 26, ): diff --git a/src/lightning/data/cache/pytree.py b/src/lightning/data/cache/pytree.py new file mode 100644 index 0000000000000..06b6fd44534f9 --- /dev/null +++ b/src/lightning/data/cache/pytree.py @@ -0,0 +1,567 @@ +# Taken from PyTorch https://github.com/pytorch/pytorch/blob/e9ebda29d87ce0916ab08c06ab26fd3766a870e5/torch/utils/_pytree.py +# This should be available in 2.0.2 +# TODO: Remove me when open sourced. + +import dataclasses +import json +import warnings +from collections import OrderedDict, namedtuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union, cast, overload + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + +DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], Tuple[List, Context]] +UnflattenFunc = Callable[[List, Context], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +ToStrFunc = Callable[["TreeSpec", List[str]], str] +MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] + + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +class NodeDef(NamedTuple): + type: Type[Any] + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + + +SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} + + +# _SerializeNodeDef holds the following: +# - typ: the type of the node (e.g., "Dict", "List", etc) +# - type_fqn: the fully qualified name of the type, e.g. "collections.OrderedDict" +# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the +# context, and the version number +# - from_dumpable_context takes in a string representation of the context, and the +# version, and returns the deserialized context +class _SerializeNodeDef(NamedTuple): + typ: Type[Any] + type_fqn: str + to_dumpable_context: Optional[ToDumpableContextFn] + from_dumpable_context: Optional[FromDumpableContextFn] + + +SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} +SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} + + +def _register_pytree_node( + typ: Any, + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, + *, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, +) -> None: + """ + Args: + typ: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattedn pytree. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + """ + if to_str_fn is not None or maybe_from_str_fn is not None: + warnings.warn( + "to_str_fn and maybe_from_str_fn is deprecated. " + "Please use to_dumpable_context and from_dumpable_context instead." + ) + + node_def = NodeDef( + typ, + flatten_fn, + unflatten_fn, + ) + SUPPORTED_NODES[typ] = node_def + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError(f"Both to_dumpable_context and from_dumpable_context for {typ} must " "be None or registered.") + + type_fqn = f"{typ.__module__}.{typ.__name__}" + serialize_node_def = _SerializeNodeDef(typ, type_fqn, to_dumpable_context, from_dumpable_context) + SUPPORTED_SERIALIZED_TYPES[typ] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[type_fqn] = typ + + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: + return dict(zip(context, values)) + + +def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return d, None + + +def _list_unflatten(values: List[Any], context: Context) -> List[Any]: + return list(values) + + +def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: + return list(d), None + + +def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]: + return tuple(values) + + +def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: + return list(d), type(d) + + +def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple: + return cast(NamedTuple, context(*values)) + + +def _namedtuple_serialize(context: Context) -> DumpableContext: + json_namedtuple = { + "class_name": context.__name__, + "fields": context._fields, + } + return json_namedtuple + + +def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: + class_name = dumpable_context["class_name"] + assert isinstance(class_name, str) + context = namedtuple(class_name, dumpable_context["fields"]) # type: ignore[misc] + return context + + +def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _odict_unflatten(values: List[Any], context: Context) -> "OrderedDict[Any, Any]": + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_register_pytree_node(dict, _dict_flatten, _dict_unflatten) +_register_pytree_node(list, _list_flatten, _list_unflatten) +_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten) +_register_pytree_node( + namedtuple, + _namedtuple_flatten, + _namedtuple_unflatten, + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, +) +_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) + + +# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +def _is_namedtuple_instance(pytree: Any) -> bool: + typ = type(pytree) + bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + fields = getattr(typ, "_fields", None) + if not isinstance(fields, tuple): + return False + return all(type(entry) == str for entry in fields) + + +def _get_node_type(pytree: Any) -> Any: + if _is_namedtuple_instance(pytree): + return namedtuple + return type(pytree) + + +# A leaf is defined as anything that is not a Node. +def _is_leaf(pytree: PyTree) -> bool: + return _get_node_type(pytree) not in SUPPORTED_NODES + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +@dataclasses.dataclass +class TreeSpec: + type: Any + context: Context + children_specs: List["TreeSpec"] + + def __post_init__(self) -> None: + self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) + + def __repr__(self, indent: int = 0) -> str: + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" + if len(self.children_specs): + indent += 2 + children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += "," if len(self.children_specs) > 1 else "" + children_specs_str += ",".join( + ["\n" + " " * indent + child.__repr__(indent) for child in self.children_specs[1:]] + ) + repr_suffix: str = f"{children_specs_str}])" + return repr_prefix + repr_suffix + + +class LeafSpec(TreeSpec): + def __init__(self) -> None: + super().__init__(None, None, []) + self.num_leaves = 1 + + def __repr__(self, indent: int = 0) -> str: + return "*" + + +def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree.""" + if _is_leaf(pytree): + return [pytree], LeafSpec() + + node_type = _get_node_type(pytree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(pytree) + + # Recursively flatten the children + result: List[Any] = [] + children_specs: List[TreeSpec] = [] + for child in child_pytrees: + flat, child_spec = tree_flatten(child) + result += flat + children_specs.append(child_spec) + + return result, TreeSpec(node_type, context, children_specs) + + +def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + + This is the inverse operation of `tree_flatten`. + + """ + if not isinstance(spec, TreeSpec): + raise TypeError( + f"tree_unflatten(values, spec): Expected `spec` to be instance of " + f"TreeSpec but got item of type {type(spec)}." + ) + if len(values) != spec.num_leaves: + raise ValueError( + f"tree_unflatten(values, spec): `values` has length {len(values)} " + f"but the spec refers to a pytree that holds {spec.num_leaves} " + f"items ({spec})." + ) + if isinstance(spec, LeafSpec): + return values[0] + + unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in spec.children_specs: + end += child_spec.num_leaves + child_pytrees.append(tree_unflatten(values[start:end], child_spec)) + start = end + + return unflatten_fn(child_pytrees, spec.context) + + +def tree_map(fn: Any, pytree: PyTree) -> PyTree: + flat_args, spec = tree_flatten(pytree) + return tree_unflatten([fn(i) for i in flat_args], spec) + + +Type2 = Tuple[Type[T], Type[S]] +Type3 = Tuple[Type[T], Type[S], Type[U]] +TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] + +Fn3 = Callable[[Union[T, S, U]], R] +Fn2 = Callable[[Union[T, S]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]: + ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]: + ... + + +def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]: + """Suppose you are writing a tree_map over tensors, leaving everything else unchanged. Ordinarily you would have + to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + + """ + + def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]: + def inner(x: T) -> Any: + if isinstance(x, ty): + return f(x) + return x + + return inner + + return deco + + +@overload +def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree: + ... + + +@overload +def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree: + ... + + +@overload +def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree: + ... + + +def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree: + return tree_map(map_only(ty)(fn), pytree) + + +def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool: + flat_args, _ = tree_flatten(pytree) + return all(map(pred, flat_args)) + + +def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool: + flat_args, _ = tree_flatten(pytree) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool: + ... + + +@overload +def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool: + ... + + +@overload +def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool: + ... + + +def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool: + flat_args, _ = tree_flatten(pytree) + return all(pred(x) for x in flat_args if isinstance(x, ty)) + + +@overload +def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool: + ... + + +@overload +def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool: + ... + + +def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool: + flat_args, _ = tree_flatten(pytree) + return any(pred(x) for x in flat_args if isinstance(x, ty)) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]: + assert isinstance(spec, TreeSpec) + + if _is_leaf(pytree): + return [pytree] * spec.num_leaves + if isinstance(spec, LeafSpec): + return None + node_type = _get_node_type(pytree) + if node_type != spec.type: + return None + + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, ctx = flatten_fn(pytree) + + # Check if the Node is different from the spec + if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context: + return None + + # Recursively flatten the children + result: List[Any] = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec) + if flat is not None: + result += flat + else: + return None + + return result + + +@dataclasses.dataclass +class _TreeSpecSchema: + type: Optional[str] + context: DumpableContext + children_spec: List["_TreeSpecSchema"] + + +class _ProtocolFn(NamedTuple): + treespec_to_json: Callable[[TreeSpec], DumpableContext] + json_to_treespec: Callable[[DumpableContext], TreeSpec] + + +_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} + + +def _treespec_to_json(spec: TreeSpec) -> _TreeSpecSchema: + if isinstance(spec, LeafSpec): + return _TreeSpecSchema(None, None, []) + + if spec.type not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError(f"Serializing {spec.type} in pytree is not registered.") + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[spec.type] + + type_fqn = serialize_node_def.type_fqn + + if serialize_node_def.to_dumpable_context is None: + try: + serialized_context = json.dumps(spec.context) + except TypeError as e: + raise TypeError( + "Unable to serialize context. " + "Please make the context json dump-able, or register a " + "custom serializer using _register_pytree_node." + ) from e + else: + serialized_context = serialize_node_def.to_dumpable_context(spec.context) + + child_schemas = [_treespec_to_json(child) for child in spec.children_specs] + + return _TreeSpecSchema(type_fqn, serialized_context, child_schemas) + + +def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: + if json_schema["type"] is None and json_schema["context"] is None and len(json_schema["children_spec"]) == 0: + return LeafSpec() + + if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError(f'Deserializing {json_schema["type"]} in pytree is not registered.') + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] + + if serialize_node_def.from_dumpable_context is None: + try: + context = json.loads(json_schema["context"]) + except TypeError: + raise TypeError( + "Unable to deserialize context. " + "Please make the context json load-able, or register a " + "custom serializer using _register_pytree_node." + ) + else: + context = serialize_node_def.from_dumpable_context(json_schema["context"]) + + children_spec = [] + for child_string in json_schema["children_spec"]: + children_spec.append(_json_to_treespec(child_string)) + + return TreeSpec(typ, context, children_spec) + + +_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + if protocol is None: + protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL + + if protocol in _SUPPORTED_PROTOCOLS: + json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) + else: + raise ValueError(f"Unknown protocol {protocol}. Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}") + + str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) + return str_spec + + +def treespec_loads(data: str) -> TreeSpec: + protocol, json_schema = json.loads(data) + + if protocol in _SUPPORTED_PROTOCOLS: + return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) + raise ValueError(f"Unknown protocol {protocol}. Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}") + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def pytree_to_str(spec: TreeSpec) -> str: + warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") + return treespec_dumps(spec) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def str_to_pytree(json: str) -> TreeSpec: + warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") + return treespec_loads(json) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 87cec2915a5e4..19ad61435eb8c 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -13,14 +13,72 @@ import json import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import numpy as np -from lightning.data.cache.serializers import _SERIALIZERS +from lightning.data.cache.pytree import tree_flatten, tree_unflatten, treespec_loads +from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv +class ChunksConfig: + def __init__(self, cache_dir: str, index_filenames: str): + self._cache_dir = cache_dir + self.index_filenames = sorted(index_filenames) + self._intervals = [] + self._config = None + self._chunks = [] + + for filename in self.index_filenames: + with open(os.path.join(self._cache_dir, filename)) as f: + data = json.load(f) + + if self._config is None: + self._config = data["config"] + self._config["data_spec"] = treespec_loads(self._config["data_spec"]) + flattened_data_format, _ = tree_flatten(self._config["data_format"]) + self._config["flattened_data_format"] = flattened_data_format + + elif self._config != data["config"]: + raise Exception("The config isn't consistent between chunks. This shouldn't have happened.") + + self._chunks.extend(data["chunks"]) + + for chunk in self._chunks: + start, end = chunk["interval"] + if (end - start + 1) != chunk["samples"]: + raise Exception( + "The config intervals doesn't match the number of samples. This shouldn't have happened." + ) + self._intervals.append(chunk["interval"]) + + @property + def intervals(self): + return self._intervals + + @property + def config(self): + return self._config + + def __getitem__(self, index: int) -> Tuple[str, int, int]: + """Find the associated chunk metadata.""" + for interval_index, internal in enumerate(self._intervals): + if internal[0] <= index and index <= internal[1]: + chunk = self._chunks[interval_index] + mapping = chunk["mapping"][str(index)] + return os.path.join(self._cache_dir, chunk["filename"]), *mapping + raise Exception(f"The chunk interval weren't properly defined. Found {self._intervals} for index {index}.") + + @classmethod + def load(cls, cache_dir: str) -> Optional["ChunksConfig"]: + files = os.listdir(cache_dir) + index_filenames = sorted([f for f in files if f.endswith("index.json")]) + if not index_filenames: + return None + return ChunksConfig(cache_dir, index_filenames) + + class BinaryReader: def __init__(self, cache_dir: str, compression: Optional[str] = None): """The BinaryReader enables to read chunked dataset in an efficient way. @@ -41,44 +99,17 @@ def __init__(self, cache_dir: str, compression: Optional[str] = None): self._index = None self._intervals = None - # TODO: Use a chunk class self._chunks_data = {} - self._serializers = _SERIALIZERS + self._serializers: Dict[str, Serializer] = _SERIALIZERS self._env = _DistributedEnv.detect() self._worker_env: Optional[_WorkerEnv] = None - def _try_read_index(self): - """Try to read the chunks json index files if available.""" - files = os.listdir(self._cache_dir) - indexes_filepath = sorted([os.path.join(self._cache_dir, f) for f in files if f.endswith("index.json")]) - if not indexes_filepath: - return - - index = {"chunks": []} - for path in indexes_filepath: - with open(path) as f: - data = json.load(f) - index["chunks"].extend(data["chunks"]) - - self._index = index + self._config: Optional[ChunksConfig] = None - for chunk in self._index["chunks"]: - chunk["data"] = None - self._chunks_data[chunk["filename"]] = chunk - - self._intervals = [] - num_samples = [v["samples"] for v in self._index["chunks"]] - cumsum_samples = np.cumsum([0] + num_samples) - for i in range(len(cumsum_samples) - 1): - self._intervals.append([cumsum_samples[i], cumsum_samples[i + 1]]) - - def _map_index_to_chunk_id(self, index: int) -> int: - """Find the associated chunk in which the current index was stored.""" - for interval_index, internal in enumerate(self._intervals): - if internal[0] <= index and index < internal[1]: - return interval_index - raise Exception(f"The chunk interval weren't properly defined. Found {self._intervals} for inded {index}.") + def _try_load_config(self): + """Try to load the chunks config if the index files are available.""" + self._config = ChunksConfig.load(self._cache_dir) def read(self, index: int): """Read an item for the given from a chunk. @@ -87,60 +118,51 @@ def read(self, index: int): """ if self._index is None: - self._try_read_index() + self._try_load_config() - if self._index is None: + if self._config is None: raise Exception("The reader index isn't defined.") - chunk_id = self._map_index_to_chunk_id(index) - chunk_config = self._index["chunks"][chunk_id] - chunk_path = os.path.join(self._cache_dir, chunk_config["filename"]) - raw_item_data, item_config = self.load_item_from_chunk(index, chunk_path, keep_in_memory=True) - return self.deserialize(raw_item_data, item_config) + chunk_filepath, begin, end = self._config[index] + raw_item_data = self.load_item_from_chunk(chunk_filepath, begin, end, keep_in_memory=True) + return self.deserialize(raw_item_data) - def deserialize(self, raw_item_data: bytes, item_config: Dict[str, Any]) -> Any: + def deserialize(self, raw_item_data: bytes) -> Any: """Deserialize the raw bytes into their python equivalent.""" sizes = [] idx = 0 - data_format = item_config["data_format"] - keys = sorted(data_format) - for key in keys: + data_format = self._config.config["flattened_data_format"] + for key in data_format: (size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32) sizes.append(size) idx += 4 - sample = {} - for key, size in zip(keys, sizes): - value = raw_item_data[idx : idx + size] - serializer = self._serializers[data_format[key]] - sample[key] = serializer.deserialize(value) + data = [] + for size, format in zip(sizes, data_format): + serializer = self._serializers[format] + data_bytes = raw_item_data[idx : idx + size] + data.append(serializer.deserialize(data_bytes)) idx += size - return sample - - def load_item_from_chunk(self, index: int, chunk_path: str, keep_in_memory: bool = False): - chunk_name = os.path.basename(chunk_path) - try: - begin, end = self._chunks_data[chunk_name]["mapping"][str(index)] - except Exception as e: - raise Exception(f"Medata: ({self._chunks_data[chunk_name]}), Error: {e}") - config = self._chunks_data[chunk_name]["config"] - if self._chunks_data[chunk_name]["data"] is not None: - return self._chunks_data[chunk_name]["data"][begin:end], config + return tree_unflatten(data, self._config.config["data_spec"]) + + def load_item_from_chunk(self, chunk_filepath: str, begin: int, end: int, keep_in_memory: bool = False): + if chunk_filepath in self._chunks_data: + return self._chunks_data[chunk_filepath][begin:end] if keep_in_memory: - with open(chunk_path, "rb", 0) as fp: + with open(chunk_filepath, "rb", 0) as fp: data = fp.read() - self._chunks_data[chunk_name]["data"] = data - return data[begin:end], config + self._chunks_data[chunk_filepath] = data + return data[begin:end] - with open(chunk_path, "rb", 0) as fp: + with open(chunk_filepath, "rb", 0) as fp: fp.seek(begin) data = fp.read(end - begin) - return data, config + return data def get_length(self) -> int: """Get the number of samples across all chunks.""" if self._index is None: - self._try_read_index() + self._try_load_config() if self._index is None: raise Exception("The reader index isn't defined.") @@ -150,7 +172,7 @@ def get_length(self) -> int: def get_chunk_interval(self): """Get the index interval of each chunks.""" if self._index is None: - self._try_read_index() + self._try_load_config() if self._intervals is None: raise Exception("The reader index isn't defined.") diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index 2d74baf1b5a57..aa59b46e4d846 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -42,6 +42,10 @@ def serialize(self, data: any) -> bytes: def deserialize(self, data: bytes) -> any: pass + @abstractmethod + def can_serialize(self, data: any) -> bool: + pass + class PILSerializer(Serializer): """The PILSerializer serialize and deserialize PIL Image to and from bytes.""" @@ -62,6 +66,9 @@ def deserialize(self, data: bytes) -> any: raw = data[idx2:] return Image.frombytes(mode, size, raw) # pyright: ignore + def can_serialize(self, item) -> bool: + pass + class IntSerializer(Serializer): """The IntSerializer serialize and deserialize integer to and from bytes.""" @@ -72,6 +79,9 @@ def serialize(self, item: int) -> bytes: def deserialize(self, data: bytes) -> int: return int(data.decode("utf-8")) + def can_serialize(self, item) -> bool: + return isinstance(item, int) + class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" @@ -90,6 +100,9 @@ def deserialize(self, data: bytes) -> Image: inp = BytesIO(data) return Image.open(inp) + def can_serialize(self, item) -> bool: + pass + _SERIALIZERS = { "pil": PILSerializer(), diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 1ff7a1f5024ec..c4df4b164aaee 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -13,12 +13,13 @@ import json import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import numpy as np from lightning.data.cache.compression import _COMPRESSORS -from lightning.data.cache.serializers import _SERIALIZERS +from lightning.data.cache.pytree import tree_flatten, tree_unflatten, treespec_dumps +from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.cache.worker import get_worker_info from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -37,7 +38,7 @@ class BinaryWriter: def __init__( self, cache_dir: str, - data_format: Dict[str, str], + data_format: Union[Dict[str, str], str] = None, chunk_size: int = 1 << 26, compression: Optional[str] = None, ): @@ -45,32 +46,43 @@ def __init__( Arguments: cache_dir: The path to where the chunks will be saved. - data_format: The format of the provided data to cache. Only dictionary are supported for now. + data_format: The format of the provided data to cache. chunk_size: The maximum number of bytes to store within a chunk. compression: The compression algorithm to use. """ self._cache_dir = cache_dir - self._data_format = {k.lower(): v for k, v in data_format.items()} - self._chunk_size = chunk_size - self._compression = compression if not os.path.exists(self._cache_dir): raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.") - if len(self._data_format) == 0: - raise ValueError("The provided data format shouldn't be empty.") - - self._data_format_keys = sorted(self._data_format.keys()) - self._serializers = _SERIALIZERS + self._serializers: Dict[str, Serializer] = _SERIALIZERS + self._chunk_size = chunk_size + self._compression = compression - available_serializers = set(self._serializers.keys()) - selected_serializers = set(self._data_format.values()) - if selected_serializers.difference(available_serializers): - raise ValueError( - "The provided data_format don't match the provided serializers." - " Should be selected from {sorted(available_serializers)}." - ) + self._data_format = None + self._data_spec = None + + if data_format: + if isinstance(data_format, str): + self._data_format = data_format + self._data_format_key = None + elif isinstance(data_format, Dict): + if len(data_format) == 0: + raise ValueError("The provided data format shouldn't be empty.") + self._data_format = {k.lower(): v for k, v in data_format.items()} + self._data_format_keys = sorted(self._data_format.keys()) + + available_serializers = set(self._serializers.keys()) + selected_serializers = set(self._data_format.values()) + if selected_serializers.difference(available_serializers): + raise ValueError( + "The provided data_format don't match the provided serializers." + " Should be selected from {sorted(available_serializers)}." + ) + else: + self._data_format = None + self._data_format_key = None if self._compression: if len(_COMPRESSORS) == 0: @@ -86,10 +98,6 @@ def __init__( self._serialized_items = [] self._chunks_info = [] self._indexes = [] - obj = self.get_config() - text = json.dumps(obj, sort_keys=True) - self._config_data = text.encode("utf-8") - self._env = _DistributedEnv.detect() self._worker_env = None self._rank = None @@ -105,8 +113,12 @@ def rank(self): def get_config(self) -> Dict[str, Any]: """Returns the config of the writer.""" - out = {"compression": self._compression, "chunk_size": self._chunk_size, "data_format": self._data_format} - out.update(self._data_format) + out = { + "compression": self._compression, + "chunk_size": self._chunk_size, + "data_format": self._data_format, + "data_spec": treespec_dumps(self._data_spec), + } cloud_path = self.get_cloud_path(self._cache_dir) if cloud_path: out["cloud_path"] = cloud_path @@ -115,52 +127,71 @@ def get_config(self) -> Dict[str, Any]: out["user_id"] = user_id return out - def serialize(self, items: Dict[str, Any]) -> bytes: + def serialize(self, items: Any) -> bytes: """Serialize a dictionary into its binary format.""" - if not isinstance(items, dict): - raise Exception("The provided data should be a dictionary.") - - keys = sorted(items.keys()) - - if keys != self._data_format_keys: - raise Exception( - f"The provided keys don't match the provided format. Found {keys} instead of {self._data_format_keys}." - ) + flattened, data_spec = tree_flatten(items) sizes = [] data = [] - for key in self._data_format_keys: - serializer_name = self._data_format[key] - serializer = self._serializers[serializer_name] - serialized_data = serializer.serialize(items[key]) if not isinstance(items[key], bytes) else items[key] - sizes.append(len(serialized_data)) - data.append(serialized_data) + formats = [] + for item in flattened: + formats.append(self._serialize(item, sizes, data)) + + data_format = tree_unflatten(formats, data_spec) + + if self._data_format is None: + self._data_format = data_format + self._data_spec = data_spec + else: + if self._data_format != data_format: + raise Exception( + f"The data format changed between items. Found {data_format} instead of {self._data_format}." + ) + if self._data_spec != data_spec: + raise Exception( + f"The data format changed between items. Found {data_spec} instead of {self._data_spec}." + ) head = np.array(sizes, np.uint32).tobytes() body = b"".join(data) return head + body + def _serialize(self, item, sizes, data) -> bytes: + if isinstance(item, bytes): + data.append(item) + sizes.append(len(item)) + return "bytes" + + for serializer_name, serializer in self._serializers.items(): + if serializer.can_serialize(item): + serialized_item = serializer.serialize(item) + data.append(serialized_item) + sizes.append(len(serialized_item)) + return serializer_name + raise ValueError(f"The provided item isn't serializable. Found {item}") + def _create_chunk(self, filename: str) -> bytes: """Create a binary chunk from all the binarized items.""" num_items = np.uint32(len(self._serialized_items)) sizes = list(map(len, self._serialized_items)) offsets = np.array([0] + sizes).cumsum().astype(np.uint32) - offsets += len(num_items.tobytes()) + len(offsets.tobytes()) + len(self._config_data) + offsets += len(num_items.tobytes()) + len(offsets.tobytes()) sample_data = b"".join(self._serialized_items) - data = num_items.tobytes() + offsets.tobytes() + self._config_data + sample_data + data = num_items.tobytes() + offsets.tobytes() + sample_data offsets = offsets.tolist() mapping = {} for i in range(len(self._indexes)): mapping[self._indexes[i]] = [offsets[i], offsets[i + 1]] assert len(mapping) == len(self._indexes) + assert (self._indexes[-1] - self._indexes[0] + 1) == len(self._serialized_items) chunk_info = { "samples": len(self._serialized_items), - "config": self.get_config(), "filename": filename, "mapping": mapping, + "interval": [self._indexes[0], self._indexes[-1]], } self._chunks_info.append(chunk_info) @@ -228,7 +259,7 @@ def write_chunks_index(self): """Write the chunks index to a JSON file.""" filepath = os.path.join(self._cache_dir, f"{self.rank}.index.json") with open(filepath, "w") as out: - json.dump({"chunks": self._chunks_info}, out, sort_keys=True) + json.dump({"chunks": self._chunks_info, "config": self.get_config()}, out, sort_keys=True) def done(self): """Called when StopIteration is triggered. diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index db0dce1c76e7e..dfc6725cbe634 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -141,3 +141,31 @@ def test_cache_for_image_dataset_distributed(num_workers, tmpdir): fabric = Fabric(accelerator="cpu", devices=2, strategy="ddp_spawn") fabric.launch(partial(_fabric_cache_for_image_dataset, num_workers=num_workers, tmpdir=tmpdir)) + + +def test_cache_with_simple_format(tmpdir): + cache_dir = os.path.join(tmpdir, "cache1") + os.makedirs(cache_dir) + + cache = Cache(cache_dir, chunk_size=90) + + for i in range(100): + cache[i] = i + + cache.done() + + for i in range(100): + assert i == cache[i] + + cache_dir = os.path.join(tmpdir, "cache2") + os.makedirs(cache_dir) + + cache = Cache(cache_dir, chunk_size=90) + + for i in range(100): + cache[i] = [i, {0: [i + 1]}] + + cache.done() + + for i in range(100): + assert [i, {0: [i + 1]}] == cache[i] From c63d0e137007b8eb273af84f1ea4594ad96b1d36 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 15:54:59 +0100 Subject: [PATCH 37/84] update --- src/lightning/data/cache/cache.py | 6 +- src/lightning/data/cache/reader.py | 48 ++++++++------- src/lightning/data/cache/serializers.py | 35 +++++++---- src/lightning/data/cache/writer.py | 75 +++++++++-------------- tests/tests_data/cache/test_cache.py | 16 +++-- tests/tests_data/cache/test_serializer.py | 2 +- tests/tests_data/cache/test_writer.py | 10 +-- 7 files changed, 94 insertions(+), 98 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 9616c979915c0..6d2e4093575d0 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -13,7 +13,7 @@ import logging import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter @@ -26,7 +26,6 @@ class Cache: def __init__( self, cache_dir: str, - data_format: Union[Dict[str, any], str] = None, compression: Optional[str] = None, chunk_size: int = 2 << 26, ): @@ -35,13 +34,12 @@ def __init__( Arguments: cache_dir: The path to where the chunks will be stored. - data_format: The structure of the data to be serialized. compression: The name of the algorithm to reduce the size of the chunks chunk_size: The maximum byte size of chunk. """ super().__init__() - self._writer = BinaryWriter(cache_dir, data_format, chunk_size=chunk_size, compression=compression) + self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, compression=compression) self._reader = BinaryReader(cache_dir, compression=compression) self._cache_dir = cache_dir self._is_done = False diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 19ad61435eb8c..7395639b2f402 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -17,9 +17,9 @@ import numpy as np -from lightning.data.cache.pytree import tree_flatten, tree_unflatten, treespec_loads +from lightning.data.cache.pytree import tree_unflatten, treespec_loads from lightning.data.cache.serializers import _SERIALIZERS, Serializer -from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv +from lightning.data.datasets.env import _DistributedEnv class ChunksConfig: @@ -36,36 +36,41 @@ def __init__(self, cache_dir: str, index_filenames: str): if self._config is None: self._config = data["config"] - self._config["data_spec"] = treespec_loads(self._config["data_spec"]) - flattened_data_format, _ = tree_flatten(self._config["data_format"]) - self._config["flattened_data_format"] = flattened_data_format elif self._config != data["config"]: raise Exception("The config isn't consistent between chunks. This shouldn't have happened.") self._chunks.extend(data["chunks"]) + self._config["data_spec"] = treespec_loads(self._config["data_spec"]) + for chunk in self._chunks: start, end = chunk["interval"] - if (end - start + 1) != chunk["samples"]: + if (end - start) != chunk["samples"]: raise Exception( "The config intervals doesn't match the number of samples. This shouldn't have happened." ) self._intervals.append(chunk["interval"]) + self._length = sum([chunk["samples"] for chunk in self._chunks]) + @property def intervals(self): return self._intervals + @property + def data_format(self): + return self._config["data_format"] + @property def config(self): return self._config def __getitem__(self, index: int) -> Tuple[str, int, int]: """Find the associated chunk metadata.""" - for interval_index, internal in enumerate(self._intervals): - if internal[0] <= index and index <= internal[1]: - chunk = self._chunks[interval_index] + for interval_config, internal in enumerate(self._intervals): + if internal[0] <= index and index < internal[1]: + chunk = self._chunks[interval_config] mapping = chunk["mapping"][str(index)] return os.path.join(self._cache_dir, chunk["filename"]), *mapping raise Exception(f"The chunk interval weren't properly defined. Found {self._intervals} for index {index}.") @@ -78,6 +83,9 @@ def load(cls, cache_dir: str) -> Optional["ChunksConfig"]: return None return ChunksConfig(cache_dir, index_filenames) + def __len__(self) -> int: + return self._length + class BinaryReader: def __init__(self, cache_dir: str, compression: Optional[str] = None): @@ -96,15 +104,13 @@ def __init__(self, cache_dir: str, compression: Optional[str] = None): raise FileNotFoundError(f"The provided cache_dir `{self._cache_dir}` doesn't exist.") self._compression = compression - self._index = None + self._config = None self._intervals = None self._chunks_data = {} self._serializers: Dict[str, Serializer] = _SERIALIZERS self._env = _DistributedEnv.detect() - self._worker_env: Optional[_WorkerEnv] = None - self._config: Optional[ChunksConfig] = None def _try_load_config(self): @@ -117,7 +123,7 @@ def read(self, index: int): If the chunk isn't available, it will be downloaded. """ - if self._index is None: + if self._config is None: self._try_load_config() if self._config is None: @@ -131,8 +137,8 @@ def deserialize(self, raw_item_data: bytes) -> Any: """Deserialize the raw bytes into their python equivalent.""" sizes = [] idx = 0 - data_format = self._config.config["flattened_data_format"] - for key in data_format: + data_format = self._config.data_format + for _ in data_format: (size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32) sizes.append(size) idx += 4 @@ -161,20 +167,20 @@ def load_item_from_chunk(self, chunk_filepath: str, begin: int, end: int, keep_i def get_length(self) -> int: """Get the number of samples across all chunks.""" - if self._index is None: + if self._config is None: self._try_load_config() - if self._index is None: + if self._config is None: raise Exception("The reader index isn't defined.") - return sum([v["samples"] for v in self._index["chunks"]]) + return len(self._config) def get_chunk_interval(self): """Get the index interval of each chunks.""" - if self._index is None: + if self._config is None: self._try_load_config() - if self._intervals is None: + if self._config is None: raise Exception("The reader index isn't defined.") - return self._intervals + return self._config.intervals diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index aa59b46e4d846..cdc2135d3f74d 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -50,7 +50,7 @@ def can_serialize(self, data: any) -> bool: class PILSerializer(Serializer): """The PILSerializer serialize and deserialize PIL Image to and from bytes.""" - def serialize(self, item: any) -> bytes: + def serialize(self, item: Image) -> bytes: mode = item.mode.encode("utf-8") width, height = item.size raw = item.tobytes() @@ -67,7 +67,7 @@ def deserialize(self, data: bytes) -> any: return Image.frombytes(mode, size, raw) # pyright: ignore def can_serialize(self, item) -> bool: - pass + return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile) class IntSerializer(Serializer): @@ -86,26 +86,35 @@ def can_serialize(self, item) -> bool: class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" - def serialize(self, obj: Image) -> bytes: - if isinstance(obj, JpegImageFile): - if not hasattr(obj, "filename"): + def serialize(self, item: Image) -> bytes: + if isinstance(item, JpegImageFile): + if not hasattr(item, "filename"): raise ValueError( "The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method." ) - with open(obj.filename, "rb") as f: + with open(item.filename, "rb") as f: return f.read() - raise TypeError(f"The provided object should be of type {JpegImageFile}. Found {obj}.") + raise TypeError(f"The provided itemect should be of type {JpegImageFile}. Found {item}.") def deserialize(self, data: bytes) -> Image: inp = BytesIO(data) return Image.open(inp) def can_serialize(self, item) -> bool: - pass + return isinstance(item, JpegImageFile) + + +class BytesSerializer(Serializer): + """The BytesSerializer serialize and deserialize integer to and from bytes.""" + + def serialize(self, item: bytes) -> bytes: + return item + + def deserialize(self, item: bytes) -> bytes: + return item + + def can_serialize(self, item: bytes) -> bool: + return isinstance(item, bytes) -_SERIALIZERS = { - "pil": PILSerializer(), - "int": IntSerializer(), - "jpeg": JPEGSerializer(), -} +_SERIALIZERS = {"pil": PILSerializer(), "int": IntSerializer(), "jpeg": JPEGSerializer(), "bytes": BytesSerializer()} diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index c4df4b164aaee..1dafac0c14247 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -13,12 +13,12 @@ import json import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional import numpy as np from lightning.data.cache.compression import _COMPRESSORS -from lightning.data.cache.pytree import tree_flatten, tree_unflatten, treespec_dumps +from lightning.data.cache.pytree import tree_flatten, treespec_dumps from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.cache.worker import get_worker_info from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -38,7 +38,6 @@ class BinaryWriter: def __init__( self, cache_dir: str, - data_format: Union[Dict[str, str], str] = None, chunk_size: int = 1 << 26, compression: Optional[str] = None, ): @@ -46,7 +45,6 @@ def __init__( Arguments: cache_dir: The path to where the chunks will be saved. - data_format: The format of the provided data to cache. chunk_size: The maximum number of bytes to store within a chunk. compression: The compression algorithm to use. @@ -63,27 +61,6 @@ def __init__( self._data_format = None self._data_spec = None - if data_format: - if isinstance(data_format, str): - self._data_format = data_format - self._data_format_key = None - elif isinstance(data_format, Dict): - if len(data_format) == 0: - raise ValueError("The provided data format shouldn't be empty.") - self._data_format = {k.lower(): v for k, v in data_format.items()} - self._data_format_keys = sorted(self._data_format.keys()) - - available_serializers = set(self._serializers.keys()) - selected_serializers = set(self._data_format.values()) - if selected_serializers.difference(available_serializers): - raise ValueError( - "The provided data_format don't match the provided serializers." - " Should be selected from {sorted(available_serializers)}." - ) - else: - self._data_format = None - self._data_format_key = None - if self._compression: if len(_COMPRESSORS) == 0: raise ValueError("No compresion algorithms are installed.") @@ -103,6 +80,17 @@ def __init__( self._rank = None self._is_done = False + @property + def filled(self) -> bool: + """Returns whether the caching phase is done.""" + if self._is_done: + return True + files = os.listdir(self._cache_dir) + index_files = [f for f in files if f.endswith("index.json")] + worker_end = _WorkerEnv.detect(get_worker_info_fn=get_worker_info) + self._is_done = len(index_files) == self._env.world_size * worker_end.world_size + return self._is_done + @property def rank(self): """Returns the rank of the writer.""" @@ -117,7 +105,7 @@ def get_config(self) -> Dict[str, Any]: "compression": self._compression, "chunk_size": self._chunk_size, "data_format": self._data_format, - "data_spec": treespec_dumps(self._data_spec), + "data_spec": treespec_dumps(self._data_spec) if self._data_spec else None, } cloud_path = self.get_cloud_path(self._cache_dir) if cloud_path: @@ -134,35 +122,27 @@ def serialize(self, items: Any) -> bytes: sizes = [] data = [] - formats = [] + data_format = [] for item in flattened: - formats.append(self._serialize(item, sizes, data)) - - data_format = tree_unflatten(formats, data_spec) + data_format.append(self._serialize(item, sizes, data)) if self._data_format is None: self._data_format = data_format + elif self._data_format != data_format: + raise Exception( + f"The data format changed between items. Found {data_format} instead of {self._data_format}." + ) + + if self._data_spec is None: self._data_spec = data_spec - else: - if self._data_format != data_format: - raise Exception( - f"The data format changed between items. Found {data_format} instead of {self._data_format}." - ) - if self._data_spec != data_spec: - raise Exception( - f"The data format changed between items. Found {data_spec} instead of {self._data_spec}." - ) + elif self._data_spec != data_spec: + raise Exception(f"The data format changed between items. Found {data_spec} instead of {self._data_spec}.") head = np.array(sizes, np.uint32).tobytes() body = b"".join(data) return head + body def _serialize(self, item, sizes, data) -> bytes: - if isinstance(item, bytes): - data.append(item) - sizes.append(len(item)) - return "bytes" - for serializer_name, serializer in self._serializers.items(): if serializer.can_serialize(item): serialized_item = serializer.serialize(item) @@ -191,7 +171,7 @@ def _create_chunk(self, filename: str) -> bytes: "samples": len(self._serialized_items), "filename": filename, "mapping": mapping, - "interval": [self._indexes[0], self._indexes[-1]], + "interval": [self._indexes[0], self._indexes[-1] + 1], } self._chunks_info.append(chunk_info) @@ -258,8 +238,9 @@ def write_file( def write_chunks_index(self): """Write the chunks index to a JSON file.""" filepath = os.path.join(self._cache_dir, f"{self.rank}.index.json") + config = self.get_config() with open(filepath, "w") as out: - json.dump({"chunks": self._chunks_info, "config": self.get_config()}, out, sort_keys=True) + json.dump({"chunks": self._chunks_info, "config": config}, out, sort_keys=True) def done(self): """Called when StopIteration is triggered. @@ -267,7 +248,7 @@ def done(self): It tries to save the last chunk and write the chunks index. """ - if self._is_done: + if self.filled: return if self._serialized_items: self.write_chunk() diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index dfc6725cbe634..28fc823838006 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -56,10 +56,12 @@ def __len__(self): return len(self.data) def __getitem__(self, index): + from PIL import Image + if self.cache.filled: data = self.cache[index] if self.use_transform: - data["image"] = self.transform(data["image"]).unsqueeze(0) + data["image"] = self.transform(Image.open(io.BytesIO(data["image"]))).unsqueeze(0) return data self.cache[index] = {**self.data[index], "index": index} return None @@ -73,19 +75,25 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): cache_dir = os.path.join(tmpdir, "cache") distributed_env = _DistributedEnv.detect() - cache = Cache(cache_dir, data_format={"image": "jpeg", "class": "int", "index": "int"}, chunk_size=2 << 12) + cache = Cache(cache_dir, chunk_size=2 << 12) dataset = ImageDataset(tmpdir, cache, dataset_size, 10) dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4) for _ in dataloader: pass + # Not strictly required but added to avoid race condition + if distributed_env.world_size > 1: + fabric.barrier() + + assert cache.filled + for i in range(len(dataset)): cached_data = dataset[i] original_data = dataset.data[i] assert cached_data["class"] == original_data["class"] original_image = Image.open(io.BytesIO(original_data["image"])) - assert cached_data["image"] == original_image + assert Image.open(io.BytesIO(cached_data["image"])) == original_image dataset.use_transform = True @@ -118,7 +126,7 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): @pytest.mark.skipif( condition=not _PIL_AVAILABLE or not _TORCH_VISION_AVAILABLE, reason="Requires: ['pil', 'torchvision']" ) -@pytest.mark.parametrize("num_workers", [0]) +@pytest.mark.parametrize("num_workers", [0, 1, 2]) def test_cache_for_image_dataset(num_workers, tmpdir): cache_dir = os.path.join(tmpdir, "cache") os.makedirs(cache_dir) diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index 01b6e41b121e6..11e24ea26478b 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -22,7 +22,7 @@ def test_serializers(): - assert sorted(_SERIALIZERS) == ["int", "jpeg", "pil"] + assert sorted(_SERIALIZERS) == ["bytes", "int", "jpeg", "pil"] def test_int_serializer(): diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index 7a16db693c748..da941150341c0 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -28,16 +28,10 @@ def test_binary_writer_with_ints(tmpdir): with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): BinaryWriter("dontexists", {}) - with pytest.raises(ValueError, match="The provided data format shouldn't be empty."): - BinaryWriter(tmpdir, {}) - - with pytest.raises(ValueError, match="['int', 'jpeg', 'pil']"): - BinaryWriter(tmpdir, {"i": "random"}) - with pytest.raises(ValueError, match="No compresion algorithms are installed."): BinaryWriter(tmpdir, {"i": "int"}, compression="something_else") - binary_writer = BinaryWriter(tmpdir, {"i": "int", "i+1": "int", "i+2": "int"}, chunk_size=90) + binary_writer = BinaryWriter(tmpdir, chunk_size=90) for i in range(100): binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2} @@ -66,7 +60,7 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): cache_dir = os.path.join(tmpdir, "chunks") os.makedirs(cache_dir, exist_ok=True) - binary_writer = BinaryWriter(cache_dir, {"x": "jpeg", "y": "int"}, chunk_size=2 << 12) + binary_writer = BinaryWriter(cache_dir, chunk_size=2 << 12) imgs = [] From 976d680baaa318311461c02ddab74541ee92b4b2 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 15:59:34 +0100 Subject: [PATCH 38/84] update --- tests/tests_data/cache/test_sampler.py | 43 -------------------------- 1 file changed, 43 deletions(-) diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index e4ba606d8a3cc..88d71cdf753b2 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -136,49 +136,6 @@ def test_cache_distributed_sampler_samplers(params): assert sampler.samplers[2].data_source == params[2][rank][2] -def test_cache_distributed_sampler_sampling(): - """Valides the DistributedCacheSampler can return batch of data in an ordered way.""" - dataset_size = 129 - sampler = DistributedCacheSampler(dataset_size, 5, 3, 3) - iter_sampler = iter(sampler) - - all_indexes = [] - indexes = [] - while True: - try: - index = next(iter_sampler) - indexes.append(index) - all_indexes.append(index) - except StopIteration: - assert indexes == [0, 1, 2, 5, 6, 7, 10, 11, 12, 3, 4] - assert sampler._done == {0} - break - - indexes = [] - while True: - try: - index = next(iter_sampler) - indexes.append(index) - all_indexes.append(index) - except StopIteration: - assert indexes == [8, 9] - assert sampler._done == {0, 1} - break - - indexes = [] - while True: - try: - index = next(iter_sampler) - indexes.append(index) - all_indexes.append(index) - except StopIteration: - assert indexes == [13, 14, 15, 16] - assert sampler._done == {0, 1, 2} - break - - assert sorted(all_indexes) == list(range(dataset_size)) - - @pytest.mark.parametrize( "params", [ From c32c5aae9f5911e61e6168b101beedce75f9f7a9 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 16:53:29 +0100 Subject: [PATCH 39/84] update --- src/lightning/data/cache/serializers.py | 63 ++++++++++++++++++++++- tests/tests_data/cache/test_sampler.py | 1 - tests/tests_data/cache/test_serializer.py | 35 ++++++++++++- 3 files changed, 95 insertions(+), 4 deletions(-) diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index cdc2135d3f74d..90af6dfb3be2b 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -15,6 +15,7 @@ from io import BytesIO import numpy as np +import torch from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") @@ -117,4 +118,64 @@ def can_serialize(self, item: bytes) -> bool: return isinstance(item, bytes) -_SERIALIZERS = {"pil": PILSerializer(), "int": IntSerializer(), "jpeg": JPEGSerializer(), "bytes": BytesSerializer()} +_TORCH_DTYPES_MAPPING = { + 0: torch.float32, + 1: torch.float, + 2: torch.float64, + 3: torch.double, + 4: torch.complex64, + 5: torch.cfloat, + 6: torch.complex128, + 7: torch.cdouble, + 8: torch.float16, + 9: torch.half, + 10: torch.bfloat16, + 11: torch.uint8, + 12: torch.int8, + 13: torch.int16, + 14: torch.short, + 15: torch.int32, + 16: torch.int, + 17: torch.int64, + 18: torch.long, + 19: torch.bool, +} + + +class TensorSerializer(Serializer): + """The TensorSerializer serialize and deserialize tensor to and from bytes.""" + + def __init__(self): + super().__init__() + self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()} + + def serialize(self, item: torch.Tensor) -> bytes: + dtype_indice = self._dtype_to_indice[item.dtype] + data = [np.uint32(dtype_indice).tobytes()] + data.append(np.uint32(len(item.shape)).tobytes()) + for dim in item.shape: + data.append(np.uint32(dim).tobytes()) + data.append(item.numpy().tobytes()) + return b"".join(data) + + def deserialize(self, data: bytes) -> torch.Tensor: + dtype_indice = np.frombuffer(data[0:4], np.uint32).item() + dtype = _TORCH_DTYPES_MAPPING[dtype_indice] + shape_size = np.frombuffer(data[4:8], np.uint32).item() + shape = [] + for shape_idx in range(shape_size): + shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) + tensor = torch.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype) + return torch.reshape(tensor, torch.Size(shape)) + + def can_serialize(self, item: torch.Tensor) -> bool: + return isinstance(item, torch.Tensor) + + +_SERIALIZERS = { + "pil": PILSerializer(), + "int": IntSerializer(), + "jpeg": JPEGSerializer(), + "bytes": BytesSerializer(), + "tensor": TensorSerializer(), +} diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index 88d71cdf753b2..320a3b4c6ae50 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -154,5 +154,4 @@ def test_cache_batch_sampler(params): batches = [] for batch in batch_sampler: batches.append(batch) - print(batches) assert batches == params[1] diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index 11e24ea26478b..14ccc02fc9e10 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -15,14 +15,22 @@ import numpy as np import pytest -from lightning.data.cache.serializers import _SERIALIZERS, IntSerializer, JPEGSerializer, PILSerializer +import torch +from lightning.data.cache.serializers import ( + _SERIALIZERS, + _TORCH_DTYPES_MAPPING, + IntSerializer, + JPEGSerializer, + PILSerializer, + TensorSerializer, +) from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") def test_serializers(): - assert sorted(_SERIALIZERS) == ["bytes", "int", "jpeg", "pil"] + assert sorted(_SERIALIZERS) == ["bytes", "int", "jpeg", "pil", "tensor"] def test_int_serializer(): @@ -91,3 +99,26 @@ def test_pil_serializer(mode): # Validate data content assert np.array_equal(np_data, np_dec_data) + + +def test_tensor_serializer(): + serializer = TensorSerializer() + + shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)] + for dtype in _TORCH_DTYPES_MAPPING.values(): + for shape in shapes: + # Not serializable for some reasons + if dtype in [torch.bfloat16]: + continue + tensor = torch.ones(shape, dtype=dtype) + data = serializer.serialize(tensor) + deserialized_tensor = serializer.deserialize(data) + assert deserialized_tensor.dtype == dtype + assert torch.equal(tensor, deserialized_tensor) + + +def test_assert_bfloat16_tensor_serializer(): + serializer = TensorSerializer() + tensor = torch.ones((10,), dtype=torch.bfloat16) + with pytest.raises(TypeError, match="Got unsupported ScalarType BFloat16"): + serializer.serialize(tensor) From e3cd282279fc37b451c49c1fc336bde3d953a8ec Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 17:18:37 +0100 Subject: [PATCH 40/84] update --- src/lightning/data/cache/sampler.py | 9 +++--- src/lightning/data/cache/serializers.py | 2 +- tests/tests_data/cache/test_sampler.py | 37 +++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py index 5808bf93cf024..26087cfddb528 100644 --- a/src/lightning/data/cache/sampler.py +++ b/src/lightning/data/cache/sampler.py @@ -214,6 +214,7 @@ def __init__( self._cache = cache self._shuffle = shuffle self._num_workers = num_workers + self._shuffled_chunk_intervals = None def __iter_ordered__(self) -> Iterator[List[int]]: # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 @@ -242,11 +243,11 @@ def __iter__(self): def __iter_from_chunks__(self): chunk_intervals = self._cache.get_chunk_interval() - shuffled_chunk_intervals = np.random.permutation(chunk_intervals) + self._shuffled_chunk_intervals = np.random.permutation(chunk_intervals) if self._num_replicas == 1: indices = [] - for interval in shuffled_chunk_intervals: + for interval in self._shuffled_chunk_intervals: interval_indices = np.arange(interval[0], interval[1]) shuffled_interval_indices = np.random.permutation(interval_indices) indices.extend(shuffled_interval_indices.tolist()) @@ -255,14 +256,14 @@ def __iter_from_chunks__(self): raise Exception("The generated indices don't match the initial length of the sampler.") else: - chunks_per_replica = len(shuffled_chunk_intervals) // self._num_replicas + chunks_per_replica = len(self._shuffled_chunk_intervals) // self._num_replicas for replica_idx in range(self._num_replicas): if replica_idx != self._rank: continue is_last_replica = replica_idx == self._num_replicas - 1 start_replica = replica_idx * chunks_per_replica end_replica = len(chunk_intervals) if is_last_replica else (replica_idx + 1) * chunks_per_replica - shuffled_chunk_intervals_replica = shuffled_chunk_intervals[start_replica:end_replica] + shuffled_chunk_intervals_replica = self._shuffled_chunk_intervals[start_replica:end_replica] indices = [] for interval in shuffled_chunk_intervals_replica: diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index 90af6dfb3be2b..7777f8b2ba198 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -129,7 +129,7 @@ def can_serialize(self, item: bytes) -> bool: 7: torch.cdouble, 8: torch.float16, 9: torch.half, - 10: torch.bfloat16, + 10: torch.bfloat16, # Not supported https://github.com/pytorch/pytorch/issues/110285 11: torch.uint8, 12: torch.int8, 13: torch.int16, diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index 320a3b4c6ae50..57efaa757162d 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -1,6 +1,8 @@ from unittest import mock +import numpy as np import pytest +from lightning import seed_everything from lightning.data.cache.sampler import CacheBatchSampler, CacheSampler, DistributedCacheSampler @@ -155,3 +157,38 @@ def test_cache_batch_sampler(params): for batch in batch_sampler: batches.append(batch) assert batches == params[1] + + chunk_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)] + + cache.filled = True + cache.get_chunk_interval.return_value = chunk_interval + + seed_everything(42) + + batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache) + + batches_1 = [] + for batch in batch_sampler: + batches_1.extend(batch) + + size = 0 + for interval in batch_sampler._shuffled_chunk_intervals: + interval_indices = np.arange(interval[0], interval[1]) + for indice in interval_indices: + assert indice in batches_1[size : size + len(interval_indices)] + size += len(interval_indices) + + assert len(batches_1) == params[0] + + batches_2 = [] + for batch in batch_sampler: + batches_2.extend(batch) + + size = 0 + for interval in batch_sampler._shuffled_chunk_intervals: + interval_indices = np.arange(interval[0], interval[1]) + for indice in interval_indices: + assert indice in batches_2[size : size + len(interval_indices)] + size += len(interval_indices) + + assert batches_1 != batches_2 From 888cae4d525c528593fc7cfc7eb84607122879d7 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 17:28:52 +0100 Subject: [PATCH 41/84] update --- tests/tests_data/cache/test_sampler.py | 69 ++++++++++++++++---------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index 57efaa757162d..dc7eb65db7e92 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -141,22 +141,23 @@ def test_cache_distributed_sampler_samplers(params): @pytest.mark.parametrize( "params", [ - (21, [[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]]), - (11, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]]), - (8, [[0, 1], [2, 3], [4, 5, 6], [7]]), - (4, [[0], [1], [2, 3]]), - (9, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], []]), - (19, [[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]]), + (21, 1, [[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]]), + (11, 1, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]]), + (8, 1, [[0, 1], [2, 3], [4, 5, 6], [7]]), + (4, 1, [[0], [1], [2, 3]]), + (9, 1, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], []]), + (19, 1, [[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]]), + (19, 2, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], []]), ], ) def test_cache_batch_sampler(params): cache = mock.MagicMock() cache.filled = False - batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache) + batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache) batches = [] for batch in batch_sampler: batches.append(batch) - assert batches == params[1] + assert batches == params[2] chunk_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)] @@ -165,30 +166,48 @@ def test_cache_batch_sampler(params): seed_everything(42) - batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache) + batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache) batches_1 = [] for batch in batch_sampler: batches_1.extend(batch) - size = 0 - for interval in batch_sampler._shuffled_chunk_intervals: - interval_indices = np.arange(interval[0], interval[1]) - for indice in interval_indices: - assert indice in batches_1[size : size + len(interval_indices)] - size += len(interval_indices) - - assert len(batches_1) == params[0] + def validate_batch(data): + chunks = batch_sampler._shuffled_chunk_intervals + if params[1] == 1: + size = 0 + for interval in chunks: + interval_indices = np.arange(interval[0], interval[1]) + for indice in interval_indices: + assert indice in data[size : size + len(interval_indices)] + size += len(interval_indices) + else: + chunks_per_replica = len(chunks) // params[1] + for replica_idx in range(params[1]): + if replica_idx != 0: + continue + is_last_replica = replica_idx == params[1] - 1 + start_replica = replica_idx * chunks_per_replica + end_replica = len(chunks) if is_last_replica else (replica_idx + 1) * chunks_per_replica + shuffled_chunk_intervals_replica = chunks[start_replica:end_replica] + + assert len(shuffled_chunk_intervals_replica) + + size = 0 + for interval in shuffled_chunk_intervals_replica: + interval_indices = np.arange(interval[0], interval[1]) + for indice in interval_indices: + assert indice in data[size : size + len(interval_indices)] + size += len(interval_indices) + + validate_batch(batches_1) + if params[1] == 1: + assert len(batches_1) == params[0] batches_2 = [] for batch in batch_sampler: batches_2.extend(batch) - size = 0 - for interval in batch_sampler._shuffled_chunk_intervals: - interval_indices = np.arange(interval[0], interval[1]) - for indice in interval_indices: - assert indice in batches_2[size : size + len(interval_indices)] - size += len(interval_indices) - - assert batches_1 != batches_2 + validate_batch(batches_2) + if params[1] == 1: + assert batches_1 != batches_2 From de58298ad879c8440ab0045b9bdc9643900b9d5d Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 17:46:51 +0100 Subject: [PATCH 42/84] update --- src/lightning/data/cache/dataloader.py | 42 ++++++++++++++++++++++++-- tests/tests_data/cache/test_cache.py | 10 ++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 111509ca8cfe2..c602d2eff3140 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -12,8 +12,11 @@ # limitations under the License. import logging +import os +from datetime import datetime +from typing import Optional -from torch.utils.data import IterableDataset +from torch.utils.data import Dataset, IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import ( DataLoader, @@ -29,6 +32,27 @@ logger = logging.Logger(__name__) +class CacheDataset(Dataset): + def __init__( + self, dataset: Dataset, cache_dir: Optional[str], chunk_size: Optional[int], compression: Optional[str] + ): + self._datataset = dataset + if cache_dir is None: + cache_dir = os.path.join(os.getcwd(), "cache_dir", datetime.now().strftime("%m-%d-%Y-%H-%M")) + os.makedirs(cache_dir, exist_ok=True) + chunk_size = 2 << 26 + self.cache = Cache(cache_dir, chunk_size=chunk_size, compression=compression) + + def __len__(self) -> int: + return len(self.cache) if self.cache.filled else len(self._datataset) + + def __getitem__(self, index): + data = self.cache[index] if self.cache.filled else self._datataset[index] + if not self.cache.filled: + self.cache[index] = data + return data + + class CacheCollateFn: def __init__(self): self.collate_fn = default_collate @@ -75,6 +99,9 @@ def __init__( generator=None, batch_size=1, drop_last=False, + cache_dir: Optional[str] = None, + chunk_size: Optional[int] = 2 << 26, + compression: Optional[str] = None, **kwargs, ): if sampler: @@ -88,11 +115,20 @@ def __init__( cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] - if not cache or len(cache) > 1: + if len(cache) > 1: raise Exception(f"The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") - cache = cache[0] + if len(cache) == 0: + if cache_dir is None: + logger.info("You can provide a `cache_dir` filepath to the CacheDataLoader.") + + dataset = CacheDataset(dataset, cache_dir, chunk_size, compression) + cache = dataset.cache + else: + cache = cache[0] + cache._setup(num_workers) + if not cache.filled and shuffle: logger.info("Shuffle is ignored during caching phase") diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 28fc823838006..47fb7a8bd494d 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -18,10 +18,12 @@ import numpy as np import pytest +import torch from lightning import seed_everything from lightning.data.cache import Cache, CacheDataLoader from lightning.data.datasets.env import _DistributedEnv from lightning.fabric import Fabric +from lightning.pytorch.demos.boring_classes import RandomDataset from lightning_utilities.core.imports import RequirementCache from torch.utils.data import Dataset @@ -177,3 +179,11 @@ def test_cache_with_simple_format(tmpdir): for i in range(100): assert [i, {0: [i + 1]}] == cache[i] + + +def test_cache_with_auto_wrapping(tmpdir): + dataset = RandomDataset(64, 64) + dataloader = CacheDataLoader(dataset, cache_dir=tmpdir) + for batch in dataloader: + assert isinstance(batch, torch.Tensor) + assert len(os.listdir(tmpdir)) == 2 From c13f948ff41bd7d0e465c73d5096985adce77a88 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 17:50:07 +0100 Subject: [PATCH 43/84] update --- tests/tests_data/cache/test_cache.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 47fb7a8bd494d..9555a3b3e525e 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -186,4 +186,4 @@ def test_cache_with_auto_wrapping(tmpdir): dataloader = CacheDataLoader(dataset, cache_dir=tmpdir) for batch in dataloader: assert isinstance(batch, torch.Tensor) - assert len(os.listdir(tmpdir)) == 2 + assert sorted(os.listdir(tmpdir)) == ["0.index.json", "chunk-0-0.bin"] From b03afd034c9d74cf9de9e974fa248b852ebe1825 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 29 Sep 2023 17:58:33 +0100 Subject: [PATCH 44/84] update --- tests/tests_data/cache/test_cache.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 9555a3b3e525e..1d60c7b22f8c3 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -183,7 +183,8 @@ def test_cache_with_simple_format(tmpdir): def test_cache_with_auto_wrapping(tmpdir): dataset = RandomDataset(64, 64) - dataloader = CacheDataLoader(dataset, cache_dir=tmpdir) + dataloader = CacheDataLoader(dataset, cache_dir=tmpdir, chunk_size=2 << 12) for batch in dataloader: assert isinstance(batch, torch.Tensor) - assert sorted(os.listdir(tmpdir)) == ["0.index.json", "chunk-0-0.bin"] + assert sorted(os.listdir(tmpdir)) == ["0.index.json", "chunk-0-0.bin", "chunk-0-1.bin", "chunk-0-2.bin"] + # Your dataset is optimised for the cloud From ab1c0e13d0f6607db93c0110a5867e589c37d848 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 2 Oct 2023 09:37:56 +0100 Subject: [PATCH 45/84] update --- src/lightning/data/cache/__init__.py | 4 +- src/lightning/data/cache/cache.py | 21 +++++++++- src/lightning/data/cache/dataloader.py | 57 ++++++++++++++++++++------ src/lightning/data/cache/reader.py | 27 ++++++++---- src/lightning/data/cache/sampler.py | 41 ++++++++++-------- src/lightning/data/cache/writer.py | 34 +++++++-------- tests/tests_data/cache/test_cache.py | 38 ++++++++++++++--- tests/tests_data/cache/test_sampler.py | 4 +- tests/tests_data/cache/test_writer.py | 18 ++++---- 9 files changed, 168 insertions(+), 76 deletions(-) diff --git a/src/lightning/data/cache/__init__.py b/src/lightning/data/cache/__init__.py index 996936c877edf..1f9601debc26a 100644 --- a/src/lightning/data/cache/__init__.py +++ b/src/lightning/data/cache/__init__.py @@ -12,6 +12,6 @@ # limitations under the License. from lightning.data.cache.cache import Cache -from lightning.data.cache.dataloader import CacheDataLoader +from lightning.data.cache.dataloader import LightningDataLoader -__all__ = ["Cache", "CacheDataLoader"] +__all__ = ["Cache", "LightningDataLoader"] diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 6d2e4093575d0..bd0f2fd697f61 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -15,6 +15,9 @@ import os from typing import Any, Dict, Optional +import torch + +from lightning.data.cache.pytree import tree_flatten from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter from lightning.data.datasets.env import _DistributedEnv @@ -46,15 +49,29 @@ def __init__( self._distributed_env = _DistributedEnv.detect() self._num_workers: Optional[int] = None + def _is_item_equal(self, data_1: Any, data_2: Any) -> bool: + data_1_flattened, _ = tree_flatten(data_1) + data_2_flattened, _ = tree_flatten(data_2) + + if len(data_1_flattened) != len(data_2_flattened): + return False + + return all(self._is_data_equal(d1, d2) for d1, d2 in zip(data_1_flattened, data_2_flattened)) + + def _is_data_equal(self, d1, d2) -> bool: + if isinstance(d1, torch.Tensor) and isinstance(d2, torch.Tensor): + return torch.equal(d1, d2) + return d1 == d2 + def _setup(self, num_workers: int) -> None: - """Called by the CacheDataLoader to ensure the num_workers is known.""" + """Called by the LightningDataLoader to ensure the num_workers is known.""" self._num_workers = num_workers @property def filled(self) -> bool: """Returns whether the caching phase is done.""" if self._num_workers is None: - raise Exception("The Cache wasn't setup properly. HINT: Did you use the CacheDataLoader ?") + raise Exception("The Cache wasn't setup properly. HINT: Did you use the LightningDataLoader ?") if self._is_done: return True files = os.listdir(self._cache_dir) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index c602d2eff3140..4ee0e7b06eb7f 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -14,8 +14,9 @@ import logging import os from datetime import datetime -from typing import Optional +from typing import Any, Optional +import torch from torch.utils.data import Dataset, IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import ( @@ -26,12 +27,31 @@ ) from lightning.data.cache import Cache +from lightning.data.cache.pytree import tree_flatten from lightning.data.cache.sampler import CacheBatchSampler from lightning.data.datasets.env import _DistributedEnv logger = logging.Logger(__name__) +def _equal_items(data_1: Any, data_2: Any) -> bool: + data_1_flattened, _ = tree_flatten(data_1) + data_2_flattened, _ = tree_flatten(data_2) + + if len(data_1_flattened) != len(data_2_flattened): + return False + + return all(_equal_item(d1, d2) for d1, d2 in zip(data_1_flattened, data_2_flattened)) + + +def _equal_item(d1, d2) -> bool: + if not isinstance(d1, type(d2)): + raise False + if isinstance(d1, torch.Tensor) and isinstance(d2, torch.Tensor): + return torch.equal(d1, d2) + return d1 == d2 + + class CacheDataset(Dataset): def __init__( self, dataset: Dataset, cache_dir: Optional[str], chunk_size: Optional[int], compression: Optional[str] @@ -42,15 +62,24 @@ def __init__( os.makedirs(cache_dir, exist_ok=True) chunk_size = 2 << 26 self.cache = Cache(cache_dir, chunk_size=chunk_size, compression=compression) + self.is_deterministic = False def __len__(self) -> int: return len(self.cache) if self.cache.filled else len(self._datataset) def __getitem__(self, index): - data = self.cache[index] if self.cache.filled else self._datataset[index] + data_1 = self.cache[index] if self.cache.filled else self._datataset[index] if not self.cache.filled: - self.cache[index] = data - return data + if not self.is_deterministic: + data2 = self._datataset[index] + if not _equal_items(data_1, data2): + raise ValueError( + f"Your dataset items aren't deterministic. Found {data_1} and {data2} for index {index}." + " HINT: Use the `lightning.data.cache.Cache` directly within your dataset." + ) + self.is_deterministic = True + self.cache[index] = data_1 + return data_1 class CacheCollateFn: @@ -87,7 +116,7 @@ def __init__(self, loader): super().__init__(loader) -class CacheDataLoader(DataLoader): +class LightningDataLoader(DataLoader): def __init__( self, dataset, @@ -105,22 +134,26 @@ def __init__( **kwargs, ): if sampler: - raise Exception("Passing a sampler isn't supported with the CacheDataLoader.") + raise Exception( + "The LightningDataLoader relies on its own internal sampler. Passing a sampler isn't supported." + ) if batch_sampler: - raise Exception("Passing a batch_sampler isn't supported with the CacheDataLoader.") + raise Exception( + "The LightningDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported." + ) if isinstance(dataset, IterableDataset): - raise Exception("Only map-based dataset are supported by the CacheDataLoader for now.") + raise Exception("Only map-based dataset are supported by the LightningDataLoader for now.") cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] if len(cache) > 1: - raise Exception(f"The CacheDataloader should be used with a dataset using a single cache. Found {cache}.") + raise Exception("We found several Cache used as attributes from your dataset. Only one is support for now.") if len(cache) == 0: if cache_dir is None: - logger.info("You can provide a `cache_dir` filepath to the CacheDataLoader.") + logger.info("You can provide a `cache_dir` filepath to the LightningDataLoader.") dataset = CacheDataset(dataset, cache_dir, chunk_size, compression) cache = dataset.cache @@ -130,7 +163,7 @@ def __init__( cache._setup(num_workers) if not cache.filled and shuffle: - logger.info("Shuffle is ignored during caching phase") + logger.info("Shuffle is ignored during the caching phase phase") distributed_env = _DistributedEnv.detect() batch_sampler = CacheBatchSampler( @@ -156,7 +189,7 @@ def __init__( ) def _get_iterator(self) -> "_BaseDataLoaderIter": - """Overriden to ensure the `Cache.done` method is triggered on iteration done.""" + """Overriden to ensure the `Cache.done()` method is triggered on iteration done.""" if self.num_workers == 0: return _SingleProcessDataLoaderIterPatch(self) self.check_worker_number_rationality() diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 7395639b2f402..4ffcaae5d1fee 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -13,11 +13,12 @@ import json import os -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union import numpy as np from lightning.data.cache.pytree import tree_unflatten, treespec_loads +from lightning.data.cache.sampler import BatchIndex from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.datasets.env import _DistributedEnv @@ -66,13 +67,19 @@ def data_format(self): def config(self): return self._config - def __getitem__(self, index: int) -> Tuple[str, int, int]: + def __getitem__(self, index: Union[int, BatchIndex]) -> Tuple[str, int, int]: """Find the associated chunk metadata.""" - for interval_config, internal in enumerate(self._intervals): - if internal[0] <= index and index < internal[1]: - chunk = self._chunks[interval_config] - mapping = chunk["mapping"][str(index)] - return os.path.join(self._cache_dir, chunk["filename"]), *mapping + if isinstance(index, int): + for interval_config, internal in enumerate(self._intervals): + if internal[0] <= index and index < internal[1]: + chunk = self._chunks[interval_config] + mapping = chunk["mapping"][str(index)] + return os.path.join(self._cache_dir, chunk["filename"]), *mapping + # Note: Optimisation to avoid doing the interval search. + elif isinstance(index, BatchIndex): + chunk = self._chunks[index.chunk_index] + mapping = chunk["mapping"][str(index.index)] + return os.path.join(self._cache_dir, chunk["filename"]), *mapping raise Exception(f"The chunk interval weren't properly defined. Found {self._intervals} for index {index}.") @classmethod @@ -117,10 +124,12 @@ def _try_load_config(self): """Try to load the chunks config if the index files are available.""" self._config = ChunksConfig.load(self._cache_dir) - def read(self, index: int): + def read(self, index: Union[int, BatchIndex]): """Read an item for the given from a chunk. - If the chunk isn't available, it will be downloaded. + If the chunk isn't available locally or in memory, it will be downloaded. + + Prefetching should reduce the wait time to be the batch available. """ if self._config is None: diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py index 26087cfddb528..06df9b34eb3be 100644 --- a/src/lightning/data/cache/sampler.py +++ b/src/lightning/data/cache/sampler.py @@ -12,14 +12,13 @@ # limitations under the License. import logging +from dataclasses import dataclass from typing import Iterator, List import numpy as np from torch.utils.data.distributed import DistributedSampler from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized -from lightning.data.cache import Cache - logger = logging.Logger(__name__) @@ -47,7 +46,7 @@ def __init__(self, dataset_size: int): super().__init__(None) self.dataset_size = dataset_size self.worker_id = 0 - self.indice_id = 0 + self.index_id = 0 self.iterators = [] self._done = set() @@ -81,14 +80,14 @@ def __next__(self) -> List[int]: while len(self._done) != self.iterators: try: data = next(self.iterators[self.worker_id]) - self.indice_id += 1 - if self.indice_id == self.batch_size: - self.indice_id = 0 + self.index_id += 1 + if self.index_id == self.batch_size: + self.index_id = 0 self._next_worker_id() return data except StopIteration: self._done.add(self.worker_id) - self.indice_id = 0 + self.index_id = 0 self._next_worker_id() raise StopIteration @@ -121,7 +120,7 @@ def __init__(self, dataset_size: int, num_workers: int, batch_size: int): self._done = set() assert sum([len(s) for s in self.samplers]) == dataset_size self.worker_id = 0 - self.indice_id = 0 + self.index_id = 0 class DistributedCacheSampler(BaseCacheSampler): @@ -165,7 +164,13 @@ def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: assert sum([len(s) for s in self.samplers]) == replica_size self.worker_id = 0 - self.indice_id = 0 + self.index_id = 0 + + +@dataclass +class BatchIndex: + index: int + chunk_index: int class CacheBatchSampler(BatchSampler): @@ -178,7 +183,7 @@ def __init__( batch_size: int, drop_last: bool, shuffle: bool, - cache: Cache, + cache: any, ): """The CacheBatchSampler handles the generation of batch indices. @@ -243,14 +248,15 @@ def __iter__(self): def __iter_from_chunks__(self): chunk_intervals = self._cache.get_chunk_interval() - self._shuffled_chunk_intervals = np.random.permutation(chunk_intervals) + shuffled_indices = np.random.permutation(range(len(chunk_intervals))) + self._shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indices] if self._num_replicas == 1: indices = [] - for interval in self._shuffled_chunk_intervals: + for interval, chunk_index in zip(self._shuffled_chunk_intervals, shuffled_indices): interval_indices = np.arange(interval[0], interval[1]) - shuffled_interval_indices = np.random.permutation(interval_indices) - indices.extend(shuffled_interval_indices.tolist()) + shuffled_interval_indices = np.random.permutation(interval_indices).tolist() + indices.extend([BatchIndex(index, chunk_index) for index in shuffled_interval_indices]) if len(indices) != len(self.sampler): raise Exception("The generated indices don't match the initial length of the sampler.") @@ -264,12 +270,13 @@ def __iter_from_chunks__(self): start_replica = replica_idx * chunks_per_replica end_replica = len(chunk_intervals) if is_last_replica else (replica_idx + 1) * chunks_per_replica shuffled_chunk_intervals_replica = self._shuffled_chunk_intervals[start_replica:end_replica] + shuffled_indices_replica = shuffled_indices[start_replica:end_replica] indices = [] - for interval in shuffled_chunk_intervals_replica: + for interval, chunk_index in zip(shuffled_chunk_intervals_replica, shuffled_indices_replica): interval_indices = np.arange(interval[0], interval[1]) - shuffled_interval_indices = np.random.permutation(interval_indices) - indices.extend(shuffled_interval_indices.tolist()) + shuffled_interval_indices = np.random.permutation(interval_indices).tolist() + indices.extend([BatchIndex(index, chunk_index) for index in shuffled_interval_indices]) self.sampler = IteratorSampler(indices) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 1dafac0c14247..7bc06c3feccc3 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -34,6 +34,20 @@ def cloud_path(cache_dir: str) -> Optional[str]: return f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/content/{cache_dir}/" +def get_cloud_path(cache_dir: str) -> Optional[str]: + """Returns the s3 URL to the cache_dir.""" + cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) + project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) + cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) + + if cluster_id is None or project_id is None or cloud_space_id is None: + return None + cache_dir = cache_dir.replace("~/", "").replace("~", "").replace("/teamspace/studios/this_studio/", "") + if cache_dir.startswith("/"): + cache_dir = cache_dir[1:] + return os.path.join(f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/code/content", cache_dir) + + class BinaryWriter: def __init__( self, @@ -107,7 +121,7 @@ def get_config(self) -> Dict[str, Any]: "data_format": self._data_format, "data_spec": treespec_dumps(self._data_spec) if self._data_spec else None, } - cloud_path = self.get_cloud_path(self._cache_dir) + cloud_path = get_cloud_path(self._cache_dir) if cloud_path: out["cloud_path"] = cloud_path user_id = os.getenv("LIGHTNING_USER_ID", None) @@ -201,6 +215,8 @@ def __setitem__(self, index, items: any): The index needs to be provided in order. + This is handled by the samplers automatically. This ensures we can map an index to a shard from an interval. + """ serialized_items = self.serialize(items) serialized_items_size = len(serialized_items) @@ -255,19 +271,3 @@ def done(self): self.write_chunks_index() self.reset() self._is_done = True - - @classmethod - def get_cloud_path(cls, cache_dir: str) -> Optional[str]: - """Returns the s3 URL to the cache_dir.""" - cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) - project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) - cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) - - if cluster_id is None or project_id is None or cloud_space_id is None: - return None - cache_dir = cache_dir.replace("~/", "").replace("~", "").replace("/teamspace/studios/this_studio/", "") - if cache_dir.startswith("/"): - cache_dir = cache_dir[1:] - return os.path.join( - f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/code/content", cache_dir - ) diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 1d60c7b22f8c3..bcd02427a7cb8 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -20,7 +20,8 @@ import pytest import torch from lightning import seed_everything -from lightning.data.cache import Cache, CacheDataLoader +from lightning.data.cache import Cache +from lightning.data.cache.dataloader import LightningDataLoader from lightning.data.datasets.env import _DistributedEnv from lightning.fabric import Fabric from lightning.pytorch.demos.boring_classes import RandomDataset @@ -79,7 +80,7 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): cache = Cache(cache_dir, chunk_size=2 << 12) dataset = ImageDataset(tmpdir, cache, dataset_size, 10) - dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4) + dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4) for _ in dataloader: pass @@ -101,14 +102,14 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): if distributed_env.world_size == 1: indexes = [] - for batch in CacheDataLoader(dataset, num_workers=num_workers, batch_size=4): + for batch in LightningDataLoader(dataset, num_workers=num_workers, batch_size=4): indexes.extend(batch["index"].numpy().tolist()) assert len(indexes) == dataset_size seed_everything(42) - dataloader = CacheDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True) + dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4, shuffle=True) dataloader_iter = iter(dataloader) indexes = [] @@ -182,9 +183,34 @@ def test_cache_with_simple_format(tmpdir): def test_cache_with_auto_wrapping(tmpdir): + os.makedirs(os.path.join(tmpdir, "cache_1"), exist_ok=True) + dataset = RandomDataset(64, 64) - dataloader = CacheDataLoader(dataset, cache_dir=tmpdir, chunk_size=2 << 12) + dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_1"), chunk_size=2 << 12) for batch in dataloader: assert isinstance(batch, torch.Tensor) - assert sorted(os.listdir(tmpdir)) == ["0.index.json", "chunk-0-0.bin", "chunk-0-1.bin", "chunk-0-2.bin"] + assert sorted(os.listdir(os.path.join(tmpdir, "cache_1"))) == [ + "0.index.json", + "chunk-0-0.bin", + "chunk-0-1.bin", + "chunk-0-2.bin", + ] # Your dataset is optimised for the cloud + + class RandomDatasetAtRuntime(Dataset): + def __init__(self, size: int, length: int): + self.len = length + self.size = size + + def __getitem__(self, index: int) -> torch.Tensor: + return torch.randn(1, self.size) + + def __len__(self) -> int: + return self.len + + os.makedirs(os.path.join(tmpdir, "cache_2"), exist_ok=True) + dataset = RandomDatasetAtRuntime(64, 64) + dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_2"), chunk_size=2 << 12) + with pytest.raises(ValueError, match="Your dataset items aren't deterministic"): + for batch in dataloader: + pass diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index dc7eb65db7e92..b99a914f0083e 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -179,7 +179,7 @@ def validate_batch(data): for interval in chunks: interval_indices = np.arange(interval[0], interval[1]) for indice in interval_indices: - assert indice in data[size : size + len(interval_indices)] + assert indice in [b.index for b in data[size : size + len(interval_indices)]] size += len(interval_indices) else: chunks_per_replica = len(chunks) // params[1] @@ -197,7 +197,7 @@ def validate_batch(data): for interval in shuffled_chunk_intervals_replica: interval_indices = np.arange(interval[0], interval[1]) for indice in interval_indices: - assert indice in data[size : size + len(interval_indices)] + assert indice in [b.index for b in data[size : size + len(interval_indices)]] size += len(interval_indices) validate_batch(batches_1) diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index da941150341c0..eeb1a525bf403 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -18,7 +18,7 @@ import numpy as np import pytest from lightning.data.cache.reader import BinaryReader -from lightning.data.cache.writer import BinaryWriter +from lightning.data.cache.writer import BinaryWriter, get_cloud_path from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") @@ -93,7 +93,7 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): @pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") def test_binary_writer_config(monkeypatch): - assert BinaryWriter.get_cloud_path("") is None + assert get_cloud_path("") is None monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "cluster_id") monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id") @@ -101,10 +101,10 @@ def test_binary_writer_config(monkeypatch): prefix = "s3://cluster_id/projects/project_id/cloudspaces/cloud_space_id/code/content/" - assert BinaryWriter.get_cloud_path("") == prefix - assert BinaryWriter.get_cloud_path("~") == prefix - assert BinaryWriter.get_cloud_path("~/") == prefix - assert BinaryWriter.get_cloud_path("/") == prefix - assert BinaryWriter.get_cloud_path("/data") == f"{prefix}data" - assert BinaryWriter.get_cloud_path("~/data") == f"{prefix}data" - assert BinaryWriter.get_cloud_path("/teamspace/studios/this_studio/data") == f"{prefix}data" + assert get_cloud_path("") == prefix + assert get_cloud_path("~") == prefix + assert get_cloud_path("~/") == prefix + assert get_cloud_path("/") == prefix + assert get_cloud_path("/data") == f"{prefix}data" + assert get_cloud_path("~/data") == f"{prefix}data" + assert get_cloud_path("/teamspace/studios/this_studio/data") == f"{prefix}data" From 675cd45fb5f80df0525a35942e97f3900c2ca4fa Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 2 Oct 2023 09:44:06 +0100 Subject: [PATCH 46/84] Update src/lightning/data/cache/reader.py Co-authored-by: Luca Antiga --- src/lightning/data/cache/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 4ffcaae5d1fee..817b2e515d0f3 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -185,7 +185,7 @@ def get_length(self) -> int: return len(self._config) def get_chunk_interval(self): - """Get the index interval of each chunks.""" + """Get the index interval of each chunk.""" if self._config is None: self._try_load_config() From aec12be6fbb40ae192274e184c0a3505804b19ea Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 2 Oct 2023 15:11:43 +0100 Subject: [PATCH 47/84] update --- src/lightning/data/cache/cache.py | 17 ---------- src/lightning/data/cache/dataloader.py | 43 +++++++++++++------------- src/lightning/data/cache/writer.py | 30 ------------------ tests/tests_data/cache/test_writer.py | 22 +------------ 4 files changed, 22 insertions(+), 90 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index bd0f2fd697f61..f1796139677ae 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -15,9 +15,6 @@ import os from typing import Any, Dict, Optional -import torch - -from lightning.data.cache.pytree import tree_flatten from lightning.data.cache.reader import BinaryReader from lightning.data.cache.writer import BinaryWriter from lightning.data.datasets.env import _DistributedEnv @@ -49,20 +46,6 @@ def __init__( self._distributed_env = _DistributedEnv.detect() self._num_workers: Optional[int] = None - def _is_item_equal(self, data_1: Any, data_2: Any) -> bool: - data_1_flattened, _ = tree_flatten(data_1) - data_2_flattened, _ = tree_flatten(data_2) - - if len(data_1_flattened) != len(data_2_flattened): - return False - - return all(self._is_data_equal(d1, d2) for d1, d2 in zip(data_1_flattened, data_2_flattened)) - - def _is_data_equal(self, d1, d2) -> bool: - if isinstance(d1, torch.Tensor) and isinstance(d2, torch.Tensor): - return torch.equal(d1, d2) - return d1 == d2 - def _setup(self, num_workers: int) -> None: """Called by the LightningDataLoader to ensure the num_workers is known.""" self._num_workers = num_workers diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 4ee0e7b06eb7f..bee6ae45d50fc 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -12,8 +12,6 @@ # limitations under the License. import logging -import os -from datetime import datetime from typing import Any, Optional import torch @@ -47,9 +45,10 @@ def _equal_items(data_1: Any, data_2: Any) -> bool: def _equal_item(d1, d2) -> bool: if not isinstance(d1, type(d2)): raise False - if isinstance(d1, torch.Tensor) and isinstance(d2, torch.Tensor): - return torch.equal(d1, d2) - return d1 == d2 + equality = d1 == d2 + if isinstance(equality, torch.Tensor): + return equality.all() + return equality class CacheDataset(Dataset): @@ -57,28 +56,24 @@ def __init__( self, dataset: Dataset, cache_dir: Optional[str], chunk_size: Optional[int], compression: Optional[str] ): self._datataset = dataset - if cache_dir is None: - cache_dir = os.path.join(os.getcwd(), "cache_dir", datetime.now().strftime("%m-%d-%Y-%H-%M")) - os.makedirs(cache_dir, exist_ok=True) - chunk_size = 2 << 26 - self.cache = Cache(cache_dir, chunk_size=chunk_size, compression=compression) - self.is_deterministic = False + self._cache = Cache(cache_dir, chunk_size=chunk_size, compression=compression) + self._is_deterministic = False def __len__(self) -> int: - return len(self.cache) if self.cache.filled else len(self._datataset) + return len(self._cache) if self._cache.filled else len(self._datataset) def __getitem__(self, index): - data_1 = self.cache[index] if self.cache.filled else self._datataset[index] - if not self.cache.filled: - if not self.is_deterministic: + data_1 = self._cache[index] if self._cache.filled else self._datataset[index] + if not self._cache.filled: + if not self._is_deterministic: data2 = self._datataset[index] if not _equal_items(data_1, data2): raise ValueError( f"Your dataset items aren't deterministic. Found {data_1} and {data2} for index {index}." " HINT: Use the `lightning.data.cache.Cache` directly within your dataset." ) - self.is_deterministic = True - self.cache[index] = data_1 + self._is_deterministic = True + self._cache[index] = data_1 return data_1 @@ -117,6 +112,8 @@ def __init__(self, loader): class LightningDataLoader(DataLoader): + __doc__ = DataLoader.__doc__ + def __init__( self, dataset, @@ -134,26 +131,28 @@ def __init__( **kwargs, ): if sampler: - raise Exception( + raise ValueError( "The LightningDataLoader relies on its own internal sampler. Passing a sampler isn't supported." ) if batch_sampler: - raise Exception( + raise ValueError( "The LightningDataLoader relies on its own internal sampler. Passing a batch_sampler isn't supported." ) if isinstance(dataset, IterableDataset): - raise Exception("Only map-based dataset are supported by the LightningDataLoader for now.") + raise ValueError("Only map-based dataset are supported by the LightningDataLoader for now.") cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] if len(cache) > 1: - raise Exception("We found several Cache used as attributes from your dataset. Only one is support for now.") + raise ValueError( + "We found several Cache used as attributes from your dataset. Only one is support for now." + ) if len(cache) == 0: if cache_dir is None: - logger.info("You can provide a `cache_dir` filepath to the LightningDataLoader.") + raise ValueError("You can provide a `cache_dir` filepath to the LightningDataLoader.") dataset = CacheDataset(dataset, cache_dir, chunk_size, compression) cache = dataset.cache diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 7bc06c3feccc3..418a3968b197c 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -24,30 +24,6 @@ from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv -def cloud_path(cache_dir: str) -> Optional[str]: - cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) - project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) - cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) - - if cluster_id is None or project_id is None or cloud_space_id is None: - return None - return f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/content/{cache_dir}/" - - -def get_cloud_path(cache_dir: str) -> Optional[str]: - """Returns the s3 URL to the cache_dir.""" - cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) - project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) - cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) - - if cluster_id is None or project_id is None or cloud_space_id is None: - return None - cache_dir = cache_dir.replace("~/", "").replace("~", "").replace("/teamspace/studios/this_studio/", "") - if cache_dir.startswith("/"): - cache_dir = cache_dir[1:] - return os.path.join(f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/code/content", cache_dir) - - class BinaryWriter: def __init__( self, @@ -121,12 +97,6 @@ def get_config(self) -> Dict[str, Any]: "data_format": self._data_format, "data_spec": treespec_dumps(self._data_spec) if self._data_spec else None, } - cloud_path = get_cloud_path(self._cache_dir) - if cloud_path: - out["cloud_path"] = cloud_path - user_id = os.getenv("LIGHTNING_USER_ID", None) - if user_id: - out["user_id"] = user_id return out def serialize(self, items: Any) -> bytes: diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index eeb1a525bf403..d9e7ea2d0f8eb 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -13,12 +13,11 @@ import json import os -import sys import numpy as np import pytest from lightning.data.cache.reader import BinaryReader -from lightning.data.cache.writer import BinaryWriter, get_cloud_path +from lightning.data.cache.writer import BinaryWriter from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") @@ -89,22 +88,3 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): data = reader.read(i) assert data["x"] == imgs[i] assert data["y"] == i - - -@pytest.mark.skipif(condition=sys.platform == "win32", reason="Not supported on windows") -def test_binary_writer_config(monkeypatch): - assert get_cloud_path("") is None - - monkeypatch.setenv("LIGHTNING_CLUSTER_ID", "cluster_id") - monkeypatch.setenv("LIGHTNING_CLOUD_PROJECT_ID", "project_id") - monkeypatch.setenv("LIGHTNING_CLOUD_SPACE_ID", "cloud_space_id") - - prefix = "s3://cluster_id/projects/project_id/cloudspaces/cloud_space_id/code/content/" - - assert get_cloud_path("") == prefix - assert get_cloud_path("~") == prefix - assert get_cloud_path("~/") == prefix - assert get_cloud_path("/") == prefix - assert get_cloud_path("/data") == f"{prefix}data" - assert get_cloud_path("~/data") == f"{prefix}data" - assert get_cloud_path("/teamspace/studios/this_studio/data") == f"{prefix}data" From 48c0fdff40793df6e44a06d637a414cab5a86706 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 3 Oct 2023 10:37:49 +0000 Subject: [PATCH 48/84] update --- _notebooks | 1 - src/lightning/data/cache/cache.py | 11 +++++++---- src/lightning/data/cache/writer.py | 30 ++++++++++++++---------------- 3 files changed, 21 insertions(+), 21 deletions(-) delete mode 160000 _notebooks diff --git a/_notebooks b/_notebooks deleted file mode 160000 index 70821217ca0db..0000000000000 --- a/_notebooks +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 70821217ca0db0280af537002839dbb340f77d68 diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index bd0f2fd697f61..cc3099ef0f16c 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -30,19 +30,22 @@ def __init__( self, cache_dir: str, compression: Optional[str] = None, - chunk_size: int = 2 << 26, + chunk_size: Optional[int] = None, + chunk_bytes: Optional[int] = None, ): """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements together in order to accelerate fetching. Arguments: cache_dir: The path to where the chunks will be stored. - compression: The name of the algorithm to reduce the size of the chunks - chunk_size: The maximum byte size of chunk. + compression: The name of the algorithm to reduce the size of the chunks. + chunk_bytes: The maximum number of bytes within a chunk. + chunk_size: The maximum number of items within a chunk. + """ super().__init__() - self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, compression=compression) + self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) self._reader = BinaryReader(cache_dir, compression=compression) self._cache_dir = cache_dir self._is_done = False diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 7bc06c3feccc3..3845d8c545ead 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -24,16 +24,6 @@ from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv -def cloud_path(cache_dir: str) -> Optional[str]: - cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) - project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID", None) - cloud_space_id = os.getenv("LIGHTNING_CLOUD_SPACE_ID", None) - - if cluster_id is None or project_id is None or cloud_space_id is None: - return None - return f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/content/{cache_dir}/" - - def get_cloud_path(cache_dir: str) -> Optional[str]: """Returns the s3 URL to the cache_dir.""" cluster_id = os.getenv("LIGHTNING_CLUSTER_ID", None) @@ -45,21 +35,23 @@ def get_cloud_path(cache_dir: str) -> Optional[str]: cache_dir = cache_dir.replace("~/", "").replace("~", "").replace("/teamspace/studios/this_studio/", "") if cache_dir.startswith("/"): cache_dir = cache_dir[1:] - return os.path.join(f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/code/content", cache_dir) + return os.path.join(f"s3://{cluster_id}/projects/{project_id}/cloudspaces/{cloud_space_id}/code/content", cache_dir, "/") class BinaryWriter: def __init__( self, cache_dir: str, - chunk_size: int = 1 << 26, + chunk_size: Optional[int] = None, + chunk_bytes: Optional[int] = None, compression: Optional[str] = None, ): """The BinaryWriter enables to chunk dataset into an efficient streaming format for cloud training. Arguments: cache_dir: The path to where the chunks will be saved. - chunk_size: The maximum number of bytes to store within a chunk. + chunk_bytes: The maximum number of bytes within a chunk. + chunk_size: The maximum number of items within a chunk. compression: The compression algorithm to use. """ @@ -68,8 +60,12 @@ def __init__( if not os.path.exists(self._cache_dir): raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.") + if (chunk_size is None and chunk_bytes is None) or (chunk_size and chunk_bytes): + raise ValueError("Either one of the `chunk_size` or the `chunk_bytes` need to be provided.") + self._serializers: Dict[str, Serializer] = _SERIALIZERS self._chunk_size = chunk_size + self._chunk_bytes = chunk_bytes self._compression = compression self._data_format = None @@ -117,7 +113,7 @@ def get_config(self) -> Dict[str, Any]: """Returns the config of the writer.""" out = { "compression": self._compression, - "chunk_size": self._chunk_size, + "chunk_size": self._chunk_bytes, "data_format": self._data_format, "data_spec": treespec_dumps(self._data_spec) if self._data_spec else None, } @@ -221,10 +217,12 @@ def __setitem__(self, index, items: any): serialized_items = self.serialize(items) serialized_items_size = len(serialized_items) - if self._chunk_size < self._current_chunk_size + serialized_items_size: + should_write = (self._chunk_bytes and self._chunk_bytes < self._current_chunk_size + serialized_items_size) or (self._chunk_size and len(self._indexes) >= self._chunk_size) + + if should_write: if self._current_chunk_size == 0: raise Exception( - f"The provided chunk_size {self._chunk_size} is too small." + f"The provided chunk_size {self._chunk_bytes} is too small." f" You should use a multiple of {serialized_items_size} bytes." ) self.write_chunk() From a0f369658f4f89286af0b9b377dcd9d5683a980a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Oct 2023 10:39:43 +0000 Subject: [PATCH 49/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/cache/cache.py | 1 - src/lightning/data/cache/writer.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 479532ade9ce5..5b6c6d479dce9 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -39,7 +39,6 @@ def __init__( chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. - """ super().__init__() self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 97b74deaac8a7..dc99e1e7c2358 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -197,7 +197,9 @@ def __setitem__(self, index, items: any): serialized_items = self.serialize(items) serialized_items_size = len(serialized_items) - should_write = (self._chunk_bytes and self._chunk_bytes < self._current_chunk_size + serialized_items_size) or (self._chunk_size and len(self._indexes) >= self._chunk_size) + should_write = (self._chunk_bytes and self._chunk_bytes < self._current_chunk_size + serialized_items_size) or ( + self._chunk_size and len(self._indexes) >= self._chunk_size + ) if should_write: if self._current_chunk_size == 0: From 91e718cac470e574373977f2f2b8d341083412d0 Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 3 Oct 2023 11:45:51 +0100 Subject: [PATCH 50/84] update --- src/lightning/data/cache/writer.py | 4 ++- tests/tests_data/cache/test_cache.py | 10 ++++---- tests/tests_data/cache/test_writer.py | 35 ++++++++++++++++++++++++--- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index dc99e1e7c2358..349373be56f40 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -99,7 +99,8 @@ def get_config(self) -> Dict[str, Any]: """Returns the config of the writer.""" out = { "compression": self._compression, - "chunk_size": self._chunk_bytes, + "chunk_size": self._chunk_size, + "chunk_bytes": self._chunk_bytes, "data_format": self._data_format, "data_spec": treespec_dumps(self._data_spec) if self._data_spec else None, } @@ -158,6 +159,7 @@ def _create_chunk(self, filename: str) -> bytes: assert (self._indexes[-1] - self._indexes[0] + 1) == len(self._serialized_items) chunk_info = { + "chunk_bytes": self._current_chunk_size, "samples": len(self._serialized_items), "filename": filename, "mapping": mapping, diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index bcd02427a7cb8..47b1a27ee8fa7 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -78,7 +78,7 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): cache_dir = os.path.join(tmpdir, "cache") distributed_env = _DistributedEnv.detect() - cache = Cache(cache_dir, chunk_size=2 << 12) + cache = Cache(cache_dir, chunk_bytes=2 << 12) dataset = ImageDataset(tmpdir, cache, dataset_size, 10) dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4) @@ -158,7 +158,7 @@ def test_cache_with_simple_format(tmpdir): cache_dir = os.path.join(tmpdir, "cache1") os.makedirs(cache_dir) - cache = Cache(cache_dir, chunk_size=90) + cache = Cache(cache_dir, chunk_bytes=90) for i in range(100): cache[i] = i @@ -171,7 +171,7 @@ def test_cache_with_simple_format(tmpdir): cache_dir = os.path.join(tmpdir, "cache2") os.makedirs(cache_dir) - cache = Cache(cache_dir, chunk_size=90) + cache = Cache(cache_dir, chunk_bytes=90) for i in range(100): cache[i] = [i, {0: [i + 1]}] @@ -186,7 +186,7 @@ def test_cache_with_auto_wrapping(tmpdir): os.makedirs(os.path.join(tmpdir, "cache_1"), exist_ok=True) dataset = RandomDataset(64, 64) - dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_1"), chunk_size=2 << 12) + dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_1"), chunk_bytes=2 << 12) for batch in dataloader: assert isinstance(batch, torch.Tensor) assert sorted(os.listdir(os.path.join(tmpdir, "cache_1"))) == [ @@ -210,7 +210,7 @@ def __len__(self) -> int: os.makedirs(os.path.join(tmpdir, "cache_2"), exist_ok=True) dataset = RandomDatasetAtRuntime(64, 64) - dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_2"), chunk_size=2 << 12) + dataloader = LightningDataLoader(dataset, cache_dir=os.path.join(tmpdir, "cache_2"), chunk_bytes=2 << 12) with pytest.raises(ValueError, match="Your dataset items aren't deterministic"): for batch in dataloader: pass diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index d9e7ea2d0f8eb..bfa57734583ef 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -23,14 +23,14 @@ _PIL_AVAILABLE = RequirementCache("PIL") -def test_binary_writer_with_ints(tmpdir): +def test_binary_writer_with_ints_and_chunk_bytes(tmpdir): with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): BinaryWriter("dontexists", {}) with pytest.raises(ValueError, match="No compresion algorithms are installed."): BinaryWriter(tmpdir, {"i": "int"}, compression="something_else") - binary_writer = BinaryWriter(tmpdir, chunk_size=90) + binary_writer = BinaryWriter(tmpdir, chunk_bytes=90) for i in range(100): binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2} @@ -52,6 +52,35 @@ def test_binary_writer_with_ints(tmpdir): assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} +def test_binary_writer_with_ints_and_chunk_size(tmpdir): + with pytest.raises(FileNotFoundError, match="The provided cache directory `dontexists` doesn't exist."): + BinaryWriter("dontexists", {}) + + with pytest.raises(ValueError, match="No compresion algorithms are installed."): + BinaryWriter(tmpdir, {"i": "int"}, compression="something_else") + + binary_writer = BinaryWriter(tmpdir, chunk_size=25) + + for i in range(100): + binary_writer[i] = {"i": i, "i+1": i + 1, "i+2": i + 2} + + assert len(os.listdir(tmpdir)) == 3 + binary_writer.done() + assert len(os.listdir(tmpdir)) == 5 + + with open(os.path.join(tmpdir, "0.index.json")) as f: + data = json.load(f) + + assert data["chunks"][0]["samples"] == 25 + assert data["chunks"][1]["samples"] == 25 + assert data["chunks"][-1]["samples"] == 25 + + reader = BinaryReader(tmpdir) + for i in range(100): + data = reader.read(i) + assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} + + @pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") def test_binary_writer_with_jpeg_and_int(tmpdir): """Validate the writer and reader can serialize / deserialize a pair of image and label.""" @@ -59,7 +88,7 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): cache_dir = os.path.join(tmpdir, "chunks") os.makedirs(cache_dir, exist_ok=True) - binary_writer = BinaryWriter(cache_dir, chunk_size=2 << 12) + binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12) imgs = [] From ff7547cb86d3f256006efe279c84eeb7600adc8a Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 3 Oct 2023 12:45:33 +0100 Subject: [PATCH 51/84] update --- src/lightning/data/cache/dataloader.py | 48 +++- src/lightning/data/cache/serializers.py | 32 ++- src/lightning/data/cache/worker.py | 331 ---------------------- src/lightning/data/cache/writer.py | 5 +- tests/tests_data/cache/test_serializer.py | 36 ++- 5 files changed, 99 insertions(+), 353 deletions(-) delete mode 100644 src/lightning/data/cache/worker.py diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index bee6ae45d50fc..720f92ced9166 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -31,6 +31,8 @@ logger = logging.Logger(__name__) +_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B + def _equal_items(data_1: Any, data_2: Any) -> bool: data_1_flattened, _ = tree_flatten(data_1) @@ -53,10 +55,15 @@ def _equal_item(d1, d2) -> bool: class CacheDataset(Dataset): def __init__( - self, dataset: Dataset, cache_dir: Optional[str], chunk_size: Optional[int], compression: Optional[str] + self, + dataset: Dataset, + cache_dir: Optional[str], + chunk_bytes: Optional[int], + chunk_size: int, + compression: Optional[str], ): self._datataset = dataset - self._cache = Cache(cache_dir, chunk_size=chunk_size, compression=compression) + self._cache = Cache(cache_dir, chunk_bytes=chunk_bytes, chunk_size=chunk_size, compression=compression) self._is_deterministic = False def __len__(self) -> int: @@ -100,14 +107,39 @@ def _next_data(self): raise StopIteration() +class WorkerLoop: + def __call__(self, dataset_kind, *args, **kwargs): + from torch.utils.data import _DatasetKind + from torch.utils.data._utils import worker + + from lightning.data.cache.cache import Cache + + create_fetcher = _DatasetKind.create_fetcher + + fetcher = None + + def create_fetcher_fn(*args, **kwargs): + nonlocal fetcher + fetcher = create_fetcher(*args, **kwargs) + return fetcher + + _DatasetKind.create_fetcher = create_fetcher_fn + + worker._worker_loop(dataset_kind, *args, **kwargs) + + if dataset_kind == _DatasetKind.Map: + for v in fetcher.dataset.__dict__.values(): + if isinstance(v, Cache): + v.done() + + class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): def __init__(self, loader): - # Patch PyTorch worker loop + # Patch PyTorch worker loop to call the `cache.done()` method. from torch.utils.data._utils import worker - from lightning.data.cache.worker import _worker_loop - - worker._worker_loop = _worker_loop + worker._original_worker_loop = worker._worker_loop + worker._worker_loop = WorkerLoop() super().__init__(loader) @@ -126,7 +158,7 @@ def __init__( batch_size=1, drop_last=False, cache_dir: Optional[str] = None, - chunk_size: Optional[int] = 2 << 26, + chunk_bytes: Optional[int] = _DEFAULT_CHUNK_BYTES, compression: Optional[str] = None, **kwargs, ): @@ -154,7 +186,7 @@ def __init__( if cache_dir is None: raise ValueError("You can provide a `cache_dir` filepath to the LightningDataLoader.") - dataset = CacheDataset(dataset, cache_dir, chunk_size, compression) + dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size if chunk_bytes else None, compression) cache = dataset.cache else: cache = cache[0] diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index 7777f8b2ba198..47b3c87981c98 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -11,7 +11,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle from abc import ABC, abstractmethod +from collections import OrderedDict from io import BytesIO import numpy as np @@ -172,10 +174,26 @@ def can_serialize(self, item: torch.Tensor) -> bool: return isinstance(item, torch.Tensor) -_SERIALIZERS = { - "pil": PILSerializer(), - "int": IntSerializer(), - "jpeg": JPEGSerializer(), - "bytes": BytesSerializer(), - "tensor": TensorSerializer(), -} +class PickleSerializer(Serializer): + """The PickleSerializer serialize and deserialize python objects to and from bytes.""" + + def serialize(self, item: any) -> bytes: + return pickle.dumps(item) + + def deserialize(self, data: bytes) -> any: + return pickle.loads(data) + + def can_serialize(self, item: any) -> bool: + return isinstance(item, any) + + +_SERIALIZERS = OrderedDict( + **{ + "pil": PILSerializer(), + "int": IntSerializer(), + "jpeg": JPEGSerializer(), + "bytes": BytesSerializer(), + "tensor": TensorSerializer(), + "pickle": PickleSerializer(), + } +) diff --git a/src/lightning/data/cache/worker.py b/src/lightning/data/cache/worker.py deleted file mode 100644 index 74534dd09b85b..0000000000000 --- a/src/lightning/data/cache/worker.py +++ /dev/null @@ -1,331 +0,0 @@ -# Copy pasted from https://github.com/pytorch/pytorch/blob/main/torch/utils/data/_utils/worker.py + lines 318-328 -# TODO: Delete me when this is addressed https://github.com/pytorch/pytorch/issues/110156 - -import os -import queue -import random -from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union - -import torch -from torch._utils import ExceptionWrapper -from torch.utils.data._utils import HAS_NUMPY, IS_WINDOWS, MP_STATUS_CHECK_INTERVAL, signal_handling - -if TYPE_CHECKING: - from torch.utils.data import Dataset - -if IS_WINDOWS: - import ctypes - from ctypes.wintypes import BOOL, DWORD, HANDLE - - # On Windows, the parent ID of the worker process remains unchanged when the manager process - # is gone, and the only way to check it through OS is to let the worker have a process handle - # of the manager and ask if the process status has changed. - class ManagerWatchdog: - def __init__(self): - self.manager_pid = os.getppid() - - # mypy cannot detect this code is windows only - self.kernel32 = ctypes.WinDLL("kernel32", use_last_error=True) # type: ignore[attr-defined] - self.kernel32.OpenProcess.argtypes = (DWORD, BOOL, DWORD) - self.kernel32.OpenProcess.restype = HANDLE - self.kernel32.WaitForSingleObject.argtypes = (HANDLE, DWORD) - self.kernel32.WaitForSingleObject.restype = DWORD - - # Value obtained from https://msdn.microsoft.com/en-us/library/ms684880.aspx - SYNCHRONIZE = 0x00100000 - self.manager_handle = self.kernel32.OpenProcess(SYNCHRONIZE, 0, self.manager_pid) - - if not self.manager_handle: - raise ctypes.WinError(ctypes.get_last_error()) # type: ignore[attr-defined] - - self.manager_dead = False - - def is_alive(self): - if not self.manager_dead: - # Value obtained from https://msdn.microsoft.com/en-us/library/windows/desktop/ms687032.aspx - self.manager_dead = self.kernel32.WaitForSingleObject(self.manager_handle, 0) == 0 - return not self.manager_dead - -else: - - class ManagerWatchdog: # type: ignore[no-redef] - def __init__(self): - self.manager_pid = os.getppid() - self.manager_dead = False - - def is_alive(self): - if not self.manager_dead: - self.manager_dead = os.getppid() != self.manager_pid - return not self.manager_dead - - -_worker_info = None - - -class WorkerInfo: - id: int - num_workers: int - seed: int - dataset: "Dataset" - __initialized = False - - def __init__(self, **kwargs): - for k, v in kwargs.items(): - setattr(self, k, v) - self.__keys = tuple(kwargs.keys()) - self.__initialized = True - - def __setattr__(self, key, val): - if self.__initialized: - raise RuntimeError(f"Cannot assign attributes to {self.__class__.__name__} objects") - return super().__setattr__(key, val) - - def __repr__(self): - items = [] - for k in self.__keys: - items.append(f"{k}={getattr(self, k)}") - return "{}({})".format(self.__class__.__name__, ", ".join(items)) - - -def get_worker_info() -> Optional[WorkerInfo]: - r"""Returns the information about the current - :class:`~torch.utils.data.DataLoader` iterator worker process. - - When called in a worker, this returns an object guaranteed to have the - following attributes: - - * :attr:`id`: the current worker id. - * :attr:`num_workers`: the total number of workers. - * :attr:`seed`: the random seed set for the current worker. This value is - determined by main process RNG and the worker id. See - :class:`~torch.utils.data.DataLoader`'s documentation for more details. - * :attr:`dataset`: the copy of the dataset object in **this** process. Note - that this will be a different object in a different process than the one - in the main process. - - When called in the main process, this returns ``None``. - - .. note:: - When used in a :attr:`worker_init_fn` passed over to - :class:`~torch.utils.data.DataLoader`, this method can be useful to - set up each worker process differently, for instance, using ``worker_id`` - to configure the ``dataset`` object to only read a specific fraction of a - sharded dataset, or use ``seed`` to seed other libraries used in dataset - code. - """ - return _worker_info - - -@dataclass(frozen=True) -class _IterableDatasetStopIteration: - worker_id: int - - -@dataclass(frozen=True) -class _ResumeIteration: - seed: Optional[int] = None - - -# This function generates an array of int32 as the seed for -# `numpy.random`, in order to prevent state collision due to same -# seed and algorithm for `numpy.random` and `random` modules. -# TODO: Implement `SeedSequence` like object for `torch.random` -def _generate_state(base_seed, worker_id): - INIT_A = 0x43B0D7E5 - MULT_A = 0x931E8875 - INIT_B = 0x8B51F9DD - MULT_B = 0x58F38DED - MIX_MULT_L = 0xCA01F9DD - MIX_MULT_R = 0x4973F715 - XSHIFT = 4 * 8 // 2 - MASK32 = 0xFFFFFFFF - - entropy = [worker_id, base_seed & MASK32, base_seed >> 32, 0] - pool = [0] * 4 - - hash_const_A = INIT_A - - def hash(value): - nonlocal hash_const_A - value = (value ^ hash_const_A) & MASK32 - hash_const_A = (hash_const_A * MULT_A) & MASK32 - value = (value * hash_const_A) & MASK32 - value = (value ^ (value >> XSHIFT)) & MASK32 - return value - - def mix(x, y): - result_x = (MIX_MULT_L * x) & MASK32 - result_y = (MIX_MULT_R * y) & MASK32 - result = (result_x - result_y) & MASK32 - result = (result ^ (result >> XSHIFT)) & MASK32 - return result - - # Add in the entropy to the pool. - for i in range(len(pool)): - pool[i] = hash(entropy[i]) - - # Mix all bits together so late bits can affect earlier bits. - for i_src in range(len(pool)): - for i_dst in range(len(pool)): - if i_src != i_dst: - pool[i_dst] = mix(pool[i_dst], hash(pool[i_src])) - - hash_const_B = INIT_B - state = [] - for i_dst in range(4): - data_val = pool[i_dst] - data_val = (data_val ^ hash_const_B) & MASK32 - hash_const_B = (hash_const_B * MULT_B) & MASK32 - data_val = (data_val * hash_const_B) & MASK32 - data_val = (data_val ^ (data_val >> XSHIFT)) & MASK32 - state.append(data_val) - return state - - -def _worker_loop( - dataset_kind, - dataset, - index_queue, - data_queue, - done_event, - auto_collation, - collate_fn, - drop_last, - base_seed, - init_fn, - worker_id, - num_workers, - persistent_workers, - shared_seed, -): - # See NOTE [ Data Loader Multiprocessing Shutdown Logic ] for details on the - # logic of this function. - - try: - # Initialize C side signal handlers for SIGBUS and SIGSEGV. Python signal - # module's handlers are executed after Python returns from C low-level - # handlers, likely when the same fatal signal had already happened - # again. - # https://docs.python.org/3/library/signal.html#execution-of-python-signal-handlers - signal_handling._set_worker_signal_handlers() - - torch.set_num_threads(1) - seed = base_seed + worker_id - random.seed(seed) - torch.manual_seed(seed) - if HAS_NUMPY: - np_seed = _generate_state(base_seed, worker_id) - import numpy as np - - np.random.seed(np_seed) - - from torch.utils.data import IterDataPipe - from torch.utils.data.graph_settings import apply_random_seed - - shared_rng = torch.Generator() - if isinstance(dataset, IterDataPipe): - assert shared_seed is not None - shared_rng.manual_seed(shared_seed) - dataset = apply_random_seed(dataset, shared_rng) - - global _worker_info - _worker_info = WorkerInfo(id=worker_id, num_workers=num_workers, seed=seed, dataset=dataset) - - from torch.utils.data import _DatasetKind - - init_exception = None - - try: - if init_fn is not None: - init_fn(worker_id) - - fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) - except Exception: - init_exception = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}") - - # When using Iterable mode, some worker can exit earlier than others due - # to the IterableDataset behaving differently for different workers. - # When such things happen, an `_IterableDatasetStopIteration` object is - # sent over to the main process with the ID of this worker, so that the - # main process won't send more tasks to this worker, and will send - # `None` to this worker to properly exit it. - # - # Note that we cannot set `done_event` from a worker as it is shared - # among all processes. Instead, we set the `iteration_end` flag to - # signify that the iterator is exhausted. When either `done_event` or - # `iteration_end` is set, we skip all processing step and just wait for - # `None`. - iteration_end = False - - watchdog = ManagerWatchdog() - - while watchdog.is_alive(): - try: - r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) - except queue.Empty: - continue - if isinstance(r, _ResumeIteration): - # Acknowledge the main process - data_queue.put((r, None)) - iteration_end = False - - if isinstance(dataset, IterDataPipe): - assert r.seed is not None - shared_rng.manual_seed(r.seed) - dataset = apply_random_seed(dataset, shared_rng) - - # Recreate the fetcher for worker-reuse policy - fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last) - continue - - if r is None: - # Received the final signal - assert done_event.is_set() or iteration_end - break - - if done_event.is_set() or iteration_end: - # `done_event` is set. But I haven't received the final signal - # (None) yet. I will keep continuing until get it, and skip the - # processing steps. - continue - idx, index = r - data: Union[_IterableDatasetStopIteration, ExceptionWrapper] - if init_exception is not None: - data = init_exception - init_exception = None - else: - try: - data = fetcher.fetch(index) - except Exception as e: - if isinstance(e, StopIteration) and dataset_kind == _DatasetKind.Iterable: - data = _IterableDatasetStopIteration(worker_id) - # Set `iteration_end` - # (1) to save future `next(...)` calls, and - # (2) to avoid sending multiple `_IterableDatasetStopIteration`s. - iteration_end = True - else: - # It is important that we don't store exc_info in a variable. - # `ExceptionWrapper` does the correct thing. - # See NOTE [ Python Traceback Reference Cycle Problem ] - data = ExceptionWrapper(where=f"in DataLoader worker process {worker_id}") - data_queue.put((idx, data)) - del data, idx, index, r # save memory - except KeyboardInterrupt: - # Main process will raise KeyboardInterrupt anyways. - pass - if done_event.is_set(): - ####### ADDTIONATIONAL CODE ####### - - from lightning.data.cache import Cache - - # required to ensure the cache is persisted - if dataset_kind == _DatasetKind.Map: - for v in fetcher.dataset.__dict__.values(): - if isinstance(v, Cache): - v.done() - - ####### ADDTIONATIONAL CODE ####### - - data_queue.cancel_join_thread() - data_queue.close() diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 349373be56f40..be9a57a084d6e 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -20,7 +20,6 @@ from lightning.data.cache.compression import _COMPRESSORS from lightning.data.cache.pytree import tree_flatten, treespec_dumps from lightning.data.cache.serializers import _SERIALIZERS, Serializer -from lightning.data.cache.worker import get_worker_info from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -83,7 +82,7 @@ def filled(self) -> bool: return True files = os.listdir(self._cache_dir) index_files = [f for f in files if f.endswith("index.json")] - worker_end = _WorkerEnv.detect(get_worker_info_fn=get_worker_info) + worker_end = _WorkerEnv.detect() self._is_done = len(index_files) == self._env.world_size * worker_end.world_size return self._is_done @@ -91,7 +90,7 @@ def filled(self) -> bool: def rank(self): """Returns the rank of the writer.""" if self._rank is None: - self._worker_env = _WorkerEnv.detect(get_worker_info_fn=get_worker_info) + self._worker_env = _WorkerEnv.detect() self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank return self._rank diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index 14ccc02fc9e10..1a690576b1ddf 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -12,15 +12,18 @@ # limitations under the License. import os +from time import time import numpy as np import pytest import torch +from lightning import seed_everything from lightning.data.cache.serializers import ( _SERIALIZERS, _TORCH_DTYPES_MAPPING, IntSerializer, JPEGSerializer, + PickleSerializer, PILSerializer, TensorSerializer, ) @@ -30,7 +33,7 @@ def test_serializers(): - assert sorted(_SERIALIZERS) == ["bytes", "int", "jpeg", "pil", "tensor"] + assert list(_SERIALIZERS.keys()) == ["pil", "int", "jpeg", "bytes", "tensor", "pickle"] def test_int_serializer(): @@ -102,8 +105,13 @@ def test_pil_serializer(mode): def test_tensor_serializer(): - serializer = TensorSerializer() + seed_everything(42) + + serializer_tensor = TensorSerializer() + serializer_pickle = PickleSerializer() + ratio_times = [] + ratio_bytes = [] shapes = [(10,), (10, 10), (10, 10, 10), (10, 10, 10, 5), (10, 10, 10, 5, 4)] for dtype in _TORCH_DTYPES_MAPPING.values(): for shape in shapes: @@ -111,11 +119,31 @@ def test_tensor_serializer(): if dtype in [torch.bfloat16]: continue tensor = torch.ones(shape, dtype=dtype) - data = serializer.serialize(tensor) - deserialized_tensor = serializer.deserialize(data) + + t0 = time() + data = serializer_tensor.serialize(tensor) + deserialized_tensor = serializer_tensor.deserialize(data) + tensor_time = time() - t0 + tensor_bytes = len(data) + + assert deserialized_tensor.dtype == dtype + assert torch.equal(tensor, deserialized_tensor) + + t1 = time() + data = serializer_pickle.serialize(tensor) + deserialized_tensor = serializer_pickle.deserialize(data) + pickle_time = time() - t1 + pickle_bytes = len(data) + assert deserialized_tensor.dtype == dtype assert torch.equal(tensor, deserialized_tensor) + ratio_times.append(pickle_time / tensor_time) + ratio_bytes.append(pickle_bytes / tensor_bytes) + + assert np.mean(ratio_times) > 4 + assert np.mean(ratio_bytes) > 2 + def test_assert_bfloat16_tensor_serializer(): serializer = TensorSerializer() From 6dfd3fd957fb7e2071e1acb74e0f6c3a0a3b5c1d Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 3 Oct 2023 12:58:43 +0100 Subject: [PATCH 52/84] update --- src/lightning/data/cache/dataloader.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 720f92ced9166..8d5d5173acdc8 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -12,6 +12,7 @@ # limitations under the License. import logging +from importlib import reload from typing import Any, Optional import torch @@ -114,6 +115,8 @@ def __call__(self, dataset_kind, *args, **kwargs): from lightning.data.cache.cache import Cache + reloaded_worker = reload(worker) + create_fetcher = _DatasetKind.create_fetcher fetcher = None @@ -125,7 +128,7 @@ def create_fetcher_fn(*args, **kwargs): _DatasetKind.create_fetcher = create_fetcher_fn - worker._worker_loop(dataset_kind, *args, **kwargs) + reloaded_worker._worker_loop(dataset_kind, *args, **kwargs) if dataset_kind == _DatasetKind.Map: for v in fetcher.dataset.__dict__.values(): @@ -138,7 +141,6 @@ def __init__(self, loader): # Patch PyTorch worker loop to call the `cache.done()` method. from torch.utils.data._utils import worker - worker._original_worker_loop = worker._worker_loop worker._worker_loop = WorkerLoop() super().__init__(loader) From 72e469d74dadabfacd5ca8a9e68b7a1c0e097dbd Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 3 Oct 2023 14:39:05 +0100 Subject: [PATCH 53/84] update --- src/lightning/data/cache/cache.py | 8 +++-- src/lightning/data/cache/dataloader.py | 2 ++ src/lightning/data/cache/writer.py | 44 ++++++++++++++++++++++++-- 3 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 5b6c6d479dce9..19478b708d463 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -59,9 +59,7 @@ def filled(self) -> bool: raise Exception("The Cache wasn't setup properly. HINT: Did you use the LightningDataLoader ?") if self._is_done: return True - files = os.listdir(self._cache_dir) - index_files = [f for f in files if f.endswith("index.json")] - self._is_done = len(index_files) == self._distributed_env.world_size * (self._num_workers or 1) + self._is_done = os.path.exists(os.path.join(self._cache_dir, "index.json")) return self._is_done def __setitem__(self, index, data) -> None: @@ -76,6 +74,10 @@ def done(self) -> None: """Inform the writer the chunking phase is finished.""" self._writer.done() + def merge(self) -> None: + """Inform the writer the chunks indexed can be merged.""" + self._writer.merge() + def __len__(self) -> int: return self._reader.get_length() diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 8d5d5173acdc8..f6cc411a25724 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -105,6 +105,7 @@ def _next_data(self): for v in self._dataset_fetcher.dataset.__dict__.values(): if isinstance(v, Cache): v.done() + v.merge() raise StopIteration() @@ -134,6 +135,7 @@ def create_fetcher_fn(*args, **kwargs): for v in fetcher.dataset.__dict__.values(): if isinstance(v, Cache): v.done() + v.merge() class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index be9a57a084d6e..28309f50b1b7f 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -13,6 +13,7 @@ import json import os +from time import sleep from typing import Any, Dict, Optional import numpy as np @@ -55,6 +56,7 @@ def __init__( self._data_format = None self._data_spec = None + self._num_workers = None if self._compression: if len(_COMPRESSORS) == 0: @@ -70,10 +72,10 @@ def __init__( self._serialized_items = [] self._chunks_info = [] self._indexes = [] - self._env = _DistributedEnv.detect() self._worker_env = None self._rank = None self._is_done = False + self._distributed_env = _DistributedEnv.detect() @property def filled(self) -> bool: @@ -83,7 +85,7 @@ def filled(self) -> bool: files = os.listdir(self._cache_dir) index_files = [f for f in files if f.endswith("index.json")] worker_end = _WorkerEnv.detect() - self._is_done = len(index_files) == self._env.world_size * worker_end.world_size + self._is_done = len(index_files) == self._distributed_env.world_size * worker_end.world_size return self._is_done @property @@ -91,7 +93,7 @@ def rank(self): """Returns the rank of the writer.""" if self._rank is None: self._worker_env = _WorkerEnv.detect() - self._rank = self._env.global_rank * self._worker_env.world_size + self._worker_env.rank + self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank return self._rank def get_config(self) -> Dict[str, Any]: @@ -252,3 +254,39 @@ def done(self): self.write_chunks_index() self.reset() self._is_done = True + + def merge(self): + if self.rank != 0: + while not os.path.exists(os.path.join(self._cache_dir, "index.json")): + sleep(0.001) + return + + num_workers = _WorkerEnv.detect().world_size + + is_done = False + while not is_done: + files = os.listdir(self._cache_dir) + if "index.json" in files: + return + index_files = [f for f in files if f.endswith("index.json") and f != "index.json"] + is_done = len(index_files) == self._distributed_env.world_size * num_workers + + chunks_info = [] + config = None + for index_filename in sorted(index_files): + chunk_path = os.path.join(self._cache_dir, index_filename) + with open(chunk_path) as f: + data = json.load(f) + + if config is None: + config = data["config"] + + elif config != data["config"]: + raise Exception("The config isn't consistent between chunks. This shouldn't have happened.") + + chunks_info.extend(data["chunks"]) + + os.remove(chunk_path) + + with open(os.path.join(self._cache_dir, "index.json"), "w") as f: + json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True) From d8c6fd7c707adf889f95db5c2b3d11740bf22dae Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 3 Oct 2023 14:40:19 +0100 Subject: [PATCH 54/84] update --- src/lightning/data/cache/cache.py | 3 --- src/lightning/data/cache/dataloader.py | 2 -- 2 files changed, 5 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 19478b708d463..90e3b6940ae45 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -73,9 +73,6 @@ def __getitem__(self, index) -> Dict[str, Any]: def done(self) -> None: """Inform the writer the chunking phase is finished.""" self._writer.done() - - def merge(self) -> None: - """Inform the writer the chunks indexed can be merged.""" self._writer.merge() def __len__(self) -> int: diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index f6cc411a25724..8d5d5173acdc8 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -105,7 +105,6 @@ def _next_data(self): for v in self._dataset_fetcher.dataset.__dict__.values(): if isinstance(v, Cache): v.done() - v.merge() raise StopIteration() @@ -135,7 +134,6 @@ def create_fetcher_fn(*args, **kwargs): for v in fetcher.dataset.__dict__.values(): if isinstance(v, Cache): v.done() - v.merge() class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): From f3184ff4be750d95a0d582900de8e71d9683d2f9 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Tue, 3 Oct 2023 13:46:10 +0000 Subject: [PATCH 55/84] update --- src/lightning/data/cache/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index be9a57a084d6e..98584ea5ba302 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -216,7 +216,7 @@ def __setitem__(self, index, items: any): self._current_chunk_size += serialized_items_size if self._indexes: - assert self._indexes[-1] == index - 1 + assert self._indexes[-1] == index - 1, (self._indexes, index -1) self._indexes.append(index) From 564fbab15a22ddfa2dcba3173ca51e6ceb07d8d0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 3 Oct 2023 13:47:50 +0000 Subject: [PATCH 56/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/cache/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 15da6d54dfa69..1470db1124656 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -218,7 +218,7 @@ def __setitem__(self, index, items: any): self._current_chunk_size += serialized_items_size if self._indexes: - assert self._indexes[-1] == index - 1, (self._indexes, index -1) + assert self._indexes[-1] == index - 1, (self._indexes, index - 1) self._indexes.append(index) From 19198a23a4e9f3f3734670d4c5d5a1a6ea21f9ce Mon Sep 17 00:00:00 2001 From: thomas Date: Tue, 3 Oct 2023 15:39:42 +0100 Subject: [PATCH 57/84] update --- src/lightning/data/cache/sampler.py | 21 +++++++++++++++++++++ src/lightning/data/cache/writer.py | 2 -- tests/tests_data/cache/test_sampler.py | 11 +++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py index 06df9b34eb3be..fee11e63b63bc 100644 --- a/src/lightning/data/cache/sampler.py +++ b/src/lightning/data/cache/sampler.py @@ -221,6 +221,27 @@ def __init__( self._num_workers = num_workers self._shuffled_chunk_intervals = None + # self._validate() + + def _validate(self): + if self._num_workers > 1 and not self._cache.filled: + batches = {} + for batch_index, batch_indices in enumerate(self): + worker_index = batch_index % self._num_workers + if worker_index not in batches: + batches[worker_index] = [] + batches[worker_index].extend(batch_indices) + elif len(batch_indices) > 0: + if batches[worker_index][-1] != (batch_indices[0] - 1): + breakpoint() + batches[worker_index].extend(batch_indices) + + for indices in batches.values(): + indices = np.asarray(indices) + diff = indices[1:] - (indices[:-1] + 1) + if diff.sum() != 0: + raise RuntimeError("This shouldn't have happened. There is a bug in the CacheSampler.") + def __iter_ordered__(self) -> Iterator[List[int]]: # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 iterator = iter(self.sampler) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 28309f50b1b7f..f8730d491d2c0 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -286,7 +286,5 @@ def merge(self): chunks_info.extend(data["chunks"]) - os.remove(chunk_path) - with open(os.path.join(self._cache_dir, "index.json"), "w") as f: json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True) diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index b99a914f0083e..cec4418105db2 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -211,3 +211,14 @@ def validate_batch(data): validate_batch(batches_2) if params[1] == 1: assert batches_1 != batches_2 + + +def test_batch_sampler_imagenet(): + dataset_size = 1281167 + world_size = 1 + rank = 0 + num_workers = 32 + batch_size = 8 + cache = mock.MagicMock() + cache.filled = False + CacheBatchSampler(dataset_size, world_size, rank, num_workers, batch_size, False, True, cache) From 39a08460d6a299fee6127c1aee06897f26a8b8e1 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 6 Oct 2023 13:18:14 +0100 Subject: [PATCH 58/84] New cache (#18706) Co-authored-by: thomas Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/lightning/data/cache/cache.py | 19 +- src/lightning/data/cache/config.py | 112 +++++++ src/lightning/data/cache/constants.py | 14 + src/lightning/data/cache/dataloader.py | 93 +++++- src/lightning/data/cache/downloader.py | 73 +++++ src/lightning/data/cache/reader.py | 199 ++++++------- src/lightning/data/cache/sampler.py | 343 +++++++++------------- src/lightning/data/cache/writer.py | 11 +- tests/tests_data/cache/test_cache.py | 13 +- tests/tests_data/cache/test_sampler.py | 233 ++++----------- tests/tests_data/cache/test_serializer.py | 2 +- tests/tests_data/cache/test_writer.py | 31 +- 12 files changed, 615 insertions(+), 528 deletions(-) create mode 100644 src/lightning/data/cache/config.py create mode 100644 src/lightning/data/cache/constants.py create mode 100644 src/lightning/data/cache/downloader.py diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 90e3b6940ae45..0e9edc4f9d288 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -13,9 +13,10 @@ import logging import os -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from lightning.data.cache.reader import BinaryReader +from lightning.data.cache.sampler import ChunkedIndex from lightning.data.cache.writer import BinaryWriter from lightning.data.datasets.env import _DistributedEnv @@ -26,6 +27,7 @@ class Cache: def __init__( self, cache_dir: str, + remote_dir: Optional[str] = None, compression: Optional[str] = None, chunk_size: Optional[int] = None, chunk_bytes: Optional[int] = None, @@ -35,6 +37,7 @@ def __init__( Arguments: cache_dir: The path to where the chunks will be stored. + remote_dir: The path to a remote folder where the data are located. compression: The name of the algorithm to reduce the size of the chunks. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. @@ -42,7 +45,7 @@ def __init__( """ super().__init__() self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) - self._reader = BinaryReader(cache_dir, compression=compression) + self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression) self._cache_dir = cache_dir self._is_done = False self._distributed_env = _DistributedEnv.detect() @@ -66,17 +69,25 @@ def __setitem__(self, index, data) -> None: """Store an item in the writer.""" self._writer[index] = data - def __getitem__(self, index) -> Dict[str, Any]: + def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]: """Read an item in the reader.""" + if isinstance(index, int): + index = ChunkedIndex(index, self._get_chunk_index_from_index(index)) return self._reader.read(index) def done(self) -> None: """Inform the writer the chunking phase is finished.""" self._writer.done() - self._writer.merge() + + def merge(self, num_workers: int = 1) -> None: + """Inform the writer the chunking phase is finished.""" + self._writer.merge(num_workers) def __len__(self) -> int: return self._reader.get_length() def get_chunk_interval(self): return self._reader.get_chunk_interval() + + def _get_chunk_index_from_index(self, index: int) -> int: + return self._reader._get_chunk_index_from_index(index) diff --git a/src/lightning/data/cache/config.py b/src/lightning/data/cache/config.py new file mode 100644 index 0000000000000..736eb1073972d --- /dev/null +++ b/src/lightning/data/cache/config.py @@ -0,0 +1,112 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from typing import Optional, Tuple + +from lightning.data.cache.constants import INDEX_FILENAME +from lightning.data.cache.downloader import get_downloader_cls +from lightning.data.cache.pytree import treespec_loads +from lightning.data.cache.sampler import ChunkedIndex + + +class ChunksConfig: + def __init__(self, cache_dir: str, remote_dir: Optional[str]): + """The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its + chunk. + + Arguments: + cache_dir: The path to cache folder. + remote_dir: The remote folder where the data are stored. + + """ + self._cache_dir = cache_dir + self._intervals = [] + self._config = None + self._chunks = [] + self._remote_dir = remote_dir + + with open(os.path.join(self._cache_dir, INDEX_FILENAME)) as f: + data = json.load(f) + + self._config = data["config"] + + self._chunks.extend(data["chunks"]) + + self._config["data_spec"] = treespec_loads(self._config["data_spec"]) + + for chunk in self._chunks: + start, end = chunk["interval"] + if (end - start) != chunk["chunk_size"]: + raise Exception( + "The config intervals doesn't match the number of samples. This shouldn't have happened." + ) + self._intervals.append(chunk["interval"]) + + self._length = sum([chunk["chunk_size"] for chunk in self._chunks]) + + self._downloader = None + if remote_dir: + self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks) + + def download_chunk_from_index(self, chunk_index: int) -> None: + chunk_filename = self._chunks[chunk_index]["filename"] + + local_chunkpath = os.path.join(self._cache_dir, chunk_filename) + + if os.path.exists(local_chunkpath): + return None + + return self._downloader.download_chunk_from_index(chunk_index) + + @property + def intervals(self): + return self._intervals + + @property + def data_format(self): + return self._config["data_format"] + + @property + def config(self): + return self._config + + def _get_chunk_index_from_index(self, index: int) -> int: + for chunk_index, internal in enumerate(self._intervals): + if internal[0] <= index < internal[1]: + return chunk_index + raise ValueError( + f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}." + ) + + def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]: + """Find the associated chunk metadata.""" + chunk = self._chunks[index.chunk_index] + return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index] + + @classmethod + def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["ChunksConfig"]: + cache_index_filepath = os.path.join(cache_dir, INDEX_FILENAME) + + if isinstance(remote_dir, str): + downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, []) + downloader.download_file(os.path.join(remote_dir, INDEX_FILENAME), cache_index_filepath) + + if not os.path.exists(cache_index_filepath): + return None + + return ChunksConfig(cache_dir, remote_dir) + + def __len__(self) -> int: + return self._length diff --git a/src/lightning/data/cache/constants.py b/src/lightning/data/cache/constants.py new file mode 100644 index 0000000000000..ee7b6f59cb339 --- /dev/null +++ b/src/lightning/data/cache/constants.py @@ -0,0 +1,14 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +INDEX_FILENAME = "index.json" diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 8d5d5173acdc8..d1b96a7dc3c5a 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -11,11 +11,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio +import inspect import logging +import os from importlib import reload -from typing import Any, Optional +from typing import Any, Callable, Optional import torch +from lightning_utilities.core.imports import RequirementCache from torch.utils.data import Dataset, IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data.dataloader import ( @@ -28,7 +32,9 @@ from lightning.data.cache import Cache from lightning.data.cache.pytree import tree_flatten from lightning.data.cache.sampler import CacheBatchSampler -from lightning.data.datasets.env import _DistributedEnv +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv + +_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer") logger = logging.Logger(__name__) @@ -86,12 +92,19 @@ def __getitem__(self, index): class CacheCollateFn: - def __init__(self): - self.collate_fn = default_collate + def __init__(self, collate_fn: Optional[Callable] = None): + self.collate_fn = collate_fn or default_collate def __call__(self, items): if all(item is None for item in items): return None + + # If the __getitem__ method is asynchornous, collect all the items. + if all(inspect.iscoroutine(item) for item in items): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + items = loop.run_until_complete(asyncio.gather(*items)) + return self.collate_fn(items) @@ -100,25 +113,44 @@ class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): def _next_data(self): try: - return super()._next_data() + data = None + while data is None: + data = super()._next_data() + return data except StopIteration: for v in self._dataset_fetcher.dataset.__dict__.values(): if isinstance(v, Cache): v.done() + if not v.filled: + v.merge(1) raise StopIteration() class WorkerLoop: + """Wrap the PyTorch DataLoader WorkerLoop to perform caching and profiling.""" + + def __init__(self, global_rank: int, profile: bool = False) -> None: + self._global_rank = global_rank + self._profile = profile + def __call__(self, dataset_kind, *args, **kwargs): from torch.utils.data import _DatasetKind from torch.utils.data._utils import worker from lightning.data.cache.cache import Cache - reloaded_worker = reload(worker) + rank = _WorkerEnv.detect().rank + enable_profiling = self._global_rank == 0 and rank == 0 and _VIZ_TRACKER_AVAILABLE and self._profile - create_fetcher = _DatasetKind.create_fetcher + if enable_profiling: + from viztracer import VizTracer + + tracer = VizTracer(output_file=os.path.join(os.getcwd(), "trace.json")) + tracer.start() + # Reload to remove the patching + reloaded_worker = reload(worker) + create_fetcher = _DatasetKind.create_fetcher fetcher = None def create_fetcher_fn(*args, **kwargs): @@ -135,15 +167,35 @@ def create_fetcher_fn(*args, **kwargs): if isinstance(v, Cache): v.done() + if enable_profiling: + tracer.stop() + tracer.save() + class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): def __init__(self, loader): + self._cache = loader._cache + self._num_workers = loader.num_workers # Patch PyTorch worker loop to call the `cache.done()` method. from torch.utils.data._utils import worker - worker._worker_loop = WorkerLoop() + worker._worker_loop = WorkerLoop(loader._global_rank, loader._profile) super().__init__(loader) + def _shutdown_workers(self): + super()._shutdown_workers() + if not self._cache.filled: + self._cache.merge(self._num_workers) + + def _next_data(self): + try: + data = None + while data is None: + data = super()._next_data() + return data + except StopIteration as e: + raise e + class LightningDataLoader(DataLoader): __doc__ = DataLoader.__doc__ @@ -157,11 +209,13 @@ def __init__( num_workers=0, shuffle: bool = False, generator=None, - batch_size=1, + batch_size=None, drop_last=False, cache_dir: Optional[str] = None, chunk_bytes: Optional[int] = _DEFAULT_CHUNK_BYTES, compression: Optional[str] = None, + profile: bool = False, + collate_fn: Optional[Callable] = None, **kwargs, ): if sampler: @@ -177,6 +231,9 @@ def __init__( if isinstance(dataset, IterableDataset): raise ValueError("Only map-based dataset are supported by the LightningDataLoader for now.") + if profile and not _VIZ_TRACKER_AVAILABLE: + raise ModuleNotFoundError("To enable DataLoader profiling, run `pip install viztracer`.") + cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] if len(cache) > 1: @@ -186,10 +243,10 @@ def __init__( if len(cache) == 0: if cache_dir is None: - raise ValueError("You can provide a `cache_dir` filepath to the LightningDataLoader.") + raise ValueError("You should provide a `cache_dir` filepath to the LightningDataLoader.") - dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size if chunk_bytes else None, compression) - cache = dataset.cache + dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size, compression) + cache = dataset._cache else: cache = cache[0] @@ -198,25 +255,31 @@ def __init__( if not cache.filled and shuffle: logger.info("Shuffle is ignored during the caching phase phase") + self._cache = cache + distributed_env = _DistributedEnv.detect() + self._global_rank = distributed_env.global_rank + batch_sampler = CacheBatchSampler( len(dataset), distributed_env.world_size, - distributed_env.global_rank, + self._global_rank, num_workers, - batch_size, + batch_size or 1, drop_last, shuffle, cache, ) + self._profile = profile + super().__init__( dataset, *args, sampler=None, batch_sampler=batch_sampler, generator=generator, - collate_fn=CacheCollateFn(), + collate_fn=CacheCollateFn(collate_fn), num_workers=num_workers, **kwargs, ) diff --git a/src/lightning/data/cache/downloader.py b/src/lightning/data/cache/downloader.py new file mode 100644 index 0000000000000..21e2abfbc8f7f --- /dev/null +++ b/src/lightning/data/cache/downloader.py @@ -0,0 +1,73 @@ +# Copyright The Lightning AI team. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Type +from urllib import parse + + +class Downloader(ABC): + def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): + self._remote_dir = remote_dir + self._cache_dir = cache_dir + self._chunks = chunks + + def download_chunk_from_index(self, chunk_index: int) -> None: + chunk_filename = self._chunks[chunk_index]["filename"] + local_chunkpath = os.path.join(self._cache_dir, chunk_filename) + remote_chunkpath = os.path.join(self._remote_dir, chunk_filename) + return self.download_file(remote_chunkpath, local_chunkpath) + + @abstractmethod + def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: + pass + + +class S3Downloader(Downloader): + @classmethod + def downldownload_fileoad_file_from_s3(cls, remote_filepath: str, local_filepath: str): + import boto3 + from boto3.s3.transfer import TransferConfig + from botocore.config import Config + + obj = parse.urlparse(remote_filepath) + + if obj.scheme != "s3": + raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") + + extra_args = {} + + # Create a new session per thread + session = boto3.session.Session() + # Create a resource client using a thread's session object + s3 = session.client("s3", config=Config(read_timeout=None)) + # Threads calling S3 operations return RuntimeError (cannot schedule new futures after + # interpreter shutdown). Temporary solution is to have `use_threads` as `False`. + # Issue: https://github.com/boto/boto3/issues/3113 + s3.download_file( + obj.netloc, + obj.path.lstrip("/"), + local_filepath, + ExtraArgs=extra_args, + Config=TransferConfig(use_threads=False), + ) + + +_DOWNLOADERS = {"s3://": S3Downloader} + + +def get_downloader_cls(remote_dir: str) -> Type[Downloader]: + for k, cls in _DOWNLOADERS.items(): + if remote_dir.startswith(k): + return cls + raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.") diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 817b2e515d0f3..cdfc5f9c03f97 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -11,101 +11,62 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import os -from typing import Any, Dict, Optional, Tuple, Union +from contextlib import contextmanager +from threading import Lock, Thread +from time import sleep, time +from typing import Any, Dict, List, Optional import numpy as np -from lightning.data.cache.pytree import tree_unflatten, treespec_loads -from lightning.data.cache.sampler import BatchIndex +from lightning.data.cache.config import ChunksConfig +from lightning.data.cache.pytree import tree_unflatten +from lightning.data.cache.sampler import ChunkedIndex from lightning.data.cache.serializers import _SERIALIZERS, Serializer -from lightning.data.datasets.env import _DistributedEnv +from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv -class ChunksConfig: - def __init__(self, cache_dir: str, index_filenames: str): - self._cache_dir = cache_dir - self.index_filenames = sorted(index_filenames) - self._intervals = [] - self._config = None - self._chunks = [] - - for filename in self.index_filenames: - with open(os.path.join(self._cache_dir, filename)) as f: - data = json.load(f) +class PrepareChunksThread(Thread): + """This thread is responsible to download the chunks associated to a given worker.""" - if self._config is None: - self._config = data["config"] + def __init__(self, config: ChunksConfig): + super().__init__(daemon=True) + self._config = config + self._chunks_index_to_be_processed = [] + self._chunks_index_to_ready = [] + self._lock = Lock() - elif self._config != data["config"]: - raise Exception("The config isn't consistent between chunks. This shouldn't have happened.") + def add(self, chunk_indices: List[int]) -> None: + """Receive the list of the chunk indices to download for the current epoch.""" + with self._lock: + self._chunks_index_to_be_processed.extend(chunk_indices) - self._chunks.extend(data["chunks"]) + def run(self): + while True: + with self._lock: + if len(self._chunks_index_to_be_processed) == 0: + sleep(0.007) + continue + chunk_index = self._chunks_index_to_be_processed.pop(0) - self._config["data_spec"] = treespec_loads(self._config["data_spec"]) - - for chunk in self._chunks: - start, end = chunk["interval"] - if (end - start) != chunk["samples"]: - raise Exception( - "The config intervals doesn't match the number of samples. This shouldn't have happened." - ) - self._intervals.append(chunk["interval"]) - - self._length = sum([chunk["samples"] for chunk in self._chunks]) - - @property - def intervals(self): - return self._intervals - - @property - def data_format(self): - return self._config["data_format"] - - @property - def config(self): - return self._config - - def __getitem__(self, index: Union[int, BatchIndex]) -> Tuple[str, int, int]: - """Find the associated chunk metadata.""" - if isinstance(index, int): - for interval_config, internal in enumerate(self._intervals): - if internal[0] <= index and index < internal[1]: - chunk = self._chunks[interval_config] - mapping = chunk["mapping"][str(index)] - return os.path.join(self._cache_dir, chunk["filename"]), *mapping - # Note: Optimisation to avoid doing the interval search. - elif isinstance(index, BatchIndex): - chunk = self._chunks[index.chunk_index] - mapping = chunk["mapping"][str(index.index)] - return os.path.join(self._cache_dir, chunk["filename"]), *mapping - raise Exception(f"The chunk interval weren't properly defined. Found {self._intervals} for index {index}.") - - @classmethod - def load(cls, cache_dir: str) -> Optional["ChunksConfig"]: - files = os.listdir(cache_dir) - index_filenames = sorted([f for f in files if f.endswith("index.json")]) - if not index_filenames: - return None - return ChunksConfig(cache_dir, index_filenames) - - def __len__(self) -> int: - return self._length + # TODO: Implement eviction + self._config.download_chunk_from_index(chunk_index) + self._chunks_index_to_ready.append(chunk_index) class BinaryReader: - def __init__(self, cache_dir: str, compression: Optional[str] = None): + def __init__(self, cache_dir: str, remote_dir: Optional[str] = None, compression: Optional[str] = None): """The BinaryReader enables to read chunked dataset in an efficient way. Arguments: - cache_dir: The path to cache folder + cache_dir: The path to cache folder. + remote_dir: The path to a remote folder where the data are located. compression: The algorithm to decompress the chunks. """ - super().__init__() self._cache_dir = cache_dir + self._remote_dir = remote_dir if not os.path.exists(self._cache_dir): raise FileNotFoundError(f"The provided cache_dir `{self._cache_dir}` doesn't exist.") @@ -117,14 +78,44 @@ def __init__(self, cache_dir: str, compression: Optional[str] = None): self._chunks_data = {} self._serializers: Dict[str, Serializer] = _SERIALIZERS - self._env = _DistributedEnv.detect() + self._distributed_env = _DistributedEnv.detect() + self._rank = None self._config: Optional[ChunksConfig] = None + self._latest_chunk_index = None + self._executor = None + self._prepare_thread = None + + def _get_chunk_index_from_index(self, index: int): + # Load the config containing the index + if self._config is None: + self._try_load_config() + + if self._config is None: + raise Exception("The reader index isn't defined.") + + return self._config._get_chunk_index_from_index(index) def _try_load_config(self): """Try to load the chunks config if the index files are available.""" - self._config = ChunksConfig.load(self._cache_dir) + self._config = ChunksConfig.load(self._cache_dir, self._remote_dir) + return self._config - def read(self, index: Union[int, BatchIndex]): + @property + def rank(self): + """Returns the rank of the writer.""" + if self._rank is None: + self._worker_env = _WorkerEnv.detect() + self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank + return self._rank + + @contextmanager + def measure_on_rank_0(self, msg: str): + if self.rank == 0: + t0 = time() + yield + print(msg, time() - t0) + + def read(self, index: ChunkedIndex): """Read an item for the given from a chunk. If the chunk isn't available locally or in memory, it will be downloaded. @@ -132,64 +123,60 @@ def read(self, index: Union[int, BatchIndex]): Prefetching should reduce the wait time to be the batch available. """ - if self._config is None: - self._try_load_config() + if not isinstance(index, ChunkedIndex): + raise ValueError("The Reader.read(...) method expects a chunked Index.") - if self._config is None: + # Load the config containing the index + if self._config is None and self._try_load_config() is None: raise Exception("The reader index isn't defined.") - chunk_filepath, begin, end = self._config[index] - raw_item_data = self.load_item_from_chunk(chunk_filepath, begin, end, keep_in_memory=True) + # Create and start the prepare chunks thread + if index.chunk_indexes is not None and self._prepare_thread is None: + self._prepare_thread = PrepareChunksThread(self._config) + self._prepare_thread.start() + self._prepare_thread.add(index.chunk_indexes) + + # Fetch the element + chunk_filepath, begin, _ = self._config[index] + raw_item_data = self.load_item_from_chunk(index.index, chunk_filepath, begin) return self.deserialize(raw_item_data) def deserialize(self, raw_item_data: bytes) -> Any: """Deserialize the raw bytes into their python equivalent.""" - sizes = [] - idx = 0 - data_format = self._config.data_format - for _ in data_format: - (size,) = np.frombuffer(raw_item_data[idx : idx + 4], np.uint32) - sizes.append(size) - idx += 4 + idx = len(self._config.data_format) * 4 + sizes = np.frombuffer(raw_item_data[:idx], np.uint32) data = [] - for size, format in zip(sizes, data_format): - serializer = self._serializers[format] + for size, data_format in zip(sizes, self._config.data_format): + serializer = self._serializers[data_format] data_bytes = raw_item_data[idx : idx + size] data.append(serializer.deserialize(data_bytes)) idx += size return tree_unflatten(data, self._config.config["data_spec"]) - def load_item_from_chunk(self, chunk_filepath: str, begin: int, end: int, keep_in_memory: bool = False): - if chunk_filepath in self._chunks_data: - return self._chunks_data[chunk_filepath][begin:end] + def load_item_from_chunk(self, index: int, chunk_filepath: str, begin: int): + offset = (1 + (index - begin)) * 4 - if keep_in_memory: - with open(chunk_filepath, "rb", 0) as fp: - data = fp.read() - self._chunks_data[chunk_filepath] = data - return data[begin:end] + while not os.path.exists(chunk_filepath): + sleep(0.0001) with open(chunk_filepath, "rb", 0) as fp: + fp.seek(offset) + pair = fp.read(8) + begin, end = np.frombuffer(pair, np.uint32) fp.seek(begin) data = fp.read(end - begin) return data def get_length(self) -> int: """Get the number of samples across all chunks.""" - if self._config is None: - self._try_load_config() - - if self._config is None: + if self._config is None and self._try_load_config() is None: raise Exception("The reader index isn't defined.") return len(self._config) def get_chunk_interval(self): """Get the index interval of each chunk.""" - if self._config is None: - self._try_load_config() - - if self._config is None: + if self._config is None and self._try_load_config() is None: raise Exception("The reader index isn't defined.") return self._config.intervals diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py index fee11e63b63bc..e321febacec55 100644 --- a/src/lightning/data/cache/sampler.py +++ b/src/lightning/data/cache/sampler.py @@ -13,172 +13,26 @@ import logging from dataclasses import dataclass -from typing import Iterator, List +from typing import Any, List, Optional import numpy as np -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data.sampler import BatchSampler, RandomSampler, Sampler, SequentialSampler, Sized logger = logging.Logger(__name__) -class IteratorSampler(Sampler[int]): - r"""Samples elements sequentially, always in the same order. - - Args: - data_source (Dataset): dataset to sample from - - """ - data_source: Sized - - def __init__(self, data_source: Sized) -> None: - self.data_source = data_source - - def __iter__(self) -> Iterator[int]: - return iter(self.data_source) - - def __len__(self) -> int: - return len(self.data_source) - - -class BaseCacheSampler(Sampler): - def __init__(self, dataset_size: int): - super().__init__(None) - self.dataset_size = dataset_size - self.worker_id = 0 - self.index_id = 0 - self.iterators = [] - self._done = set() - - def __len__(self) -> int: - return self.dataset_size - - @property - def done(self) -> bool: - return len(self._done) == len(self.iterators) - - def __iter__(self) -> "BaseCacheSampler": - self._done = set() - - for sampler in self.samplers: - self.iterators.append(iter(sampler)) - - return self - - def _next_worker_id(self): - if self.done: - return - counter = 1 - while True: - next_worker_id = (self.worker_id + counter) % self.num_workers - if next_worker_id not in self._done: - self.worker_id = next_worker_id - break - counter += 1 - - def __next__(self) -> List[int]: - while len(self._done) != self.iterators: - try: - data = next(self.iterators[self.worker_id]) - self.index_id += 1 - if self.index_id == self.batch_size: - self.index_id = 0 - self._next_worker_id() - return data - except StopIteration: - self._done.add(self.worker_id) - self.index_id = 0 - self._next_worker_id() - raise StopIteration - - -class CacheSampler(BaseCacheSampler): - def __init__(self, dataset_size: int, num_workers: int, batch_size: int): - """The CacheSampler splits the dataset indices into ordered chunks and assign each one of them to a DataLoader - worker. The Cache Writer expects the index to be provided in an ordered fashion. - - Arguments: - dataset_size: The size of the dataset. - num_workers: The number of workers provided to the DataLoader - batch_size: The number of items in a batch - - """ - - super().__init__(dataset_size) - self.batch_size = batch_size - self.num_workers = num_workers - self.indices = range(dataset_size) - worker_size = dataset_size // self.num_workers - self.samplers = [] - for worker_idx in range(num_workers): - is_last = worker_idx == num_workers - 1 - worker_indices = self.indices[ - worker_idx * worker_size : dataset_size if is_last else (worker_idx + 1) * worker_size - ] - self.samplers.append(IteratorSampler(worker_indices)) - self.iterators = [] - self._done = set() - assert sum([len(s) for s in self.samplers]) == dataset_size - self.worker_id = 0 - self.index_id = 0 - - -class DistributedCacheSampler(BaseCacheSampler): - def __init__(self, dataset_size: int, num_replicas: int, rank: int, num_workers: int, batch_size: int): - """The DistributedCacheSampler splits the dataset indices into ordered chunks along all the replicas and their - workers. The Cache Writer expects the index to be provided in an ordered fashion. - - Arguments: - dataset_size: The size of the dataset. - num_workers: The number of workers provided to the DataLoader - batch_size: The number of items in a batch - - """ - super().__init__(dataset_size) - self.batch_size = batch_size - self.num_workers = num_workers - self.indices = range(dataset_size) - replica_size = dataset_size // num_replicas - worker_size = dataset_size // (num_replicas * self.num_workers) - self.samplers = [] - for replica_idx in range(num_replicas): - if replica_idx != rank: - continue - - is_last_replica = replica_idx == num_replicas - 1 - start_replica = replica_idx * replica_size - end_replica = dataset_size if is_last_replica else (replica_idx + 1) * replica_size - replica_indices = self.indices[start_replica:end_replica] - - replica_size = len(replica_indices) - - for worker_idx in range(num_workers): - is_last_worker = worker_idx == num_workers - 1 - start_worker = worker_idx * worker_size - end_worker = replica_size if is_last_worker else (worker_idx + 1) * worker_size - worker_indices = replica_indices[start_worker:end_worker] - self.samplers.append(IteratorSampler(worker_indices)) - - self.iterators = [] - self._done = set() - - assert sum([len(s) for s in self.samplers]) == replica_size - self.worker_id = 0 - self.index_id = 0 - - @dataclass -class BatchIndex: +class ChunkedIndex: index: int chunk_index: int + chunk_indexes: Optional[List[int]] = None -class CacheBatchSampler(BatchSampler): +class CacheBatchSampler: def __init__( self, dataset_size: int, num_replicas: int, - rank: int, + global_rank: int, num_workers: int, batch_size: int, drop_last: bool, @@ -193,33 +47,22 @@ def __init__( Arguments: dataset_size: The size of the dataset. num_replicas: The number of processes involves in the distributed training. - rank: The rank of the given process + global_rank: The global_rank of the given process num_workers: The number of workers provided to the DataLoader. batch_size: The number of items in a batch. + drop_last: Whether to drop the last batch of data. shuffle: Whether the data should be shuffled. cache: The cache associated to the dataset. """ - - if num_replicas == 1: - if not cache.filled and num_workers > 1: - sampler = CacheSampler(dataset_size, num_workers, batch_size) - elif shuffle: - sampler = RandomSampler(range(dataset_size)) - else: - sampler = SequentialSampler(range(dataset_size)) - else: - if not cache.filled: - sampler = DistributedCacheSampler(dataset_size, num_replicas, rank, num_workers, batch_size) - else: - sampler = DistributedSampler(range(dataset_size), num_replicas=num_replicas, rank=rank, shuffle=shuffle) - super().__init__(sampler, batch_size, drop_last) + self._dataset_size = dataset_size self._num_replicas = num_replicas - self._rank = rank + self._global_rank = global_rank self._cache = cache self._shuffle = shuffle - self._num_workers = num_workers + self._num_workers = num_workers or 1 self._shuffled_chunk_intervals = None + self._batch_size = batch_size # self._validate() @@ -261,47 +104,135 @@ def __iter_ordered__(self) -> Iterator[List[int]]: batch = [] def __iter__(self): - if self._cache.filled and self._shuffle: - return self.__iter_from_chunks__() - if self._num_workers > 1 and not self._cache.filled: - return self.__iter_ordered__() - return super().__iter__() + # When the cache is filled, we need to iterate though the chunks + if self._cache.filled: + if self._num_replicas == 1: + return self.__iter_from_chunks_non_distributed__() + return self.__iter_from_chunks_distributed__() - def __iter_from_chunks__(self): + # shuffle is ignored while building the binarized version of the dataset + if self._num_replicas == 1: + return self.__iter_non_distributed__() + return self.__iter_distributed__() + + def __iter_non_distributed__(self): + worker_size = self._dataset_size // self._num_workers + self.samplers = [] + indices = list(range(self._dataset_size)) + worker_indices = [] + for worker_idx in range(self._num_workers): + is_last = worker_idx == self._num_workers - 1 + start = worker_idx * worker_size + end = self._dataset_size if is_last else (worker_idx + 1) * worker_size + worker_indices.append(indices[start:end]) + + assert sum([len(s) for s in worker_indices]) == self._dataset_size + + worker_indices_batches = [self._chunk_list(indices, self._batch_size) for indices in worker_indices] + + yield from self.__iter_indices_per_workers__(worker_indices_batches) + + def __iter_distributed__(self): + self.indices = list(range(self._dataset_size)) + replica_size = self._dataset_size // self._num_replicas + worker_size = self._dataset_size // (self._num_replicas * self._num_workers) + self.samplers = [] + for rank in range(self._num_replicas): + if rank != self._global_rank: + continue + + is_last_replica = rank == self._num_replicas - 1 + start_replica = rank * replica_size + end_replica = self._dataset_size if is_last_replica else (rank + 1) * replica_size + replica_indices = self.indices[start_replica:end_replica] + + replica_size = len(replica_indices) + + worker_indices = [] + for worker_idx in range(self._num_workers): + is_last_worker = worker_idx == self._num_workers - 1 + start_worker = worker_idx * worker_size + end_worker = replica_size if is_last_worker else (worker_idx + 1) * worker_size + worker_indices.append(replica_indices[start_worker:end_worker]) + + assert sum([len(s) for s in worker_indices]) == len(replica_indices) + + worker_indices_batches = [self._chunk_list(indices, self._batch_size) for indices in worker_indices] + + yield from self.__iter_indices_per_workers__(worker_indices_batches) + + def __iter_from_chunks_non_distributed__(self): chunk_intervals = self._cache.get_chunk_interval() - shuffled_indices = np.random.permutation(range(len(chunk_intervals))) - self._shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indices] + shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) + shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] + yield from self.__iter_from_shuffled_chunks(shuffled_indexes, shuffled_chunk_intervals) - if self._num_replicas == 1: - indices = [] - for interval, chunk_index in zip(self._shuffled_chunk_intervals, shuffled_indices): - interval_indices = np.arange(interval[0], interval[1]) - shuffled_interval_indices = np.random.permutation(interval_indices).tolist() - indices.extend([BatchIndex(index, chunk_index) for index in shuffled_interval_indices]) - - if len(indices) != len(self.sampler): - raise Exception("The generated indices don't match the initial length of the sampler.") - - else: - chunks_per_replica = len(self._shuffled_chunk_intervals) // self._num_replicas - for replica_idx in range(self._num_replicas): - if replica_idx != self._rank: - continue - is_last_replica = replica_idx == self._num_replicas - 1 - start_replica = replica_idx * chunks_per_replica - end_replica = len(chunk_intervals) if is_last_replica else (replica_idx + 1) * chunks_per_replica - shuffled_chunk_intervals_replica = self._shuffled_chunk_intervals[start_replica:end_replica] - shuffled_indices_replica = shuffled_indices[start_replica:end_replica] - - indices = [] - for interval, chunk_index in zip(shuffled_chunk_intervals_replica, shuffled_indices_replica): - interval_indices = np.arange(interval[0], interval[1]) - shuffled_interval_indices = np.random.permutation(interval_indices).tolist() - indices.extend([BatchIndex(index, chunk_index) for index in shuffled_interval_indices]) - - self.sampler = IteratorSampler(indices) - - return super().__iter__() + def __iter_from_chunks_distributed__(self): + chunk_intervals = self._cache.get_chunk_interval() + shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) + shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] + + replica_chunks = [] + replica_intervals = [] + for index, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)): + if index % self._num_replicas == self._global_rank: + replica_chunks.append(chunk_index) + replica_intervals.append(chunk_interval) + + yield from self.__iter_from_shuffled_chunks(replica_chunks, replica_intervals) + + def __iter_from_shuffled_chunks(self, shuffled_indexes, shuffled_chunk_intervals): + chunks_per_workers = [[] for _ in range(self._num_workers)] + for i, chunk_index in enumerate(shuffled_indexes): + chunks_per_workers[i % self._num_workers].append(chunk_index) + + indices_per_workers = [[] for _ in range(self._num_workers)] + + for i, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)): + worker_id = i % self._num_workers + interval_indices = np.arange(chunk_interval[0], chunk_interval[1]) + shuffled_interval_indices = np.random.permutation(interval_indices).tolist() + is_empty = len(indices_per_workers[worker_id]) == 0 + indices_per_workers[worker_id].extend( + [ + ChunkedIndex( + index, + chunk_index, + chunk_indexes=chunks_per_workers[worker_id] if j == 0 and is_empty else None, + ) + for j, index in enumerate(shuffled_interval_indices) + ] + ) + + indices_per_workers_splitted = [self._chunk_list(indices, self._batch_size) for indices in indices_per_workers] + + yield from self.__iter_indices_per_workers__(indices_per_workers_splitted) def __len__(self) -> int: return super().__len__() + + def __iter_indices_per_workers__(self, indices_per_workers): + batches = [] + counter = 0 + while sum([len(v) for v in indices_per_workers]) != 0: + worker_indices = indices_per_workers[counter % self._num_workers] + if len(worker_indices) == 0: + batches.append([]) + else: + batches.append(worker_indices.pop(0)) + counter += 1 + + while True: + if len(batches[-1]) == 0: + batches.pop(-1) + else: + break + + yield from batches + + def _chunk_list(self, arr: List[Any], chunk_size: int) -> List[List[Any]]: + out = [] + for i in range(0, len(arr), chunk_size): + slice_item = slice(i, i + chunk_size, 1) + out.append(arr[slice_item]) + return out diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 6d8e2f5c0fc6e..47615349109e4 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -161,9 +161,8 @@ def _create_chunk(self, filename: str) -> bytes: chunk_info = { "chunk_bytes": self._current_chunk_size, - "samples": len(self._serialized_items), + "chunk_size": len(self._serialized_items), "filename": filename, - "mapping": mapping, "interval": [self._indexes[0], self._indexes[-1] + 1], } @@ -255,21 +254,21 @@ def done(self): self.reset() self._is_done = True - def merge(self): + def merge(self, num_workers: int): + num_workers = num_workers or 1 if self.rank != 0: while not os.path.exists(os.path.join(self._cache_dir, "index.json")): sleep(0.001) return - num_workers = _WorkerEnv.detect().world_size - is_done = False while not is_done: files = os.listdir(self._cache_dir) if "index.json" in files: return - index_files = [f for f in files if f.endswith("index.json") and f != "index.json"] + index_files = [f for f in files if f.endswith("index.json")] is_done = len(index_files) == self._distributed_env.world_size * num_workers + sleep(0.001) chunks_info = [] config = None diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index 47b1a27ee8fa7..c30a890ab0c5a 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -78,7 +78,7 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): cache_dir = os.path.join(tmpdir, "cache") distributed_env = _DistributedEnv.detect() - cache = Cache(cache_dir, chunk_bytes=2 << 12) + cache = Cache(cache_dir, chunk_size=10) dataset = ImageDataset(tmpdir, cache, dataset_size, 10) dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4) @@ -102,9 +102,10 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): if distributed_env.world_size == 1: indexes = [] - for batch in LightningDataLoader(dataset, num_workers=num_workers, batch_size=4): - indexes.extend(batch["index"].numpy().tolist()) - + dataloader = LightningDataLoader(dataset, num_workers=num_workers, batch_size=4) + for batch in dataloader: + if batch: + indexes.extend(batch["index"].numpy().tolist()) assert len(indexes) == dataset_size seed_everything(42) @@ -164,6 +165,7 @@ def test_cache_with_simple_format(tmpdir): cache[i] = i cache.done() + cache.merge() for i in range(100): assert i == cache[i] @@ -177,6 +179,7 @@ def test_cache_with_simple_format(tmpdir): cache[i] = [i, {0: [i + 1]}] cache.done() + cache.merge() for i in range(100): assert [i, {0: [i + 1]}] == cache[i] @@ -190,10 +193,10 @@ def test_cache_with_auto_wrapping(tmpdir): for batch in dataloader: assert isinstance(batch, torch.Tensor) assert sorted(os.listdir(os.path.join(tmpdir, "cache_1"))) == [ - "0.index.json", "chunk-0-0.bin", "chunk-0-1.bin", "chunk-0-2.bin", + "index.json", ] # Your dataset is optimised for the cloud diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index cec4418105db2..4f4196a7631de 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -1,168 +1,71 @@ from unittest import mock -import numpy as np import pytest from lightning import seed_everything -from lightning.data.cache.sampler import CacheBatchSampler, CacheSampler, DistributedCacheSampler - - -def test_cache_sampler_sampling(): - """Valides the CacheSampler can return batch of data in an ordered way.""" - dataset_size = 17 - sampler = CacheSampler(dataset_size, 3, 3) - iter_sampler = iter(sampler) - - all_indexes = [] - indexes = [] - while True: - try: - index = next(iter_sampler) - indexes.append(index) - all_indexes.append(index) - except StopIteration: - assert indexes == [0, 1, 2, 5, 6, 7, 10, 11, 12, 3, 4] - assert sampler._done == {0} - break - - indexes = [] - while True: - try: - index = next(iter_sampler) - indexes.append(index) - all_indexes.append(index) - except StopIteration: - assert indexes == [8, 9] - assert sampler._done == {0, 1} - break - - indexes = [] - while True: - try: - index = next(iter_sampler) - indexes.append(index) - all_indexes.append(index) - except StopIteration: - assert indexes == [13, 14, 15, 16] - assert sampler._done == {0, 1, 2} - break - - assert sorted(all_indexes) == list(range(dataset_size)) - - -@pytest.mark.parametrize( - "params", - [ - (21, range(0, 7), range(7, 14), range(14, 21)), - (23, range(0, 7), range(7, 14), range(14, 23)), - (33, range(0, 11), range(11, 22), range(22, 33)), - (49, range(0, 16), range(16, 32), range(32, 49)), - (5, range(0, 1), range(1, 2), range(2, 5)), - (12, range(0, 4), range(4, 8), range(8, 12)), - ], -) -def test_cache_sampler_samplers(params): - sampler = CacheSampler(params[0], 3, 3) - assert sampler.samplers[0].data_source == params[1] - assert sampler.samplers[1].data_source == params[2] - assert sampler.samplers[2].data_source == params[3] +from lightning.data.cache.sampler import CacheBatchSampler @pytest.mark.parametrize( "params", [ ( - 102, - 2, - [ - [range(0, 17), range(17, 34), range(34, 51)], - [range(51, 68), range(68, 85), range(85, 102)], - ], + 21, + 1, + [[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]], + [[7, 0, 0], [1, 1, 1], [5, 5, 5], [0, 4, 4], [8, 3, 3], [2, 2, 2], [4], [3], [6]], ), ( - 227, - 5, - [ - [range(0, 15), range(15, 30), range(30, 45)], - [range(45, 60), range(60, 75), range(75, 90)], - [range(90, 105), range(105, 120), range(120, 135)], - [range(135, 150), range(150, 165), range(165, 180)], - [range(180, 195), range(195, 210), range(210, 227)], - ], + 11, + 1, + [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]], + [[1, 1, 1], [3, 3], [0, 0, 0], [2, 2, 2]], ), + (8, 1, [[0, 1], [2, 3], [4, 5, 6], [], [], [7]], [[1, 1, 2], [3], [0, 0], [2, 2]]), + (4, 1, [[0], [1], [2, 3]], [[0], [1], [2, 2]]), ( - 1025, - 7, - [ - [range(0, 48), range(48, 96), range(96, 146)], - [range(146, 194), range(194, 242), range(242, 292)], - [range(292, 340), range(340, 388), range(388, 438)], - [range(438, 486), range(486, 534), range(534, 584)], - [range(584, 632), range(632, 680), range(680, 730)], - [range(730, 778), range(778, 826), range(826, 876)], - [range(876, 924), range(924, 972), range(972, 1025)], - ], + 9, + 1, + [[0, 1, 2], [3, 4, 5], [6, 7, 8]], + [[0, 0, 0], [1, 1, 1], [2, 2, 2]], ), ( - 323, - 2, - [ - [range(0, 53), range(53, 106), range(106, 161)], - [range(161, 214), range(214, 267), range(267, 323)], - ], - ), - ( - 23, - 3, - [ - [range(0, 2), range(2, 4), range(4, 7)], - [range(7, 9), range(9, 11), range(11, 14)], - [range(14, 16), range(16, 18), range(18, 23)], - ], - ), - ( - 45, - 2, - [ - [range(0, 7), range(7, 14), range(14, 22)], - [range(22, 29), range(29, 36), range(36, 45)], - ], + 19, + 1, + [[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]], + [[0, 0, 0], [1, 1, 1], [5, 5, 5], [2, 2, 2], [4, 4, 4], [3, 3, 3], [6]], ), - ], -) -def test_cache_distributed_sampler_samplers(params): - """This test validates the sub-samplers of the DistributedCacheSampler has the right sampling intervals.""" - for rank in range(params[1]): - sampler = DistributedCacheSampler(params[0], params[1], rank, 3, 3) - assert sampler.samplers[0].data_source == params[2][rank][0] - assert sampler.samplers[1].data_source == params[2][rank][1] - assert sampler.samplers[2].data_source == params[2][rank][2] - - -@pytest.mark.parametrize( - "params", - [ - (21, 1, [[0, 1, 2], [7, 8, 9], [14, 15, 16], [3, 4, 5], [10, 11, 12], [17, 18, 19], [6], [13], [20]]), - (11, 1, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], [9, 10]]), - (8, 1, [[0, 1], [2, 3], [4, 5, 6], [7]]), - (4, 1, [[0], [1], [2, 3]]), - (9, 1, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], []]), - (19, 1, [[0, 1, 2], [6, 7, 8], [12, 13, 14], [3, 4, 5], [9, 10, 11], [15, 16, 17], [], [], [18]]), - (19, 2, [[0, 1, 2], [3, 4, 5], [6, 7, 8], [], [], []]), + (19, 2, [[0, 1, 2], [3, 4, 5], [6, 7, 8]], [[0, 0, 0], [5, 5, 5], [4, 4, 4], [6]]), ], ) def test_cache_batch_sampler(params): + seed_everything(42) + cache = mock.MagicMock() cache.filled = False - batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache) - batches = [] - for batch in batch_sampler: - batches.append(batch) - assert batches == params[2] - - chunk_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)] + if params[1] > 1: + batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache) + batches = [] + for batch in batch_sampler: + batches.append(batch) + assert batches == params[2], batches + + batch_sampler = CacheBatchSampler(params[0], 1, 0, 3, 3, False, True, cache) + batches = [] + for batch in batch_sampler: + batches.append(batch) + + chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)] + else: + batch_sampler = CacheBatchSampler(params[0], params[1], 0, 3, 3, False, True, cache) + batches = [] + for batch in batch_sampler: + batches.append(batch) + assert batches == params[2], batches + + chunks_interval = [[batch[0], batch[-1] + 1] for batch in batches if len(batch)] cache.filled = True - cache.get_chunk_interval.return_value = chunk_interval + cache.get_chunk_interval.return_value = chunks_interval seed_everything(42) @@ -170,45 +73,29 @@ def test_cache_batch_sampler(params): batches_1 = [] for batch in batch_sampler: - batches_1.extend(batch) + batches_1.append(batch) - def validate_batch(data): - chunks = batch_sampler._shuffled_chunk_intervals + def validate_batch(data, check_values): if params[1] == 1: - size = 0 - for interval in chunks: - interval_indices = np.arange(interval[0], interval[1]) - for indice in interval_indices: - assert indice in [b.index for b in data[size : size + len(interval_indices)]] - size += len(interval_indices) + assert all(b[0].chunk_indexes is not None for b in data[:3]) + assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3]) + assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:]) + if check_values: + assert [[x.chunk_index for x in d] for d in data] == params[3] else: - chunks_per_replica = len(chunks) // params[1] - for replica_idx in range(params[1]): - if replica_idx != 0: - continue - is_last_replica = replica_idx == params[1] - 1 - start_replica = replica_idx * chunks_per_replica - end_replica = len(chunks) if is_last_replica else (replica_idx + 1) * chunks_per_replica - shuffled_chunk_intervals_replica = chunks[start_replica:end_replica] - - assert len(shuffled_chunk_intervals_replica) - - size = 0 - for interval in shuffled_chunk_intervals_replica: - interval_indices = np.arange(interval[0], interval[1]) - for indice in interval_indices: - assert indice in [b.index for b in data[size : size + len(interval_indices)]] - size += len(interval_indices) - - validate_batch(batches_1) - if params[1] == 1: - assert len(batches_1) == params[0] + assert all(b[0].chunk_indexes is not None for b in data[:3]) + assert all(b[1].chunk_indexes is None if len(b) > 1 else True for b in data[:3]) + assert all(b[0].chunk_indexes is None if len(b) else True for b in data[3:]) + if check_values: + assert [[x.chunk_index for x in d] for d in data] == params[3] + + validate_batch(batches_1, True) batches_2 = [] for batch in batch_sampler: - batches_2.extend(batch) + batches_2.append(batch) - validate_batch(batches_2) + validate_batch(batches_2, False) if params[1] == 1: assert batches_1 != batches_2 diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index 1a690576b1ddf..ea4f447a6aeae 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -141,7 +141,7 @@ def test_tensor_serializer(): ratio_times.append(pickle_time / tensor_time) ratio_bytes.append(pickle_bytes / tensor_bytes) - assert np.mean(ratio_times) > 4 + assert np.mean(ratio_times) > 3.5 assert np.mean(ratio_bytes) > 2 diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index bfa57734583ef..57d541a1d3166 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -17,6 +17,7 @@ import numpy as np import pytest from lightning.data.cache.reader import BinaryReader +from lightning.data.cache.sampler import ChunkedIndex from lightning.data.cache.writer import BinaryWriter from lightning_utilities.core.imports import RequirementCache @@ -42,13 +43,19 @@ def test_binary_writer_with_ints_and_chunk_bytes(tmpdir): with open(os.path.join(tmpdir, "0.index.json")) as f: data = json.load(f) - assert data["chunks"][0]["samples"] == 6 - assert data["chunks"][1]["samples"] == 5 - assert data["chunks"][-1]["samples"] == 4 + assert data["chunks"][0]["chunk_size"] == 6 + assert data["chunks"][1]["chunk_size"] == 5 + assert data["chunks"][-1]["chunk_size"] == 4 + + chunk_sizes = np.cumsum([chunk["chunk_size"] for chunk in data["chunks"]]) reader = BinaryReader(tmpdir) for i in range(100): - data = reader.read(i) + for chunk_index, chunk_start in enumerate(chunk_sizes): + if i >= chunk_start: + continue + break + data = reader.read(ChunkedIndex(i, chunk_index=chunk_index)) assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} @@ -71,13 +78,13 @@ def test_binary_writer_with_ints_and_chunk_size(tmpdir): with open(os.path.join(tmpdir, "0.index.json")) as f: data = json.load(f) - assert data["chunks"][0]["samples"] == 25 - assert data["chunks"][1]["samples"] == 25 - assert data["chunks"][-1]["samples"] == 25 + assert data["chunks"][0]["chunk_size"] == 25 + assert data["chunks"][1]["chunk_size"] == 25 + assert data["chunks"][-1]["chunk_size"] == 25 reader = BinaryReader(tmpdir) for i in range(100): - data = reader.read(i) + data = reader.read(ChunkedIndex(i, chunk_index=i // 25)) assert data == {"i": i, "i+1": i + 1, "i+2": i + 2} @@ -108,12 +115,12 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): with open(os.path.join(cache_dir, "0.index.json")) as f: data = json.load(f) - assert data["chunks"][0]["samples"] == 4 - assert data["chunks"][1]["samples"] == 4 - assert data["chunks"][-1]["samples"] == 4 + assert data["chunks"][0]["chunk_size"] == 4 + assert data["chunks"][1]["chunk_size"] == 4 + assert data["chunks"][-1]["chunk_size"] == 4 reader = BinaryReader(cache_dir) for i in range(100): - data = reader.read(i) + data = reader.read(ChunkedIndex(i, chunk_index=i // 4)) assert data["x"] == imgs[i] assert data["y"] == i From 200d6b54b7e243470ce165fc97b3b0ac9738a2e5 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 13:20:44 +0100 Subject: [PATCH 59/84] update --- src/lightning/data/cache/sampler.py | 7 +++---- tests/tests_data/cache/test_sampler.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py index e321febacec55..d65acfe81bb7c 100644 --- a/src/lightning/data/cache/sampler.py +++ b/src/lightning/data/cache/sampler.py @@ -13,7 +13,7 @@ import logging from dataclasses import dataclass -from typing import Any, List, Optional +from typing import Any, Iterator, List, Optional import numpy as np @@ -64,7 +64,8 @@ def __init__( self._shuffled_chunk_intervals = None self._batch_size = batch_size - # self._validate() + # Before starting, ensures the chunk indices are properly defined. + self._validate() def _validate(self): if self._num_workers > 1 and not self._cache.filled: @@ -75,8 +76,6 @@ def _validate(self): batches[worker_index] = [] batches[worker_index].extend(batch_indices) elif len(batch_indices) > 0: - if batches[worker_index][-1] != (batch_indices[0] - 1): - breakpoint() batches[worker_index].extend(batch_indices) for indices in batches.values(): diff --git a/tests/tests_data/cache/test_sampler.py b/tests/tests_data/cache/test_sampler.py index 4f4196a7631de..d6528794a33f5 100644 --- a/tests/tests_data/cache/test_sampler.py +++ b/tests/tests_data/cache/test_sampler.py @@ -101,6 +101,7 @@ def validate_batch(data, check_values): def test_batch_sampler_imagenet(): + """Validate the Imagenet dataset is valid.""" dataset_size = 1281167 world_size = 1 rank = 0 From 79994abf872229388ce3ae0ba11adfe4f999905a Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 13:21:58 +0100 Subject: [PATCH 60/84] update --- _notebooks | 1 + 1 file changed, 1 insertion(+) create mode 160000 _notebooks diff --git a/_notebooks b/_notebooks new file mode 160000 index 0000000000000..70821217ca0db --- /dev/null +++ b/_notebooks @@ -0,0 +1 @@ +Subproject commit 70821217ca0db0280af537002839dbb340f77d68 From cbb7487e2268e69a205b84a5e342b630450956fd Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 13:40:33 +0100 Subject: [PATCH 61/84] update --- src/lightning/data/cache/cache.py | 3 +- src/lightning/data/cache/config.py | 8 +-- src/lightning/data/cache/dataloader.py | 20 ++++++ src/lightning/data/cache/writer.py | 93 +++++++++++++++----------- 4 files changed, 79 insertions(+), 45 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 0e9edc4f9d288..e16c8046cba8b 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -15,6 +15,7 @@ import os from typing import Any, Dict, Optional, Union +from lightning.data.cache.constants import INDEX_FILENAME from lightning.data.cache.reader import BinaryReader from lightning.data.cache.sampler import ChunkedIndex from lightning.data.cache.writer import BinaryWriter @@ -62,7 +63,7 @@ def filled(self) -> bool: raise Exception("The Cache wasn't setup properly. HINT: Did you use the LightningDataLoader ?") if self._is_done: return True - self._is_done = os.path.exists(os.path.join(self._cache_dir, "index.json")) + self._is_done = os.path.exists(os.path.join(self._cache_dir, INDEX_FILENAME)) return self._is_done def __setitem__(self, index, data) -> None: diff --git a/src/lightning/data/cache/config.py b/src/lightning/data/cache/config.py index 736eb1073972d..7360d01aea2bf 100644 --- a/src/lightning/data/cache/config.py +++ b/src/lightning/data/cache/config.py @@ -13,7 +13,7 @@ import json import os -from typing import Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from lightning.data.cache.constants import INDEX_FILENAME from lightning.data.cache.downloader import get_downloader_cls @@ -71,15 +71,15 @@ def download_chunk_from_index(self, chunk_index: int) -> None: return self._downloader.download_chunk_from_index(chunk_index) @property - def intervals(self): + def intervals(self) -> List[List[int]]: return self._intervals @property - def data_format(self): + def data_format(self) -> List[str]: return self._config["data_format"] @property - def config(self): + def config(self) -> Dict[str, Any]: return self._config def _get_chunk_index_from_index(self, index: int) -> int: diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index d1b96a7dc3c5a..4a1e7e693ecbd 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -69,6 +69,16 @@ def __init__( chunk_size: int, compression: Optional[str], ): + """ + The `CacheDataset` is a dataset wraper to provide a beginner experience with the Cache. + + Arguments: + dataset: The dataset of the user + cache_dir: The folder where the chunks are written to. + chunk_bytes: The maximal number of bytes to write within a chunk. + chunk_sie: The maximal number of items to write to a chunk. + compression: The compression algorithm to use to reduce the size of the chunk. + """ self._datataset = dataset self._cache = Cache(cache_dir, chunk_bytes=chunk_bytes, chunk_size=chunk_size, compression=compression) self._is_deterministic = False @@ -92,6 +102,14 @@ def __getitem__(self, index): class CacheCollateFn: + """This CacheCollateFn is used to accelerate the processing of the data generated using the Cache. + + During the chunking phase, there is no need to return any data from the DataLoader reducing some time. + + Additionally, if the user makes their __getitem__ asynchronous, the collate executes them in parallel. + + """ + def __init__(self, collate_fn: Optional[Callable] = None): self.collate_fn = collate_fn or default_collate @@ -184,6 +202,8 @@ def __init__(self, loader): def _shutdown_workers(self): super()._shutdown_workers() + + # If the data isn't filled, we trigger an indedm merge if not self._cache.filled: self._cache.merge(self._num_workers) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 47615349109e4..8573b29d7ffee 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -14,11 +14,12 @@ import json import os from time import sleep -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional import numpy as np -from lightning.data.cache.compression import _COMPRESSORS +from lightning.data.cache.compression import _COMPRESSORS, Compressor +from lightning.data.cache.constants import INDEX_FILENAME from lightning.data.cache.pytree import tree_flatten, treespec_dumps from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -65,13 +66,13 @@ def __init__( raise ValueError( f"The provided compression {self._compression} isn't available in {sorted(_COMPRESSORS)}" ) - self._compressor = _COMPRESSORS[self._compression] + self._compressor: Compressor = _COMPRESSORS[self._compression] - self._current_chunk_size = 0 - self._chunk_id = 0 - self._serialized_items = [] + self._current_chunk_bytes = 0 + self._chunk_index = 0 + self._serialized_items: List[bytes] = [] self._chunks_info = [] - self._indexes = [] + self._indexes: List[int] = [] self._worker_env = None self._rank = None self._is_done = False @@ -83,7 +84,7 @@ def filled(self) -> bool: if self._is_done: return True files = os.listdir(self._cache_dir) - index_files = [f for f in files if f.endswith("index.json")] + index_files = [f for f in files if f.endswith(INDEX_FILENAME)] worker_end = _WorkerEnv.detect() self._is_done = len(index_files) == self._distributed_env.world_size * worker_end.world_size return self._is_done @@ -109,8 +110,11 @@ def get_config(self) -> Dict[str, Any]: def serialize(self, items: Any) -> bytes: """Serialize a dictionary into its binary format.""" + + # Flatten the items provided by the users flattened, data_spec = tree_flatten(items) + # Collect the sizes and associated bytes for each item sizes = [] data = [] @@ -130,11 +134,13 @@ def serialize(self, items: Any) -> bytes: elif self._data_spec != data_spec: raise Exception(f"The data format changed between items. Found {data_spec} instead of {self._data_spec}.") + # Concatenante into a single byte array head = np.array(sizes, np.uint32).tobytes() body = b"".join(data) return head + body - def _serialize(self, item, sizes, data) -> bytes: + def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> None: + """Serialize a given item and append its size and bytes to the sizes and data array.""" for serializer_name, serializer in self._serializers.items(): if serializer.can_serialize(item): serialized_item = serializer.serialize(item) @@ -160,7 +166,7 @@ def _create_chunk(self, filename: str) -> bytes: assert (self._indexes[-1] - self._indexes[0] + 1) == len(self._serialized_items) chunk_info = { - "chunk_bytes": self._current_chunk_size, + "chunk_bytes": self._current_chunk_bytes, "chunk_size": len(self._serialized_items), "filename": filename, "interval": [self._indexes[0], self._indexes[-1] + 1], @@ -170,25 +176,21 @@ def _create_chunk(self, filename: str) -> bytes: return data - def write_chunk(self): + def write_chunk(self) -> None: """Write a chunk to the filesystem.""" if self._compression: - filename = f"chunk-{self.rank}-{self._chunk_id}.{self._compression}.bin" + filename = f"chunk-{self.rank}-{self._chunk_index}.{self._compression}.bin" else: - filename = f"chunk-{self.rank}-{self._chunk_id}.bin" - self.write_file(self._create_chunk(filename), filename) - - @property - def available_serializers(self): - return self._serializers + filename = f"chunk-{self.rank}-{self._chunk_index}.bin" + self.write_chunk_to_file(self._create_chunk(filename), filename) def reset(self) -> None: """Reset the writer to handle the next chunk.""" self._serialized_items = [] self._indexes = [] - self._current_chunk_size = 0 + self._current_chunk_bytes = 0 - def __setitem__(self, index, items: any): + def __setitem__(self, index: int, items: Any) -> None: """Store an item to a chunk. The index needs to be provided in order. @@ -196,56 +198,60 @@ def __setitem__(self, index, items: any): This is handled by the samplers automatically. This ensures we can map an index to a shard from an interval. """ + # Serialize the items serialized_items = self.serialize(items) serialized_items_size = len(serialized_items) - should_write = (self._chunk_bytes and self._chunk_bytes < self._current_chunk_size + serialized_items_size) or ( - self._chunk_size and len(self._indexes) >= self._chunk_size - ) + # Check whether it is time to write a chunk + should_write = ( + self._chunk_bytes and self._chunk_bytes < self._current_chunk_bytes + serialized_items_size + ) or (self._chunk_size and len(self._indexes) >= self._chunk_size) if should_write: - if self._current_chunk_size == 0: + if self._current_chunk_bytes == 0: raise Exception( f"The provided chunk_size {self._chunk_bytes} is too small." f" You should use a multiple of {serialized_items_size} bytes." ) self.write_chunk() self.reset() - self._chunk_id += 1 + self._chunk_index += 1 + # Store the serialized items into the chunk. self._serialized_items.append(serialized_items) - self._current_chunk_size += serialized_items_size + self._current_chunk_bytes += serialized_items_size + # Validate the index are provided in an incremental order + # This is required to ensure we can find efficiently a chunk index from an index using the chunk interval if self._indexes: assert self._indexes[-1] == index - 1, (self._indexes, index - 1) + # Store the index self._indexes.append(index) - def write_file( + def write_chunk_to_file( self, raw_data: bytes, filename: str, ) -> None: """Write chunk bytes to a file.""" + # Whether to compress the raw bytes if self._compression: raw_data = self._compressor.compress(raw_data) - filepath = os.path.join(self._cache_dir, filename) - with open(filepath, "wb") as out: + + # Write the binary chunk file + with open(os.path.join(self._cache_dir, filename), "wb") as out: out.write(raw_data) def write_chunks_index(self): """Write the chunks index to a JSON file.""" - filepath = os.path.join(self._cache_dir, f"{self.rank}.index.json") + filepath = os.path.join(self._cache_dir, f"{self.rank}.{INDEX_FILENAME}") config = self.get_config() with open(filepath, "w") as out: json.dump({"chunks": self._chunks_info, "config": config}, out, sort_keys=True) def done(self): - """Called when StopIteration is triggered. - - It tries to save the last chunk and write the chunks index. - - """ + """Called when StopIteration is triggered.""" if self.filled: return if self._serialized_items: @@ -254,22 +260,28 @@ def done(self): self.reset() self._is_done = True - def merge(self, num_workers: int): + def merge(self, num_workers: int) -> None: + """ "Once all the workers have written their own index, the merge function is responsible to read and merge them + into a single index.""" num_workers = num_workers or 1 + + # Only for non rank 0 if self.rank != 0: - while not os.path.exists(os.path.join(self._cache_dir, "index.json")): + while not os.path.exists(os.path.join(self._cache_dir, INDEX_FILENAME)): sleep(0.001) return + # Wait for all indexes to be available is_done = False while not is_done: files = os.listdir(self._cache_dir) - if "index.json" in files: + if INDEX_FILENAME in files: return - index_files = [f for f in files if f.endswith("index.json")] + index_files = [f for f in files if f.endswith(INDEX_FILENAME)] is_done = len(index_files) == self._distributed_env.world_size * num_workers sleep(0.001) + # Read the index and append the chunks together chunks_info = [] config = None for index_filename in sorted(index_files): @@ -285,5 +297,6 @@ def merge(self, num_workers: int): chunks_info.extend(data["chunks"]) - with open(os.path.join(self._cache_dir, "index.json"), "w") as f: + # Write down the collected index + with open(os.path.join(self._cache_dir, INDEX_FILENAME), "w") as f: json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True) From 17ce63ba544b13f3aea31e285f4b6e1be583d2bb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Oct 2023 12:41:50 +0000 Subject: [PATCH 62/84] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/data/cache/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index 4a1e7e693ecbd..fff1ca934f1cd 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -69,8 +69,7 @@ def __init__( chunk_size: int, compression: Optional[str], ): - """ - The `CacheDataset` is a dataset wraper to provide a beginner experience with the Cache. + """The `CacheDataset` is a dataset wraper to provide a beginner experience with the Cache. Arguments: dataset: The dataset of the user @@ -78,6 +77,7 @@ def __init__( chunk_bytes: The maximal number of bytes to write within a chunk. chunk_sie: The maximal number of items to write to a chunk. compression: The compression algorithm to use to reduce the size of the chunk. + """ self._datataset = dataset self._cache = Cache(cache_dir, chunk_bytes=chunk_bytes, chunk_size=chunk_size, compression=compression) From 1eac118ab3114350757f2fba6ad1e2b276b728cd Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 14:06:16 +0100 Subject: [PATCH 63/84] update --- src/lightning/data/cache/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 8573b29d7ffee..bfd40708b6d7d 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -261,7 +261,7 @@ def done(self): self._is_done = True def merge(self, num_workers: int) -> None: - """ "Once all the workers have written their own index, the merge function is responsible to read and merge them + """Once all the workers have written their own index, the merge function is responsible to read and merge them into a single index.""" num_workers = num_workers or 1 From c683a6b526cb7391d8b6f4e6c4c9088d3ad240f7 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:14:02 +0100 Subject: [PATCH 64/84] update --- pyproject.toml | 1 - src/lightning/data/cache/cache.py | 6 +- src/lightning/data/cache/compression.py | 10 +-- src/lightning/data/cache/config.py | 22 +++++-- src/lightning/data/cache/dataloader.py | 83 +++++++++++++------------ src/lightning/data/cache/downloader.py | 6 +- src/lightning/data/cache/reader.py | 78 ++++++++++------------- src/lightning/data/cache/sampler.py | 61 ++++++++---------- src/lightning/data/cache/serializers.py | 25 ++++---- src/lightning/data/cache/writer.py | 27 ++++---- 10 files changed, 155 insertions(+), 164 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 50e80be9f5566..43ef2fc0195f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,7 +138,6 @@ exclude = [ "src/lightning/app/cli/pl-app-template", "src/lightning/app/cli/react-ui-template", "src/lightning/app/launcher", - "src/lightning/data/cache", ] install_types = "True" non_interactive = "True" diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index e16c8046cba8b..83fbe43b854ec 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -13,7 +13,7 @@ import logging import os -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from lightning.data.cache.constants import INDEX_FILENAME from lightning.data.cache.reader import BinaryReader @@ -66,7 +66,7 @@ def filled(self) -> bool: self._is_done = os.path.exists(os.path.join(self._cache_dir, INDEX_FILENAME)) return self._is_done - def __setitem__(self, index, data) -> None: + def __setitem__(self, index: int, data: Any) -> None: """Store an item in the writer.""" self._writer[index] = data @@ -87,7 +87,7 @@ def merge(self, num_workers: int = 1) -> None: def __len__(self) -> int: return self._reader.get_length() - def get_chunk_interval(self): + def get_chunk_interval(self) -> List[Tuple[int, int]]: return self._reader.get_chunk_interval() def _get_chunk_index_from_index(self, index: int) -> int: diff --git a/src/lightning/data/cache/compression.py b/src/lightning/data/cache/compression.py index 731abb979ceb0..68fbc2eaf3975 100644 --- a/src/lightning/data/cache/compression.py +++ b/src/lightning/data/cache/compression.py @@ -36,7 +36,7 @@ def decompress(self, data: bytes) -> bytes: pass @abstractclassmethod - def register(cls, compressors: Dict[str, TCompressor]): + def register(cls, compressors: Dict[str, "Compressor"]) -> None: pass @@ -44,13 +44,13 @@ class ZSTDCompressor(Compressor): """Compressor for the zstd package.""" @requires("zstd") - def __init__(self, level): + def __init__(self, level: int) -> None: super().__init__() self.level = level self.extension = "zstd" @property - def name(self): + def name(self) -> str: return f"{self.extension}:{self.level}" def compress(self, data: bytes) -> bytes: @@ -60,7 +60,7 @@ def decompress(self, data: bytes) -> bytes: return zstd.decompress(data) @classmethod - def register(cls, compressors): + def register(cls, compressors: Dict[str, "Compressor"]) -> None: # type: ignore if not _ZSTD_AVAILABLE: return @@ -71,6 +71,6 @@ def register(cls, compressors): compressors[f"zstd:{level}"] = ZSTDCompressor(level) -_COMPRESSORS = {} +_COMPRESSORS: Dict[str, Compressor] = {} ZSTDCompressor.register(_COMPRESSORS) diff --git a/src/lightning/data/cache/config.py b/src/lightning/data/cache/config.py index 7360d01aea2bf..0757622cd8353 100644 --- a/src/lightning/data/cache/config.py +++ b/src/lightning/data/cache/config.py @@ -32,7 +32,7 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str]): """ self._cache_dir = cache_dir - self._intervals = [] + self._intervals: List[Tuple[int, int]] = [] self._config = None self._chunks = [] self._remote_dir = remote_dir @@ -52,11 +52,12 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str]): raise Exception( "The config intervals doesn't match the number of samples. This shouldn't have happened." ) - self._intervals.append(chunk["interval"]) + self._intervals.append((chunk["interval"][0], chunk["interval"][1])) self._length = sum([chunk["chunk_size"] for chunk in self._chunks]) self._downloader = None + if remote_dir: self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks) @@ -66,20 +67,29 @@ def download_chunk_from_index(self, chunk_index: int) -> None: local_chunkpath = os.path.join(self._cache_dir, chunk_filename) if os.path.exists(local_chunkpath): - return None + return + + if self._downloader is None: + raise RuntimeError("The downloader should be defined.") - return self._downloader.download_chunk_from_index(chunk_index) + self._downloader.download_chunk_from_index(chunk_index) @property - def intervals(self) -> List[List[int]]: + def intervals(self) -> List[Tuple[int, int]]: + if self._intervals is None: + raise RuntimeError("The intervals should be defined.") return self._intervals @property - def data_format(self) -> List[str]: + def data_format(self) -> Any: + if self._config is None: + raise RuntimeError("The config should be defined.") return self._config["data_format"] @property def config(self) -> Dict[str, Any]: + if self._config is None: + raise RuntimeError("The config should be defined.") return self._config def _get_chunk_index_from_index(self, index: int) -> int: diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index fff1ca934f1cd..bb9f527ea24c9 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -16,18 +16,21 @@ import logging import os from importlib import reload -from typing import Any, Callable, Optional +from typing import Any, Callable, List, Optional import torch from lightning_utilities.core.imports import RequirementCache from torch.utils.data import Dataset, IterableDataset from torch.utils.data._utils.collate import default_collate +from torch.utils.data._utils.fetch import _BaseDatasetFetcher from torch.utils.data.dataloader import ( DataLoader, _BaseDataLoaderIter, + _DatasetKind, _MultiProcessingDataLoaderIter, _SingleProcessDataLoaderIter, ) +from torch.utils.data.sampler import BatchSampler, Sampler from lightning.data.cache import Cache from lightning.data.cache.pytree import tree_flatten @@ -51,22 +54,24 @@ def _equal_items(data_1: Any, data_2: Any) -> bool: return all(_equal_item(d1, d2) for d1, d2 in zip(data_1_flattened, data_2_flattened)) -def _equal_item(d1, d2) -> bool: +def _equal_item(d1: Any, d2: Any) -> bool: if not isinstance(d1, type(d2)): - raise False + return False equality = d1 == d2 if isinstance(equality, torch.Tensor): - return equality.all() - return equality + return bool(equality.all().item()) + if equality is True: + return True + return False class CacheDataset(Dataset): def __init__( self, - dataset: Dataset, - cache_dir: Optional[str], + dataset: Any, + cache_dir: str, chunk_bytes: Optional[int], - chunk_size: int, + chunk_size: Optional[int], compression: Optional[str], ): """The `CacheDataset` is a dataset wraper to provide a beginner experience with the Cache. @@ -79,18 +84,18 @@ def __init__( compression: The compression algorithm to use to reduce the size of the chunk. """ - self._datataset = dataset + self._dataset = dataset self._cache = Cache(cache_dir, chunk_bytes=chunk_bytes, chunk_size=chunk_size, compression=compression) self._is_deterministic = False def __len__(self) -> int: - return len(self._cache) if self._cache.filled else len(self._datataset) + return len(self._cache) if self._cache.filled else len(self._dataset) - def __getitem__(self, index): - data_1 = self._cache[index] if self._cache.filled else self._datataset[index] + def __getitem__(self, index: int) -> Any: + data_1 = self._cache[index] if self._cache.filled else self._dataset[index] if not self._cache.filled: if not self._is_deterministic: - data2 = self._datataset[index] + data2 = self._dataset[index] if not _equal_items(data_1, data2): raise ValueError( f"Your dataset items aren't deterministic. Found {data_1} and {data2} for index {index}." @@ -110,10 +115,10 @@ class CacheCollateFn: """ - def __init__(self, collate_fn: Optional[Callable] = None): + def __init__(self, collate_fn: Optional[Callable] = None) -> None: self.collate_fn = collate_fn or default_collate - def __call__(self, items): + def __call__(self, items: List[Any]) -> Any: if all(item is None for item in items): return None @@ -129,7 +134,7 @@ def __call__(self, items): class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): """This is overriden to inform the cache is done chunking.""" - def _next_data(self): + def _next_data(self) -> Any: try: data = None while data is None: @@ -151,8 +156,7 @@ def __init__(self, global_rank: int, profile: bool = False) -> None: self._global_rank = global_rank self._profile = profile - def __call__(self, dataset_kind, *args, **kwargs): - from torch.utils.data import _DatasetKind + def __call__(self, dataset_kind: _DatasetKind, *args: Any, **kwargs: Any) -> None: from torch.utils.data._utils import worker from lightning.data.cache.cache import Cache @@ -171,16 +175,17 @@ def __call__(self, dataset_kind, *args, **kwargs): create_fetcher = _DatasetKind.create_fetcher fetcher = None - def create_fetcher_fn(*args, **kwargs): + def create_fetcher_fn(*args: Any, **kwargs: Any) -> "_BaseDatasetFetcher": nonlocal fetcher fetcher = create_fetcher(*args, **kwargs) return fetcher - _DatasetKind.create_fetcher = create_fetcher_fn + _DatasetKind.create_fetcher = create_fetcher_fn # type: ignore reloaded_worker._worker_loop(dataset_kind, *args, **kwargs) if dataset_kind == _DatasetKind.Map: + assert fetcher for v in fetcher.dataset.__dict__.values(): if isinstance(v, Cache): v.done() @@ -191,7 +196,7 @@ def create_fetcher_fn(*args, **kwargs): class _MultiProcessingDataLoaderIterPatch(_MultiProcessingDataLoaderIter): - def __init__(self, loader): + def __init__(self, loader: DataLoader) -> None: self._cache = loader._cache self._num_workers = loader.num_workers # Patch PyTorch worker loop to call the `cache.done()` method. @@ -200,14 +205,14 @@ def __init__(self, loader): worker._worker_loop = WorkerLoop(loader._global_rank, loader._profile) super().__init__(loader) - def _shutdown_workers(self): + def _shutdown_workers(self) -> None: super()._shutdown_workers() # If the data isn't filled, we trigger an indedm merge if not self._cache.filled: self._cache.merge(self._num_workers) - def _next_data(self): + def _next_data(self) -> Any: try: data = None while data is None: @@ -222,22 +227,22 @@ class LightningDataLoader(DataLoader): def __init__( self, - dataset, - *args, - sampler=None, - batch_sampler=None, - num_workers=0, + dataset: Any, + *args: Any, + sampler: Optional[Sampler] = None, + batch_sampler: Optional[BatchSampler] = None, + num_workers: int = 0, shuffle: bool = False, - generator=None, - batch_size=None, - drop_last=False, + generator: Optional[torch.Generator] = None, + batch_size: Optional[int] = None, + drop_last: bool = False, cache_dir: Optional[str] = None, chunk_bytes: Optional[int] = _DEFAULT_CHUNK_BYTES, compression: Optional[str] = None, profile: bool = False, collate_fn: Optional[Callable] = None, - **kwargs, - ): + **kwargs: Any, + ) -> None: if sampler: raise ValueError( "The LightningDataLoader relies on its own internal sampler. Passing a sampler isn't supported." @@ -254,21 +259,21 @@ def __init__( if profile and not _VIZ_TRACKER_AVAILABLE: raise ModuleNotFoundError("To enable DataLoader profiling, run `pip install viztracer`.") - cache = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] + cache_list = [v for v in dataset.__dict__.values() if isinstance(v, Cache)] - if len(cache) > 1: + if len(cache_list) > 1: raise ValueError( "We found several Cache used as attributes from your dataset. Only one is support for now." ) - if len(cache) == 0: + if len(cache_list) == 0: if cache_dir is None: raise ValueError("You should provide a `cache_dir` filepath to the LightningDataLoader.") dataset = CacheDataset(dataset, cache_dir, chunk_bytes, batch_size, compression) cache = dataset._cache else: - cache = cache[0] + cache = cache_list[0] cache._setup(num_workers) @@ -296,9 +301,7 @@ def __init__( super().__init__( dataset, *args, - sampler=None, - batch_sampler=batch_sampler, - generator=generator, + batch_sampler=batch_sampler, # type: ignore collate_fn=CacheCollateFn(collate_fn), num_workers=num_workers, **kwargs, diff --git a/src/lightning/data/cache/downloader.py b/src/lightning/data/cache/downloader.py index 21e2abfbc8f7f..7e83d76859c45 100644 --- a/src/lightning/data/cache/downloader.py +++ b/src/lightning/data/cache/downloader.py @@ -26,7 +26,7 @@ def download_chunk_from_index(self, chunk_index: int) -> None: chunk_filename = self._chunks[chunk_index]["filename"] local_chunkpath = os.path.join(self._cache_dir, chunk_filename) remote_chunkpath = os.path.join(self._remote_dir, chunk_filename) - return self.download_file(remote_chunkpath, local_chunkpath) + self.download_file(remote_chunkpath, local_chunkpath) @abstractmethod def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: @@ -35,7 +35,7 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: class S3Downloader(Downloader): @classmethod - def downldownload_fileoad_file_from_s3(cls, remote_filepath: str, local_filepath: str): + def downldownload_fileoad_file_from_s3(cls, remote_filepath: str, local_filepath: str) -> None: import boto3 from boto3.s3.transfer import TransferConfig from botocore.config import Config @@ -45,7 +45,7 @@ def downldownload_fileoad_file_from_s3(cls, remote_filepath: str, local_filepath if obj.scheme != "s3": raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}") - extra_args = {} + extra_args: Dict[str, Any] = {} # Create a new session per thread session = boto3.session.Session() diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index cdfc5f9c03f97..7724512cb9de9 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -12,15 +12,14 @@ # limitations under the License. import os -from contextlib import contextmanager from threading import Lock, Thread -from time import sleep, time -from typing import Any, Dict, List, Optional +from time import sleep +from typing import Any, Dict, List, Optional, Tuple import numpy as np from lightning.data.cache.config import ChunksConfig -from lightning.data.cache.pytree import tree_unflatten +from lightning.data.cache.pytree import PyTree, tree_unflatten from lightning.data.cache.sampler import ChunkedIndex from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -29,11 +28,11 @@ class PrepareChunksThread(Thread): """This thread is responsible to download the chunks associated to a given worker.""" - def __init__(self, config: ChunksConfig): + def __init__(self, config: ChunksConfig) -> None: super().__init__(daemon=True) self._config = config - self._chunks_index_to_be_processed = [] - self._chunks_index_to_ready = [] + self._chunks_index_to_be_processed: List[int] = [] + self._chunks_index_to_ready: List[int] = [] self._lock = Lock() def add(self, chunk_indices: List[int]) -> None: @@ -41,7 +40,7 @@ def add(self, chunk_indices: List[int]) -> None: with self._lock: self._chunks_index_to_be_processed.extend(chunk_indices) - def run(self): + def run(self) -> None: while True: with self._lock: if len(self._chunks_index_to_be_processed) == 0: @@ -55,7 +54,7 @@ def run(self): class BinaryReader: - def __init__(self, cache_dir: str, remote_dir: Optional[str] = None, compression: Optional[str] = None): + def __init__(self, cache_dir: str, remote_dir: Optional[str] = None, compression: Optional[str] = None) -> None: """The BinaryReader enables to read chunked dataset in an efficient way. Arguments: @@ -72,50 +71,41 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str] = None, compression raise FileNotFoundError(f"The provided cache_dir `{self._cache_dir}` doesn't exist.") self._compression = compression - self._config = None - self._intervals = None + self._intervals: Optional[List[str]] = None - self._chunks_data = {} self._serializers: Dict[str, Serializer] = _SERIALIZERS - self._distributed_env = _DistributedEnv.detect() - self._rank = None + self._rank: Optional[int] = None self._config: Optional[ChunksConfig] = None - self._latest_chunk_index = None - self._executor = None - self._prepare_thread = None + self._prepare_thread: Optional[PrepareChunksThread] = None - def _get_chunk_index_from_index(self, index: int): + def _get_chunk_index_from_index(self, index: int) -> int: # Load the config containing the index - if self._config is None: - self._try_load_config() - - if self._config is None: - raise Exception("The reader index isn't defined.") + if self._config is None and self._try_load_config() is None: + raise Exception("The reader index isn't defined.") - return self._config._get_chunk_index_from_index(index) + return self._config._get_chunk_index_from_index(index) # type: ignore - def _try_load_config(self): + def _try_load_config(self) -> Optional[ChunksConfig]: """Try to load the chunks config if the index files are available.""" self._config = ChunksConfig.load(self._cache_dir, self._remote_dir) return self._config @property - def rank(self): + def config(self) -> ChunksConfig: + if self._config is None: + raise RuntimeError("The config should be defined.") + return self._config + + @property + def rank(self) -> int: """Returns the rank of the writer.""" if self._rank is None: self._worker_env = _WorkerEnv.detect() self._rank = self._distributed_env.global_rank * self._worker_env.world_size + self._worker_env.rank return self._rank - @contextmanager - def measure_on_rank_0(self, msg: str): - if self.rank == 0: - t0 = time() - yield - print(msg, time() - t0) - - def read(self, index: ChunkedIndex): + def read(self, index: ChunkedIndex) -> Any: """Read an item for the given from a chunk. If the chunk isn't available locally or in memory, it will be downloaded. @@ -131,29 +121,29 @@ def read(self, index: ChunkedIndex): raise Exception("The reader index isn't defined.") # Create and start the prepare chunks thread - if index.chunk_indexes is not None and self._prepare_thread is None: + if index.chunk_indexes is not None and self._prepare_thread is None and self._config: self._prepare_thread = PrepareChunksThread(self._config) self._prepare_thread.start() self._prepare_thread.add(index.chunk_indexes) # Fetch the element - chunk_filepath, begin, _ = self._config[index] + chunk_filepath, begin, _ = self.config[index] raw_item_data = self.load_item_from_chunk(index.index, chunk_filepath, begin) return self.deserialize(raw_item_data) - def deserialize(self, raw_item_data: bytes) -> Any: + def deserialize(self, raw_item_data: bytes) -> PyTree: """Deserialize the raw bytes into their python equivalent.""" - idx = len(self._config.data_format) * 4 + idx = len(self.config.data_format) * 4 sizes = np.frombuffer(raw_item_data[:idx], np.uint32) data = [] - for size, data_format in zip(sizes, self._config.data_format): + for size, data_format in zip(sizes, self.config.data_format): serializer = self._serializers[data_format] data_bytes = raw_item_data[idx : idx + size] data.append(serializer.deserialize(data_bytes)) idx += size - return tree_unflatten(data, self._config.config["data_spec"]) + return tree_unflatten(data, self.config.config["data_spec"]) - def load_item_from_chunk(self, index: int, chunk_filepath: str, begin: int): + def load_item_from_chunk(self, index: int, chunk_filepath: str, begin: int) -> bytes: offset = (1 + (index - begin)) * 4 while not os.path.exists(chunk_filepath): @@ -172,11 +162,11 @@ def get_length(self) -> int: if self._config is None and self._try_load_config() is None: raise Exception("The reader index isn't defined.") - return len(self._config) + return len(self.config) - def get_chunk_interval(self): + def get_chunk_interval(self) -> List[Tuple[int, int]]: """Get the index interval of each chunk.""" if self._config is None and self._try_load_config() is None: raise Exception("The reader index isn't defined.") - return self._config.intervals + return self.config.intervals diff --git a/src/lightning/data/cache/sampler.py b/src/lightning/data/cache/sampler.py index d65acfe81bb7c..fe88a2cf2c316 100644 --- a/src/lightning/data/cache/sampler.py +++ b/src/lightning/data/cache/sampler.py @@ -13,7 +13,7 @@ import logging from dataclasses import dataclass -from typing import Any, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional, Union import numpy as np @@ -37,7 +37,7 @@ def __init__( batch_size: int, drop_last: bool, shuffle: bool, - cache: any, + cache: Any, ): """The CacheBatchSampler handles the generation of batch indices. @@ -64,13 +64,18 @@ def __init__( self._shuffled_chunk_intervals = None self._batch_size = batch_size + self._drop_last = drop_last + self._length = 0 + # Before starting, ensures the chunk indices are properly defined. self._validate() - def _validate(self): + def _validate(self) -> None: + """Checks each worker is getting sucessive indices.""" if self._num_workers > 1 and not self._cache.filled: - batches = {} + batches: Dict[int, Any] = {} for batch_index, batch_indices in enumerate(self): + self._length += 1 worker_index = batch_index % self._num_workers if worker_index not in batches: batches[worker_index] = [] @@ -84,25 +89,7 @@ def _validate(self): if diff.sum() != 0: raise RuntimeError("This shouldn't have happened. There is a bug in the CacheSampler.") - def __iter_ordered__(self) -> Iterator[List[int]]: - # Implemented based on the benchmarking in https://github.com/pytorch/pytorch/pull/76951 - iterator = iter(self.sampler) - batch = [] - while not self.sampler.done: - try: - idx = next(iterator) - batch.append(idx) - if len(batch) == self.batch_size: - yield batch - batch = [] - except StopIteration: - if self.sampler.done: - yield batch - return - yield batch - batch = [] - - def __iter__(self): + def __iter__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: # When the cache is filled, we need to iterate though the chunks if self._cache.filled: if self._num_replicas == 1: @@ -114,9 +101,8 @@ def __iter__(self): return self.__iter_non_distributed__() return self.__iter_distributed__() - def __iter_non_distributed__(self): + def __iter_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: worker_size = self._dataset_size // self._num_workers - self.samplers = [] indices = list(range(self._dataset_size)) worker_indices = [] for worker_idx in range(self._num_workers): @@ -131,11 +117,10 @@ def __iter_non_distributed__(self): yield from self.__iter_indices_per_workers__(worker_indices_batches) - def __iter_distributed__(self): + def __iter_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: self.indices = list(range(self._dataset_size)) replica_size = self._dataset_size // self._num_replicas worker_size = self._dataset_size // (self._num_replicas * self._num_workers) - self.samplers = [] for rank in range(self._num_replicas): if rank != self._global_rank: continue @@ -160,13 +145,13 @@ def __iter_distributed__(self): yield from self.__iter_indices_per_workers__(worker_indices_batches) - def __iter_from_chunks_non_distributed__(self): + def __iter_from_chunks_non_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: chunk_intervals = self._cache.get_chunk_interval() shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] - yield from self.__iter_from_shuffled_chunks(shuffled_indexes, shuffled_chunk_intervals) + yield from self.__iter_from_shuffled_chunks(shuffled_indexes.tolist(), shuffled_chunk_intervals) - def __iter_from_chunks_distributed__(self): + def __iter_from_chunks_distributed__(self) -> Iterator[List[Union[int, ChunkedIndex]]]: chunk_intervals = self._cache.get_chunk_interval() shuffled_indexes = np.random.permutation(range(len(chunk_intervals))) shuffled_chunk_intervals = np.asarray(chunk_intervals)[shuffled_indexes] @@ -180,12 +165,14 @@ def __iter_from_chunks_distributed__(self): yield from self.__iter_from_shuffled_chunks(replica_chunks, replica_intervals) - def __iter_from_shuffled_chunks(self, shuffled_indexes, shuffled_chunk_intervals): - chunks_per_workers = [[] for _ in range(self._num_workers)] + def __iter_from_shuffled_chunks( + self, shuffled_indexes: List[int], shuffled_chunk_intervals: List[List[int]] + ) -> Iterator[List[Union[int, ChunkedIndex]]]: + chunks_per_workers: List[List[int]] = [[] for _ in range(self._num_workers)] for i, chunk_index in enumerate(shuffled_indexes): chunks_per_workers[i % self._num_workers].append(chunk_index) - indices_per_workers = [[] for _ in range(self._num_workers)] + indices_per_workers: List[List[ChunkedIndex]] = [[] for _ in range(self._num_workers)] for i, (chunk_index, chunk_interval) in enumerate(zip(shuffled_indexes, shuffled_chunk_intervals)): worker_id = i % self._num_workers @@ -208,10 +195,12 @@ def __iter_from_shuffled_chunks(self, shuffled_indexes, shuffled_chunk_intervals yield from self.__iter_indices_per_workers__(indices_per_workers_splitted) def __len__(self) -> int: - return super().__len__() + return self._length - def __iter_indices_per_workers__(self, indices_per_workers): - batches = [] + def __iter_indices_per_workers__( + self, indices_per_workers: List[List[List[Union[int, ChunkedIndex]]]] + ) -> Iterator[List[Union[int, ChunkedIndex]]]: + batches: List[List[Union[int, ChunkedIndex]]] = [] counter = 0 while sum([len(v) for v in indices_per_workers]) != 0: worker_indices = indices_per_workers[counter % self._num_workers] diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index 47b3c87981c98..f85027201f3bc 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -15,6 +15,7 @@ from abc import ABC, abstractmethod from collections import OrderedDict from io import BytesIO +from typing import Any import numpy as np import torch @@ -38,15 +39,15 @@ class Serializer(ABC): """ @abstractmethod - def serialize(self, data: any) -> bytes: + def serialize(self, data: Any) -> bytes: pass @abstractmethod - def deserialize(self, data: bytes) -> any: + def deserialize(self, data: bytes) -> Any: pass @abstractmethod - def can_serialize(self, data: any) -> bool: + def can_serialize(self, data: Any) -> bool: pass @@ -60,7 +61,7 @@ def serialize(self, item: Image) -> bytes: ints = np.array([width, height, len(mode)], np.uint32) return ints.tobytes() + mode + raw - def deserialize(self, data: bytes) -> any: + def deserialize(self, data: bytes) -> Any: idx = 3 * 4 width, height, mode_size = np.frombuffer(data[:idx], np.uint32) idx2 = idx + mode_size @@ -69,7 +70,7 @@ def deserialize(self, data: bytes) -> any: raw = data[idx2:] return Image.frombytes(mode, size, raw) # pyright: ignore - def can_serialize(self, item) -> bool: + def can_serialize(self, item: Any) -> bool: return isinstance(item, Image.Image) and not isinstance(item, JpegImageFile) @@ -82,7 +83,7 @@ def serialize(self, item: int) -> bytes: def deserialize(self, data: bytes) -> int: return int(data.decode("utf-8")) - def can_serialize(self, item) -> bool: + def can_serialize(self, item: Any) -> bool: return isinstance(item, int) @@ -103,7 +104,7 @@ def deserialize(self, data: bytes) -> Image: inp = BytesIO(data) return Image.open(inp) - def can_serialize(self, item) -> bool: + def can_serialize(self, item: Any) -> bool: return isinstance(item, JpegImageFile) @@ -147,7 +148,7 @@ def can_serialize(self, item: bytes) -> bool: class TensorSerializer(Serializer): """The TensorSerializer serialize and deserialize tensor to and from bytes.""" - def __init__(self): + def __init__(self) -> None: super().__init__() self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()} @@ -177,14 +178,14 @@ def can_serialize(self, item: torch.Tensor) -> bool: class PickleSerializer(Serializer): """The PickleSerializer serialize and deserialize python objects to and from bytes.""" - def serialize(self, item: any) -> bytes: + def serialize(self, item: Any) -> bytes: return pickle.dumps(item) - def deserialize(self, data: bytes) -> any: + def deserialize(self, data: bytes) -> Any: return pickle.loads(data) - def can_serialize(self, item: any) -> bool: - return isinstance(item, any) + def can_serialize(self, _: Any) -> bool: + return True _SERIALIZERS = OrderedDict( diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index bfd40708b6d7d..879e7cbac16e3 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -20,7 +20,7 @@ from lightning.data.cache.compression import _COMPRESSORS, Compressor from lightning.data.cache.constants import INDEX_FILENAME -from lightning.data.cache.pytree import tree_flatten, treespec_dumps +from lightning.data.cache.pytree import PyTree, tree_flatten, treespec_dumps from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv @@ -55,9 +55,8 @@ def __init__( self._chunk_bytes = chunk_bytes self._compression = compression - self._data_format = None - self._data_spec = None - self._num_workers = None + self._data_format: Optional[List[str]] = None + self._data_spec: Optional[PyTree] = None if self._compression: if len(_COMPRESSORS) == 0: @@ -71,10 +70,10 @@ def __init__( self._current_chunk_bytes = 0 self._chunk_index = 0 self._serialized_items: List[bytes] = [] - self._chunks_info = [] + self._chunks_info: List[Dict[str, Any]] = [] self._indexes: List[int] = [] - self._worker_env = None - self._rank = None + self._worker_env: Optional[_WorkerEnv] = None + self._rank: Optional[int] = None self._is_done = False self._distributed_env = _DistributedEnv.detect() @@ -90,7 +89,7 @@ def filled(self) -> bool: return self._is_done @property - def rank(self): + def rank(self) -> int: """Returns the rank of the writer.""" if self._rank is None: self._worker_env = _WorkerEnv.detect() @@ -115,10 +114,10 @@ def serialize(self, items: Any) -> bytes: flattened, data_spec = tree_flatten(items) # Collect the sizes and associated bytes for each item - sizes = [] - data = [] + sizes: List[int] = [] + data: List[bytes] = [] - data_format = [] + data_format: List[str] = [] for item in flattened: data_format.append(self._serialize(item, sizes, data)) @@ -139,7 +138,7 @@ def serialize(self, items: Any) -> bytes: body = b"".join(data) return head + body - def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> None: + def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str: """Serialize a given item and append its size and bytes to the sizes and data array.""" for serializer_name, serializer in self._serializers.items(): if serializer.can_serialize(item): @@ -243,14 +242,14 @@ def write_chunk_to_file( with open(os.path.join(self._cache_dir, filename), "wb") as out: out.write(raw_data) - def write_chunks_index(self): + def write_chunks_index(self) -> None: """Write the chunks index to a JSON file.""" filepath = os.path.join(self._cache_dir, f"{self.rank}.{INDEX_FILENAME}") config = self.get_config() with open(filepath, "w") as out: json.dump({"chunks": self._chunks_info, "config": config}, out, sort_keys=True) - def done(self): + def done(self) -> None: """Called when StopIteration is triggered.""" if self.filled: return From c6db7d9a776ccfbcbe28414727be5d0f615016ad Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:17:00 +0100 Subject: [PATCH 65/84] update --- src/lightning/data/cache/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index bb9f527ea24c9..fbfadf6b2b842 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -275,6 +275,7 @@ def __init__( else: cache = cache_list[0] + # This is required in the main thread. cache._setup(num_workers) if not cache.filled and shuffle: From 5fdd0a1600e9e2e9dcc8dd0fdcb294d3206e6e13 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:17:25 +0100 Subject: [PATCH 66/84] update --- src/lightning/data/cache/cache.py | 7 ------- src/lightning/data/cache/dataloader.py | 3 --- 2 files changed, 10 deletions(-) diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 83fbe43b854ec..964dcd497f9f2 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -50,17 +50,10 @@ def __init__( self._cache_dir = cache_dir self._is_done = False self._distributed_env = _DistributedEnv.detect() - self._num_workers: Optional[int] = None - - def _setup(self, num_workers: int) -> None: - """Called by the LightningDataLoader to ensure the num_workers is known.""" - self._num_workers = num_workers @property def filled(self) -> bool: """Returns whether the caching phase is done.""" - if self._num_workers is None: - raise Exception("The Cache wasn't setup properly. HINT: Did you use the LightningDataLoader ?") if self._is_done: return True self._is_done = os.path.exists(os.path.join(self._cache_dir, INDEX_FILENAME)) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index fbfadf6b2b842..e2acfa1d39d24 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -275,9 +275,6 @@ def __init__( else: cache = cache_list[0] - # This is required in the main thread. - cache._setup(num_workers) - if not cache.filled and shuffle: logger.info("Shuffle is ignored during the caching phase phase") From 3c63d30168743f671f0bfa2d4f532be3d0b58158 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:17:30 +0100 Subject: [PATCH 67/84] update --- src/lightning/data/cache/dataloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index e2acfa1d39d24..a6a31d1df78e2 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -276,7 +276,7 @@ def __init__( cache = cache_list[0] if not cache.filled and shuffle: - logger.info("Shuffle is ignored during the caching phase phase") + logger.info("Shuffle is ignored during the caching phase phase.") self._cache = cache From 6ca0b5bfa0d41ad6b396fc64001c483bef4da6c4 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:24:15 +0100 Subject: [PATCH 68/84] update --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 43ef2fc0195f7..b528dfa8488d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,6 +138,7 @@ exclude = [ "src/lightning/app/cli/pl-app-template", "src/lightning/app/cli/react-ui-template", "src/lightning/app/launcher", + "src/lightning/data/cache/pytree" ] install_types = "True" non_interactive = "True" From 0bdacbb894909b72f95ed15f8e76a028a9d4c906 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:25:20 +0100 Subject: [PATCH 69/84] update --- src/lightning/data/cache/reader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 7724512cb9de9..047cf135e603c 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -46,9 +46,9 @@ def run(self) -> None: if len(self._chunks_index_to_be_processed) == 0: sleep(0.007) continue - chunk_index = self._chunks_index_to_be_processed.pop(0) - # TODO: Implement eviction + # TODO: Implement chunk eviction + chunk_index = self._chunks_index_to_be_processed.pop(0) self._config.download_chunk_from_index(chunk_index) self._chunks_index_to_ready.append(chunk_index) From 91350047d5c6f99f0582fcf2bd074983962305e8 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:25:37 +0100 Subject: [PATCH 70/84] update --- src/lightning/data/cache/reader.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 047cf135e603c..e4f713b939d8b 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -47,8 +47,9 @@ def run(self) -> None: sleep(0.007) continue - # TODO: Implement chunk eviction - chunk_index = self._chunks_index_to_be_processed.pop(0) + chunk_index = self._chunks_index_to_be_processed.pop(0) + + # TODO: Implement eviction self._config.download_chunk_from_index(chunk_index) self._chunks_index_to_ready.append(chunk_index) From e356838f2f15d2b007759f83c0d4f7aef257c8f8 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:33:48 +0100 Subject: [PATCH 71/84] update --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b528dfa8488d0..49657b3c4a7a4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,7 +138,7 @@ exclude = [ "src/lightning/app/cli/pl-app-template", "src/lightning/app/cli/react-ui-template", "src/lightning/app/launcher", - "src/lightning/data/cache/pytree" + "src/lightning/data/cache/pytree.py", ] install_types = "True" non_interactive = "True" From c4ddeb8d47dbcf90b127417470ad6634e467be0f Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 15:37:53 +0100 Subject: [PATCH 72/84] update --- src/lightning/data/cache/pytree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/pytree.py b/src/lightning/data/cache/pytree.py index 06b6fd44534f9..3e2692aa6156f 100644 --- a/src/lightning/data/cache/pytree.py +++ b/src/lightning/data/cache/pytree.py @@ -186,7 +186,7 @@ def _is_namedtuple_instance(pytree: Any) -> bool: fields = getattr(typ, "_fields", None) if not isinstance(fields, tuple): return False - return all(type(entry) == str for entry in fields) + return all(isinstance(entry, str) for entry in fields) def _get_node_type(pytree: Any) -> Any: From da985f245061cf156f6aca96961e838dcc48d2d8 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 17:50:29 +0100 Subject: [PATCH 73/84] update --- src/lightning/data/cache/writer.py | 4 +++- tests/tests_data/cache/test_writer.py | 9 ++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 879e7cbac16e3..02e6f8159d7a4 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -259,7 +259,7 @@ def done(self) -> None: self.reset() self._is_done = True - def merge(self, num_workers: int) -> None: + def merge(self, num_workers: int = 1) -> None: """Once all the workers have written their own index, the merge function is responsible to read and merge them into a single index.""" num_workers = num_workers or 1 @@ -296,6 +296,8 @@ def merge(self, num_workers: int) -> None: chunks_info.extend(data["chunks"]) + os.remove(chunk_path) + # Write down the collected index with open(os.path.join(self._cache_dir, INDEX_FILENAME), "w") as f: json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True) diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index 57d541a1d3166..67d14db450b04 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -38,9 +38,10 @@ def test_binary_writer_with_ints_and_chunk_bytes(tmpdir): assert len(os.listdir(tmpdir)) == 19 binary_writer.done() + binary_writer.merge() assert len(os.listdir(tmpdir)) == 21 - with open(os.path.join(tmpdir, "0.index.json")) as f: + with open(os.path.join(tmpdir, "index.json")) as f: data = json.load(f) assert data["chunks"][0]["chunk_size"] == 6 @@ -73,9 +74,10 @@ def test_binary_writer_with_ints_and_chunk_size(tmpdir): assert len(os.listdir(tmpdir)) == 3 binary_writer.done() + binary_writer.merge() assert len(os.listdir(tmpdir)) == 5 - with open(os.path.join(tmpdir, "0.index.json")) as f: + with open(os.path.join(tmpdir, "index.json")) as f: data = json.load(f) assert data["chunks"][0]["chunk_size"] == 25 @@ -110,9 +112,10 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): assert len(os.listdir(cache_dir)) == 24 binary_writer.done() + binary_writer.merge() assert len(os.listdir(cache_dir)) == 26 - with open(os.path.join(cache_dir, "0.index.json")) as f: + with open(os.path.join(cache_dir, "index.json")) as f: data = json.load(f) assert data["chunks"][0]["chunk_size"] == 4 From 0e5342c76a6b1f19cba272fc2c3a6e28925c5708 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 6 Oct 2023 18:47:17 +0100 Subject: [PATCH 74/84] update --- src/lightning/data/cache/downloader.py | 2 +- tests/tests_data/cache/test_serializer.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/data/cache/downloader.py b/src/lightning/data/cache/downloader.py index 7e83d76859c45..9c4a91188155a 100644 --- a/src/lightning/data/cache/downloader.py +++ b/src/lightning/data/cache/downloader.py @@ -35,7 +35,7 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: class S3Downloader(Downloader): @classmethod - def downldownload_fileoad_file_from_s3(cls, remote_filepath: str, local_filepath: str) -> None: + def download_file(cls, remote_filepath: str, local_filepath: str) -> None: import boto3 from boto3.s3.transfer import TransferConfig from botocore.config import Config diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index ea4f447a6aeae..d692f964874a4 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -12,6 +12,7 @@ # limitations under the License. import os +import sys from time import time import numpy as np @@ -104,6 +105,7 @@ def test_pil_serializer(mode): assert np.array_equal(np_data, np_dec_data) +@pytest.mark.skipif(sys.platform == "win32", reason="Not supported on windows") def test_tensor_serializer(): seed_everything(42) From 3aed4c5c2d83d7a5c2e15579b94611cf9e5df01d Mon Sep 17 00:00:00 2001 From: thomas Date: Sat, 7 Oct 2023 15:48:17 +0100 Subject: [PATCH 75/84] update --- src/lightning/data/cache/serializers.py | 12 +++++-- tests/tests_data/cache/test_serializer.py | 39 ----------------------- 2 files changed, 10 insertions(+), 41 deletions(-) diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index f85027201f3bc..d9492c7fac138 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -15,13 +15,14 @@ from abc import ABC, abstractmethod from collections import OrderedDict from io import BytesIO -from typing import Any +from typing import Any, Union import numpy as np import torch from lightning_utilities.core.imports import RequirementCache _PIL_AVAILABLE = RequirementCache("PIL") +_TORCH_VISION_AVAILABLE = RequirementCache("torchvision") if _PIL_AVAILABLE: from PIL import Image @@ -30,6 +31,9 @@ Image = None JpegImageFile = None +if _TORCH_VISION_AVAILABLE: + from torchvision.io import decode_jpeg + class Serializer(ABC): """The base interface for any serializers. @@ -100,7 +104,11 @@ def serialize(self, item: Image) -> bytes: return f.read() raise TypeError(f"The provided itemect should be of type {JpegImageFile}. Found {item}.") - def deserialize(self, data: bytes) -> Image: + def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]: + if _TORCH_VISION_AVAILABLE: + array = torch.frombuffer(data, dtype=torch.uint8) + return decode_jpeg(array) + inp = BytesIO(data) return Image.open(inp) diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index d692f964874a4..6ea9f4402d2b9 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import sys from time import time @@ -23,7 +22,6 @@ _SERIALIZERS, _TORCH_DTYPES_MAPPING, IntSerializer, - JPEGSerializer, PickleSerializer, PILSerializer, TensorSerializer, @@ -46,43 +44,6 @@ def test_int_serializer(): assert i == serializer.deserialize(data) -@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") -@pytest.mark.parametrize("mode", ["L", "RGB"]) -def test_jpeg_serializer(mode, tmpdir): - serializer = JPEGSerializer() - - from PIL import Image - - path = os.path.join(tmpdir, "img.jpeg") - - size = {"RGB": (28, 28, 3), "L": (28, 28)}[mode] - np_data = np.random.randint(255, size=size, dtype=np.uint8) - img = Image.fromarray(np_data).convert(mode) - - np.testing.assert_array_equal(np_data, np.array(img)) - - with pytest.raises(TypeError, match="PIL.JpegImagePlugin.JpegImageFile"): - serializer.serialize(img) - - # from the JPEG image directly - img.save(path, format="jpeg", quality=100) - img = Image.open(path) - - data = serializer.serialize(img) - assert isinstance(data, bytes) - deserialized_img = np.asarray(serializer.deserialize(data)) - assert np.array_equal(np.asarray(img), np.array(deserialized_img)) - - # read bytes from the file - with open(path, "rb") as f: - data = f.read() - - assert isinstance(data, bytes) - deserialized_img = np.asarray(serializer.deserialize(data)) - - assert np.array_equal(np.asarray(img), np.array(deserialized_img)) - - @pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") @pytest.mark.parametrize("mode", ["I", "L", "RGB"]) def test_pil_serializer(mode): From 1f41d60a59b8d83fd2b4dcbdeedd407a94ec022a Mon Sep 17 00:00:00 2001 From: thomas Date: Sat, 7 Oct 2023 19:12:38 +0100 Subject: [PATCH 76/84] update --- tests/tests_data/cache/test_writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index 67d14db450b04..202a2f771665d 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -125,5 +125,5 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): reader = BinaryReader(cache_dir) for i in range(100): data = reader.read(ChunkedIndex(i, chunk_index=i // 4)) - assert data["x"] == imgs[i] + np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i]) assert data["y"] == i From bab1f1c02d5d754bcc081f7120e9f999fc5a93a1 Mon Sep 17 00:00:00 2001 From: thomas Date: Sat, 7 Oct 2023 19:30:53 +0100 Subject: [PATCH 77/84] update --- src/lightning/data/cache/dataloader.py | 2 +- src/lightning/data/cache/serializers.py | 43 +++++++++++++++-------- src/lightning/data/cache/writer.py | 4 +-- tests/tests_data/cache/test_serializer.py | 10 +++--- tests/tests_data/cache/test_writer.py | 39 ++++++++++++++++++++ 5 files changed, 76 insertions(+), 22 deletions(-) diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index a6a31d1df78e2..da50657d948b5 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -128,7 +128,7 @@ def __call__(self, items: List[Any]) -> Any: asyncio.set_event_loop(loop) items = loop.run_until_complete(asyncio.gather(*items)) - return self.collate_fn(items) + return self.collate_fn([item for item in items if item is not None]) class _SingleProcessDataLoaderIterPatch(_SingleProcessDataLoaderIter): diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index d9492c7fac138..7a2dae4b5ffec 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -11,11 +11,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import pickle from abc import ABC, abstractmethod from collections import OrderedDict from io import BytesIO -from typing import Any, Union +from typing import Any, Optional, Tuple, Union import numpy as np import torch @@ -43,7 +44,7 @@ class Serializer(ABC): """ @abstractmethod - def serialize(self, data: Any) -> bytes: + def serialize(self, data: Any) -> Tuple[bytes, Optional[str]]: pass @abstractmethod @@ -58,12 +59,12 @@ def can_serialize(self, data: Any) -> bool: class PILSerializer(Serializer): """The PILSerializer serialize and deserialize PIL Image to and from bytes.""" - def serialize(self, item: Image) -> bytes: + def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]: mode = item.mode.encode("utf-8") width, height = item.size raw = item.tobytes() ints = np.array([width, height, len(mode)], np.uint32) - return ints.tobytes() + mode + raw + return ints.tobytes() + mode + raw, None def deserialize(self, data: bytes) -> Any: idx = 3 * 4 @@ -81,8 +82,8 @@ def can_serialize(self, item: Any) -> bool: class IntSerializer(Serializer): """The IntSerializer serialize and deserialize integer to and from bytes.""" - def serialize(self, item: int) -> bytes: - return str(item).encode("utf-8") + def serialize(self, item: int) -> Tuple[bytes, Optional[str]]: + return str(item).encode("utf-8"), None def deserialize(self, data: bytes) -> int: return int(data.decode("utf-8")) @@ -94,14 +95,14 @@ def can_serialize(self, item: Any) -> bool: class JPEGSerializer(Serializer): """The JPEGSerializer serialize and deserialize JPEG image to and from bytes.""" - def serialize(self, item: Image) -> bytes: + def serialize(self, item: Image) -> Tuple[bytes, Optional[str]]: if isinstance(item, JpegImageFile): if not hasattr(item, "filename"): raise ValueError( "The JPEG Image's filename isn't defined. HINT: Open the image in your Dataset __getitem__ method." ) with open(item.filename, "rb") as f: - return f.read() + return f.read(), None raise TypeError(f"The provided itemect should be of type {JpegImageFile}. Found {item}.") def deserialize(self, data: bytes) -> Union[JpegImageFile, torch.Tensor]: @@ -119,8 +120,8 @@ def can_serialize(self, item: Any) -> bool: class BytesSerializer(Serializer): """The BytesSerializer serialize and deserialize integer to and from bytes.""" - def serialize(self, item: bytes) -> bytes: - return item + def serialize(self, item: bytes) -> Tuple[bytes, Optional[str]]: + return item, None def deserialize(self, item: bytes) -> bytes: return item @@ -160,14 +161,14 @@ def __init__(self) -> None: super().__init__() self._dtype_to_indice = {v: k for k, v in _TORCH_DTYPES_MAPPING.items()} - def serialize(self, item: torch.Tensor) -> bytes: + def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]: dtype_indice = self._dtype_to_indice[item.dtype] data = [np.uint32(dtype_indice).tobytes()] data.append(np.uint32(len(item.shape)).tobytes()) for dim in item.shape: data.append(np.uint32(dim).tobytes()) data.append(item.numpy().tobytes()) - return b"".join(data) + return b"".join(data), None def deserialize(self, data: bytes) -> torch.Tensor: dtype_indice = np.frombuffer(data[0:4], np.uint32).item() @@ -186,8 +187,8 @@ def can_serialize(self, item: torch.Tensor) -> bool: class PickleSerializer(Serializer): """The PickleSerializer serialize and deserialize python objects to and from bytes.""" - def serialize(self, item: Any) -> bytes: - return pickle.dumps(item) + def serialize(self, item: Any) -> Tuple[bytes, Optional[str]]: + return pickle.dumps(item), None def deserialize(self, data: bytes) -> Any: return pickle.loads(data) @@ -196,8 +197,22 @@ def can_serialize(self, _: Any) -> bool: return True +class FileSerializer(Serializer): + def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]: + _, file_extension = os.path.splitext(filepath) + with open(filepath, "rb") as f: + return f.read(), file_extension.replace(".", "") + + def deserialize(self, data: bytes) -> Any: + pass + + def can_serialize(self, data: Any) -> bool: + return isinstance(data, str) and os.path.exists(data) + + _SERIALIZERS = OrderedDict( **{ + "file": FileSerializer(), "pil": PILSerializer(), "int": IntSerializer(), "jpeg": JPEGSerializer(), diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index 02e6f8159d7a4..d71f9ff2afb63 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -142,10 +142,10 @@ def _serialize(self, item: Any, sizes: List[int], data: List[bytes]) -> str: """Serialize a given item and append its size and bytes to the sizes and data array.""" for serializer_name, serializer in self._serializers.items(): if serializer.can_serialize(item): - serialized_item = serializer.serialize(item) + serialized_item, name = serializer.serialize(item) data.append(serialized_item) sizes.append(len(serialized_item)) - return serializer_name + return name or serializer_name raise ValueError(f"The provided item isn't serializable. Found {item}") def _create_chunk(self, filename: str) -> bytes: diff --git a/tests/tests_data/cache/test_serializer.py b/tests/tests_data/cache/test_serializer.py index 6ea9f4402d2b9..cc4cefbead9d2 100644 --- a/tests/tests_data/cache/test_serializer.py +++ b/tests/tests_data/cache/test_serializer.py @@ -32,14 +32,14 @@ def test_serializers(): - assert list(_SERIALIZERS.keys()) == ["pil", "int", "jpeg", "bytes", "tensor", "pickle"] + assert list(_SERIALIZERS.keys()) == ["file", "pil", "int", "jpeg", "bytes", "tensor", "pickle"] def test_int_serializer(): serializer = IntSerializer() for i in range(100): - data = serializer.serialize(i) + data, _ = serializer.serialize(i) assert isinstance(data, bytes) assert i == serializer.deserialize(data) @@ -54,7 +54,7 @@ def test_pil_serializer(mode): np_data = np.random.randint(255, size=(28, 28), dtype=np.uint32) img = Image.fromarray(np_data).convert(mode) - data = serializer.serialize(img) + data, _ = serializer.serialize(img) assert isinstance(data, bytes) deserialized_img = serializer.deserialize(data) @@ -84,7 +84,7 @@ def test_tensor_serializer(): tensor = torch.ones(shape, dtype=dtype) t0 = time() - data = serializer_tensor.serialize(tensor) + data, _ = serializer_tensor.serialize(tensor) deserialized_tensor = serializer_tensor.deserialize(data) tensor_time = time() - t0 tensor_bytes = len(data) @@ -93,7 +93,7 @@ def test_tensor_serializer(): assert torch.equal(tensor, deserialized_tensor) t1 = time() - data = serializer_pickle.serialize(tensor) + data, _ = serializer_pickle.serialize(tensor) deserialized_tensor = serializer_pickle.deserialize(data) pickle_time = time() - t1 pickle_bytes = len(data) diff --git a/tests/tests_data/cache/test_writer.py b/tests/tests_data/cache/test_writer.py index 202a2f771665d..331be0b5fa112 100644 --- a/tests/tests_data/cache/test_writer.py +++ b/tests/tests_data/cache/test_writer.py @@ -127,3 +127,42 @@ def test_binary_writer_with_jpeg_and_int(tmpdir): data = reader.read(ChunkedIndex(i, chunk_index=i // 4)) np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i]) assert data["y"] == i + + +@pytest.mark.skipif(condition=not _PIL_AVAILABLE, reason="Requires: ['pil']") +def test_binary_writer_with_jpeg_filepath_and_int(tmpdir): + """Validate the writer and reader can serialize / deserialize a pair of image and label.""" + from PIL import Image + + cache_dir = os.path.join(tmpdir, "chunks") + os.makedirs(cache_dir, exist_ok=True) + binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12) + + imgs = [] + + for i in range(100): + path = os.path.join(tmpdir, f"img{i}.jpeg") + np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8) + img = Image.fromarray(np_data).convert("L") + img.save(path, format="jpeg", quality=100) + img = Image.open(path) + imgs.append(img) + binary_writer[i] = {"x": path, "y": i} + + assert len(os.listdir(cache_dir)) == 24 + binary_writer.done() + binary_writer.merge() + assert len(os.listdir(cache_dir)) == 26 + + with open(os.path.join(cache_dir, "index.json")) as f: + data = json.load(f) + + assert data["chunks"][0]["chunk_size"] == 4 + assert data["chunks"][1]["chunk_size"] == 4 + assert data["chunks"][-1]["chunk_size"] == 4 + + reader = BinaryReader(cache_dir) + for i in range(100): + data = reader.read(ChunkedIndex(i, chunk_index=i // 4)) + np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i]) + assert data["y"] == i From 2c0ee2f5999d7deddf0f0537055bcd4bdd2cec36 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 9 Oct 2023 09:56:02 +0100 Subject: [PATCH 78/84] update --- src/lightning/data/cache/cache.py | 7 +- src/lightning/data/cache/config.py | 15 +- src/lightning/data/cache/constants.py | 9 +- src/lightning/data/cache/dataloader.py | 8 +- src/lightning/data/cache/downloader.py | 1 + src/lightning/data/cache/pytree.py | 567 ------------------------ src/lightning/data/cache/reader.py | 6 +- src/lightning/data/cache/serializers.py | 4 +- src/lightning/data/cache/writer.py | 18 +- tests/tests_data/cache/test_cache.py | 26 +- 10 files changed, 49 insertions(+), 612 deletions(-) delete mode 100644 src/lightning/data/cache/pytree.py diff --git a/src/lightning/data/cache/cache.py b/src/lightning/data/cache/cache.py index 964dcd497f9f2..1d9c2d6e69554 100644 --- a/src/lightning/data/cache/cache.py +++ b/src/lightning/data/cache/cache.py @@ -15,7 +15,7 @@ import os from typing import Any, Dict, List, Optional, Tuple, Union -from lightning.data.cache.constants import INDEX_FILENAME +from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE from lightning.data.cache.reader import BinaryReader from lightning.data.cache.sampler import ChunkedIndex from lightning.data.cache.writer import BinaryWriter @@ -39,12 +39,15 @@ def __init__( Arguments: cache_dir: The path to where the chunks will be stored. remote_dir: The path to a remote folder where the data are located. + The scheme needs to be added to the path. compression: The name of the algorithm to reduce the size of the chunks. chunk_bytes: The maximum number of bytes within a chunk. chunk_size: The maximum number of items within a chunk. """ super().__init__() + if not _TORCH_2_1_0_AVAILABLE: + raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.") self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression) self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression) self._cache_dir = cache_dir @@ -56,7 +59,7 @@ def filled(self) -> bool: """Returns whether the caching phase is done.""" if self._is_done: return True - self._is_done = os.path.exists(os.path.join(self._cache_dir, INDEX_FILENAME)) + self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)) return self._is_done def __setitem__(self, index: int, data: Any) -> None: diff --git a/src/lightning/data/cache/config.py b/src/lightning/data/cache/config.py index 0757622cd8353..2e23a15dc2bda 100644 --- a/src/lightning/data/cache/config.py +++ b/src/lightning/data/cache/config.py @@ -15,11 +15,13 @@ import os from typing import Any, Dict, List, Optional, Tuple -from lightning.data.cache.constants import INDEX_FILENAME +from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE from lightning.data.cache.downloader import get_downloader_cls -from lightning.data.cache.pytree import treespec_loads from lightning.data.cache.sampler import ChunkedIndex +if _TORCH_2_1_0_AVAILABLE: + from torch.utils._pytree import treespec_loads + class ChunksConfig: def __init__(self, cache_dir: str, remote_dir: Optional[str]): @@ -28,7 +30,8 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str]): Arguments: cache_dir: The path to cache folder. - remote_dir: The remote folder where the data are stored. + remote_dir: The path to a remote folder where the data are located. + The scheme needs to be added to the path. """ self._cache_dir = cache_dir @@ -37,7 +40,7 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str]): self._chunks = [] self._remote_dir = remote_dir - with open(os.path.join(self._cache_dir, INDEX_FILENAME)) as f: + with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f: data = json.load(f) self._config = data["config"] @@ -107,11 +110,11 @@ def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]: @classmethod def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["ChunksConfig"]: - cache_index_filepath = os.path.join(cache_dir, INDEX_FILENAME) + cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME) if isinstance(remote_dir, str): downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, []) - downloader.download_file(os.path.join(remote_dir, INDEX_FILENAME), cache_index_filepath) + downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath) if not os.path.exists(cache_index_filepath): return None diff --git a/src/lightning/data/cache/constants.py b/src/lightning/data/cache/constants.py index ee7b6f59cb339..d9dfa136c4999 100644 --- a/src/lightning/data/cache/constants.py +++ b/src/lightning/data/cache/constants.py @@ -11,4 +11,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -INDEX_FILENAME = "index.json" +from lightning_utilities.core.imports import RequirementCache + +_INDEX_FILENAME = "index.json" +_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B + +# This is required for full pytree serialization / deserialization support +_TORCH_2_1_0_AVAILABLE = RequirementCache("torch>=2.1.0") +_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer") diff --git a/src/lightning/data/cache/dataloader.py b/src/lightning/data/cache/dataloader.py index da50657d948b5..8f58b54c97dee 100644 --- a/src/lightning/data/cache/dataloader.py +++ b/src/lightning/data/cache/dataloader.py @@ -19,7 +19,6 @@ from typing import Any, Callable, List, Optional import torch -from lightning_utilities.core.imports import RequirementCache from torch.utils.data import Dataset, IterableDataset from torch.utils.data._utils.collate import default_collate from torch.utils.data._utils.fetch import _BaseDatasetFetcher @@ -33,16 +32,15 @@ from torch.utils.data.sampler import BatchSampler, Sampler from lightning.data.cache import Cache -from lightning.data.cache.pytree import tree_flatten +from lightning.data.cache.constants import _DEFAULT_CHUNK_BYTES, _TORCH_2_1_0_AVAILABLE, _VIZ_TRACKER_AVAILABLE from lightning.data.cache.sampler import CacheBatchSampler from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv -_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer") +if _TORCH_2_1_0_AVAILABLE: + from torch.utils._pytree import tree_flatten logger = logging.Logger(__name__) -_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B - def _equal_items(data_1: Any, data_2: Any) -> bool: data_1_flattened, _ = tree_flatten(data_1) diff --git a/src/lightning/data/cache/downloader.py b/src/lightning/data/cache/downloader.py index 9c4a91188155a..460d0e576de87 100644 --- a/src/lightning/data/cache/downloader.py +++ b/src/lightning/data/cache/downloader.py @@ -63,6 +63,7 @@ def download_file(cls, remote_filepath: str, local_filepath: str) -> None: ) +# TODO: Add fsspec support _DOWNLOADERS = {"s3://": S3Downloader} diff --git a/src/lightning/data/cache/pytree.py b/src/lightning/data/cache/pytree.py deleted file mode 100644 index 3e2692aa6156f..0000000000000 --- a/src/lightning/data/cache/pytree.py +++ /dev/null @@ -1,567 +0,0 @@ -# Taken from PyTorch https://github.com/pytorch/pytorch/blob/e9ebda29d87ce0916ab08c06ab26fd3766a870e5/torch/utils/_pytree.py -# This should be available in 2.0.2 -# TODO: Remove me when open sourced. - -import dataclasses -import json -import warnings -from collections import OrderedDict, namedtuple -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, TypeVar, Union, cast, overload - -T = TypeVar("T") -S = TypeVar("S") -U = TypeVar("U") -R = TypeVar("R") - -DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 - -Context = Any -PyTree = Any -FlattenFunc = Callable[[PyTree], Tuple[List, Context]] -UnflattenFunc = Callable[[List, Context], PyTree] -DumpableContext = Any # Any json dumpable text -ToDumpableContextFn = Callable[[Context], DumpableContext] -FromDumpableContextFn = Callable[[DumpableContext], Context] -ToStrFunc = Callable[["TreeSpec", List[str]], str] -MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] - - -# A NodeDef holds two callables: -# - flatten_fn should take the collection and return a flat list of values. -# It can also return some context that is used in reconstructing the -# collection. -# - unflatten_fn should take a flat list of values and some context -# (returned by flatten_fn). It returns the collection by reconstructing -# it from the list and the context. -class NodeDef(NamedTuple): - type: Type[Any] - flatten_fn: FlattenFunc - unflatten_fn: UnflattenFunc - - -SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} - - -# _SerializeNodeDef holds the following: -# - typ: the type of the node (e.g., "Dict", "List", etc) -# - type_fqn: the fully qualified name of the type, e.g. "collections.OrderedDict" -# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the -# context, and the version number -# - from_dumpable_context takes in a string representation of the context, and the -# version, and returns the deserialized context -class _SerializeNodeDef(NamedTuple): - typ: Type[Any] - type_fqn: str - to_dumpable_context: Optional[ToDumpableContextFn] - from_dumpable_context: Optional[FromDumpableContextFn] - - -SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} -SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} - - -def _register_pytree_node( - typ: Any, - flatten_fn: FlattenFunc, - unflatten_fn: UnflattenFunc, - to_str_fn: Optional[ToStrFunc] = None, - maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, - *, - to_dumpable_context: Optional[ToDumpableContextFn] = None, - from_dumpable_context: Optional[FromDumpableContextFn] = None, -) -> None: - """ - Args: - typ: the type to register - flatten_fn: A callable that takes a pytree and returns a flattened - representation of the pytree and additional context to represent the - flattened pytree. - unflatten_fn: A callable that takes a flattened version of the pytree, - additional context, and returns an unflattedn pytree. - to_dumpable_context: An optional keyword argument to custom specify how - to convert the context of the pytree to a custom json dumpable - representation. This is used for json serialization, which is being - used in torch.export right now. - from_dumpable_context: An optional keyword argument to custom specify how - to convert the custom json dumpable representation of the context - back to the original context. This is used for json deserialization, - which is being used in torch.export right now. - """ - if to_str_fn is not None or maybe_from_str_fn is not None: - warnings.warn( - "to_str_fn and maybe_from_str_fn is deprecated. " - "Please use to_dumpable_context and from_dumpable_context instead." - ) - - node_def = NodeDef( - typ, - flatten_fn, - unflatten_fn, - ) - SUPPORTED_NODES[typ] = node_def - - if (to_dumpable_context is None) ^ (from_dumpable_context is None): - raise ValueError(f"Both to_dumpable_context and from_dumpable_context for {typ} must " "be None or registered.") - - type_fqn = f"{typ.__module__}.{typ.__name__}" - serialize_node_def = _SerializeNodeDef(typ, type_fqn, to_dumpable_context, from_dumpable_context) - SUPPORTED_SERIALIZED_TYPES[typ] = serialize_node_def - SERIALIZED_TYPE_TO_PYTHON_TYPE[type_fqn] = typ - - -def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: - return dict(zip(context, values)) - - -def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: - return d, None - - -def _list_unflatten(values: List[Any], context: Context) -> List[Any]: - return list(values) - - -def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: - return list(d), None - - -def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]: - return tuple(values) - - -def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: - return list(d), type(d) - - -def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple: - return cast(NamedTuple, context(*values)) - - -def _namedtuple_serialize(context: Context) -> DumpableContext: - json_namedtuple = { - "class_name": context.__name__, - "fields": context._fields, - } - return json_namedtuple - - -def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: - class_name = dumpable_context["class_name"] - assert isinstance(class_name, str) - context = namedtuple(class_name, dumpable_context["fields"]) # type: ignore[misc] - return context - - -def _odict_flatten(d: "OrderedDict[Any, Any]") -> Tuple[List[Any], Context]: - return list(d.values()), list(d.keys()) - - -def _odict_unflatten(values: List[Any], context: Context) -> "OrderedDict[Any, Any]": - return OrderedDict((key, value) for key, value in zip(context, values)) - - -_register_pytree_node(dict, _dict_flatten, _dict_unflatten) -_register_pytree_node(list, _list_flatten, _list_unflatten) -_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten) -_register_pytree_node( - namedtuple, - _namedtuple_flatten, - _namedtuple_unflatten, - to_dumpable_context=_namedtuple_serialize, - from_dumpable_context=_namedtuple_deserialize, -) -_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) - - -# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple -def _is_namedtuple_instance(pytree: Any) -> bool: - typ = type(pytree) - bases = typ.__bases__ - if len(bases) != 1 or bases[0] != tuple: - return False - fields = getattr(typ, "_fields", None) - if not isinstance(fields, tuple): - return False - return all(isinstance(entry, str) for entry in fields) - - -def _get_node_type(pytree: Any) -> Any: - if _is_namedtuple_instance(pytree): - return namedtuple - return type(pytree) - - -# A leaf is defined as anything that is not a Node. -def _is_leaf(pytree: PyTree) -> bool: - return _get_node_type(pytree) not in SUPPORTED_NODES - - -# A TreeSpec represents the structure of a pytree. It holds: -# "type": the type of root Node of the pytree -# context: some context that is useful in unflattening the pytree -# children_specs: specs for each child of the root Node -# num_leaves: the number of leaves -@dataclasses.dataclass -class TreeSpec: - type: Any - context: Context - children_specs: List["TreeSpec"] - - def __post_init__(self) -> None: - self.num_leaves: int = sum([spec.num_leaves for spec in self.children_specs]) - - def __repr__(self, indent: int = 0) -> str: - repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" - children_specs_str: str = "" - if len(self.children_specs): - indent += 2 - children_specs_str += self.children_specs[0].__repr__(indent) - children_specs_str += "," if len(self.children_specs) > 1 else "" - children_specs_str += ",".join( - ["\n" + " " * indent + child.__repr__(indent) for child in self.children_specs[1:]] - ) - repr_suffix: str = f"{children_specs_str}])" - return repr_prefix + repr_suffix - - -class LeafSpec(TreeSpec): - def __init__(self) -> None: - super().__init__(None, None, []) - self.num_leaves = 1 - - def __repr__(self, indent: int = 0) -> str: - return "*" - - -def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree.""" - if _is_leaf(pytree): - return [pytree], LeafSpec() - - node_type = _get_node_type(pytree) - flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - child_pytrees, context = flatten_fn(pytree) - - # Recursively flatten the children - result: List[Any] = [] - children_specs: List[TreeSpec] = [] - for child in child_pytrees: - flat, child_spec = tree_flatten(child) - result += flat - children_specs.append(child_spec) - - return result, TreeSpec(node_type, context, children_specs) - - -def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: - """Given a list of values and a TreeSpec, builds a pytree. - - This is the inverse operation of `tree_flatten`. - - """ - if not isinstance(spec, TreeSpec): - raise TypeError( - f"tree_unflatten(values, spec): Expected `spec` to be instance of " - f"TreeSpec but got item of type {type(spec)}." - ) - if len(values) != spec.num_leaves: - raise ValueError( - f"tree_unflatten(values, spec): `values` has length {len(values)} " - f"but the spec refers to a pytree that holds {spec.num_leaves} " - f"items ({spec})." - ) - if isinstance(spec, LeafSpec): - return values[0] - - unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn - - # Recursively unflatten the children - start = 0 - end = 0 - child_pytrees = [] - for child_spec in spec.children_specs: - end += child_spec.num_leaves - child_pytrees.append(tree_unflatten(values[start:end], child_spec)) - start = end - - return unflatten_fn(child_pytrees, spec.context) - - -def tree_map(fn: Any, pytree: PyTree) -> PyTree: - flat_args, spec = tree_flatten(pytree) - return tree_unflatten([fn(i) for i in flat_args], spec) - - -Type2 = Tuple[Type[T], Type[S]] -Type3 = Tuple[Type[T], Type[S], Type[U]] -TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] - -Fn3 = Callable[[Union[T, S, U]], R] -Fn2 = Callable[[Union[T, S]], R] -Fn = Callable[[T], R] -FnAny = Callable[[Any], R] - -MapOnlyFn = Callable[[T], Callable[[Any], Any]] - - -# These specializations help with type inference on the lambda passed to this -# function -@overload -def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: - ... - - -@overload -def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]: - ... - - -# This specialization is needed for the implementations below that call -@overload -def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]: - ... - - -def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]: - """Suppose you are writing a tree_map over tensors, leaving everything else unchanged. Ordinarily you would have - to write: - - def go(t): - if isinstance(t, Tensor): - return ... - else: - return t - - With this function, you only need to write: - - @map_only(Tensor) - def go(t): - return ... - - You can also directly use 'tree_map_only' - - """ - - def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]: - def inner(x: T) -> Any: - if isinstance(x, ty): - return f(x) - return x - - return inner - - return deco - - -@overload -def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree: - ... - - -@overload -def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree: - ... - - -@overload -def tree_map_only(ty: Type3[T, S, U], fn: Fn3[T, S, U, Any], pytree: PyTree) -> PyTree: - ... - - -def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree: - return tree_map(map_only(ty)(fn), pytree) - - -def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool: - flat_args, _ = tree_flatten(pytree) - return all(map(pred, flat_args)) - - -def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool: - flat_args, _ = tree_flatten(pytree) - return any(map(pred, flat_args)) - - -@overload -def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool: - ... - - -@overload -def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool: - ... - - -@overload -def tree_all_only(ty: Type3[T, S, U], pred: Fn3[T, S, U, bool], pytree: PyTree) -> bool: - ... - - -def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool: - flat_args, _ = tree_flatten(pytree) - return all(pred(x) for x in flat_args if isinstance(x, ty)) - - -@overload -def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool: - ... - - -@overload -def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool: - ... - - -def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool: - flat_args, _ = tree_flatten(pytree) - return any(pred(x) for x in flat_args if isinstance(x, ty)) - - -# Broadcasts a pytree to the provided TreeSpec and returns the flattened -# values. If this is not possible, then this function returns None. -# -# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), -# would return [0, 0]. This is useful for part of the vmap implementation: -# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be -# broadcastable to the tree structure of `inputs` and we use -# _broadcast_to_and_flatten to check this. -def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]: - assert isinstance(spec, TreeSpec) - - if _is_leaf(pytree): - return [pytree] * spec.num_leaves - if isinstance(spec, LeafSpec): - return None - node_type = _get_node_type(pytree) - if node_type != spec.type: - return None - - flatten_fn = SUPPORTED_NODES[node_type].flatten_fn - child_pytrees, ctx = flatten_fn(pytree) - - # Check if the Node is different from the spec - if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context: - return None - - # Recursively flatten the children - result: List[Any] = [] - for child, child_spec in zip(child_pytrees, spec.children_specs): - flat = _broadcast_to_and_flatten(child, child_spec) - if flat is not None: - result += flat - else: - return None - - return result - - -@dataclasses.dataclass -class _TreeSpecSchema: - type: Optional[str] - context: DumpableContext - children_spec: List["_TreeSpecSchema"] - - -class _ProtocolFn(NamedTuple): - treespec_to_json: Callable[[TreeSpec], DumpableContext] - json_to_treespec: Callable[[DumpableContext], TreeSpec] - - -_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} - - -def _treespec_to_json(spec: TreeSpec) -> _TreeSpecSchema: - if isinstance(spec, LeafSpec): - return _TreeSpecSchema(None, None, []) - - if spec.type not in SUPPORTED_SERIALIZED_TYPES: - raise NotImplementedError(f"Serializing {spec.type} in pytree is not registered.") - - serialize_node_def = SUPPORTED_SERIALIZED_TYPES[spec.type] - - type_fqn = serialize_node_def.type_fqn - - if serialize_node_def.to_dumpable_context is None: - try: - serialized_context = json.dumps(spec.context) - except TypeError as e: - raise TypeError( - "Unable to serialize context. " - "Please make the context json dump-able, or register a " - "custom serializer using _register_pytree_node." - ) from e - else: - serialized_context = serialize_node_def.to_dumpable_context(spec.context) - - child_schemas = [_treespec_to_json(child) for child in spec.children_specs] - - return _TreeSpecSchema(type_fqn, serialized_context, child_schemas) - - -def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: - if json_schema["type"] is None and json_schema["context"] is None and len(json_schema["children_spec"]) == 0: - return LeafSpec() - - if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: - raise NotImplementedError(f'Deserializing {json_schema["type"]} in pytree is not registered.') - - typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] - serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] - - if serialize_node_def.from_dumpable_context is None: - try: - context = json.loads(json_schema["context"]) - except TypeError: - raise TypeError( - "Unable to deserialize context. " - "Please make the context json load-able, or register a " - "custom serializer using _register_pytree_node." - ) - else: - context = serialize_node_def.from_dumpable_context(json_schema["context"]) - - children_spec = [] - for child_string in json_schema["children_spec"]: - children_spec.append(_json_to_treespec(child_string)) - - return TreeSpec(typ, context, children_spec) - - -_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) - - -def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: - if protocol is None: - protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL - - if protocol in _SUPPORTED_PROTOCOLS: - json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) - else: - raise ValueError(f"Unknown protocol {protocol}. Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}") - - str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) - return str_spec - - -def treespec_loads(data: str) -> TreeSpec: - protocol, json_schema = json.loads(data) - - if protocol in _SUPPORTED_PROTOCOLS: - return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) - raise ValueError(f"Unknown protocol {protocol}. Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}") - - -# TODO(angelayi): remove this function after OSS/internal stabilize -def pytree_to_str(spec: TreeSpec) -> str: - warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") - return treespec_dumps(spec) - - -# TODO(angelayi): remove this function after OSS/internal stabilize -def str_to_pytree(json: str) -> TreeSpec: - warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") - return treespec_loads(json) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index e4f713b939d8b..3110145633939 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -19,11 +19,14 @@ import numpy as np from lightning.data.cache.config import ChunksConfig -from lightning.data.cache.pytree import PyTree, tree_unflatten +from lightning.data.cache.constants import _TORCH_2_1_0_AVAILABLE from lightning.data.cache.sampler import ChunkedIndex from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv +if _TORCH_2_1_0_AVAILABLE: + from torch.utils._pytree import PyTree, tree_unflatten + class PrepareChunksThread(Thread): """This thread is responsible to download the chunks associated to a given worker.""" @@ -61,6 +64,7 @@ def __init__(self, cache_dir: str, remote_dir: Optional[str] = None, compression Arguments: cache_dir: The path to cache folder. remote_dir: The path to a remote folder where the data are located. + The scheme needs to be added to the path. compression: The algorithm to decompress the chunks. """ diff --git a/src/lightning/data/cache/serializers.py b/src/lightning/data/cache/serializers.py index 7a2dae4b5ffec..6ba5d5a4dbcad 100644 --- a/src/lightning/data/cache/serializers.py +++ b/src/lightning/data/cache/serializers.py @@ -181,7 +181,7 @@ def deserialize(self, data: bytes) -> torch.Tensor: return torch.reshape(tensor, torch.Size(shape)) def can_serialize(self, item: torch.Tensor) -> bool: - return isinstance(item, torch.Tensor) + return isinstance(item, torch.Tensor) and type(item) == torch.Tensor class PickleSerializer(Serializer): @@ -201,7 +201,7 @@ class FileSerializer(Serializer): def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]: _, file_extension = os.path.splitext(filepath) with open(filepath, "rb") as f: - return f.read(), file_extension.replace(".", "") + return f.read(), file_extension.replace(".", "").lower() def deserialize(self, data: bytes) -> Any: pass diff --git a/src/lightning/data/cache/writer.py b/src/lightning/data/cache/writer.py index d71f9ff2afb63..29981a89f4e3a 100644 --- a/src/lightning/data/cache/writer.py +++ b/src/lightning/data/cache/writer.py @@ -19,11 +19,13 @@ import numpy as np from lightning.data.cache.compression import _COMPRESSORS, Compressor -from lightning.data.cache.constants import INDEX_FILENAME -from lightning.data.cache.pytree import PyTree, tree_flatten, treespec_dumps +from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE from lightning.data.cache.serializers import _SERIALIZERS, Serializer from lightning.data.datasets.env import _DistributedEnv, _WorkerEnv +if _TORCH_2_1_0_AVAILABLE: + from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps + class BinaryWriter: def __init__( @@ -83,7 +85,7 @@ def filled(self) -> bool: if self._is_done: return True files = os.listdir(self._cache_dir) - index_files = [f for f in files if f.endswith(INDEX_FILENAME)] + index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] worker_end = _WorkerEnv.detect() self._is_done = len(index_files) == self._distributed_env.world_size * worker_end.world_size return self._is_done @@ -244,7 +246,7 @@ def write_chunk_to_file( def write_chunks_index(self) -> None: """Write the chunks index to a JSON file.""" - filepath = os.path.join(self._cache_dir, f"{self.rank}.{INDEX_FILENAME}") + filepath = os.path.join(self._cache_dir, f"{self.rank}.{_INDEX_FILENAME}") config = self.get_config() with open(filepath, "w") as out: json.dump({"chunks": self._chunks_info, "config": config}, out, sort_keys=True) @@ -266,7 +268,7 @@ def merge(self, num_workers: int = 1) -> None: # Only for non rank 0 if self.rank != 0: - while not os.path.exists(os.path.join(self._cache_dir, INDEX_FILENAME)): + while not os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME)): sleep(0.001) return @@ -274,9 +276,9 @@ def merge(self, num_workers: int = 1) -> None: is_done = False while not is_done: files = os.listdir(self._cache_dir) - if INDEX_FILENAME in files: + if _INDEX_FILENAME in files: return - index_files = [f for f in files if f.endswith(INDEX_FILENAME)] + index_files = [f for f in files if f.endswith(_INDEX_FILENAME)] is_done = len(index_files) == self._distributed_env.world_size * num_workers sleep(0.001) @@ -299,5 +301,5 @@ def merge(self, num_workers: int = 1) -> None: os.remove(chunk_path) # Write down the collected index - with open(os.path.join(self._cache_dir, INDEX_FILENAME), "w") as f: + with open(os.path.join(self._cache_dir, _INDEX_FILENAME), "w") as f: json.dump({"chunks": chunks_info, "config": config}, f, sort_keys=True) diff --git a/tests/tests_data/cache/test_cache.py b/tests/tests_data/cache/test_cache.py index c30a890ab0c5a..d1605123a02c8 100644 --- a/tests/tests_data/cache/test_cache.py +++ b/tests/tests_data/cache/test_cache.py @@ -11,7 +11,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import io import os import sys from functools import partial @@ -33,9 +32,8 @@ class ImageDataset(Dataset): - def __init__(self, tmpdir, cache, size, num_classes, use_transform: bool = False): + def __init__(self, tmpdir, cache, size, num_classes): from PIL import Image - from torchvision import transforms as T self.data = [] self.cache = cache @@ -47,31 +45,21 @@ def __init__(self, tmpdir, cache, size, num_classes, use_transform: bool = False np_data = np.random.randint(255, size=(28, 28), dtype=np.uint8) img = Image.fromarray(np_data).convert("L") img.save(path, format="jpeg", quality=100) - # read bytes from the file - with open(path, "rb") as f: - data = f.read() - self.data.append({"image": data, "class": np.random.randint(num_classes)}) - - self.use_transform = use_transform - self.transform = T.Compose([T.ToTensor()]) + self.data.append({"image": path, "class": np.random.randint(num_classes)}) def __len__(self): return len(self.data) def __getitem__(self, index): - from PIL import Image - if self.cache.filled: - data = self.cache[index] - if self.use_transform: - data["image"] = self.transform(Image.open(io.BytesIO(data["image"]))).unsqueeze(0) - return data + return self.cache[index] self.cache[index] = {**self.data[index], "index": index} return None def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): from PIL import Image + from torchvision.transforms import PILToTensor dataset_size = 85 @@ -95,10 +83,8 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): cached_data = dataset[i] original_data = dataset.data[i] assert cached_data["class"] == original_data["class"] - original_image = Image.open(io.BytesIO(original_data["image"])) - assert Image.open(io.BytesIO(cached_data["image"])) == original_image - - dataset.use_transform = True + original_array = PILToTensor()(Image.open(original_data["image"])) + assert torch.equal(original_array, cached_data["image"]) if distributed_env.world_size == 1: indexes = [] From 3d57af836408329de430fc4437db706b635f1b4b Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 9 Oct 2023 10:01:30 +0100 Subject: [PATCH 79/84] update --- src/lightning/data/cache/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/data/cache/reader.py b/src/lightning/data/cache/reader.py index 3110145633939..667360190bbc3 100644 --- a/src/lightning/data/cache/reader.py +++ b/src/lightning/data/cache/reader.py @@ -136,7 +136,7 @@ def read(self, index: ChunkedIndex) -> Any: raw_item_data = self.load_item_from_chunk(index.index, chunk_filepath, begin) return self.deserialize(raw_item_data) - def deserialize(self, raw_item_data: bytes) -> PyTree: + def deserialize(self, raw_item_data: bytes) -> "PyTree": """Deserialize the raw bytes into their python equivalent.""" idx = len(self.config.data_format) * 4 sizes = np.frombuffer(raw_item_data[:idx], np.uint32) From 9ff7dc41d9f3c6f3b00fd27ccd1421267773a14b Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 9 Oct 2023 12:37:07 +0100 Subject: [PATCH 80/84] update --- .github/workflows/ci-tests-data.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-tests-data.yml b/.github/workflows/ci-tests-data.yml index 6d303c26093b6..4de87f501ffdd 100644 --- a/.github/workflows/ci-tests-data.yml +++ b/.github/workflows/ci-tests-data.yml @@ -34,9 +34,9 @@ jobs: fail-fast: false matrix: include: - - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } - - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" } + - { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } + - { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } + - { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" } # "oldest" versions tests, only on minimum Python # - {os: "macOS-11", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"} # - {os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"} From bfa57c0ab1ad97cde77bee5d6c2cec3e4364bf60 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 9 Oct 2023 12:49:30 +0100 Subject: [PATCH 81/84] update --- requirements/data/data.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/data/data.txt b/requirements/data/data.txt index 4fa81bd1a50e5..3ef97aaabbc27 100644 --- a/requirements/data/data.txt +++ b/requirements/data/data.txt @@ -5,4 +5,4 @@ lightning-utilities >=0.8.0, <0.10.0 # to be able to include also 0.6 and preserve `>` needed for CI min version bypass torchdata >0.5.9, <0.7.0 # to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass -torch >0.14.0, <2.1.0 +torch >0.14.0, <=2.1.0 From e733e5954376f7bc35f0bbff57f690935d18dda4 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 9 Oct 2023 13:52:31 +0100 Subject: [PATCH 82/84] update --- requirements/data/data.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/data/data.txt b/requirements/data/data.txt index 3ef97aaabbc27..4813af9523aa2 100644 --- a/requirements/data/data.txt +++ b/requirements/data/data.txt @@ -3,6 +3,6 @@ lightning-utilities >=0.8.0, <0.10.0 # to be able to include also 0.6 and preserve `>` needed for CI min version bypass -torchdata >0.5.9, <0.7.0 +torchdata >0.5.9, <=0.7.0 # to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass torch >0.14.0, <=2.1.0 From 1d0f5e4f99b91c7e0cf952bf17390d529907cec7 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 9 Oct 2023 15:46:31 +0100 Subject: [PATCH 83/84] update --- .github/checkgroup.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/checkgroup.yml b/.github/checkgroup.yml index a8bbe1d3413f9..0aef571eaa1f4 100644 --- a/.github/checkgroup.yml +++ b/.github/checkgroup.yml @@ -173,9 +173,9 @@ subprojects: - "!*.md" - "!**/*.md" checks: - - "data-cpu (macOS-11, lightning, 3.10, 2.0)" - - "data-cpu (ubuntu-20.04, lightning, 3.10, 2.0)" - - "data-cpu (windows-2022, lightning, 3.10, 2.0)" + - "data-cpu (macOS-11, lightning, 3.10, 2.1)" + - "data-cpu (ubuntu-20.04, lightning, 3.10, 2.1)" + - "data-cpu (windows-2022, lightning, 3.10, 2.1)" # SECTION: lightning_fabric From ff7b6293843bd55f01c6dd2caf24b39047ff9584 Mon Sep 17 00:00:00 2001 From: thomas Date: Mon, 9 Oct 2023 15:52:46 +0100 Subject: [PATCH 84/84] update --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 49657b3c4a7a4..43ef2fc0195f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -138,7 +138,6 @@ exclude = [ "src/lightning/app/cli/pl-app-template", "src/lightning/app/cli/react-ui-template", "src/lightning/app/launcher", - "src/lightning/data/cache/pytree.py", ] install_types = "True" non_interactive = "True"