-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
quantile
gives wrong gradient on Julia 1.6
#973
Comments
Looks like the primal computation itself got messed up: julia> autodiff(Forward, f, Duplicated,Duplicated(2.0, 1.0))
(0.0, 0.0)
julia> f(x)=sum(quantile([1.0, x], [0.7]))^C
julia> x=1.0
1.0
julia> quantile([1.0, x], [0.7])
1-element Vector{Float64}:
1.0
julia> f(2.0)
1.7
julia> f(x)=sum(quantile([1.0, x], [0.7]))
f (generic function with 1 method) |
using Enzyme, Statistics
Enzyme.API.printall!(true)
@inline function myquantile(v::AbstractVector, p::Real; alpha::Real=1.0, beta::Real=alpha)
n = length(v)
m = alpha + p * (one(alpha) - alpha - beta)
aleph = n*p + oftype(p, m)
j = clamp(trunc(Int, aleph), 1, n-1)
γ = clamp(aleph - j, 0, 1)
if n == 1
a = v[1]
b = v[1]
else
a = v[j]
b = v[j + 1]
end
return a + γ*(b-a)
end
function f(x)
v = [1.0, x]
return first(map(x->myquantile(v, x, alpha=1., beta=1.), [0.7]))
end
@show f(2.0)
@show autodiff(Forward, f, Duplicated,Duplicated(2.0, 1.0)) |
This turns out to be a bug in the attributor which changes a store double ... into a store double undef https://godbolt.org/z/Pa7Eejcq4 @jdoerfert any insights here? |
The bug appears to occur in LLVM 11 (above) and 12 (https://godbolt.org/z/YehKx37h8), but not 13+ |
This is effectively fixed by disabling the attributor in #986 for lower LLVM versions, but is still quite unsatisfying. Either way closing as now resolved and the attributor bug is independent of Enzyme.jl |
On Julia 1.9 and Enzyme main (6692fad) this works:
But on Julia 1.6.7 it gives the wrong answer:
autodiff(Forward, f, Duplicated(2.0, 1.0))[1]
also gives0.7
.A possibly simpler variant:
sum(quantile([1.0, x], 0.7))
gives the correct0.7
on both versions butsum(quantile([1.0, x], [0.7]))
gives0.7
on Julia 1.9 (correct) and0.0
on Julia 1.6 (incorrect).The text was updated successfully, but these errors were encountered: