Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions monai/data/thread_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class ThreadDataLoader(DataLoader):
on the same batch will still produce good training with minimal short-term overfitting while allowing a slow batch
generation process more time to produce a result.

Another typical usage is to accelerate light-weight preprocessing (usually cached all the deterministic transforms
and no IO operations), because it leverages the separate thread to execute preprocessing to avoid unnecessary IPC
between multiple workers of DataLoader. And as CUDA may not work well with the multi-processing of DataLoader,
`ThreadDataLoader` can be useful for GPU transforms. For more details:
https://github.com/Project-MONAI/tutorials/blob/master/acceleration/fast_model_training_guide.md.

See:
* Fischetti et al. "Faster SGD training by minibatch persistency." ArXiv (2018) https://arxiv.org/abs/1806.07353
* Dami et al., "Faster Neural Network Training with Data Echoing" ArXiv (2020) https://arxiv.org/abs/1907.05550
Expand All @@ -99,20 +105,15 @@ class ThreadDataLoader(DataLoader):
dataset: input dataset.
buffer_size: number of items to buffer from the data source.
buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items.
num_workers: number of the multi-processing workers in PyTorch DataLoader.
repeats: number of times to yield the same batch
repeats: number of times to yield the same batch.
kwargs: other arguments for `DataLoader` except for `dataset`.

"""

def __init__(
self,
dataset: Dataset,
buffer_size: int = 1,
buffer_timeout: float = 0.01,
num_workers: int = 0,
repeats: int = 1,
**kwargs,
self, dataset: Dataset, buffer_size: int = 1, buffer_timeout: float = 0.01, repeats: int = 1, **kwargs
):
super().__init__(dataset, num_workers, **kwargs)
super().__init__(dataset, **kwargs)
self.buffer_size = buffer_size
self.buffer_timeout = buffer_timeout
self.repeats = repeats
Expand Down