diff --git a/README.md b/README.md index 6ae91bea3..155837597 100644 --- a/README.md +++ b/README.md @@ -209,6 +209,21 @@ for batch in dataloader: ``` + +Additionally, you can inject client connection settings for [S3](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html#boto3.session.Session.client) or GCP when initializing your dataset. This is useful for specifying custom endpoints and credentials per dataset. + +```python +from litdata import StreamingDataset + +storage_options = { + "endpoint_url": "your_endpoint_url", + "aws_access_key_id": "your_access_key_id", + "aws_secret_access_key": "your_secret_access_key", +} + +dataset = StreamingDataset('s3://my-bucket/my-data', storage_options=storage_options) +``` +
diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 17cf5534a..e647db216 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -45,6 +45,7 @@ def __init__( max_cache_size: Union[int, str] = "100GB", serializers: Optional[Dict[str, Serializer]] = None, writer_chunk_index: Optional[int] = None, + storage_options: Optional[Dict] = {}, ): """The Cache enables to optimise dataset format for cloud training. This is done by grouping several elements together in order to accelerate fetching. @@ -60,6 +61,7 @@ def __init__( max_cache_size: The maximum cache size used by the reader when fetching the chunks. serializers: Provide your own serializers. writer_chunk_index: The index of the chunk to start from when writing. + storage_options: Additional connection options for accessing storage services. """ super().__init__() @@ -85,6 +87,7 @@ def __init__( encryption=encryption, item_loader=item_loader, serializers=serializers, + storage_options=storage_options, ) self._is_done = False self._distributed_env = _DistributedEnv.detect() diff --git a/src/litdata/streaming/client.py b/src/litdata/streaming/client.py index 6185e189d..d24803c3c 100644 --- a/src/litdata/streaming/client.py +++ b/src/litdata/streaming/client.py @@ -13,7 +13,7 @@ import os from time import time -from typing import Any, Optional +from typing import Any, Dict, Optional import boto3 import botocore @@ -26,10 +26,11 @@ class S3Client: # TODO: Generalize to support more cloud providers. - def __init__(self, refetch_interval: int = 3300) -> None: + def __init__(self, refetch_interval: int = 3300, storage_options: Optional[Dict] = {}) -> None: self._refetch_interval = refetch_interval self._last_time: Optional[float] = None self._client: Optional[Any] = None + self._storage_options: dict = storage_options or {} def _create_client(self) -> None: has_shared_credentials_file = ( @@ -38,7 +39,11 @@ def _create_client(self) -> None: if has_shared_credentials_file or not _IS_IN_STUDIO: self._client = boto3.client( - "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) + "s3", + **{ + "config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}), + **self._storage_options, + }, ) else: provider = InstanceMetadataProvider(iam_role_fetcher=InstanceMetadataFetcher(timeout=3600, num_attempts=5)) diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index b9744f0ac..d62bc13a0 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -33,6 +33,7 @@ def __init__( item_loader: Optional[BaseItemLoader] = None, subsampled_files: Optional[List[str]] = None, region_of_interest: Optional[List[Tuple[int, int]]] = None, + storage_options: Optional[Dict] = {}, ) -> None: """The ChunksConfig reads the index files associated a chunked dataset and enables to map an index to its chunk. @@ -44,6 +45,7 @@ def __init__( The scheme needs to be added to the path. subsampled_files: List of subsampled chunk files loaded from `input_dir/index.json` file. region_of_interest: List of tuples of {start,end} of region of interest for each chunk. + storage_options: Additional connection options for accessing storage services. """ self._cache_dir = cache_dir @@ -52,6 +54,7 @@ def __init__( self._chunks = None self._remote_dir = remote_dir self._item_loader = item_loader or PyTreeLoader() + self._storage_options = storage_options # load data from `index.json` file data = load_index_file(self._cache_dir) @@ -75,7 +78,7 @@ def __init__( self._downloader = None if remote_dir: - self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks) + self._downloader = get_downloader_cls(remote_dir, cache_dir, self._chunks, self._storage_options) self._compressor_name = self._config["compression"] self._compressor: Optional[Compressor] = None @@ -234,17 +237,20 @@ def load( item_loader: Optional[BaseItemLoader] = None, subsampled_files: Optional[List[str]] = None, region_of_interest: Optional[List[Tuple[int, int]]] = None, + storage_options: Optional[dict] = {}, ) -> Optional["ChunksConfig"]: cache_index_filepath = os.path.join(cache_dir, _INDEX_FILENAME) if isinstance(remote_dir, str): - downloader = get_downloader_cls(remote_dir, cache_dir, []) + downloader = get_downloader_cls(remote_dir, cache_dir, [], storage_options) downloader.download_file(os.path.join(remote_dir, _INDEX_FILENAME), cache_index_filepath) if not os.path.exists(cache_index_filepath): return None - return ChunksConfig(cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest) + return ChunksConfig( + cache_dir, serializers, remote_dir, item_loader, subsampled_files, region_of_interest, storage_options + ) def __len__(self) -> int: return self._length diff --git a/src/litdata/streaming/dataset.py b/src/litdata/streaming/dataset.py index 69526a370..b56300cbe 100644 --- a/src/litdata/streaming/dataset.py +++ b/src/litdata/streaming/dataset.py @@ -54,6 +54,7 @@ def __init__( max_cache_size: Union[int, str] = "100GB", subsample: float = 1.0, encryption: Optional[Encryption] = None, + storage_options: Optional[Dict] = {}, ) -> None: """The streaming dataset can be used once your data have been optimised using the DatasetOptimiser class. @@ -70,6 +71,7 @@ def __init__( max_cache_size: The maximum cache size used by the StreamingDataset. subsample: Float representing fraction of the dataset to be randomly sampled (e.g., 0.1 => 10% of dataset). encryption: The encryption object to use for decrypting the data. + storage_options: Additional connection options for accessing storage services. """ super().__init__() @@ -85,7 +87,7 @@ def __init__( self.subsampled_files: List[str] = [] self.region_of_interest: List[Tuple[int, int]] = [] self.subsampled_files, self.region_of_interest = subsample_streaming_dataset( - self.input_dir, item_loader, subsample, shuffle, seed + self.input_dir, item_loader, subsample, shuffle, seed, storage_options ) self.item_loader = item_loader @@ -128,6 +130,7 @@ def __init__( self.num_workers: int = 1 self.batch_size: int = 1 self._encryption = encryption + self.storage_options = storage_options def set_shuffle(self, shuffle: bool) -> None: self.shuffle = shuffle @@ -163,6 +166,7 @@ def _create_cache(self, worker_env: _WorkerEnv) -> Cache: serializers=self.serializers, max_cache_size=self.max_cache_size, encryption=self._encryption, + storage_options=self.storage_options, ) cache._reader._try_load_config() diff --git a/src/litdata/streaming/downloader.py b/src/litdata/streaming/downloader.py index eba85c5e9..c5138b997 100644 --- a/src/litdata/streaming/downloader.py +++ b/src/litdata/streaming/downloader.py @@ -15,7 +15,7 @@ import shutil import subprocess from abc import ABC -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from urllib import parse from filelock import FileLock, Timeout @@ -25,10 +25,13 @@ class Downloader(ABC): - def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): + def __init__( + self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + ): self._remote_dir = remote_dir self._cache_dir = cache_dir self._chunks = chunks + self._storage_options = storage_options or {} def download_chunk_from_index(self, chunk_index: int) -> None: chunk_filename = self._chunks[chunk_index]["filename"] @@ -41,12 +44,14 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None: class S3Downloader(Downloader): - def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): - super().__init__(remote_dir, cache_dir, chunks) + def __init__( + self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + ): + super().__init__(remote_dir, cache_dir, chunks, storage_options) self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0 if not self._s5cmd_available: - self._client = S3Client() + self._client = S3Client(storage_options=self._storage_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: obj = parse.urlparse(remote_filepath) @@ -88,11 +93,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: class GCPDownloader(Downloader): - def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]): + def __init__( + self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} + ): if not _GOOGLE_STORAGE_AVAILABLE: raise ModuleNotFoundError(str(_GOOGLE_STORAGE_AVAILABLE)) - super().__init__(remote_dir, cache_dir, chunks) + super().__init__(remote_dir, cache_dir, chunks, storage_options) def download_file(self, remote_filepath: str, local_filepath: str) -> None: from google.cloud import storage @@ -113,7 +120,7 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: if key[0] == "/": key = key[1:] - client = storage.Client() + client = storage.Client(**self._storage_options) bucket = client.bucket(bucket_name) blob = bucket.blob(key) blob.download_to_filename(local_filepath) @@ -140,8 +147,10 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None: _DOWNLOADERS = {"s3://": S3Downloader, "gs://": GCPDownloader, "local:": LocalDownloaderWithCache, "": LocalDownloader} -def get_downloader_cls(remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]) -> Downloader: +def get_downloader_cls( + remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]], storage_options: Optional[Dict] = {} +) -> Downloader: for k, cls in _DOWNLOADERS.items(): if str(remote_dir).startswith(k): - return cls(remote_dir, cache_dir, chunks) + return cls(remote_dir, cache_dir, chunks, storage_options) raise ValueError(f"The provided `remote_dir` {remote_dir} doesn't have a downloader associated.") diff --git a/src/litdata/streaming/reader.py b/src/litdata/streaming/reader.py index 0361f5737..9b59ce7b0 100644 --- a/src/litdata/streaming/reader.py +++ b/src/litdata/streaming/reader.py @@ -169,6 +169,7 @@ def __init__( encryption: Optional[Encryption] = None, item_loader: Optional[BaseItemLoader] = None, serializers: Optional[Dict[str, Serializer]] = None, + storage_options: Optional[dict] = {}, ) -> None: """The BinaryReader enables to read chunked dataset in an efficient way. @@ -183,6 +184,7 @@ def __init__( item_loader: The chunk sampler to create sub arrays from a chunk. max_cache_size: The maximum cache size used by the reader when fetching the chunks. serializers: Provide your own serializers. + storage_options: Additional connection options for accessing storage services. """ super().__init__() @@ -207,6 +209,7 @@ def __init__( self._item_loader = item_loader or PyTreeLoader() self._last_chunk_index: Optional[int] = None self._max_cache_size = int(os.getenv("MAX_CACHE_SIZE", max_cache_size or 0)) + self._storage_options = storage_options def _get_chunk_index_from_index(self, index: int) -> Tuple[int, int]: # Load the config containing the index @@ -224,6 +227,7 @@ def _try_load_config(self) -> Optional[ChunksConfig]: self._item_loader, self.subsampled_files, self.region_of_interest, + self._storage_options, ) return self._config diff --git a/src/litdata/utilities/dataset_utilities.py b/src/litdata/utilities/dataset_utilities.py index a45b2f514..57e690ff5 100644 --- a/src/litdata/utilities/dataset_utilities.py +++ b/src/litdata/utilities/dataset_utilities.py @@ -19,6 +19,7 @@ def subsample_streaming_dataset( subsample: float = 1.0, shuffle: bool = False, seed: int = 42, + storage_options: Optional[Dict] = {}, ) -> Tuple[List[str], List[Tuple[int, int]]]: """Subsample streaming dataset. @@ -46,7 +47,7 @@ def subsample_streaming_dataset( # Check if `index.json` file exists in cache path if not os.path.exists(cache_index_filepath) and isinstance(input_dir.url, str): assert input_dir.url is not None - downloader = get_downloader_cls(input_dir.url, input_dir.path, []) + downloader = get_downloader_cls(input_dir.url, input_dir.path, [], storage_options) downloader.download_file(os.path.join(input_dir.url, _INDEX_FILENAME), cache_index_filepath) if os.path.exists(os.path.join(input_dir.path, _INDEX_FILENAME)): diff --git a/tests/streaming/test_client.py b/tests/streaming/test_client.py index 8af8e080d..78ea919d2 100644 --- a/tests/streaming/test_client.py +++ b/tests/streaming/test_client.py @@ -6,6 +6,38 @@ from litdata.streaming import client +def test_s3_client_with_storage_options(monkeypatch): + boto3 = mock.MagicMock() + monkeypatch.setattr(client, "boto3", boto3) + + botocore = mock.MagicMock() + monkeypatch.setattr(client, "botocore", botocore) + + storage_options = { + "region_name": "us-west-2", + "endpoint_url": "https://custom.endpoint", + "config": botocore.config.Config(retries={"max_attempts": 100}), + } + s3_client = client.S3Client(storage_options=storage_options) + + assert s3_client.client + + boto3.client.assert_called_with( + "s3", + region_name="us-west-2", + endpoint_url="https://custom.endpoint", + config=botocore.config.Config(retries={"max_attempts": 100}), + ) + + s3_client = client.S3Client() + + assert s3_client.client + + boto3.client.assert_called_with( + "s3", config=botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}) + ) + + def test_s3_client_without_cloud_space_id(monkeypatch): boto3 = mock.MagicMock() monkeypatch.setattr(client, "boto3", boto3) diff --git a/tests/streaming/test_downloader.py b/tests/streaming/test_downloader.py index 1c2d34a3a..686ba4cbe 100644 --- a/tests/streaming/test_downloader.py +++ b/tests/streaming/test_downloader.py @@ -36,11 +36,13 @@ def test_gcp_downloader(tmpdir, monkeypatch, google_mock): mock_bucket.blob = MagicMock(return_value=mock_blob) # Initialize the downloader - downloader = GCPDownloader("gs://random_bucket", tmpdir, []) + storage_options = {"project": "DUMMY_PROJECT"} + downloader = GCPDownloader("gs://random_bucket", tmpdir, [], storage_options) local_filepath = os.path.join(tmpdir, "a.txt") downloader.download_file("gs://random_bucket/a.txt", local_filepath) # Assert that the correct methods were called + google_mock.cloud.storage.Client.assert_called_with(**storage_options) mock_client.bucket.assert_called_with("random_bucket") mock_bucket.blob.assert_called_with("a.txt") mock_blob.download_to_filename.assert_called_with(local_filepath)