Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,21 @@ for batch in dataloader:

```


Additionally, you can inject client connection settings for [S3](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client) or GCP when initializing your dataset. This is useful for specifying custom endpoints and credentials per dataset.

```python
from litdata import StreamingDataset

storage_options = {
"endpoint_url": "your_endpoint_url",
"aws_access_key_id": "your_access_key_id",
"aws_secret_access_key": "your_secret_access_key",
}

dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options)
```

</details>

<details>
Expand Down
3 changes: 3 additions & 0 deletions src/litdata/streaming/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
max_cache_size: Union[int, str] = "100GB",
serializers: Optional[Dict[str, Serializer]] = None,
writer_chunk_index: Optional[int] = None,
storage_options: Optional[Dict] = {},
):
"""The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements
together in order to accelerate fetching.
Expand All @@ -60,6 +61,7 @@ def __init__(
max_cache_size: The maximum cache size used by the reader when fetching the chunks.
serializers: Provide your own serializers.
writer_chunk_index: The index of the chunk to start from when writing.
storage_options: Additional connection options for accessing storage services.

"""
super().__init__()
Expand All @@ -85,6 +87,7 @@ def __init__(
encryption=encryption,
item_loader=item_loader,
serializers=serializers,
storage_options=storage_options,
)
self._is_done = False
self._distributed_env = _DistributedEnv.detect()
Expand Down
11 changes: 8 additions & 3 deletions src/litdata/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import os
from time import time
from typing import Any, Optional
from typing import Any, Dict, Optional

import boto3
import botocore
Expand All @@ -26,10 +26,11 @@
class S3Client:
# TODO: Generalize to support more cloud providers.

def __init__(self, refetch_interval: int = 3300) -> None:
def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None:
self._refetch_interval = refetch_interval
self._last_time: Optional[float] = None
self._client: Optional[Any] = None
self._storage_options: dict = storage_options or {}

def _create_client(self) -> None:
has_shared_credentials_file = (
Expand All @@ -38,7 +39,11 @@ def _create_client(self) -> None:

if has_shared_credentials_file or not _IS_IN_STUDIO:
self._client = boto3.client(
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
"s3",
**{
"config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
**self._storage_options,
},
)
else:
provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5))
Expand Down
12 changes: 9 additions & 3 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
item_loader: Optional[BaseItemLoader] = None,
subsampled_files: Optional[List[str]] = None,
region_of_interest: Optional[List[Tuple[int, int]]] = None,
storage_options: Optional[Dict] = {},
) -> None:
"""The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its
chunk.
Expand All @@ -44,6 +45,7 @@ def __init__(
The scheme needs to be added to the path.
subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file.
region_of_interest: List of tuples of {start,end} of region of interest for each chunk.
storage_options: Additional connection options for accessing storage services.

"""
self._cache_dir = cache_dir
Expand All @@ -52,6 +54,7 @@ def __init__(
self._chunks = None
self._remote_dir = remote_dir
self._item_loader = item_loader or PyTreeLoader()
self._storage_options = storage_options

# load data from `index.json` file
data = load_index_file(self._cache_dir)
Expand All @@ -75,7 +78,7 @@ def __init__(
self._downloader = None

if remote_dir:
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks)
self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks, self._storage_options)

self._compressor_name = self._config["compression"]
self._compressor: Optional[Compressor] = None
Expand Down Expand Up @@ -234,17 +237,20 @@ def load(
item_loader: Optional[BaseItemLoader] = None,
subsampled_files: Optional[List[str]] = None,
region_of_interest: Optional[List[Tuple[int, int]]] = None,
storage_options: Optional[dict] = {},
) -> Optional["ChunksConfig"]:
cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME)

if isinstance(remote_dir, str):
downloader = get_downloader_cls(remote_dir, cache_dir, [])
downloader = get_downloader_cls(remote_dir, cache_dir, [], storage_options)
downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath)

if not os.path.exists(cache_index_filepath):
return None

return ChunksConfig(cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest)
return ChunksConfig(
cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest, storage_options
)

def __len__(self) -> int:
return self._length
Expand Down
6 changes: 5 additions & 1 deletion src/litdata/streaming/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
max_cache_size: Union[int, str] = "100GB",
subsample: float = 1.0,
encryption: Optional[Encryption] = None,
storage_options: Optional[Dict] = {},
) -> None:
"""The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class.

Expand All @@ -70,6 +71,7 @@ def __init__(
max_cache_size: The maximum cache size used by the StreamingDataset.
subsample: Float representing fraction of the dataset to be randomly sampled (e.g., 0.1 => 10% of dataset).
encryption: The encryption object to use for decrypting the data.
storage_options: Additional connection options for accessing storage services.

"""
super().__init__()
Expand All @@ -85,7 +87,7 @@ def __init__(
self.subsampled_files: List[str] = []
self.region_of_interest: List[Tuple[int, int]] = []
self.subsampled_files, self.region_of_interest = subsample_streaming_dataset(
self.input_dir, item_loader, subsample, shuffle, seed
self.input_dir, item_loader, subsample, shuffle, seed, storage_options
)

self.item_loader = item_loader
Expand Down Expand Up @@ -128,6 +130,7 @@ def __init__(
self.num_workers: int = 1
self.batch_size: int = 1
self._encryption = encryption
self.storage_options = storage_options

def set_shuffle(self, shuffle: bool) -> None:
self.shuffle = shuffle
Expand Down Expand Up @@ -163,6 +166,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache:
serializers=self.serializers,
max_cache_size=self.max_cache_size,
encryption=self._encryption,
storage_options=self.storage_options,
)
cache._reader._try_load_config()

Expand Down
29 changes: 19 additions & 10 deletions src/litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import shutil
import subprocess
from abc import ABC
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from urllib import parse

from filelock import FileLock, Timeout
Expand All @@ -25,10 +25,13 @@


class Downloader(ABC):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
):
self._remote_dir = remote_dir
self._cache_dir = cache_dir
self._chunks = chunks
self._storage_options = storage_options or {}

def download_chunk_from_index(self, chunk_index: int) -> None:
chunk_filename = self._chunks[chunk_index]["filename"]
Expand All @@ -41,12 +44,14 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:


class S3Downloader(Downloader):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
super().__init__(remote_dir, cache_dir, chunks)
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
):
super().__init__(remote_dir, cache_dir, chunks, storage_options)
self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0

if not self._s5cmd_available:
self._client = S3Client()
self._client = S3Client(storage_options=self._storage_options)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
obj = parse.urlparse(remote_filepath)
Expand Down Expand Up @@ -88,11 +93,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:


class GCPDownloader(Downloader):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
def __init__(
self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
):
if not _GOOGLE_STORAGE_AVAILABLE:
raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE))

super().__init__(remote_dir, cache_dir, chunks)
super().__init__(remote_dir, cache_dir, chunks, storage_options)

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
from google.cloud import storage
Expand All @@ -113,7 +120,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
if key[0] == "/":
key = key[1:]

client = storage.Client()
client = storage.Client(**self._storage_options)
bucket = client.bucket(bucket_name)
blob = bucket.blob(key)
blob.download_to_filename(local_filepath)
Expand All @@ -140,8 +147,10 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
_DOWNLOADERS = {"s3://": S3Downloader, "gs://": GCPDownloader, "local:": LocalDownloaderWithCache, "": LocalDownloader}


def get_downloader_cls(remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]) -> Downloader:
def get_downloader_cls(
remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {}
) -> Downloader:
for k, cls in _DOWNLOADERS.items():
if str(remote_dir).startswith(k):
return cls(remote_dir, cache_dir, chunks)
return cls(remote_dir, cache_dir, chunks, storage_options)
raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.")
4 changes: 4 additions & 0 deletions src/litdata/streaming/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __init__(
encryption: Optional[Encryption] = None,
item_loader: Optional[BaseItemLoader] = None,
serializers: Optional[Dict[str, Serializer]] = None,
storage_options: Optional[dict] = {},
) -> None:
"""The BinaryReader enables to read chunked dataset in an efficient way.

Expand All @@ -183,6 +184,7 @@ def __init__(
item_loader: The chunk sampler to create sub arrays from a chunk.
max_cache_size: The maximum cache size used by the reader when fetching the chunks.
serializers: Provide your own serializers.
storage_options: Additional connection options for accessing storage services.

"""
super().__init__()
Expand All @@ -207,6 +209,7 @@ def __init__(
self._item_loader = item_loader or PyTreeLoader()
self._last_chunk_index: Optional[int] = None
self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0))
self._storage_options = storage_options

def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]:
# Load the config containing the index
Expand All @@ -224,6 +227,7 @@ def _try_load_config(self) -> Optional[ChunksConfig]:
self._item_loader,
self.subsampled_files,
self.region_of_interest,
self._storage_options,
)
return self._config

Expand Down
3 changes: 2 additions & 1 deletion src/litdata/utilities/dataset_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def subsample_streaming_dataset(
subsample: float = 1.0,
shuffle: bool = False,
seed: int = 42,
storage_options: Optional[Dict] = {},
) -> Tuple[List[str], List[Tuple[int, int]]]:
"""Subsample streaming dataset.

Expand Down Expand Up @@ -46,7 +47,7 @@ def subsample_streaming_dataset(
# Check if `index.json` file exists in cache path
if not os.path.exists(cache_index_filepath) and isinstance(input_dir.url, str):
assert input_dir.url is not None
downloader = get_downloader_cls(input_dir.url, input_dir.path, [])
downloader = get_downloader_cls(input_dir.url, input_dir.path, [], storage_options)
downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath)

if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)):
Expand Down
32 changes: 32 additions & 0 deletions tests/streaming/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,38 @@
from litdata.streaming import client


def test_s3_client_with_storage_options(monkeypatch):
boto3 = mock.MagicMock()
monkeypatch.setattr(client, "boto3", boto3)

botocore = mock.MagicMock()
monkeypatch.setattr(client, "botocore", botocore)

storage_options = {
"region_name": "us-west-2",
"endpoint_url": "https://custom.endpoint",
"config": botocore.config.Config(retries={"max_attempts": 100}),
}
s3_client = client.S3Client(storage_options=storage_options)

assert s3_client.client

boto3.client.assert_called_with(
"s3",
region_name="us-west-2",
endpoint_url="https://custom.endpoint",
config=botocore.config.Config(retries={"max_attempts": 100}),
)

s3_client = client.S3Client()

assert s3_client.client

boto3.client.assert_called_with(
"s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"})
)


def test_s3_client_without_cloud_space_id(monkeypatch):
boto3 = mock.MagicMock()
monkeypatch.setattr(client, "boto3", boto3)
Expand Down
4 changes: 3 additions & 1 deletion tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ def test_gcp_downloader(tmpdir, monkeypatch, google_mock):
mock_bucket.blob = MagicMock(return_value=mock_blob)

# Initialize the downloader
downloader = GCPDownloader("gs://random_bucket", tmpdir, [])
storage_options = {"project": "DUMMY_PROJECT"}
downloader = GCPDownloader("gs://random_bucket", tmpdir, [], storage_options)
local_filepath = os.path.join(tmpdir, "a.txt")
downloader.download_file("gs://random_bucket/a.txt", local_filepath)

# Assert that the correct methods were called
google_mock.cloud.storage.Client.assert_called_with(**storage_options)
mock_client.bucket.assert_called_with("random_bucket")
mock_bucket.blob.assert_called_with("a.txt")
mock_blob.download_to_filename.assert_called_with(local_filepath)
Expand Down