diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 997e96a1b3..13e9ebaad6 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -16,17 +16,16 @@ import numpy as np import torch +from monai.config.type_definitions import NdarrayOrTensor from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset -from monai.data.utils import list_data_collate, pad_list_data_collate +from monai.data.utils import decollate_batch, pad_list_data_collate from monai.transforms.compose import Compose from monai.transforms.inverse import InvertibleTransform -from monai.transforms.inverse_batch_transform import BatchInverseTransform +from monai.transforms.post.dictionary import Invertd from monai.transforms.transform import Randomizable -from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils.enums import CommonKeys, PostFix, TraceKeys -from monai.utils.module import optional_import -from monai.utils.type_conversion import convert_data_type +from monai.transforms.utils_pytorch_numpy_unification import mode, stack +from monai.utils import CommonKeys, PostFix, optional_import if TYPE_CHECKING: from tqdm import tqdm @@ -80,19 +79,24 @@ class TestTimeAugmentation: For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. this arg only works when `meta_keys=None`. - return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the - full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended - equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. + to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. + output_device: if converted the inverted data to Tensor, move the inverted results to target device + before `post_func`, default to "cpu". + post_func: post processing for the inverted data, should be a callable function. + return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` + will return the full data. Dimensions will be same size as when passing a single image through + `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. progress: whether to display a progress bar. Example: .. code-block:: python - transform = RandAffined(keys, ...) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + model = UNet(...).to(device) + transform = Compose([RandAffined(keys, ...), ...]) + transform.set_random_state(seed=123) # ensure deterministic evaluation tt_aug = TestTimeAugmentation( - transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device + transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device ) mode, mean, std, vvc = tt_aug(test_data) """ @@ -109,6 +113,9 @@ def __init__( nearest_interp: bool = True, orig_meta_keys: Optional[str] = None, meta_key_postfix=DEFAULT_POST_FIX, + to_tensor: bool = True, + output_device: Union[str, torch.device] = "cpu", + post_func: Callable = _identity, return_full_data: bool = False, progress: bool = True, ) -> None: @@ -118,12 +125,20 @@ def __init__( self.inferrer_fn = inferrer_fn self.device = device self.image_key = image_key - self.orig_key = orig_key - self.nearest_interp = nearest_interp - self.orig_meta_keys = orig_meta_keys - self.meta_key_postfix = meta_key_postfix self.return_full_data = return_full_data self.progress = progress + self._pred_key = CommonKeys.PRED + self.inverter = Invertd( + keys=self._pred_key, + transform=transform, + orig_keys=orig_key, + orig_meta_keys=orig_meta_keys, + meta_key_postfix=meta_key_postfix, + nearest_interp=nearest_interp, + to_tensor=to_tensor, + device=output_device, + post_func=post_func, + ) # check that the transform has at least one random component, and that all random transforms are invertible self._check_transforms() @@ -135,8 +150,8 @@ def _check_transforms(self): invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) # check at least 1 random if sum(randoms) == 0: - raise RuntimeError( - "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." + warnings.warn( + "TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms." ) # check that whenever randoms is True, invertibles is also true for r, i in zip(randoms, invertibles): @@ -147,18 +162,19 @@ def _check_transforms(self): def __call__( self, data: Dict[str, Any], num_examples: int = 10 - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]: + ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float], NdarrayOrTensor]: """ Args: data: dictionary data to be processed. num_examples: number of realisations to be processed and results combined. Returns: - - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across - `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, - including `num_examples`. See original paper for clarification. - - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across - the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. + - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are + calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC) + is `std/mean` across the whole output, including `num_examples`. See original paper for clarification. + - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then + concatenating across the first dimension containing `num_examples`. This allows the user to perform + their own analysis if desired. """ d = dict(data) @@ -171,56 +187,22 @@ def __call__( ds = Dataset(data_in, self.transform) dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - transform_key = InvertibleTransform.trace_key(self.orig_key) - - # create inverter - inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) - - outputs: List[np.ndarray] = [] + outs: List = [] for batch_data in tqdm(dl) if has_tqdm and self.progress else dl: - - batch_images = batch_data[self.image_key].to(self.device) - # do model forward pass - batch_output = self.inferrer_fn(batch_images) - if isinstance(batch_output, torch.Tensor): - batch_output = batch_output.detach().cpu() - if isinstance(batch_output, np.ndarray): - batch_output = torch.Tensor(batch_output) - transform_info = batch_data.get(transform_key, None) - if transform_info is None: - # no invertible transforms, adding dummy info for identity invertible - transform_info = [[TraceKeys.NONE] for _ in range(self.batch_size)] - if self.nearest_interp: - transform_info = convert_inverse_interp_mode( - trans_info=deepcopy(transform_info), mode="nearest", align_corners=None - ) + batch_data[self._pred_key] = self.inferrer_fn(batch_data[self.image_key].to(self.device)) + outs.extend([self.inverter(i)[self._pred_key] for i in decollate_batch(batch_data)]) - # create a dictionary containing the inferred batch and their transforms - inferred_dict = {self.orig_key: batch_output, transform_key: transform_info} - # if meta dict is present, add that too (required for some inverse transforms) - meta_dict_key = self.orig_meta_keys or f"{self.orig_key}_{self.meta_key_postfix}" - if meta_dict_key in batch_data: - inferred_dict[meta_dict_key] = batch_data[meta_dict_key] - - # do inverse transformation (allow missing keys as only inverting the orig_key) - with allow_missing_keys_mode(self.transform): # type: ignore - inv_batch = inverter(inferred_dict) - - # append - outputs.append(inv_batch[self.orig_key]) - - # output - output: np.ndarray = np.concatenate(outputs) + output: NdarrayOrTensor = stack(outs, 0) if self.return_full_data: return output # calculate metrics - output_t, *_ = convert_data_type(output, output_type=torch.Tensor, dtype=np.int64) - mode: np.ndarray = np.asarray(torch.mode(output_t, dim=0).values) # type: ignore - mean: np.ndarray = np.mean(output, axis=0) # type: ignore - std: np.ndarray = np.std(output, axis=0) # type: ignore - vvc: float = (np.std(output) / np.mean(output)).item() - return mode, mean, std, vvc + _mode = mode(output, dim=0) + mean = output.mean(0) + std = output.std(0) + vvc = (output.std() / output.mean()).item() + + return _mode, mean, std, vvc diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9b779ed18e..408edd6d56 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -562,11 +562,13 @@ isfinite, isnan, maximum, + mode, moveaxis, nonzero, percentile, ravel, repeat, + stack, unravel_index, where, ) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 90932553c4..99d348b5df 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -16,6 +16,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.utils.misc import ensure_tuple, is_module_ver_at_least +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ "moveaxis", @@ -37,6 +38,8 @@ "repeat", "isnan", "ascontiguousarray", + "stack", + "mode", ] @@ -355,7 +358,7 @@ def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor: Args: x: array/tensor - + dim: dimension along which to stack """ if isinstance(x, np.ndarray): return np.isnan(x) @@ -378,3 +381,31 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs) -> NdarrayOrTensor: if isinstance(x, torch.Tensor): return x.contiguous(**kwargs) return x + + +def stack(x: Sequence[NdarrayOrTensor], dim: int) -> NdarrayOrTensor: + """`np.stack` with equivalent implementation for torch. + + Args: + x: array/tensor + dim: dimension along which to perform the stack (referred to as `axis` by numpy) + """ + if isinstance(x[0], np.ndarray): + return np.stack(x, dim) # type: ignore + return torch.stack(x, dim) # type: ignore + + +def mode(x: NdarrayOrTensor, dim: int = -1, to_long: bool = True) -> NdarrayOrTensor: + """`torch.mode` with equivalent implementation for numpy. + + Args: + x: array/tensor + dim: dimension along which to perform `mode` (referred to as `axis` by numpy) + to_long: convert input to long before performing mode. + """ + x_t: torch.Tensor + dtype = torch.int64 if to_long else None + x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype) # type: ignore + o_t = torch.mode(x_t, dim).values + o, *_ = convert_to_dst_type(o_t, x) + return o diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 9e22a5609f..8815354052 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -75,11 +75,12 @@ def tearDown(self) -> None: def test_test_time_augmentation(self): input_size = (20, 20) - device = "cuda" if torch.cuda.is_available() else "cpu" keys = ["image", "label"] num_training_ims = 10 + train_data = self.get_data(num_training_ims, input_size) test_data = self.get_data(1, input_size) + device = "cuda" if torch.cuda.is_available() else "cpu" transforms = Compose( [ @@ -125,21 +126,28 @@ def test_test_time_augmentation(self): post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) - def inferrer_fn(x): - return post_trans(model(x)) - - tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=5, + num_workers=0, + inferrer_fn=model, + device=device, + to_tensor=True, + output_device="cpu", + post_func=post_trans, + ) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1,) + input_size) self.assertEqual(mean.shape, (1,) + input_size) self.assertTrue(all(np.unique(mode) == (0, 1))) - self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) + self.assertGreaterEqual(mean.min(), 0.0) + self.assertLessEqual(mean.max(), 1.0) self.assertEqual(std.shape, (1,) + input_size) self.assertIsInstance(vvc, float) - def test_fail_non_random(self): + def test_warn_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) - with self.assertRaises(RuntimeError): + with self.assertWarns(UserWarning): TestTimeAugmentation(transforms, None, None, None) def test_warn_random_but_has_no_invertible(self): diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 4db3056b7b..b13378debe 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -13,11 +13,18 @@ import numpy as np import torch +from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import percentile +from monai.transforms.utils_pytorch_numpy_unification import mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose +TEST_MODE = [] +for p in TEST_NDARRAYS: + TEST_MODE.append([p(np.array([1, 2, 3, 4, 4, 5])), p(4), False]) + TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) + TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) + class TestPytorchNumpyUnification(unittest.TestCase): def setUp(self) -> None: @@ -54,6 +61,11 @@ def test_dim(self): atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 assert_allclose(results[0], results[-1], type_test=False, atol=atol) + @parameterized.expand(TEST_MODE) + def test_mode(self, array, expected, to_long): + res = mode(array, to_long=to_long) + assert_allclose(res, expected) + if __name__ == "__main__": unittest.main()