Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN gradients for sqrt #1101

Open
JordiBolibar opened this issue Oct 13, 2021 · 7 comments · May be fixed by JuliaDiff/ChainRules.jl#599
Open

NaN gradients for sqrt #1101

JordiBolibar opened this issue Oct 13, 2021 · 7 comments · May be fixed by JuliaDiff/ChainRules.jl#599

Comments

@JordiBolibar
Copy link
Contributor

After a long time hunting a bug with @facusapienza21, we have realized that Zygote fails to provide a gradient for the basic sqrt function. This has been discussed at length in this Discourse thread.

Here's a MWE to reproduce the issue:

using Zygote 
using Flux

A₀ = [[1,0] [0,3]]
A₁ = [[0,0] [0,0]]

function loss(θ)
    A = A₀.^θ
    A = sqrt.(A)
    return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end

θ = 4.0
loss_θ, back_θ = Zygote.pullback(loss, θ) 

For this last case, the value of back_θ(1.0) is NaN. However, if we avoid the use of sqrt() by defining the loss function as

function loss(θ)
    A = A₀.^/2)
    return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end

then Zygote provides the right gradient.

According to @mcabbott, "the reason we get NaN is that the slope of sqrt at zero is infinite. That infinity multiplies the slope of 0^x at 4, which is zero. Whereas with the 0^(x/2) version, the slope is simply zero".

Being such a basic function, this bug can potentially impact a large number of users.

@mcabbott
Copy link
Member

I don't think this is really a bug in sqrt that can be solved. But avoiding such issues is one reason to define gradient rules for larger functions -- this is something like norm, and we can choose to give that smoother behaviour.

It's a little like sin(x)/x, which has an obvious definition at zero to make it continuous, but the computer does not know this and gives you NaN. Which is a reason to wrap it in a function sinc(x) = iszero(x) ? zero(x) : sin(x)/x to help out. Although again the derivative goes wrong, abs(ForwardDiff.derivative(sinc, 1e-40)) > 1e20, which we could smooth out with further rules.

@mcabbott
Copy link
Member

Xref also discussion here: #1036 . It might be possible to regularise all Inf gradients; this is likely to sometimes lead to wrong finite gradients, but perhaps they are acceptable, and perhaps such smoothed gradients would be more useful?

@JordiBolibar
Copy link
Contributor Author

That's a good point. To be honest, I'm not sure wrong gradients are better than an error. I would say that for this, if no perfect solution is available, the best solution might be an informative error. Something pointing out that the issue comes from this, and proposing a solution (e.g. just use ^1/2), like the one I used.

@mcabbott
Copy link
Member

I agree that an error is often better than a NaN. Looks like this has been discussed a bit:

https://discourse.julialang.org/t/treating-nan-as-error-helping-debugging/36933

JuliaLang/julia#27705

Less ambitiously, something like this could potentially be added only to AD. For instance inserting a function which is by default check_nan() = false into @scalar_derivative would I think let you recompile all rules to have a check in them, for debugging.

@Alexander-Barth
Copy link

Zygote and PyTorch seem to behave similarily in these cases:

gradient(x ->  x * sqrt(x),0)
# (NaN,)
gradient(x ->  x^(1.5),0)
# (0.0,)

PyTorch:

x = torch.tensor([0.], requires_grad=True); f = x*torch.sqrt(x); f.backward(); x.grad
# tensor([nan])
x = torch.tensor([0.], requires_grad=True); f = x**(1.5); f.backward(); x.grad
# tensor([0.])

To me (to my math professors, as far as I remember :-)) √x and x / √x are just two different functions. √x = 0 for x = 0 but x / √x is undefined for x = 0 (they are equal almost everywhere but still different).

Given that the derivative of sqrt is undefined (in the mathematical sense) for x = 0, having NaN as a results seems quite logical too me. I would not expect any symbolic transformation from Zygote (or PyTorch) to lift this pathological case.

@mcabbott
Copy link
Member

There might be more clever ways, ForwardDiff's nan-safe mode works around some cases where the simple conclusion would be NaN. Today's discussion here: JuliaDiff/ChainRules.jl#576

@MariusDrulea
Copy link

The derivative of sqrt(x) is 1/(2*sqrt(x)) so it has to be Inf around 0, as 1/0 return Inf in Julia. Zygote, ForwardDiff, ReverseDiff are right here. It would be a terrible mistake if these AD tools will return something else.

Possible solutions to avoid Inf gradients for sqrt:

  1. sqrte(x) = sqrt(x+e), where e is a small positive number.
  2. sqrt_(x) = x > e ? sqrt(x) : sqrt(x+e), this is for the case you want to keep the exact sqrt behaviour for most x values

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants