-
Notifications
You must be signed in to change notification settings - Fork 18
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement backend for ReverseDiff #29
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
fd68ec6
Move asarray to main file
sethaxen 7cb5c19
Create reversediff.jl
sethaxen 9a3089e
Conditionally load ReverseDiff
sethaxen d6b412d
Merge branch 'master' into reversediff
sethaxen 7ca5da5
Merge branch 'master' into reversediff
sethaxen 125be12
Add ReverseDiff tests
sethaxen 82a1c3c
Use Jacobian as primitive (for now)
sethaxen a1028f3
Overload primal_value
sethaxen b1a6507
Use ReverseDiff's own API when available
sethaxen 1e1f92c
Speed up derivative
sethaxen 6a7c0e5
Use DiffResults' functions for value_and_XXX functions
sethaxen c4a084e
Increment version number
sethaxen 1e3b73c
Add docstring
sethaxen 29323b0
Subtype AbstractReverseMode
sethaxen 07cae79
Merge branch 'master' into reversediff
mohamed82008 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
using .ReverseDiff: ReverseDiff, DiffResults | ||
|
||
primal_value(x::ReverseDiff.TrackedReal) = ReverseDiff.value(x) | ||
|
||
""" | ||
ReverseDiffBackend | ||
|
||
AD backend that uses reverse mode with ReverseDiff.jl. | ||
""" | ||
struct ReverseDiffBackend <: AbstractReverseMode end | ||
|
||
@primitive function jacobian(ba::ReverseDiffBackend, f, xs...) | ||
xs_arr = map(asarray, xs) | ||
tape = ReverseDiff.JacobianTape(xs_arr) do (xs_arr...) | ||
xs_new = map(xs, xs_arr) do x, x_arr | ||
return x isa Number ? only(x_arr) : x_arr | ||
end | ||
return asarray(f(xs_new...)) | ||
end | ||
results = ReverseDiff.jacobian!(tape, xs_arr) | ||
return map(xs, results) do x, result | ||
return x isa Number ? vec(result) : result | ||
end | ||
end | ||
function jacobian(ba::ReverseDiffBackend, f, xs::AbstractArray...) | ||
return ReverseDiff.jacobian(asarray ∘ f, xs) | ||
end | ||
|
||
function derivative(ba::ReverseDiffBackend, f, xs::Number...) | ||
tape = ReverseDiff.InstructionTape() | ||
xs_tracked = ReverseDiff.TrackedReal.(xs, zero.(xs), Ref(tape)) | ||
y_tracked = f(xs_tracked...) | ||
ReverseDiff.seed!(y_tracked) | ||
ReverseDiff.reverse_pass!(tape) | ||
return ReverseDiff.deriv.(xs_tracked) | ||
end | ||
|
||
function gradient(ba::ReverseDiffBackend, f, xs::AbstractArray...) | ||
return ReverseDiff.gradient(f, xs) | ||
end | ||
|
||
function hessian(ba::ReverseDiffBackend, f, x::AbstractArray) | ||
return (ReverseDiff.hessian(f, x),) | ||
end | ||
|
||
function value_and_gradient(ba::ReverseDiffBackend, f, x::AbstractArray) | ||
result = DiffResults.GradientResult(x) | ||
cfg = ReverseDiff.GradientConfig(x) | ||
ReverseDiff.gradient!(result, f, x, cfg) | ||
return DiffResults.value(result), (DiffResults.derivative(result),) | ||
end | ||
|
||
function value_and_hessian(ba::ReverseDiffBackend, f, x) | ||
result = DiffResults.HessianResult(x) | ||
cfg = ReverseDiff.HessianConfig(result, x) | ||
ReverseDiff.hessian!(result, f, x, cfg) | ||
return DiffResults.value(result), (DiffResults.hessian(result),) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
using AbstractDifferentiation | ||
using Test | ||
using ReverseDiff | ||
|
||
@testset "ReverseDiffBackend" begin | ||
backends = [@inferred(AD.ReverseDiffBackend())] | ||
@testset for backend in backends | ||
@testset "Derivative" begin | ||
test_derivatives(backend) | ||
end | ||
@testset "Gradient" begin | ||
test_gradients(backend) | ||
end | ||
@testset "Jacobian" begin | ||
test_jacobians(backend) | ||
end | ||
@testset "Hessian" begin | ||
test_hessians(backend) | ||
end | ||
@testset "jvp" begin | ||
test_jvp(backend) | ||
end | ||
@testset "j′vp" begin | ||
test_j′vp(backend) | ||
end | ||
@testset "Lazy Derivative" begin | ||
test_lazy_derivatives(backend) | ||
end | ||
@testset "Lazy Gradient" begin | ||
test_lazy_gradients(backend) | ||
end | ||
@testset "Lazy Jacobian" begin | ||
test_lazy_jacobians(backend) | ||
end | ||
@testset "Lazy Hessian" begin | ||
test_lazy_hessians(backend) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the recommended form? In DiffEqSensitivity we use a slightly different set of instructions, see
https://github.com/SciML/DiffEqSensitivity.jl/blob/8601419ad910455ef8842597a8a13c615768477f/src/derivative_wrappers.jl#L406
(including
unseed!
,input_hook
,output_hook
)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps, I wrote this a few months back and don't recall the reasons. Thanks the reference, I'll check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ReverseDiff is a little unclear about what constitutes its API. e.g. there's its documented API, and the
api
source directory. Some functions are declared in the latter but not documented (e.g.input_hook
andoutput_hook
), so they may be part of the API, whileunseed!
, for example, does not appear to be part of the API at all. Can you comment on why non-API functions were used in DiffEqSensitivity?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure. Probably to support some more general cases than
seeded_reverse_pass!
does? CC @ChrisRackauckasThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's to make it non-allocating and more generic. I got those hooks directly from Jarrett.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-allocating would only be relevant for mutating
f!
, right? Currently I don't think mutatingf!
are supported by the AD API.