Skip to content
This repository has been archived by the owner on Mar 17, 2021. It is now read-only.

Dice loss - mean over labels not in patch #23

Closed
elitap opened this issue Jan 3, 2018 · 8 comments
Closed

Dice loss - mean over labels not in patch #23

elitap opened this issue Jan 3, 2018 · 8 comments

Comments

@elitap
Copy link
Contributor

elitap commented Jan 3, 2018

Related to issue #22 I think there is also a problem with Dice Loss. Eg. for a two class problem the loss function returns a Dice value of 0.5 if just one of the two label is present in the patch, tough the prediction is correctly segmenting that one label. In the paper referenced in #22 (Generalized Dice overlap as a deep learning loss function for highly unbalanced segmentations) it is mentioned that for the loss functions to work correctly, the presence of each label is required (I also noticed that the Dice loss formulation in the paper, having epsilon in the nominator and denominator cant be true). Anyway, I would argue that this is rarely the case to have all labels inside one patch, if no special sampling strategy is used, when large sized image volumes are to segment.
To fix that, I would suggest to calculate the mean just over the labels being present in the patch?

Eg. fixing the dice loss in loss_segmentation.py as following:

one_hot_summed = tf.sparse_reduce_sum(one_hot, reduction_axes=[0])
return 1.0 - (tf.reduce_sum(dice_score) / tf.to_float(tf.count_nonzero(one_hot_summed)))

Once more thank you very much for the great work you are doing here.
Best, elias

@FabianIsensee
Copy link

FabianIsensee commented Jan 3, 2018

for a two class problem the loss function returns a Dice value of 0.5 if just one of the two label is present in the patch, tough the prediction is correctly segmenting that one label

I believe this is indeed intended behavior. If one label is not present, the intersection and thus the dice is 0 by definition. Since the dice loss takes the mean over all classes, the dice scores are 1 and 0, respectively. If you subtract 0.5 (which is the mean dice) from 1 then you end up with 0.5.
Of couse depending to what specific implementation you are referring to this may be a bit different (Generalized Dice loss for instance)

@elitap
Copy link
Contributor Author

elitap commented Jan 3, 2018

Hm agree, as it is for now that is what happens. But, why would I want to have a loss function reaching at max 0.5 (binary use case - might be different for more classes) for 100% correct segmented results? Again agree talking about the Dice loss not the generalized dice.

@FabianIsensee
Copy link

Keep in mind that this is just the loss function, not the dice score. You only want it to get gradients and the gradients that the network get come from two classes. If one class does not exist and the corresponding dice loss is zero, the gradients for that class will be zero as well and nothing will be done regarding the network output for this class. If you want dice scores I suggest you compute them separately.
If you have only one class (foreground vs background) then you can design your network to only have one output and only that output will be used for the loss. Then you will get a loss of 1 in the case you described (mean dice for empty slice is 0, the dice scores here are 1 - mean_dice_loss).

@elitap
Copy link
Contributor Author

elitap commented Jan 3, 2018

Really interesting, very nice explanation close that issue as intended behavior. Thank you

@elitap elitap closed this as completed Jan 3, 2018
elitap added a commit to elitap/NiftyNet that referenced this issue Jan 16, 2018
…tly sampled patch, contraty to the discussion (NifTK#23) we believe that the loss function should not hava an upper boundary in dependence of the sampling
@elitap elitap reopened this Jul 19, 2018
@elitap
Copy link
Contributor Author

elitap commented Jul 19, 2018

I recently head some time to follow up on that. I am working on an approach to segment organs in the head and neck area, and it is just to often the case that my patches do not contain labels from all organs. Anyways I directly compared both my updated version of the Dice loss to the original Niftynet one, and I would still argue that this should be changed. Although, the differences on my test data can be explained by a "better" local Minima (test_50k(csv).txt), I still argue that a loss converging to some arbitrary value depending on the patch sampling is not very interpretable.

loss_comparsion

In my case both versions achieve the same results (the proposed version even slightly better), but the loss of my version is easier to interpret, also my version is identical to the original one if all labels are present. As of that I would reopen that issue to be discussed.

Thank you :-)

@FabianIsensee
Copy link

Interesting finding. I will implement your loss as well and see what is does to my data. It may take a few days though.
Intuitively I would say this is just a cosmetic improvement and I do not expect it to actually work better (statistically significant). Even though it may look better in your plots, it still does not reflect what you want to see and is still far from interpretable. Whenever you do patch based training, the dice scores AND the dice loss you get on patches mean nothing because dice is a global metric and can only be computed properly on entire patients. What you are seeing here is basically the background dice most of the time (I am assuming that most of your patches only contain background? Does niftynet optimize the background dice as well? I never used NiftyNet 😄 )

@elitap
Copy link
Contributor Author

elitap commented Jul 19, 2018

Totally agree, I guess its mostly cosmetic but having a loss converging to zero makes just more sense in my world than a loss that converges to I dont know what :-) I had plots for each organ individually so yes it is mostly the background in that plot. Hm and speaking of significance tests, can you recommend one? In segmentation literature different once are used, and I somehow get the feeling that mostly one underlining the statement of the work is used. My Dices (HD distances) are clearly not normal distributed, but I use a hierarchical approach so I usually stick with a paired Wilcoxon test. In this case I would argue the data is unpaired. So long story short, If you are interested I run one, though I am no mathematician and therefore not a true believer in what I read regarding those tests :-)
Really cool that you try it on your data, I am curious about the outcome :-) Thanks

@wyli
Copy link
Member

wyli commented Sep 14, 2018

Just to clarify...

niftynet's version:
1.0 - tf.reduce_mean(dice_score)
'cosmetic' version proposed in this thread:
1.0 - (tf.reduce_sum(dice_score) / tf.to_float(tf.count_nonzero(one_hot_summed)))

Since the sum of dice_scores are scaled by different factors in these versions, the gradients are scaled differently. But depending on the learning rate/number of classes/task, the difference can be tiny.

@wyli wyli closed this as completed Sep 14, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants