Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs) -> None:
# disable unnecessary multiprocessing caching
from monai.data.dataset import CacheDataset # avoid circular import

if isinstance(dataset, CacheDataset) and dataset.runtime_cache:
if isinstance(dataset, CacheDataset):
dataset.disable_share_memory_cache()

_g.manual_seed(init_seed)
Expand Down
24 changes: 17 additions & 7 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ def __init__(
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads if computing cache in the initialization.
If num_workers is None then the number returned by os.cpu_count() is used.
If a value less than 1 is speficied, 1 will be used instead.
If a value less than 1 is specified, 1 will be used instead.
progress: whether to display a progress bar.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
Expand All @@ -778,15 +778,18 @@ def __init__(
hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.
defaults to `monai.data.utils.pickle_hashing`.
runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
the cache content at initializaiton, if `True`, it will cache during the first epoch
the cache content at initialization, if `True`, it will cache during the first epoch
of model training, so it can start the first mini-batch earlier. please note that:
1. when using this option in multi-gpu distributed training,
`torch.cuda.set_device()` must be called before initializing this class.
2. to execute `runtime cache` on GPU memory, must co-work with
2. if caching data that is in GPU memory during multi-gpu distributed training, this option
Comment thread
myron marked this conversation as resolved.
should not be used, since the underlying shared cache only works for CPU shared memory.
3. to execute `runtime cache` on GPU memory, must co-work with
`monai.data.DataLoader`, and can't work with `monai.data.DistributedSampler`
as GPU Tensor usually can't be shared in the multiprocessing context.
(try ``cache_dataset.disable_share_memory_cache()`` in case of GPU caching issues.)


"""
if not isinstance(transform, Compose):
transform = Compose(transform)
Expand Down Expand Up @@ -827,7 +830,7 @@ def _compute_cache(indices=None):
cache = Manager().list([None for _ in range(self.cache_num)])
if self._is_dist:
obj_list = [cache]
# broadcast the ProxyList to all the ranks, then share the same cache content at runtime
# broadcast the ListProxy to all the ranks, then share the same cache content at runtime
dist.broadcast_object_list(obj_list, src=0)
cache = obj_list[0]
else:
Expand All @@ -848,11 +851,18 @@ def _compute_cache(indices=None):

def disable_share_memory_cache(self):
"""
If the cache content is multiprocessing share memory list, convert it to a regular ptython list.
Because multiprocessing ProxyList is not supported for the GPU caching, may need to explicitly diasble it.
If the cache content is a multiprocessing shared memory ListProxy, convert it to a regular python list.
Because multiprocessing ListProxy is not supported for the GPU caching, explicitly disable it.

"""
self._cache = list(self._cache)
if self.runtime_cache:
if not self._is_dist:
self._cache = list(self._cache)
else:
warnings.warn(
"Unable to disable shared cache in DDP, when runtime_cache==True."
"Please use runtime_cache=False option to explicitly not use the shared cache."
)

def _fill_cache(self, indices=None) -> List:
"""
Expand Down