From d912568a12cdb493d6b6eda0abd4011b6f64151b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Thu, 7 May 2020 23:29:29 +0800 Subject: [PATCH] [DLMED] fix all the data type issues --- monai/transforms/intensity/array.py | 21 +++++++++------------ monai/transforms/intensity/dictionary.py | 17 +++++++---------- monai/transforms/spatial/array.py | 10 +++++----- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index be24505ac3..9156e9b764 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -41,7 +41,7 @@ def randomize(self, im_shape): def __call__(self, img): self.randomize(img.shape) - return img + self._noise if self._do_transform else img + return img + self._noise.astype(img.dtype) if self._do_transform else img class ShiftIntensity(Transform): @@ -55,7 +55,7 @@ def __init__(self, offset): self.offset = offset def __call__(self, img): - return img + self.offset + return (img + self.offset).astype(img.dtype) class RandShiftIntensity(Randomizable, Transform): @@ -92,24 +92,22 @@ class ScaleIntensity(Transform): If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``. """ - def __init__(self, minv=0.0, maxv=1.0, factor=None, dtype=np.float32): + def __init__(self, minv=0.0, maxv=1.0, factor=None): """ Args: minv (int or float): minimum value of output data. maxv (int or float): maximum value of output data. factor (float): factor scale by ``v = v * (1 + factor)``. - dtype (np.dtype): expected output data type. """ self.minv = minv self.maxv = maxv self.factor = factor - self.dtype = dtype def __call__(self, img): if self.minv is not None and self.maxv is not None: - return rescale_array(img, self.minv, self.maxv, self.dtype) + return rescale_array(img, self.minv, self.maxv, img.dtype) else: - return (img * (1 + self.factor)).astype(self.dtype) + return (img * (1 + self.factor)).astype(img.dtype) class RandScaleIntensity(Randomizable, Transform): @@ -118,18 +116,17 @@ class RandScaleIntensity(Randomizable, Transform): is randomly picked from (factors[0], factors[0]). """ - def __init__(self, factors, prob=0.1, dtype=np.float32): + def __init__(self, factors, prob=0.1): """ Args: factors(float, tuple or list): factor range to randomly scale by ``v = v * (1 + factor)``. if single number, factor value is picked from (-factors, factors). prob (float): probability of scale. - dtype (np.dtype): expected output data type. + """ self.factors = (-factors, factors) if not isinstance(factors, (list, tuple)) else factors assert len(self.factors) == 2, "factors should be a number or pair of numbers." self.prob = prob - self.dtype = dtype self._do_transform = False def randomize(self): @@ -140,7 +137,7 @@ def __call__(self, img): self.randomize() if not self._do_transform: return img - scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype) + scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) return scaler(img) @@ -205,7 +202,7 @@ def __init__(self, threshold, above=True, cval=0): self.cval = cval def __call__(self, img): - return np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval) + return np.where(img > self.threshold if self.above else img < self.threshold, img, self.cval).astype(img.dtype) class ScaleIntensityRange(Transform): diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 2235110563..7813ad500d 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -15,8 +15,6 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -import numpy as np - from monai.transforms.compose import MapTransform, Randomizable from monai.transforms.intensity.array import ( NormalizeIntensity, @@ -60,7 +58,7 @@ def __call__(self, data): if not self._do_transform: return d for key in self.keys: - d[key] = d[key] + self._noise + d[key] = d[key] + self._noise.astype(d[key].dtype) return d @@ -129,7 +127,7 @@ class ScaleIntensityd(MapTransform): If `minv` and `maxv` not provided, use `factor` to scale image by ``v = v * (1 + factor)``. """ - def __init__(self, keys, minv=0.0, maxv=1.0, factor=None, dtype=np.float32): + def __init__(self, keys, minv=0.0, maxv=1.0, factor=None): """ Args: keys (hashable items): keys of the corresponding items to be transformed. @@ -137,10 +135,10 @@ def __init__(self, keys, minv=0.0, maxv=1.0, factor=None, dtype=np.float32): minv (int or float): minimum value of output data. maxv (int or float): maximum value of output data. factor (float): factor scale by ``v = v * (1 + factor)``. - dtype (np.dtype): expected output data type. + """ super().__init__(keys) - self.scaler = ScaleIntensity(minv, maxv, factor, dtype) + self.scaler = ScaleIntensity(minv, maxv, factor) def __call__(self, data): d = dict(data) @@ -154,7 +152,7 @@ class RandScaleIntensityd(Randomizable, MapTransform): dictionary-based version :py:class:`monai.transforms.RandScaleIntensity`. """ - def __init__(self, keys, factors, prob=0.1, dtype=np.float32): + def __init__(self, keys, factors, prob=0.1): """ Args: keys (hashable items): keys of the corresponding items to be transformed. @@ -163,13 +161,12 @@ def __init__(self, keys, factors, prob=0.1, dtype=np.float32): if single number, factor value is picked from (-factors, factors). prob (float): probability of rotating. (Default 0.1, with 10% probability it returns a rotated array.) - dtype (np.dtype): expected output data type. + """ super().__init__(keys) self.factors = (-factors, factors) if not isinstance(factors, (list, tuple)) else factors assert len(self.factors) == 2, "factors should be a number or pair of numbers." self.prob = prob - self.dtype = dtype self._do_transform = False def randomize(self): @@ -181,7 +178,7 @@ def __call__(self, data): self.randomize() if not self._do_transform: return d - scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor, dtype=self.dtype) + scaler = ScaleIntensity(minv=None, maxv=None, factor=self.factor) for key in self.keys: d[key] = scaler(d[key]) return d diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index b35bf519ba..93a9733fc9 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -202,7 +202,7 @@ def __call__(self, img): flipped = list() for channel in img: flipped.append(np.flip(channel, self.spatial_axis)) - return np.stack(flipped) + return np.stack(flipped).astype(img.dtype) class Resize(Transform): @@ -266,7 +266,7 @@ def __call__(self, img): anti_aliasing_sigma=self.anti_aliasing_sigma, ) ) - return np.stack(resized).astype(np.float32) + return np.stack(resized).astype(img.dtype) class Rotate(Transform): @@ -316,7 +316,7 @@ def __call__(self, img): prefilter=self.prefilter, ) ) - return np.stack(rotated).astype(np.float32) + return np.stack(rotated).astype(img.dtype) class Zoom(Transform): @@ -383,7 +383,7 @@ def __call__(self, img): prefilter=self.prefilter, ) ) - zoomed = np.stack(zoomed).astype(np.float32) + zoomed = np.stack(zoomed).astype(img.dtype) if not self.keep_size or np.allclose(img.shape, zoomed.shape): return zoomed @@ -424,7 +424,7 @@ def __call__(self, img): rotated = list() for channel in img: rotated.append(np.rot90(channel, self.k, self.spatial_axes)) - return np.stack(rotated) + return np.stack(rotated).astype(img.dtype) class RandRotate90(Randomizable, Transform):