Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
42a45e0
Merge pull request #19 from Project-MONAI/master
Nic-Ma Feb 1, 2021
cd16a13
Merge pull request #32 from Project-MONAI/master
Nic-Ma Feb 24, 2021
6f87afd
Merge pull request #180 from Project-MONAI/dev
Nic-Ma Jul 22, 2021
f398298
Merge pull request #214 from Project-MONAI/dev
Nic-Ma Sep 8, 2021
2241fea
Merge pull request #359 from Project-MONAI/dev
Nic-Ma Jan 26, 2022
ecd3b76
[DLMED] refine TTA
Nic-Ma Jan 26, 2022
a9d8371
[DLMED] enhance TTA
Nic-Ma Jan 27, 2022
5016de5
[DLMED] fix flake8
Nic-Ma Jan 27, 2022
4e05fe8
[DLMED] refine example in doc-string
Nic-Ma Jan 28, 2022
463ffd4
Merge branch 'dev' into optimize-tta
Nic-Ma Jan 28, 2022
c280884
Merge branch 'dev' into optimize-tta
Nic-Ma Jan 28, 2022
b9c01bd
[DLMED] update according to comments
Nic-Ma Jan 28, 2022
4a87197
add to optimize tta
rijobro Jan 28, 2022
af4538e
Merge branch 'dev' into optimize-tta
rijobro Jan 31, 2022
1016d38
update
rijobro Jan 31, 2022
ae44ba3
Merge remote-tracking branch 'NicMa/optimize-tta' into add-to-optimiz…
rijobro Jan 31, 2022
a8a8a26
.item for np.ndarray too
rijobro Jan 31, 2022
b94b601
Merge pull request #362 from rijobro/add-to-optimize-tta
Nic-Ma Jan 31, 2022
d702e20
Merge branch 'dev' into optimize-tta
wyli Jan 31, 2022
7a89517
[DLMED] fix flake8
Nic-Ma Feb 1, 2022
304df0f
Merge branch 'dev' into optimize-tta
Nic-Ma Feb 1, 2022
a06d0b5
[DLMED] update according to comments
Nic-Ma Feb 1, 2022
3df5a71
Merge branch 'dev' into optimize-tta
Nic-Ma Feb 1, 2022
c262687
Merge branch 'dev' into optimize-tta
wyli Feb 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 50 additions & 68 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,16 @@
import numpy as np
import torch

from monai.config.type_definitions import NdarrayOrTensor
from monai.data.dataloader import DataLoader
from monai.data.dataset import Dataset
from monai.data.utils import list_data_collate, pad_list_data_collate
from monai.data.utils import decollate_batch, pad_list_data_collate
from monai.transforms.compose import Compose
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.inverse_batch_transform import BatchInverseTransform
from monai.transforms.post.dictionary import Invertd
from monai.transforms.transform import Randomizable
from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils.enums import CommonKeys, PostFix, TraceKeys
from monai.utils.module import optional_import
from monai.utils.type_conversion import convert_data_type
from monai.transforms.utils_pytorch_numpy_unification import mode, stack
from monai.utils import CommonKeys, PostFix, optional_import

if TYPE_CHECKING:
from tqdm import tqdm
Expand Down Expand Up @@ -80,19 +79,24 @@ class TestTimeAugmentation:
For example, to handle key `image`, read/write affine matrices from the
metadata `image_meta_dict` dictionary's `affine` field.
this arg only works when `meta_keys=None`.
return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True` will return the
full data. Dimensions will be same size as when passing a single image through `inferrer_fn`, with a dimension appended
equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
output_device: if converted the inverted data to Tensor, move the inverted results to target device
before `post_func`, default to "cpu".
post_func: post processing for the inverted data, should be a callable function.
return_full_data: normally, metrics are returned (mode, mean, std, vvc). Setting this flag to `True`
will return the full data. Dimensions will be same size as when passing a single image through
`inferrer_fn`, with a dimension appended equal in size to `num_examples` (N), i.e., `[N,C,H,W,[D]]`.
progress: whether to display a progress bar.

Example:
.. code-block:: python

transform = RandAffined(keys, ...)
post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
model = UNet(...).to(device)
transform = Compose([RandAffined(keys, ...), ...])
transform.set_random_state(seed=123) # ensure deterministic evaluation

tt_aug = TestTimeAugmentation(
transform, batch_size=5, num_workers=0, inferrer_fn=lambda x: post_trans(model(x)), device=device
transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device
)
mode, mean, std, vvc = tt_aug(test_data)
"""
Expand All @@ -109,6 +113,9 @@ def __init__(
nearest_interp: bool = True,
orig_meta_keys: Optional[str] = None,
meta_key_postfix=DEFAULT_POST_FIX,
to_tensor: bool = True,
output_device: Union[str, torch.device] = "cpu",
post_func: Callable = _identity,
return_full_data: bool = False,
progress: bool = True,
) -> None:
Expand All @@ -118,12 +125,20 @@ def __init__(
self.inferrer_fn = inferrer_fn
self.device = device
self.image_key = image_key
self.orig_key = orig_key
self.nearest_interp = nearest_interp
self.orig_meta_keys = orig_meta_keys
self.meta_key_postfix = meta_key_postfix
self.return_full_data = return_full_data
self.progress = progress
self._pred_key = CommonKeys.PRED
self.inverter = Invertd(
keys=self._pred_key,
transform=transform,
orig_keys=orig_key,
orig_meta_keys=orig_meta_keys,
meta_key_postfix=meta_key_postfix,
nearest_interp=nearest_interp,
to_tensor=to_tensor,
device=output_device,
post_func=post_func,
)

# check that the transform has at least one random component, and that all random transforms are invertible
self._check_transforms()
Expand All @@ -135,8 +150,8 @@ def _check_transforms(self):
invertibles = np.array([isinstance(t, InvertibleTransform) for t in ts])
# check at least 1 random
if sum(randoms) == 0:
raise RuntimeError(
"Requires a `Randomizable` transform or a `Compose` containing at least one `Randomizable` transform."
warnings.warn(
"TTA usually has at least a `Randomizable` transform or `Compose` contains `Randomizable` transforms."
)
# check that whenever randoms is True, invertibles is also true
for r, i in zip(randoms, invertibles):
Expand All @@ -147,18 +162,19 @@ def _check_transforms(self):

def __call__(
self, data: Dict[str, Any], num_examples: int = 10
) -> Union[Tuple[np.ndarray, np.ndarray, np.ndarray, float], np.ndarray]:
) -> Union[Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor, float], NdarrayOrTensor]:
"""
Args:
data: dictionary data to be processed.
num_examples: number of realisations to be processed and results combined.

Returns:
- if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are calculated across
`num_examples` outputs at each voxel. The volume variation coefficient (VVC) is `std/mean` across the whole output,
including `num_examples`. See original paper for clarification.
- if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then concatenating across
the first dimension containing `num_examples`. This allows the user to perform their own analysis if desired.
- if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are
calculated across `num_examples` outputs at each voxel. The volume variation coefficient (VVC)
is `std/mean` across the whole output, including `num_examples`. See original paper for clarification.
- if `return_full_data==False`: data is returned as-is after applying the `inferrer_fn` and then
concatenating across the first dimension containing `num_examples`. This allows the user to perform
their own analysis if desired.
"""
d = dict(data)

Expand All @@ -171,56 +187,22 @@ def __call__(
ds = Dataset(data_in, self.transform)
dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate)

transform_key = InvertibleTransform.trace_key(self.orig_key)

# create inverter
inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate)
Comment thread
wyli marked this conversation as resolved.

outputs: List[np.ndarray] = []
outs: List = []

for batch_data in tqdm(dl) if has_tqdm and self.progress else dl:

batch_images = batch_data[self.image_key].to(self.device)

# do model forward pass
batch_output = self.inferrer_fn(batch_images)
if isinstance(batch_output, torch.Tensor):
batch_output = batch_output.detach().cpu()
if isinstance(batch_output, np.ndarray):
batch_output = torch.Tensor(batch_output)
transform_info = batch_data.get(transform_key, None)
if transform_info is None:
# no invertible transforms, adding dummy info for identity invertible
transform_info = [[TraceKeys.NONE] for _ in range(self.batch_size)]
if self.nearest_interp:
transform_info = convert_inverse_interp_mode(
trans_info=deepcopy(transform_info), mode="nearest", align_corners=None
)
batch_data[self._pred_key] = self.inferrer_fn(batch_data[self.image_key].to(self.device))
outs.extend([self.inverter(i)[self._pred_key] for i in decollate_batch(batch_data)])

# create a dictionary containing the inferred batch and their transforms
inferred_dict = {self.orig_key: batch_output, transform_key: transform_info}
# if meta dict is present, add that too (required for some inverse transforms)
meta_dict_key = self.orig_meta_keys or f"{self.orig_key}_{self.meta_key_postfix}"
if meta_dict_key in batch_data:
inferred_dict[meta_dict_key] = batch_data[meta_dict_key]

# do inverse transformation (allow missing keys as only inverting the orig_key)
with allow_missing_keys_mode(self.transform): # type: ignore
inv_batch = inverter(inferred_dict)

# append
outputs.append(inv_batch[self.orig_key])

# output
output: np.ndarray = np.concatenate(outputs)
output: NdarrayOrTensor = stack(outs, 0)

if self.return_full_data:
return output

# calculate metrics
output_t, *_ = convert_data_type(output, output_type=torch.Tensor, dtype=np.int64)
mode: np.ndarray = np.asarray(torch.mode(output_t, dim=0).values) # type: ignore
mean: np.ndarray = np.mean(output, axis=0) # type: ignore
std: np.ndarray = np.std(output, axis=0) # type: ignore
vvc: float = (np.std(output) / np.mean(output)).item()
return mode, mean, std, vvc
_mode = mode(output, dim=0)
mean = output.mean(0)
std = output.std(0)
vvc = (output.std() / output.mean()).item()

return _mode, mean, std, vvc
2 changes: 2 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,11 +562,13 @@
isfinite,
isnan,
maximum,
mode,
moveaxis,
nonzero,
percentile,
ravel,
repeat,
stack,
unravel_index,
where,
)
33 changes: 32 additions & 1 deletion monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from monai.config.type_definitions import NdarrayOrTensor
from monai.utils.misc import ensure_tuple, is_module_ver_at_least
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

__all__ = [
"moveaxis",
Expand All @@ -37,6 +38,8 @@
"repeat",
"isnan",
"ascontiguousarray",
"stack",
"mode",
]


Expand Down Expand Up @@ -355,7 +358,7 @@ def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor:

Args:
x: array/tensor

dim: dimension along which to stack
"""
if isinstance(x, np.ndarray):
return np.isnan(x)
Expand All @@ -378,3 +381,31 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs) -> NdarrayOrTensor:
if isinstance(x, torch.Tensor):
return x.contiguous(**kwargs)
return x


def stack(x: Sequence[NdarrayOrTensor], dim: int) -> NdarrayOrTensor:
"""`np.stack` with equivalent implementation for torch.

Args:
x: array/tensor
dim: dimension along which to perform the stack (referred to as `axis` by numpy)
"""
if isinstance(x[0], np.ndarray):
return np.stack(x, dim) # type: ignore
return torch.stack(x, dim) # type: ignore


def mode(x: NdarrayOrTensor, dim: int = -1, to_long: bool = True) -> NdarrayOrTensor:
"""`torch.mode` with equivalent implementation for numpy.

Args:
x: array/tensor
dim: dimension along which to perform `mode` (referred to as `axis` by numpy)
to_long: convert input to long before performing mode.
"""
x_t: torch.Tensor
dtype = torch.int64 if to_long else None
x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype) # type: ignore
o_t = torch.mode(x_t, dim).values
o, *_ = convert_to_dst_type(o_t, x)
return o
24 changes: 16 additions & 8 deletions tests/test_testtimeaugmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,12 @@ def tearDown(self) -> None:

def test_test_time_augmentation(self):
input_size = (20, 20)
device = "cuda" if torch.cuda.is_available() else "cpu"
keys = ["image", "label"]
num_training_ims = 10

train_data = self.get_data(num_training_ims, input_size)
test_data = self.get_data(1, input_size)
device = "cuda" if torch.cuda.is_available() else "cpu"

transforms = Compose(
[
Expand Down Expand Up @@ -125,21 +126,28 @@ def test_test_time_augmentation(self):

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

def inferrer_fn(x):
return post_trans(model(x))

tt_aug = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=inferrer_fn, device=device)
tt_aug = TestTimeAugmentation(
transform=transforms,
batch_size=5,
num_workers=0,
inferrer_fn=model,
device=device,
to_tensor=True,
output_device="cpu",
post_func=post_trans,
)
mode, mean, std, vvc = tt_aug(test_data)
self.assertEqual(mode.shape, (1,) + input_size)
self.assertEqual(mean.shape, (1,) + input_size)
self.assertTrue(all(np.unique(mode) == (0, 1)))
self.assertEqual((mean.min(), mean.max()), (0.0, 1.0))
self.assertGreaterEqual(mean.min(), 0.0)
self.assertLessEqual(mean.max(), 1.0)
self.assertEqual(std.shape, (1,) + input_size)
self.assertIsInstance(vvc, float)

def test_fail_non_random(self):
def test_warn_non_random(self):
transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)])
with self.assertRaises(RuntimeError):
with self.assertWarns(UserWarning):
TestTimeAugmentation(transforms, None, None, None)

def test_warn_random_but_has_no_invertible(self):
Expand Down
14 changes: 13 additions & 1 deletion tests/test_utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,18 @@

import numpy as np
import torch
from parameterized import parameterized

from monai.transforms.utils_pytorch_numpy_unification import percentile
from monai.transforms.utils_pytorch_numpy_unification import mode, percentile
from monai.utils import set_determinism
from tests.utils import TEST_NDARRAYS, SkipIfBeforePyTorchVersion, assert_allclose

TEST_MODE = []
for p in TEST_NDARRAYS:
TEST_MODE.append([p(np.array([1, 2, 3, 4, 4, 5])), p(4), False])
TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4.1), False])
TEST_MODE.append([p(np.array([3.1, 4.1, 4.1, 5.1])), p(4), True])


class TestPytorchNumpyUnification(unittest.TestCase):
def setUp(self) -> None:
Expand Down Expand Up @@ -54,6 +61,11 @@ def test_dim(self):
atol = 0.5 if not hasattr(torch, "quantile") else 1e-4
assert_allclose(results[0], results[-1], type_test=False, atol=atol)

@parameterized.expand(TEST_MODE)
def test_mode(self, array, expected, to_long):
res = mode(array, to_long=to_long)
assert_allclose(res, expected)


if __name__ == "__main__":
unittest.main()