Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
e140f6a
Add support for direct upload to r2 buckets
Sep 4, 2025
32efe93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 4, 2025
903c719
Create a dedicated R2Client and move get_r2_bucket_credentials into t…
Sep 5, 2025
79fd5ea
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2025
62f8327
Move back to R2FsProvider inheriting from FsProvider to avoid self.cl…
Sep 5, 2025
46b8086
Use existing patterns when callling requests.post/get
Sep 5, 2025
2568398
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 5, 2025
5807a31
Rename lightning_data_connection_id to data_connection_id
Sep 5, 2025
b9c7a7c
Add tests
Sep 5, 2025
861916e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 8, 2025
34eef2a
Fix type in rebase
Sep 8, 2025
0d9f8e8
Prevent potential null pointer
Sep 8, 2025
4943755
Add test for lightning_storage resolver
Sep 8, 2025
23d4d61
Move code for constructing storage options into a util function
Sep 8, 2025
f006976
Add retry logic to client for fetching temp creds
Sep 8, 2025
09699ce
R2 client and fsProvder should inherit from the s3 counterparts
Sep 8, 2025
97576ca
udpate
tchaton Sep 8, 2025
bcbee75
udpate
tchaton Sep 8, 2025
052d3f1
udpate
tchaton Sep 8, 2025
8891e0f
udpate
tchaton Sep 8, 2025
81de16e
udpate
tchaton Sep 8, 2025
932e284
udpate
tchaton Sep 8, 2025
5346a2a
udpate
tchaton Sep 8, 2025
9a59d89
udpate
tchaton Sep 8, 2025
e3e126a
udpate
tchaton Sep 8, 2025
fc89b31
udpate
tchaton Sep 8, 2025
09da903
udpate
tchaton Sep 8, 2025
a41e54a
udpate
tchaton Sep 8, 2025
7f5d6e6
udpate
tchaton Sep 8, 2025
e7cd014
udpate
tchaton Sep 8, 2025
d226cb7
udpate
tchaton Sep 8, 2025
04e690a
udpate
tchaton Sep 8, 2025
3336fe1
Address PR comments
Sep 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,16 @@ jobs:
- name: Run fast tests in parallel
run: |
pytest tests \
--ignore=tests/processing \
--ignore=tests/raw \
-n 2 --cov=litdata --durations=120
--ignore=tests/processing \
--ignore=tests/raw \
-n 2 --cov=litdata --durations=0 --timeout=120 --capture=no --verbose

- name: Run processing tests sequentially
run: |
# note that the listed test should match ignored in the previous step
pytest \
tests/processing tests/raw \
--cov=litdata --cov-append --durations=90
--cov=litdata --cov-append --durations=0 --timeout=120 --capture=no --verbose

- name: Statistics
continue-on-error: true
Expand Down
6 changes: 5 additions & 1 deletion src/litdata/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
_DEFAULT_CACHE_DIR = os.path.join(Path.home(), ".lightning", "chunks")
_DEFAULT_LIGHTNING_CACHE_DIR = os.path.join("/cache", "chunks")
_LITDATA_CACHE_DIR = os.getenv("LITDATA_CACHE_DIR", None)
_SUPPORTED_PROVIDERS = ("s3", "gs") # cloud providers supported by litdata for uploading (optimize, map, merge, etc)
_SUPPORTED_PROVIDERS = (
"s3",
"gs",
"r2",
) # cloud providers supported by litdata for uploading (optimize, map, merge, etc)

# This is required for full pytree serialization / deserialization support
_TORCH_GREATER_EQUAL_2_1_0 = RequirementCache("torch>=2.1.0")
Expand Down
27 changes: 16 additions & 11 deletions src/litdata/processing/data_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
_TQDM_AVAILABLE,
)
from litdata.processing.readers import BaseReader, StreamingDataLoaderReader
from litdata.processing.utilities import _create_dataset, remove_uuid_from_filename
from litdata.processing.utilities import _create_dataset, construct_storage_options, remove_uuid_from_filename
from litdata.streaming import Cache
from litdata.streaming.cache import Dir
from litdata.streaming.dataloader import StreamingDataLoader
Expand Down Expand Up @@ -168,7 +168,8 @@ def _download_data_target(
dirpath = os.path.dirname(local_path)
os.makedirs(dirpath, exist_ok=True)
if fs_provider is None:
fs_provider = _get_fs_provider(input_dir.url, storage_options)
merged_storage_options = construct_storage_options(storage_options, input_dir)
fs_provider = _get_fs_provider(input_dir.url, merged_storage_options)
fs_provider.download_file(path, local_path)

elif os.path.isfile(path):
Expand Down Expand Up @@ -233,7 +234,8 @@ def _upload_fn(
obj = parse.urlparse(output_dir.url if output_dir.url else output_dir.path)

if obj.scheme in _SUPPORTED_PROVIDERS:
fs_provider = _get_fs_provider(output_dir.url, storage_options)
merged_storage_options = construct_storage_options(storage_options, output_dir)
fs_provider = _get_fs_provider(output_dir.url, merged_storage_options)

while True:
data: Optional[Union[str, tuple[str, str]]] = upload_queue.get()
Expand Down Expand Up @@ -1022,7 +1024,8 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
local_filepath = os.path.join(cache_dir, _INDEX_FILENAME)

if obj.scheme in _SUPPORTED_PROVIDERS:
fs_provider = _get_fs_provider(output_dir.url, self.storage_options)
merged_storage_options = construct_storage_options(self.storage_options, output_dir)
fs_provider = _get_fs_provider(output_dir.url, merged_storage_options)
fs_provider.upload_file(
local_filepath,
os.path.join(output_dir.url, os.path.basename(local_filepath)),
Expand All @@ -1044,8 +1047,9 @@ def _upload_index(self, output_dir: Dir, cache_dir: str, num_nodes: int, node_ra
remote_filepath = os.path.join(output_dir_path, f"{node_rank}-{_INDEX_FILENAME}")
node_index_filepath = os.path.join(cache_dir, os.path.basename(remote_filepath))
if obj.scheme in _SUPPORTED_PROVIDERS:
_wait_for_file_to_exist(remote_filepath, storage_options=self.storage_options)
fs_provider = _get_fs_provider(remote_filepath, self.storage_options)
merged_storage_options = construct_storage_options(self.storage_options, output_dir)
_wait_for_file_to_exist(remote_filepath, storage_options=merged_storage_options)
fs_provider = _get_fs_provider(remote_filepath, merged_storage_options)
fs_provider.download_file(remote_filepath, node_index_filepath)
elif output_dir.path and os.path.isdir(output_dir.path):
shutil.copyfile(remote_filepath, node_index_filepath)
Expand Down Expand Up @@ -1499,8 +1503,8 @@ def _cleanup_checkpoints(self) -> None:

prefix = self.output_dir.url.rstrip("/") + "/"
checkpoint_prefix = os.path.join(prefix, ".checkpoints")

fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)
merged_storage_options = construct_storage_options(self.storage_options, self.output_dir)
fs_provider = _get_fs_provider(self.output_dir.url, merged_storage_options)
fs_provider.delete_file_or_directory(checkpoint_prefix)

def _save_current_config(self, workers_user_items: list[list[Any]]) -> None:
Expand Down Expand Up @@ -1529,8 +1533,8 @@ def _save_current_config(self, workers_user_items: list[list[Any]]) -> None:

if obj.scheme not in _SUPPORTED_PROVIDERS:
not_supported_provider(self.output_dir.url)

fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)
merged_storage_options = construct_storage_options(self.storage_options, self.output_dir)
fs_provider = _get_fs_provider(self.output_dir.url, merged_storage_options)

prefix = self.output_dir.url.rstrip("/") + "/" + ".checkpoints/"

Expand Down Expand Up @@ -1601,7 +1605,8 @@ def _load_checkpoint_config(self, workers_user_items: list[list[Any]]) -> None:

# download all the checkpoint files in tempdir and read them
with tempfile.TemporaryDirectory() as temp_dir:
fs_provider = _get_fs_provider(self.output_dir.url, self.storage_options)
merged_storage_options = construct_storage_options(self.storage_options, self.output_dir)
fs_provider = _get_fs_provider(self.output_dir.url, merged_storage_options)
saved_file_dir = fs_provider.download_directory(prefix, temp_dir)

if not os.path.exists(os.path.join(saved_file_dir, "config.json")):
Expand Down
7 changes: 7 additions & 0 deletions src/litdata/processing/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,10 @@ def remove_uuid_from_filename(filepath: str) -> str:

# uuid is of 32 characters, '.json' is 5 characters and '-' is 1 character
return filepath[:-38] + ".json"


def construct_storage_options(storage_options: dict[str, Any], input_dir: Dir) -> dict[str, Any]:
merged_storage_options = storage_options.copy()
if hasattr(input_dir, "data_connection_id") and input_dir.data_connection_id:
merged_storage_options["data_connection_id"] = input_dir.data_connection_id
return merged_storage_options
141 changes: 141 additions & 0 deletions src/litdata/streaming/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from time import time
from typing import Any, Optional

import boto3
import botocore
import requests
from botocore.credentials import InstanceMetadataProvider
from botocore.utils import InstanceMetadataFetcher
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from litdata.constants import _IS_IN_STUDIO

# Constants for the retry adapter. Docs: https://urllib3.readthedocs.io/en/stable/reference/urllib3.util.html
# Maximum number of total connection retry attempts (e.g., 2880 retries = 24 hours with 30s timeout per request)
_CONNECTION_RETRY_TOTAL = 2880
# Backoff factor for connection retries (wait time increases by this factor after each failure)
_CONNECTION_RETRY_BACKOFF_FACTOR = 0.5
# Default timeout for each HTTP request in seconds
_DEFAULT_REQUEST_TIMEOUT = 30 # seconds


class _CustomRetryAdapter(HTTPAdapter):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.timeout = kwargs.pop("timeout", _DEFAULT_REQUEST_TIMEOUT)
super().__init__(*args, **kwargs)

def send(self, request: Any, *args: Any, **kwargs: Any) -> Any:
kwargs["timeout"] = kwargs.get("timeout", self.timeout)
return super().send(request, **kwargs)


class S3Client:
# TODO: Generalize to support more cloud providers.
Expand Down Expand Up @@ -76,3 +98,122 @@ def client(self) -> Any:
self._last_time = time()

return self._client


class R2Client(S3Client):
"""R2 client with refreshable credentials for Cloudflare R2 storage."""

def __init__(
self,
refetch_interval: int = 3600, # 1 hour - this is the default refresh interval for R2 credentials
storage_options: Optional[dict] = {},
session_options: Optional[dict] = {},
) -> None:
# Store R2-specific options before calling super()
self._base_storage_options: dict = storage_options or {}

# Call parent constructor with R2-specific refetch interval
super().__init__(
refetch_interval=refetch_interval,
storage_options={}, # storage options handled in _create_client
session_options=session_options,
)

def get_r2_bucket_credentials(self, data_connection_id: str) -> dict[str, str]:
"""Fetch temporary R2 credentials for the current lightning storage connection."""
# Create session with retry logic
retry_strategy = Retry(
total=_CONNECTION_RETRY_TOTAL,
backoff_factor=_CONNECTION_RETRY_BACKOFF_FACTOR,
status_forcelist=[
408, # Request Timeout
429, # Too Many Requests
500, # Internal Server Error
502, # Bad Gateway
503, # Service Unavailable
504, # Gateway Timeout
],
)
adapter = _CustomRetryAdapter(max_retries=retry_strategy, timeout=_DEFAULT_REQUEST_TIMEOUT)
session = requests.Session()
session.mount("http://", adapter)
session.mount("https://", adapter)

try:
# Get Lightning Cloud API token
cloud_url = os.getenv("LIGHTNING_CLOUD_URL", "https://lightning.ai")
api_key = os.getenv("LIGHTNING_API_KEY")
username = os.getenv("LIGHTNING_USERNAME")
project_id = os.getenv("LIGHTNING_CLOUD_PROJECT_ID")

if not all([api_key, username, project_id]):
raise RuntimeError("Missing required environment variables")

# Login to get token
payload = {"apiKey": api_key, "username": username}
login_url = f"{cloud_url}/v1/auth/login"
response = session.post(login_url, data=json.dumps(payload))

if "token" not in response.json():
raise RuntimeError("Failed to get authentication token")

token = response.json()["token"]

# Get temporary bucket credentials
headers = {"Authorization": f"Bearer {token}", "Content-Type": "application/json"}
credentials_url = (
f"{cloud_url}/v1/projects/{project_id}/data-connections/{data_connection_id}/temp-bucket-credentials"
)

credentials_response = session.get(credentials_url, headers=headers, timeout=10)

if credentials_response.status_code != 200:
raise RuntimeError(f"Failed to get credentials: {credentials_response.status_code}")

temp_credentials = credentials_response.json()

endpoint_url = f"https://{temp_credentials['accountId']}.r2.cloudflarestorage.com"

# Format credentials for S3Client
return {
"aws_access_key_id": temp_credentials["accessKeyId"],
"aws_secret_access_key": temp_credentials["secretAccessKey"],
"aws_session_token": temp_credentials["sessionToken"],
"endpoint_url": endpoint_url,
}

except Exception as e:
# Fallback to hardcoded credentials if API call fails
print(f"Failed to get R2 credentials from API: {e}. Using fallback credentials.")
raise RuntimeError(f"Failed to get R2 credentials and no fallback available: {e}")

def _create_client(self) -> None:
"""Create a new R2 client with fresh credentials."""
# Get data connection ID from storage options
data_connection_id = self._base_storage_options.get("data_connection_id")
if not data_connection_id:
raise RuntimeError("data_connection_id is required in storage_options for R2 client")

# Get fresh R2 credentials
r2_credentials = self.get_r2_bucket_credentials(data_connection_id)

# Filter out metadata keys that shouldn't be passed to boto3
filtered_storage_options = {
k: v for k, v in self._base_storage_options.items() if k not in ["data_connection_id"]
}

# Combine filtered storage options with fresh credentials
combined_storage_options = {**filtered_storage_options, **r2_credentials}

# Update the inherited storage options with R2 credentials
self._storage_options = combined_storage_options

# Create session and client
session = boto3.Session(**self._session_options)
self._client = session.client(
"s3",
**{
"config": botocore.config.Config(retries={"max_attempts": 1000, "mode": "adaptive"}),
**combined_storage_options,
},
)
1 change: 1 addition & 0 deletions src/litdata/streaming/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def download_chunk_from_index(self, chunk_index: int, skip_lock: bool = False) -

if os.path.exists(local_chunkpath):
self.try_decompress(local_chunkpath)

if self._downloader is not None and not skip_lock:
# We don't want to redownload the base, but we should mark
# it as having been requested by something
Expand Down
75 changes: 74 additions & 1 deletion src/litdata/streaming/fs_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from urllib import parse

from litdata.constants import _GOOGLE_STORAGE_AVAILABLE, _SUPPORTED_PROVIDERS
from litdata.streaming.client import S3Client
from litdata.streaming.client import R2Client, S3Client


class FsProvider(ABC):
Expand Down Expand Up @@ -224,6 +224,77 @@ def is_empty(self, path: str) -> bool:
return not objects["KeyCount"] > 0


class R2FsProvider(S3FsProvider):
def __init__(self, storage_options: Optional[dict[str, Any]] = {}):
super().__init__(storage_options=storage_options)

# Create R2Client with refreshable credentials
self.client = R2Client(storage_options=storage_options)

def upload_file(self, local_path: str, remote_path: str) -> None:
bucket_name, blob_path = get_bucket_and_path(remote_path, "r2")
self.client.client.upload_file(local_path, bucket_name, blob_path)

def download_file(self, remote_path: str, local_path: str) -> None:
bucket_name, blob_path = get_bucket_and_path(remote_path, "r2")
with open(local_path, "wb") as f:
self.client.client.download_fileobj(bucket_name, blob_path, f)

def download_directory(self, remote_path: str, local_directory_name: str) -> str:
"""Download all objects under a given S3 prefix (directory) using the existing client."""
bucket_name, remote_directory_name = get_bucket_and_path(remote_path, "r2")

# Ensure local directory exists
local_directory_name = os.path.abspath(local_directory_name)
os.makedirs(local_directory_name, exist_ok=True)

saved_file_dir = "."

# List objects under the given prefix
objects = self.client.client.list_objects_v2(Bucket=bucket_name, Prefix=remote_directory_name)

# Check if objects exist
if "Contents" in objects:
for obj in objects["Contents"]:
local_filename = os.path.join(local_directory_name, obj["Key"])

# Ensure parent directories exist
os.makedirs(os.path.dirname(local_filename), exist_ok=True)

# Download each file
with open(local_filename, "wb") as f:
self.client.client.download_fileobj(bucket_name, obj["Key"], f)
saved_file_dir = os.path.dirname(local_filename)

return saved_file_dir

def delete_file_or_directory(self, path: str) -> None:
"""Delete the file or the directory."""
bucket_name, blob_path = get_bucket_and_path(path, "r2")

# List objects under the given path
objects = self.client.client.list_objects_v2(Bucket=bucket_name, Prefix=blob_path)

# Check if objects exist
if "Contents" in objects:
for obj in objects["Contents"]:
self.client.client.delete_object(Bucket=bucket_name, Key=obj["Key"])

def exists(self, path: str) -> bool:
import botocore

bucket_name, blob_path = get_bucket_and_path(path, "r2")
try:
_ = self.client.client.head_object(Bucket=bucket_name, Key=blob_path)
return True
except botocore.exceptions.ClientError as e:
if "the HeadObject operation: Not Found" in str(e):
return False
raise e
except Exception as e:
raise e


def get_bucket_and_path(remote_filepath: str, expected_scheme: str = "s3") -> tuple[str, str]:
"""Parse the remote filepath and return the bucket name and the blob path.

Expand Down Expand Up @@ -259,6 +330,8 @@ def _get_fs_provider(remote_filepath: str, storage_options: Optional[dict[str, A
return GCPFsProvider(storage_options=storage_options)
if obj.scheme == "s3":
return S3FsProvider(storage_options=storage_options)
if obj.scheme == "r2":
return R2FsProvider(storage_options=storage_options)
raise ValueError(f"Unsupported scheme: {obj.scheme}")


Expand Down
Loading
Loading