Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 149 additions & 36 deletions monai/inferers/inferer.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,7 @@
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ChannelWise,
ClassesToIndices,
ConvertToMultiChannelBasedOnBratsClasses,
CuCIM,
Expand All @@ -536,6 +537,7 @@
RandIdentity,
RandImageFilter,
RandLambda,
RandChannelWise,
RandTorchIO,
RandTorchVision,
RemoveRepeatedChannel,
Expand Down Expand Up @@ -568,6 +570,9 @@
CastToTyped,
CastToTypeD,
CastToTypeDict,
ChannelWised,
ChannelWiseD,
ChannelWiseDict,
ClassesToIndicesd,
ClassesToIndicesD,
ClassesToIndicesDict,
Expand Down Expand Up @@ -631,6 +636,9 @@
RandLambdad,
RandLambdaD,
RandLambdaDict,
RandChannelWised,
RandChannelWiseD,
RandChannelWiseDict,
RandTorchIOd,
RandTorchIOD,
RandTorchIODict,
Expand Down
78 changes: 78 additions & 0 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
"EnsureType",
"RepeatChannel",
"RemoveRepeatedChannel",
"ChannelWise",
"RandChannelWise",
"SplitDim",
"CastToType",
"ToTensor",
Expand Down Expand Up @@ -288,6 +290,82 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
return out


class ChannelWise(Transform):
"""
Apply a given transform to each channel of the input array independently and
concatenate the results back along the channel dimension.

Args:
transform: a callable transform to apply to each channel.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, transform: Callable) -> None:
self.transform = transform

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
if img.shape[0] == 0:
return img

results = []
for i in range(img.shape[0]):
res = self.transform(img[[i], ...])
results.append(res)

if isinstance(img, torch.Tensor):
return torch.cat(results, dim=0)
return np.concatenate(results, axis=0)
Comment on lines +315 to +321
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard per-channel output shape invariants before concatenation.

If the wrapped transform removes the singleton channel dimension, concatenation can silently corrupt layout (e.g., flattening channel semantics into axis 0). Validate per-channel outputs before appending.

Suggested patch
 class ChannelWise(Transform):
@@
     def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
@@
-        results = []
+        results = []
+        expected_ndim = img.ndim
+        expected_shape_wo_channel: tuple[int, ...] | None = None
         for i in range(img.shape[0]):
             res = self.transform(img[[i], ...])
+            if res.ndim != expected_ndim:
+                raise ValueError(
+                    f"channel transform must preserve ndim={expected_ndim}, got ndim={res.ndim} for channel {i}."
+                )
+            current_shape_wo_channel = tuple(res.shape[1:])
+            if expected_shape_wo_channel is None:
+                expected_shape_wo_channel = current_shape_wo_channel
+            elif current_shape_wo_channel != expected_shape_wo_channel:
+                raise ValueError(
+                    "channel transform outputs must have consistent non-channel shape; "
+                    f"got {current_shape_wo_channel} and {expected_shape_wo_channel}."
+                )
             results.append(res)
@@
 class RandChannelWise(RandomizableTransform):
@@
     def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
@@
-        results = []
+        results = []
+        expected_ndim = img.ndim
+        expected_shape_wo_channel: tuple[int, ...] | None = None
         for i in range(img.shape[0]):
             res = self.transform(img[[i], ...])
+            if res.ndim != expected_ndim:
+                raise ValueError(
+                    f"channel transform must preserve ndim={expected_ndim}, got ndim={res.ndim} for channel {i}."
+                )
+            current_shape_wo_channel = tuple(res.shape[1:])
+            if expected_shape_wo_channel is None:
+                expected_shape_wo_channel = current_shape_wo_channel
+            elif current_shape_wo_channel != expected_shape_wo_channel:
+                raise ValueError(
+                    "channel transform outputs must have consistent non-channel shape; "
+                    f"got {current_shape_wo_channel} and {expected_shape_wo_channel}."
+                )
             results.append(res)

Also applies to: 359-365

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/utility/array.py` around lines 315 - 321, The per-channel
loop in the method calling self.transform collects results and blindly
concatenates them (torch.cat / np.concatenate), which can corrupt layout if a
wrapped transform drops the singleton channel dimension; before appending each
res, validate that its dimensionality and leading channel size preserve the
expected singleton channel (e.g., for torch.Tensor res.ndim and res.shape[0]==1,
for np.ndarray res.ndim and res.shape[0]==1) and either raise a clear error or
reintroduce the missing channel axis (unsqueeze or np.expand_dims) so that all
items in results are consistent for concatenation; apply the same guard logic to
the analogous loop around lines 359-365.



class RandChannelWise(RandomizableTransform):
"""
Randomizable version of :py:class:`monai.transforms.ChannelWise`, the input
`transform` will be applied independently to each channel.

Args:
transform: a callable transform to apply to each channel.
prob: probability of applying the transform to the entire image.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, transform: Callable, prob: float = 1.0) -> None:
RandomizableTransform.__init__(self, prob)
self.transform = transform

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandChannelWise:
super().set_random_state(seed, state)
if hasattr(self.transform, "set_random_state"):
self.transform.set_random_state(seed, state)
return self

def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
if randomize:
self.randomize(None)
if not self._do_transform:
return img

if img.shape[0] == 0:
return img

results = []
for i in range(img.shape[0]):
res = self.transform(img[[i], ...])
results.append(res)

if isinstance(img, torch.Tensor):
return torch.cat(results, dim=0)
return np.concatenate(results, axis=0)



class SplitDim(Transform, MultiSampleTrait):
"""
Given an image of size X along a certain dimension, return a list of length X containing
Expand Down
74 changes: 74 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
ApplyTransformToPoints,
AsChannelLast,
CastToType,
ChannelWise,
ClassesToIndices,
ConvertToMultiChannelBasedOnBratsClasses,
CuCIM,
Expand All @@ -52,6 +53,7 @@
LabelToMask,
Lambda,
MapLabelValue,
RandChannelWise,
RemoveRepeatedChannel,
RepeatChannel,
SimulateDelay,
Expand Down Expand Up @@ -88,6 +90,9 @@
"ConcatItemsD",
"ConcatItemsDict",
"ConcatItemsd",
"ChannelWiseD",
"ChannelWiseDict",
"ChannelWised",
"ConvertToMultiChannelBasedOnBratsClassesD",
"ConvertToMultiChannelBasedOnBratsClassesDict",
"ConvertToMultiChannelBasedOnBratsClassesd",
Expand Down Expand Up @@ -131,6 +136,9 @@
"FlattenSubKeysd",
"FlattenSubKeysD",
"FlattenSubKeysDict",
"RandChannelWiseD",
"RandChannelWiseDict",
"RandChannelWised",
"RandCuCIMd",
"RandCuCIMD",
"RandCuCIMDict",
Expand Down Expand Up @@ -338,6 +346,70 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
return d


class ChannelWised(MapTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.ChannelWise`.
"""

backend = ChannelWise.backend

def __init__(self, keys: KeysCollection, transform: Callable, allow_missing_keys: bool = False) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
transform: a callable transform to apply to each channel.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
self.converter = ChannelWise(transform=transform)

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
for key in self.key_iterator(d):
d[key] = self.converter(d[key])
return d


class RandChannelWised(MapTransform, RandomizableTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.RandChannelWise`.
"""

backend = RandChannelWise.backend

def __init__(self, keys: KeysCollection, transform: Callable, prob: float = 1.0, allow_missing_keys: bool = False) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
transform: a callable transform to apply to each channel.
prob: probability of applying the transform to the entire image.
allow_missing_keys: don't raise exception if key is missing.
"""
MapTransform.__init__(self, keys, allow_missing_keys)
RandomizableTransform.__init__(self, prob)
self.converter = RandChannelWise(transform=transform, prob=1.0)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
) -> RandChannelWised:
super().set_random_state(seed, state)
if hasattr(self.converter, "set_random_state"):
self.converter.set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize(None)
if not self._do_transform:
return d

for key in self.key_iterator(d):
d[key] = self.converter(d[key], randomize=False)
return d


class SplitDimd(MapTransform, MultiSampleTrait):
backend = SplitDim.backend

Expand Down Expand Up @@ -2032,6 +2104,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
AsChannelLastD = AsChannelLastDict = AsChannelLastd
EnsureChannelFirstD = EnsureChannelFirstDict = EnsureChannelFirstd
RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld
ChannelWiseD = ChannelWiseDict = ChannelWised
RandChannelWiseD = RandChannelWiseDict = RandChannelWised
RepeatChannelD = RepeatChannelDict = RepeatChanneld
SplitDimD = SplitDimDict = SplitDimd
CastToTypeD = CastToTypeDict = CastToTyped
Expand Down
58 changes: 58 additions & 0 deletions tests/inferers/test_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

_, has_scipy = optional_import("scipy")
_, has_einops = optional_import("einops")
DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler")

TEST_CASES = [
[
Expand Down Expand Up @@ -126,6 +127,63 @@ def test_ddpm_sampler(self, model_params, input_shape):
)
self.assertEqual(len(intermediates), 10)

@skipUnless(has_einops and has_diffusers, "Requires einops and diffusers")
def test_diffusers_ddpm_call(self):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=[32, 64],
attention_levels=[False, True],
num_res_blocks=1,
num_head_channels=32,
)
model.to(device)
model.eval()
scheduler = DiffusersDDPMScheduler(num_train_timesteps=1000, beta_schedule="linear", prediction_type="epsilon")
scheduler.set_timesteps(num_inference_steps=50)
inferer = DiffusionInferer(scheduler=scheduler)

batch_size = 2
image_size = 32
inputs = torch.randn(batch_size, 1, image_size, image_size).to(device)
noise = torch.randn_like(inputs)
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,)).long().to(device)
with torch.no_grad():
prediction = inferer(inputs=inputs, diffusion_model=model, noise=noise, timesteps=timesteps)

self.assertEqual(prediction.shape, inputs.shape)
scheduler.set_timesteps(num_inference_steps=2)
sample = inferer.sample(input_noise=noise, diffusion_model=model, scheduler=scheduler, verbose=False)
self.assertEqual(sample.shape, inputs.shape)

@skipUnless(has_einops and has_diffusers, "Requires einops and diffusers")
def test_diffusers_ddpm_get_likelihood(self):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = DiffusionModelUNet(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=[8],
norm_num_groups=8,
attention_levels=[True],
num_res_blocks=1,
num_head_channels=8,
)
model.to(device)
model.eval()
inputs = torch.randn(2, 1, 8, 8).to(device)
scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon")
inferer = DiffusionInferer(scheduler=scheduler)
scheduler.set_timesteps(num_inference_steps=10)
likelihood, intermediates = inferer.get_likelihood(
inputs=inputs, diffusion_model=model, scheduler=scheduler, save_intermediates=True
)
self.assertEqual(len(intermediates), 10)
self.assertEqual(intermediates[0].shape, inputs.shape)
self.assertEqual(likelihood.shape[0], inputs.shape[0])

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_ddim_sampler(self, model_params, input_shape):
Expand Down
41 changes: 41 additions & 0 deletions tests/inferers/test_latent_diffusion_inferer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from monai.utils import optional_import

_, has_einops = optional_import("einops")
DiffusersDDPMScheduler, has_diffusers = optional_import("diffusers", name="DDPMScheduler")
TEST_CASES = [
[
"AutoencoderKL",
Expand Down Expand Up @@ -414,6 +415,46 @@ def test_sample_shape(
)
self.assertEqual(sample.shape, input_shape)

@skipUnless(has_einops and has_diffusers, "Requires einops and diffusers")
def test_diffusers_ddpm_sample_shape(self):
device = "cuda:0" if torch.cuda.is_available() else "cpu"
stage_1 = AutoencoderKL(
spatial_dims=2,
in_channels=1,
out_channels=1,
channels=(4, 4),
latent_channels=3,
attention_levels=[False, False],
num_res_blocks=1,
with_encoder_nonlocal_attn=False,
with_decoder_nonlocal_attn=False,
norm_num_groups=4,
)
stage_2 = DiffusionModelUNet(
spatial_dims=2,
in_channels=3,
out_channels=3,
channels=[4, 4],
norm_num_groups=4,
attention_levels=[False, False],
num_res_blocks=1,
num_head_channels=4,
)
stage_1.to(device)
stage_2.to(device)
stage_1.eval()
stage_2.eval()

noise = torch.randn(1, 3, 4, 4).to(device)
scheduler = DiffusersDDPMScheduler(num_train_timesteps=10, beta_schedule="linear", prediction_type="epsilon")
inferer = LatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
scheduler.set_timesteps(num_inference_steps=10)

sample = inferer.sample(
input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler
)
self.assertEqual(sample.shape, (1, 1, 8, 8))

@parameterized.expand(TEST_CASES)
@skipUnless(has_einops, "Requires einops")
def test_sample_shape_with_cfg(
Expand Down
Loading
Loading