Skip to content

Feat/channel wise transforms#8898

Open
ugbotueferhire wants to merge 6 commits into
Project-MONAI:devfrom
ugbotueferhire:feat/channel-wise-transforms
Open

Feat/channel wise transforms#8898
ugbotueferhire wants to merge 6 commits into
Project-MONAI:devfrom
ugbotueferhire:feat/channel-wise-transforms

Conversation

@ugbotueferhire
Copy link
Copy Markdown
Contributor

Fixes #8311.

Description

Adds new wrapper transforms ChannelWise, RandChannelWise, ChannelWised, and RandChannelWised to independently apply an array-based transform to each channel of an input array. This resolves issues surrounding applying data augmentations channel-wise, which is a common requirement for early fusion models where different 3D volumes or modalities are concatenated along the channel axis.

The ChannelWise transform ensures the inner transform receives slices with a singleton channel dimension to maintain expected shape invariants, and successfully maintains independent PRNG states for random augmentations per-channel.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

ugbotueferhire and others added 5 commits May 25, 2026 01:09
Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
I, ugbotueferhire <ugbotueferhire@gmail.com>, hereby add my Signed-off-by to this commit: 8f95fb1

Signed-off-by: ugbotueferhire <ugbotueferhire@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Jun 5, 2026

📝 Walkthrough

Walkthrough

The PR introduces two independent changes: (1) scheduler-compatibility refactoring in diffusion inferers, centralizing scheduler handling with static helper methods for parameter detection, config access, and posterior computation, while refactoring sampling/likelihood loops in both DiffusionInferer and ControlNetDiffusionInferer, and (2) new channel-wise transforms (ChannelWise/RandChannelWise) that apply a provided callable independently per-channel and concatenate results, with dictionary-based wrappers and package exports, validated by comprehensive test suites.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Out of Scope Changes check ⚠️ Warning Changes to diffusion inferer (scheduler compatibility centralization) and test additions appear tangential; only channel-wise transforms directly address issue #8311. Remove diffusion inferer refactoring or clarify its necessity in the description. Focus PR scope on channel-wise transforms per issue #8311.
Docstring Coverage ⚠️ Warning Docstring coverage is 26.32% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed Title accurately describes the primary feature: channel-wise transforms are being added as new wrapper transforms.
Description check ✅ Passed Description covers main changes, links to issue #8311, includes required sections with appropriate checkboxes, and documents new tests and docstrings.
Linked Issues check ✅ Passed All coding requirements from issue #8311 are met: ChannelWise, RandChannelWise, ChannelWised, RandChannelWised transforms implemented with proper per-channel application and PRNG state management.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

🧹 Nitpick comments (2)
monai/transforms/utility/dictionary.py (1)

349-411: ⚡ Quick win

Complete docstrings on ChannelWised and RandChannelWised methods.

__call__ and set_random_state should include full Google-style Args/Returns/Raises to match project standards.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/utility/dictionary.py` around lines 349 - 411, The
docstrings for ChannelWised.__call__, RandChannelWised.__call__, and
RandChannelWised.set_random_state are incomplete; update them to full
Google-style docstrings including Args (describe parameters like data, seed,
state, their types and behavior), Returns (describe returned dict[Hashable,
NdarrayOrTensor] and its contents), and Raises (document possible exceptions,
e.g., KeyError for missing keys when allow_missing_keys is False, TypeError for
invalid input types). Ensure ChannelWised and RandChannelWised class docstrings
mention their converter attributes and any randomness behavior, and for
RandChannelWised.set_random_state specify it returns self (RandChannelWised) and
that it delegates to converter.set_random_state when available.
monai/transforms/utility/array.py (1)

293-366: ⚡ Quick win

Add complete Google-style docstrings for new definitions.

The new class/method docstrings are minimal and omit full Args/Returns/Raises coverage for definitions like __call__ and set_random_state.

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/utility/array.py` around lines 293 - 366, Update the minimal
docstrings for ChannelWise and RandChannelWise to full Google-style docstrings:
for the classes add a short description and Args describing transform (callable)
and prob (float) where applicable; for ChannelWise.__call__ and
RandChannelWise.__call__ add Args (img: NdarrayOrTensor, randomize: bool for
RandChannelWise.__call__), Returns (NdarrayOrTensor) and Raises (e.g.,
ValueError if input shape invalid) sections with types and brief meanings; for
RandChannelWise.set_random_state add Args (seed, state), Returns (self /
RandChannelWise) and any raised exceptions; ensure wording matches existing type
hints (NdarrayOrTensor, np.random.RandomState) and keep examples/notes optional
but consistent with project Google-style docstring conventions, placing the
updated docstrings in the definitions of ChannelWise, ChannelWise.__call__,
RandChannelWise, RandChannelWise.set_random_state, and RandChannelWise.__call__.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/inferers/inferer.py`:
- Around line 936-954: Add a docstring to the _scheduler_step method explaining
that it wraps a Scheduler.step call and normalizes its output into a previous
sample Tensor; document parameters (scheduler: Scheduler, model_output:
torch.Tensor, timestep: int|torch.Tensor, sample: torch.Tensor, next_timestep:
int|torch.Tensor|None) and mention that step_kwargs may include return_dict
based on _scheduler_step_supports_kwarg, and that RFlowScheduler is handled with
a different call signature (RFlowScheduler.step includes next_timestep). State
the return type (torch.Tensor) and note that the method returns the previous
sample via _get_previous_sample_from_step_output.
- Around line 914-935: The method _get_posterior_variance is missing a
docstring; add a concise docstring above the function explaining what it
computes (the posterior variance used in the diffusion reverse step), describing
parameters (scheduler: Scheduler, timestep: int|torch.Tensor,
predicted_variance: torch.Tensor|None) and return type (torch.Tensor), and
briefly note behavior for each variance_type branch ("fixed_small",
"fixed_large", "learned", "learned_range") including how predicted_variance is
used; keep it short, follow existing project style (one-line summary + short
param/return descriptions).
- Around line 865-871: Add a Google-style docstring to the static method
_scheduler_step_supports_kwarg explaining the parameter introspection: describe
the parameters (scheduler: Scheduler, kwarg: str), the return value (bool
indicating whether scheduler.step accepts kwarg), and the exceptions handled
(TypeError and ValueError caught from inspect.signature). Mention that the
function uses inspect.signature(scheduler.step).parameters to check for the
kwarg and that inspected errors are swallowed (returning False).
- Around line 902-913: Add a docstring to the _get_posterior_mean function that
briefly states what the function computes (the posterior mean used in the
diffusion/scheduler step), describes the inputs (scheduler, timestep, x_0, x_t)
and their expected types/shapes, documents the mathematical formula being
implemented (coefficients for x_0 and x_t using scheduler.alphas,
scheduler.alphas_cumprod and scheduler.betas) and states the return type
(torch.Tensor); keep it concise, one-line summary plus parameter and return
sections, and mention edge-case behavior for timestep==0 where scheduler.one is
used.
- Around line 883-887: Add a docstring to the helper function
_get_scheduler_name(scheduler: Scheduler) that succinctly describes its purpose
(returns a human-readable name for a Scheduler instance), documents the
parameter (scheduler: Scheduler) and the return type (str), and notes the lookup
order (prefers scheduler._get_name() if present, otherwise uses
scheduler.__class__.__name__); update the function definition for
_get_scheduler_name to include this docstring directly above the implementation.
- Around line 873-881: Add a Google-style docstring to
_get_previous_sample_from_step_output that documents the parameter step_output
(types expected), the returned torch.Tensor, and the TypeError raised; also
change the TypeError message to include the actual type encountered (e.g., using
type(step_output)) so the error reports the unsupported type. Ensure the
docstring briefly explains the three supported shapes (tuple where [0] is prev
sample, Mapping with "prev_sample" key, or object attribute prev_sample) and
mentions the raised TypeError when none match.
- Around line 889-900: Add a descriptive docstring to the helper function
_get_scheduler_config_value(scheduler, name, default) explaining its purpose
(resolve a configuration value by first checking scheduler.config mapping or
attributes, then scheduler attributes, and returning default if not found), and
document the parameters (scheduler: Scheduler, name: str, default: Any) and
return type (Any) plus any behavior/edge-cases (e.g., handles Mapping config and
attribute lookup order). Keep it concise and follow project's docstring style
(short summary, params, returns).

In `@monai/transforms/utility/array.py`:
- Around line 315-321: The per-channel loop in the method calling self.transform
collects results and blindly concatenates them (torch.cat / np.concatenate),
which can corrupt layout if a wrapped transform drops the singleton channel
dimension; before appending each res, validate that its dimensionality and
leading channel size preserve the expected singleton channel (e.g., for
torch.Tensor res.ndim and res.shape[0]==1, for np.ndarray res.ndim and
res.shape[0]==1) and either raise a clear error or reintroduce the missing
channel axis (unsqueeze or np.expand_dims) so that all items in results are
consistent for concatenation; apply the same guard logic to the analogous loop
around lines 359-365.

In `@tests/test_channel_wise.py`:
- Around line 10-46: Add equivalent tests that exercise the torch backend by
re-running the same scenarios with torch.Tensor inputs: create torch tensors for
data in test_channel_wise_deterministic, test_rand_channel_wise, and
test_prob_zero and invoke ChannelWise, RandChannelWise with ScaleIntensity and
RandGaussianNoise respectively (use set_determinism(seed=0) to control
randomness for torch too). For assertions use torch.allclose (or convert outputs
to numpy with .cpu().numpy()) to check per-channel scaling and inequality of
random channels, and assert tensor shapes match; ensure the
RandChannelWise(prob=0.0) case returns an identical torch tensor. Reference
ChannelWise, RandChannelWise, ScaleIntensity, RandGaussianNoise, and
set_determinism to locate code to test.
- Around line 28-31: The test sets global determinism via
set_determinism(seed=0) but never restores it, so wrap the deterministic section
(where you instantiate RandChannelWise and apply it to data) in a try/finally
and in the finally call set_determinism(None) (or the API call that disables
determinism) to restore global state; specifically modify the block around
set_determinism(seed=0), transform = RandChannelWise(...), and out =
transform(data) to ensure set_determinism is undone after the test.

In `@tests/test_channel_wised.py`:
- Around line 10-46: Add tests that exercise ChannelWised and RandChannelWised
with torch.Tensor inputs (not only numpy arrays) so the torch code paths are
covered: for test_channel_wise_deterministic create data as torch.tensor with
the same values and call ChannelWised(keys=["image"],
transform=ScaleIntensity()) then assert the per-channel scaled results and shape
using torch.allclose; for test_rand_channel_wise use set_determinism(seed=0) and
pass a torch.zeros tensor into RandChannelWised(keys=["image"],
transform=RandGaussianNoise(prob=1.0, std=1.0)) and assert channels differ with
torch comparisons and shape equality; similarly add a torch variant of
test_prob_zero using RandChannelWised(..., prob=0.0) to assert output equals
input. Ensure you import torch and use ChannelWised, RandChannelWised,
ScaleIntensity, RandGaussianNoise, and set_determinism names from the module
under test.
- Around line 28-31: The test sets global determinism with
set_determinism(seed=0) but never restores it; after running the randomized
transform (the RandChannelWised/RandGaussianNoise call in
tests/test_channel_wised.py where transform(data) is executed) call
set_determinism(None) (or the library's provided reset/unset call) to restore
the global RNG/determinism state so other tests are not affected; place that
call immediately after out = transform(data).

---

Nitpick comments:
In `@monai/transforms/utility/array.py`:
- Around line 293-366: Update the minimal docstrings for ChannelWise and
RandChannelWise to full Google-style docstrings: for the classes add a short
description and Args describing transform (callable) and prob (float) where
applicable; for ChannelWise.__call__ and RandChannelWise.__call__ add Args (img:
NdarrayOrTensor, randomize: bool for RandChannelWise.__call__), Returns
(NdarrayOrTensor) and Raises (e.g., ValueError if input shape invalid) sections
with types and brief meanings; for RandChannelWise.set_random_state add Args
(seed, state), Returns (self / RandChannelWise) and any raised exceptions;
ensure wording matches existing type hints (NdarrayOrTensor,
np.random.RandomState) and keep examples/notes optional but consistent with
project Google-style docstring conventions, placing the updated docstrings in
the definitions of ChannelWise, ChannelWise.__call__, RandChannelWise,
RandChannelWise.set_random_state, and RandChannelWise.__call__.

In `@monai/transforms/utility/dictionary.py`:
- Around line 349-411: The docstrings for ChannelWised.__call__,
RandChannelWised.__call__, and RandChannelWised.set_random_state are incomplete;
update them to full Google-style docstrings including Args (describe parameters
like data, seed, state, their types and behavior), Returns (describe returned
dict[Hashable, NdarrayOrTensor] and its contents), and Raises (document possible
exceptions, e.g., KeyError for missing keys when allow_missing_keys is False,
TypeError for invalid input types). Ensure ChannelWised and RandChannelWised
class docstrings mention their converter attributes and any randomness behavior,
and for RandChannelWised.set_random_state specify it returns self
(RandChannelWised) and that it delegates to converter.set_random_state when
available.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 63a24f4b-29bb-4185-ac59-093521eb8abb

📥 Commits

Reviewing files that changed from the base of the PR and between 2a7d0cf and dfe632f.

📒 Files selected for processing (8)
  • monai/inferers/inferer.py
  • monai/transforms/__init__.py
  • monai/transforms/utility/array.py
  • monai/transforms/utility/dictionary.py
  • tests/inferers/test_diffusion_inferer.py
  • tests/inferers/test_latent_diffusion_inferer.py
  • tests/test_channel_wise.py
  • tests/test_channel_wised.py

Comment thread monai/inferers/inferer.py
Comment on lines +865 to +871
@staticmethod
def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool:
try:
return kwarg in inspect.signature(scheduler.step).parameters
except (TypeError, ValueError):
return False

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Add docstring describing parameter introspection logic.

All new methods lack Google-style docstrings. Per coding guidelines, docstrings must describe parameters, return values, and exceptions.

📝 Suggested docstring
     `@staticmethod`
     def _scheduler_step_supports_kwarg(scheduler: Scheduler, kwarg: str) -> bool:
+        """
+        Check if a scheduler's step method accepts a specific keyword argument.
+
+        Args:
+            scheduler: Scheduler instance to inspect.
+            kwarg: Name of the keyword argument to check.
+
+        Returns:
+            True if the scheduler's step method accepts the kwarg, False otherwise.
+        """
         try:

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/inferers/inferer.py` around lines 865 - 871, Add a Google-style
docstring to the static method _scheduler_step_supports_kwarg explaining the
parameter introspection: describe the parameters (scheduler: Scheduler, kwarg:
str), the return value (bool indicating whether scheduler.step accepts kwarg),
and the exceptions handled (TypeError and ValueError caught from
inspect.signature). Mention that the function uses
inspect.signature(scheduler.step).parameters to check for the kwarg and that
inspected errors are swallowed (returning False).

Comment thread monai/inferers/inferer.py
Comment on lines +873 to +881
def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor:
if isinstance(step_output, tuple):
return step_output[0]
if isinstance(step_output, Mapping):
return step_output["prev_sample"]
if hasattr(step_output, "prev_sample"):
return step_output.prev_sample
raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Add docstring and include actual type in error message.

Missing docstring and the TypeError doesn't report the actual unsupported type encountered.

📝 Proposed fix
     `@staticmethod`
     def _get_previous_sample_from_step_output(step_output: Any) -> torch.Tensor:
+        """
+        Extract the previous sample tensor from various scheduler step output formats.
+
+        Args:
+            step_output: Output from scheduler.step(), which may be a tuple, dict, or object.
+
+        Returns:
+            The previous sample tensor.
+
+        Raises:
+            TypeError: If the output format is not recognized.
+        """
         if isinstance(step_output, tuple):
             return step_output[0]
         if isinstance(step_output, Mapping):
             return step_output["prev_sample"]
         if hasattr(step_output, "prev_sample"):
             return step_output.prev_sample
-        raise TypeError("Unsupported scheduler.step output. Expected a tuple or an object with `prev_sample`.")
+        raise TypeError(
+            f"Unsupported scheduler.step output type: {type(step_output).__name__}. "
+            "Expected a tuple, mapping with 'prev_sample' key, or object with prev_sample attribute."
+        )

As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/inferers/inferer.py` around lines 873 - 881, Add a Google-style
docstring to _get_previous_sample_from_step_output that documents the parameter
step_output (types expected), the returned torch.Tensor, and the TypeError
raised; also change the TypeError message to include the actual type encountered
(e.g., using type(step_output)) so the error reports the unsupported type.
Ensure the docstring briefly explains the three supported shapes (tuple where
[0] is prev sample, Mapping with "prev_sample" key, or object attribute
prev_sample) and mentions the raised TypeError when none match.

Comment thread monai/inferers/inferer.py
Comment on lines +883 to +887
def _get_scheduler_name(scheduler: Scheduler) -> str:
if hasattr(scheduler, "_get_name"):
return scheduler._get_name()
return scheduler.__class__.__name__

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Add docstring.

📝 Suggested docstring
     `@staticmethod`
     def _get_scheduler_name(scheduler: Scheduler) -> str:
+        """
+        Get the name of a scheduler instance.
+
+        Args:
+            scheduler: Scheduler instance.
+
+        Returns:
+            The scheduler's name.
+        """
         if hasattr(scheduler, "_get_name"):

As per coding guidelines, docstrings are required for all definitions.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/inferers/inferer.py` around lines 883 - 887, Add a docstring to the
helper function _get_scheduler_name(scheduler: Scheduler) that succinctly
describes its purpose (returns a human-readable name for a Scheduler instance),
documents the parameter (scheduler: Scheduler) and the return type (str), and
notes the lookup order (prefers scheduler._get_name() if present, otherwise uses
scheduler.__class__.__name__); update the function definition for
_get_scheduler_name to include this docstring directly above the implementation.

Comment thread monai/inferers/inferer.py
Comment on lines +889 to +900
def _get_scheduler_config_value(scheduler: Scheduler, name: str, default: Any = None) -> Any:
config = getattr(scheduler, "config", None)
if isinstance(config, Mapping):
if name in config:
return config[name]
elif config is not None and hasattr(config, name):
return getattr(config, name)

if hasattr(scheduler, name):
return getattr(scheduler, name)
return default

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Add docstring.

📝 Suggested docstring
     `@staticmethod`
     def _get_scheduler_config_value(scheduler: Scheduler, name: str, default: Any = None) -> Any:
+        """
+        Read a configuration value from a scheduler.
+
+        Args:
+            scheduler: Scheduler instance.
+            name: Configuration parameter name.
+            default: Value to return if the parameter is not found.
+
+        Returns:
+            The configuration value or default if not found.
+        """
         config = getattr(scheduler, "config", None)

As per coding guidelines, docstrings are required for all definitions.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _get_scheduler_config_value(scheduler: Scheduler, name: str, default: Any = None) -> Any:
config = getattr(scheduler, "config", None)
if isinstance(config, Mapping):
if name in config:
return config[name]
elif config is not None and hasattr(config, name):
return getattr(config, name)
if hasattr(scheduler, name):
return getattr(scheduler, name)
return default
`@staticmethod`
def _get_scheduler_config_value(scheduler: Scheduler, name: str, default: Any = None) -> Any:
"""
Read a configuration value from a scheduler.
Args:
scheduler: Scheduler instance.
name: Configuration parameter name.
default: Value to return if the parameter is not found.
Returns:
The configuration value or default if not found.
"""
config = getattr(scheduler, "config", None)
if isinstance(config, Mapping):
if name in config:
return config[name]
elif config is not None and hasattr(config, name):
return getattr(config, name)
if hasattr(scheduler, name):
return getattr(scheduler, name)
return default
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/inferers/inferer.py` around lines 889 - 900, Add a descriptive
docstring to the helper function _get_scheduler_config_value(scheduler, name,
default) explaining its purpose (resolve a configuration value by first checking
scheduler.config mapping or attributes, then scheduler attributes, and returning
default if not found), and document the parameters (scheduler: Scheduler, name:
str, default: Any) and return type (Any) plus any behavior/edge-cases (e.g.,
handles Mapping config and attribute lookup order). Keep it concise and follow
project's docstring style (short summary, params, returns).

Comment thread monai/inferers/inferer.py
Comment on lines +902 to +913
def _get_posterior_mean(
scheduler: Scheduler, timestep: int | torch.Tensor, x_0: torch.Tensor, x_t: torch.Tensor
) -> torch.Tensor:
alpha_t = scheduler.alphas[timestep]
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one

x_0_coefficient = alpha_prod_t_prev.sqrt() * scheduler.betas[timestep] / (1 - alpha_prod_t)
x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)

return x_0_coefficient * x_0 + x_t_coefficient * x_t

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major | ⚡ Quick win

Add docstring describing posterior mean calculation.

📝 Suggested docstring
     `@staticmethod`
     def _get_posterior_mean(
         scheduler: Scheduler, timestep: int | torch.Tensor, x_0: torch.Tensor, x_t: torch.Tensor
     ) -> torch.Tensor:
+        """
+        Compute the posterior mean for the diffusion process at a given timestep.
+
+        Implements formula (7) from https://arxiv.org/pdf/2006.11239.pdf
+
+        Args:
+            scheduler: Scheduler with alphas, alphas_cumprod, betas, and one attributes.
+            timestep: Current timestep (scalar or tensor).
+            x_0: Noise-free input.
+            x_t: Input noised to timestep t.
+
+        Returns:
+            The posterior mean tensor.
+        """
         alpha_t = scheduler.alphas[timestep]

As per coding guidelines, docstrings are required for all definitions.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _get_posterior_mean(
scheduler: Scheduler, timestep: int | torch.Tensor, x_0: torch.Tensor, x_t: torch.Tensor
) -> torch.Tensor:
alpha_t = scheduler.alphas[timestep]
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one
x_0_coefficient = alpha_prod_t_prev.sqrt() * scheduler.betas[timestep] / (1 - alpha_prod_t)
x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)
return x_0_coefficient * x_0 + x_t_coefficient * x_t
def _get_posterior_mean(
scheduler: Scheduler, timestep: int | torch.Tensor, x_0: torch.Tensor, x_t: torch.Tensor
) -> torch.Tensor:
"""
Compute the posterior mean for the diffusion process at a given timestep.
Implements formula (7) from https://arxiv.org/pdf/2006.11239.pdf
Args:
scheduler: Scheduler with alphas, alphas_cumprod, betas, and one attributes.
timestep: Current timestep (scalar or tensor).
x_0: Noise-free input.
x_t: Input noised to timestep t.
Returns:
The posterior mean tensor.
"""
alpha_t = scheduler.alphas[timestep]
alpha_prod_t = scheduler.alphas_cumprod[timestep]
alpha_prod_t_prev = scheduler.alphas_cumprod[timestep - 1] if timestep > 0 else scheduler.one
x_0_coefficient = alpha_prod_t_prev.sqrt() * scheduler.betas[timestep] / (1 - alpha_prod_t)
x_t_coefficient = alpha_t.sqrt() * (1 - alpha_prod_t_prev) / (1 - alpha_prod_t)
return x_0_coefficient * x_0 + x_t_coefficient * x_t
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/inferers/inferer.py` around lines 902 - 913, Add a docstring to the
_get_posterior_mean function that briefly states what the function computes (the
posterior mean used in the diffusion/scheduler step), describes the inputs
(scheduler, timestep, x_0, x_t) and their expected types/shapes, documents the
mathematical formula being implemented (coefficients for x_0 and x_t using
scheduler.alphas, scheduler.alphas_cumprod and scheduler.betas) and states the
return type (torch.Tensor); keep it concise, one-line summary plus parameter and
return sections, and mention edge-case behavior for timestep==0 where
scheduler.one is used.

Comment on lines +315 to +321
for i in range(img.shape[0]):
res = self.transform(img[[i], ...])
results.append(res)

if isinstance(img, torch.Tensor):
return torch.cat(results, dim=0)
return np.concatenate(results, axis=0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Guard per-channel output shape invariants before concatenation.

If the wrapped transform removes the singleton channel dimension, concatenation can silently corrupt layout (e.g., flattening channel semantics into axis 0). Validate per-channel outputs before appending.

Suggested patch
 class ChannelWise(Transform):
@@
     def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
@@
-        results = []
+        results = []
+        expected_ndim = img.ndim
+        expected_shape_wo_channel: tuple[int, ...] | None = None
         for i in range(img.shape[0]):
             res = self.transform(img[[i], ...])
+            if res.ndim != expected_ndim:
+                raise ValueError(
+                    f"channel transform must preserve ndim={expected_ndim}, got ndim={res.ndim} for channel {i}."
+                )
+            current_shape_wo_channel = tuple(res.shape[1:])
+            if expected_shape_wo_channel is None:
+                expected_shape_wo_channel = current_shape_wo_channel
+            elif current_shape_wo_channel != expected_shape_wo_channel:
+                raise ValueError(
+                    "channel transform outputs must have consistent non-channel shape; "
+                    f"got {current_shape_wo_channel} and {expected_shape_wo_channel}."
+                )
             results.append(res)
@@
 class RandChannelWise(RandomizableTransform):
@@
     def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor:
@@
-        results = []
+        results = []
+        expected_ndim = img.ndim
+        expected_shape_wo_channel: tuple[int, ...] | None = None
         for i in range(img.shape[0]):
             res = self.transform(img[[i], ...])
+            if res.ndim != expected_ndim:
+                raise ValueError(
+                    f"channel transform must preserve ndim={expected_ndim}, got ndim={res.ndim} for channel {i}."
+                )
+            current_shape_wo_channel = tuple(res.shape[1:])
+            if expected_shape_wo_channel is None:
+                expected_shape_wo_channel = current_shape_wo_channel
+            elif current_shape_wo_channel != expected_shape_wo_channel:
+                raise ValueError(
+                    "channel transform outputs must have consistent non-channel shape; "
+                    f"got {current_shape_wo_channel} and {expected_shape_wo_channel}."
+                )
             results.append(res)

Also applies to: 359-365

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/transforms/utility/array.py` around lines 315 - 321, The per-channel
loop in the method calling self.transform collects results and blindly
concatenates them (torch.cat / np.concatenate), which can corrupt layout if a
wrapped transform drops the singleton channel dimension; before appending each
res, validate that its dimensionality and leading channel size preserve the
expected singleton channel (e.g., for torch.Tensor res.ndim and res.shape[0]==1,
for np.ndarray res.ndim and res.shape[0]==1) and either raise a clear error or
reintroduce the missing channel axis (unsqueeze or np.expand_dims) so that all
items in results are consistent for concatenation; apply the same guard logic to
the analogous loop around lines 359-365.

Comment on lines +10 to +46
def test_channel_wise_deterministic(self):
# Test applying a deterministic transform channel-wise
data = np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]]) # shape (2, 2, 2)

# ScaleIntensity applies to the whole input array independently
transform = ChannelWise(transform=ScaleIntensity())
out = transform(data)

# Channel 0 scaled
np.testing.assert_allclose(out[0], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5)
# Channel 1 scaled
np.testing.assert_allclose(out[1], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5)
self.assertEqual(out.shape, data.shape)

def test_rand_channel_wise(self):
# Test applying a randomized transform channel-wise
data = np.zeros((3, 4, 4))

set_determinism(seed=0)
# Apply random noise with high standard deviation to see the difference
transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0))
out = transform(data)

# All channels should have different noise values
self.assertFalse(np.allclose(out[0], out[1]))
self.assertFalse(np.allclose(out[1], out[2]))
self.assertFalse(np.allclose(out[0], out[2]))

# Output shape should be exactly the same
self.assertEqual(out.shape, data.shape)

def test_prob_zero(self):
# Test when RandChannelWise prob is 0.0
data = np.zeros((2, 2, 2))
transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0)
out = transform(data)
np.testing.assert_allclose(out, data)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add torch-backend test coverage for the new wrappers.

Current coverage exercises numpy only; the tensor branch is part of the public backend contract and should be explicitly tested.

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_channel_wise.py` around lines 10 - 46, Add equivalent tests that
exercise the torch backend by re-running the same scenarios with torch.Tensor
inputs: create torch tensors for data in test_channel_wise_deterministic,
test_rand_channel_wise, and test_prob_zero and invoke ChannelWise,
RandChannelWise with ScaleIntensity and RandGaussianNoise respectively (use
set_determinism(seed=0) to control randomness for torch too). For assertions use
torch.allclose (or convert outputs to numpy with .cpu().numpy()) to check
per-channel scaling and inequality of random channels, and assert tensor shapes
match; ensure the RandChannelWise(prob=0.0) case returns an identical torch
tensor. Reference ChannelWise, RandChannelWise, ScaleIntensity,
RandGaussianNoise, and set_determinism to locate code to test.

Comment on lines +28 to +31
set_determinism(seed=0)
# Apply random noise with high standard deviation to see the difference
transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0))
out = transform(data)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reset determinism after this test to avoid global state bleed.

set_determinism(seed=0) mutates global test state and should be restored.

Suggested patch
     def test_rand_channel_wise(self):
@@
-        set_determinism(seed=0)
-        # Apply random noise with high standard deviation to see the difference
-        transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0))
-        out = transform(data)
+        set_determinism(seed=0)
+        try:
+            # Apply random noise with high standard deviation to see the difference
+            transform = RandChannelWise(transform=RandGaussianNoise(prob=1.0, std=1.0))
+            out = transform(data)
+        finally:
+            set_determinism(seed=None)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_channel_wise.py` around lines 28 - 31, The test sets global
determinism via set_determinism(seed=0) but never restores it, so wrap the
deterministic section (where you instantiate RandChannelWise and apply it to
data) in a try/finally and in the finally call set_determinism(None) (or the API
call that disables determinism) to restore global state; specifically modify the
block around set_determinism(seed=0), transform = RandChannelWise(...), and out
= transform(data) to ensure set_determinism is undone after the test.

Comment on lines +10 to +46
def test_channel_wise_deterministic(self):
# Test applying a deterministic transform channel-wise
data = {"image": np.array([[[1.0, 2.0], [3.0, 4.0]], [[10.0, 20.0], [30.0, 40.0]]])} # shape (2, 2, 2)

# ScaleIntensity applies to the whole input array independently
transform = ChannelWised(keys=["image"], transform=ScaleIntensity())
out = transform(data)

# Channel 0 scaled
np.testing.assert_allclose(out["image"][0], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5)
# Channel 1 scaled
np.testing.assert_allclose(out["image"][1], np.array([[0.0, 0.3333333], [0.6666667, 1.0]]), atol=1e-5)
self.assertEqual(out["image"].shape, data["image"].shape)

def test_rand_channel_wise(self):
# Test applying a randomized transform channel-wise
data = {"image": np.zeros((3, 4, 4))}

set_determinism(seed=0)
# Apply random noise with high standard deviation to see the difference
transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0))
out = transform(data)

# All channels should have different noise values
self.assertFalse(np.allclose(out["image"][0], out["image"][1]))
self.assertFalse(np.allclose(out["image"][1], out["image"][2]))
self.assertFalse(np.allclose(out["image"][0], out["image"][2]))

# Output shape should be exactly the same
self.assertEqual(out["image"].shape, data["image"].shape)

def test_prob_zero(self):
# Test when RandChannelWised prob is 0.0
data = {"image": np.zeros((2, 2, 2))}
transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0), prob=0.0)
out = transform(data)
np.testing.assert_allclose(out["image"], data["image"])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Cover torch input path for ChannelWised and RandChannelWised.

These tests currently validate numpy dictionaries only, leaving tensor behavior unverified.

As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_channel_wised.py` around lines 10 - 46, Add tests that exercise
ChannelWised and RandChannelWised with torch.Tensor inputs (not only numpy
arrays) so the torch code paths are covered: for test_channel_wise_deterministic
create data as torch.tensor with the same values and call
ChannelWised(keys=["image"], transform=ScaleIntensity()) then assert the
per-channel scaled results and shape using torch.allclose; for
test_rand_channel_wise use set_determinism(seed=0) and pass a torch.zeros tensor
into RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0,
std=1.0)) and assert channels differ with torch comparisons and shape equality;
similarly add a torch variant of test_prob_zero using RandChannelWised(...,
prob=0.0) to assert output equals input. Ensure you import torch and use
ChannelWised, RandChannelWised, ScaleIntensity, RandGaussianNoise, and
set_determinism names from the module under test.

Comment on lines +28 to +31
set_determinism(seed=0)
# Apply random noise with high standard deviation to see the difference
transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0))
out = transform(data)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reset determinism state in this randomized test.

Global determinism should be restored to prevent inter-test coupling.

Suggested patch
     def test_rand_channel_wise(self):
@@
-        set_determinism(seed=0)
-        # Apply random noise with high standard deviation to see the difference
-        transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0))
-        out = transform(data)
+        set_determinism(seed=0)
+        try:
+            # Apply random noise with high standard deviation to see the difference
+            transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0))
+            out = transform(data)
+        finally:
+            set_determinism(seed=None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
set_determinism(seed=0)
# Apply random noise with high standard deviation to see the difference
transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0))
out = transform(data)
set_determinism(seed=0)
try:
# Apply random noise with high standard deviation to see the difference
transform = RandChannelWised(keys=["image"], transform=RandGaussianNoise(prob=1.0, std=1.0))
out = transform(data)
finally:
set_determinism(seed=None)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_channel_wised.py` around lines 28 - 31, The test sets global
determinism with set_determinism(seed=0) but never restores it; after running
the randomized transform (the RandChannelWised/RandGaussianNoise call in
tests/test_channel_wised.py where transform(data) is executed) call
set_determinism(None) (or the library's provided reset/unset call) to restore
the global RNG/determinism state so other tests are not affected; place that
call immediately after out = transform(data).

@ugbotueferhire
Copy link
Copy Markdown
Contributor Author

Hi @ericspod please can you check this out?

@aymuos15
Copy link
Copy Markdown
Contributor

aymuos15 commented Jun 5, 2026

ChannelWise(ScaleIntensity()) duplicates the existing channel_wise=True flag already on most intensity transforms (ScaleIntensity, NormalizeIntensity, RandShiftIntensity, etc.)

I think that should be the base (its also more performant than the loop approach here)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Implementing Channel-Wise Transforms

3 participants