Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/litdata/streaming/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def can_serialize(self, _: Any) -> bool:

class FileSerializer(Serializer):
def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]:
print("FileSerializer will be removed in the future.")
_, file_extension = os.path.splitext(filepath)
with open(filepath, "rb") as f:
file_extension = file_extension.replace(".", "").lower()
Expand All @@ -306,7 +307,9 @@ def deserialize(self, data: bytes) -> Any:
return data

def can_serialize(self, data: Any) -> bool:
return isinstance(data, str) and os.path.isfile(data)
# return isinstance(data, str) and os.path.isfile(data)
# FileSerializer will be removed in the future.
return False


class VideoSerializer(Serializer):
Expand Down
4 changes: 3 additions & 1 deletion src/litdata/streaming/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def __init__(

"""
self._cache_dir = cache_dir
os.makedirs(self._cache_dir, exist_ok=True)
if not os.path.exists(self._cache_dir):
os.makedirs(self._cache_dir, exist_ok=True)

if (isinstance(self._cache_dir, str) and not os.path.exists(self._cache_dir)) or self._cache_dir is None:
raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.")

Expand Down
6 changes: 4 additions & 2 deletions tests/streaming/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None):
original_data = dataset.data[i]
assert cached_data["class"] == original_data["class"]
original_array = PILToTensor()(Image.open(original_data["image"]))
assert torch.equal(original_array, cached_data["image"])
cached_array = PILToTensor()(Image.open(cached_data["image"]))
assert torch.equal(original_array, cached_array)

if distributed_env.world_size == 1:
indexes = []
Expand Down Expand Up @@ -129,7 +130,8 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None):
original_data = dataset.data[i]
assert cached_data["class"] == original_data["class"]
original_array = PILToTensor()(Image.open(original_data["image"]))
assert torch.equal(original_array, cached_data["image"])
cached_array = PILToTensor()(Image.open(cached_data["image"]))
assert torch.equal(original_array, cached_array)

streaming_dataset_iter = iter(streaming_dataset)
for _ in streaming_dataset_iter:
Expand Down
18 changes: 10 additions & 8 deletions tests/streaming/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def test_binary_writer_with_jpeg_filepath_and_int(tmpdir):

cache_dir = os.path.join(tmpdir, "chunks")
os.makedirs(cache_dir, exist_ok=True)
binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12)
binary_writer = BinaryWriter(cache_dir, chunk_size=7) # each chunk will have 7 items

imgs = []

Expand All @@ -172,23 +172,25 @@ def test_binary_writer_with_jpeg_filepath_and_int(tmpdir):
imgs.append(img)
binary_writer[i] = {"x": path, "y": i}

assert len(os.listdir(cache_dir)) == 24
assert len(os.listdir(cache_dir)) == 14 # 100 items / 7 items per chunk = 14 chunks
binary_writer.done()
binary_writer.merge()
assert len(os.listdir(cache_dir)) == 26
assert len(os.listdir(cache_dir)) == 16 # 2 items in last chunk and index.json file

with open(os.path.join(cache_dir, "index.json")) as f:
data = json.load(f)

assert data["chunks"][0]["chunk_size"] == 4
assert data["chunks"][1]["chunk_size"] == 4
assert data["chunks"][-1]["chunk_size"] == 4
assert data["chunks"][0]["chunk_size"] == 7
assert data["chunks"][1]["chunk_size"] == 7
assert data["chunks"][-1]["chunk_size"] == 2
assert sum([chunk["chunk_size"] for chunk in data["chunks"]]) == 100

reader = BinaryReader(cache_dir, max_cache_size=10 ^ 9)
for i in range(100):
data = reader.read(ChunkedIndex(i, chunk_index=i // 4))
np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i])
data = reader.read(ChunkedIndex(i, chunk_index=i // 7))
img_read = Image.open(data["x"])
print(f"{img_read.size=}")
np.testing.assert_array_equal(img_read, imgs[i])
assert data["y"] == i


Expand Down