Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
2e3e1c2
update
tchaton Sep 26, 2023
90bcd89
update
tchaton Sep 26, 2023
8d76988
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 26, 2023
fa8a5f3
update
tchaton Sep 27, 2023
70332a9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
a894cc4
update
tchaton Sep 27, 2023
e1ebe37
update
tchaton Sep 27, 2023
bf47412
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
35cae78
update
tchaton Sep 27, 2023
2376c3e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 27, 2023
28fab53
update
Sep 28, 2023
7f54886
update
Sep 28, 2023
c1b197f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
81f5a79
update
Sep 28, 2023
019b1bd
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Sep 28, 2023
c2ee47c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
d0708e0
update
Sep 28, 2023
7ddba5c
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Sep 28, 2023
6f6ce5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
23bc5c4
update
Sep 28, 2023
f3430d7
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Sep 28, 2023
f79a292
update
Sep 28, 2023
4770038
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
dd0991c
update
Sep 28, 2023
a3a9ff7
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Sep 28, 2023
32fe811
update
Sep 28, 2023
1e3d1ab
update
Sep 28, 2023
7859601
Merge branch 'master' into introduce_cache
tchaton Sep 28, 2023
d05e34f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 28, 2023
6a63117
update
Sep 28, 2023
d7fb2d1
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Sep 28, 2023
41de123
update
Sep 28, 2023
e4397d0
update
Sep 28, 2023
a442858
update
Sep 29, 2023
81f68a8
Merge branch 'master' into introduce_cache
tchaton Sep 29, 2023
a30b277
update
Sep 29, 2023
f1dc0b2
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Sep 29, 2023
6171812
update
Sep 29, 2023
f21a7d8
Update src/lightning/data/cache/dataloader.py
tchaton Sep 29, 2023
9bc8811
Update src/lightning/data/cache/dataloader.py
tchaton Sep 29, 2023
5df5984
update
Sep 29, 2023
ec4d7d8
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Sep 29, 2023
ab470c3
update
Sep 29, 2023
f6d8184
update
Sep 29, 2023
41c2ba9
update
Sep 29, 2023
b642fd3
updatte
Sep 29, 2023
c63d0e1
update
Sep 29, 2023
75cc6fb
Merge branch 'master' into introduce_cache
tchaton Sep 29, 2023
976d680
update
Sep 29, 2023
77a909a
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Sep 29, 2023
c32c5aa
update
Sep 29, 2023
e3cd282
update
Sep 29, 2023
888cae4
update
Sep 29, 2023
de58298
update
Sep 29, 2023
c13f948
update
Sep 29, 2023
b03afd0
update
Sep 29, 2023
ab1c0e1
update
Oct 2, 2023
675cd45
Update src/lightning/data/cache/reader.py
tchaton Oct 2, 2023
aec12be
update
Oct 2, 2023
48c0fdf
update
tchaton Oct 3, 2023
6f396fa
update
tchaton Oct 3, 2023
a0f3696
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2023
91e718c
update
Oct 3, 2023
ff7547c
update
Oct 3, 2023
6dfd3fd
update
Oct 3, 2023
72e469d
update
Oct 3, 2023
d8c6fd7
update
Oct 3, 2023
f3184ff
update
tchaton Oct 3, 2023
2aa8551
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
tchaton Oct 3, 2023
564fbab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 3, 2023
19198a2
update
Oct 3, 2023
54685ac
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Oct 3, 2023
39a0846
New cache (#18706)
tchaton Oct 6, 2023
200d6b5
update
Oct 6, 2023
79994ab
update
Oct 6, 2023
cbb7487
update
Oct 6, 2023
17ce63b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2023
1eac118
update
Oct 6, 2023
81de71e
update
Oct 6, 2023
c683a6b
update
Oct 6, 2023
b5425b8
Merge branch 'master' into introduce_cache
tchaton Oct 6, 2023
c6db7d9
update
Oct 6, 2023
5fdd0a1
update
Oct 6, 2023
3c63d30
update
Oct 6, 2023
b4e991d
Merge branch 'introduce_cache' of https://github.com/Lightning-AI/lig…
Oct 6, 2023
6ca0b5b
update
Oct 6, 2023
0bdacbb
update
Oct 6, 2023
9135004
update
Oct 6, 2023
e356838
update
Oct 6, 2023
c4ddeb8
update
Oct 6, 2023
da985f2
update
Oct 6, 2023
0e5342c
update
Oct 6, 2023
3aed4c5
update
Oct 7, 2023
1f41d60
update
Oct 7, 2023
bab1f1c
update
Oct 7, 2023
2c0ee2f
update
Oct 9, 2023
3d57af8
update
Oct 9, 2023
9ff7dc4
update
Oct 9, 2023
bfa57c0
update
Oct 9, 2023
e733e59
update
Oct 9, 2023
1d0f5e4
update
Oct 9, 2023
ff7b629
update
Oct 9, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/checkgroup.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,9 @@ subprojects:
- "!*.md"
- "!**/*.md"
checks:
- "data-cpu (macOS-11, lightning, 3.10, 2.0)"
- "data-cpu (ubuntu-20.04, lightning, 3.10, 2.0)"
- "data-cpu (windows-2022, lightning, 3.10, 2.0)"
- "data-cpu (macOS-11, lightning, 3.10, 2.1)"
- "data-cpu (ubuntu-20.04, lightning, 3.10, 2.1)"
- "data-cpu (windows-2022, lightning, 3.10, 2.1)"

# SECTION: lightning_fabric

Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/ci-tests-data.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ jobs:
fail-fast: false
matrix:
include:
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.0" }
- { os: "macOS-11", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
- { os: "windows-2022", pkg-name: "lightning", python-version: "3.10", pytorch-version: "2.1" }
# "oldest" versions tests, only on minimum Python
# - {os: "macOS-11", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"}
# - {os: "ubuntu-20.04", pkg-name: "lightning", python-version: "3.8", pytorch-version: "2.0", requires: "oldest"}
Expand Down
4 changes: 2 additions & 2 deletions requirements/data/data.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@

lightning-utilities >=0.8.0, <0.10.0
# to be able to include also 0.6 and preserve `>` needed for CI min version bypass
torchdata >0.5.9, <0.7.0
torchdata >0.5.9, <=0.7.0
# to be able to include also PL 2.0 and preserve `>` needed for CI min version bypass
torch >0.14.0, <2.1.0
torch >0.14.0, <=2.1.0
17 changes: 17 additions & 0 deletions src/lightning/data/cache/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning.data.cache.cache import Cache
from lightning.data.cache.dataloader import LightningDataLoader

__all__ = ["Cache", "LightningDataLoader"]
90 changes: 90 additions & 0 deletions src/lightning/data/cache/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
from typing import Any, Dict, List, Optional, Tuple, Union

from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
from lightning.data.cache.reader import BinaryReader
from lightning.data.cache.sampler import ChunkedIndex
from lightning.data.cache.writer import BinaryWriter
from lightning.data.datasets.env import _DistributedEnv

logger = logging.Logger(__name__)


class Cache:
def __init__(
self,
cache_dir: str,
remote_dir: Optional[str] = None,
compression: Optional[str] = None,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[int] = None,
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
together in order to accelerate fetching.

Arguments:
cache_dir: The path to where the chunks will be stored.
remote_dir: The path to a remote folder where the data are located.
The scheme needs to be added to the path.
compression: The name of the algorithm to reduce the size of the chunks.
chunk_bytes: The maximum number of bytes within a chunk.
chunk_size: The maximum number of items within a chunk.

"""
super().__init__()
if not _TORCH_2_1_0_AVAILABLE:
raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.")
self._writer = BinaryWriter(cache_dir, chunk_size=chunk_size, chunk_bytes=chunk_bytes, compression=compression)
self._reader = BinaryReader(cache_dir, remote_dir=remote_dir, compression=compression)
self._cache_dir = cache_dir
self._is_done = False
self._distributed_env = _DistributedEnv.detect()

@property
def filled(self) -> bool:
"""Returns whether the caching phase is done."""
if self._is_done:
return True
self._is_done = os.path.exists(os.path.join(self._cache_dir, _INDEX_FILENAME))
return self._is_done

def __setitem__(self, index: int, data: Any) -> None:
"""Store an item in the writer."""
self._writer[index] = data

def __getitem__(self, index: Union[int, ChunkedIndex]) -> Dict[str, Any]:
"""Read an item in the reader."""
if isinstance(index, int):
index = ChunkedIndex(index, self._get_chunk_index_from_index(index))
return self._reader.read(index)

def done(self) -> None:
"""Inform the writer the chunking phase is finished."""
self._writer.done()

def merge(self, num_workers: int = 1) -> None:
"""Inform the writer the chunking phase is finished."""
self._writer.merge(num_workers)

def __len__(self) -> int:
return self._reader.get_length()

def get_chunk_interval(self) -> List[Tuple[int, int]]:
return self._reader.get_chunk_interval()

def _get_chunk_index_from_index(self, index: int) -> int:
return self._reader._get_chunk_index_from_index(index)
76 changes: 76 additions & 0 deletions src/lightning/data/cache/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractclassmethod, abstractmethod
from typing import Dict, TypeVar

from lightning_utilities.core.imports import RequirementCache, requires

_ZSTD_AVAILABLE = RequirementCache("zstd")

if _ZSTD_AVAILABLE:
import zstd

TCompressor = TypeVar("TCompressor", bound="Compressor")


class Compressor(ABC):
"""Base class for compression algorithm."""

@abstractmethod
def compress(self, data: bytes) -> bytes:
pass

@abstractmethod
def decompress(self, data: bytes) -> bytes:
pass

@abstractclassmethod
def register(cls, compressors: Dict[str, "Compressor"]) -> None:
pass


class ZSTDCompressor(Compressor):
"""Compressor for the zstd package."""

@requires("zstd")
def __init__(self, level: int) -> None:
super().__init__()
self.level = level
self.extension = "zstd"

@property
def name(self) -> str:
return f"{self.extension}:{self.level}"

def compress(self, data: bytes) -> bytes:
return zstd.compress(data, self.level)

def decompress(self, data: bytes) -> bytes:
return zstd.decompress(data)

@classmethod
def register(cls, compressors: Dict[str, "Compressor"]) -> None: # type: ignore
if not _ZSTD_AVAILABLE:
return

# default
compressors["zstd"] = ZSTDCompressor(4)

for level in list(range(1, 23)):
compressors[f"zstd:{level}"] = ZSTDCompressor(level)


_COMPRESSORS: Dict[str, Compressor] = {}

ZSTDCompressor.register(_COMPRESSORS)
125 changes: 125 additions & 0 deletions src/lightning/data/cache/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import Any, Dict, List, Optional, Tuple

from lightning.data.cache.constants import _INDEX_FILENAME, _TORCH_2_1_0_AVAILABLE
from lightning.data.cache.downloader import get_downloader_cls
from lightning.data.cache.sampler import ChunkedIndex

if _TORCH_2_1_0_AVAILABLE:
from torch.utils._pytree import treespec_loads


class ChunksConfig:
def __init__(self, cache_dir: str, remote_dir: Optional[str]):
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
chunk.

Arguments:
cache_dir: The path to cache folder.
remote_dir: The path to a remote folder where the data are located.
The scheme needs to be added to the path.

"""
self._cache_dir = cache_dir
self._intervals: List[Tuple[int, int]] = []
self._config = None
self._chunks = []
self._remote_dir = remote_dir

with open(os.path.join(self._cache_dir, _INDEX_FILENAME)) as f:
data = json.load(f)

self._config = data["config"]

self._chunks.extend(data["chunks"])

self._config["data_spec"] = treespec_loads(self._config["data_spec"])

for chunk in self._chunks:
start, end = chunk["interval"]
if (end - start) != chunk["chunk_size"]:
raise Exception(
"The config intervals doesn't match the number of samples. This shouldn't have happened."
)
self._intervals.append((chunk["interval"][0], chunk["interval"][1]))

self._length = sum([chunk["chunk_size"] for chunk in self._chunks])

self._downloader = None

if remote_dir:
self._downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, self._chunks)

def download_chunk_from_index(self, chunk_index: int) -> None:
chunk_filename = self._chunks[chunk_index]["filename"]

local_chunkpath = os.path.join(self._cache_dir, chunk_filename)

if os.path.exists(local_chunkpath):
return

if self._downloader is None:
raise RuntimeError("The downloader should be defined.")

self._downloader.download_chunk_from_index(chunk_index)

@property
def intervals(self) -> List[Tuple[int, int]]:
if self._intervals is None:
raise RuntimeError("The intervals should be defined.")
return self._intervals

@property
def data_format(self) -> Any:
if self._config is None:
raise RuntimeError("The config should be defined.")
return self._config["data_format"]

@property
def config(self) -> Dict[str, Any]:
if self._config is None:
raise RuntimeError("The config should be defined.")
return self._config

def _get_chunk_index_from_index(self, index: int) -> int:
for chunk_index, internal in enumerate(self._intervals):
if internal[0] <= index < internal[1]:
return chunk_index
raise ValueError(
f"The provided index {index} didn't find a match within the chunk intervals {self._intervals}."
)

def __getitem__(self, index: ChunkedIndex) -> Tuple[str, int, int]:
"""Find the associated chunk metadata."""
chunk = self._chunks[index.chunk_index]
return os.path.join(self._cache_dir, chunk["filename"]), *self._intervals[index.chunk_index]

@classmethod
def load(cls, cache_dir: str, remote_dir: Optional[str] = None) -> Optional["ChunksConfig"]:
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)

if isinstance(remote_dir, str):
downloader = get_downloader_cls(remote_dir)(remote_dir, cache_dir, [])
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)

if not os.path.exists(cache_index_filepath):
return None

return ChunksConfig(cache_dir, remote_dir)

def __len__(self) -> int:
return self._length
21 changes: 21 additions & 0 deletions src/lightning/data/cache/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from lightning_utilities.core.imports import RequirementCache

_INDEX_FILENAME = "index.json"
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B

# This is required for full pytree serialization / deserialization support
_TORCH_2_1_0_AVAILABLE = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
Loading