Skip to content

ValueError with tree_unflatten when trying to read local cache #62

@mads-oestergaard

Description

@mads-oestergaard

🐛 Bug

For some reason I get this bug trying to iterate over a local dataset. The dataset contains raw bytes (read with audio_path.read_bytes() where audio_path is a pathlib.Path object) and a filename (str). I optimize with

import torchaudio
from litdata import optimize, StreamingDataset
from pathlib import Path

def fn(path: Path):
    return {"data": path.read_bytes(), "filename": path.stem}

output_dir = "~/my-special-place"
input_dir = Path("path/to/wav/files")

optimize(fn, list(input_dir.rglob("*.wav")), "~/my-special-place", chunk_bytes="64MB"), num_workers=4)

class Dataset(StreamingDataset):
    def __getitem__(self, idx: int) -> tuple[torch.Tensor, int, str]:
        obj = super().__getitem__(idx)
        x, sr = torchaudio.load(BytesIO(obj["data"]))
        return x, sr, obj["filename"]

dataset = Dataset("local:" + output_dir)
dataloader = DataLoader(
    dataset,
    batch_size=1,
    num_workers=4,
)
for x, sr, filename in tqdm(dataloader, total=len(inputs), desc="Verifying"):
    pass

but then it (always) fails with this ValueError:

ValueError: Caught ValueError in DataLoader worker process 3.
Original Traceback (most recent call last):
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 32, in fetch
    data.append(next(self.dataset_iter))
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 298, in __next__
    data = self.__getitem__(
  File "/home/mads/Repos/dev-mads/my-repo/scripts/upload_data_for_streaming-litdata.py", line 20, in __getitem__
    obj = super().__getitem__(idx)
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/litdata/streaming/dataset.py", line 268, in __getitem__
    return self.cache[index]
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/litdata/streaming/cache.py", line 135, in __getitem__
    return self._reader.read(index)
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/litdata/streaming/reader.py", line 252, in read
    item = self._item_loader.load_item_from_chunk(index.index, index.chunk_index, chunk_filepath, begin)
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/litdata/streaming/item_loader.py", line 106, in load_item_from_chunk
    return self.deserialize(data)
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/litdata/streaming/item_loader.py", line 127, in deserialize
    return tree_unflatten(data, self._config["data_spec"])
  File "/home/mads/Repos/dev-mads/my-repo/.venv/lib/python3.10/site-packages/torch/utils/_pytree.py", line 552, in tree_unflatten
    raise ValueError(
ValueError: tree_unflatten(leaves, treespec): `leaves` has length 0 but the spec refers to a pytree that holds 2 items (TreeSpec(dict, ['data', 'filename'], [*,
  *])).

Code sample

Expected behavior

Environment

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions