Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement backend for ReverseDiff #29

Merged
merged 15 commits into from
Jan 26, 2022
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@ Compat = "3"
ExprTools = "0.1"
ForwardDiff = "0.10"
Requires = "0.5, 1"
ReverseDiff = "1"
julia = "1"

[extras]
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "Zygote"]
test = ["Test", "FiniteDifferences", "ForwardDiff", "Random", "ReverseDiff", "Zygote"]
4 changes: 4 additions & 0 deletions src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -640,8 +640,12 @@ function zero_matrix_like(x)
throw("The function `zero_matrix_like` is not defined for the type $(typeof(x)).")
end

@inline asarray(x) = [x]
@inline asarray(x::AbstractArray) = x

function __init__()
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" include("reversediff.jl")
@require FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" include("finitedifferences.jl")
end

Expand Down
3 changes: 0 additions & 3 deletions src/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ end
# support arrays and tuples
@noinline step_toward(x, v, h) = x .+ h .* v

@inline asarray(x) = [x]
@inline asarray(x::AbstractArray) = x

getchunksize(::Nothing) = Nothing
getchunksize(::Val{N}) where {N} = N

Expand Down
58 changes: 58 additions & 0 deletions src/reversediff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using .ReverseDiff: ReverseDiff, DiffResults

primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x)

"""
ReverseDiffBackend

AD backend that uses reverse mode with ReverseDiff.jl.
"""
struct ReverseDiffBackend <: AbstractReverseMode end

@primitive function jacobian(ba::ReverseDiffBackend, f, xs...)
xs_arr = map(asarray, xs)
tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...)
xs_new = map(xs, xs_arr) do x, x_arr
return x isa Number ? only(x_arr) : x_arr
end
return asarray(f(xs_new...))
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the recommended form? In DiffEqSensitivity we use a slightly different set of instructions, see
https://github.com/SciML/DiffEqSensitivity.jl/blob/8601419ad910455ef8842597a8a13c615768477f/src/derivative_wrappers.jl#L406
(including unseed!, input_hook, output_hook)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps, I wrote this a few months back and don't recall the reasons. Thanks the reference, I'll check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReverseDiff is a little unclear about what constitutes its API. e.g. there's its documented API, and the api source directory. Some functions are declared in the latter but not documented (e.g. input_hook and output_hook), so they may be part of the API, while unseed!, for example, does not appear to be part of the API at all. Can you comment on why non-API functions were used in DiffEqSensitivity?

Copy link
Contributor

@frankschae frankschae Jan 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure. Probably to support some more general cases than seeded_reverse_pass! does? CC @ChrisRackauckas

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's to make it non-allocating and more generic. I got those hooks directly from Jarrett.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-allocating would only be relevant for mutating f!, right? Currently I don't think mutating f! are supported by the AD API.

results = ReverseDiff.jacobian!(tape, xs_arr)
return map(xs, results) do x, result
return x isa Number ? vec(result) : result
end
end
function jacobian(ba::ReverseDiffBackend, f, xs::AbstractArray...)
return ReverseDiff.jacobian(asarray ∘ f, xs)
end

function derivative(ba::ReverseDiffBackend, f, xs::Number...)
tape = ReverseDiff.InstructionTape()
xs_tracked = ReverseDiff.TrackedReal.(xs, zero.(xs), Ref(tape))
y_tracked = f(xs_tracked...)
ReverseDiff.seed!(y_tracked)
ReverseDiff.reverse_pass!(tape)
return ReverseDiff.deriv.(xs_tracked)
end

function gradient(ba::ReverseDiffBackend, f, xs::AbstractArray...)
return ReverseDiff.gradient(f, xs)
end

function hessian(ba::ReverseDiffBackend, f, x::AbstractArray)
return (ReverseDiff.hessian(f, x),)
end

function value_and_gradient(ba::ReverseDiffBackend, f, x::AbstractArray)
result = DiffResults.GradientResult(x)
cfg = ReverseDiff.GradientConfig(x)
ReverseDiff.gradient!(result, f, x, cfg)
return DiffResults.value(result), (DiffResults.derivative(result),)
end

function value_and_hessian(ba::ReverseDiffBackend, f, x)
result = DiffResults.HessianResult(x)
cfg = ReverseDiff.HessianConfig(result, x)
ReverseDiff.hessian!(result, f, x, cfg)
return DiffResults.value(result), (DiffResults.hessian(result),)
end
39 changes: 39 additions & 0 deletions test/reversediff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using AbstractDifferentiation
using Test
using ReverseDiff

@testset "ReverseDiffBackend" begin
backends = [@inferred(AD.ReverseDiffBackend())]
@testset for backend in backends
@testset "Derivative" begin
test_derivatives(backend)
end
@testset "Gradient" begin
test_gradients(backend)
end
@testset "Jacobian" begin
test_jacobians(backend)
end
@testset "Hessian" begin
test_hessians(backend)
end
@testset "jvp" begin
test_jvp(backend)
end
@testset "j′vp" begin
test_j′vp(backend)
end
@testset "Lazy Derivative" begin
test_lazy_derivatives(backend)
end
@testset "Lazy Gradient" begin
test_lazy_gradients(backend)
end
@testset "Lazy Jacobian" begin
test_lazy_jacobians(backend)
end
@testset "Lazy Hessian" begin
test_lazy_hessians(backend)
end
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ using Test
include("test_utils.jl")
include("defaults.jl")
include("forwarddiff.jl")
include("reversediff.jl")
include("finitedifferences.jl")
end