Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Dice losses computed for the whole batch? #22

Closed
FabianIsensee opened this issue Jan 3, 2018 · 5 comments
Closed

Dice losses computed for the whole batch? #22

FabianIsensee opened this issue Jan 3, 2018 · 5 comments
Labels

Comments

@FabianIsensee
Copy link

Hey there,
when looking into the implementation of the dice losses I noticed that they do not respect each image in the batch individually but are rather computed over the entire batch (if I understood that correctly). This was not mentioned in the corresponding paper ("Generalized Dice overlap as a deep learning loss function for highly unbelanced segmentations" by CH Sudre et al.).
It that behavior intentional? Due to the nature of the dice loss, computing it over the entire batch vs computing it for each sample individually and then taking the mean is not equivalent.
Regards,
Fabian

@Zach-ER
Copy link
Collaborator

Zach-ER commented Jan 8, 2018

Hi Fabian,
I believe that this behaviour is not intentional -- good spot!
I will work on correcting this as soon as possible.

@Zach-ER
Copy link
Collaborator

Zach-ER commented Jan 12, 2018

This behaviour is now fixed in commit:
196024b

Thanks again for pointing it out!

@wyli
Copy link
Member

wyli commented Jan 12, 2018

Thanks both, I'm closing this issue.

@wyli wyli closed this as completed Jan 12, 2018
@wyli wyli added the bug label Jan 23, 2018
@shilpa-ananth
Copy link

Hello, I'm having a little trouble understanding this. Under the generalized_dice_loss function, when the score is calculated as :
generalised_dice_numerator =
2 * tf.reduce_sum(tf.multiply(weights, intersect))
generalised_dice_denominator =
tf.reduce_sum(tf.multiply(weights, seg_vol + ref_vol))
generalised_dice_score =
generalised_dice_numerator / generalised_dice_denominator

  1. Does this mean you calculate the numerator (reduce the sum over) for the entire batch and the denominator for the entire batch? Or am I not calculating the sizes correctly?
    From my understanding, weights is of size [batch_size, classes]. Wouldn't you have to sum over the labels first, divide the by the denominator and then take a mean over the batch?

  2. Another question I had was whether the function was supposed to take in logits, or scaled probabilities.

@wyli
Copy link
Member

wyli commented Mar 5, 2018

Hi @shilpa-ananth ,
in the latest version (0.2.2) the Dice loss is computed for each sample in a mini-batch:
https://github.com/NifTK/NiftyNet/blob/v0.2.2/niftynet/layer/loss_segmentation.py#L79
and the user could choose to use either logits or probs:
https://github.com/NifTK/NiftyNet/blob/v0.2.2/niftynet/layer/loss_segmentation.py#L84

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

4 participants