Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def weighted_patch_samples(
idx = r_state.randint(0, len(v), size=n_samples)
else:
r, *_ = convert_to_dst_type(r_state.random(n_samples), v) # type: ignore
idx = searchsorted(v, r * v[-1], right=True)
idx = searchsorted(v, r * v[-1], right=True) # type: ignore
idx, *_ = convert_to_dst_type(idx, v, dtype=torch.int) # type: ignore
# compensate 'valid' mode
diff = np.minimum(win_size, img_size) // 2
Expand All @@ -411,7 +411,7 @@ def weighted_patch_samples(


def correct_crop_centers(
centers: List[Union[int, torch.Tensor]],
centers: Union[NdarrayOrTensor, List[NdarrayOrTensor]],
Comment thread
Nic-Ma marked this conversation as resolved.
spatial_size: Union[Sequence[int], int],
label_spatial_shape: Sequence[int],
allow_smaller: bool = False,
Expand Down
61 changes: 31 additions & 30 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import torch

from monai.config.type_definitions import NdarrayOrTensor
from monai.utils.misc import is_module_ver_at_least
from monai.utils.misc import ensure_tuple, is_module_ver_at_least

__all__ = [
"moveaxis",
Expand All @@ -40,28 +40,29 @@
]


def moveaxis(x: NdarrayOrTensor, src: int, dst: int) -> NdarrayOrTensor:
"""`moveaxis` for pytorch and numpy, using `permute` for pytorch ver < 1.8"""
def moveaxis(x: NdarrayOrTensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]]) -> NdarrayOrTensor:
"""`moveaxis` for pytorch and numpy, using `permute` for pytorch version < 1.7"""
if isinstance(x, torch.Tensor):
if hasattr(torch, "moveaxis"): # `moveaxis` is new in torch 1.8.0
return torch.moveaxis(x, src, dst)
if hasattr(torch, "movedim"): # `movedim` is new in torch 1.7.0
# torch.moveaxis is a recent alias since torch 1.8.0
return torch.movedim(x, src, dst) # type: ignore
return _moveaxis_with_permute(x, src, dst) # type: ignore
if isinstance(x, np.ndarray):
return np.moveaxis(x, src, dst)
raise RuntimeError()
return np.moveaxis(x, src, dst)


def _moveaxis_with_permute(x, src, dst):
def _moveaxis_with_permute(
x: torch.Tensor, src: Union[int, Sequence[int]], dst: Union[int, Sequence[int]]
) -> torch.Tensor:
# get original indices
indices = list(range(x.ndim))
# make src and dst positive
if src < 0:
src = len(indices) + src
if dst < 0:
dst = len(indices) + dst
# remove desired index and insert it in new position
indices.pop(src)
indices.insert(dst, src)
len_indices = len(indices)
for s, d in zip(ensure_tuple(src), ensure_tuple(dst)):
# make src and dst positive
# remove desired index and insert it in new position
pos_s = len_indices + s if s < 0 else s
pos_d = len_indices + d if d < 0 else d
indices.pop(pos_s)
indices.insert(pos_d, pos_s)
return x.permute(indices)


Expand Down Expand Up @@ -151,7 +152,7 @@ def where(condition: NdarrayOrTensor, x=None, y=None) -> NdarrayOrTensor:
return result


def nonzero(x: NdarrayOrTensor):
def nonzero(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.nonzero` with equivalent implementation for torch.

Args:
Expand Down Expand Up @@ -185,7 +186,7 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor:
return np.floor_divide(a, b)


def unravel_index(idx, shape):
def unravel_index(idx, shape) -> NdarrayOrTensor:
"""`np.unravel_index` with equivalent implementation for torch.

Args:
Expand All @@ -204,7 +205,7 @@ def unravel_index(idx, shape):
return np.asarray(np.unravel_index(idx, shape))


def unravel_indices(idx, shape):
def unravel_indices(idx, shape) -> NdarrayOrTensor:
"""Computing unravel coordinates from indices.

Args:
Expand All @@ -215,10 +216,10 @@ def unravel_indices(idx, shape):
Stacked indices unravelled for given shape
"""
lib_stack = torch.stack if isinstance(idx[0], torch.Tensor) else np.stack
return lib_stack([unravel_index(i, shape) for i in idx])
return lib_stack([unravel_index(i, shape) for i in idx]) # type: ignore


def ravel(x: NdarrayOrTensor):
def ravel(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.ravel` with equivalent implementation for torch.

Args:
Expand All @@ -234,7 +235,7 @@ def ravel(x: NdarrayOrTensor):
return np.ravel(x)


def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]):
def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]) -> NdarrayOrTensor:
"""`np.any` with equivalent implementation for torch.

For pytorch, convert to boolean for compatibility with older versions.
Expand All @@ -247,7 +248,7 @@ def any_np_pt(x: NdarrayOrTensor, axis: Union[int, Sequence[int]]):
Return a contiguous flattened array/tensor.
"""
if isinstance(x, np.ndarray):
return np.any(x, axis)
return np.any(x, axis) # type: ignore

# pytorch can't handle multiple dimensions to `any` so loop across them
axis = [axis] if not isinstance(axis, Sequence) else axis
Expand Down Expand Up @@ -287,7 +288,7 @@ def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> N
return torch.cat(to_cat, dim=axis, out=out) # type: ignore


def cumsum(a: NdarrayOrTensor, axis=None, **kwargs):
def cumsum(a: NdarrayOrTensor, axis=None, **kwargs) -> NdarrayOrTensor:
"""
`np.cumsum` with equivalent implementation for torch.

Expand All @@ -306,14 +307,14 @@ def cumsum(a: NdarrayOrTensor, axis=None, **kwargs):
return torch.cumsum(a, dim=axis, **kwargs)


def isfinite(x):
def isfinite(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.isfinite` with equivalent implementation for torch."""
if not isinstance(x, torch.Tensor):
return np.isfinite(x)
return torch.isfinite(x)


def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs):
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs) -> NdarrayOrTensor:
"""
`np.searchsorted` with equivalent implementation for torch.

Expand All @@ -332,7 +333,7 @@ def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=Non
return torch.searchsorted(a, v, right=right, **kwargs) # type: ignore


def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None, **kwargs):
def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None, **kwargs) -> NdarrayOrTensor:
"""
`np.repeat` with equivalent implementation for torch (`repeat_interleave`).

Expand All @@ -349,7 +350,7 @@ def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None, **kwarg
return torch.repeat_interleave(a, repeats, dim=axis, **kwargs)


def isnan(x: NdarrayOrTensor):
def isnan(x: NdarrayOrTensor) -> NdarrayOrTensor:
"""`np.isnan` with equivalent implementation for torch.

Args:
Expand All @@ -361,7 +362,7 @@ def isnan(x: NdarrayOrTensor):
return torch.isnan(x)


def ascontiguousarray(x: NdarrayOrTensor, **kwargs):
def ascontiguousarray(x: NdarrayOrTensor, **kwargs) -> NdarrayOrTensor:
"""`np.ascontiguousarray` with equivalent implementation for torch (`contiguous`).

Args:
Expand Down