Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.0"
[deps]
Cassette = "7057c7e9-c182-5462-911a-8362d720325c"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"

Expand Down
144 changes: 75 additions & 69 deletions src/dual_context.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,39 @@
using Cassette

import Cassette: overdub
using ChainRules

import Cassette: overdub, Context, nametype, canrecurse
import ForwardDiff: Dual, value, partials, Partials, tagtype, dualtag

Cassette.@context DualContext

using ForwardDiff: Dual, value, partials, Partials,
tagtype, ≺, DualMismatchError
include("tag.jl")

@inline _dominant_dual(tag, maxi, i) = maxi
const TaggedCtx{T} = Context{nametype(DualContext), T}

Base.@pure @inline function _dominant_dual(::Val{T}, maxi, i, x::Dual{S}, tail...) where {S, T}
if T === nothing || (T !== S && T ≺ S)
_dominant_dual(Val{S}(), i, i-1, tail...)
else
_dominant_dual(Val{T}(), maxi, i-1, tail...)
end
function dualcontext()
Cassette.disablehooks(DualContext(metadata=dualtag()))
end

Base.@pure @inline function _dominant_dual(tag, maxi, i, x, tail...)
_dominant_dual(tag, maxi, i-1, tail...)
# Calls to `dualtag` are aware of the current context.
# Note that the tags produced in the current context are Tag{T} where T is the metadata type of the context
@inline function overdub(ctx::C, ::typeof(dualtag)) where {T, C <: TaggedCtx{T}}
Tag{T}()
end

@inline function dominant_dual(xs...)
_dominant_dual(Val{nothing}(), 0, length(xs), reverse(xs)...)
@inline function overdub(ctx::TaggedCtx, ::typeof(find_dual), args...)
find_dual(args...)
end

@inline _value(::Val{T}, x) where T = x
@inline _value(::Val{T}, d::Dual{T}) where T = value(d)
@inline function _value(::Val{T}, d::Dual{S}) where {T,S}
if S ≺ T
d
else
throw(DualMismatchError(T,S))
end
@inline function overdub(ctx::TaggedCtx, ::typeof(canrecurse), args...)
canrecurse(args...)
end

@inline Base.@propagate_inbounds _partials(::Val{T}, x, i...) where T = partials(x, i...)
@inline Base.@propagate_inbounds _partials(::Val{T}, d::Dual{T}, i...) where T = partials(d, i...)
@inline function _partials(::Val{T}, d::Dual{S}, i...) where {T,S}
if S ≺ T
zero(d)
else
throw(DualMismatchError(T,S))
end
end
@inline _value(::Any, x) = x
@inline _value(::T, d::Dual{T}) where T = value(d)


@inline _partials(::Any, x, i...) = partials(x, i...)
@inline _partials(::T, d::Dual{T}, i...) where T = partials(d, i...)
@inline _partials(::T, x::Dual{S}, i...) where {T,S} = partials(zero(Dual{Tag{T}}), i...)

using ChainRules
using ChainRulesCore
Expand All @@ -59,54 +48,71 @@ end
ChainRulesCore.mul_zero(::Zero, p::Partials) = zero(p)
ChainRulesCore.mul_zero(p::Partials, ::Zero) = zero(p)

@inline _values(S, xs) = map(x->_value(S, x), xs)
@inline _partialss(S, xs) = map(x->_partials(S, x), xs)

@inline overdub(ctx::DualContext, f, a...) = Cassette.recurse(ctx, f, a...)
@inline overdub(ctx::DualContext, f, a) = _overdub(ctx, f, a)
@inline overdub(ctx::DualContext, f, a, b) = _overdub(ctx, f, a, b)
@inline overdub(ctx::DualContext, f, a, b, c) = _overdub(ctx, f, a, b, c)
@inline overdub(ctx::DualContext, f, a, b, c, d) = _overdub(ctx, f, a, b, c, d)
@inline function _frule_overdub(ctx, tag::T, f, args...) where T
res = Cassette.recurse(ctx, frule, f, _values(tag, args)...)

@inline function _overdub(ctx, f, args...)
# find the position of the dual number with the highest
# precedence (dominant) tag
idx = dominant_dual(args...)
if res === nothing
# this means there is no frule (majority of all calls)
return Cassette.recurse(ctx, f, args...)
else
# this means a result and one or more partial function
# was computed
vals, ∂s = res
ps = _partialss(tag, args)

if !(∂s isa Tuple)
# a single function (scalar output)
d = overdub(ctx, ∂s, ps...)
return Dual{T}(vals, d)
else
# many partial functions (as many as outputs)
return map(vals, ∂s) do val, ∂
Dual{T}(val, overdub(ctx, ∂, ps...))
end
end
end
end

@inline anydual(x::Dual, ys...) = true
@inline anydual(x, ys...) = anydual(ys...)
@inline anydual() = false

@inline function overdub(ctx::TaggedCtx{T}, f, args...) where {T}
if nfields(args) > 4
return Cassette.recurse(ctx, f, args...)
end

if !Cassette.canrecurse(ctx, f, args...)
return Cassette.fallback(ctx, f, args...)
end
# a short-cut for compilation to cop out if this
# call doesn't deal with any dual numbers
if !anydual(args...)
return Cassette.recurse(ctx, f, args...)
end

# find the position of the dual number with the current
# context's tag or a child tag.
idx = find_dual(Tag{T}(), args...)
if idx === 0
# none of the arguments are dual
Cassette.recurse(ctx, f, args...)
return Cassette.recurse(ctx, f, args...)
else
# most dominant tag on the duals
dtag = tagtype(args[idx])

# We may now start operating for a completely
# different tag -- this is OK.
tag = tagtype(fieldtype(typeof(args), idx))()
# call ChainRules.frule to execute `f` and
# get a function that computes the partials
res = overdub(ctx, frule, f,
map(x->_value(Val{dtag}(), x), args)...)

if res === nothing
# this means there is no frule (majority of all calls)
return Cassette.recurse(ctx, f, args...)
else
# this means a result and one or more partial function
# was computed
vals, ∂s = res
ps = map(x->_partials(Val{dtag}(), x), args)

if !(∂s isa Tuple)
# a single function scalar output
d = overdub(ctx, ∂s, ps...)
return Dual{dtag}(vals, d)
else
# many partial functions (as many as outputs)
return map(vals, ∂s) do val, ∂
Dual{dtag}(val, overdub(ctx, ∂, ps...))
end
end
end
return _frule_overdub(ctx, tag, f, args...)
end
end

function dualrun(f, args...)
ctx = DualContext()
ctx = dualcontext()
Cassette.overdub(ctx, f, args...)
end

Expand Down
17 changes: 17 additions & 0 deletions src/tag.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import ForwardDiff: Dual

struct Tag{Parent} end

_find_dual(ctx::Type{T}, i) where {T} = 0
_find_dual(ctx::Type{T}, i, x::Type{<:Dual{T}}, xs...) where {T} = i
_find_dual(ctx::Type{T}, i, x, xs...) where {T} = _find_dual(ctx, i-1, xs...)

innertagtype(::Type{Tag{T}}) where T = T
@inline @generated function find_dual(T::Tag, xs...)
idx = _find_dual(T, length(xs), reverse(xs)...)
idx === 0 ?
_find_dual(innertagtype(T), length(xs), reverse(xs)...) : idx
end

# Base case where T is not a Tag
@inline @generated find_dual(T, xs...) = 0
7 changes: 1 addition & 6 deletions test/dualtest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DualTest
using Test
using Random
using ForwardDiff
using ForwardDiff: Partials, Dual, value, partials
using ForwardDiff: Partials, Dual, value, partials, tagtype

using Cassette

Expand All @@ -25,8 +25,6 @@ end

import Calculus

struct TestTag end

samerng() = MersenneTwister(1)

# By lower-bounding the Int range at 2, we avoid cases where differentiating an
Expand All @@ -38,9 +36,6 @@ dual_isapprox(a, b) = isapprox(a, b)
dual_isapprox(a::Dual{T,T1,T2}, b::Dual{T,T3,T4}) where {T,T1,T2,T3,T4} = isapprox(value(a), value(b)) && isapprox(partials(a), partials(b))
dual_isapprox(a::Dual{T,T1,T2}, b::Dual{T3,T4,T5}) where {T,T1,T2,T3,T4,T5} = error("Tags don't match")

ForwardDiff.:≺(::Type{TestTag()}, ::Int) = true
ForwardDiff.:≺(::Int, ::Type{TestTag()}) = false

for N in (0,3), M in (1,4), V in (Int, Float32)
println(" ...testing Dual{TestTag(),$V,$N} and Dual{TestTag(),Dual{TestTag(),$V,$M},$N}")

Expand Down
51 changes: 51 additions & 0 deletions test/nested.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using Test
using ForwardDiff2: Dual, partials, dualrun, find_dual

const Tag1 = Tag{Nothing}
const Tag2 = Tag{Tag1}

@testset "find_dual" begin
@test find_dual(Tag1, 0, 0) == 0
@test find_dual(Tag1, Dual{Tag1}(1,1), 1) == 1
@test find_dual(Tag1, 1, Dual{Tag1}(1,1)) == 2
@test find_dual(Tag2, 1, Dual{Tag1}(1,1)) == 2
@test find_dual(Tag2, Dual{Tag1}(1,1), 1) == 1
@test find_dual(Tag1, Dual{Tag2}(1,1), 1) == 0

@test find_dual(Tag2, Dual{Tag1}, 1) == 0
@test find_dual(Tag1, Dual{Tag2}, 1) == 0
@test find_dual(Tag1, Dual{Tag1}, 1) == 0
end

using Cassette
Cassette.@context TestCtx

const TaggedTestCtx{T} = Cassette.Context{Cassette.nametype(TestCtx), T}

@inline function find_dual_ctx(::TaggedTestCtx{T}, args...) where T
find_dual(Tag{T}, args...)
end

@testset "find_dual_ctx" begin
ctx = TaggedTestCtx(metadata=nothing)
@test find_dual_ctx(ctx, 1, Dual{Tag1}(1,1)) == 2
@test find_dual_ctx(ctx, Dual{Tag1}(1,1), 1) == 1
@test find_dual_ctx(ctx, Dual{Tag1}(1,1), Dual{Tag2}(1,1)) == 1
@test find_dual_ctx(ctx, Dual{Tag2}(1,1), Dual{Tag1}(1,1)) == 2

ctx = TaggedTestCtx(metadata=Tag2())
@test find_dual_ctx(ctx, Dual{Tag1}(1,1), Dual{Tag2}(1,1)) == 2
@test find_dual_ctx(ctx, Dual{Tag2}(1,1), Dual{Tag1}(1,1)) == 1
end

function D(f, x)
dualrun() do
xx = Dual(x, one(x))
partials(f(xx), 1)
end
end


@testset "nested differentiation" begin
@test D(x -> x * D(y -> x + y, 1), 1) === 1
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
@time include("dualarray.jl")
@time include("nested.jl")
@time include("jacobian.jl")