Skip to content

Commit

Permalink
enable one-sided nested differentiation for binary Dual functions (#166)
Browse files Browse the repository at this point in the history
* enable single-sided nested differentiation for binary Dual functions

* partially fix exp perf regression

* get tests passing (don't check for currently failing SIMD instruction, see #167)

* try to fix v0.4 ambiguity warnings
  • Loading branch information
jrevels committed Dec 5, 2016
1 parent 9fb6582 commit 25b93c2
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 40 deletions.
6 changes: 4 additions & 2 deletions src/config.jl
Expand Up @@ -4,6 +4,8 @@ abstract AbstractConfig
# Config #
###########

@inline chunksize(::Tuple{}) = error("empty tuple passed to `chunksize`")

# Define a few different AbstractConfig types. All these types share the same structure,
# but feature different constructors and dispatch restrictions in downstream code.
for Config in (:GradientConfig, :JacobianConfig)
Expand All @@ -30,8 +32,8 @@ for Config in (:GradientConfig, :JacobianConfig)
Base.copy{N,T,D}(cfg::$Config{N,T,D}) = $Config{N,T,D}(cfg.seeds, copy(cfg.duals))
Base.copy{N,T,D<:Tuple}(cfg::$Config{N,T,D}) = $Config{N,T,D}(cfg.seeds, map(copy, cfg.duals))

chunksize{N}(::$Config{N}) = N
chunksize{N}(::Tuple{Vararg{$Config{N}}}) = N
@inline chunksize{N}(::$Config{N}) = N
@inline chunksize{N}(::Tuple{Vararg{$Config{N}}}) = N
end
end

Expand Down
99 changes: 64 additions & 35 deletions src/dual.jl
@@ -1,3 +1,5 @@
const ExternalReal = Union{subtypes(Real)...}

########
# Dual #
########
Expand Down Expand Up @@ -54,13 +56,40 @@ degree{N,T}(::Type{Dual{N,T}}) = 1 + degree(T)

macro ambiguous(ex)
def = ex.head == :macrocall ? ex.args[2] : ex
f = def.args[1].args[1].args[1]
sig = def.args[1]
body = def.args[2]
f = isa(sig.args[1], Expr) && sig.args[1].head == :curly ? sig.args[1].args[1] : sig.args[1]
a, b = sig.args[2].args[1], sig.args[3].args[1]
Ta, Tb = sig.args[2].args[2], sig.args[3].args[2]
if isa(a, Symbol) && isa(b, Symbol) && isa(Ta, Symbol) && isa(Tb, Symbol)
if Ta == :Real && Tb == :Dual
return quote
@inline $(f){A<:ExternalReal,B<:Dual}($(a)::Dual{0,A}, $(b)::Dual{0,B}) = Dual($(f)(value(a), value(b)))
@inline $(f){M,A<:ExternalReal,B<:Dual}($(a)::Dual{0,A}, $(b)::Dual{M,B}) = $(f)(value(a), b)
@inline $(f){N,A<:ExternalReal,B<:Dual}($(a)::Dual{N,A}, $(b)::Dual{0,B}) = $(f)(a, value(b))
@inline $(f){N,A<:ExternalReal,B<:Dual}($(a)::Dual{N,A}, $(b)::Dual{N,B}) = $(body)
@inline $(f){N,M,A<:ExternalReal,B<:Dual}($(a)::Dual{N,A}, $(b)::Dual{M,B}) = $(body)
$(esc(ex))
end
elseif Ta == :Dual && Tb == :Real
return quote
@inline $(f){A<:Dual,B<:ExternalReal}($(a)::Dual{0,A}, $(b)::Dual{0,B}) = Dual($(f)(value(a), value(b)))
@inline $(f){M,A<:Dual,B<:ExternalReal}($(a)::Dual{0,A}, $(b)::Dual{M,B}) = $(f)(value(a), b)
@inline $(f){N,A<:Dual,B<:ExternalReal}($(a)::Dual{N,A}, $(b)::Dual{0,B}) = $(f)(a, value(b))
@inline $(f){N,A<:Dual,B<:ExternalReal}($(a)::Dual{N,A}, $(b)::Dual{N,B}) = $(body)
@inline $(f){N,M,A<:Dual,B<:ExternalReal}($(a)::Dual{N,A}, $(b)::Dual{M,B}) = $(body)
$(esc(ex))
end
else
return esc(ex)
end
end
return quote
$(f)(a::Dual, b::Dual) = error("npartials($(typeof(a))) != npartials($(typeof(b)))")
@inline $(f){N,M,A<:Real,B<:Real}(a::Dual{N,A}, b::Dual{M,B}) = error("npartials($(typeof(a))) != npartials($(typeof(b)))")
if !(in($f, (isequal, ==, isless, <, <=, <)))
$(f)(a::Dual{0}, b::Dual{0}) = Dual($(f)(value(a), value(b)))
$(f)(a::Dual{0}, b::Dual) = $(f)(value(a), b)
$(f)(a::Dual, b::Dual{0}) = $(f)(a, value(b))
@inline $(f){A<:Real,B<:Real}(a::Dual{0,A}, b::Dual{0,B}) = Dual($(f)(value(a), value(b)))
@inline $(f){M,A<:Real,B<:Real}(a::Dual{0,A}, b::Dual{M,B}) = $(f)(value(a), b)
@inline $(f){N,A<:Real,B<:Real}(a::Dual{N,A}, b::Dual{0,B}) = $(f)(a, value(b))
end
$(esc(ex))
end
Expand Down Expand Up @@ -160,7 +189,6 @@ Base.promote_rule{N,A<:Real,B<:Real}(::Type{Dual{N,A}}, ::Type{B}) = Dual{N,prom
Base.promote_rule{N,A<:Real,B<:Real}(::Type{A}, ::Type{Dual{N,B}}) = Dual{N,promote_type(A, B)}

Base.convert(::Type{Dual}, n::Dual) = n
Base.convert{N1,N2,T<:Real}(D::Type{Dual{N1,T}}, n::Dual{N2}) = error("can't convert $(typeof(n)) to $(D)")
Base.convert{N,T<:Real}(::Type{Dual{N,T}}, n::Dual{N}) = Dual(convert(T, value(n)), convert(Partials{N,T}, partials(n)))
Base.convert{D<:Dual}(::Type{D}, n::D) = n
Base.convert{N,T<:Real}(::Type{Dual{N,T}}, x::Real) = Dual(convert(T, x), zero(Partials{N,T}))
Expand All @@ -181,12 +209,12 @@ Base.float{N,T}(n::Dual{N,T}) = Dual{N,promote_type(T, Float16)}(n)
#----------------------#

@ambiguous @inline @compat(Base.:+){N}(n1::Dual{N}, n2::Dual{N}) = Dual(value(n1) + value(n2), partials(n1) + partials(n2))
@inline @compat(Base.:+)(n::Dual, x::Real) = Dual(value(n) + x, partials(n))
@inline @compat(Base.:+)(x::Real, n::Dual) = n + x
@ambiguous @inline @compat(Base.:+)(n::Dual, x::Real) = Dual(value(n) + x, partials(n))
@ambiguous @inline @compat(Base.:+)(x::Real, n::Dual) = n + x

@ambiguous @inline @compat(Base.:-){N}(n1::Dual{N}, n2::Dual{N}) = Dual(value(n1) - value(n2), partials(n1) - partials(n2))
@inline @compat(Base.:-)(n::Dual, x::Real) = Dual(value(n) - x, partials(n))
@inline @compat(Base.:-)(x::Real, n::Dual) = Dual(x - value(n), -(partials(n)))
@ambiguous @inline @compat(Base.:-)(n::Dual, x::Real) = Dual(value(n) - x, partials(n))
@ambiguous @inline @compat(Base.:-)(x::Real, n::Dual) = Dual(x - value(n), -(partials(n)))
@inline @compat(Base.:-)(n::Dual) = Dual(-(value(n)), -(partials(n)))

# Multiplication #
Expand All @@ -200,8 +228,8 @@ Base.float{N,T}(n::Dual{N,T}) = Dual{N,promote_type(T, Float16)}(n)
return Dual(v1 * v2, _mul_partials(partials(n1), partials(n2), v2, v1))
end

@inline @compat(Base.:*)(n::Dual, x::Real) = Dual(value(n) * x, partials(n) * x)
@inline @compat(Base.:*)(x::Real, n::Dual) = n * x
@ambiguous @inline @compat(Base.:*)(n::Dual, x::Real) = Dual(value(n) * x, partials(n) * x)
@ambiguous @inline @compat(Base.:*)(x::Real, n::Dual) = n * x

# Division #
#----------#
Expand All @@ -211,45 +239,41 @@ end
return Dual(v1 / v2, _div_partials(partials(n1), partials(n2), v1, v2))
end

@inline function @compat(Base.:/)(x::Real, n::Dual)
@ambiguous @inline function @compat(Base.:/)(x::Real, n::Dual)
v = value(n)
divv = x / v
return Dual(divv, -(divv / v) * partials(n))
end

@inline @compat(Base.:/)(n::Dual, x::Real) = Dual(value(n) / x, partials(n) / x)
@ambiguous @inline @compat(Base.:/)(n::Dual, x::Real) = Dual(value(n) / x, partials(n) / x)

# Exponentiation #
#----------------#

for f in (macroexpand(:(@compat(Base.:^))), :(NaNMath.pow))
@eval begin
@ambiguous @inline function ($f){N}(n1::Dual{N}, n2::Dual{N})
if isconstant(n2)
return $(f)(n1, value(n2))
else
v1, v2 = value(n1), value(n2)
expv = ($f)(v1, v2)
powval = v2 * ($f)(v1, v2 - 1)
logval = expv * log(v1)
new_partials = _mul_partials(partials(n1), partials(n2), powval, logval)
return Dual(expv, new_partials)
end
v1, v2 = value(n1), value(n2)
expv = ($f)(v1, v2)
powval = v2 * ($f)(v1, v2 - 1)
logval = isconstant(n2) ? one(expv) : expv * log(v1)
new_partials = _mul_partials(partials(n1), partials(n2), powval, logval)
return Dual(expv, new_partials)
end

@inline ($f)(::Base.Irrational{:e}, n::Dual) = exp(n)
end

for T in (:Integer, :Rational, :Real)
@eval begin
@inline function ($f)(n::Dual, x::$(T))
@ambiguous @inline function ($f)(n::Dual, x::$(T))
v = value(n)
expv = ($f)(v, x)
deriv = x * ($f)(v, x - 1)
return Dual(expv, deriv * partials(n))
end

@inline function ($f)(x::$(T), n::Dual)
@ambiguous @inline function ($f)(x::$(T), n::Dual)
v = value(n)
expv = ($f)(x, v)
deriv = expv*log(x)
Expand Down Expand Up @@ -334,14 +358,19 @@ end
return Dual(h, (vx/h) * partials(x) + (vy/h) * partials(y) + (vz/h) * partials(z))
end

@inline Base.hypot{N}(x::Dual{N}, y::Dual{N}) = calc_hypot(x, y)
@inline Base.hypot(x::Dual, y::Real) = calc_hypot(x, y)
@inline Base.hypot(x::Real, y::Dual) = calc_hypot(x, y)
@ambiguous @inline Base.hypot{N}(x::Dual{N}, y::Dual{N}) = calc_hypot(x, y)
@ambiguous @inline Base.hypot(x::Dual, y::Real) = calc_hypot(x, y)
@ambiguous @inline Base.hypot(x::Real, y::Dual) = calc_hypot(x, y)

for A in (:(Dual{N}), :Real), B in (:(Dual{N}), :Real), C in (:(Dual{N}), :Real)
(A == B == C == :Real) && continue
@eval(@inline Base.hypot{N}(x::$A, y::$B, z::$C) = calc_hypot(x, y, z))
end
@inline Base.hypot(x::Dual, y::Dual, z::Dual) = calc_hypot(x, y, z)

@inline Base.hypot(x::Real, y::Dual, z::Dual) = calc_hypot(x, y, z)
@inline Base.hypot(x::Dual, y::Real, z::Dual) = calc_hypot(x, y, z)
@inline Base.hypot(x::Dual, y::Dual, z::Real) = calc_hypot(x, y, z)

@inline Base.hypot(x::Dual, y::Real, z::Real) = calc_hypot(x, y, z)
@inline Base.hypot(x::Real, y::Dual, z::Real) = calc_hypot(x, y, z)
@inline Base.hypot(x::Real, y::Real, z::Dual) = calc_hypot(x, y, z)

@inline sincos(n) = (sin(n), cos(n))

Expand All @@ -362,8 +391,8 @@ end
end

@ambiguous @inline Base.atan2{N}(y::Dual{N}, x::Dual{N}) = calc_atan2(y, x)
@inline Base.atan2(y::Real, x::Dual) = calc_atan2(y, x)
@inline Base.atan2(y::Dual, x::Real) = calc_atan2(y, x)
@ambiguous @inline Base.atan2(y::Real, x::Dual) = calc_atan2(y, x)
@ambiguous @inline Base.atan2(y::Dual, x::Real) = calc_atan2(y, x)

###################
# Pretty Printing #
Expand Down
4 changes: 3 additions & 1 deletion test/MiscTest.jl
Expand Up @@ -54,7 +54,9 @@ test_nested_jacobian_output = [-sin(1) 0.0 0.0;
-0.0 -0.0 -0.0;
-0.0 -0.0 -sin(3)]

@test_approx_eq ForwardDiff.jacobian(x -> ForwardDiff.jacobian(sin, x), [1, 2, 3]) test_nested_jacobian_output
sin_jacobian = x -> ForwardDiff.jacobian(y -> broadcast(sin, y), x)

@test_approx_eq ForwardDiff.jacobian(sin_jacobian, [1., 2., 3.]) test_nested_jacobian_output

# Issue #59 example #
#-------------------#
Expand Down
7 changes: 5 additions & 2 deletions test/SIMDTest.jl
@@ -1,7 +1,7 @@
module SIMDTest

using Base.Test
using ForwardDiff: Dual
using ForwardDiff: Dual, valtype

const DUALS = (Dual(1., 2., 3., 4.),
Dual(1., 2., 3., 4., 5.),
Expand Down Expand Up @@ -33,7 +33,10 @@ for D in map(typeof, DUALS)

exp_bitcode = sprint(io -> code_llvm(io, ^, (D, D)))
@test ismatch(r"fadd \<.*?x double\>", exp_bitcode)
@test ismatch(r"fmul \<.*?x double\>", exp_bitcode)
if !(valtype(D) <: Dual)
# see https://github.com/JuliaDiff/ForwardDiff.jl/issues/167
@test ismatch(r"fmul \<.*?x double\>", exp_bitcode)
end

sum_bitcode = sprint(io -> code_llvm(io, simd_sum, (Vector{D},)))
@test ismatch(r"fadd \<.*?x double\>", sum_bitcode)
Expand Down

0 comments on commit 25b93c2

Please sign in to comment.