Skip to content
52 changes: 48 additions & 4 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class MedNISTDataset(Randomizable, CacheDataset):
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
if 0 a single thread will be used. Default is 0.
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.

Raises:
ValueError: When ``root_dir`` is not a directory.
Expand All @@ -75,6 +81,8 @@ def __init__(
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: int = 0,
progress: bool = True,
copy_cache: bool = True,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand All @@ -87,7 +95,14 @@ def __init__(
dataset_dir = root_dir / self.dataset_folder_name
self.num_class = 0
if download:
download_and_extract(self.resource, tarfile_name, root_dir, self.md5)
download_and_extract(
url=self.resource,
filepath=tarfile_name,
output_dir=root_dir,
hash_val=self.md5,
hash_type="md5",
progress=progress,
)

if not dataset_dir.is_dir():
raise RuntimeError(
Expand All @@ -97,7 +112,14 @@ def __init__(
if transform == ():
transform = LoadImaged("image")
CacheDataset.__init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
self,
data=data,
transform=transform,
cache_num=cache_num,
cache_rate=cache_rate,
num_workers=num_workers,
progress=progress,
copy_cache=copy_cache,
)

def randomize(self, data: List[int]) -> None:
Expand Down Expand Up @@ -177,6 +199,12 @@ class DecathlonDataset(Randomizable, CacheDataset):
will take the minimum of (cache_num, data_length x cache_rate, data_length).
num_workers: the number of worker threads to use.
if 0 a single thread will be used. Default is 0.
progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cached content
(for example, randomly crop from the cached image and deepcopy the crop region)
or if every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.

Raises:
ValueError: When ``root_dir`` is not a directory.
Expand Down Expand Up @@ -241,6 +269,8 @@ def __init__(
cache_num: int = sys.maxsize,
cache_rate: float = 1.0,
num_workers: int = 0,
progress: bool = True,
copy_cache: bool = True,
) -> None:
root_dir = Path(root_dir)
if not root_dir.is_dir():
Expand All @@ -253,7 +283,14 @@ def __init__(
dataset_dir = root_dir / task
tarfile_name = f"{dataset_dir}.tar"
if download:
download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task])
download_and_extract(
url=self.resource[task],
filepath=tarfile_name,
output_dir=root_dir,
hash_val=self.md5[task],
hash_type="md5",
progress=progress,
)

if not dataset_dir.exists():
raise RuntimeError(
Expand All @@ -277,7 +314,14 @@ def __init__(
if transform == ():
transform = LoadImaged(["image", "label"])
CacheDataset.__init__(
self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers
self,
data=data,
transform=transform,
cache_num=cache_num,
cache_rate=cache_rate,
num_workers=num_workers,
progress=progress,
copy_cache=copy_cache,
)

def get_indices(self) -> np.ndarray:
Expand Down
6 changes: 6 additions & 0 deletions monai/apps/pathology/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class SmartCachePatchWSIDataset(SmartCacheDataset):
num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch.
If num_replace_workers is None then the number returned by os.cpu_count() is used.
progress: whether to display a progress bar when caching for the first epoch.
copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
default to `True`. if the random transforms don't modify the cache content
or every cache item is only used once in a `multi-processing` environment,
may set `copy=False` for better performance.

"""

Expand All @@ -139,6 +143,7 @@ def __init__(
num_init_workers: Optional[int] = None,
num_replace_workers: Optional[int] = None,
progress: bool = True,
copy_cache: bool = True,
):
patch_wsi_dataset = PatchWSIDataset(
data=data,
Expand All @@ -157,6 +162,7 @@ def __init__(
num_replace_workers=num_replace_workers,
progress=progress,
shuffle=False,
copy_cache=copy_cache,
)


Expand Down
1 change: 1 addition & 0 deletions tests/test_decathlondataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def _test_dataset(dataset):
transform=transform,
section="validation",
download=True,
copy_cache=False,
)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
Expand Down
4 changes: 3 additions & 1 deletion tests/test_mednistdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def _test_dataset(dataset):
self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64))

try: # will start downloading if testing_dir doesn't have the MedNIST files
data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=True)
data = MedNISTDataset(
root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False
)
except (ContentTooShortError, HTTPError, RuntimeError) as e:
print(str(e))
if isinstance(e, RuntimeError):
Expand Down
1 change: 1 addition & 0 deletions tests/test_smartcache_patch_wsi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
"cache_num": 2,
"num_init_workers": 1,
"num_replace_workers": 1,
"copy_cache": False,
},
[
{"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0]]])},
Expand Down