Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss#8744
Fix batch size broadcasting bug in GeneralizedWassersteinDiceLoss#8744hongjie-qiu wants to merge 1 commit intoProject-MONAI:devfrom
Conversation
📝 WalkthroughWalkthroughThe PR fixes batch-size handling in GeneralizedWassersteinDiceLoss by keeping the forward output for Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (3)
monai/losses/dice.py (2)
509-514: Document thereduction="none"output shape in the docstring.The
forwarddocstring has noReturns:section. Since this PR changes thereduction="none"output shape from a broken reshape attempt to a well-defined(B,), callers need to know what to expect.📝 Proposed docstring update
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. target: the shape should be BNH[WD]. + Returns: + Scalar when ``reduction`` is ``"mean"`` or ``"sum"``. + Tensor of shape ``(B,)`` when ``reduction`` is ``"none"``, one loss value + per sample (GWDL aggregates over classes and spatial dims internally). """As per coding guidelines, Google-style docstrings should describe each return value in a
Returns:section.Also applies to: 550-553
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/losses/dice.py` around lines 509 - 514, Update the forward docstring for the Dice loss (method forward) to add a Google-style "Returns:" section that clearly states the tensor shape and meaning for each reduction mode; specifically document that when reduction="none" the output is a 1-D tensor of shape (B,) containing per-batch loss values (and describe shapes for "mean"/"sum" if applicable), and apply the same docstring clarification to the other forward variant referenced around lines 550-553 so callers know to expect a (B,) output instead of the previous reshape behavior.
596-630: Both helpers share identical alpha-mapping code — extract a private helper.The five-line alpha-extension/gather/squeeze block is duplicated verbatim in
_compute_generalized_true_positiveand_compute_denominator. A private_map_alpha_to_voxelswould remove the duplication and make future fixes a one-place change (as this PR illustrates — the squeeze had to be added in both places).Additionally, both methods are missing
Returns:sections in their docstrings. As per coding guidelines, Google-style docstrings must describe return values.♻️ Proposed refactor
+ def _map_alpha_to_voxels(self, alpha: torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor: + """Map per-class alpha weights to a per-voxel tensor via flat_target. + + Args: + alpha: per-class weights of shape (B, C). + flat_target: flattened target labels of shape (B, S). + + Returns: + Per-voxel alpha values of shape (B, S). + """ + alpha_extended = torch.unsqueeze(alpha, dim=2) + alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) + flat_target_extended = torch.unsqueeze(flat_target, dim=1) + alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) + return torch.squeeze(alpha_extended, dim=1) def _compute_generalized_true_positive( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor ) -> torch.Tensor: """ Args: alpha: generalised number of true positives of target class. flat_target: the target tensor. wasserstein_distance_map: the map obtained from the above function. + + Returns: + Per-sample generalised true positives of shape (B,). """ - # Extend alpha to a map and select value at each voxel according to flat_target - alpha_extended = torch.unsqueeze(alpha, dim=2) - alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) - flat_target_extended = torch.unsqueeze(flat_target, dim=1) - alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - alpha_extended = torch.squeeze(alpha_extended, dim=1) - return torch.sum(alpha_extended * (1.0 - wasserstein_distance_map), dim=1) + alpha_per_voxel = self._map_alpha_to_voxels(alpha, flat_target) + return torch.sum(alpha_per_voxel * (1.0 - wasserstein_distance_map), dim=1) def _compute_denominator( self, alpha: torch.Tensor, flat_target: torch.Tensor, wasserstein_distance_map: torch.Tensor ) -> torch.Tensor: """ Args: alpha: generalised number of true positives of target class. flat_target: the target tensor. wasserstein_distance_map: the map obtained from the above function. + + Returns: + Per-sample denominator of shape (B,). """ - # Extend alpha to a map and select value at each voxel according to flat_target - alpha_extended = torch.unsqueeze(alpha, dim=2) - alpha_extended = alpha_extended.expand((flat_target.size(0), self.num_classes, flat_target.size(1))) - flat_target_extended = torch.unsqueeze(flat_target, dim=1) - alpha_extended = torch.gather(alpha_extended, index=flat_target_extended, dim=1) - alpha_extended = torch.squeeze(alpha_extended, dim=1) - return torch.sum(alpha_extended * (2.0 - wasserstein_distance_map), dim=1) + alpha_per_voxel = self._map_alpha_to_voxels(alpha, flat_target) + return torch.sum(alpha_per_voxel * (2.0 - wasserstein_distance_map), dim=1)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/losses/dice.py` around lines 596 - 630, Both methods _compute_generalized_true_positive and _compute_denominator duplicate the alpha-extension/gather/squeeze block; extract that logic into a private helper (e.g., _map_alpha_to_voxels(self, alpha: torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor) that returns the per-voxel alpha_extended tensor, update both methods to call this helper and use its result in their sums, and remove the duplicated code; also add a Returns: section to the docstrings of _compute_generalized_true_positive and _compute_denominator describing the returned tensor shape and meaning.tests/losses/test_generalized_wasserstein_dice_loss.py (1)
293-295:float()on areduction="none"output is fragile for future readers.
loss_fn(pred_a, target_a)returns shape(1,)here;float()only works because batch size is 1. Prefer.item()or index explicitly to make intent clear.🔧 Suggested clarification
- loss_a = float(loss_fn(pred_a, target_a)) - loss_b = float(loss_fn(pred_b, target_b)) + loss_a = loss_fn(pred_a, target_a)[0].item() + loss_b = loss_fn(pred_b, target_b)[0].item()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 293 - 295, The tests call float(loss_fn(pred_a, target_a)) and float(loss_fn(pred_b, target_b)) where loss_fn is configured with reduction="none" and returns a 1-element tensor; using float() is fragile and hides the intent. Replace float(...) with .item() (e.g., loss_fn(pred_a, target_a).item()) or explicitly index [0] to extract the scalar so the code clearly indicates you're converting a single-element tensor to a Python float; update both occurrences referencing loss_fn, pred_a, target_a, pred_b, and target_b.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 254-270: The mean-reduction subtest uses trivially-zero losses
(pred_single from pred_very_good), so it won't catch a broken mean reduction;
update the test to use a non-trivial prediction for at least one sample: when
constructing pred_single/pred_batch for the GeneralizedWassersteinDiceLoss
checks, replace one of the batch entries (or pred_single) with a clearly poor
prediction (e.g., pred_very_poor) so loss_single and loss_batch produce a
non-zero, different-per-sample value; keep using loss_fn =
GeneralizedWassersteinDiceLoss(..., weighting_mode=w_mode, reduction="mean") and
then assert loss_batch equals loss_single to verify mean reduction behavior.
---
Nitpick comments:
In `@monai/losses/dice.py`:
- Around line 509-514: Update the forward docstring for the Dice loss (method
forward) to add a Google-style "Returns:" section that clearly states the tensor
shape and meaning for each reduction mode; specifically document that when
reduction="none" the output is a 1-D tensor of shape (B,) containing per-batch
loss values (and describe shapes for "mean"/"sum" if applicable), and apply the
same docstring clarification to the other forward variant referenced around
lines 550-553 so callers know to expect a (B,) output instead of the previous
reshape behavior.
- Around line 596-630: Both methods _compute_generalized_true_positive and
_compute_denominator duplicate the alpha-extension/gather/squeeze block; extract
that logic into a private helper (e.g., _map_alpha_to_voxels(self, alpha:
torch.Tensor, flat_target: torch.Tensor) -> torch.Tensor) that returns the
per-voxel alpha_extended tensor, update both methods to call this helper and use
its result in their sums, and remove the duplicated code; also add a Returns:
section to the docstrings of _compute_generalized_true_positive and
_compute_denominator describing the returned tensor shape and meaning.
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 293-295: The tests call float(loss_fn(pred_a, target_a)) and
float(loss_fn(pred_b, target_b)) where loss_fn is configured with
reduction="none" and returns a 1-element tensor; using float() is fragile and
hides the intent. Replace float(...) with .item() (e.g., loss_fn(pred_a,
target_a).item()) or explicitly index [0] to extract the scalar so the code
clearly indicates you're converting a single-element tensor to a Python float;
update both occurrences referencing loss_fn, pred_a, target_a, pred_b, and
target_b.
| # Also test with mean reduction: batch loss should equal single-sample loss | ||
| for w_mode in ["default", "GDL"]: | ||
| loss_fn = GeneralizedWassersteinDiceLoss( | ||
| dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), | ||
| weighting_mode=w_mode, | ||
| reduction="mean", | ||
| ) | ||
|
|
||
| loss_single = float(loss_fn(pred_single, target_single)) | ||
| loss_batch = float(loss_fn(pred_batch, target_batch)) | ||
|
|
||
| self.assertAlmostEqual( | ||
| loss_batch, | ||
| loss_single, | ||
| places=5, | ||
| msg=f"Batch mean loss != single mean loss for weighting_mode={w_mode}", | ||
| ) |
There was a problem hiding this comment.
Mean-reduction sub-test is trivially weak.
pred_single = pred_very_good → loss ≈ 0 for both single and batch inputs. The assertion 0.0 ≈ 0.0 passes even if mean-reduction across a batch is broken. Use non-trivial predictions (e.g., pred_very_poor for one sample) so the expected mean is a non-zero, verifiable value.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 254 -
270, The mean-reduction subtest uses trivially-zero losses (pred_single from
pred_very_good), so it won't catch a broken mean reduction; update the test to
use a non-trivial prediction for at least one sample: when constructing
pred_single/pred_batch for the GeneralizedWassersteinDiceLoss checks, replace
one of the batch entries (or pred_single) with a clearly poor prediction (e.g.,
pred_very_poor) so loss_single and loss_batch produce a non-zero,
different-per-sample value; keep using loss_fn =
GeneralizedWassersteinDiceLoss(..., weighting_mode=w_mode, reduction="mean") and
then assert loss_batch equals loss_single to verify mean reduction behavior.
…oject-MONAI#4650) After `torch.gather`, `alpha_extended` retains shape (B, 1, S) while `wasserstein_distance_map` has shape (B, S). When batch size > 1 the element-wise multiply broadcasts to (B, B, S), mixing values across samples. Fixed by squeezing dim=1 after gather in both `_compute_generalized_true_positive` and `_compute_denominator`, and reducing with `dim=1` instead of `dim=[1, 2]`. Also fixed the `reduction="none"` code path which incorrectly tried to reshape the per-sample loss tensor (B,) to (B, C, 1, ...) — GWDL aggregates over classes internally so the class dimension doesn't apply. Added regression tests that verify batch consistency: - identical samples in a batch produce the same loss as a single sample - batched per-sample losses match individually computed losses Signed-off-by: hongjie-qiu <77599736+hongjie-qiu@users.noreply.github.com>
063df92 to
4887d9d
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 268-297: The test test_batch_size_different_samples is currently
trivial because both pred_a and pred_b are perfect one-hot predictions; change
pred_b to produce a poor prediction (e.g., build pred_b from 1 - target_b before
one_hot) so that loss_b is ≈1.0 and the batch variant checks are meaningful;
keep pred_a as the perfect 1000*F.one_hot(target_a, ...) and ensure pred_batch =
torch.cat([pred_a, pred_b], dim=0) remains consistent so loss_a, loss_b, and
loss_batch indices compare correctly inside the weighting_mode loop that calls
GeneralizedWassersteinDiceLoss and asserts loss_batch[0]==loss_a and
loss_batch[1]==loss_b.
---
Duplicate comments:
In `@tests/losses/test_generalized_wasserstein_dice_loss.py`:
- Around line 252-266: The mean-reduction sub-test is weak because
pred_single/pred_batch are perfect (loss ≈ 0), so it doesn't catch
mean-reduction bugs; update the test to use non-perfect predictions so
per-sample losses are non-zero and comparable: for
GeneralizedWassersteinDiceLoss with weighting_mode set from ["default","GDL"]
and reduction="mean", construct pred_single/pred_batch (or modify pred_single to
contain a small deliberate error and replicate into pred_batch as the batch of
samples) and/or use multiple distinct samples in pred_batch, compute loss_single
= float(loss_fn(pred_single, target_single)) and loss_batch =
float(loss_fn(pred_batch, target_batch)), then assert loss_batch ≈ loss_single
to verify the mean aggregation is correct; reference
GeneralizedWassersteinDiceLoss, pred_single, pred_batch, target_single,
target_batch, weighting_mode, and reduction="mean" when making the change.
| def test_batch_size_different_samples(self): | ||
| """ | ||
| Regression test for https://github.com/Project-MONAI/MONAI/issues/4650 | ||
| Verify loss is computed correctly when batch contains different samples. | ||
| """ | ||
| target_a = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]).unsqueeze(0) | ||
| target_b = torch.tensor([[1, 1, 0, 0], [1, 1, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]).unsqueeze(0) | ||
|
|
||
| pred_a = 1000 * F.one_hot(target_a, num_classes=2).permute(0, 3, 1, 2).float() | ||
| pred_b = 1000 * F.one_hot(target_b, num_classes=2).permute(0, 3, 1, 2).float() | ||
|
|
||
| # Combine into a batch | ||
| target_batch = torch.cat([target_a, target_b], dim=0) | ||
| pred_batch = torch.cat([pred_a, pred_b], dim=0) | ||
|
|
||
| for w_mode in ["default", "GDL"]: | ||
| loss_fn = GeneralizedWassersteinDiceLoss( | ||
| dist_matrix=np.array([[0.0, 1.0], [1.0, 0.0]]), weighting_mode=w_mode, reduction="none" | ||
| ) | ||
|
|
||
| loss_a = float(loss_fn(pred_a, target_a)) | ||
| loss_b = float(loss_fn(pred_b, target_b)) | ||
| loss_batch = loss_fn(pred_batch, target_batch) | ||
|
|
||
| self.assertAlmostEqual( | ||
| float(loss_batch[0]), loss_a, places=5, msg=f"Batch loss[0] != loss_a for weighting_mode={w_mode}" | ||
| ) | ||
| self.assertAlmostEqual( | ||
| float(loss_batch[1]), loss_b, places=5, msg=f"Batch loss[1] != loss_b for weighting_mode={w_mode}" | ||
| ) |
There was a problem hiding this comment.
Both samples use perfect predictions — regression test is trivially weak.
pred_a = 1000 * F.one_hot(target_a, ...) and pred_b = 1000 * F.one_hot(target_b, ...) both produce a Wasserstein distance map ≈ 0 everywhere (identity cost matrix, correct class predicted). Every assertion reduces to 0 ≈ 0 and would pass on the unpatched code, defeating the stated regression purpose.
Use at least one poor prediction (e.g., pred_b = 1000 * F.one_hot(1 - target_b, ...)) so loss_b ≈ 1.0 and loss_batch[1] ≈ 1.0 is a non-trivial check.
🛠️ Suggested fix
-pred_b = 1000 * F.one_hot(target_b, num_classes=2).permute(0, 3, 1, 2).float()
+pred_b = 1000 * F.one_hot(1 - target_b, num_classes=2).permute(0, 3, 1, 2).float()Then loss_b will be ≈ 1.0, giving a meaningful per-sample regression check.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/losses/test_generalized_wasserstein_dice_loss.py` around lines 268 -
297, The test test_batch_size_different_samples is currently trivial because
both pred_a and pred_b are perfect one-hot predictions; change pred_b to produce
a poor prediction (e.g., build pred_b from 1 - target_b before one_hot) so that
loss_b is ≈1.0 and the batch variant checks are meaningful; keep pred_a as the
perfect 1000*F.one_hot(target_a, ...) and ensure pred_batch = torch.cat([pred_a,
pred_b], dim=0) remains consistent so loss_a, loss_b, and loss_batch indices
compare correctly inside the weighting_mode loop that calls
GeneralizedWassersteinDiceLoss and asserts loss_batch[0]==loss_a and
loss_batch[1]==loss_b.
Fixes #4650
Description
When
batch_size > 1,GeneralizedWassersteinDiceLossproduces incorrect loss values because of a tensor broadcasting issue in_compute_generalized_true_positiveand_compute_denominator.After
torch.gather,alpha_extendedhas shape(B, 1, S)whilewasserstein_distance_maphas shape(B, S). The element-wise multiply silently broadcasts to(B, B, S), which mixes values across batch samples. This means the loss has always been wrong for any training run withbatch_size > 1.The fix follows the reference implementation by the original paper's author — squeeze
dim=1after the gather so both tensors are(B, S), and reduce withdim=1instead ofdim=[1, 2].I also noticed that
reduction="none"was broken (never had test coverage) — it tried to reshape the per-sample loss(B,)into(B, C, 1, ...), but GWDL aggregates over classes internally so the class dimension doesn't exist in the output. Fixed that as well.Changes
monai/losses/dice.py: squeeze + dim fix in_compute_generalized_true_positiveand_compute_denominator; fixedreduction="none"pathtests/losses/test_generalized_wasserstein_dice_loss.py: two new regression tests for batch consistencyTests
All existing tests pass. The new regression tests fail on unpatched code and pass with the fix.