-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement backend for ReverseDiff #29
Conversation
Codecov Report
@@ Coverage Diff @@
## master #29 +/- ##
==========================================
+ Coverage 80.05% 81.90% +1.85%
==========================================
Files 2 3 +1
Lines 396 431 +35
==========================================
+ Hits 317 353 +36
+ Misses 79 78 -1
Continue to review full report at Codecov.
|
ReverseDiff.deriv!.(y_tracked, ws...) | ||
ReverseDiff.reverse_pass!(tape) | ||
return ReverseDiff.deriv.(x_tracked) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the recommended form? In DiffEqSensitivity we use a slightly different set of instructions, see
https://github.com/SciML/DiffEqSensitivity.jl/blob/8601419ad910455ef8842597a8a13c615768477f/src/derivative_wrappers.jl#L406
(including unseed!
, input_hook
, output_hook
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, I wrote this a few months back and don't recall the reasons. Thanks the reference, I'll check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReverseDiff is a little unclear about what constitutes its API. e.g. there's its documented API, and the api
source directory. Some functions are declared in the latter but not documented (e.g. input_hook
and output_hook
), so they may be part of the API, while unseed!
, for example, does not appear to be part of the API at all. Can you comment on why non-API functions were used in DiffEqSensitivity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. Probably to support some more general cases than seeded_reverse_pass!
does? CC @ChrisRackauckas
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to make it non-allocating and more generic. I got those hooks directly from Jarrett.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-allocating would only be relevant for mutating f!
, right? Currently I don't think mutating f!
are supported by the AD API.
RE tape compilation, I can think of 3 ways to do it, each with their own levels of support. Note that tape compilation requires the function and inputs be available. First, a user can provide the function Second, a user can specify a Third, as Turing does, the compiled tape is memoized with Memoization.jl based on the function and the size and type of the inputs. I think this would also allow for |
@devmotion do you have other ideas regarding supporting tape compilation? |
I'm not sure. I've seen some strange with Memoization (e.g. TuringLang/Turing.jl#1418), in general it feels a bit obscure and like hiding something that maybe should just be more explicit. Is it possible to add AD-specific keyword arguments to the API? Then another alternative could be to provide a user-facing function that returns a compiled tape (basically |
I am a fan of option 1. We can provide an additional function |
@mohammed perhaps the user can just construct the object without providing a compiled tape. Then AD calls AbstractDifferentiation.jl/src/AbstractDifferentiation.jl Lines 236 to 244 in 7cf1724
|
Hmm, I think we may need to define most of the value_and_xyz functions for the compiled RD backend since that might be easier than trying to repurpose the current fallbacks to use the tape correctly instead of the function. (I am not sure though, they might just work. Will need to implement and test it to see.) |
One problem with this is that for the backend to be able to store a new compiled tape, it cannot be specialized on the function. This causes e.g. the computation of the Jacobian to be type-unstable. |
One idea is to store an example output in the backend and use type assertions |
Alternatively, we can warn and fallback on the uncompiled tape behaviour asking the user to recompile a different tape. |
Then we have the same issue with type-stability for
This could work if we only support |
Either way, since it seems a ReverseDiff backend with a compiled tape will use a different struct than one without it, we can address tape compilation in a future PR. Now that we have an implementation with a |
ReverseDiff has no analogous convenience function to `derivative`
I ran a benchmark (derived from the one in #20), which led me to make many of the same overloads as in #20. This PR is now ready for review. Benchmark: using AbstractDifferentiation: AbstractDifferentiation, AD, @primitive, AbstractBackend
import AbstractDifferentiation: primal_value, jacobian, derivative, gradient, hessian, value_and_gradient, value_and_hessian
using ReverseDiff: ReverseDiff, DiffResults
using BenchmarkTools, PrettyTables
f(x) = sum(sin, x)
g(x) = sin.(x)
x = randn(10)
Δx = randn(10)
Δy = randn(10)
@inline asarray(x) = [x]
@inline asarray(x::AbstractArray) = x
primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x)
# just implement the primitive
struct ReverseDiffBackend <: AbstractBackend end
@primitive function jacobian(ba::ReverseDiffBackend, f, xs...)
xs_arr = map(asarray, xs)
tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...)
xs_new = map(xs, xs_arr) do x, x_arr
return x isa Number ? only(x_arr) : x_arr
end
return asarray(f(xs_new...))
end
results = ReverseDiff.jacobian!(tape, xs_arr)
return map(xs, results) do x, result
return x isa Number ? vec(result) : result
end
end
# above plus a few obvious optimizations
struct ReverseDiffBackendPartial <: AbstractBackend end
@primitive function jacobian(ba::ReverseDiffBackendPartial, f, xs...)
xs_arr = map(asarray, xs)
tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...)
xs_new = map(xs, xs_arr) do x, x_arr
return x isa Number ? only(x_arr) : x_arr
end
return asarray(f(xs_new...))
end
results = ReverseDiff.jacobian!(tape, xs_arr)
return map(xs, results) do x, result
return x isa Number ? vec(result) : result
end
end
function jacobian(ba::ReverseDiffBackendPartial, f, xs::AbstractArray...)
return ReverseDiff.jacobian(asarray ∘ f, xs)
end
function derivative(ba::ReverseDiffBackendPartial, f, xs::Number...)
tape = ReverseDiff.InstructionTape()
xs_tracked = ReverseDiff.TrackedReal.(xs, zero.(xs), Ref(tape))
y_tracked = f(xs_tracked...)
ReverseDiff.seed!(y_tracked)
ReverseDiff.reverse_pass!(tape)
return ReverseDiff.deriv.(xs_tracked)
end
function gradient(ba::ReverseDiffBackendPartial, f, xs::AbstractArray...)
return ReverseDiff.gradient(f, xs)
end
function hessian(ba::ReverseDiffBackendPartial, f, x::AbstractArray)
return (ReverseDiff.hessian(f, x),)
end
# above plus overloading some value_and_XXX functions
struct ReverseDiffBackendAll <: AbstractBackend end
@primitive function jacobian(ba::ReverseDiffBackendAll, f, xs...)
xs_arr = map(asarray, xs)
tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...)
xs_new = map(xs, xs_arr) do x, x_arr
return x isa Number ? only(x_arr) : x_arr
end
return asarray(f(xs_new...))
end
results = ReverseDiff.jacobian!(tape, xs_arr)
return map(xs, results) do x, result
return x isa Number ? vec(result) : result
end
end
function jacobian(ba::ReverseDiffBackendAll, f, xs::AbstractArray...)
return ReverseDiff.jacobian(asarray ∘ f, xs)
end
function derivative(ba::ReverseDiffBackendAll, f, xs::Number...)
tape = ReverseDiff.InstructionTape()
xs_tracked = ReverseDiff.TrackedReal.(xs, zero.(xs), Ref(tape))
y_tracked = f(xs_tracked...)
ReverseDiff.seed!(y_tracked)
ReverseDiff.reverse_pass!(tape)
return ReverseDiff.deriv.(xs_tracked)
end
function gradient(ba::ReverseDiffBackendAll, f, xs::AbstractArray...)
return ReverseDiff.gradient(f, xs)
end
function hessian(ba::ReverseDiffBackendAll, f, x::AbstractArray)
return (ReverseDiff.hessian(f, x),)
end
function value_and_gradient(ba::ReverseDiffBackendAll, f, x::AbstractArray)
result = DiffResults.GradientResult(x)
cfg = ReverseDiff.GradientConfig(x)
ReverseDiff.gradient!(result, f, x, cfg)
return DiffResults.value(result), (DiffResults.derivative(result),)
end
function value_and_hessian(ba::ReverseDiffBackendAll, f, x)
result = DiffResults.HessianResult(x)
cfg = ReverseDiff.HessianConfig(result, x)
ReverseDiff.hessian!(result, f, x, cfg)
return DiffResults.value(result), (DiffResults.hessian(result),)
end
# benchmark
suite = BenchmarkGroup()
approaches = [
"jacobian" => ReverseDiffBackend(),
"jacobian partial" => ReverseDiffBackendPartial(),
"jacobian all" => ReverseDiffBackendAll(),
]
for (name, b) in approaches
bg = BenchmarkGroup()
bg["derivative(b,f,::Number)"] = @benchmarkable $(AD.derivative)($b, sin, 0.5)
bg["gradient(b,f,x)"] = @benchmarkable $(AD.gradient)($b, $f, $x)
bg["jacobian(b,f,x)"] = @benchmarkable $(AD.jacobian)($b, $f, $x)
bg["jacobian(b,g,x)"] = @benchmarkable $(AD.jacobian)($b, $g, $x)
bg["jacobian(b,f,::Number)"] = @benchmarkable $(AD.jacobian)($b, sin, 0.5)
bg["hessian"] = @benchmarkable $(AD.hessian)($b, $f, $x)
bg["value_and_gradient"] = @benchmarkable $(AD.value_and_gradient)($b, $f, $x)
bg["value_and_jacobian(b,f,x)"] = @benchmarkable $(AD.value_and_jacobian)($b, $f, $x)
bg["value_and_jacobian(b,g,x)"] = @benchmarkable $(AD.value_and_jacobian)($b, $g, $x)
bg["value_and_jacobian(b,f,::Number)"] = @benchmarkable $(AD.value_and_jacobian)($b, sin, 0.5)
bg["value_and_hessian"] = @benchmarkable $(AD.value_and_hessian)($b, $f, $x)
bg["value_gradient_and_hessian"] = @benchmarkable $(AD.value_gradient_and_hessian)($b, $f, $x)
pushforward = AD.pushforward_function(b, g, x)
bg["pushforward"] = @benchmarkable $pushforward($((Δx,)))
value_and_pushforward = AD.value_and_pushforward_function(b, g, x)
bg["value_and_pushforward"] = @benchmarkable $value_and_pushforward($((Δx,)))
pullback = AD.pullback_function(b, g, x)
bg["pullback"] = @benchmarkable $pullback($Δy)
value_and_pullback = AD.value_and_pullback_function(b, g, x)
bg["value_and_pullback"] = @benchmarkable $value_and_pullback($Δy)
suite[name] = bg
end
tune!(suite; verbose = true)
results = run(suite, verbose = true)
map(sort(collect(keys(results["jacobian"])))) do k
(;
Symbol("function") => k,
(Symbol(approach) => results[approach][k] for (approach, _) in approaches)...
)
end |> pretty_table The last column uses identical implementations to the ones in this PR. ┌──────────────────────────────────┬──────────────────────┬──────────────────────┬──────────────────────┐
│ function │ jacobian │ jacobian partial │ jacobian all │
│ String │ BenchmarkTools.Trial │ BenchmarkTools.Trial │ BenchmarkTools.Trial │
├──────────────────────────────────┼──────────────────────┼──────────────────────┼──────────────────────┤
│ derivative(b,f,::Number) │ Trial(990.800 ns) │ Trial(280.469 ns) │ Trial(279.181 ns) │
│ gradient(b,f,x) │ Trial(4.302 μs) │ Trial(3.427 μs) │ Trial(3.398 μs) │
│ hessian │ Trial(169.063 μs) │ Trial(96.796 μs) │ Trial(99.938 μs) │
│ jacobian(b,f,::Number) │ Trial(1.033 μs) │ Trial(1.015 μs) │ Trial(1.052 μs) │
│ jacobian(b,f,x) │ Trial(4.195 μs) │ Trial(4.115 μs) │ Trial(4.030 μs) │
│ jacobian(b,g,x) │ Trial(2.206 μs) │ Trial(1.856 μs) │ Trial(1.846 μs) │
│ pullback │ Trial(1.974 μs) │ Trial(1.175 μs) │ Trial(1.337 μs) │
│ pushforward │ Trial(3.084 μs) │ Trial(3.055 μs) │ Trial(3.164 μs) │
│ value_and_gradient │ Trial(4.393 μs) │ Trial(4.290 μs) │ Trial(3.457 μs) │
│ value_and_hessian │ Trial(171.106 μs) │ Trial(172.341 μs) │ Trial(108.149 μs) │
│ value_and_jacobian(b,f,::Number) │ Trial(1.145 μs) │ Trial(972.300 ns) │ Trial(1.193 μs) │
│ value_and_jacobian(b,f,x) │ Trial(4.429 μs) │ Trial(4.289 μs) │ Trial(4.233 μs) │
│ value_and_jacobian(b,g,x) │ Trial(2.436 μs) │ Trial(2.032 μs) │ Trial(2.009 μs) │
│ value_and_pullback │ Trial(3.588 μs) │ Trial(1.617 μs) │ Trial(1.565 μs) │
│ value_and_pushforward │ Trial(3.186 μs) │ Trial(3.106 μs) │ Trial(3.174 μs) │
│ value_gradient_and_hessian │ Trial(172.960 μs) │ Trial(171.780 μs) │ Trial(100.681 μs) │
└──────────────────────────────────┴──────────────────────┴──────────────────────┴──────────────────────┘ |
thanks Seth! |
This PR will implement the backend for ReverseDiff. I'll make an attempt as well to support optional tape compilation at least when
pullback_function
orvalue_and_pullback_function
is called.