Skip to content

Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss#8744

Open
hongjie-qiu wants to merge 1 commit intoProject-MONAI:devfrom
hongjie-qiu:4650-fix-gwdl-batch-size
Open

Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss#8744
hongjie-qiu wants to merge 1 commit intoProject-MONAI:devfrom
hongjie-qiu:4650-fix-gwdl-batch-size

Conversation

@hongjie-qiu
Copy link

@hongjie-qiu hongjie-qiu commented Feb 19, 2026

Fixes #4650

Description

When batch_size > 1, GeneralizedWassersteinDiceLoss produces incorrect loss values because of a tensor broadcasting issue in _compute_generalized_true_positive and _compute_denominator.

After torch.gather, alpha_extended has shape (B, 1, S) while wasserstein_distance_map has shape (B, S). The element-wise multiply silently broadcasts to (B, B, S), which mixes values across batch samples. This means the loss has always been wrong for any training run with batch_size > 1.

The fix follows the reference implementation by the original paper's author — squeeze dim=1 after the gather so both tensors are (B, S), and reduce with dim=1 instead of dim=[1, 2].

I also noticed that reduction="none" was broken (never had test coverage) — it tried to reshape the per-sample loss (B,) into (B, C, 1, ...), but GWDL aggregates over classes internally so the class dimension doesn't exist in the output. Fixed that as well.

Changes

  • monai/losses/dice.py: squeeze + dim fix in _compute_generalized_true_positive and _compute_denominator; fixed reduction="none" path
  • tests/losses/test_generalized_wasserstein_dice_loss.py: two new regression tests for batch consistency

Tests

All existing tests pass. The new regression tests fail on unpatched code and pass with the fix.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 19, 2026

📝 Walkthrough

Walkthrough

The PR fixes batch-size handling in GeneralizedWassersteinDiceLoss by keeping the forward output for reduction="none" as shape (B,) instead of broadcasting per-voxel losses. Internal helpers now squeeze the class dimension after gathering alpha_extended, changing reductions from dims [1, 2] to dim=1 so per-batch results are preserved. Two regression tests were added to verify consistency between batch and single-sample losses for identical and distinct samples across weighting modes.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed Title concisely describes the main fix: correcting a batch size broadcasting bug in GeneralizedWassersteinDiceLoss, directly matching the changeset.
Description check ✅ Passed Description covers all required sections: linked issue reference, detailed explanation of the bug, fix details, changed files, and test information.
Linked Issues check ✅ Passed All requirements from issue #4650 are met: batch-size broadcasting bug fixed in _compute_generalized_true_positive and _compute_denominator, reduction="none" fixed, regression tests added.
Out of Scope Changes check ✅ Passed All changes are scoped to the linked issue: fixes are localized to GeneralizedWassersteinDiceLoss loss function and its tests, no unrelated modifications detected.
Docstring Coverage ✅ Passed Docstring coverage is 87.50% which is sufficient. The required threshold is 80.00%.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (3)
monai/losses/dice.py (2)

509-514: Document the reduction="none" output shape in the docstring.

The forward docstring has no Returns: section. Since this PR changes the reduction="none" output shape from a broken reshape attempt to a well-defined (B,), callers need to know what to expect.

📝 Proposed docstring update
 def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
     """
     Args:
         input: the shape should be BNH[WD].
         target: the shape should be BNH[WD].

+    Returns:
+        Scalar when ``reduction`` is ``"mean"`` or ``"sum"``.
+        Tensor of shape ``(B,)`` when ``reduction`` is ``"none"``, one loss value
+        per sample (GWDL aggregates over classes and spatial dims internally).
     """

As per coding guidelines, Google-style docstrings should describe each return value in a Returns: section.

Also applies to: 550-553

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/dice.py` around lines 509 - 514, Update the forward docstring
for the Dice loss (method forward) to add a Google-style "Returns:" section that
clearly states the tensor shape and meaning for each reduction mode;
specifically document that when reduction="none" the output is a 1-D tensor of
shape (B,) containing per-batch loss values (and describe shapes for
"mean"/"sum" if applicable), and apply the same docstring clarification to the
other forward variant referenced around lines 550-553 so callers know to expect
a (B,) output instead of the previous reshape behavior.

596-630: Both helpers share identical alpha-mapping code — extract a private helper.

The five-line alpha-extension/gather/squeeze block is duplicated verbatim in _compute_generalized_true_positive and _compute_denominator. A private _map_alpha_to_voxels would remove the duplication and make future fixes a one-place change (as this PR illustrates — the squeeze had to be added in both places).

Additionally, both methods are missing Returns: sections in their docstrings. As per coding guidelines, Google-style docstrings must describe return values.

♻️ Proposed refactor
+    def _map_alpha_to_voxels(self, alpha: torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor:
+        """Map per-class alpha weights to a per-voxel tensor via flat_target.
+
+        Args:
+            alpha: per-class weights of shape (B, C).
+            flat_target: flattened target labels of shape (B, S).
+
+        Returns:
+            Per-voxel alpha values of shape (B, S).
+        """
+        alpha_extended = torch.unsqueeze(alpha, dim=2)
+        alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
+        flat_target_extended = torch.unsqueeze(flat_target, dim=1)
+        alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
+        return torch.squeeze(alpha_extended, dim=1)

     def _compute_generalized_true_positive(
         self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
     ) -> torch.Tensor:
         """
         Args:
             alpha: generalised number of true positives of target class.
             flat_target: the target tensor.
             wasserstein_distance_map: the map obtained from the above function.
+
+        Returns:
+            Per-sample generalised true positives of shape (B,).
         """
-        # Extend alpha to a map and select value at each voxel according to flat_target
-        alpha_extended = torch.unsqueeze(alpha, dim=2)
-        alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
-        flat_target_extended = torch.unsqueeze(flat_target, dim=1)
-        alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
-        alpha_extended = torch.squeeze(alpha_extended, dim=1)
-        return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1)
+        alpha_per_voxel = self._map_alpha_to_voxels(alpha, flat_target)
+        return torch.sum(alpha_per_voxel * (1.0 - wasserstein_distance_map), dim=1)

     def _compute_denominator(
         self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor
     ) -> torch.Tensor:
         """
         Args:
             alpha: generalised number of true positives of target class.
             flat_target: the target tensor.
             wasserstein_distance_map: the map obtained from the above function.
+
+        Returns:
+            Per-sample denominator of shape (B,).
         """
-        # Extend alpha to a map and select value at each voxel according to flat_target
-        alpha_extended = torch.unsqueeze(alpha, dim=2)
-        alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1)))
-        flat_target_extended = torch.unsqueeze(flat_target, dim=1)
-        alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1)
-        alpha_extended = torch.squeeze(alpha_extended, dim=1)
-        return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1)
+        alpha_per_voxel = self._map_alpha_to_voxels(alpha, flat_target)
+        return torch.sum(alpha_per_voxel * (2.0 - wasserstein_distance_map), dim=1)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@monai/losses/dice.py` around lines 596 - 630, Both methods
_compute_generalized_true_positive and _compute_denominator duplicate the
alpha-extension/gather/squeeze block; extract that logic into a private helper
(e.g., _map_alpha_to_voxels(self, alpha: torch.Tensor, flat_target:
torch.Tensor) -> torch.Tensor) that returns the per-voxel alpha_extended tensor,
update both methods to call this helper and use its result in their sums, and
remove the duplicated code; also add a Returns: section to the docstrings of
_compute_generalized_true_positive and _compute_denominator describing the
returned tensor shape and meaning.
tests/losses/test_generalized_wasserstein_dice_loss.py (1)

293-295: float() on a reduction="none" output is fragile for future readers.

loss_fn(pred_a, target_a) returns shape (1,) here; float() only works because batch size is 1. Prefer .item() or index explicitly to make intent clear.

🔧 Suggested clarification
-            loss_a = float(loss_fn(pred_a, target_a))
-            loss_b = float(loss_fn(pred_b, target_b))
+            loss_a = loss_fn(pred_a, target_a)[0].item()
+            loss_b = loss_fn(pred_b, target_b)[0].item()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 293 -
295, The tests call float(loss_fn(pred_a, target_a)) and float(loss_fn(pred_b,
target_b)) where loss_fn is configured with reduction="none" and returns a
1-element tensor; using float() is fragile and hides the intent. Replace
float(...) with .item() (e.g., loss_fn(pred_a, target_a).item()) or explicitly
index [0] to extract the scalar so the code clearly indicates you're converting
a single-element tensor to a Python float; update both occurrences referencing
loss_fn, pred_a, target_a, pred_b, and target_b.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 254-270: The mean-reduction subtest uses trivially-zero losses
(pred_single from pred_very_good), so it won't catch a broken mean reduction;
update the test to use a non-trivial prediction for at least one sample: when
constructing pred_single/pred_batch for the GeneralizedWassersteinDiceLoss
checks, replace one of the batch entries (or pred_single) with a clearly poor
prediction (e.g., pred_very_poor) so loss_single and loss_batch produce a
non-zero, different-per-sample value; keep using loss_fn =
GeneralizedWassersteinDiceLoss(..., weighting_mode=w_mode, reduction="mean") and
then assert loss_batch equals loss_single to verify mean reduction behavior.

---

Nitpick comments:
In `@monai/losses/dice.py`:
- Around line 509-514: Update the forward docstring for the Dice loss (method
forward) to add a Google-style "Returns:" section that clearly states the tensor
shape and meaning for each reduction mode; specifically document that when
reduction="none" the output is a 1-D tensor of shape (B,) containing per-batch
loss values (and describe shapes for "mean"/"sum" if applicable), and apply the
same docstring clarification to the other forward variant referenced around
lines 550-553 so callers know to expect a (B,) output instead of the previous
reshape behavior.
- Around line 596-630: Both methods _compute_generalized_true_positive and
_compute_denominator duplicate the alpha-extension/gather/squeeze block; extract
that logic into a private helper (e.g., _map_alpha_to_voxels(self, alpha:
torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor) that returns the
per-voxel alpha_extended tensor, update both methods to call this helper and use
its result in their sums, and remove the duplicated code; also add a Returns:
section to the docstrings of _compute_generalized_true_positive and
_compute_denominator describing the returned tensor shape and meaning.

In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 293-295: The tests call float(loss_fn(pred_a, target_a)) and
float(loss_fn(pred_b, target_b)) where loss_fn is configured with
reduction="none" and returns a 1-element tensor; using float() is fragile and
hides the intent. Replace float(...) with .item() (e.g., loss_fn(pred_a,
target_a).item()) or explicitly index [0] to extract the scalar so the code
clearly indicates you're converting a single-element tensor to a Python float;
update both occurrences referencing loss_fn, pred_a, target_a, pred_b, and
target_b.

Comment on lines 254 to 270
# Also test with mean reduction: batch loss should equal single-sample loss
for w_mode in ["default", "GDL"]:
loss_fn = GeneralizedWassersteinDiceLoss(
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]),
weighting_mode=w_mode,
reduction="mean",
)

loss_single = float(loss_fn(pred_single, target_single))
loss_batch = float(loss_fn(pred_batch, target_batch))

self.assertAlmostEqual(
loss_batch,
loss_single,
places=5,
msg=f"Batch mean loss != single mean loss for weighting_mode={w_mode}",
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Mean-reduction sub-test is trivially weak.

pred_single = pred_very_good → loss ≈ 0 for both single and batch inputs. The assertion 0.0 ≈ 0.0 passes even if mean-reduction across a batch is broken. Use non-trivial predictions (e.g., pred_very_poor for one sample) so the expected mean is a non-zero, verifiable value.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 254 -
270, The mean-reduction subtest uses trivially-zero losses (pred_single from
pred_very_good), so it won't catch a broken mean reduction; update the test to
use a non-trivial prediction for at least one sample: when constructing
pred_single/pred_batch for the GeneralizedWassersteinDiceLoss checks, replace
one of the batch entries (or pred_single) with a clearly poor prediction (e.g.,
pred_very_poor) so loss_single and loss_batch produce a non-zero,
different-per-sample value; keep using loss_fn =
GeneralizedWassersteinDiceLoss(..., weighting_mode=w_mode, reduction="mean") and
then assert loss_batch equals loss_single to verify mean reduction behavior.

…oject-MONAI#4650)

After `torch.gather`, `alpha_extended` retains shape (B, 1, S) while
`wasserstein_distance_map` has shape (B, S). When batch size > 1 the
element-wise multiply broadcasts to (B, B, S), mixing values across
samples. Fixed by squeezing dim=1 after gather in both
`_compute_generalized_true_positive` and `_compute_denominator`, and
reducing with `dim=1` instead of `dim=[1, 2]`.

Also fixed the `reduction="none"` code path which incorrectly tried to
reshape the per-sample loss tensor (B,) to (B, C, 1, ...) — GWDL
aggregates over classes internally so the class dimension doesn't apply.

Added regression tests that verify batch consistency:
- identical samples in a batch produce the same loss as a single sample
- batched per-sample losses match individually computed losses

Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com>
@hongjie-qiu hongjie-qiu force-pushed the 4650-fix-gwdl-batch-size branch from 063df92 to 4887d9d Compare February 19, 2026 17:57
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 268-297: The test test_batch_size_different_samples is currently
trivial because both pred_a and pred_b are perfect one-hot predictions; change
pred_b to produce a poor prediction (e.g., build pred_b from 1 - target_b before
one_hot) so that loss_b is ≈1.0 and the batch variant checks are meaningful;
keep pred_a as the perfect 1000*F.one_hot(target_a, ...) and ensure pred_batch =
torch.cat([pred_a, pred_b], dim=0) remains consistent so loss_a, loss_b, and
loss_batch indices compare correctly inside the weighting_mode loop that calls
GeneralizedWassersteinDiceLoss and asserts loss_batch[0]==loss_a and
loss_batch[1]==loss_b.

---

Duplicate comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 252-266: The mean-reduction sub-test is weak because
pred_single/pred_batch are perfect (loss ≈ 0), so it doesn't catch
mean-reduction bugs; update the test to use non-perfect predictions so
per-sample losses are non-zero and comparable: for
GeneralizedWassersteinDiceLoss with weighting_mode set from ["default","GDL"]
and reduction="mean", construct pred_single/pred_batch (or modify pred_single to
contain a small deliberate error and replicate into pred_batch as the batch of
samples) and/or use multiple distinct samples in pred_batch, compute loss_single
= float(loss_fn(pred_single, target_single)) and loss_batch =
float(loss_fn(pred_batch, target_batch)), then assert loss_batch ≈ loss_single
to verify the mean aggregation is correct; reference
GeneralizedWassersteinDiceLoss, pred_single, pred_batch, target_single,
target_batch, weighting_mode, and reduction="mean" when making the change.

Comment on lines +268 to +297
def test_batch_size_different_samples(self):
"""
Regression test for https://github.com/Project-MONAI/MONAI/issues/4650
Verify loss is computed correctly when batch contains different samples.
"""
target_a = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]).unsqueeze(0)
target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0)

pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float()
pred_b = 1000 * F.one_hot(target_b, num_classes=2).permute(0, 3, 1, 2).float()

# Combine into a batch
target_batch = torch.cat([target_a, target_b], dim=0)
pred_batch = torch.cat([pred_a, pred_b], dim=0)

for w_mode in ["default", "GDL"]:
loss_fn = GeneralizedWassersteinDiceLoss(
dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none"
)

loss_a = float(loss_fn(pred_a, target_a))
loss_b = float(loss_fn(pred_b, target_b))
loss_batch = loss_fn(pred_batch, target_batch)

self.assertAlmostEqual(
float(loss_batch[0]), loss_a, places=5, msg=f"Batch loss[0] != loss_a for weighting_mode={w_mode}"
)
self.assertAlmostEqual(
float(loss_batch[1]), loss_b, places=5, msg=f"Batch loss[1] != loss_b for weighting_mode={w_mode}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Both samples use perfect predictions — regression test is trivially weak.

pred_a = 1000 * F.one_hot(target_a, ...) and pred_b = 1000 * F.one_hot(target_b, ...) both produce a Wasserstein distance map ≈ 0 everywhere (identity cost matrix, correct class predicted). Every assertion reduces to 0 ≈ 0 and would pass on the unpatched code, defeating the stated regression purpose.

Use at least one poor prediction (e.g., pred_b = 1000 * F.one_hot(1 - target_b, ...)) so loss_b ≈ 1.0 and loss_batch[1] ≈ 1.0 is a non-trivial check.

🛠️ Suggested fix
-pred_b = 1000 * F.one_hot(target_b, num_classes=2).permute(0, 3, 1, 2).float()
+pred_b = 1000 * F.one_hot(1 - target_b, num_classes=2).permute(0, 3, 1, 2).float()

Then loss_b will be ≈ 1.0, giving a meaningful per-sample regression check.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 268 -
297, The test test_batch_size_different_samples is currently trivial because
both pred_a and pred_b are perfect one-hot predictions; change pred_b to produce
a poor prediction (e.g., build pred_b from 1 - target_b before one_hot) so that
loss_b is ≈1.0 and the batch variant checks are meaningful; keep pred_a as the
perfect 1000*F.one_hot(target_a, ...) and ensure pred_batch = torch.cat([pred_a,
pred_b], dim=0) remains consistent so loss_a, loss_b, and loss_batch indices
compare correctly inside the weighting_mode loop that calls
GeneralizedWassersteinDiceLoss and asserts loss_batch[0]==loss_a and
loss_batch[1]==loss_b.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug in the generalized Wasserstein Dice loss

1 participant

Comments