From 1e2b8dbb70d1cf981b5fde3cef66b666c6488e73 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Nov 2021 00:02:08 +0800 Subject: [PATCH 1/9] [DLMED] fix 2 CSVIterableDataset issues Signed-off-by: Nic Ma --- monai/data/iterable_dataset.py | 19 +++++++----- tests/test_csv_iterable_dataset.py | 50 ++++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 21 deletions(-) diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index fd52ac041f..58d6ca9ad4 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -74,9 +74,8 @@ def __init__(self, data, transform=None, buffer_size: int = 512, shuffle: bool = def _rand_pop(self, buffer: List, num_workers: int = 1, id: int = 0): length = len(buffer) for i in range(min(length, num_workers)): - if self.shuffle: - # randomly select an item for every worker and pop - self.randomize(length) + # randomly select an item for every worker and pop + self.randomize(length) # switch random index data and the last index data item, buffer[self._idx] = buffer[self._idx], buffer[-1] buffer.pop() @@ -92,12 +91,16 @@ def __iter__(self): id = info.id if info is not None else 0 _buffer = [] - for item in self.data: - if len(_buffer) >= self.size: - self._rand_pop(_buffer, num_workers=num_workers, id=id) - _buffer.append(item) + for i, item in enumerate(self.data): + if not self.shuffle: + if i % num_workers == id: + yield item + else: + if len(_buffer) >= self.size: + self._rand_pop(_buffer, num_workers=num_workers, id=id) + _buffer.append(item) while _buffer: - return self._rand_pop(_buffer, num_workers=num_workers, id=id) + yield from self._rand_pop(_buffer, num_workers=num_workers, id=id) def randomize(self, size: int) -> None: self._idx = self.R.randint(size) diff --git a/tests/test_csv_iterable_dataset.py b/tests/test_csv_iterable_dataset.py index b8e77364a6..3354c446f2 100644 --- a/tests/test_csv_iterable_dataset.py +++ b/tests/test_csv_iterable_dataset.py @@ -63,9 +63,11 @@ def prepare_csv_file(data, filepath): prepare_csv_file(test_data3, filepath3) # test single CSV file - dataset = CSVIterableDataset(filepath1) - for i, item in enumerate(dataset): - if i == 2: + dataset = CSVIterableDataset(filepath1, shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 3: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, str) else v for k, v in item.items()}, { @@ -78,16 +80,23 @@ def prepare_csv_file(data, filepath): }, ) break + self.assertEqual(count, 3) + # test reset iterables dataset.reset(filename=filepath3) + count = 0 for i, item in enumerate(dataset): - if i == 3: + count += 1 + if i == 4: self.assertEqual(item["meta_0"], False) + self.assertEqual(count, 5) # test multiple CSV files, join tables with kwargs - dataset = CSVIterableDataset([filepath1, filepath2, filepath3], on="subject_id") - for i, item in enumerate(dataset): - if i == 3: + dataset = CSVIterableDataset([filepath1, filepath2, filepath3], on="subject_id", shuffle=False) + count = 0 + for item in dataset: + count += 1 + if count == 4: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items()}, { @@ -110,15 +119,19 @@ def prepare_csv_file(data, filepath): "meta_2": True, }, ) + self.assertEqual(count, 5) # test selected columns and chunk size dataset = CSVIterableDataset( filename=[filepath1, filepath2, filepath3], chunksize=2, col_names=["subject_id", "image", "ehr_1", "ehr_7", "meta_1"], + shuffle=False, ) - for i, item in enumerate(dataset): - if i == 3: + count = 0 + for item in dataset: + count += 1 + if count == 4: self.assertDictEqual( {k: round(v, 4) if not isinstance(v, (str, np.bool_)) else v for k, v in item.items()}, { @@ -129,20 +142,25 @@ def prepare_csv_file(data, filepath): "meta_1": False, }, ) + self.assertEqual(count, 5) # test group columns dataset = CSVIterableDataset( filename=[filepath1, filepath2, filepath3], col_names=["subject_id", "image", *[f"ehr_{i}" for i in range(11)], "meta_0", "meta_1", "meta_2"], col_groups={"ehr": [f"ehr_{i}" for i in range(11)], "meta12": ["meta_1", "meta_2"]}, + shuffle=False, ) - for i, item in enumerate(dataset): - if i == 3: + count = 0 + for item in dataset: + count += 1 + if count == 4: np.testing.assert_allclose( [round(i, 4) for i in item["ehr"]], [3.3333, 3.2353, 3.4000, 3.1647, 3.0863, 3.7255, 3.6980, 3.6980, 3.7020, 3.3098, 3.7294], ) np.testing.assert_allclose(item["meta12"], [False, True]) + self.assertEqual(count, 5) # test transform dataset = CSVIterableDataset( @@ -158,23 +176,29 @@ def prepare_csv_file(data, filepath): [3.7725, 4.2118, 4.6353, 5.298, 9.5451], [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], [6.4275, 6.2549, 5.9765, 6.2627, 7.7176], - [6.8392, 6.4745, 5.8627, 5.1922, 5.2745], + [3.3333, 3.2353, 3.4000, 3.1647, 3.0863], ] + count = 0 for item, exp in zip(dataset, expected): + count += 1 self.assertTrue(isinstance(item["ehr"], np.ndarray)) np.testing.assert_allclose(np.around(item["ehr"], 4), exp) + self.assertEqual(count, 4) # test multiple processes loading - dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label")) + dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label"), shuffle=False) # set num workers = 0 for mac / win num_workers = 2 if sys.platform == "linux" else 0 dataloader = DataLoader(dataset=dataset, num_workers=num_workers, batch_size=2) + count = 0 for item in dataloader: + count += 1 # test the last item which only has 1 data if len(item) == 1: self.assertListEqual(item["subject_id"], ["s000002"]) np.testing.assert_allclose(item["label"], [4]) self.assertListEqual(item["image"], ["./imgs/s000002.png"]) + self.assertEqual(count, 3) if __name__ == "__main__": From ea642322abf409be3032b31641bba90d32f68854 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Nov 2021 00:28:45 +0800 Subject: [PATCH 2/9] [DLMED] fix length issue Signed-off-by: Nic Ma --- monai/data/iterable_dataset.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index 58d6ca9ad4..92087efe98 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -72,10 +72,9 @@ def __init__(self, data, transform=None, buffer_size: int = 512, shuffle: bool = self._idx = 0 def _rand_pop(self, buffer: List, num_workers: int = 1, id: int = 0): - length = len(buffer) - for i in range(min(length, num_workers)): + for i in range(min(len(buffer), num_workers)): # randomly select an item for every worker and pop - self.randomize(length) + self.randomize(len(buffer)) # switch random index data and the last index data item, buffer[self._idx] = buffer[self._idx], buffer[-1] buffer.pop() From 7e3747d2ebeb6617031966ddd50a74ec11ce1c90 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Nov 2021 00:54:09 +0800 Subject: [PATCH 3/9] [DLMED] add transform Signed-off-by: Nic Ma --- monai/data/iterable_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index 92087efe98..de72dca230 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -93,6 +93,8 @@ def __iter__(self): for i, item in enumerate(self.data): if not self.shuffle: if i % num_workers == id: + if self.transform is not None: + item = apply_transform(self.transform, item) yield item else: if len(_buffer) >= self.size: From b5bc53568ecbae4f14dafa5c6181e65c6c7cec7a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Nov 2021 19:08:52 +0800 Subject: [PATCH 4/9] [DLMED] simplify the logic Signed-off-by: Nic Ma --- docs/source/data.rst | 6 +- monai/data/__init__.py | 2 +- monai/data/iterable_dataset.py | 113 +++++++++++++++++------------ tests/test_csv_iterable_dataset.py | 9 ++- 4 files changed, 75 insertions(+), 55 deletions(-) diff --git a/docs/source/data.rst b/docs/source/data.rst index 7e4a25c0ab..c528deeacc 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -21,9 +21,9 @@ Generic Interfaces :members: :special-members: __next__ -`IterableBuffer` -~~~~~~~~~~~~~~~~ -.. autoclass:: IterableBuffer +`ShuffleBuffer` +~~~~~~~~~~~~~~~ +.. autoclass:: ShuffleBuffer :members: :special-members: __next__ diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 20e50f3e9a..df436ba667 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -28,7 +28,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader -from .iterable_dataset import CSVIterableDataset, IterableBuffer, IterableDataset +from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index de72dca230..67cb69230c 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union from torch.utils.data import IterableDataset as _TorchIterableDataset from torch.utils.data import get_worker_info @@ -29,8 +29,11 @@ class IterableDataset(_TorchIterableDataset): https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset. For example, typical input data can be web data stream which can support multi-process access. - Note that when used with `DataLoader` and `num_workers > 0`, each worker process will have a - different copy of the dataset object, need to guarantee process-safe from data source or DataLoader. + To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, + every process executes transforms on part of every loaded data. + Note that the order of output data may not match data source in multi-processing mode. + And each worker process will have a different copy of the dataset object, need to guarantee + process-safe from data source or DataLoader. """ @@ -44,70 +47,79 @@ def __init__(self, data: Iterable, transform: Optional[Callable] = None) -> None self.transform = transform self.source = None + def _get_item(self, data): + """ + Utility function to fetch data item from the source. + Subclass can extend it for customized logic. + + """ + for item in data: + yield item + def __iter__(self): + info = get_worker_info() + num_workers = info.num_workers if info is not None else 1 + id = info.id if info is not None else 0 + self.source = iter(self.data) - for data in self.source: - if self.transform is not None: - data = apply_transform(self.transform, data) - yield data + for i, item in enumerate(self._get_item(self.source)): + if i % num_workers == id: + if self.transform is not None: + item = apply_transform(self.transform, item) + yield item -class IterableBuffer(Randomizable, IterableDataset): +class ShuffleBuffer(Randomizable, IterableDataset): """ - Extend the IterableDataset with a buffer and support to randomly pop items. + Extend the IterableDataset with a buffer and randomly pop items. Args: data: input data source to load and transform to generate dataset for model. transform: a callable data transform on input data. buffer_size: size of the buffer to store items and randomly pop, default to 512. - shuffle: if True, randomly pop a item from the list, otherwise, pop from the beginning, - default to False. + seed: random seed to initialize the random state of all workers, set `seed += 1` in + every iter() call, refer to the PyTorch idea: + https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98. """ - def __init__(self, data, transform=None, buffer_size: int = 512, shuffle: bool = True) -> None: + def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0) -> None: super().__init__(data=data, transform=transform) self.size = buffer_size - self.shuffle = shuffle + self.seed = seed self._idx = 0 - def _rand_pop(self, buffer: List, num_workers: int = 1, id: int = 0): - for i in range(min(len(buffer), num_workers)): - # randomly select an item for every worker and pop + def _get_item(self, data): + """ + Fetch data from the source, if buffer is not full, fill into buffer, otherwise, + randomly pop items from the buffer. + After loading all the data from source, randomly pop items from the buffer. + + """ + self.seed += 1 + self.set_random_state(seed=self.seed) # make all workers in sync + buffer = [] + + def _pop_item(): self.randomize(len(buffer)) # switch random index data and the last index data - item, buffer[self._idx] = buffer[self._idx], buffer[-1] + ret, buffer[self._idx] = buffer[self._idx], buffer[-1] buffer.pop() - if i == id: - if self.transform is not None: - item = apply_transform(self.transform, item) - yield item + return ret - def __iter__(self): - # pop items for multi-workers - info = get_worker_info() - num_workers = info.num_workers if info is not None else 1 - id = info.id if info is not None else 0 + for item in data: + if len(buffer) >= self.size: + yield _pop_item() + buffer.append(item) - _buffer = [] - for i, item in enumerate(self.data): - if not self.shuffle: - if i % num_workers == id: - if self.transform is not None: - item = apply_transform(self.transform, item) - yield item - else: - if len(_buffer) >= self.size: - self._rand_pop(_buffer, num_workers=num_workers, id=id) - _buffer.append(item) - while _buffer: - yield from self._rand_pop(_buffer, num_workers=num_workers, id=id) + while buffer: + yield _pop_item() def randomize(self, size: int) -> None: self._idx = self.R.randint(size) -class CSVIterableDataset(IterableDataset, Randomizable): +class CSVIterableDataset(IterableDataset): """ Iterable dataset to load CSV files and generate dictionary data. It is particularly useful when data come from a stream, inherits from PyTorch IterableDataset: @@ -118,10 +130,10 @@ class CSVIterableDataset(IterableDataset, Randomizable): Note that as a stream input, it can't get the length of dataset. To effectively shuffle the data in the big dataset, users can set a big buffer to continuously store - the loaded chunks, then randomly pick data from the buffer for following tasks. + the loaded data, then randomly pick data from the buffer for following tasks. To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers, - every process executes transforms on part of every loaded chunk. + every process executes transforms on part of every loaded data. Note: the order of output data may not match data source in multi-processing mode. It can load data from multiple CSV files and join the tables with additional `kwargs` arg. @@ -160,6 +172,9 @@ class CSVIterableDataset(IterableDataset, Randomizable): `col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}` transform: transform to apply on the loaded items of a dictionary data. shuffle: whether to shuffle all the data in the buffer every time a new chunk loaded. + seed: random seed to initialize the random state for all the workers if `shuffle` is True, + set `seed += 1` in every iter() call, refer to the PyTorch idea: + https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98. kwargs: additional arguments for `pandas.merge()` API to join tables. """ @@ -174,6 +189,7 @@ def __init__( col_groups: Optional[Dict[str, Sequence[str]]] = None, transform: Optional[Callable] = None, shuffle: bool = False, + seed: int = 0, **kwargs, ): self.files = ensure_tuple(filename) @@ -183,6 +199,7 @@ def __init__( self.col_types = col_types self.col_groups = col_groups self.shuffle = shuffle + self.seed = seed self.kwargs = kwargs self.iters = self.reset() super().__init__(data=None, transform=transform) # type: ignore @@ -205,8 +222,10 @@ def _flattened(self): ) def __iter__(self): - buffer = IterableBuffer( - data=self._flattened(), transform=self.transform, buffer_size=self.buffer_size, shuffle=self.shuffle - ) - buffer.set_random_state(state=self.R) - yield from buffer + if self.shuffle: + self.seed += 1 + buffer = ShuffleBuffer( + data=self._flattened(), transform=self.transform, buffer_size=self.buffer_size, seed=self.seed + ) + yield from buffer + yield from IterableDataset(data=self._flattened(), transform=self.transform) diff --git a/tests/test_csv_iterable_dataset.py b/tests/test_csv_iterable_dataset.py index 3354c446f2..fae8e0ba8d 100644 --- a/tests/test_csv_iterable_dataset.py +++ b/tests/test_csv_iterable_dataset.py @@ -170,20 +170,21 @@ def prepare_csv_file(data, filepath): col_groups={"ehr": [f"ehr_{i}" for i in range(5)]}, transform=ToNumpyd(keys="ehr"), shuffle=True, + seed=123, ) - dataset.set_random_state(123) expected = [ + [6.8392, 6.4745, 5.8627, 5.1922, 5.2745], + [3.3333, 3.2353, 3.4000, 3.1647, 3.0863], [3.7725, 4.2118, 4.6353, 5.298, 9.5451], - [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], [6.4275, 6.2549, 5.9765, 6.2627, 7.7176], - [3.3333, 3.2353, 3.4000, 3.1647, 3.0863], + [2.0078, 2.2902, 2.0549, 3.0196, 3.8078], ] count = 0 for item, exp in zip(dataset, expected): count += 1 self.assertTrue(isinstance(item["ehr"], np.ndarray)) np.testing.assert_allclose(np.around(item["ehr"], 4), exp) - self.assertEqual(count, 4) + self.assertEqual(count, 5) # test multiple processes loading dataset = CSVIterableDataset(filepath1, transform=ToNumpyd(keys="label"), shuffle=False) From 790c2669bb6778f86796429fb023a08100e0d48b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 1 Nov 2021 11:09:40 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/iterable_dataset.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index 67cb69230c..2770fed527 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -53,8 +53,7 @@ def _get_item(self, data): Subclass can extend it for customized logic. """ - for item in data: - yield item + yield from data def __iter__(self): info = get_worker_info() From 6eceb33e7f69b0b578dc97ee4b110a6a79f62992 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 1 Nov 2021 22:41:08 +0800 Subject: [PATCH 6/9] [DLMED] remove get_items Signed-off-by: Nic Ma --- monai/data/iterable_dataset.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index 2770fed527..2ad6d346e6 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -47,21 +47,13 @@ def __init__(self, data: Iterable, transform: Optional[Callable] = None) -> None self.transform = transform self.source = None - def _get_item(self, data): - """ - Utility function to fetch data item from the source. - Subclass can extend it for customized logic. - - """ - yield from data - def __iter__(self): info = get_worker_info() num_workers = info.num_workers if info is not None else 1 id = info.id if info is not None else 0 self.source = iter(self.data) - for i, item in enumerate(self._get_item(self.source)): + for i, item in enumerate(self.source): if i % num_workers == id: if self.transform is not None: item = apply_transform(self.transform, item) @@ -88,7 +80,7 @@ def __init__(self, data, transform=None, buffer_size: int = 512, seed: int = 0) self.seed = seed self._idx = 0 - def _get_item(self, data): + def __iter__(self): """ Fetch data from the source, if buffer is not full, fill into buffer, otherwise, randomly pop items from the buffer. @@ -98,6 +90,7 @@ def _get_item(self, data): self.seed += 1 self.set_random_state(seed=self.seed) # make all workers in sync buffer = [] + source = self.data def _pop_item(): self.randomize(len(buffer)) @@ -106,13 +99,17 @@ def _pop_item(): buffer.pop() return ret - for item in data: - if len(buffer) >= self.size: + def _get_item(): + for item in source: + if len(buffer) >= self.size: + yield _pop_item() + buffer.append(item) + + while buffer: yield _pop_item() - buffer.append(item) - while buffer: - yield _pop_item() + self.data = _get_item() + return super().__iter__() def randomize(self, size: int) -> None: self._idx = self.R.randint(size) From 3d394fc5f41292b51152364d2a8e37bfa7df7762 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Nov 2021 00:49:37 +0800 Subject: [PATCH 7/9] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/iterable_dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/monai/data/iterable_dataset.py b/monai/data/iterable_dataset.py index 2ad6d346e6..d1365fa220 100644 --- a/monai/data/iterable_dataset.py +++ b/monai/data/iterable_dataset.py @@ -11,6 +11,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, Sequence, Union +import numpy as np from torch.utils.data import IterableDataset as _TorchIterableDataset from torch.utils.data import get_worker_info @@ -88,7 +89,7 @@ def __iter__(self): """ self.seed += 1 - self.set_random_state(seed=self.seed) # make all workers in sync + super().set_random_state(seed=self.seed) # make all workers in sync buffer = [] source = self.data @@ -114,6 +115,9 @@ def _get_item(): def randomize(self, size: int) -> None: self._idx = self.R.randint(size) + def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None): + raise NotImplementedError(f"`set_random_state` is not available in {self.__class__.__name__}.") + class CSVIterableDataset(IterableDataset): """ From f4c8daa38e1d63653eb2df5d006d5a83b8ea340d Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Nov 2021 07:32:37 +0800 Subject: [PATCH 8/9] [DLMED] add dtype Signed-off-by: Nic Ma --- monai/transforms/intensity/array.py | 23 +++++++++++++++---- monai/transforms/intensity/dictionary.py | 8 +++++-- tests/test_scale_intensity_range.py | 8 ++++--- .../test_scale_intensity_range_percentiles.py | 5 ++-- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index c4ddb27a8c..0a6f87f273 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -725,16 +725,20 @@ class ScaleIntensityRange(Transform): b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. + dtype: output data type, defaults to float32. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None: + def __init__( + self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False, dtype: DtypeLike = np.float32 + ) -> None: self.a_min = a_min self.a_max = a_max self.b_min = b_min self.b_max = b_max self.clip = clip + self.dtype = dtype def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ @@ -748,7 +752,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = img * (self.b_max - self.b_min) + self.b_min if self.clip: img = clip(img, self.b_min, self.b_max) - return img + ret, *_ = convert_data_type(img, dtype=self.dtype) + + return ret class AdjustContrast(Transform): @@ -883,12 +889,20 @@ class ScaleIntensityRangePercentiles(Transform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max]. + dtype: output data type, defaults to float32. """ backend = ScaleIntensityRange.backend def __init__( - self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False + self, + lower: float, + upper: float, + b_min: float, + b_max: float, + clip: bool = False, + relative: bool = False, + dtype: DtypeLike = np.float32, ) -> None: if lower < 0.0 or lower > 100.0: raise ValueError("Percentiles must be in the range [0, 100]") @@ -900,6 +914,7 @@ def __init__( self.b_max = b_max self.clip = clip self.relative = relative + self.dtype = dtype def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ @@ -914,7 +929,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: b_min = ((self.b_max - self.b_min) * (self.lower / 100.0)) + self.b_min b_max = ((self.b_max - self.b_min) * (self.upper / 100.0)) + self.b_min - scalar = ScaleIntensityRange(a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=False) + scalar = ScaleIntensityRange(a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=False, dtype=self.dtype) img = scalar(img) if self.clip: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 8459d7fc02..0caa118652 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -712,6 +712,7 @@ class ScaleIntensityRanged(MapTransform): b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. + dtype: output data type, defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -725,10 +726,11 @@ def __init__( b_min: float, b_max: float, clip: bool = False, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip) + self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip, dtype) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -826,6 +828,7 @@ class ScaleIntensityRangePercentilesd(MapTransform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max] + dtype: output data type, defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -840,10 +843,11 @@ def __init__( b_max: float, clip: bool = False, relative: bool = False, + dtype: DtypeLike = np.float32, allow_missing_keys: bool = False, ) -> None: super().__init__(keys, allow_missing_keys) - self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative) + self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, dtype) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) diff --git a/tests/test_scale_intensity_range.py b/tests/test_scale_intensity_range.py index d06bfd3596..f3e971b2ea 100644 --- a/tests/test_scale_intensity_range.py +++ b/tests/test_scale_intensity_range.py @@ -11,17 +11,19 @@ import unittest +import numpy as np + from monai.transforms import ScaleIntensityRange from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose class IntensityScaleIntensityRange(NumpyImageTestCase2D): def test_image_scale_intensity_range(self): - scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80) + scaler = ScaleIntensityRange(a_min=20, a_max=108, b_min=50, b_max=80, dtype=np.uint8) for p in TEST_NDARRAYS: scaled = scaler(p(self.imt)) - expected = (self.imt - 20) / 88 - expected = expected * 30 + 50 + self.assertTrue(scaled.dtype, np.uint8) + expected = (((self.imt - 20) / 88) * 30 + 50).astype(np.uint8) assert_allclose(scaled, p(expected)) diff --git a/tests/test_scale_intensity_range_percentiles.py b/tests/test_scale_intensity_range_percentiles.py index 0024cb349d..786e4299b5 100644 --- a/tests/test_scale_intensity_range_percentiles.py +++ b/tests/test_scale_intensity_range_percentiles.py @@ -27,9 +27,8 @@ def test_scaling(self): a_min = np.percentile(img, lower) a_max = np.percentile(img, upper) - expected = (img - a_min) / (a_max - a_min) - expected = (expected * (b_max - b_min)) + b_min - scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max) + expected = (((img - a_min) / (a_max - a_min)) * (b_max - b_min) + b_min).astype(np.uint8) + scaler = ScaleIntensityRangePercentiles(lower=lower, upper=upper, b_min=b_min, b_max=b_max, dtype=np.uint8) for p in TEST_NDARRAYS: result = scaler(p(img)) assert_allclose(result, p(expected), rtol=1e-4) From be965af6e22e1ba6b06b874091dec5d489dabed4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 2 Nov 2021 12:15:52 +0800 Subject: [PATCH 9/9] [DLMED] enhance dtype Signed-off-by: Nic Ma --- monai/transforms/intensity/array.py | 53 ++++++++++++++---------- monai/transforms/intensity/dictionary.py | 18 ++++---- monai/transforms/utility/array.py | 2 +- monai/transforms/utility/dictionary.py | 4 +- monai/transforms/utils.py | 42 ++++++++----------- monai/utils/type_conversion.py | 2 +- tests/test_histogram_normalize.py | 13 +++--- tests/test_histogram_normalized.py | 13 +++--- 8 files changed, 73 insertions(+), 74 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index 0a6f87f273..8d7ebfbfe0 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -290,7 +290,7 @@ class StdShiftIntensity(Transform): nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. Please ensure that the first dimension represents the channel of the image if True. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -323,7 +323,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - img, *_ = convert_data_type(img, dtype=self.dtype) + if self.dtype is not None: + img, *_ = convert_data_type(img, dtype=self.dtype) if self.channel_wise: for i, d in enumerate(img): img[i] = self._stdshift(d) # type: ignore @@ -355,7 +356,7 @@ def __init__( prob: probability of std shift. nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ RandomizableTransform.__init__(self, prob) @@ -416,7 +417,7 @@ def __init__( this parameter, please set `minv` and `maxv` into None. channel_wise: if True, scale on each channel separately. Please ensure that the first dimension represents the channel of the image if True. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ self.minv = minv self.maxv = maxv @@ -439,7 +440,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: return rescale_array(img, self.minv, self.maxv, dtype=self.dtype) if self.factor is not None: ret = img * (1 + self.factor) - ret, *_ = convert_data_type(ret, dtype=self.dtype) + if self.dtype is not None: + ret, *_ = convert_data_type(ret, dtype=self.dtype or img.dtype) return ret raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.") @@ -460,7 +462,7 @@ def __init__( factors: factor range to randomly scale by ``v = v * (1 + factor)``. if single number, factor value is picked from (-factors, factors). prob: probability of scale. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ RandomizableTransform.__init__(self, prob) @@ -507,7 +509,7 @@ class RandBiasField(RandomizableTransform): degree: degree of freedom of the polynomials. The value should be no less than 1. Defaults to 3. coeff_range: range of the random coefficients. Defaults to (0.0, 0.1). - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. prob: probability to do random bias field. """ @@ -580,7 +582,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen ) img_np, *_ = convert_data_type(img, np.ndarray) out = img_np * np.exp(_bias_fields) - out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype) + out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype or img.dtype) return out @@ -598,7 +600,7 @@ class NormalizeIntensity(Transform): nonzero: whether only normalize non-zero values. channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -665,6 +667,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True, """ + dtype = self.dtype or img.dtype if self.channel_wise: if self.subtrahend is not None and len(self.subtrahend) != len(img): raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.") @@ -680,7 +683,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: else: img = self._normalize(img, self.subtrahend, self.divisor) - out, *_ = convert_data_type(img, dtype=self.dtype) + out, *_ = convert_data_type(img, dtype=dtype) return out @@ -725,7 +728,7 @@ class ScaleIntensityRange(Transform): b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] @@ -744,6 +747,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ + dtype = self.dtype or img.dtype if self.a_max - self.a_min == 0.0: warn("Divide by zero (a_min == a_max)", Warning) return img - self.a_min + self.b_min @@ -752,7 +756,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: img = img * (self.b_max - self.b_min) + self.b_min if self.clip: img = clip(img, self.b_min, self.b_max) - ret, *_ = convert_data_type(img, dtype=self.dtype) + ret, *_ = convert_data_type(img, dtype=dtype) return ret @@ -889,7 +893,7 @@ class ScaleIntensityRangePercentiles(Transform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max]. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. """ backend = ScaleIntensityRange.backend @@ -1983,7 +1987,7 @@ class HistogramNormalize(Transform): mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`. only points at which `mask==True` are used for the equalization. can also provide the mask along with img at runtime. - dtype: data type of the output, default to `float32`. + dtype: data type of the output, if None, same as input image. default to `float32`. """ @@ -2003,12 +2007,15 @@ def __init__( self.mask = mask self.dtype = dtype - def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> np.ndarray: - return equalize_hist( - img=img, - mask=mask if mask is not None else self.mask, - num_bins=self.num_bins, - min=self.min, - max=self.max, - dtype=self.dtype, - ) + def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor: + img_np: np.ndarray + img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore + mask = mask if mask is not None else self.mask + mask_np: Optional[np.ndarray] = None + if mask is not None: + mask_np, *_ = convert_data_type(mask, np.ndarray) # type: ignore + + ret = equalize_hist(img=img_np, mask=mask_np, num_bins=self.num_bins, min=self.min, max=self.max) + out, *_ = convert_to_dst_type(src=ret, dst=img, dtype=self.dtype or img.dtype) + + return out diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 0caa118652..3689d86689 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -412,7 +412,7 @@ def __init__( nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. Please ensure that the first dimension represents the channel of the image if True. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) @@ -451,7 +451,7 @@ def __init__( prob: probability of std shift. nonzero: whether only count non-zero values. channel_wise: if True, calculate on each channel separately. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ MapTransform.__init__(self, keys, allow_missing_keys) @@ -509,7 +509,7 @@ def __init__( this parameter, please set `minv` and `maxv` into None. channel_wise: if True, scale on each channel separately. Please ensure that the first dimension represents the channel of the image if True. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -546,7 +546,7 @@ def __init__( if single number, factor value is picked from (-factors, factors). prob: probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -597,7 +597,7 @@ def __init__( degree: degree of freedom of the polynomials. The value should be no less than 1. Defaults to 3. coeff_range: range of the random coefficients. Defaults to (0.0, 0.1). - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. prob: probability to do random bias field. allow_missing_keys: don't raise exception if key is missing. @@ -641,7 +641,7 @@ class NormalizeIntensityd(MapTransform): nonzero: whether only normalize non-zero values. channel_wise: if using calculated mean and std, calculate on each channel separately or calculate on the entire image directly. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -712,7 +712,7 @@ class ScaleIntensityRanged(MapTransform): b_min: intensity target range min. b_max: intensity target range max. clip: whether to perform clip after scaling. - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -828,7 +828,7 @@ class ScaleIntensityRangePercentilesd(MapTransform): b_max: intensity target range max. clip: whether to perform clip after scaling. relative: whether to scale to the corresponding percentiles of [b_min, b_max] - dtype: output data type, defaults to float32. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: don't raise exception if key is missing. """ @@ -1553,7 +1553,7 @@ class HistogramNormalized(MapTransform): only points at which `mask==True` are used for the equalization. can also provide the mask by `mask_key` at runtime. mask_key: if mask is None, will try to get the mask with `mask_key`. - dtype: data type of the output, default to `float32`. + dtype: output data type, if None, same as input image. defaults to float32. allow_missing_keys: do not raise exception if key is missing. """ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 4c54076056..50738978bb 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -403,7 +403,7 @@ class ToNumpy(Transform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, dtype: Optional[DtypeLike] = None) -> None: + def __init__(self, dtype: DtypeLike = None) -> None: super().__init__() self.dtype = dtype diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 87c6685512..1412790227 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -552,9 +552,7 @@ class ToNumpyd(MapTransform): backend = ToNumpy.backend - def __init__( - self, keys: KeysCollection, dtype: Optional[DtypeLike] = None, allow_missing_keys: bool = False - ) -> None: + def __init__(self, keys: KeysCollection, dtype: DtypeLike = None, allow_missing_keys: bool = False) -> None: """ Args: keys: keys of the corresponding items to be transformed. diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 21dc9422d3..1d3204b7a0 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -149,14 +149,15 @@ def zero_margins(img: np.ndarray, margin: int) -> bool: def rescale_array( - arr: NdarrayOrTensor, minv: float = 0.0, maxv: float = 1.0, dtype: Union[DtypeLike, torch.dtype] = np.float32 + arr: NdarrayOrTensor, + minv: float = 0.0, + maxv: float = 1.0, + dtype: Optional[Union[DtypeLike, torch.dtype]] = np.float32, ) -> NdarrayOrTensor: """ Rescale the values of numpy array `arr` to be from `minv` to `maxv`. """ - if dtype is not None: - arr, *_ = convert_data_type(arr, dtype=dtype) - + dtype_ = dtype or arr.dtype mina = arr.min() maxa = arr.max() @@ -164,7 +165,10 @@ def rescale_array( return arr * minv norm = (arr - mina) / (maxa - mina) # normalize the array first - return (norm * (maxv - minv)) + minv # rescale by minv and maxv, which is the normalized array by default + arr = (norm * (maxv - minv)) + minv # rescale by minv and maxv, which is the normalized array by default + + ret, *_ = convert_data_type(arr, dtype=dtype_) + return ret def rescale_instance_array( @@ -173,7 +177,7 @@ def rescale_instance_array( """ Rescale each array slice along the first dimension of `arr` independently. """ - out: np.ndarray = np.zeros(arr.shape, dtype) + out: np.ndarray = np.zeros(arr.shape, dtype or arr.dtype) for i in range(arr.shape[0]): out[i] = rescale_array(arr[i], minv, maxv, dtype) @@ -184,8 +188,8 @@ def rescale_array_int_max(arr: np.ndarray, dtype: DtypeLike = np.uint16) -> np.n """ Rescale the array `arr` to be between the minimum and maximum values of the type `dtype`. """ - info: np.iinfo = np.iinfo(dtype) - return np.asarray(rescale_array(arr, info.min, info.max), dtype=dtype) + info: np.iinfo = np.iinfo(dtype or arr.dtype) + return np.asarray(rescale_array(arr, info.min, info.max), dtype=dtype or arr.dtype) def copypaste_arrays( @@ -1234,12 +1238,7 @@ def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequen def equalize_hist( - img: NdarrayOrTensor, - mask: Optional[NdarrayOrTensor] = None, - num_bins: int = 256, - min: int = 0, - max: int = 255, - dtype: DtypeLike = np.float32, + img: np.ndarray, mask: Optional[np.ndarray] = None, num_bins: int = 256, min: int = 0, max: int = 255 ) -> np.ndarray: """ Utility to equalize input image based on the histogram. @@ -1254,17 +1253,11 @@ def equalize_hist( https://numpy.org/doc/stable/reference/generated/numpy.histogram.html. min: the min value to normalize input image, default to `0`. max: the max value to normalize input image, default to `255`. - dtype: data type of the output, default to `float32`. """ - img_np: np.ndarray - img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore - mask_np: Optional[np.ndarray] = None - if mask is not None: - mask_np, *_ = convert_data_type(mask, np.ndarray) # type: ignore - orig_shape = img_np.shape - hist_img = img_np[np.array(mask_np, dtype=bool)] if mask_np is not None else img_np + orig_shape = img.shape + hist_img = img[np.array(mask, dtype=bool)] if mask is not None else img if has_skimage: hist, bins = exposure.histogram(hist_img.flatten(), num_bins) else: @@ -1276,9 +1269,8 @@ def equalize_hist( cum = rescale_array(arr=cum, minv=min, maxv=max) # apply linear interpolation - img_np = np.interp(img_np.flatten(), bins, cum) - - return img_np.reshape(orig_shape).astype(dtype) + img = np.interp(img.flatten(), bins, cum) + return img.reshape(orig_shape) class Fourier: diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index d449cd8a03..42a9e247ad 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -130,7 +130,7 @@ def convert_to_tensor( return data -def convert_to_numpy(data, dtype: Optional[DtypeLike] = None, wrap_sequence: bool = False): +def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False): """ Utility to convert the input data to a numpy array. If passing a dictionary, list or tuple, recursively check every item and convert it to numpy array. diff --git a/tests/test_histogram_normalize.py b/tests/test_histogram_normalize.py index e0178166d9..06fe7e6956 100644 --- a/tests/test_histogram_normalize.py +++ b/tests/test_histogram_normalize.py @@ -15,7 +15,8 @@ from parameterized import parameterized from monai.transforms import HistogramNormalize -from tests.utils import TEST_NDARRAYS +from monai.utils import get_equivalent_dtype +from tests.utils import TEST_NDARRAYS, assert_allclose TESTS = [] for p in TEST_NDARRAYS: @@ -23,7 +24,7 @@ [ {"num_bins": 4, "min": 1, "max": 5, "mask": np.array([1, 1, 1, 1, 1, 0])}, p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), - np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), + p(np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0])), ] ) @@ -31,7 +32,7 @@ [ {"num_bins": 4, "max": 4, "dtype": np.uint8}, p(np.array([0.0, 1.0, 2.0, 3.0, 4.0])), - np.array([0, 0, 1, 3, 4]), + p(np.array([0, 0, 1, 3, 4])), ] ) @@ -39,7 +40,7 @@ [ {"num_bins": 256, "max": 255, "dtype": np.uint8}, p(np.array([[[100.0, 200.0], [150.0, 250.0]]])), - np.array([[[0, 170], [70, 255]]]), + p(np.array([[[0, 170], [70, 255]]])), ] ) @@ -48,8 +49,8 @@ class TestHistogramNormalize(unittest.TestCase): @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = HistogramNormalize(**argments)(image) - np.testing.assert_allclose(result, expected_data) - self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + assert_allclose(result, expected_data) + self.assertEqual(get_equivalent_dtype(result.dtype, data_type=np.ndarray), argments.get("dtype", np.float32)) if __name__ == "__main__": diff --git a/tests/test_histogram_normalized.py b/tests/test_histogram_normalized.py index 314c7bd75b..e11ee77da5 100644 --- a/tests/test_histogram_normalized.py +++ b/tests/test_histogram_normalized.py @@ -15,7 +15,8 @@ from parameterized import parameterized from monai.transforms import HistogramNormalized -from tests.utils import TEST_NDARRAYS +from monai.utils import get_equivalent_dtype +from tests.utils import TEST_NDARRAYS, assert_allclose TESTS = [] for p in TEST_NDARRAYS: @@ -23,7 +24,7 @@ [ {"keys": "img", "num_bins": 4, "min": 1, "max": 5, "mask_key": "mask"}, {"img": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])), "mask": p(np.array([1, 1, 1, 1, 1, 0]))}, - np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0]), + p(np.array([1.0, 1.5, 2.5, 4.0, 5.0, 5.0])), ] ) @@ -31,7 +32,7 @@ [ {"keys": "img", "num_bins": 4, "max": 4, "dtype": np.uint8}, {"img": p(np.array([0.0, 1.0, 2.0, 3.0, 4.0]))}, - np.array([0, 0, 1, 3, 4]), + p(np.array([0, 0, 1, 3, 4])), ] ) @@ -39,7 +40,7 @@ [ {"keys": "img", "num_bins": 256, "max": 255, "dtype": np.uint8}, {"img": p(np.array([[[100.0, 200.0], [150.0, 250.0]]]))}, - np.array([[[0, 170], [70, 255]]]), + p(np.array([[[0, 170], [70, 255]]])), ] ) @@ -48,8 +49,8 @@ class TestHistogramNormalized(unittest.TestCase): @parameterized.expand(TESTS) def test_value(self, argments, image, expected_data): result = HistogramNormalized(**argments)(image)["img"] - np.testing.assert_allclose(result, expected_data) - self.assertEqual(result.dtype, argments.get("dtype", np.float32)) + assert_allclose(result, expected_data) + self.assertEqual(get_equivalent_dtype(result.dtype, data_type=np.ndarray), argments.get("dtype", np.float32)) if __name__ == "__main__":