Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,16 @@ outputs = optimize(
)
```

## Network Drive On-Prem Support

On-prem compute nodes can mount and use network drive. In order to reduce their network overload, the `StreamingDataset` supports `caching` the chunks.

```python
from lightning.data import StreamingDataset

dataset = StreamingDataset(input_dir="local:/data/shared-drive/some-data")
```

# ⚡ Contributors

We welcome any contributions, pull requests, or issues. If you use the Streaming Dataset for your own project, please reach out to us on Slack or Discord.
8 changes: 7 additions & 1 deletion litdata/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ def download_file(self, remote_filepath: str, local_filepath: str) -> None:
shutil.copy(remote_filepath, local_filepath)


_DOWNLOADERS = {"s3://": S3Downloader, "": LocalDownloader}
class LocalDownloaderWithCache(LocalDownloader):
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
remote_filepath = remote_filepath.replace("local:", "")
super().download_file(remote_filepath, local_filepath)


_DOWNLOADERS = {"s3://": S3Downloader, "local:": LocalDownloaderWithCache, "": LocalDownloader}


def get_downloader_cls(remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]) -> Downloader:
Expand Down
3 changes: 3 additions & 0 deletions litdata/streaming/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def _resolve_dir(dir_path: Optional[Union[str, Dir]]) -> Dir:
if dir_path.startswith("s3://"):
return Dir(path=None, url=dir_path)

if dir_path.startswith("local:"):
return Dir(path=None, url=dir_path)

dir_path = _resolve_time_template(dir_path)

dir_path_absolute = str(Path(dir_path).absolute().resolve())
Expand Down
17 changes: 16 additions & 1 deletion tests/streaming/test_downloader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from unittest.mock import MagicMock

from litdata.streaming.downloader import S3Downloader, subprocess
from litdata.streaming.downloader import LocalDownloaderWithCache, S3Downloader, shutil, subprocess


def test_s3_downloader_fast(tmpdir, monkeypatch):
Expand All @@ -11,3 +11,18 @@ def test_s3_downloader_fast(tmpdir, monkeypatch):
downloader = S3Downloader(tmpdir, tmpdir, [])
downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt"))
popen_mock.wait.assert_called()


def test_download_with_cache(tmpdir, monkeypatch):
# Create a file to download/cache
with open("a.txt", "w") as f:
f.write("hello")

try:
local_downloader = LocalDownloaderWithCache(tmpdir, tmpdir, [])
shutil_mock = MagicMock()
monkeypatch.setattr(shutil, "copy", shutil_mock)
local_downloader.download_file("local:a.txt", os.path.join(tmpdir, "a.txt"))
shutil_mock.assert_called()
finally:
os.remove("a.txt")