From 8acbc520cfa054c5c9a9f68b7e5e0b96d891ba18 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 10 Jan 2022 18:26:50 +0800 Subject: [PATCH 1/5] [DLMED] update GridPatchDataset Signed-off-by: Nic Ma --- monai/data/grid_dataset.py | 66 +++++++++++++++----------------------- tests/test_grid_dataset.py | 17 +++++----- 2 files changed, 35 insertions(+), 48 deletions(-) diff --git a/monai/data/grid_dataset.py b/monai/data/grid_dataset.py index 7a0f79d00e..9eb84a58c9 100644 --- a/monai/data/grid_dataset.py +++ b/monai/data/grid_dataset.py @@ -9,16 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, Optional, Sequence, Union - -import numpy as np -import torch -from torch.utils.data import IterableDataset +from typing import Callable, Dict, Iterable, Optional, Sequence, Union from monai.data.dataset import Dataset +from monai.data.iterable_dataset import IterableDataset from monai.data.utils import iter_patch from monai.transforms import apply_transform -from monai.utils import NumpyPadMode, ensure_tuple, look_up_option +from monai.utils import NumpyPadMode, deprecated_arg, ensure_tuple, look_up_option __all__ = ["PatchDataset", "GridPatchDataset", "PatchIter"] @@ -96,7 +93,7 @@ class GridPatchDataset(IterableDataset): patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) # construct the dataset - ds = GridPatchDataset(dataset=images, + ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity) # use the grid patch dataset @@ -108,49 +105,34 @@ class GridPatchDataset(IterableDataset): # coordinates: tensor([[[0, 1], [0, 2], [0, 2]], # [[0, 1], [2, 4], [0, 2]]]) + Args: + data: the data source to read image data from. + patch_iter: converts an input image (item from dataset) into a iterable of image patches. + `patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates). + see also: :py:class:`monai.data.PatchIter`. + transform: a callable data transform operates on the patches. + with_coordinates: whether to yield the coordinates of each patch, default to `True`. + + .. deprecated:: 0.8.0 + ``dataset`` is deprecated, use ``data`` instead. + """ + @deprecated_arg(name="dataset", new_name="data", since="0.8", msg_suffix="please use `data` instead.") def __init__( self, - dataset: Sequence, + data: Union[Iterable, Sequence], patch_iter: Callable, transform: Optional[Callable] = None, with_coordinates: bool = True, ) -> None: - """ - Initializes this dataset in terms of the image dataset, patch generator, and an optional transform. - - Args: - dataset: the dataset to read image data from. - patch_iter: converts an input image (item from dataset) into a iterable of image patches. - `patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates). - see also: :py:class:`monai.data.PatchIter`. - transform: a callable data transform operates on the patches. - with_coordinates: whether to yield the coordinates of each patch, default to `True`. - - """ - - self.dataset = dataset + super().__init__(data=data, transform=None) self.patch_iter = patch_iter self.transform = transform self.with_coordinates = with_coordinates def __iter__(self): - worker_info = torch.utils.data.get_worker_info() - iter_start, iter_end = 0, 1 - try: - iter_end = len(self.dataset) # TODO: support iterable self.dataset - except TypeError: - raise NotImplementedError("image dataset must implement `len()`.") from None - - if worker_info is not None: - # split workload - per_worker = int(np.ceil((iter_end - iter_start) / float(worker_info.num_workers))) - iter_start += worker_info.id * per_worker - iter_end = min(iter_start + per_worker, iter_end) - - for index in range(iter_start, iter_end): - image = self.dataset[index] + for image in super().__iter__(): if not self.with_coordinates: for patch, *_ in self.patch_iter(image): # patch_iter to yield at least 1 item: patch out_patch = ( @@ -204,20 +186,24 @@ class PatchDataset(Dataset): >>> torch.Size([2, 1, 3, 3]) + .. deprecated:: 0.8.0 + ``dataset`` is deprecated, use ``data`` instead. + """ + @deprecated_arg(name="dataset", new_name="data", since="0.8", msg_suffix="please use `data` instead.") def __init__( - self, dataset: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Optional[Callable] = None + self, data: Sequence, patch_func: Callable, samples_per_image: int = 1, transform: Optional[Callable] = None ) -> None: """ Args: - dataset: an image dataset to extract patches from. + data: an image dataset to extract patches from. patch_func: converts an input image (item from dataset) into a sequence of image patches. patch_func(dataset[idx]) must return a sequence of patches (length `samples_per_image`). samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements. transform: transform applied to each patch. """ - super().__init__(data=dataset, transform=transform) + super().__init__(data=data, transform=transform) self.patch_func = patch_func if samples_per_image <= 0: diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 9c4bcc52ae..9f24ce4463 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -33,31 +33,32 @@ def tearDown(self): set_determinism(None) def test_shape(self): - test_dataset = ["vwxyz", "helloworld", "worldfoobar"] - result = GridPatchDataset(dataset=test_dataset, patch_iter=identity_generator, with_coordinates=False) + # test Iterable input data + test_dataset = iter(["vwxyz", "helloworld", "worldfoobar"]) + result = GridPatchDataset(data=test_dataset, patch_iter=identity_generator, with_coordinates=False) output = [] n_workers = 0 if sys.platform == "win32" else 2 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) - expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"] + expected = ["d", "dfo", "hel", "low", "oba", "orl", "orl", "r", "vwx", "yzw"] self.assertEqual(sorted(output), sorted(expected)) self.assertEqual(len("".join(expected)), len("".join(test_dataset))) def test_loading_array(self): set_determinism(seed=1234) - # image dataset + # test sequence input data with images images = [np.arange(16, dtype=float).reshape(1, 4, 4), np.arange(16, dtype=float).reshape(1, 4, 4)] # image level patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0)) - ds = GridPatchDataset(dataset=images, patch_iter=patch_iter, transform=patch_intensity) + ds = GridPatchDataset(data=images, patch_iter=patch_iter, transform=patch_intensity) # use the grid patch dataset for item in DataLoader(ds, batch_size=2, shuffle=False, num_workers=0): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[2.0577, 3.0577], [6.0577, 7.0577]]], [[[10.5540, 11.5540], [14.5540, 15.5540]]]]), - rtol=1e-5, + np.array([[[[1.4965, 2.4965], [5.4965, 6.4965]]], [[[11.3584, 12.3584], [15.3584, 16.3584]]]]), + rtol=1e-4, ) np.testing.assert_allclose(item[1], np.array([[[0, 1], [0, 2], [2, 4]], [[0, 1], [2, 4], [2, 4]]]), rtol=1e-5) if sys.platform != "win32": @@ -65,7 +66,7 @@ def test_loading_array(self): np.testing.assert_equal(tuple(item[0].shape), (2, 1, 2, 2)) np.testing.assert_allclose( item[0], - np.array([[[[1.6533, 2.6533], [5.6533, 6.6533]]], [[[9.8524, 10.8524], [13.8524, 14.8524]]]]), + np.array([[[[1.2548, 2.2548], [5.2548, 6.2548]]], [[[9.1106, 10.1106], [13.1106, 14.1106]]]]), rtol=1e-3, ) np.testing.assert_allclose( From 12edcac4107798b83707ed75e4c2a00968baae09 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 11 Jan 2022 07:59:39 +0800 Subject: [PATCH 2/5] [DLMED] fix test typo Signed-off-by: Nic Ma --- tests/test_patch_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_patch_dataset.py b/tests/test_patch_dataset.py index 1796ad4f23..a46c117b75 100644 --- a/tests/test_patch_dataset.py +++ b/tests/test_patch_dataset.py @@ -29,7 +29,7 @@ def test_shape(self): test_dataset = ["vwxyz", "hello", "world"] n_per_image = len(test_dataset[0]) - result = PatchDataset(dataset=test_dataset, patch_func=identity, samples_per_image=n_per_image) + result = PatchDataset(data=test_dataset, patch_func=identity, samples_per_image=n_per_image) output = [] n_workers = 0 if sys.platform == "win32" else 2 @@ -50,7 +50,7 @@ def test_loading_array(self): patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0) image_ds = Dataset(images, transform=patch_intensity) # patch level - ds = PatchDataset(dataset=image_ds, patch_func=sampler, samples_per_image=n_samples, transform=patch_intensity) + ds = PatchDataset(data=image_ds, patch_func=sampler, samples_per_image=n_samples, transform=patch_intensity) np.testing.assert_equal(len(ds), n_samples * len(images)) # use the patch dataset, length: len(images) x samplers_per_image From dee13b7dc06f25782b2b7f36ab3d58ee47a0fee6 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 11 Jan 2022 16:29:46 +0800 Subject: [PATCH 3/5] [DLMED] update for windows Signed-off-by: Nic Ma --- tests/test_grid_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 9f24ce4463..ff3af21355 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -40,7 +40,11 @@ def test_shape(self): n_workers = 0 if sys.platform == "win32" else 2 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) - expected = ["d", "dfo", "hel", "low", "oba", "orl", "orl", "r", "vwx", "yzw"] + if sys.platform == "win32": + expected = ["ar", "ell", "ldf", "oob", "owo", "rld", "vwx", "wor", "yzh"] + else: + expected = ["d", "dfo", "hel", "low", "oba", "orl", "orl", "r", "vwx", "yzw"] + self.assertEqual(sorted(output), sorted(expected)) self.assertEqual(len("".join(expected)), len("".join(test_dataset))) From c0fedd97a3f7dc2706571c12db37a4b85fe04d41 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 11 Jan 2022 20:37:38 +0800 Subject: [PATCH 4/5] [DLMED] fix windows test Signed-off-by: Nic Ma --- tests/test_grid_dataset.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index ff3af21355..16edb0609d 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -34,19 +34,16 @@ def tearDown(self): def test_shape(self): # test Iterable input data - test_dataset = iter(["vwxyz", "helloworld", "worldfoobar"]) + test_dataset = iter(["vwxyz", "worldfoobar", "helloworld"]) result = GridPatchDataset(data=test_dataset, patch_iter=identity_generator, with_coordinates=False) output = [] n_workers = 0 if sys.platform == "win32" else 2 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) - if sys.platform == "win32": - expected = ["ar", "ell", "ldf", "oob", "owo", "rld", "vwx", "wor", "yzh"] - else: - expected = ["d", "dfo", "hel", "low", "oba", "orl", "orl", "r", "vwx", "yzw"] + expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"] self.assertEqual(sorted(output), sorted(expected)) - self.assertEqual(len("".join(expected)), len("".join(test_dataset))) + self.assertEqual(len("".join(expected)), len("".join(list(test_dataset)))) def test_loading_array(self): set_determinism(seed=1234) From a2eea6b45ae21dca79c57886b6ece2cf0fe453ae Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 11 Jan 2022 22:00:56 +0800 Subject: [PATCH 5/5] [DLMED] update test Signed-off-by: Nic Ma --- tests/test_grid_dataset.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_grid_dataset.py b/tests/test_grid_dataset.py index 16edb0609d..9361d82cdf 100644 --- a/tests/test_grid_dataset.py +++ b/tests/test_grid_dataset.py @@ -34,16 +34,18 @@ def tearDown(self): def test_shape(self): # test Iterable input data - test_dataset = iter(["vwxyz", "worldfoobar", "helloworld"]) + test_dataset = iter(["vwxyz", "helloworld", "worldfoobar"]) result = GridPatchDataset(data=test_dataset, patch_iter=identity_generator, with_coordinates=False) output = [] n_workers = 0 if sys.platform == "win32" else 2 for item in DataLoader(result, batch_size=3, num_workers=n_workers): output.append("".join(item)) - expected = ["vwx", "wor", "yzh", "ldf", "ell", "oob", "owo", "ar", "rld"] - + if sys.platform == "win32": + expected = ["ar", "ell", "ldf", "oob", "owo", "rld", "vwx", "wor", "yzh"] + else: + expected = ["d", "dfo", "hel", "low", "oba", "orl", "orl", "r", "vwx", "yzw"] + self.assertEqual(len("".join(expected)), len("".join(list(test_dataset)))) self.assertEqual(sorted(output), sorted(expected)) - self.assertEqual(len("".join(expected)), len("".join(list(test_dataset)))) def test_loading_array(self): set_determinism(seed=1234)