-
-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathSimpleNonlinearSolveReverseDiffExt.jl
60 lines (51 loc) · 2.58 KB
/
SimpleNonlinearSolveReverseDiffExt.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
module SimpleNonlinearSolveReverseDiffExt
using ArrayInterface, DiffEqBase, ReverseDiff, SciMLBase, SimpleNonlinearSolve
import ReverseDiff: TrackedArray, TrackedReal
import SimpleNonlinearSolve: __internal_solve_up
function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0, u0_changed,
p::TrackedArray, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0::TrackedArray, u0_changed,
p, p_changed, alg, args...; kwargs...)
return ReverseDiff.track(__internal_solve_up, prob, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
end
function __internal_solve_up(prob::NonlinearProblem, sensealg,
u0::AbstractArray{<:TrackedReal}, u0_changed, p::AbstractArray{<:TrackedReal},
p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, ArrayInterface.aos_to_soa(u0), true,
ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
function __internal_solve_up(prob::NonlinearProblem, sensealg, u0, u0_changed,
p::AbstractArray{<:TrackedReal}, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
function __internal_solve_up(prob::NonlinearProblem, sensealg,
u0::AbstractArray{<:TrackedReal}, u0_changed, p, p_changed, alg, args...; kwargs...)
return __internal_solve_up(
prob, sensealg, u0, true, ArrayInterface.aos_to_soa(p), true, alg, args...; kwargs...)
end
ReverseDiff.@grad function __internal_solve_up(
prob::NonlinearProblem, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = DiffEqBase._solve_adjoint(
prob, sensealg, ReverseDiff.value(u0), ReverseDiff.value(p),
SciMLBase.ReverseDiffOriginator(), alg, args...; kwargs...)
function ∇__internal_solve_up(_args...)
∂prob, ∂sensealg, ∂u0, ∂p, ∂originator, ∂args... = ∇internal(_args...)
return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...)
end
return Array(out), ∇__internal_solve_up
end
end