Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax complex function signatures to make them ForwardDiff compatible #36030

Merged
merged 3 commits into from
May 28, 2020

Conversation

YingboMa
Copy link
Contributor

This PR relaxes complex math function signatures so log(::Complex{<:Dual}) can work.

Ref: JuliaDiff/ForwardDiff.jl#455

@YingboMa
Copy link
Contributor Author

YingboMa commented May 26, 2020

Example:

using ForwardDiff
Base.float(d::ForwardDiff.Dual{T}) where T = ForwardDiff.Dual{T}(float(d.value), d.partials)
Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(float(d.value)), d.partials)
Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(float(d.value)), d.partials)
function Base.ldexp(x::T, e::Integer) where T<:ForwardDiff.Dual
    if e >=0
        x * (1<<e)
    else
        x / (1<<-e)
    end
end
julia> using Test

julia> for f in [atanh, tanh, acos, x->2^x, x->(2+im)^x, exp10, exp2, expm1, log, sqrt]
           @inferred f(ForwardDiff.Dual(1, 1) + im)
       end

@jebej
Copy link
Contributor

jebej commented May 26, 2020

Any potential performance impacts?

base/complex.jl Outdated Show resolved Hide resolved
base/complex.jl Outdated Show resolved Hide resolved
@JeffBezanson JeffBezanson added the complex Complex numbers label May 26, 2020
@YingboMa
Copy link
Contributor Author

channels                        |      225     1               226

I don't think the CI failure is related.

@arm61
Copy link

arm61 commented Oct 12, 2020

Even running on the nightly I am still having an error thrown with some of the examples in @YingboMa's comment (#36030 (comment)). i.e.

julia> using ForwardDiff
julia> sqrt(ForwardDiff.Dual(1., 1.) + 0im)
ERROR: MethodError: no method matching nextfloat(::ForwardDiff.Dual{Nothing, Float64, 1})
Closest candidates are:
  nextfloat(::Union{Float16, Float32, Float64}, ::Integer) at float.jl:595
  nextfloat(::BigFloat) at mpfr.jl:911
  nextfloat(::AbstractFloat) at float.jl:639
  ...
Stacktrace:
 [1] ssqs(x::ForwardDiff.Dual{Nothing, Float64, 1}, y::ForwardDiff.Dual{Nothing, Float64, 1})
   @ Base ./complex.jl:472
 [2] sqrt(z::Complex{ForwardDiff.Dual{Nothing, Float64, 1}})
   @ Base ./complex.jl:487
 [3] top-level scope
   @ REPL[19]:1

Any help/advice would be great! I am new to julia but keen to learn.

@baggepinnen
Copy link
Contributor

You must define the functions listed above, the following works on a 14 day old nightly.

julia> using ForwardDiff

julia> Base.float(d::ForwardDiff.Dual{T}) where T = ForwardDiff.Dual{T}(float(d.value), d.partials)

julia> Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(float(d.value)), d.partials)

julia> Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(float(d.value)), d.partials)

julia> function Base.ldexp(x::T, e::Integer) where T<:ForwardDiff.Dual
           if e >=0
               x * (1<<e)
           else
               x / (1<<-e)
           end
       end

julia> sqrt(ForwardDiff.Dual(1., 1.) + 0im)
Dual{Nothing}(1.0,0.5) + Dual{Nothing}(0.0,0.0)*im

@arm61
Copy link

arm61 commented Oct 28, 2020

Amazing! It is working now for me.

@kapple19
Copy link

Am I doing something wrong?

julia> using ForwardDiff

julia> Base.float(d::ForwardDiff.Dual{T}) where T = ForwardDiff.Dual{T}(float(d.value), d.partials)

julia> Base.prevfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(prevfloat(float(d.value)), d.partials)

julia> Base.nextfloat(d::ForwardDiff.Dual{T,V,N}) where {T,V,N} = ForwardDiff.Dual{T}(nextfloat(float(d.value)), d.partials)

julia> function Base.ldexp(x::T, e::Integer) where T<:ForwardDiff.Dual
                  if e >=0
                                 x * (1<<e)
                                            else
                                                           x / (1<<-e)
                                                                      end
                                                                             end

julia> sqrt(ForwardDiff.Dual(1., 1.) + 0im)
ERROR: StackOverflowError:
Stacktrace:
 [1] sqrt(::Complex{ForwardDiff.Dual{Nothing,Float64,1}}) at .\complex.jl:506 (repeats 79984 times)

@arm61
Copy link

arm61 commented Nov 11, 2020

Are you using the nightly 1.6.x build? That's what I needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
complex Complex numbers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants