Skip to content
Merged
4 changes: 2 additions & 2 deletions monai/config/type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@
# container must be iterable.
IndexSelection = Union[Iterable[int], int]

#: Type of datatypes: Adapted from https://github.com/numpy/numpy/blob/master/numpy/typing/_dtype_like.py
DtypeLike = Union[np.dtype, type, None]
#: Type of datatypes: Adapted from https://github.com/numpy/numpy/blob/v1.21.4/numpy/typing/_dtype_like.py#L121
DtypeLike = Union[np.dtype, type, str, None]

#: NdarrayTensor
#
Expand Down
4 changes: 3 additions & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,14 +446,16 @@ class ToCupy(Transform):

Args:
dtype: data type specifier. It is inferred from the input by default.
if not None, must be an argument of `numpy.dtype`, for more details:
https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`.

"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, dtype=None, wrap_sequence: bool = True) -> None:
def __init__(self, dtype: Optional[np.dtype] = None, wrap_sequence: bool = True) -> None:
super().__init__()
self.dtype = dtype
self.wrap_sequence = wrap_sequence
Expand Down
8 changes: 7 additions & 1 deletion monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,8 @@ class ToCupyd(MapTransform):
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
dtype: data type specifier. It is inferred from the input by default.
if not None, must be an argument of `numpy.dtype`, for more details:
https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.
wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`.
allow_missing_keys: don't raise exception if key is missing.
Expand All @@ -600,7 +602,11 @@ class ToCupyd(MapTransform):
backend = ToCupy.backend

def __init__(
self, keys: KeysCollection, dtype=None, wrap_sequence: bool = True, allow_missing_keys: bool = False
self,
keys: KeysCollection,
dtype: Optional[np.dtype] = None,
wrap_sequence: bool = True,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.converter = ToCupy(dtype=dtype, wrap_sequence=wrap_sequence)
Expand Down
22 changes: 11 additions & 11 deletions monai/utils/type_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def dtype_torch_to_numpy(dtype):
def dtype_numpy_to_torch(dtype):
"""Convert a numpy dtype to its torch equivalent."""
# np dtypes can be given as np.float32 and np.dtype(np.float32) so unify them
dtype = np.dtype(dtype) if isinstance(dtype, type) else dtype
dtype = np.dtype(dtype) if isinstance(dtype, (type, str)) else dtype
return look_up_option(dtype, _np_to_torch_dtype)


Expand Down Expand Up @@ -151,8 +151,8 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
will convert Tensor, Numpy array, float, int, bool to numpy arrays, strings and objects keep the original.
for dictionary, list or tuple, convert every item to a numpy array if applicable.
dtype: target data type when converting to numpy array.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
wrap_sequence: if `False`, then lists will recursively call this function.
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
"""
if isinstance(data, torch.Tensor):
data = data.detach().to(dtype=get_equivalent_dtype(dtype, torch.Tensor), device="cpu").numpy()
Expand All @@ -175,19 +175,19 @@ def convert_to_numpy(data, dtype: DtypeLike = None, wrap_sequence: bool = False)
return data


def convert_to_cupy(data, dtype, wrap_sequence: bool = False):
def convert_to_cupy(data, dtype: Optional[np.dtype] = None, wrap_sequence: bool = False):
"""
Utility to convert the input data to a cupy array. If passing a dictionary, list or tuple,
recursively check every item and convert it to cupy array.

Args:
data: input data can be PyTorch Tensor, numpy array, cupy array, list, dictionary, int, float, bool, str, etc.
Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays

Tensor, numpy array, cupy array, float, int, bool are converted to cupy arrays,
for dictionary, list or tuple, convert every item to a numpy array if applicable.
dtype: target data type when converting to Cupy array.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
dtype: target data type when converting to Cupy array, tt must be an argument of `numpy.dtype`,
for more details: https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.
wrap_sequence: if `False`, then lists will recursively call this function.
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
"""

# direct calls
Expand Down Expand Up @@ -227,8 +227,8 @@ def convert_data_type(
dtype: dtype of output data. Converted to correct library type (e.g.,
`np.float32` is converted to `torch.float32` if output type is `torch.Tensor`).
If left blank, it remains unchanged.
wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`.
If `True`, then `[1, 2]` -> `array([1, 2])`.
wrap_sequence: if `False`, then lists will recursively call this function.
E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`.
Returns:
modified data, orig_type, orig_device

Expand Down
6 changes: 3 additions & 3 deletions tests/test_normalize_intensity.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,23 @@
TESTS.append(
[
p,
{"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3]},
{"nonzero": False, "channel_wise": True, "subtrahend": [1, 2, 3], "dtype": np.float32},
p(np.ones((3, 2, 2))),
p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-2.0, -2.0], [-2.0, -2.0]]])),
]
)
TESTS.append(
[
p,
{"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2]},
{"nonzero": True, "channel_wise": True, "subtrahend": [1, 2, 3], "divisor": [0, 0, 2], "dtype": "float32"},
p(np.ones((3, 2, 2))),
p(np.array([[[0.0, 0.0], [0.0, 0.0]], [[-1.0, -1.0], [-1.0, -1.0]], [[-1.0, -1.0], [-1.0, -1.0]]])),
]
)
TESTS.append(
[
p,
{"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0},
{"nonzero": True, "channel_wise": False, "subtrahend": 2, "divisor": 0, "dtype": torch.float32},
p(np.ones((3, 2, 2))),
p(np.ones((3, 2, 2)) * -1.0),
]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_tensor_input(self):
test_data = torch.tensor([[1, 2], [3, 4]])
test_data = test_data.rot90()
self.assertFalse(test_data.is_contiguous())
result = ToNumpy()(test_data)
result = ToNumpy(dtype=torch.uint8)(test_data)
self.assertTrue(isinstance(result, np.ndarray))
self.assertTrue(result.flags["C_CONTIGUOUS"])
assert_allclose(result, test_data, type_test=False)
Expand All @@ -73,7 +73,7 @@ def test_list_tuple(self):

def test_single_value(self):
for test_data in [5, np.array(5), torch.tensor(5)]:
result = ToNumpy()(test_data)
result = ToNumpy(dtype=np.uint8)(test_data)
self.assertTrue(isinstance(result, np.ndarray))
assert_allclose(result, np.asarray(test_data), type_test=False)
self.assertEqual(result.ndim, 0)
Expand Down