Skip to content

Add sigmoid/softmax activation to AsymmetricUnifiedFocalLoss#8863

Open
AlexanderSanin wants to merge 5 commits into
Project-MONAI:devfrom
AlexanderSanin:feat/unified-focal-loss-activation-8603
Open

Add sigmoid/softmax activation to AsymmetricUnifiedFocalLoss#8863
AlexanderSanin wants to merge 5 commits into
Project-MONAI:devfrom
AlexanderSanin:feat/unified-focal-loss-activation-8603

Conversation

@AlexanderSanin
Copy link
Copy Markdown
Contributor

Summary

Fixes #8603

Adds `use_softmax` and `use_sigmoid` parameters to `AsymmetricUnifiedFocalLoss`, following the same pattern as `FocalLoss`. This allows users to pass raw logits directly without manually applying activations beforehand.

  • `use_softmax=True`: applies softmax along channel dim (for multi-class)
  • `use_sigmoid=True`: applies sigmoid (for binary)
  • Both `False` (default): input assumed to be probabilities — fully backward compatible
  • Mutually exclusive validation with clear error message
  • Removed stale TODO comment and added missing docstrings for `reduction` parameter

Test plan

  • Existing tests pass unchanged (backward compatible defaults)
  • New `test_use_sigmoid`: passes logits with sigmoid activation
  • New `test_use_softmax`: passes logits with softmax activation
  • New `test_mutually_exclusive`: validates that setting both raises `ValueError`

Signed-off-by: Oleksandr Sanin alexaaander.sanin@gmail.com

)

Catching BaseException inadvertently suppresses KeyboardInterrupt,
SystemExit, and GeneratorExit, which should nearly always propagate.
All 17 occurrences across monai/ and tests/ are replaced with
Exception, which is the appropriate base class for catchable errors.

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
…-MONAI#8603)

Add use_softmax and use_sigmoid parameters so users can pass raw
logits directly. When both are False (default), the input is assumed
to already be probabilities, preserving backward compatibility.

Also removes the stale TODO comment about multi-class support and
adds proper docstrings for the reduction parameter.

Signed-off-by: Oleksandr Sanin <alexaaander.sanin@gmail.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 18, 2026

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 1117667e-e2b4-4414-b4a9-c41d8d168dd5

📥 Commits

Reviewing files that changed from the base of the PR and between ec11cc3 and 7ed88e8.

📒 Files selected for processing (1)
  • monai/losses/unified_focal_loss.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • monai/losses/unified_focal_loss.py

📝 Walkthrough

Walkthrough

This PR systematically narrows many catch/suppress blocks from catching BaseException to Exception across multiple modules and tests, and adds two constructor flags (use_softmax, use_sigmoid) to AsymmetricUnifiedFocalLoss with runtime mutual-exclusivity enforcement and accompanying unit tests that validate activations and the exclusivity check.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Out of Scope Changes check ⚠️ Warning Most changes are BaseException→Exception narrowing in unrelated modules. Only core feature change (unified_focal_loss.py) is in-scope. Narrowing is reasonable safety hardening but represents significant scope creep beyond #8603. Separate BaseException→Exception refactoring into a distinct PR. Keep this PR focused solely on AsymmetricUnifiedFocalLoss activation parameters.
Docstring Coverage ⚠️ Warning Docstring coverage is 47.62% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed Title clearly and concisely summarizes the primary change: adding sigmoid/softmax activation parameters to AsymmetricUnifiedFocalLoss.
Description check ✅ Passed Description covers main changes, backward compatibility, test plan, and follows template structure with issue reference. Test checkboxes properly indicate actual test additions.
Linked Issues check ✅ Passed PR implements all requirements from #8603: adds use_softmax/use_sigmoid parameters matching FocalLoss interface, validates mutual exclusivity, enables logit input, and includes comprehensive tests.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
tests/losses/test_unified_focal_loss.py (1)

71-76: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Same issues as test_use_sigmoid.

Single-channel logits will fail one-hot conversion, and the assertion is too weak to verify softmax application.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/losses/test_unified_focal_loss.py` around lines 71 - 76,
test_use_softmax fails because single-channel logits break one-hot conversion
and the current assertion is too weak; update the test to feed multi-channel
logits (C=2) and matching one-hot y_true so
AsymmetricUnifiedFocalLoss(use_softmax=True) can perform softmax, then
strengthen the assertion by comparing the softmax-enabled loss to a baseline
(e.g., compute loss_softmax = loss(y_pred, y_true) with
AsymmetricUnifiedFocalLoss(use_softmax=True) and loss_sigmoid =
AsymmetricUnifiedFocalLoss(use_softmax=False)(same y_pred, y_true)) and assert
both are finite and that loss_softmax differs from loss_sigmoid (or is less
than, depending on expected behavior) to verify softmax was applied; reference
test_use_softmax, AsymmetricUnifiedFocalLoss, use_softmax, y_pred and y_true
when making the changes.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@monai/losses/unified_focal_loss.py`:
- Around line 241-244: The activation is applied after converting single-channel
y_pred to one-hot causing integer-index errors for single-channel logits; update
the UnifiedFocalLoss forward (or the method handling y_pred) to either apply
softmax/sigmoid before the one_hot conversion or add an explicit input
validation that raises a clear error when y_pred has a single channel while
use_softmax or use_sigmoid is True (referencing variables y_pred, use_softmax,
use_sigmoid and the one_hot conversion block) so users get a helpful message
rather than a silent failure.
- Around line 241-244: The code enables softmax via use_softmax but the
component losses AsymmetricFocalLoss and AsymmetricFocalTverskyLoss index
hardcoded class channels ([:, 0] and [:, 1]) so softmax with >2 channels will be
broken; update the functions to validate and reject multi-class inputs (e.g.,
check y_pred.shape[1] or num_classes and raise a clear ValueError if != 2) or
explicitly document in the use_softmax docstring that only binary (2-class)
predictions are supported, and ensure the error message references use_softmax,
AsymmetricFocalLoss, and AsymmetricFocalTverskyLoss so callers know why their
multi-channel input is unsupported.

In `@tests/losses/test_unified_focal_loss.py`:
- Around line 64-69: The test crashes because AsymmetricUnifiedFocalLoss's
internal one_hot conversion is called when y_pred is single-channel float logits
([2,1,2,2]) and those logits are being cast to long, producing out-of-range
indices; change the test to provide multi-channel logits (e.g. shape [2,2,2,2])
when use_sigmoid=True or else supply integer class indices for y_true so one_hot
isn't fed raw logits; update the test_use_sigmoid to create y_pred with two
channels and matching y_true (or use class indices) so one_hot/scatter_ receives
valid class indices.

---

Duplicate comments:
In `@tests/losses/test_unified_focal_loss.py`:
- Around line 71-76: test_use_softmax fails because single-channel logits break
one-hot conversion and the current assertion is too weak; update the test to
feed multi-channel logits (C=2) and matching one-hot y_true so
AsymmetricUnifiedFocalLoss(use_softmax=True) can perform softmax, then
strengthen the assertion by comparing the softmax-enabled loss to a baseline
(e.g., compute loss_softmax = loss(y_pred, y_true) with
AsymmetricUnifiedFocalLoss(use_softmax=True) and loss_sigmoid =
AsymmetricUnifiedFocalLoss(use_softmax=False)(same y_pred, y_true)) and assert
both are finite and that loss_softmax differs from loss_sigmoid (or is less
than, depending on expected behavior) to verify softmax was applied; reference
test_use_softmax, AsymmetricUnifiedFocalLoss, use_softmax, y_pred and y_true
when making the changes.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0e334f54-a695-4d0b-bc39-462797493bc0

📥 Commits

Reviewing files that changed from the base of the PR and between ef2acfb and ec11cc3.

📒 Files selected for processing (15)
  • monai/__init__.py
  • monai/apps/auto3dseg/data_analyzer.py
  • monai/apps/auto3dseg/ensemble_builder.py
  • monai/apps/auto3dseg/utils.py
  • monai/apps/detection/metrics/coco.py
  • monai/apps/nnunet/nnunetv2_runner.py
  • monai/config/deviceconfig.py
  • monai/data/__init__.py
  • monai/inferers/inferer.py
  • monai/losses/unified_focal_loss.py
  • monai/utils/tf32.py
  • tests/apps/detection/networks/test_retinanet.py
  • tests/losses/test_unified_focal_loss.py
  • tests/networks/nets/test_resnet.py
  • tests/test_utils.py

Comment thread monai/losses/unified_focal_loss.py Outdated
Comment on lines +241 to +244
if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
elif self.use_sigmoid:
y_pred = torch.sigmoid(y_pred)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Activation applied after one-hot conversion breaks single-channel logit inputs.

Lines 227-229 convert single-channel y_pred to one-hot before activation is applied here. This means:

  • Single-channel inputs are treated as class indices (discrete values), not logits
  • Passing single-channel logits with use_sigmoid=True will fail because one_hot() expects integers
  • Users must pass 2-channel inputs to use the activation flags

Document this requirement or add validation to raise a clear error for single-channel inputs when activation flags are set.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/unified_focal_loss.py` around lines 241 - 244, The activation is
applied after converting single-channel y_pred to one-hot causing integer-index
errors for single-channel logits; update the UnifiedFocalLoss forward (or the
method handling y_pred) to either apply softmax/sigmoid before the one_hot
conversion or add an explicit input validation that raises a clear error when
y_pred has a single channel while use_softmax or use_sigmoid is True
(referencing variables y_pred, use_softmax, use_sigmoid and the one_hot
conversion block) so users get a helpful message rather than a silent failure.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Multi-class limitation: softmax won't work beyond binary despite being added.

The component losses (AsymmetricFocalLoss lines 135-138 and AsymmetricFocalTverskyLoss lines 79-80) hardcode indices [:, 0] and [:, 1], limiting support to exactly 2 classes. Even with use_softmax=True, inputs with >2 channels will fail or produce incorrect results.

Consider documenting this limitation in the use_softmax docstring or adding validation to reject num_classes != 2.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@monai/losses/unified_focal_loss.py` around lines 241 - 244, The code enables
softmax via use_softmax but the component losses AsymmetricFocalLoss and
AsymmetricFocalTverskyLoss index hardcoded class channels ([:, 0] and [:, 1]) so
softmax with >2 channels will be broken; update the functions to validate and
reject multi-class inputs (e.g., check y_pred.shape[1] or num_classes and raise
a clear ValueError if != 2) or explicitly document in the use_softmax docstring
that only binary (2-class) predictions are supported, and ensure the error
message references use_softmax, AsymmetricFocalLoss, and
AsymmetricFocalTverskyLoss so callers know why their multi-channel input is
unsupported.

Comment thread tests/losses/test_unified_focal_loss.py
AlexanderSanin and others added 3 commits May 19, 2026 09:29
…@gmail.com>

I, Oleksandr Yizchak Sanin <alexaaander.sanin@gmail.com>, hereby add my Signed-off-by to this commit: ec11cc3

Signed-off-by: Oleksandr Yizchak Sanin <alexaaander.sanin@gmail.com>
Move the sigmoid/softmax activation step before the one_hot conversion
in AsymmetricUnifiedFocalLoss.forward(). The one_hot function uses input
values as scatter indices, so passing raw logits (e.g. 10.0) causes
"index out of bounds" errors. Activation must convert logits to
probabilities first.

Signed-off-by: Oleksandr Yizchak Sanin <alexaaander.sanin@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add sigmoid/softmax interface for AsymmetricUnifiedFocalLoss

2 participants