Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions monai/transforms/utils_pytorch_numpy_unification.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def clip(a: NdarrayOrTensor, a_min, a_max) -> NdarrayOrTensor:
return result


def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[NdarrayOrTensor, float, int]:
def percentile(
x: NdarrayOrTensor, q, dim: Optional[int] = None, keepdim: bool = False, **kwargs
) -> Union[NdarrayOrTensor, float, int]:
"""`np.percentile` with equivalent implementation for torch.

Pytorch uses `quantile`, but this functionality is only available from v1.7.
Expand All @@ -97,6 +99,9 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
q: percentile to compute (should in range 0 <= q <= 100)
dim: the dim along which the percentiles are computed. default is to compute the percentile
along a flattened version of the array. only work for numpy array or Tensor with PyTorch >= 1.7.0.
keepdim: whether the output data has dim retained or not.
kwargs: if `x` is numpy array, additional args for `np.percentile`, more details:
https://numpy.org/doc/stable/reference/generated/numpy.percentile.html.

Returns:
Resulting value (scalar)
Expand All @@ -108,11 +113,11 @@ def percentile(x: NdarrayOrTensor, q, dim: Optional[int] = None) -> Union[Ndarra
raise ValueError
result: Union[NdarrayOrTensor, float, int]
if isinstance(x, np.ndarray):
result = np.percentile(x, q, axis=dim)
result = np.percentile(x, q, axis=dim, keepdims=keepdim, **kwargs)
else:
q = torch.tensor(q, device=x.device)
if hasattr(torch, "quantile"): # `quantile` is new in torch 1.7.0
result = torch.quantile(x, q / 100.0, dim=dim)
result = torch.quantile(x, q / 100.0, dim=dim, keepdim=keepdim)
else:
# Note that ``kthvalue()`` works one-based, i.e., the first sorted value
# corresponds to k=1, not k=0. Thus, we need the `1 +`.
Expand Down Expand Up @@ -282,13 +287,23 @@ def concatenate(to_cat: Sequence[NdarrayOrTensor], axis: int = 0, out=None) -> N
return torch.cat(to_cat, dim=axis, out=out) # type: ignore


def cumsum(a: NdarrayOrTensor, axis=None):
"""`np.cumsum` with equivalent implementation for torch."""
def cumsum(a: NdarrayOrTensor, axis=None, **kwargs):
"""
`np.cumsum` with equivalent implementation for torch.

Args:
a: input data to compute cumsum.
axis: expected axis to compute cumsum.
kwargs: if `a` is PyTorch Tensor, additional args for `torch.cumsum`, more details:
https://pytorch.org/docs/stable/generated/torch.cumsum.html.

"""

if isinstance(a, np.ndarray):
return np.cumsum(a, axis)
if axis is None:
return torch.cumsum(a[:], 0)
return torch.cumsum(a, dim=axis)
return torch.cumsum(a[:], 0, **kwargs)
return torch.cumsum(a, dim=axis, **kwargs)


def isfinite(x):
Expand All @@ -298,18 +313,40 @@ def isfinite(x):
return torch.isfinite(x)


def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None):
def searchsorted(a: NdarrayOrTensor, v: NdarrayOrTensor, right=False, sorter=None, **kwargs):
"""
`np.searchsorted` with equivalent implementation for torch.

Args:
a: numpy array or tensor, containing monotonically increasing sequence on the innermost dimension.
v: containing the search values.
right: if False, return the first suitable location that is found, if True, return the last such index.
sorter: if `a` is numpy array, optional array of integer indices that sort array `a` into ascending order.
kwargs: if `a` is PyTorch Tensor, additional args for `torch.searchsorted`, more details:
https://pytorch.org/docs/stable/generated/torch.searchsorted.html.

"""
side = "right" if right else "left"
if isinstance(a, np.ndarray):
return np.searchsorted(a, v, side, sorter) # type: ignore
return torch.searchsorted(a, v, right=right) # type: ignore
return torch.searchsorted(a, v, right=right, **kwargs) # type: ignore


def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None, **kwargs):
"""
`np.repeat` with equivalent implementation for torch (`repeat_interleave`).

Args:
a: input data to repeat.
repeats: number of repetitions for each element, repeats is broadcasted to fit the shape of the given axis.
axis: axis along which to repeat values.
kwargs: if `a` is PyTorch Tensor, additional args for `torch.repeat_interleave`, more details:
https://pytorch.org/docs/stable/generated/torch.repeat_interleave.html.

def repeat(a: NdarrayOrTensor, repeats: int, axis: Optional[int] = None):
"""`np.repeat` with equivalent implementation for torch (`repeat_interleave`)."""
"""
if isinstance(a, np.ndarray):
return np.repeat(a, repeats, axis)
return torch.repeat_interleave(a, repeats, dim=axis)
return torch.repeat_interleave(a, repeats, dim=axis, **kwargs)


def isnan(x: NdarrayOrTensor):
Expand All @@ -330,7 +367,7 @@ def ascontiguousarray(x: NdarrayOrTensor, **kwargs):
Args:
x: array/tensor
kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous.
https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html.

"""
if isinstance(x, np.ndarray):
Expand Down