Skip to content

Conversation

@wyli
Copy link
Contributor

@wyli wyli commented Jul 2, 2021

Signed-off-by: Wenqi Li wenqil@nvidia.com

Fixes #2509

Description

using sigmoid activation to align with the literature

the change is verified:
previously gamma=0 reduces it to nn.CrossEntropyLoss
now gamma=0 reduces it to nn.BCEWithLogitsLoss

Status

Ready

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

@wyli wyli requested review from Nic-Ma and yiheng-wang-nv July 2, 2021 19:02
@wyli wyli force-pushed the 2509-sigmoid-focal-loss branch from fc195d0 to f29ebce Compare July 2, 2021 19:17
@wyli wyli requested review from ericspod and rijobro July 2, 2021 19:18
@Nic-Ma
Copy link
Contributor

Nic-Ma commented Jul 2, 2021

Thanks for the quick fix.
@yiheng-wang-nv could you please help review it?

Thanks.

@yiheng-wang-nv
Copy link
Contributor

Hi @wyli , in paper (https://arxiv.org/pdf/1708.02002.pdf) it uses sigmoid operation to compute p but it also says to use 2-class classification for instance:
Screen Shot 2021-07-05 at 10 58 15 AM
However, for multi-class task, I think it still needs to use softmax operation, just as the cross entropy loss does:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLoss

Hi @ristoh @ericspod , could you please help to double check it?

@wyli
Copy link
Contributor Author

wyli commented Jul 5, 2021

https://arxiv.org/pdf/1708.02002.pdf

Hi @yiheng-wang-nv, the footnote you cited is for describing the formulation. In section 4 classification subnet, it's clear: "for each of the A anchors and K object classes... Finally sigmoid activations are attached to output the KA
binary predictions per spatial location"

replacing the sigmoid with softmax should work as well (for classification with mutually exclusive classes), for example in their source code they support both:
https://github.com/facebookresearch/Detectron/blob/1809dd41c1ffc881c0d6b1c16ea38d08894f8b6d/detectron/modeling/retinanet_heads.py#L281-L295

tensorflow implements the sigmoid version https://github.com/tensorflow/addons/blob/e83e71cf07f65773d0f3ba02b6de66ec3b190db7/tensorflow_addons/losses/focal_loss.py

@wyli wyli force-pushed the 2509-sigmoid-focal-loss branch from f939af1 to c39edcf Compare July 5, 2021 10:42
wyli added 3 commits July 5, 2021 11:42
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
@wyli wyli force-pushed the 2509-sigmoid-focal-loss branch from c39edcf to db64d08 Compare July 5, 2021 10:42
@sandylaker
Copy link

However, for multi-class task, I think it still needs to use softmax operation, just as the cross entropy loss does:
https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLoss

Hi @ristoh @ericspod , could you please help to double check it?

I would like to add an extra comment. For anchor-based detector, sigmoid is more widely adopted. So the classification branch actually performs multi-label classification. At inference stage, each anchor box will be classified to the class with highest score, even when the scores do not fulfill sum(softmax(scores)) = 1. Hope it helps.

@wyli wyli requested a review from ericspod July 5, 2021 19:12
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
@wyli wyli force-pushed the 2509-sigmoid-focal-loss branch from 57a6011 to dfb3318 Compare July 5, 2021 19:54
@wyli wyli requested a review from ericspod July 6, 2021 14:10
@deepib deepib changed the title 2509 update focalloss to use sigmoid 2509 update focalloss to use sigmoid (7/July) Jul 6, 2021
@ericspod ericspod self-requested a review July 6, 2021 20:08
Copy link
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good to go though perhaps some more explanation of the computation would help to clarify why this differs from the "standard" definitions of focal loss most commonly seen, why things are done in log space for instance.

@wyli wyli enabled auto-merge (squash) July 6, 2021 23:41
@wyli wyli merged commit 17dde67 into Project-MONAI:dev Jul 7, 2021
@wyli wyli deleted the 2509-sigmoid-focal-loss branch November 16, 2021 14:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Focal Loss returning 0.0 with include_background=False (7/July)

5 participants