-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Relax complex function signatures to make them ForwardDiff compatible #36030
Conversation
Example: using ForwardDiff
Base.float(d::ForwardDiff.Dual{T}) where T = ForwardDiff.Dual{T}(float(d.value), d.partials)
Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(float(d.value)), d.partials)
Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(float(d.value)), d.partials)
function Base.ldexp(x::T, e::Integer) where T<:ForwardDiff.Dual
if e >=0
x * (1<<e)
else
x / (1<<-e)
end
end julia> using Test
julia> for f in [atanh, tanh, acos, x->2^x, x->(2+im)^x, exp10, exp2, expm1, log, sqrt]
@inferred f(ForwardDiff.Dual(1, 1) + im)
end |
Any potential performance impacts? |
I don't think the CI failure is related. |
Even running on the nightly I am still having an error thrown with some of the examples in @YingboMa's comment (#36030 (comment)). i.e.
Any help/advice would be great! I am new to julia but keen to learn. |
You must define the functions listed above, the following works on a 14 day old nightly. julia> using ForwardDiff
julia> Base.float(d::ForwardDiff.Dual{T}) where T = ForwardDiff.Dual{T}(float(d.value), d.partials)
julia> Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(float(d.value)), d.partials)
julia> Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(float(d.value)), d.partials)
julia> function Base.ldexp(x::T, e::Integer) where T<:ForwardDiff.Dual
if e >=0
x * (1<<e)
else
x / (1<<-e)
end
end
julia> sqrt(ForwardDiff.Dual(1., 1.) + 0im)
Dual{Nothing}(1.0,0.5) + Dual{Nothing}(0.0,0.0)*im |
Amazing! It is working now for me. |
Am I doing something wrong? julia> using ForwardDiff
julia> Base.float(d::ForwardDiff.Dual{T}) where T = ForwardDiff.Dual{T}(float(d.value), d.partials)
julia> Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(float(d.value)), d.partials)
julia> Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(float(d.value)), d.partials)
julia> function Base.ldexp(x::T, e::Integer) where T<:ForwardDiff.Dual
if e >=0
x * (1<<e)
else
x / (1<<-e)
end
end
julia> sqrt(ForwardDiff.Dual(1., 1.) + 0im)
ERROR: StackOverflowError:
Stacktrace:
[1] sqrt(::Complex{ForwardDiff.Dual{Nothing,Float64,1}}) at .\complex.jl:506 (repeats 79984 times) |
Are you using the nightly |
This PR relaxes complex math function signatures so
log(::Complex{<:Dual})
can work.Ref: JuliaDiff/ForwardDiff.jl#455