diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 3114c5427..87928fa51 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -297,6 +297,7 @@ def can_serialize(self, _: Any) -> bool: class FileSerializer(Serializer): def serialize(self, filepath: str) -> Tuple[bytes, Optional[str]]: + print("FileSerializer will be removed in the future.") _, file_extension = os.path.splitext(filepath) with open(filepath, "rb") as f: file_extension = file_extension.replace(".", "").lower() @@ -306,7 +307,9 @@ def deserialize(self, data: bytes) -> Any: return data def can_serialize(self, data: Any) -> bool: - return isinstance(data, str) and os.path.isfile(data) + # return isinstance(data, str) and os.path.isfile(data) + # FileSerializer will be removed in the future. + return False class VideoSerializer(Serializer): diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index dec72d0f1..568910472 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -65,7 +65,9 @@ def __init__( """ self._cache_dir = cache_dir - os.makedirs(self._cache_dir, exist_ok=True) + if not os.path.exists(self._cache_dir): + os.makedirs(self._cache_dir, exist_ok=True) + if (isinstance(self._cache_dir, str) and not os.path.exists(self._cache_dir)) or self._cache_dir is None: raise FileNotFoundError(f"The provided cache directory `{self._cache_dir}` doesn't exist.") diff --git a/tests/streaming/test_cache.py b/tests/streaming/test_cache.py index ae338c884..9ad9c136b 100644 --- a/tests/streaming/test_cache.py +++ b/tests/streaming/test_cache.py @@ -95,7 +95,8 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): original_data = dataset.data[i] assert cached_data["class"] == original_data["class"] original_array = PILToTensor()(Image.open(original_data["image"])) - assert torch.equal(original_array, cached_data["image"]) + cached_array = PILToTensor()(Image.open(cached_data["image"])) + assert torch.equal(original_array, cached_array) if distributed_env.world_size == 1: indexes = [] @@ -129,7 +130,8 @@ def _cache_for_image_dataset(num_workers, tmpdir, fabric=None): original_data = dataset.data[i] assert cached_data["class"] == original_data["class"] original_array = PILToTensor()(Image.open(original_data["image"])) - assert torch.equal(original_array, cached_data["image"]) + cached_array = PILToTensor()(Image.open(cached_data["image"])) + assert torch.equal(original_array, cached_array) streaming_dataset_iter = iter(streaming_dataset) for _ in streaming_dataset_iter: diff --git a/tests/streaming/test_writer.py b/tests/streaming/test_writer.py index 04ae209fc..bb198528f 100644 --- a/tests/streaming/test_writer.py +++ b/tests/streaming/test_writer.py @@ -159,7 +159,7 @@ def test_binary_writer_with_jpeg_filepath_and_int(tmpdir): cache_dir = os.path.join(tmpdir, "chunks") os.makedirs(cache_dir, exist_ok=True) - binary_writer = BinaryWriter(cache_dir, chunk_bytes=2 << 12) + binary_writer = BinaryWriter(cache_dir, chunk_size=7) # each chunk will have 7 items imgs = [] @@ -172,23 +172,25 @@ def test_binary_writer_with_jpeg_filepath_and_int(tmpdir): imgs.append(img) binary_writer[i] = {"x": path, "y": i} - assert len(os.listdir(cache_dir)) == 24 + assert len(os.listdir(cache_dir)) == 14 # 100 items / 7 items per chunk = 14 chunks binary_writer.done() binary_writer.merge() - assert len(os.listdir(cache_dir)) == 26 + assert len(os.listdir(cache_dir)) == 16 # 2 items in last chunk and index.json file with open(os.path.join(cache_dir, "index.json")) as f: data = json.load(f) - assert data["chunks"][0]["chunk_size"] == 4 - assert data["chunks"][1]["chunk_size"] == 4 - assert data["chunks"][-1]["chunk_size"] == 4 + assert data["chunks"][0]["chunk_size"] == 7 + assert data["chunks"][1]["chunk_size"] == 7 + assert data["chunks"][-1]["chunk_size"] == 2 assert sum([chunk["chunk_size"] for chunk in data["chunks"]]) == 100 reader = BinaryReader(cache_dir, max_cache_size=10 ^ 9) for i in range(100): - data = reader.read(ChunkedIndex(i, chunk_index=i // 4)) - np.testing.assert_array_equal(np.asarray(data["x"]).squeeze(0), imgs[i]) + data = reader.read(ChunkedIndex(i, chunk_index=i // 7)) + img_read = Image.open(data["x"]) + print(f"{img_read.size=}") + np.testing.assert_array_equal(img_read, imgs[i]) assert data["y"] == i