Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,15 @@ jobs:
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-

- name: Install package & dependencies on Ubuntu
if: matrix.os == 'ubuntu-latest'
run: |
pip --version
pip install -e '.[extras]' -r requirements/test.txt -U -q --find-links $TORCH_URL
pip list

- name: Install package & dependencies
if: matrix.os != 'ubuntu-latest'
run: |
pip --version
pip install -e . -r requirements/test.txt -U -q --find-links $TORCH_URL
Expand Down
8 changes: 1 addition & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility
lightning-utilities >=0.8.0, <0.12.0
torch >=2.1.0
filelock
tqdm
numpy
torchvision
pillow
viztracer
pyarrow
boto3[crt]
requests
6 changes: 6 additions & 0 deletions requirements/extras.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torchvision
pillow
viztracer
pyarrow
tqdm
lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility
1 change: 1 addition & 0 deletions requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ pytest-rerunfailures ==14.0
pytest-random-order ==1.1.1
pandas
lightning
lightning-cloud == 0.5.65 # Must be pinned to ensure compatibility
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

_PATH_ROOT = os.path.dirname(__file__)
_PATH_SOURCE = os.path.join(_PATH_ROOT, "src")
_PATH_REQUIRES = os.path.join(_PATH_ROOT, "_requirements")
_PATH_REQUIRES = os.path.join(_PATH_ROOT, "requirements")


def _load_py_module(fname, pkg="litdata"):
Expand Down
14 changes: 13 additions & 1 deletion src/litdata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
from lightning_utilities.core.imports import RequirementCache
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from litdata.__about__ import * # noqa: F403
from litdata.imports import RequirementCache
from litdata.processing.functions import map, optimize, walk
from litdata.streaming.combined import CombinedStreamingDataset
from litdata.streaming.dataloader import StreamingDataLoader
Expand Down
5 changes: 3 additions & 2 deletions src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@

import numpy as np
import torch
from lightning_utilities.core.imports import RequirementCache

from litdata.imports import RequirementCache

_INDEX_FILENAME = "index.json"
_DEFAULT_CHUNK_BYTES = 1 << 26 # 64M B
Expand All @@ -26,7 +27,7 @@
# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
_VIZ_TRACKER_AVAILABLE = RequirementCache("viztracer")
_LIGHTNING_CLOUD_LATEST = RequirementCache("lightning-cloud>=0.5.64")
_LIGHTNING_CLOUD_AVAILABLE = RequirementCache("lightning-cloud")
_BOTO3_AVAILABLE = RequirementCache("boto3")
_TORCH_AUDIO_AVAILABLE = RequirementCache("torchaudio")
_ZSTD_AVAILABLE = RequirementCache("zstd")
Expand Down
121 changes: 121 additions & 0 deletions src/litdata/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from functools import lru_cache
from importlib.util import find_spec
from typing import Optional, TypeVar

import pkg_resources
from typing_extensions import ParamSpec

T = TypeVar("T")
P = ParamSpec("P")


@lru_cache
def package_available(package_name: str) -> bool:
"""Check if a package is available in your environment.

>>> package_available('os')
True
>>> package_available('bla')
False

"""
try:
return find_spec(package_name) is not None
except ModuleNotFoundError:
return False


@lru_cache
def module_available(module_path: str) -> bool:
"""Check if a module path is available in your environment.

>>> module_available('os')
True
>>> module_available('os.bla')
False
>>> module_available('bla.bla')
False

"""
module_names = module_path.split(".")
if not package_available(module_names[0]):
return False
try:
importlib.import_module(module_path)
except ImportError:
return False
return True


class RequirementCache:
"""Boolean-like class to check for requirement and module availability.

Args:
requirement: The requirement to check, version specifiers are allowed.
module: The optional module to try to import if the requirement check fails.

>>> RequirementCache("torch>=0.1")
Requirement 'torch>=0.1' met
>>> bool(RequirementCache("torch>=0.1"))
True
>>> bool(RequirementCache("torch>100.0"))
False
>>> RequirementCache("torch")
Requirement 'torch' met
>>> bool(RequirementCache("torch"))
True
>>> bool(RequirementCache("unknown_package"))
False

"""

def __init__(self, requirement: str, module: Optional[str] = None) -> None:
self.requirement = requirement
self.module = module

def _check_requirement(self) -> None:
if hasattr(self, "available"):
return
try:
# first try the pkg_resources requirement
pkg_resources.require(self.requirement)
self.available = True
self.message = f"Requirement {self.requirement!r} met"
except Exception as ex:
self.available = False
self.message = f"{ex.__class__.__name__}: {ex}. HINT: Try running `pip install -U {self.requirement!r}`"
requirement_contains_version_specifier = any(c in self.requirement for c in "=<>")
if not requirement_contains_version_specifier or self.module is not None:
module = self.requirement if self.module is None else self.module
# sometimes `pkg_resources.require()` fails but the module is importable
self.available = module_available(module)
if self.available:
self.message = f"Module {module!r} available"

def __bool__(self) -> bool:
"""Format as bool."""
self._check_requirement()
return self.available

def __str__(self) -> str:
"""Format as string."""
self._check_requirement()
return self.message

def __repr__(self) -> str:
"""Format as string."""
return self.__str__()
12 changes: 12 additions & 0 deletions src/litdata/processing/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
49 changes: 35 additions & 14 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import concurrent
import json
import logging
Expand All @@ -19,16 +32,16 @@

import numpy as np
import torch
from tqdm.auto import tqdm as _tqdm

from litdata.constants import (
_BOTO3_AVAILABLE,
_DEFAULT_FAST_DEV_RUN_ITEMS,
_INDEX_FILENAME,
_IS_IN_STUDIO,
_LIGHTNING_CLOUD_LATEST,
_LIGHTNING_CLOUD_AVAILABLE,
_TORCH_GREATER_EQUAL_2_1_0,
)
from litdata.imports import RequirementCache
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
from litdata.processing.utilities import _create_dataset
from litdata.streaming import Cache
Expand All @@ -39,10 +52,15 @@
from litdata.utilities.broadcast import broadcast_object
from litdata.utilities.packing import _pack_greedily

_TQDM_AVAILABLE = RequirementCache("tqdm")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm

if _TORCH_GREATER_EQUAL_2_1_0:
from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads

if _LIGHTNING_CLOUD_LATEST:
if _LIGHTNING_CLOUD_AVAILABLE:
from lightning_cloud.openapi import V1DatasetType


Expand Down Expand Up @@ -944,15 +962,16 @@ def run(self, data_recipe: DataRecipe) -> None:
print("Workers are ready ! Starting data processing...")

current_total = 0
pbar = _tqdm(
desc="Progress",
total=num_items,
smoothing=0,
position=-1,
mininterval=1,
leave=True,
dynamic_ncols=True,
)
if _TQDM_AVAILABLE:
pbar = _tqdm(
desc="Progress",
total=num_items,
smoothing=0,
position=-1,
mininterval=1,
leave=True,
dynamic_ncols=True,
)
num_nodes = _get_num_nodes()
node_rank = _get_node_rank()
total_num_items = len(user_items)
Expand All @@ -970,7 +989,8 @@ def run(self, data_recipe: DataRecipe) -> None:
self.workers_tracker[index] = counter
new_total = sum(self.workers_tracker.values())

pbar.update(new_total - current_total)
if _TQDM_AVAILABLE:
pbar.update(new_total - current_total)

current_total = new_total
if current_total == num_items:
Expand All @@ -985,7 +1005,8 @@ def run(self, data_recipe: DataRecipe) -> None:
if all(not w.is_alive() for w in self.workers):
raise RuntimeError("One of the worker has failed")

pbar.close()
if _TQDM_AVAILABLE:
pbar.close()

# TODO: Understand why it hangs.
if num_nodes == 1:
Expand Down
27 changes: 23 additions & 4 deletions src/litdata/processing/readers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,33 @@
# Copyright The Lightning AI team.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import os
from abc import ABC, abstractmethod
from typing import Any, List

from lightning_utilities.core.imports import RequirementCache
from tqdm import tqdm

from litdata.imports import RequirementCache
from litdata.streaming.dataloader import StreamingDataLoader

_PYARROW_AVAILABLE = RequirementCache("pyarrow")
_TQDM_AVAILABLE = RequirementCache("tqdm")

if _TQDM_AVAILABLE:
from tqdm.auto import tqdm as _tqdm
else:

def _tqdm(iterator: Any) -> Any:
yield from iterator


class BaseReader(ABC):
Expand Down Expand Up @@ -79,7 +98,7 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]:
table = None
parquet_filename = os.path.basename(filepath)

for start in tqdm(range(0, num_rows, self.num_rows)):
for start in _tqdm(range(0, num_rows, self.num_rows)):
end = min(start + self.num_rows, num_rows)
chunk_filepath = os.path.join(cache_folder, f"{start}_{end}_{parquet_filename}")
new_items.append(chunk_filepath)
Expand Down
Loading