From 9eec03529da9b74b11b35221fe727e057239a1ed Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 20 Dec 2021 22:53:26 +0800 Subject: [PATCH 1/3] [DLMED] add missing args Signed-off-by: Nic Ma --- monai/apps/datasets.py | 34 ++++++++++++++++++++-- monai/apps/pathology/data/datasets.py | 6 ++++ tests/test_decathlondataset.py | 1 + tests/test_mednistdataset.py | 4 ++- tests/test_smartcache_patch_wsi_dataset.py | 1 + 5 files changed, 43 insertions(+), 3 deletions(-) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 2b2f48f5d0..fa50dceb00 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -51,6 +51,12 @@ class MedNISTDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. + progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. Raises: ValueError: When ``root_dir`` is not a directory. @@ -75,6 +81,8 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, + progress: bool = True, + copy_cache: bool = True, ) -> None: root_dir = Path(root_dir) if not root_dir.is_dir(): @@ -97,7 +105,14 @@ def __init__( if transform == (): transform = LoadImaged("image") CacheDataset.__init__( - self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + self, + data=data, + transform=transform, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + progress=progress, + copy_cache=copy_cache, ) def randomize(self, data: List[int]) -> None: @@ -177,6 +192,12 @@ class DecathlonDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. + progress: whether to display a progress bar. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cached content + (for example, randomly crop from the cached image and deepcopy the crop region) + or if every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. Raises: ValueError: When ``root_dir`` is not a directory. @@ -241,6 +262,8 @@ def __init__( cache_num: int = sys.maxsize, cache_rate: float = 1.0, num_workers: int = 0, + progress: bool = True, + copy_cache: bool = True, ) -> None: root_dir = Path(root_dir) if not root_dir.is_dir(): @@ -277,7 +300,14 @@ def __init__( if transform == (): transform = LoadImaged(["image", "label"]) CacheDataset.__init__( - self, data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers + self, + data=data, + transform=transform, + cache_num=cache_num, + cache_rate=cache_rate, + num_workers=num_workers, + progress=progress, + copy_cache=copy_cache, ) def get_indices(self) -> np.ndarray: diff --git a/monai/apps/pathology/data/datasets.py b/monai/apps/pathology/data/datasets.py index c9521b1201..10f31fbec8 100644 --- a/monai/apps/pathology/data/datasets.py +++ b/monai/apps/pathology/data/datasets.py @@ -122,6 +122,10 @@ class SmartCachePatchWSIDataset(SmartCacheDataset): num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch. If num_replace_workers is None then the number returned by os.cpu_count() is used. progress: whether to display a progress bar when caching for the first epoch. + copy_cache: whether to `deepcopy` the cache content before applying the random transforms, + default to `True`. if the random transforms don't modify the cache content + or every cache item is only used once in a `multi-processing` environment, + may set `copy=False` for better performance. """ @@ -139,6 +143,7 @@ def __init__( num_init_workers: Optional[int] = None, num_replace_workers: Optional[int] = None, progress: bool = True, + copy_cache: bool = True, ): patch_wsi_dataset = PatchWSIDataset( data=data, @@ -157,6 +162,7 @@ def __init__( num_replace_workers=num_replace_workers, progress=progress, shuffle=False, + copy_cache=copy_cache, ) diff --git a/tests/test_decathlondataset.py b/tests/test_decathlondataset.py index 0756902385..ee2d92318c 100644 --- a/tests/test_decathlondataset.py +++ b/tests/test_decathlondataset.py @@ -47,6 +47,7 @@ def _test_dataset(dataset): transform=transform, section="validation", download=True, + copy_cache=False, ) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) diff --git a/tests/test_mednistdataset.py b/tests/test_mednistdataset.py index a833ab75f3..18f9f0192b 100644 --- a/tests/test_mednistdataset.py +++ b/tests/test_mednistdataset.py @@ -43,7 +43,9 @@ def _test_dataset(dataset): self.assertTupleEqual(dataset[0]["image"].shape, (1, 64, 64)) try: # will start downloading if testing_dir doesn't have the MedNIST files - data = MedNISTDataset(root_dir=testing_dir, transform=transform, section="test", download=True) + data = MedNISTDataset( + root_dir=testing_dir, transform=transform, section="test", download=True, copy_cache=False + ) except (ContentTooShortError, HTTPError, RuntimeError) as e: print(str(e)) if isinstance(e, RuntimeError): diff --git a/tests/test_smartcache_patch_wsi_dataset.py b/tests/test_smartcache_patch_wsi_dataset.py index 2150ede51c..5fe471e0d2 100644 --- a/tests/test_smartcache_patch_wsi_dataset.py +++ b/tests/test_smartcache_patch_wsi_dataset.py @@ -45,6 +45,7 @@ "cache_num": 2, "num_init_workers": 1, "num_replace_workers": 1, + "copy_cache": False, }, [ {"image": np.array([[[239]], [[239]], [[239]]], dtype=np.uint8), "label": np.array([[[0]]])}, From 62279581082d923b30c24a2b2792a22a4abd5bf8 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Dec 2021 14:46:18 +0800 Subject: [PATCH 2/3] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 497d42cf0c..c5ef821b69 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -51,7 +51,7 @@ class MedNISTDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. - progress: whether to display a progress bar. + progress: whether to display a progress bar when computing the transform cache content. copy_cache: whether to `deepcopy` the cache content before applying the random transforms, default to `True`. if the random transforms don't modify the cached content (for example, randomly crop from the cached image and deepcopy the crop region) @@ -192,7 +192,7 @@ class DecathlonDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. - progress: whether to display a progress bar. + progress: whether to display a progress bar when computing the transform cache content. copy_cache: whether to `deepcopy` the cache content before applying the random transforms, default to `True`. if the random transforms don't modify the cached content (for example, randomly crop from the cached image and deepcopy the crop region) From c2bc9a30007ff504ad6b09ad4c34333d5fa7ad38 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Dec 2021 16:22:55 +0800 Subject: [PATCH 3/3] [DLMED] update progress arg Signed-off-by: Nic Ma --- monai/apps/datasets.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index c5ef821b69..90a0f95ced 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -51,7 +51,7 @@ class MedNISTDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. - progress: whether to display a progress bar when computing the transform cache content. + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. copy_cache: whether to `deepcopy` the cache content before applying the random transforms, default to `True`. if the random transforms don't modify the cached content (for example, randomly crop from the cached image and deepcopy the crop region) @@ -95,7 +95,14 @@ def __init__( dataset_dir = root_dir / self.dataset_folder_name self.num_class = 0 if download: - download_and_extract(self.resource, tarfile_name, root_dir, self.md5) + download_and_extract( + url=self.resource, + filepath=tarfile_name, + output_dir=root_dir, + hash_val=self.md5, + hash_type="md5", + progress=progress, + ) if not dataset_dir.is_dir(): raise RuntimeError( @@ -192,7 +199,7 @@ class DecathlonDataset(Randomizable, CacheDataset): will take the minimum of (cache_num, data_length x cache_rate, data_length). num_workers: the number of worker threads to use. if 0 a single thread will be used. Default is 0. - progress: whether to display a progress bar when computing the transform cache content. + progress: whether to display a progress bar when downloading dataset and computing the transform cache content. copy_cache: whether to `deepcopy` the cache content before applying the random transforms, default to `True`. if the random transforms don't modify the cached content (for example, randomly crop from the cached image and deepcopy the crop region) @@ -276,7 +283,14 @@ def __init__( dataset_dir = root_dir / task tarfile_name = f"{dataset_dir}.tar" if download: - download_and_extract(self.resource[task], tarfile_name, root_dir, self.md5[task]) + download_and_extract( + url=self.resource[task], + filepath=tarfile_name, + output_dir=root_dir, + hash_val=self.md5[task], + hash_type="md5", + progress=progress, + ) if not dataset_dir.exists(): raise RuntimeError(