Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
d6290bf
feat(initialization): get gradients against initialization problem
DhairyaLGandhi Mar 6, 2025
d3199c0
test: add initialization problem to test suite
DhairyaLGandhi Mar 6, 2025
a4fa7c5
chore: pass tunables to jacobian
DhairyaLGandhi Mar 6, 2025
5a7dd26
test: cleanup imports
DhairyaLGandhi Mar 6, 2025
94ec324
chore: rm debug statements
DhairyaLGandhi Mar 6, 2025
0c5564e
test: add Core8 to CI
DhairyaLGandhi Mar 6, 2025
aca9cd4
chore: use get_initial_values
DhairyaLGandhi Mar 11, 2025
9bec784
chore: check for OVERDETERMINED initialization for solving initializa…
DhairyaLGandhi Mar 11, 2025
72cbb35
chore: pass sensealg to initial_values
DhairyaLGandhi Mar 11, 2025
2d85e19
chore: treat Delta as array
DhairyaLGandhi Mar 13, 2025
9a8a845
chore: use autojacvec from sensealg
DhairyaLGandhi Mar 13, 2025
95ebbf3
chore: move igs before solve, re-use initialization
DhairyaLGandhi Mar 13, 2025
957d7fe
chore: update igs to re-use inital values
DhairyaLGandhi Mar 13, 2025
a00574f
chore: qualify NoInit
DhairyaLGandhi Mar 13, 2025
4562f0c
chore: remove igs from steady state adjoint gor initialization
DhairyaLGandhi Mar 17, 2025
6c21324
chore: accumulate gradients in steady state adjoint explicitly
DhairyaLGandhi Mar 19, 2025
a675a7f
fix: handle MTKparameters and Arrays uniformly
DhairyaLGandhi Mar 20, 2025
7941a3c
feat: allow reverse mode for initialization solving
DhairyaLGandhi Mar 20, 2025
9557e8c
test: add more tests for parameter initialization
DhairyaLGandhi Mar 20, 2025
8feae0e
test: fix label
DhairyaLGandhi Mar 20, 2025
d3b1669
chore: rename file
DhairyaLGandhi Mar 20, 2025
e01eb77
test: fix sensealg and confusing error message
DhairyaLGandhi Mar 21, 2025
0ad6c62
chore: return new_u0
DhairyaLGandhi Mar 24, 2025
6df7987
chore: rebase branch
DhairyaLGandhi Mar 24, 2025
c4c7807
chore: mark symbol local
DhairyaLGandhi Mar 25, 2025
b85b16e
chore: pass tunables to Tape
DhairyaLGandhi Mar 25, 2025
8fc4136
chore: update new_u0, new_p to orignal vals if not initializing
DhairyaLGandhi Mar 26, 2025
1f1cce5
chore: rebase master
DhairyaLGandhi Mar 26, 2025
0d1abcc
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Mar 26, 2025
b164b18
chore: add MSL to test deps
DhairyaLGandhi Mar 26, 2025
915d949
feat: allow analytically solved initialization solutions to propagate…
DhairyaLGandhi Apr 3, 2025
d2fd79a
chore: force allocated buffers for vjp to be deterministic
DhairyaLGandhi Apr 3, 2025
a0cd94a
chore: pass tunables to allocate vjp of MTKParameters
DhairyaLGandhi Apr 3, 2025
885794d
test: Core6 dont access J.du
DhairyaLGandhi Apr 4, 2025
13a1ffb
chore: SteadyStateAdjoint could be thunk
DhairyaLGandhi Apr 4, 2025
7826866
chore: import AbstractThunk
DhairyaLGandhi Apr 4, 2025
6ceaa1a
chore: handle upstream pullback better in steady state adjoint
DhairyaLGandhi Apr 4, 2025
dfebd0b
chore: dont accum pullback for parameters
DhairyaLGandhi Apr 4, 2025
396f63e
test: import SII
DhairyaLGandhi Apr 5, 2025
de7e7da
test: wrap ps in ComponentArray
DhairyaLGandhi Apr 7, 2025
f5fb559
chore: call du for jacobian
DhairyaLGandhi Apr 10, 2025
019a051
chore: add recursive_copyto for identical NT trees
DhairyaLGandhi Apr 10, 2025
91ee019
deps: MSL compat
DhairyaLGandhi Apr 10, 2025
4b74718
chore: undo du access
DhairyaLGandhi Apr 10, 2025
94d5e2b
chore: handle J through fwd mode
DhairyaLGandhi Apr 10, 2025
764d3ff
chore: J = nothing instead of NT
DhairyaLGandhi Apr 11, 2025
84b2602
chore: check nothing in steady state
DhairyaLGandhi Apr 13, 2025
8a4aa79
chore: also canonicalize dp
DhairyaLGandhi Apr 13, 2025
d69ccb1
chore: adjoint of preallocation tools
DhairyaLGandhi Apr 13, 2025
2cc3673
chore: pass Δ to canonicalize
DhairyaLGandhi Apr 14, 2025
22f056a
chore: handle different parameter types
DhairyaLGandhi Apr 14, 2025
1f95b25
chore: check for number
DhairyaLGandhi Apr 14, 2025
0d78fa8
test: rm commented out code
DhairyaLGandhi Apr 14, 2025
6620e8a
chore: get tunables from dp
DhairyaLGandhi Apr 14, 2025
5f5633b
test: clean up initialization tests
DhairyaLGandhi Apr 14, 2025
984c2ce
chore: pass initializealg
DhairyaLGandhi Apr 15, 2025
82cd5fe
chore: force path through BrownBasicInit
DhairyaLGandhi Apr 16, 2025
987e8be
chore: replace NoInit with BrownBasicInit
DhairyaLGandhi Apr 17, 2025
056fffa
test: add test to check residual with initialization
DhairyaLGandhi Apr 17, 2025
a122340
chore: replace BrownBasicInit with CheckInit
DhairyaLGandhi Apr 17, 2025
e105838
chore: qualify checinit
DhairyaLGandhi Apr 17, 2025
aaddc02
test: add test to force DAE initialization and prevent simplification…
DhairyaLGandhi Apr 21, 2025
60be1c7
test: check DAE initialization takes in BrownFullBasicInit
DhairyaLGandhi Apr 21, 2025
9edbe02
chore: check for default path, handle nlsolve kwargs as ODECore inter…
DhairyaLGandhi Apr 21, 2025
308ae5c
chore: update imported symbols
DhairyaLGandhi Apr 21, 2025
a333588
Update src/concrete_solve.jl
DhairyaLGandhi Apr 21, 2025
a2cf0e6
Update src/concrete_solve.jl
ChrisRackauckas Apr 21, 2025
755c9df
Update mtk.jl
ChrisRackauckas Apr 21, 2025
75ad141
chore: typo
DhairyaLGandhi Apr 22, 2025
e07dd53
chore: qualify DefaultInit
DhairyaLGandhi Apr 22, 2025
ee804b2
chore: run parameter initialization by passing missing parameters
DhairyaLGandhi Apr 23, 2025
7abf42c
chore: rm dead code
DhairyaLGandhi Apr 23, 2025
a7d4e5a
test: allocate gt based on size of new_sol
DhairyaLGandhi Apr 23, 2025
8e660fe
Update Project.toml
ChrisRackauckas Apr 23, 2025
de63cf9
Update test/mtk.jl
ChrisRackauckas Apr 23, 2025
b88f468
Update mtk.jl
ChrisRackauckas Apr 23, 2025
7dd1cc7
Also test u0 gradients
ChrisRackauckas Apr 23, 2025
cdaa2c7
chore: merge upstream
DhairyaLGandhi Apr 24, 2025
d3608c4
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Apr 24, 2025
4cf7bd5
Update test/mtk.jl
DhairyaLGandhi Apr 24, 2025
2ae712b
chore: handle when p is a functor in steady state adjoint
DhairyaLGandhi Apr 24, 2025
229e691
Merge branch 'dg/initprob' of github.com:SciML/SciMLSensitivity.jl in…
DhairyaLGandhi Apr 24, 2025
6e1109e
Merge branch 'master' into dg/initprob
DhairyaLGandhi Apr 24, 2025
d32b3f2
chore: git mixup
DhairyaLGandhi Apr 24, 2025
6e549e7
chore: git mixup
DhairyaLGandhi Apr 24, 2025
9789034
chore: revert bad commit
DhairyaLGandhi Apr 24, 2025
35937c0
chore: handle nothing dtunables for SteadyStateAdjoint
DhairyaLGandhi Apr 24, 2025
f934635
chore: rm u0 nothing forced to empty array
DhairyaLGandhi Apr 24, 2025
a83cd29
chore: reverse order of nothing check
DhairyaLGandhi Apr 24, 2025
2dfbc6e
chore: rm dead code
DhairyaLGandhi Apr 24, 2025
9aecbfd
chore: DEQ handling
DhairyaLGandhi Apr 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ jobs:
- Core5
- Core6
- Core7
- Core8
- QA
- SDE1
- SDE2
Expand Down
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down Expand Up @@ -76,13 +77,15 @@ LinearAlgebra = "1.10"
LinearSolve = "2, 3"
Lux = "1"
Markdown = "1.10"
ModelingToolkit = "9.42"
ModelingToolkit = "9.74"
ModelingToolkitStandardLibrary = "2"
Mooncake = "0.4.52"
NLsolve = "4.5.1"
NonlinearSolve = "3.0.1, 4"
Optimization = "4"
OptimizationOptimisers = "0.3"
OrdinaryDiffEq = "6.81.1"
OrdinaryDiffEqCore = "1"
Pkg = "1.10"
PreallocationTools = "0.4.4"
QuadGK = "2.9.1"
Expand Down Expand Up @@ -117,6 +120,7 @@ DelayDiffEq = "bcd4f6db-9728-5f36-b5f7-82caef46ccdb"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
Expand All @@ -131,4 +135,4 @@ StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
test = ["AlgebraicMultigrid", "Aqua", "Calculus", "ComponentArrays", "DelayDiffEq", "Distributed", "Lux", "ModelingToolkit", "ModelingToolkitStandardLibrary", "Mooncake", "NLsolve", "NonlinearSolve", "Optimization", "OptimizationOptimisers", "OrdinaryDiffEq", "Pkg", "SafeTestsets", "SparseArrays", "SteadyStateDiffEq", "StochasticDiffEq", "Test"]
9 changes: 7 additions & 2 deletions src/SciMLSensitivity.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,19 @@ using SciMLBase: SciMLBase, AbstractOverloadingSensitivityAlgorithm,
RODEFunction, RODEProblem, ReturnCode, SDEFunction,
SDEProblem, VectorContinuousCallback, deleteat!,
get_tmp_cache, has_adjoint, isinplace, reinit!, remake,
solve, u_modified!, LinearAliasSpecifier
solve, u_modified!, LinearAliasSpecifier, OverrideInit, CheckInit

using OrdinaryDiffEqCore: OrdinaryDiffEqCore, BrownFullBasicInit, DefaultInit, default_nlsolve, has_autodiff

# AD Backends
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent
using ChainRulesCore: unthunk, @thunk, NoTangent, @not_implemented, Tangent, ZeroTangent, AbstractThunk
using Enzyme: Enzyme
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Tracker: Tracker, TrackedArray
using ReverseDiff: ReverseDiff
using Zygote: Zygote
using SciMLBase.ConstructionBase

# Std Libs
using LinearAlgebra: LinearAlgebra, Diagonal, I, UniformScaling, adjoint, axpy!,
Expand All @@ -56,6 +59,8 @@ using Markdown: Markdown, @doc_str
using Random: Random, rand!
using Statistics: Statistics, mean

using LinearAlgebra: diag

abstract type SensitivityFunction end
abstract type TransformedFunction end

Expand Down
40 changes: 26 additions & 14 deletions src/adjoint_common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
unwrappedf = unwrapped_f(f)

numparams = p === nothing || p === SciMLBase.NullParameters() ? 0 : length(tunables)
numindvar = length(u0)
numindvar = isnothing(u0) ? nothing : length(u0)
isautojacvec = get_jacvec(sensealg)

issemiexplicitdae = false
Expand Down Expand Up @@ -106,18 +106,22 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
isempty(algevar_idxs) || (issemiexplicitdae = true)
end
if !issemiexplicitdae
diffvar_idxs = eachindex(u0)
diffvar_idxs = isnothing(u0) ? nothing : eachindex(u0)
algevar_idxs = 1:0
end

if !needs_jac && !issemiexplicitdae && !(autojacvec isa Bool)
J = nothing
else
if alg === nothing || SciMLBase.forwarddiffs_model_time(alg)
# 1 chunk is fine because it's only t
_J = similar(u0, numindvar, numindvar)
_J .= 0
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0)))
if !isnothing(u0)
# 1 chunk is fine because it's only t
_J = similar(u0, numindvar, numindvar)
_J .= 0
J = dualcache(_J, ForwardDiff.pickchunksize(length(u0)))
else
J = nothing
end
else
J = similar(u0, numindvar, numindvar)
J .= 0
Expand All @@ -133,8 +137,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
dg_val[1] .= false
dg_val[2] .= false
else
dg_val = similar(u0, numindvar) # number of funcs size
dg_val .= false
if !isnothing(u0)
dg_val = similar(u0, numindvar) # number of funcs size
dg_val .= false
else
dg_val = nothing
end
end
else
pgpu = UGradientWrapper(g, _t, p)
Expand Down Expand Up @@ -241,8 +249,12 @@ function adjointdiffcache(g::G, sensealg, discrete, sol, dgdu::DG1, dgdp::DG2, f
pJ = if (quad || !(autojacvec isa Bool))
nothing
else
_pJ = similar(u0, numindvar, numparams)
_pJ .= false
if !isnothing(u0)
_pJ = similar(u0, numindvar, numparams)
_pJ .= false
else
_pJ = nothing
end
end

f_cache = isinplace ? deepcopy(u0) : nothing
Expand Down Expand Up @@ -379,11 +391,11 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
if !isRODE
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
du1 = (p !== nothing && p !== SciMLBase.NullParameters()) ?
similar(p, size(u)) : similar(u)
du1 .= false
f(du1, u, p, first(t))
f(du1, u, repack(p), first(t))
return vec(du1)
end
else
Expand All @@ -402,8 +414,8 @@ function get_paramjac_config(autojacvec::ReverseDiffVJP, p, f, y, _p, _t;
# because hasportion(Tunable(), NullParameters) == false
__p = p isa SciMLBase.NullParameters ? _p :
SciMLStructures.replace(Tunable(), p, _p)
tape = ReverseDiff.GradientTape((y, __p, [_t])) do u, p, t
vec(f(u, p, first(t)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same down here?

tape = ReverseDiff.GradientTape((y, _p, [_t])) do u, p, t
vec(f(u, repack(p), first(t)))
end
else
tape = ReverseDiff.GradientTape((y, _p, [_t], _W)) do u, p, t, W
Expand Down
122 changes: 111 additions & 11 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,16 @@ function inplace_vjp(prob, u0, p, verbose, repack)

vjp = try
f = unwrapped_f(prob.f)
tspan_ = prob isa AbstractNonlinearProblem ? nothing : [prob.tspan[1]]
if p === nothing || p isa SciMLBase.NullParameters
ReverseDiff.GradientTape((copy(u0), [prob.tspan[1]])) do u, t
ReverseDiff.GradientTape((copy(u0), tspan_)) do u, t
du1 = similar(u, size(u))
du1 .= 0
f(du1, u, p, first(t))
return vec(du1)
end
else
ReverseDiff.GradientTape((copy(u0), p, [prob.tspan[1]])) do u, p, t
ReverseDiff.GradientTape((copy(u0), p, tspan_)) do u, p, t
du1 = similar(u, size(u))
du1 .= 0
f(du1, u, repack(p), first(t))
Expand Down Expand Up @@ -299,6 +300,7 @@ function DiffEqBase._concrete_solve_adjoint(
tunables, repack = Functors.functor(p)
end

u0 = state_values(prob) === nothing ? Float64[] : u0
default_sensealg = automatic_sensealg_choice(prob, u0, tunables, verbose, repack)
DiffEqBase._concrete_solve_adjoint(prob, alg, default_sensealg, u0, p,
originator::SciMLBase.ADOriginator, args...; verbose,
Expand Down Expand Up @@ -371,6 +373,7 @@ function DiffEqBase._concrete_solve_adjoint(
args...; save_start = true, save_end = true,
saveat = eltype(prob.tspan)[],
save_idxs = nothing,
initializealg_default = SciMLBase.OverrideInit(; abstol = 1e-6, reltol = 1e-3),
kwargs...)
if !(sensealg isa GaussAdjoint) &&
!(p isa Union{Nothing, SciMLBase.NullParameters, AbstractArray}) ||
Expand Down Expand Up @@ -412,16 +415,61 @@ function DiffEqBase._concrete_solve_adjoint(
Base.diff_names(Base._nt_names(values(kwargs)),
(:callback_adj, :callback))}(values(kwargs))
isq = sensealg isa QuadratureAdjoint
kwargs_init = kwargs_adj[Base.diff_names(Base._nt_names(kwargs_adj), (:initializealg,))]

if haskey(kwargs, :initializealg) || haskey(prob.kwargs, :initializealg)
initializealg = haskey(kwargs, :initializealg) ? kwargs[:initializealg] : prob.kwargs[:initializealg]
else
initializealg = DefaultInit()
end

default_inits = Union{OverrideInit, Nothing, DefaultInit}
igs, new_u0, new_p, new_initializealg = if (SciMLBase.has_initialization_data(_prob.f) && initializealg isa default_inits)
local new_u0
local new_p
initializeprob = prob.f.initialization_data.initializeprob
iu0 = state_values(initializeprob)
isAD = if iu0 === nothing
AutoForwardDiff
elseif has_autodiff(alg)
OrdinaryDiffEqCore.alg_autodiff(alg) isa AutoForwardDiff
else
true
end
nlsolve_alg = default_nlsolve(nothing, Val(isinplace(_prob)), iu0, initializeprob, isAD)
Comment on lines +434 to +439
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas are alg_autodiff, has_autodiff and default_nlsolve public API of ODECore? If yes, we should PR to ODECore to (at the very least) add comments saying so. Or get around to adding SciMLPublic.jl with an @public macro.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They have at least been used downstream here for a long while. We should comment that for sure.

initializealg = initializealg isa Union{Nothing, DefaultInit} ? initializealg_default : initializealg

iy, back = Zygote.pullback(tunables) do tunables
new_prob = remake(_prob, p = repack(tunables))
new_u0, new_p, _ = SciMLBase.get_initial_values(new_prob, new_prob, new_prob.f, initializealg, Val(isinplace(new_prob));
sensealg = SteadyStateAdjoint(autojacvec = sensealg.autojacvec),
nlsolve_alg,
kwargs_init...)
new_tunables, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), new_p)
if SciMLBase.initialization_status(_prob) == SciMLBase.OVERDETERMINED
sum(new_tunables)
else
sum(new_u0) + sum(new_tunables)
end
end
igs = back(one(iy))[1] .- one(eltype(tunables))

igs, new_u0, new_p, SciMLBase.NoInit()
else
nothing, u0, p, initializealg
end
_prob = remake(_prob, u0 = new_u0, p = new_p)

if sensealg isa BacksolveAdjoint
sol = solve(_prob, alg, args...; save_noise = true,
sol = solve(_prob, alg, args...; initializealg = new_initializealg, save_noise = true,
save_start = save_start, save_end = save_end,
saveat = saveat, kwargs_fwd...)
elseif ischeckpointing(sensealg)
sol = solve(_prob, alg, args...; save_noise = true,
sol = solve(_prob, alg, args...; initializealg = new_initializealg, save_noise = true,
save_start = true, save_end = true,
saveat = saveat, kwargs_fwd...)
else
sol = solve(_prob, alg, args...; save_noise = true, save_start = true,
sol = solve(_prob, alg, args...; initializealg = new_initializealg, save_noise = true, save_start = true,
save_end = true, kwargs_fwd...)
end

Expand Down Expand Up @@ -491,6 +539,7 @@ function DiffEqBase._concrete_solve_adjoint(
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

function adjoint_sensitivity_backpass(Δ)
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
function df_iip(_out, u, p, t, i)
outtype = _out isa SubArray ?
ArrayInterface.parameterless_type(_out.parent) :
Expand Down Expand Up @@ -628,20 +677,22 @@ function DiffEqBase._concrete_solve_adjoint(
dgdu_discrete = df_iip,
sensealg = sensealg,
callback = cb2,
kwargs_adj...)
kwargs_init...)
else
du0, dp = adjoint_sensitivities(sol, alg, args...; t = ts,
dgdu_discrete = df_oop,
sensealg = sensealg,
callback = cb2,
kwargs_adj...)
kwargs_init...)
end

du0 = reshape(du0, size(u0))

dp = p === nothing || p === DiffEqBase.NullParameters() ? nothing :
dp isa AbstractArray ? reshape(dp', size(tunables)) : dp

dp = Zygote.accum(dp, igs)

_, repack_adjoint = if p === nothing || p === DiffEqBase.NullParameters() ||
!isscimlstructure(p)
nothing, x -> (x,)
Expand Down Expand Up @@ -1679,6 +1730,7 @@ function DiffEqBase._concrete_solve_adjoint(
u0, p, originator::SciMLBase.ADOriginator,
args...; save_idxs = nothing, kwargs...)
_prob = remake(prob, u0 = u0, p = p)

sol = solve(_prob, alg, args...; kwargs...)
_save_idxs = save_idxs === nothing ? Colon() : save_idxs

Expand All @@ -1688,26 +1740,74 @@ function DiffEqBase._concrete_solve_adjoint(
out = SciMLBase.sensitivity_solution(sol, sol[_save_idxs])
end

_, repack_adjoint = if isscimlstructure(p)
Zygote.pullback(p) do p
t, _, _ = canonicalize(Tunable(), p)
t
end
elseif isfunctor(p)
ps, re = Functors.functor(p)
ps, x -> (re(x),)
else
nothing, x -> (x,)
end

function steadystatebackpass(Δ)
Δ = Δ isa AbstractThunk ? unthunk(Δ) : Δ
# Δ = dg/dx or diffcache.dg_val
# del g/del p = 0
function df(_out, u, p, t, i)
if _save_idxs isa Number
_out[_save_idxs] = Δ[_save_idxs]
elseif Δ isa Number
@. _out[_save_idxs] = Δ
else
elseif Δ isa AbstractArray{<:AbstractArray} || Δ isa AbstractVectorOfArray || Δ isa AbstractArray
@. _out[_save_idxs] = Δ[_save_idxs]
elseif isnothing(_out)
_out
else
@. _out[_save_idxs] = Δ.u[_save_idxs]
end
end
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df, initializealg = BrownFullBasicInit())

dp, Δtunables = if Δ isa AbstractArray || Δ isa Number
# if Δ isa AbstractArray, the gradients correspond to `u`
# this is something that needs changing in the future, but
# this is the applicable till the movement to structuaral
# tangents is completed
dp, Δtunables = if isscimlstructure(dp)
dp, _, _ = canonicalize(Tunable(), dp)
dp, nothing
elseif isfunctor(dp)
dp, _ = Functors.functor(dp)
dp, nothing
else
dp, nothing
end
else
dp, Δtunables = if isscimlstructure(p)
Δp = setproperties(dp, to_nt(Δ.prob.p))
Δtunables, _, _ = canonicalize(Tunable(), Δp)
dp, _, _ = canonicalize(Tunable(), dp)
dp, Δtunables
elseif isfunctor(p)
dp, _ = Functors.functor(dp)
Δtunables, _ = Functors.functor(Δ.prob.p)
dp, Δtunables
else
dp, Δ.prob.p
end
end
dp = adjoint_sensitivities(sol, alg; sensealg = sensealg, dgdu = df)

dp = Zygote.accum(dp, (isnothing(Δtunables) || isempty(Δtunables)) ? nothing : Δtunables)

if originator isa SciMLBase.TrackerOriginator ||
originator isa SciMLBase.ReverseDiffOriginator
(NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
else
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), dp, NoTangent(),
(NoTangent(), NoTangent(), NoTangent(), NoTangent(), repack_adjoint(dp)[1], NoTangent(),
ntuple(_ -> NoTangent(), length(args))...)
end
end
Expand Down
Loading
Loading