Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,11 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
v = @view Δ[.., i]
end
if !(Δ isa NoTangent)
ForwardDiff.value.(J'vec(v))
if u0 isa Number
ForwardDiff.value.(J'v)
else
ForwardDiff.value.(J'vec(v))
end
else
zero(p)
end
Expand All @@ -639,10 +643,16 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
num_chunks = length(u0) ÷ chunk_size
num_chunks * chunk_size != length(u0) && (num_chunks += 1)

du0parts = typeof(u0[1:1])[]
du0parts = u0 isa Number ? typeof(u0)[] : typeof(u0[1:1])[]

local _du0

for j in 0:(num_chunks - 1)
local chunk
if ((j + 1) * chunk_size) <= length(u0)
if u0 isa Number
u0dualpart = seed_duals(u0, prob.f,
ForwardDiff.Chunk{chunk_size}())
elseif ((j + 1) * chunk_size) <= length(u0)
chunk = ((j * chunk_size + 1):((j + 1) * chunk_size))
u0chunk = vec(u0)[chunk]
u0dualpart = seed_duals(u0chunk, prob.f,
Expand All @@ -654,16 +664,21 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
ForwardDiff.Chunk{length(chunk)}())
end

u0dualvec = if j == 0
vcat(u0dualpart, u0[((j + 1) * chunk_size + 1):end])
elseif j == num_chunks - 1
vcat(u0[1:(j * chunk_size)], u0dualpart)
if u0 isa Number
u0dual = u0dualpart
else
vcat(u0[1:(j * chunk_size)], u0dualpart,
u0[(((j + 1) * chunk_size) + 1):end])
u0dualvec = if j == 0
vcat(u0dualpart, u0[((j + 1) * chunk_size + 1):end])
elseif j == num_chunks - 1
vcat(u0[1:(j * chunk_size)], u0dualpart)
else
vcat(u0[1:(j * chunk_size)], u0dualpart,
u0[(((j + 1) * chunk_size) + 1):end])
end

u0dual = ArrayInterfaceCore.restructure(u0, u0dualvec)
end

u0dual = ArrayInterfaceCore.restructure(u0, u0dualvec)
if p === nothing || p === DiffEqBase.NullParameters()
pdual = p
else
Expand Down Expand Up @@ -717,14 +732,26 @@ function DiffEqBase._concrete_solve_adjoint(prob, alg,
v = @view Δ[.., i]
end
if !(Δ isa NoTangent)
ForwardDiff.value.(J'vec(v))
if u0 isa Number
ForwardDiff.value.(J'v)
else
ForwardDiff.value.(J'vec(v))
end
else
zero(u0)
end
end
push!(du0parts, vec(_du0))

if !(u0 isa Number)
push!(du0parts, vec(_du0))
end
end

if u0 isa Number
first(_du0)
else
ArrayInterfaceCore.restructure(u0, reduce(vcat, du0parts))
end
ArrayInterfaceCore.restructure(u0, reduce(vcat, du0parts))
end

if originator isa SciMLBase.TrackerOriginator ||
Expand Down
18 changes: 13 additions & 5 deletions src/forward_sensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,8 +156,8 @@ const FORWARD_SENSITIVITY_PARAMETER_COMPATABILITY_MESSAGE = """
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during ODEForwardSensitivityProblem
construction. To work around this issue for complicated cases like nested structs,
look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
construction. To work around this issue for complicated cases like nested structs,
look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl.
"""

Expand Down Expand Up @@ -200,8 +200,8 @@ efficient method.
type for being the initial condition `u0` of an array. This means that
many simple types, such as `Tuple`s and `NamedTuple`s, will work as
parameters in normal contexts but will fail during ODEForwardSensitivityProblem
construction. To work around this issue for complicated cases like nested structs,
look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
construction. To work around this issue for complicated cases like nested structs,
look into defining `p` using `AbstractArray` libraries such as RecursiveArrayTools.jl
or ComponentArrays.jl.

### ODEForwardSensitivityProblem Syntax
Expand All @@ -218,7 +218,7 @@ ODEForwardSensitivityProblem(f::SciMLBase.AbstractODEFunction,u0,
kwargs...)
```

Once constructed, this problem can be used in `solve` just like any other ODEProblem.
Once constructed, this problem can be used in `solve` just like any other ODEProblem.
The solution can be deconstructed into the ODE solution and sensitivities parts using the
`extract_local_sensitivities` function, with the following dispatches:

Expand Down Expand Up @@ -441,6 +441,14 @@ function seed_duals(x::AbstractArray{V}, f,
duals = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, eltype(vec(x))))}.(vec(x), seeds)
end

function seed_duals(x::Number, f,
::ForwardDiff.Chunk{N} = ForwardDiff.Chunk(x, typemax(Int64))) where {V,
T,
N}
seeds = ForwardDiff.construct_seeds(ForwardDiff.Partials{N, typeof(x)})
duals = ForwardDiff.Dual{typeof(ForwardDiff.Tag(f, typeof(x)))}(x, seeds[1])
end

has_continuous_callback(cb::DiscreteCallback) = false
has_continuous_callback(cb::ContinuousCallback) = true
has_continuous_callback(cb::CallbackSet) = !isempty(cb.continuous_callbacks)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ end
@time @safetestset "Literal Adjoint" begin include("literal_adjoint.jl") end
@time @safetestset "ForwardDiff Chunking Adjoints" begin include("forward_chunking.jl") end
@time @safetestset "Stiff Adjoints" begin include("stiff_adjoints.jl") end
@time @safetestset "Scalar u0" begin include("scalar_u.jl") end
@time @safetestset "Autodiff Events" begin include("autodiff_events.jl") end
@time @safetestset "Null Parameters" begin include("null_parameters.jl") end
@time @safetestset "Forward Mode Prob Kwargs" begin include("forward_prob_kwargs.jl") end
Expand Down
23 changes: 23 additions & 0 deletions test/scalar_u.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using OrdinaryDiffEq, SciMLSensitivity, Zygote, ForwardDiff, Test

function neural_ode(u, p, t)
u * p[1]
end

p = [2.0]
u0 = rand(1)[1]
tspan = (0.0, 10.0)
t = Array(range(0, 0.10, length = 100))
prob_neuralode = ODEProblem(neural_ode, u0, tspan)

function loss_neuralode(p)
trial = Array(solve(prob_neuralode, AutoTsit5(Rosenbrock23()), u0 = u0, p = p,
saveat = t, abstol = 1e-6, reltol = 1e-6))
loss = sum(abs2, trial)
return loss
end

dp1 = Zygote.gradient(loss_neuralode, p)[1]
dp2 = ForwardDiff.gradient(loss_neuralode, p)

@test dp1≈dp2 atol=1e-8