Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ class StdShiftIntensity(Transform):
nonzero: whether only count non-zero values.
channel_wise: if True, calculate on each channel separately. Please ensure
that the first dimension represents the channel of the image if True.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
Expand Down Expand Up @@ -323,7 +323,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
img, *_ = convert_data_type(img, dtype=self.dtype)
if self.dtype is not None:
img, *_ = convert_data_type(img, dtype=self.dtype)
if self.channel_wise:
for i, d in enumerate(img):
img[i] = self._stdshift(d) # type: ignore
Expand Down Expand Up @@ -355,7 +356,7 @@ def __init__(
prob: probability of std shift.
nonzero: whether only count non-zero values.
channel_wise: if True, calculate on each channel separately.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.

"""
RandomizableTransform.__init__(self, prob)
Expand Down Expand Up @@ -416,7 +417,7 @@ def __init__(
this parameter, please set `minv` and `maxv` into None.
channel_wise: if True, scale on each channel separately. Please ensure
that the first dimension represents the channel of the image if True.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
"""
self.minv = minv
self.maxv = maxv
Expand All @@ -439,7 +440,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return rescale_array(img, self.minv, self.maxv, dtype=self.dtype)
if self.factor is not None:
ret = img * (1 + self.factor)
ret, *_ = convert_data_type(ret, dtype=self.dtype)
if self.dtype is not None:
ret, *_ = convert_data_type(ret, dtype=self.dtype or img.dtype)
return ret
raise ValueError("Incompatible values: minv=None or maxv=None and factor=None.")

Expand All @@ -460,7 +462,7 @@ def __init__(
factors: factor range to randomly scale by ``v = v * (1 + factor)``.
if single number, factor value is picked from (-factors, factors).
prob: probability of scale.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.

"""
RandomizableTransform.__init__(self, prob)
Expand Down Expand Up @@ -507,7 +509,7 @@ class RandBiasField(RandomizableTransform):
degree: degree of freedom of the polynomials. The value should be no less than 1.
Defaults to 3.
coeff_range: range of the random coefficients. Defaults to (0.0, 0.1).
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
prob: probability to do random bias field.

"""
Expand Down Expand Up @@ -580,7 +582,7 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen
)
img_np, *_ = convert_data_type(img, np.ndarray)
out = img_np * np.exp(_bias_fields)
out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype)
out, *_ = convert_to_dst_type(src=out, dst=img, dtype=self.dtype or img.dtype)
return out


Expand All @@ -598,7 +600,7 @@ class NormalizeIntensity(Transform):
nonzero: whether only normalize non-zero values.
channel_wise: if using calculated mean and std, calculate on each channel separately
or calculate on the entire image directly.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
Expand Down Expand Up @@ -665,6 +667,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`, assuming `img` is a channel-first array if `self.channel_wise` is True,
"""
dtype = self.dtype or img.dtype
if self.channel_wise:
if self.subtrahend is not None and len(self.subtrahend) != len(img):
raise ValueError(f"img has {len(img)} channels, but subtrahend has {len(self.subtrahend)} components.")
Expand All @@ -680,7 +683,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
else:
img = self._normalize(img, self.subtrahend, self.divisor)

out, *_ = convert_data_type(img, dtype=self.dtype)
out, *_ = convert_data_type(img, dtype=dtype)
return out


Expand Down Expand Up @@ -725,21 +728,26 @@ class ScaleIntensityRange(Transform):
b_min: intensity target range min.
b_max: intensity target range max.
clip: whether to perform clip after scaling.
dtype: output data type, if None, same as input image. defaults to float32.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False) -> None:
def __init__(
self, a_min: float, a_max: float, b_min: float, b_max: float, clip: bool = False, dtype: DtypeLike = np.float32
) -> None:
self.a_min = a_min
self.a_max = a_max
self.b_min = b_min
self.b_max = b_max
self.clip = clip
self.dtype = dtype

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
dtype = self.dtype or img.dtype
if self.a_max - self.a_min == 0.0:
warn("Divide by zero (a_min == a_max)", Warning)
return img - self.a_min + self.b_min
Expand All @@ -748,7 +756,9 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
img = img * (self.b_max - self.b_min) + self.b_min
if self.clip:
img = clip(img, self.b_min, self.b_max)
return img
ret, *_ = convert_data_type(img, dtype=dtype)

return ret


class AdjustContrast(Transform):
Expand Down Expand Up @@ -883,12 +893,20 @@ class ScaleIntensityRangePercentiles(Transform):
b_max: intensity target range max.
clip: whether to perform clip after scaling.
relative: whether to scale to the corresponding percentiles of [b_min, b_max].
dtype: output data type, if None, same as input image. defaults to float32.
"""

backend = ScaleIntensityRange.backend

def __init__(
self, lower: float, upper: float, b_min: float, b_max: float, clip: bool = False, relative: bool = False
self,
lower: float,
upper: float,
b_min: float,
b_max: float,
clip: bool = False,
relative: bool = False,
dtype: DtypeLike = np.float32,
) -> None:
if lower < 0.0 or lower > 100.0:
raise ValueError("Percentiles must be in the range [0, 100]")
Expand All @@ -900,6 +918,7 @@ def __init__(
self.b_max = b_max
self.clip = clip
self.relative = relative
self.dtype = dtype

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Expand All @@ -914,7 +933,7 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
b_min = ((self.b_max - self.b_min) * (self.lower / 100.0)) + self.b_min
b_max = ((self.b_max - self.b_min) * (self.upper / 100.0)) + self.b_min

scalar = ScaleIntensityRange(a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=False)
scalar = ScaleIntensityRange(a_min=a_min, a_max=a_max, b_min=b_min, b_max=b_max, clip=False, dtype=self.dtype)
img = scalar(img)

if self.clip:
Expand Down Expand Up @@ -1968,7 +1987,7 @@ class HistogramNormalize(Transform):
mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.
only points at which `mask==True` are used for the equalization.
can also provide the mask along with img at runtime.
dtype: data type of the output, default to `float32`.
dtype: data type of the output, if None, same as input image. default to `float32`.

"""

Expand All @@ -1988,12 +2007,15 @@ def __init__(
self.mask = mask
self.dtype = dtype

def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> np.ndarray:
return equalize_hist(
img=img,
mask=mask if mask is not None else self.mask,
num_bins=self.num_bins,
min=self.min,
max=self.max,
dtype=self.dtype,
)
def __call__(self, img: NdarrayOrTensor, mask: Optional[NdarrayOrTensor] = None) -> NdarrayOrTensor:
img_np: np.ndarray
img_np, *_ = convert_data_type(img, np.ndarray) # type: ignore
mask = mask if mask is not None else self.mask
mask_np: Optional[np.ndarray] = None
if mask is not None:
mask_np, *_ = convert_data_type(mask, np.ndarray) # type: ignore

ret = equalize_hist(img=img_np, mask=mask_np, num_bins=self.num_bins, min=self.min, max=self.max)
out, *_ = convert_to_dst_type(src=ret, dst=img, dtype=self.dtype or img.dtype)

return out
22 changes: 13 additions & 9 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def __init__(
nonzero: whether only count non-zero values.
channel_wise: if True, calculate on each channel separately. Please ensure
that the first dimension represents the channel of the image if True.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
Expand Down Expand Up @@ -451,7 +451,7 @@ def __init__(
prob: probability of std shift.
nonzero: whether only count non-zero values.
channel_wise: if True, calculate on each channel separately.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""
MapTransform.__init__(self, keys, allow_missing_keys)
Expand Down Expand Up @@ -509,7 +509,7 @@ def __init__(
this parameter, please set `minv` and `maxv` into None.
channel_wise: if True, scale on each channel separately. Please ensure
that the first dimension represents the channel of the image if True.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.

"""
Expand Down Expand Up @@ -546,7 +546,7 @@ def __init__(
if single number, factor value is picked from (-factors, factors).
prob: probability of rotating.
(Default 0.1, with 10% probability it returns a rotated array.)
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.

"""
Expand Down Expand Up @@ -597,7 +597,7 @@ def __init__(
degree: degree of freedom of the polynomials. The value should be no less than 1.
Defaults to 3.
coeff_range: range of the random coefficients. Defaults to (0.0, 0.1).
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
prob: probability to do random bias field.
allow_missing_keys: don't raise exception if key is missing.

Expand Down Expand Up @@ -641,7 +641,7 @@ class NormalizeIntensityd(MapTransform):
nonzero: whether only normalize non-zero values.
channel_wise: if using calculated mean and std, calculate on each channel separately
or calculate on the entire image directly.
dtype: output data type, defaults to float32.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""

Expand Down Expand Up @@ -712,6 +712,7 @@ class ScaleIntensityRanged(MapTransform):
b_min: intensity target range min.
b_max: intensity target range max.
clip: whether to perform clip after scaling.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""

Expand All @@ -725,10 +726,11 @@ def __init__(
b_min: float,
b_max: float,
clip: bool = False,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip)
self.scaler = ScaleIntensityRange(a_min, a_max, b_min, b_max, clip, dtype)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down Expand Up @@ -826,6 +828,7 @@ class ScaleIntensityRangePercentilesd(MapTransform):
b_max: intensity target range max.
clip: whether to perform clip after scaling.
relative: whether to scale to the corresponding percentiles of [b_min, b_max]
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: don't raise exception if key is missing.
"""

Expand All @@ -840,10 +843,11 @@ def __init__(
b_max: float,
clip: bool = False,
relative: bool = False,
dtype: DtypeLike = np.float32,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative)
self.scaler = ScaleIntensityRangePercentiles(lower, upper, b_min, b_max, clip, relative, dtype)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
Expand Down Expand Up @@ -1549,7 +1553,7 @@ class HistogramNormalized(MapTransform):
only points at which `mask==True` are used for the equalization.
can also provide the mask by `mask_key` at runtime.
mask_key: if mask is None, will try to get the mask with `mask_key`.
dtype: data type of the output, default to `float32`.
dtype: output data type, if None, same as input image. defaults to float32.
allow_missing_keys: do not raise exception if key is missing.

"""
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ class ToNumpy(Transform):

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, dtype: Optional[DtypeLike] = None) -> None:
def __init__(self, dtype: DtypeLike = None) -> None:
super().__init__()
self.dtype = dtype

Expand Down
4 changes: 1 addition & 3 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,7 @@ class ToNumpyd(MapTransform):

backend = ToNumpy.backend

def __init__(
self, keys: KeysCollection, dtype: Optional[DtypeLike] = None, allow_missing_keys: bool = False
) -> None:
def __init__(self, keys: KeysCollection, dtype: DtypeLike = None, allow_missing_keys: bool = False) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
Expand Down
Loading