Skip to content

Commit

Permalink
Add support for using the streaming dataloader in map or optimize for…
Browse files Browse the repository at this point in the history
… large scale inference (#19510)
  • Loading branch information
tchaton committed Feb 22, 2024
1 parent 4175e1a commit eb0bbde
Show file tree
Hide file tree
Showing 6 changed files with 108 additions and 19 deletions.
13 changes: 6 additions & 7 deletions src/lightning/data/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@
_LIGHTNING_CLOUD_LATEST,
_TORCH_GREATER_EQUAL_2_1_0,
)
from lightning.data.processing.readers import BaseReader
from lightning.data.processing.readers import BaseReader, StreamingDataLoaderReader
from lightning.data.processing.utilities import _create_dataset
from lightning.data.streaming import Cache
from lightning.data.streaming.cache import Dir
from lightning.data.streaming.client import S3Client
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.resolver import _resolve_dir
from lightning.data.utilities.broadcast import broadcast_object
from lightning.data.utilities.packing import _pack_greedily
Expand Down Expand Up @@ -65,11 +66,6 @@ def _get_fast_dev_run() -> int:
return bool(int(os.getenv("DATA_OPTIMIZER_FAST_DEV_RUN", 1)))


def _get_home_folder() -> str:
"""Returns whether cache folder for the filepaths."""
return os.getenv("DATA_OPTIMIZER_HOME_FOLDER", os.path.expanduser("~"))


def _get_default_cache() -> str:
return "/cache" if _IS_IN_STUDIO else tempfile.gettempdir()

Expand Down Expand Up @@ -892,9 +888,12 @@ def run(self, data_recipe: DataRecipe) -> None:
# Call the setup method of the user
user_items: List[Any] = data_recipe.prepare_structure(self.input_dir.path if self.input_dir else None)

if not isinstance(user_items, list):
if not isinstance(user_items, (list, StreamingDataLoader)):
raise ValueError("The `prepare_structure` should return a list of item metadata.")

if isinstance(user_items, StreamingDataLoader):
self.reader = StreamingDataLoaderReader(user_items)

if self.reader:
user_items = self.reader.remap_items(user_items, self.num_workers)

Expand Down
42 changes: 32 additions & 10 deletions src/lightning/data/processing/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from lightning.data.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe
from lightning.data.processing.readers import BaseReader
from lightning.data.processing.utilities import optimize_dns_context
from lightning.data.streaming.dataloader import StreamingDataLoader
from lightning.data.streaming.resolver import (
Dir,
_assert_dir_has_index_file,
Expand Down Expand Up @@ -176,6 +177,7 @@ def map(
inputs: A sequence of input to be processed by the `fn` function.
Each input should contain at least a valid filepath.
output_dir: The folder where the processed data should be stored.
weights: Provide an associated weight to each input. This is used to balance work among workers.
num_workers: The number of workers to use during processing
fast_dev_run: Whether to use process only a sub part of the inputs
num_nodes: When doing remote execution, the number of nodes to use. Only supported on https://lightning.ai/.
Expand All @@ -188,8 +190,14 @@ def map(
batch_size: Group the inputs into batches of batch_size length.
"""
if not isinstance(inputs, Sequence):
raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.")
if isinstance(inputs, StreamingDataLoader) and batch_size is not None:
raise ValueError("When providing a streaming dataloader, pass the batch_size to the dataloader directly.")

if isinstance(inputs, StreamingDataLoader) and weights is not None:
raise ValueError("When providing a streaming dataloader, weights isn't supported.")

if not isinstance(inputs, (Sequence, StreamingDataLoader)):
raise ValueError(f"The provided inputs should be non empty sequence or a streaming dataloader. Found {inputs}.")

if len(inputs) == 0:
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")
Expand Down Expand Up @@ -218,10 +226,13 @@ def map(
if error_when_not_empty:
_assert_dir_is_empty(_output_dir)

input_dir = _resolve_dir(_get_input_dir(inputs))
if not isinstance(inputs, StreamingDataLoader):
input_dir = _resolve_dir(_get_input_dir(inputs))

if isinstance(batch_size, int) and batch_size > 1:
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
if isinstance(batch_size, int) and batch_size > 1:
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
else:
input_dir = Dir()

data_processor = DataProcessor(
input_dir=input_dir,
Expand All @@ -247,6 +258,7 @@ def optimize(
fn: Callable[[Any], Any],
inputs: Sequence[Any],
output_dir: str,
weights: Optional[List[int]] = None,
chunk_size: Optional[int] = None,
chunk_bytes: Optional[Union[int, str]] = None,
compression: Optional[str] = None,
Expand All @@ -267,6 +279,7 @@ def optimize(
inputs: A sequence of input to be processed by the `fn` function.
Each input should contain at least a valid filepath.
output_dir: The folder where the processed data should be stored.
weights: Provide an associated weight to each input. This is used to balance work among workers.
chunk_size: The maximum number of elements to hold within a chunk.
chunk_bytes: The maximum number of bytes to hold within a chunk.
compression: The compression algorithm to use over the chunks.
Expand All @@ -281,8 +294,14 @@ def optimize(
batch_size: Group the inputs into batches of batch_size length.
"""
if not isinstance(inputs, Sequence):
raise ValueError(f"The provided inputs should be non empty sequence. Found {inputs}.")
if isinstance(inputs, StreamingDataLoader) and batch_size is not None:
raise ValueError("When providing a streaming dataloader, pass the batch_size to the dataloader directly.")

if isinstance(inputs, StreamingDataLoader) and weights is not None:
raise ValueError("When providing a streaming dataloader, weights isn't supported.")

if not isinstance(inputs, (Sequence, StreamingDataLoader)):
raise ValueError(f"The provided inputs should be non empty sequence or a streaming dataloader. Found {inputs}.")

if len(inputs) == 0:
raise ValueError(f"The provided inputs should be non empty. Found {inputs}.")
Expand Down Expand Up @@ -313,10 +332,13 @@ def optimize(

_assert_dir_has_index_file(_output_dir)

input_dir = _resolve_dir(_get_input_dir(inputs))
if not isinstance(inputs, StreamingDataLoader):
input_dir = _resolve_dir(_get_input_dir(inputs))

if isinstance(batch_size, int) and batch_size > 1:
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
if isinstance(batch_size, int) and batch_size > 1:
inputs = [inputs[pos : pos + batch_size] for pos in range(0, len(inputs), batch_size)]
else:
input_dir = Dir()

data_processor = DataProcessor(
input_dir=input_dir,
Expand Down
19 changes: 18 additions & 1 deletion src/lightning/data/processing/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from lightning_utilities.core.imports import RequirementCache
from tqdm import tqdm

from lightning.data.streaming.dataloader import StreamingDataLoader

_PYARROW_AVAILABLE = RequirementCache("pyarrow")


Expand All @@ -17,7 +19,7 @@ def get_node_rank(self) -> int:
return int(os.getenv("DATA_OPTIMIZER_NODE_RANK", 0))

@abstractmethod
def remap_items(self, items: List[Any], num_workers: int) -> List[Any]:
def remap_items(self, items: Any, num_workers: int) -> List[Any]:
"""This method is meant to remap the items provided by the users into items more adapted to be distributed."""
pass

Expand Down Expand Up @@ -93,3 +95,18 @@ def remap_items(self, filepaths: List[str], _: int) -> List[str]:
print("Finished resharding the parquet files for optimized processing.")

return new_items


class StreamingDataLoaderReader(BaseReader):
def __init__(self, dataloader: StreamingDataLoader) -> None:
super().__init__()
self.dataloader = dataloader
self.dataloader_iter: Any = None

def read(self, _: int) -> Any:
if self.dataloader_iter is None:
self.dataloader_iter = iter(self.dataloader)
return next(self.dataloader_iter)

def remap_items(self, dataloader: StreamingDataLoader, _: int) -> List[Any]:
return list(range(len(dataloader)))
4 changes: 4 additions & 0 deletions src/lightning/data/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def __len__(self) -> int:
return self.shuffler.get_len(self.distributed_env, self.current_epoch)

def __iter__(self) -> "StreamingDataset":
# When the StreamingDataset is used within map or optimize, let's refetch the distributed env.
if os.getenv("DATA_OPTIMIZER_GLOBAL_RANK"):
self.distributed_env = _DistributedEnv.detect()

self.worker_env = _WorkerEnv.detect()
self.cache = self._create_cache(worker_env=self.worker_env)
self.shuffler = self._create_shuffler(self.cache)
Expand Down
15 changes: 15 additions & 0 deletions src/lightning/data/utilities/env.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import Callable, Optional

import torch
Expand Down Expand Up @@ -28,6 +29,9 @@ def detect(cls) -> "_DistributedEnv":
It will default to 1 distributed process in this case.
"""
if _is_in_map_or_optimize():
return cls._instantiate_in_map_or_optimize()

if torch.distributed.is_available() and torch.distributed.is_initialized():
world_size = torch.distributed.get_world_size()
global_rank = torch.distributed.get_rank()
Expand All @@ -45,6 +49,13 @@ def detect(cls) -> "_DistributedEnv":

return cls(world_size=world_size, global_rank=global_rank, num_nodes=num_nodes)

@classmethod
def _instantiate_in_map_or_optimize(cls) -> "_DistributedEnv":
global_rank = int(os.getenv("DATA_OPTIMIZER_GLOBAL_RANK", "0"))
num_workers = int(os.getenv("DATA_OPTIMIZER_NUM_WORKERS", "0"))
num_nodes = int(os.getenv("DATA_OPTIMIZER_NUM_NODES", 1))
return cls(world_size=num_workers * num_nodes, global_rank=int(global_rank), num_nodes=num_nodes)

def __repr__(self) -> str:
return f"{self.__class__.__name__}(world_size: {self.world_size}, global_rank: {self.global_rank}\n)"

Expand Down Expand Up @@ -165,3 +176,7 @@ def __str__(self) -> str:

def _is_in_dataloader_worker() -> bool:
return torch_get_worker_info() is not None


def _is_in_map_or_optimize() -> bool:
return os.getenv("DATA_OPTIMIZER_GLOBAL_RANK") is not None
34 changes: 33 additions & 1 deletion tests/tests_data/processing/test_data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
_wait_for_file_to_exist,
)
from lightning.data.processing.functions import LambdaDataTransformRecipe, map, optimize
from lightning.data.streaming import StreamingDataset, resolver
from lightning.data.streaming import StreamingDataLoader, StreamingDataset, resolver
from lightning.data.streaming.cache import Cache, Dir
from lightning_utilities.core.imports import RequirementCache

Expand Down Expand Up @@ -1158,3 +1158,35 @@ def test_to_path(tmpdir):

assert _to_path("/teamspace/studios/this_studio/a.png") == "/teamspace/studios/this_studio/a.png"
assert _to_path(filepath) == filepath


def fetch_from_dataset(batch, output_dir):
for index in batch.numpy().tolist():
filepath = os.path.join(output_dir, f"{index}.txt")
with open(filepath, "w") as f:
f.write("Hello World!")


@pytest.mark.skipif(sys.platform == "win32", reason="skip windows")
def test_streaming_dataset_in_map(tmpdir):
seed_everything(42)

output_dir = os.path.join(tmpdir, "output_dir")

cache = Cache(input_dir=str(tmpdir), chunk_size=10)
for i in range(107):
cache[i] = i

cache.done()
cache.merge()

dataset = StreamingDataset(input_dir=str(tmpdir))

map(
fn=fetch_from_dataset,
inputs=StreamingDataLoader(dataset, num_workers=1, batch_size=2),
output_dir=output_dir,
num_workers=2,
)

assert sorted(os.listdir(output_dir)) == sorted([f"{i}.txt" for i in range(107)])

0 comments on commit eb0bbde

Please sign in to comment.