Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def randomize(self, im_shape):

def __call__(self, img):
self.randomize(img.shape)
return img + self._noise if self._do_transform else img
return img + self._noise.astype(img.dtype) if self._do_transform else img


class ShiftIntensity(Transform):
Expand All @@ -55,7 +55,7 @@ def __init__(self, offset):
self.offset = offset

def __call__(self, img):
return img + self.offset
return (img + self.offset).astype(img.dtype)


class RandShiftIntensity(Randomizable, Transform):
Expand Down Expand Up @@ -92,24 +92,22 @@ class ScaleIntensity(Transform):
If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.
"""

def __init__(self, minv=0.0, maxv=1.0, factor=None, dtype=np.float32):
def __init__(self, minv=0.0, maxv=1.0, factor=None):
"""
Args:
minv (int or float): minimum value of output data.
maxv (int or float): maximum value of output data.
factor (float): factor scale by ``v = v * (1 + factor)``.
dtype (np.dtype): expected output data type.
"""
self.minv = minv
self.maxv = maxv
self.factor = factor
self.dtype = dtype

def __call__(self, img):
if self.minv is not None and self.maxv is not None:
return rescale_array(img, self.minv, self.maxv, self.dtype)
return rescale_array(img, self.minv, self.maxv, img.dtype)
else:
return (img * (1 + self.factor)).astype(self.dtype)
return (img * (1 + self.factor)).astype(img.dtype)


class RandScaleIntensity(Randomizable, Transform):
Expand All @@ -118,18 +116,17 @@ class RandScaleIntensity(Randomizable, Transform):
is randomly picked from (factors[0], factors[0]).
"""

def __init__(self, factors, prob=0.1, dtype=np.float32):
def __init__(self, factors, prob=0.1):
"""
Args:
factors(float, tuple or list): factor range to randomly scale by ``v = v * (1 + factor)``.
if single number, factor value is picked from (-factors, factors).
prob (float): probability of scale.
dtype (np.dtype): expected output data type.

"""
self.factors = (-factors, factors) if not isinstance(factors, (list, tuple)) else factors
assert len(self.factors) == 2, "factors should be a number or pair of numbers."
self.prob = prob
self.dtype = dtype
self._do_transform = False

def randomize(self):
Expand All @@ -140,7 +137,7 @@ def __call__(self, img):
self.randomize()
if not self._do_transform:
return img
scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)
scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor)
return scaler(img)


Expand Down Expand Up @@ -205,7 +202,7 @@ def __init__(self, threshold, above=True, cval=0):
self.cval = cval

def __call__(self, img):
return np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval)
return np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval).astype(img.dtype)


class ScaleIntensityRange(Transform):
Expand Down
17 changes: 7 additions & 10 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""

import numpy as np

from monai.transforms.compose import MapTransform, Randomizable
from monai.transforms.intensity.array import (
NormalizeIntensity,
Expand Down Expand Up @@ -60,7 +58,7 @@ def __call__(self, data):
if not self._do_transform:
return d
for key in self.keys:
d[key] = d[key] + self._noise
d[key] = d[key] + self._noise.astype(d[key].dtype)
return d


Expand Down Expand Up @@ -129,18 +127,18 @@ class ScaleIntensityd(MapTransform):
If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``.
"""

def __init__(self, keys, minv=0.0, maxv=1.0, factor=None, dtype=np.float32):
def __init__(self, keys, minv=0.0, maxv=1.0, factor=None):
"""
Args:
keys (hashable items): keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
minv (int or float): minimum value of output data.
maxv (int or float): maximum value of output data.
factor (float): factor scale by ``v = v * (1 + factor)``.
dtype (np.dtype): expected output data type.

"""
super().__init__(keys)
self.scaler = ScaleIntensity(minv, maxv, factor, dtype)
self.scaler = ScaleIntensity(minv, maxv, factor)

def __call__(self, data):
d = dict(data)
Expand All @@ -154,7 +152,7 @@ class RandScaleIntensityd(Randomizable, MapTransform):
dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`.
"""

def __init__(self, keys, factors, prob=0.1, dtype=np.float32):
def __init__(self, keys, factors, prob=0.1):
"""
Args:
keys (hashable items): keys of the corresponding items to be transformed.
Expand All @@ -163,13 +161,12 @@ def __init__(self, keys, factors, prob=0.1, dtype=np.float32):
if single number, factor value is picked from (-factors, factors).
prob (float): probability of rotating.
(Default 0.1, with 10% probability it returns a rotated array.)
dtype (np.dtype): expected output data type.

"""
super().__init__(keys)
self.factors = (-factors, factors) if not isinstance(factors, (list, tuple)) else factors
assert len(self.factors) == 2, "factors should be a number or pair of numbers."
self.prob = prob
self.dtype = dtype
self._do_transform = False

def randomize(self):
Expand All @@ -181,7 +178,7 @@ def __call__(self, data):
self.randomize()
if not self._do_transform:
return d
scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype)
scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor)
for key in self.keys:
d[key] = scaler(d[key])
return d
Expand Down
10 changes: 5 additions & 5 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def __call__(self, img):
flipped = list()
for channel in img:
flipped.append(np.flip(channel, self.spatial_axis))
return np.stack(flipped)
return np.stack(flipped).astype(img.dtype)


class Resize(Transform):
Expand Down Expand Up @@ -266,7 +266,7 @@ def __call__(self, img):
anti_aliasing_sigma=self.anti_aliasing_sigma,
)
)
return np.stack(resized).astype(np.float32)
return np.stack(resized).astype(img.dtype)


class Rotate(Transform):
Expand Down Expand Up @@ -316,7 +316,7 @@ def __call__(self, img):
prefilter=self.prefilter,
)
)
return np.stack(rotated).astype(np.float32)
return np.stack(rotated).astype(img.dtype)


class Zoom(Transform):
Expand Down Expand Up @@ -383,7 +383,7 @@ def __call__(self, img):
prefilter=self.prefilter,
)
)
zoomed = np.stack(zoomed).astype(np.float32)
zoomed = np.stack(zoomed).astype(img.dtype)

if not self.keep_size or np.allclose(img.shape, zoomed.shape):
return zoomed
Expand Down Expand Up @@ -424,7 +424,7 @@ def __call__(self, img):
rotated = list()
for channel in img:
rotated.append(np.rot90(channel, self.k, self.spatial_axes))
return np.stack(rotated)
return np.stack(rotated).astype(img.dtype)


class RandRotate90(Randomizable, Transform):
Expand Down