Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

complex number support #455

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ForwardDiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ include("derivative.jl")
include("gradient.jl")
include("jacobian.jl")
include("hessian.jl")
include("complex.jl")

export DiffResults

Expand Down
172 changes: 172 additions & 0 deletions src/complex.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
Base.prevfloat(x::Dual{T,V}) where {T,V<:AbstractFloat} = prevfloat(x.value)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
Base.nextfloat(x::Dual{T,V}) where {T,V<:AbstractFloat} = nextfloat(x.value)

function Base.log(z::Complex{T}) where {A, T<:Dual{A,<:AbstractFloat}}
T1::T = 1.25
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
T2::T = 3
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
ln2::T = log(convert(T,2)) #0.6931471805599453
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
x, y = reim(z)
ρ, k = Base.ssqs(x,y)
ax = abs(x)
ay = abs(y)
if ax < ay
θ, β = ax, ay
else
θ, β = ay, ax
end
if k==0 && (0.5 < β*β) && (β <= T1 || ρ < T2)
ρρ = log1p((β-1)*(β+1)+θ*θ)/2
else
ρρ = log(ρ)/2 + k*ln2
end
Complex(ρρ, angle(z))
end
function Base.tanh(z::Complex{T}) where {A, T<:Dual{A,<:AbstractFloat}}
Ω = prevfloat(typemax(T))
ξ, η = reim(z)
if isnan(ξ) && η==0 return Complex(ξ, η) end
if 4*abs(ξ) > asinh(Ω) #Overflow?
Complex(copysign(one(T),ξ),
copysign(zero(T),η*(isfinite(η) ? sin(2*abs(η)) : one(η))))
else
t = tan(η)
β = 1+t*t #sec(η)^2
s = sinh(ξ)
ρ = sqrt(1 + s*s) #cosh(ξ)
if isinf(t)
Complex(ρ/s,1/t)
else
Complex(β*ρ*s,t)/(1+β*s*s)
end
end
end

_convert(T, x::Dual) = convert(T, x.value)
function Base._cpow(z::Union{Dual{A,T}, Complex{<:Dual{A,T}}}, p::Union{Dual{B,T}, Complex{<:Dual{B,T}}}) where {T,A,B}
if isreal(p)
pᵣ = real(p)
if isinteger(pᵣ) && abs(pᵣ) < typemax(Int32)
# |p| < typemax(Int32) serves two purposes: it prevents overflow
# when converting p to Int, and it also turns out to be roughly
# the crossover point for exp(p*log(z)) or similar to be faster.
if iszero(pᵣ) # fix signs of imaginary part for z^0
zer = flipsign(copysign(zero(T),pᵣ), imag(z))
return Complex(one(T), zer)
end
ip = _convert(Int, pᵣ)
if isreal(z)
zᵣ = real(z)
if ip < 0
iszero(z) && return Complex(T(NaN),T(NaN))
re = Base.power_by_squaring(inv(zᵣ), -ip)
im = -imag(z)
else
re = Base.power_by_squaring(zᵣ, ip)
im = imag(z)
end
# slightly tricky to get the correct sign of zero imag. part
return Complex(re, ifelse(iseven(ip) & signbit(zᵣ), -im, im))
else
return ip < 0 ? Base.power_by_squaring(inv(z), -ip) : Base.power_by_squaring(z, ip)
end
elseif isreal(z)
# (note: if both z and p are complex with ±0.0 imaginary parts,
# the sign of the ±0.0 imaginary part of the result is ambiguous)
if iszero(real(z))
return pᵣ > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im
elseif real(z) > 0
return Complex(real(z)^pᵣ, z isa Real ? ifelse(real(z) < 1, -imag(p), imag(p)) : flipsign(imag(z), pᵣ))
else
zᵣ = real(z)
rᵖ = (-zᵣ)^pᵣ
if isfinite(pᵣ)
# figuring out the sign of 0.0 when p is a complex number
# with zero imaginary part and integer/2 real part could be
# improved here, but it's not clear if it's worth it…
return rᵖ * complex(cospi(pᵣ), flipsign(sinpi(pᵣ),imag(z)))
else
iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0
return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input
end
end
else
rᵖ = abs(z)^pᵣ
ϕ = pᵣ*angle(z)
end
elseif isreal(z)
iszero(z) && return real(p) > 0 ? complex(z) : Complex(T(NaN),T(NaN)) # 0 or NaN+NaN*im
zᵣ = real(z)
pᵣ, pᵢ = reim(p)
if zᵣ > 0
rᵖ = zᵣ^pᵣ
ϕ = pᵢ*log(zᵣ)
else
r = -zᵣ
θ = copysign(T(π),imag(z))
rᵖ = r^pᵣ * exp(-pᵢ*θ)
ϕ = pᵣ*θ + pᵢ*log(r)
end
else
pᵣ, pᵢ = reim(p)
r = abs(z)
θ = angle(z)
rᵖ = r^pᵣ * exp(-pᵢ*θ)
ϕ = pᵣ*θ + pᵢ*log(r)
end

if isfinite(ϕ)
return rᵖ * cis(ϕ)
else
iszero(rᵖ) && return zero(Complex{T}) # no way to get correct signs of 0.0
return Complex(T(NaN),T(NaN)) # non-finite phase angle or NaN input
end
end

function Base.ssqs(x::T, y::T) where T<:Dual
k::Int = 0
ρ = x*x + y*y
if !isfinite(ρ) && (isinf(x) || isinf(y))
ρ = convert(T, Inf)
elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2)
m::T = max(abs(x), abs(y))
k = m==0 ? m : exponent(m)
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk
end
ρ, k
end

function Base.sqrt(z::Complex{T}) where {T<:Dual{<:Any,<:AbstractFloat}}
x, y = reim(z)
if x==y==0
return Complex(zero(x),y)
end
ρ, k::Int = Base.ssqs(x, y)
if isfinite(x) ρ=_ldexp(abs(x),-k)+sqrt(ρ) end
if isodd(k)
k = div(k-1,2)
else
k = div(k,2)-1
ρ += ρ
end
ρ = _ldexp(sqrt(ρ),k) #sqrt((abs(z)+abs(x))/2) without over/underflow
ξ = ρ
η = y
if ρ != 0
if isfinite(η) η=(η/ρ)/2 end
if x<0
ξ = abs(η)
η = copysign(ρ,y)
end
end
Complex(ξ,η)
end

# TODO: polish this ldexp function.
function _ldexp(x::T, e::Integer) where T<:Dual
if e >=0
x * (1<<e)
else
x / (1<<-e)
end
end
56 changes: 56 additions & 0 deletions test/ComplexTest.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
module ComplexTest
using ForwardDiff: Dual
GiggleLiu marked this conversation as resolved.
Show resolved Hide resolved
using Test, ForwardDiff

function numeric_jacobian_complex(f, args::T...; δ=1e-5, kwargs...) where T<:Complex
n = length(args)
J = zeros(2, 2n)
largs = [args...]
for i=1:n
# perturb real
largs[i] += δ/2
pos = f(largs...; kwargs...)
largs[i] -= δ
neg = f(largs...; kwargs...)
largs[i] += δ/2
J[1,2i-1] = (real(pos) - real(neg))/δ
J[2,2i-1] = (imag(pos) - imag(neg))/δ
# perturb real
largs[i] += δ/2*im
pos = f(largs...; kwargs...)
largs[i] -= δ*im
neg = f(largs...; kwargs...)
largs[i] += δ/2*im
J[1,2i] = (real(pos) - real(neg))/δ
J[2,2i] = (imag(pos) - imag(neg))/δ
end
return J
end

function complex_jacobian_wrapper(f)
function newf(params)
newargs = [Complex(params[2i-1], params[2i]) for i=1:length(params)÷2]
res = f(newargs...)
[real(res), imag(res)]
end
end

function check_complex_jacobian(f, args...; kwargs...)
nj = numeric_jacobian_complex(f, args...; δ=1e-5, kwargs...)
params = vcat([[x.re, x.im] for x in args]...)
fj = ForwardDiff.jacobian(complex_jacobian_wrapper(f), params)
@test isapprox(nj, fj, atol=1e-5)
end

@testset "complex instructions" begin
for OP in [+, *, /, -, ^]
println(" ...testing Complex Valued $OP")
check_complex_jacobian(OP, 4.0+2im, 2.0+1im)
end
for OP in [abs, abs2, real, imag, conj, adjoint, sin, cos, tan,
sinh, cosh, tanh, exp, log, angle, x->x^3, x->x^0.5, sqrt]
println(" ...testing Complex Valued $OP")
check_complex_jacobian(OP, 4.0+2im)
end
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ println("done (took $t seconds).")
println("Testing miscellaneous functionality...")
t = @elapsed include("MiscTest.jl")
println("done (took $t seconds).")

println("Testing complex numbers...")
t = @elapsed include("ComplexTest.jl")
println("done (took $t seconds).")