Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical issues for (logit)binarycrossentropy #914

Open
xukai92 opened this issue Nov 1, 2019 · 14 comments
Open

Numerical issues for (logit)binarycrossentropy #914

xukai92 opened this issue Nov 1, 2019 · 14 comments

Comments

@xukai92
Copy link

xukai92 commented Nov 1, 2019

When the logit is large, two functions can behave quite different.

@info Flux.logitbinarycrossentropy(50.0, 0)
@info Flux.binarycrossentropy(Flux.sigmoid(50.0), 0)

┌ Info: 50.0
└ @ Main In[86]:1
┌ Info: 36.04365338911715
└ @ Main In[86]:2

I meet this issue when trianing a GAN using Zygote (which was fine using Tracker before). Switching from logitBCE to BCE stops my training from diverging. This might also be related to recent reported weired training bahaviour using Zygote in other issues.

@MikeInnes
Copy link
Member

I could be missing the point, but isn't this why logitbinarycrossentropy exists – it's a lot more numerically stable than the unfused version?

Of course if the behaviour is better with Tracker, that suggests we can improve something in Zygote; would be good to see what the divergence is there.

@xukai92
Copy link
Author

xukai92 commented Nov 6, 2019

Sorry I'm being note clear here. I should have said that the gradient for these two cases are very different for Zygote.
For the MWE below, I'd expect the gradient is close for both and near zero.

gradient((x) -> Flux.binarycrossentroy(sigmoid(x), 0), 50)
# (0.0,)
gradient((x) -> Flux.logitbinarycrossentroy(x, 0), 50)
# (1.0,)

More information about where it starts to behaves weird can be seen from the sweep below:
image
Hope it's clear!

@MikeInnes
Copy link
Member

That does seem suspicious. Do you know what tracker or reversediff does here?

@xukai92
Copy link
Author

xukai92 commented Nov 7, 2019

image
This is how Tracker works. Bit unexpected though because as I said my training was stable with Tracker... Not sure what exacly happens here then.

@jessebett
Copy link
Contributor

This is possibly exactly what I'm observing in #876.

@matsueushi
Copy link
Contributor

When x is large, exp(-x) << 1 and 1 + exp(-x) becomes 1 in float. The threshold is 16.63553f0 for Float32 and 36.7368005696771 for Float64.

If x exceeds this number, σ(x) = 1 and the gradient σ'(x) become zero because it is defined as σ(x)(1 - σ(x)) in Zygote (https://github.com/FluxML/Zygote.jl/blob/84bf62ea18330389c64d0d918c91d7b897e1a5d8/src/lib/nnlib.jl#L8-L11).
This is why the behaviour of binarycrossentropy(sigmoid(x), 0) looks weird.

using Flux

for x in (16.63553f0, 36.7368005696771)
    around_x = nextfloat.(x, -3:3)
    display(collect(zip(around_x, σ.(around_x))))
end
7-element Array{Tuple{Float32,Float32},1}:
 (16.635525, 0.9999999)
 (16.635527, 0.9999999)
 (16.635529, 0.9999999)
 (16.63553, 0.9999999) 
 (16.635532, 1.0)      
 (16.635534, 1.0)      
 (16.635536, 1.0)      

7-element Array{Tuple{Float64,Float64},1}:
 (36.73680056967708, 0.9999999999999998) 
 (36.73680056967709, 0.9999999999999998) 
 (36.736800569677094, 0.9999999999999998)
 (36.7368005696771, 0.9999999999999998)  
 (36.73680056967711, 1.0)                
 (36.736800569677115, 1.0)               
 (36.73680056967712, 1.0)                
for x in (16.63553f0, 36.7368005696771)
    around_x = nextfloat.(x, -3:3)
    display(collect(zip(around_x, σ'.(around_x))))
end
7-element Array{Tuple{Float32,Float32},1}:
 (16.635525, 1.19209275e-7)
 (16.635527, 1.19209275e-7)
 (16.635529, 1.19209275e-7)
 (16.63553, 1.19209275e-7) 
 (16.635532, 0.0)          
 (16.635534, 0.0)          
 (16.635536, 0.0)          

7-element Array{Tuple{Float64,Float64},1}:
 (36.73680056967708, 2.2204460492503126e-16) 
 (36.73680056967709, 2.2204460492503126e-16) 
 (36.736800569677094, 2.2204460492503126e-16)
 (36.7368005696771, 2.2204460492503126e-16)  
 (36.73680056967711, 0.0)                    
 (36.736800569677115, 0.0)                   
 (36.73680056967712, 0.0)                    

@matsueushi
Copy link
Contributor

matsueushi commented Jan 23, 2020

If we assign a small positive gradient, we can avoid zero gradients in binarycrossentropy(sigmoid(x), 0).

mysigmoid(x) = one(x) / (one(x) + exp(-x))

Flux.@adjoint function mysigmoid(x)
    y = mysigmoid(x)
    z = ifelse(y == one(y), prevfloat(one(x)), y)
    w = z * (1 - z)
    return y, Δ ->* w,)
end

Screen Shot 2020-01-23 at 12 19 48 AM

https://nbviewer.jupyter.org/gist/matsueushi/666c7e7e62c093d998a839017869a519

@xukai92
Copy link
Author

xukai92 commented Feb 18, 2020

I think what @matsueushi explained in #914 (comment) makes sense. Any idea how what we approach to fix this @MikeInnes?

@matsueushi
Copy link
Contributor

Once x exceeds the threshold, sigmoid(x) remains constant. It means binarycrossentropy(sigmoid(x), 0) is also constant and its numerical derivative becomes zero, as we have seen in the behavior of Zygote.

In my opinion, we cannot rely on the value (or gradient) of binarycrossentropy(sigmoid(x), 0) for large x in floating-point arithmetic because even if we make some changes sigmoid becomes constant eventually.

I would recommend to use logitBCE instead of logit + BCE, as tensorflow/tensorflow#2462.

@CarloLucibello
Copy link
Member

we should update all models in Flux's docs and in model-zoo to use a linear layer as the last layer and logitcrossentropy as loss

@ToucheSir
Copy link
Member

For reference, here's PyTorch with Float32:

import torch
import torch.nn.functional as F

zero = torch.tensor(0, dtype=torch.float)

def get_grads(i):
   x = torch.tensor(i, dtype=torch.float, requires_grad=True)
   bce = F.binary_cross_entropy(torch.sigmoid(x), zero)
   lbce = F.binary_cross_entropy_with_logits(x, zero)
   return torch.autograd.grad(bce, x), torch.autograd.grad(lbce, x)
   
for i in range(50):
   print(get_grads(i + 1))
   
### 

((tensor(0.7311),), (tensor(0.7311),))
((tensor(0.8808),), (tensor(0.8808),))
((tensor(0.9526),), (tensor(0.9526),))
((tensor(0.9820),), (tensor(0.9820),))
((tensor(0.9933),), (tensor(0.9933),))
((tensor(0.9975),), (tensor(0.9975),))
((tensor(0.9991),), (tensor(0.9991),))
((tensor(0.9997),), (tensor(0.9997),))
((tensor(0.9999),), (tensor(0.9999),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(1.0000),), (tensor(1.0000),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))
((tensor(0.),), (tensor(1.),))

@CarloLucibello
Copy link
Member

Is there anything actionable here or we can close?

@ToucheSir
Copy link
Member

Possibly, if we want to implement @matsueushi's suggestion around sigmoid. Otherwise I don't think there's anything actionable.

@ToucheSir ToucheSir reopened this Sep 14, 2021
@CarloLucibello
Copy link
Member

Possibly, if we want to implement @matsueushi's suggestion around sigmoid

That asymptotic 0.5 is still wrong and possibly an harder to detect problem. Also performance and cuda compatibility should be assessed. Maybe better leave things as they are @matsueushi ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants