From ecd3b76525ec2c1fa0c04866131666b4d7dc79b4 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 26 Jan 2022 17:45:02 +0800 Subject: [PATCH 01/10] [DLMED] refine TTA Signed-off-by: Nic Ma --- monai/data/test_time_augmentation.py | 99 +++++++++++----------------- 1 file changed, 40 insertions(+), 59 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 997e96a1b3..875e5c7210 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -16,17 +16,17 @@ import numpy as np import torch +from monai.config.type_definitions import NdarrayOrTensor from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset -from monai.data.utils import list_data_collate, pad_list_data_collate +from monai.data.utils import decollate_batch, pad_list_data_collate +from monai.handlers.utils import from_engine from monai.transforms.compose import Compose from monai.transforms.inverse import InvertibleTransform -from monai.transforms.inverse_batch_transform import BatchInverseTransform +from monai.transforms.post.dictionary import Invertd from monai.transforms.transform import Randomizable -from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils.enums import CommonKeys, PostFix, TraceKeys +from monai.utils.enums import CommonKeys, PostFix from monai.utils.module import optional_import -from monai.utils.type_conversion import convert_data_type if TYPE_CHECKING: from tqdm import tqdm @@ -80,19 +80,22 @@ class TestTimeAugmentation: For example, to handle key `image`, read/write affine matrices from the metadata `image_meta_dict` dictionary's `affine` field. this arg only works when `meta_keys=None`. - return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the - full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended - equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. + to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. + output_device: if converted the inverted data to Tensor, move the inverted results to target device + before `post_func`, default to "cpu". + post_func: post processing for the inverted data, should be a callable function. + return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` + will return the full data. Dimensions will be same size as when passing a single image through + `inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`. progress: whether to display a progress bar. Example: .. code-block:: python transform = RandAffined(keys, ...) - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) tt_aug = TestTimeAugmentation( - transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device + transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device ) mode, mean, std, vvc = tt_aug(test_data) """ @@ -109,6 +112,9 @@ def __init__( nearest_interp: bool = True, orig_meta_keys: Optional[str] = None, meta_key_postfix=DEFAULT_POST_FIX, + to_tensor: bool = True, + output_device: Union[str, torch.device] = "cpu", + post_func: Callable = lambda x: x, return_full_data: bool = False, progress: bool = True, ) -> None: @@ -118,12 +124,20 @@ def __init__( self.inferrer_fn = inferrer_fn self.device = device self.image_key = image_key - self.orig_key = orig_key - self.nearest_interp = nearest_interp - self.orig_meta_keys = orig_meta_keys - self.meta_key_postfix = meta_key_postfix self.return_full_data = return_full_data self.progress = progress + self._pred_key = CommonKeys.PRED + self.inverter = Invertd( + keys=self._pred_key, + transform=transform, + orig_keys=orig_key, + orig_meta_keys=orig_meta_keys, + meta_key_postfix=meta_key_postfix, + nearest_interp=nearest_interp, + to_tensor=to_tensor, + device=output_device, + post_func=post_func, + ) # check that the transform has at least one random component, and that all random transforms are invertible self._check_transforms() @@ -154,11 +168,12 @@ def __call__( num_examples: number of realisations to be processed and results combined. Returns: - - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across - `num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output, - including `num_examples`. See original paper for clarification. - - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across - the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired. + - if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are + calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC) + is `std/mean` across the whole output, including `num_examples`. See original paper for clarification. + - if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then + concatenating across the first dimension containing `num_examples`. This allows the user to perform + their own analysis if desired. """ d = dict(data) @@ -171,55 +186,21 @@ def __call__( ds = Dataset(data_in, self.transform) dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - transform_key = InvertibleTransform.trace_key(self.orig_key) - - # create inverter - inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) - - outputs: List[np.ndarray] = [] + outs: List[NdarrayOrTensor] = [] for batch_data in tqdm(dl) if has_tqdm and self.progress else dl: - - batch_images = batch_data[self.image_key].to(self.device) - # do model forward pass - batch_output = self.inferrer_fn(batch_images) - if isinstance(batch_output, torch.Tensor): - batch_output = batch_output.detach().cpu() - if isinstance(batch_output, np.ndarray): - batch_output = torch.Tensor(batch_output) - transform_info = batch_data.get(transform_key, None) - if transform_info is None: - # no invertible transforms, adding dummy info for identity invertible - transform_info = [[TraceKeys.NONE] for _ in range(self.batch_size)] - if self.nearest_interp: - transform_info = convert_inverse_interp_mode( - trans_info=deepcopy(transform_info), mode="nearest", align_corners=None - ) - - # create a dictionary containing the inferred batch and their transforms - inferred_dict = {self.orig_key: batch_output, transform_key: transform_info} - # if meta dict is present, add that too (required for some inverse transforms) - meta_dict_key = self.orig_meta_keys or f"{self.orig_key}_{self.meta_key_postfix}" - if meta_dict_key in batch_data: - inferred_dict[meta_dict_key] = batch_data[meta_dict_key] - - # do inverse transformation (allow missing keys as only inverting the orig_key) - with allow_missing_keys_mode(self.transform): # type: ignore - inv_batch = inverter(inferred_dict) - - # append - outputs.append(inv_batch[self.orig_key]) + batch_data[self._pred_key] = self.inferrer_fn(batch_data[self.image_key].to(self.device)) + result = [self.inverter(i) for i in decollate_batch(batch_data)] + outs.append(from_engine(keys=self._pred_key)(result)) - # output - output: np.ndarray = np.concatenate(outputs) + output: NdarrayOrTensor = np.stack(outs, 0) if isinstance(outs[0], np.np.ndarray) else torch.stack(outs, 0) if self.return_full_data: return output # calculate metrics - output_t, *_ = convert_data_type(output, output_type=torch.Tensor, dtype=np.int64) - mode: np.ndarray = np.asarray(torch.mode(output_t, dim=0).values) # type: ignore + mode: np.ndarray = np.asarray(torch.mode(output, dim=0).values) # type: ignore mean: np.ndarray = np.mean(output, axis=0) # type: ignore std: np.ndarray = np.std(output, axis=0) # type: ignore vvc: float = (np.std(output) / np.mean(output)).item() From a9d837138cc7087bbb3c2c433cfb3ee0ddb01d2c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 27 Jan 2022 23:33:30 +0800 Subject: [PATCH 02/10] [DLMED] enhance TTA Signed-off-by: Nic Ma --- monai/data/test_time_augmentation.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 875e5c7210..934bb20341 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -20,13 +20,11 @@ from monai.data.dataloader import DataLoader from monai.data.dataset import Dataset from monai.data.utils import decollate_batch, pad_list_data_collate -from monai.handlers.utils import from_engine from monai.transforms.compose import Compose from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.dictionary import Invertd from monai.transforms.transform import Randomizable -from monai.utils.enums import CommonKeys, PostFix -from monai.utils.module import optional_import +from monai.utils import CommonKeys, PostFix, convert_data_type, convert_to_dst_type, optional_import if TYPE_CHECKING: from tqdm import tqdm @@ -191,17 +189,17 @@ def __call__( for batch_data in tqdm(dl) if has_tqdm and self.progress else dl: # do model forward pass batch_data[self._pred_key] = self.inferrer_fn(batch_data[self.image_key].to(self.device)) - result = [self.inverter(i) for i in decollate_batch(batch_data)] - outs.append(from_engine(keys=self._pred_key)(result)) + outs.extend([self.inverter(i)[self._pred_key] for i in decollate_batch(batch_data)]) - output: NdarrayOrTensor = np.stack(outs, 0) if isinstance(outs[0], np.np.ndarray) else torch.stack(outs, 0) + output: NdarrayOrTensor = np.stack(outs, 0) if isinstance(outs[0], np.ndarray) else torch.stack(outs, 0) if self.return_full_data: return output # calculate metrics - mode: np.ndarray = np.asarray(torch.mode(output, dim=0).values) # type: ignore - mean: np.ndarray = np.mean(output, axis=0) # type: ignore - std: np.ndarray = np.std(output, axis=0) # type: ignore - vvc: float = (np.std(output) / np.mean(output)).item() + output_t, *_ = convert_data_type(output, output_type=torch.Tensor) + mode, *_ = convert_to_dst_type(torch.mode(output_t.long(), dim=0).values, output, dtype=torch.int64) + mean, *_ = convert_to_dst_type(torch.mean(output_t, dim=0), output) + std, *_ = convert_to_dst_type(torch.std(output_t, dim=0), output) + vvc: float = (torch.std(output_t) / torch.mean(output_t)).item() return mode, mean, std, vvc From 5016de5bbd860b753a7853c3fb3013f53fea3063 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 28 Jan 2022 00:26:36 +0800 Subject: [PATCH 03/10] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/data/test_time_augmentation.py | 8 +++++--- tests/test_testtimeaugmentation.py | 10 +++++++++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 934bb20341..62eccb16c7 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -90,6 +90,7 @@ class TestTimeAugmentation: Example: .. code-block:: python + model = UNet(...).to(device) transform = RandAffined(keys, ...) tt_aug = TestTimeAugmentation( @@ -159,7 +160,7 @@ def _check_transforms(self): def __call__( self, data: Dict[str, Any], num_examples: int = 10 - ) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]: + ) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float], NdarrayOrTensor]: """ Args: data: dictionary data to be processed. @@ -184,7 +185,7 @@ def __call__( ds = Dataset(data_in, self.transform) dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - outs: List[NdarrayOrTensor] = [] + outs: List = [] for batch_data in tqdm(dl) if has_tqdm and self.progress else dl: # do model forward pass @@ -197,7 +198,8 @@ def __call__( return output # calculate metrics - output_t, *_ = convert_data_type(output, output_type=torch.Tensor) + output_t: torch.Tensor + output_t, *_ = convert_data_type(output, output_type=torch.Tensor) # type: ignore mode, *_ = convert_to_dst_type(torch.mode(output_t.long(), dim=0).values, output, dtype=torch.int64) mean, *_ = convert_to_dst_type(torch.mean(output_t, dim=0), output) std, *_ = convert_to_dst_type(torch.std(output_t, dim=0), output) diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 9e22a5609f..a0ceb4fa55 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -128,7 +128,15 @@ def test_test_time_augmentation(self): def inferrer_fn(x): return post_trans(model(x)) - tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device) + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=5, + num_workers=0, + inferrer_fn=inferrer_fn, + device=device, + to_tensor=True, + output_device="cpu", + ) mode, mean, std, vvc = tt_aug(test_data) self.assertEqual(mode.shape, (1,) + input_size) self.assertEqual(mean.shape, (1,) + input_size) From 4e05fe8305070475b15784f5b6dd16958419c3f9 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 28 Jan 2022 10:04:15 +0800 Subject: [PATCH 04/10] [DLMED] refine example in doc-string Signed-off-by: Nic Ma --- monai/data/test_time_augmentation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 62eccb16c7..73937661f3 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -91,7 +91,8 @@ class TestTimeAugmentation: .. code-block:: python model = UNet(...).to(device) - transform = RandAffined(keys, ...) + transform = Compose([RandAffined(keys, ...), ...]) + transform.set_random_state(seed=123) # ensure deterministic evaluation tt_aug = TestTimeAugmentation( transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device From b9c01bdb21ecd4baf3d6127fe478ca9e8e693799 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 28 Jan 2022 22:58:40 +0800 Subject: [PATCH 05/10] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/test_time_augmentation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 73937661f3..b1085abff6 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -114,7 +114,7 @@ def __init__( meta_key_postfix=DEFAULT_POST_FIX, to_tensor: bool = True, output_device: Union[str, torch.device] = "cpu", - post_func: Callable = lambda x: x, + post_func: Callable = _identity, return_full_data: bool = False, progress: bool = True, ) -> None: From 4a87197c3b7b9f0241cf52a46a56444ae9c415e2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 28 Jan 2022 15:47:13 +0000 Subject: [PATCH 06/10] add to optimize tta Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 14 +-- tests/test_testtimeaugmentation.py | 148 ++++++++++++++------------- 2 files changed, 84 insertions(+), 78 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index b1085abff6..48e1a73b5e 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -24,7 +24,7 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.dictionary import Invertd from monai.transforms.transform import Randomizable -from monai.utils import CommonKeys, PostFix, convert_data_type, convert_to_dst_type, optional_import +from monai.utils import CommonKeys, PostFix, optional_import if TYPE_CHECKING: from tqdm import tqdm @@ -199,10 +199,10 @@ def __call__( return output # calculate metrics - output_t: torch.Tensor - output_t, *_ = convert_data_type(output, output_type=torch.Tensor) # type: ignore - mode, *_ = convert_to_dst_type(torch.mode(output_t.long(), dim=0).values, output, dtype=torch.int64) - mean, *_ = convert_to_dst_type(torch.mean(output_t, dim=0), output) - std, *_ = convert_to_dst_type(torch.std(output_t, dim=0), output) - vvc: float = (torch.std(output_t) / torch.mean(output_t)).item() + mode = output.long().mode(0).values if isinstance(output, torch.Tensor) else output.astype(np.int64).mode(0) # type: ignore + mean = output.mean(0) + std = output.std(0) + vvc = output.std() / output.mean() + vvc = vvc.item() if isinstance(vvc, torch.Tensor) else vvc + return mode, mean, std, vvc diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index a0ceb4fa55..4e46837427 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -35,7 +35,7 @@ from monai.transforms.spatial.dictionary import RandFlipd, Spacingd from monai.utils import optional_import, set_determinism from monai.utils.enums import PostFix -from tests.utils import TEST_NDARRAYS +from tests.utils import TEST_NDARRAYS, assert_allclose if TYPE_CHECKING: import tqdm @@ -74,76 +74,82 @@ def tearDown(self) -> None: set_determinism(None) def test_test_time_augmentation(self): - input_size = (20, 20) - device = "cuda" if torch.cuda.is_available() else "cpu" - keys = ["image", "label"] - num_training_ims = 10 - train_data = self.get_data(num_training_ims, input_size) - test_data = self.get_data(1, input_size) - - transforms = Compose( - [ - AddChanneld(keys), - RandAffined( - keys, - prob=1.0, - spatial_size=(30, 30), - rotate_range=(np.pi / 3, np.pi / 3), - translate_range=(3, 3), - scale_range=((0.8, 1), (0.8, 1)), - padding_mode="zeros", - mode=("bilinear", "nearest"), - as_tensor_output=False, - ), - CropForegroundd(keys, source_key="image"), - DivisiblePadd(keys, 4), - ] - ) - - train_ds = CacheDataset(train_data, transforms) - # output might be different size, so pad so that they match - train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) - - model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) - loss_function = DiceLoss(sigmoid=True) - optimizer = torch.optim.Adam(model.parameters(), 1e-3) - - num_epochs = 10 - for _ in trange(num_epochs): - epoch_loss = 0 - - for batch_data in train_loader: - inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_function(outputs, labels) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - - epoch_loss /= len(train_loader) - - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) - - def inferrer_fn(x): - return post_trans(model(x)) - - tt_aug = TestTimeAugmentation( - transform=transforms, - batch_size=5, - num_workers=0, - inferrer_fn=inferrer_fn, - device=device, - to_tensor=True, - output_device="cpu", - ) - mode, mean, std, vvc = tt_aug(test_data) - self.assertEqual(mode.shape, (1,) + input_size) - self.assertEqual(mean.shape, (1,) + input_size) - self.assertTrue(all(np.unique(mode) == (0, 1))) - self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) - self.assertEqual(std.shape, (1,) + input_size) - self.assertIsInstance(vvc, float) + results = [] + for data_type in TEST_NDARRAYS: + input_size = (20, 20) + keys = ["image", "label"] + num_training_ims = 10 + train_data = self.get_data(num_training_ims, input_size, data_type) + test_data = self.get_data(1, input_size, data_type) + device = test_data.device if isinstance(test_data, torch.Tensor) else "cpu" + + transforms = Compose( + [ + AddChanneld(keys), + RandAffined( + keys, + prob=1.0, + spatial_size=(30, 30), + rotate_range=(np.pi / 3, np.pi / 3), + translate_range=(3, 3), + scale_range=((0.8, 1), (0.8, 1)), + padding_mode="zeros", + mode=("bilinear", "nearest"), + as_tensor_output=False, + ), + CropForegroundd(keys, source_key="image"), + DivisiblePadd(keys, 4), + ] + ) + + train_ds = CacheDataset(train_data, transforms) + # output might be different size, so pad so that they match + train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) + + model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + loss_function = DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + num_epochs = 10 + for _ in trange(num_epochs): + epoch_loss = 0 + + for batch_data in train_loader: + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + epoch_loss /= len(train_loader) + + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=5, + num_workers=0, + inferrer_fn=model, + device=device, + to_tensor=True, + output_device="cpu", + post_func=post_trans, + ) + result = tt_aug(test_data) + mode, mean, std, vvc = result + self.assertEqual(mode.shape, (1,) + input_size) + self.assertEqual(mean.shape, (1,) + input_size) + self.assertTrue(all(np.unique(mode) == (0, 1))) + self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) + self.assertEqual(std.shape, (1,) + input_size) + self.assertIsInstance(vvc, float) + results.append(result) + + for r in results: + for r1, r2 in zip(results[0], r): + assert_allclose(r1, r2) def test_fail_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) From 1016d38a48f694e32684ab31a31559d404b10145 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 31 Jan 2022 12:36:11 +0000 Subject: [PATCH 07/10] update Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 7 +- monai/transforms/__init__.py | 2 + .../utils_pytorch_numpy_unification.py | 33 +++- tests/test_testtimeaugmentation.py | 148 +++++++++--------- tests/test_utils_pytorch_numpy_unification.py | 14 +- 5 files changed, 122 insertions(+), 82 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 48e1a73b5e..72526e38ea 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -24,6 +24,7 @@ from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.dictionary import Invertd from monai.transforms.transform import Randomizable +from monai.transforms.utils_pytorch_numpy_unification import mode, stack from monai.utils import CommonKeys, PostFix, optional_import if TYPE_CHECKING: @@ -193,16 +194,16 @@ def __call__( batch_data[self._pred_key] = self.inferrer_fn(batch_data[self.image_key].to(self.device)) outs.extend([self.inverter(i)[self._pred_key] for i in decollate_batch(batch_data)]) - output: NdarrayOrTensor = np.stack(outs, 0) if isinstance(outs[0], np.ndarray) else torch.stack(outs, 0) + output: NdarrayOrTensor = stack(outs, 0) if self.return_full_data: return output # calculate metrics - mode = output.long().mode(0).values if isinstance(output, torch.Tensor) else output.astype(np.int64).mode(0) # type: ignore + _mode = mode(output, dim=0) mean = output.mean(0) std = output.std(0) vvc = output.std() / output.mean() vvc = vvc.item() if isinstance(vvc, torch.Tensor) else vvc - return mode, mean, std, vvc + return _mode, mean, std, vvc diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 9b779ed18e..408edd6d56 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -562,11 +562,13 @@ isfinite, isnan, maximum, + mode, moveaxis, nonzero, percentile, ravel, repeat, + stack, unravel_index, where, ) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 90932553c4..4d8803db23 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -16,6 +16,7 @@ from monai.config.type_definitions import NdarrayOrTensor from monai.utils.misc import ensure_tuple, is_module_ver_at_least +from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ "moveaxis", @@ -37,6 +38,8 @@ "repeat", "isnan", "ascontiguousarray", + "stack", + "mode", ] @@ -355,7 +358,7 @@ def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor: Args: x: array/tensor - + dim: dimension along which to stack """ if isinstance(x, np.ndarray): return np.isnan(x) @@ -378,3 +381,31 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs) -> NdarrayOrTensor: if isinstance(x, torch.Tensor): return x.contiguous(**kwargs) return x + + +def stack(x: Sequence[NdarrayOrTensor], dim: int) -> NdarrayOrTensor: + """`np.stack` with equivalent implementation for torch. + + Args: + x: array/tensor + dim: dimension along which to perform the stack (referred to as `axis` by numpy) + """ + if isinstance(x[0], np.ndarray): + return np.stack(x, dim) + return torch.stack(x, dim) # type: ignore + + +def mode(x: NdarrayOrTensor, dim: int = -1, to_long: bool = True) -> NdarrayOrTensor: + """`torch.mode` with equivalent implementation for numpy. + + Args: + x: array/tensor + dim: dimension along which to perform `mode` (referred to as `axis` by numpy) + to_long: convert input to long before performing mode. + """ + x_t: torch.Tensor + dtype = torch.int64 if to_long else None + x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype) # type: ignore + o_t = torch.mode(x_t, dim).values + o, *_ = convert_to_dst_type(o_t, x) + return o diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 4e46837427..24642c0510 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -35,7 +35,7 @@ from monai.transforms.spatial.dictionary import RandFlipd, Spacingd from monai.utils import optional_import, set_determinism from monai.utils.enums import PostFix -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS if TYPE_CHECKING: import tqdm @@ -74,82 +74,76 @@ def tearDown(self) -> None: set_determinism(None) def test_test_time_augmentation(self): - results = [] - for data_type in TEST_NDARRAYS: - input_size = (20, 20) - keys = ["image", "label"] - num_training_ims = 10 - train_data = self.get_data(num_training_ims, input_size, data_type) - test_data = self.get_data(1, input_size, data_type) - device = test_data.device if isinstance(test_data, torch.Tensor) else "cpu" - - transforms = Compose( - [ - AddChanneld(keys), - RandAffined( - keys, - prob=1.0, - spatial_size=(30, 30), - rotate_range=(np.pi / 3, np.pi / 3), - translate_range=(3, 3), - scale_range=((0.8, 1), (0.8, 1)), - padding_mode="zeros", - mode=("bilinear", "nearest"), - as_tensor_output=False, - ), - CropForegroundd(keys, source_key="image"), - DivisiblePadd(keys, 4), - ] - ) - - train_ds = CacheDataset(train_data, transforms) - # output might be different size, so pad so that they match - train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) - - model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) - loss_function = DiceLoss(sigmoid=True) - optimizer = torch.optim.Adam(model.parameters(), 1e-3) - - num_epochs = 10 - for _ in trange(num_epochs): - epoch_loss = 0 - - for batch_data in train_loader: - inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) - optimizer.zero_grad() - outputs = model(inputs) - loss = loss_function(outputs, labels) - loss.backward() - optimizer.step() - epoch_loss += loss.item() - - epoch_loss /= len(train_loader) - - post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) - - tt_aug = TestTimeAugmentation( - transform=transforms, - batch_size=5, - num_workers=0, - inferrer_fn=model, - device=device, - to_tensor=True, - output_device="cpu", - post_func=post_trans, - ) - result = tt_aug(test_data) - mode, mean, std, vvc = result - self.assertEqual(mode.shape, (1,) + input_size) - self.assertEqual(mean.shape, (1,) + input_size) - self.assertTrue(all(np.unique(mode) == (0, 1))) - self.assertEqual((mean.min(), mean.max()), (0.0, 1.0)) - self.assertEqual(std.shape, (1,) + input_size) - self.assertIsInstance(vvc, float) - results.append(result) - - for r in results: - for r1, r2 in zip(results[0], r): - assert_allclose(r1, r2) + input_size = (20, 20) + keys = ["image", "label"] + num_training_ims = 10 + + train_data = self.get_data(num_training_ims, input_size) + test_data = self.get_data(1, input_size) + device = "cuda" if torch.cuda.is_available() else "cpu" + + transforms = Compose( + [ + AddChanneld(keys), + RandAffined( + keys, + prob=1.0, + spatial_size=(30, 30), + rotate_range=(np.pi / 3, np.pi / 3), + translate_range=(3, 3), + scale_range=((0.8, 1), (0.8, 1)), + padding_mode="zeros", + mode=("bilinear", "nearest"), + as_tensor_output=False, + ), + CropForegroundd(keys, source_key="image"), + DivisiblePadd(keys, 4), + ] + ) + + train_ds = CacheDataset(train_data, transforms) + # output might be different size, so pad so that they match + train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) + + model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + loss_function = DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + num_epochs = 10 + for _ in trange(num_epochs): + epoch_loss = 0 + + for batch_data in train_loader: + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + epoch_loss /= len(train_loader) + + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=5, + num_workers=0, + inferrer_fn=model, + device=device, + to_tensor=True, + output_device="cpu", + post_func=post_trans, + ) + mode, mean, std, vvc = tt_aug(test_data) + self.assertEqual(mode.shape, (1,) + input_size) + self.assertEqual(mean.shape, (1,) + input_size) + self.assertTrue(all(np.unique(mode) == (0, 1))) + self.assertGreaterEqual(mean.min(), 0.0) + self.assertLessEqual(mean.max(), 1.0) + self.assertEqual(std.shape, (1,) + input_size) + self.assertIsInstance(vvc, float) def test_fail_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) diff --git a/tests/test_utils_pytorch_numpy_unification.py b/tests/test_utils_pytorch_numpy_unification.py index 4db3056b7b..b13378debe 100644 --- a/tests/test_utils_pytorch_numpy_unification.py +++ b/tests/test_utils_pytorch_numpy_unification.py @@ -13,11 +13,18 @@ import numpy as np import torch +from parameterized import parameterized -from monai.transforms.utils_pytorch_numpy_unification import percentile +from monai.transforms.utils_pytorch_numpy_unification import mode, percentile from monai.utils import set_determinism from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose +TEST_MODE = [] +for p in TEST_NDARRAYS: + TEST_MODE.append([p(np.array([1, 2, 3, 4, 4, 5])), p(4), False]) + TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False]) + TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True]) + class TestPytorchNumpyUnification(unittest.TestCase): def setUp(self) -> None: @@ -54,6 +61,11 @@ def test_dim(self): atol = 0.5 if not hasattr(torch, "quantile") else 1e-4 assert_allclose(results[0], results[-1], type_test=False, atol=atol) + @parameterized.expand(TEST_MODE) + def test_mode(self, array, expected, to_long): + res = mode(array, to_long=to_long) + assert_allclose(res, expected) + if __name__ == "__main__": unittest.main() From a8a8a26b1791292b962a706eec5642f7d73f31de Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 31 Jan 2022 14:07:07 +0000 Subject: [PATCH 08/10] .item for np.ndarray too Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/test_time_augmentation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 72526e38ea..8a6947231f 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -203,7 +203,6 @@ def __call__( _mode = mode(output, dim=0) mean = output.mean(0) std = output.std(0) - vvc = output.std() / output.mean() - vvc = vvc.item() if isinstance(vvc, torch.Tensor) else vvc + vvc = (output.std() / output.mean()).item() return _mode, mean, std, vvc From 7a89517d900c0a78b177acb765313b682232e555 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 1 Feb 2022 09:58:55 +0800 Subject: [PATCH 09/10] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/transforms/utils_pytorch_numpy_unification.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 4d8803db23..99d348b5df 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -391,7 +391,7 @@ def stack(x: Sequence[NdarrayOrTensor], dim: int) -> NdarrayOrTensor: dim: dimension along which to perform the stack (referred to as `axis` by numpy) """ if isinstance(x[0], np.ndarray): - return np.stack(x, dim) + return np.stack(x, dim) # type: ignore return torch.stack(x, dim) # type: ignore From a06d0b50c57a7eeb4a362a8eb1390cf602a9be1b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 1 Feb 2022 21:40:35 +0800 Subject: [PATCH 10/10] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/data/test_time_augmentation.py | 4 ++-- tests/test_testtimeaugmentation.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 8a6947231f..13e9ebaad6 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -150,8 +150,8 @@ def _check_transforms(self): invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts]) # check at least 1 random if sum(randoms) == 0: - raise RuntimeError( - "Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform." + warnings.warn( + "TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms." ) # check that whenever randoms is True, invertibles is also true for r, i in zip(randoms, invertibles): diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 24642c0510..8815354052 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -145,9 +145,9 @@ def test_test_time_augmentation(self): self.assertEqual(std.shape, (1,) + input_size) self.assertIsInstance(vvc, float) - def test_fail_non_random(self): + def test_warn_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) - with self.assertRaises(RuntimeError): + with self.assertWarns(UserWarning): TestTimeAugmentation(transforms, None, None, None) def test_warn_random_but_has_no_invertible(self):