-
Notifications
You must be signed in to change notification settings - Fork 89
StreamingDataset and StreamingDataLoader hangs after cancelling once #452
Description
🐛 Bug
I have an optimized dataset on a slurm HPC that I optimized with zstd compression, and 128MB chunk size, and the latest litdata. I can successfully train a model with a StreamingDataset and a StreamingDataLoader right after creating the optimized dataset with fabric. But after the first run is cancelled, I find that the dataset will always get stuck at a specific item. Explicitly setting the cache directory to a new, empty directory doesn't seem to help.
I tried optimizing again, and interestingly the same behavior happens again. i.e., I can train it once, and cancelling it causes all further usages stuck.
The way I use the dataset is:
fabric.setup()
train_loader = StreamingDataLoader(
StreamingDataset(
str(dataset_path),
cache_dir=str(local_cache_dir),
shuffle=True,
),
num_workers=data_config_dict["num_workers"],
batch_size=data_config_dict["train_batch_size"],
drop_last=True,
pin_memory=True,
)
train_loader = fabric.setup_dataloader(train_loader)
for batch in train_loader:
...Any ideas on what other things to test?
To Reproduce
I tried to simply manually iterate over all samples with a python loop without the DataLoader, this seems to finish running. However, if I use a StreamingDataLoader and manually iterate through it, it will indeed be stuck at some point. By stuck, I mean getting stuck for more than 30 minutes. If I interrupt at this point, the trace looks like this:
File "<stdin>", line 1, in <module>
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/tqdm/std.py", line 1181, in __iter__
for obj in iterable:
^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/litdata/streaming/dataloader.py", line 622, in __iter__
for batch in super().__iter__():
^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 757, in _next_data
data = self._dataset_fetcher.fetch(index) # may raise StopIteration
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 381, in __next__
data = self.__getitem__(
^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/litdata/streaming/dataset.py", line 349, in __getitem__
return self.cache[index]
~~~~~~~~~~^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/litdata/streaming/cache.py", line 145, in __getitem__
return self._reader.read(index)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/litdata/streaming/reader.py", line 287, in read
item = self._item_loader.load_item_from_chunk(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/litdata/streaming/item_loader.py", line 145, in load_item_from_chunk
sleep(0.1)or this:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/tqdm/std.py", line 1181, in __iter__
for obj in iterable:
^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/litdata/streaming/dataloader.py", line 622, in __iter__
for batch in super().__iter__():
^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1448, in _next_data
idx, data = self._get_data()
^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1412, in _get_data
success, data = self._try_get_data()
^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1243, in _try_get_data
data = self._data_queue.get(timeout=timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/multiprocessing/queues.py", line 113, in get
if not self._poll(timeout):
^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/multiprocessing/connection.py", line 257, in poll
return self._poll(timeout)
^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/multiprocessing/connection.py", line 440, in _poll
r = wait([self], timeout)
^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/multiprocessing/connection.py", line 1136, in wait
ready = selector.select(timeout)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/neuro/lib/python3.12/selectors.py", line 415, in select
fd_event_list = self._selector.poll(timeout)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterruptExpected behavior
The second time of using the dataset should load normally.