Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symbolic save_idxs first pass #2052

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
checkbounds = false,
sparsity = false,
analytic = nothing,
save_idxs = nothing,
kwargs...) where {iip, specialize}
f_gen = generate_function(sys, dvs, ps; expression = Val{eval_expression},
expression_module = eval_module, checkbounds = checkbounds,
Expand Down Expand Up @@ -337,7 +338,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
let sys = sys, dict = Dict()
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar)
build_explicit_observed_function(sys, obsvar; save_idxs = save_idxs)
end
if args === ()
let obs = obs
Expand All @@ -352,7 +353,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
let sys = sys, dict = Dict()
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
build_explicit_observed_function(sys, obsvar;
save_idxs = save_idxs,
checkbounds = checkbounds)
end
if args === ()
let obs = obs
Expand All @@ -365,6 +368,12 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
end
end

savedstates = if save_idxs !== nothing
states(sys)[save_idxs]
else
states(sys)
end

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
if jac
Expand All @@ -382,7 +391,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
tgrad = _tgrad === nothing ? nothing : _tgrad,
mass_matrix = _M,
jac_prototype = jac_prototype,
syms = Symbol.(states(sys)),
syms = Symbol.(savedstates),
indepsym = Symbol(get_iv(sys)),
paramsyms = Symbol.(ps),
observed = observedfun,
Expand Down Expand Up @@ -700,12 +709,18 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
parammap = DiffEqBase.NullParameters();
callback = nothing,
check_length = true,
save_idxs = nothing,
kwargs...) where {iip, specialize}
if has_symbolic_elements(save_idxs)
sym_idxs, int_idxs = partition_ints(save_idxs)
sym_idxs = unique(vcat(sym_idxs, equation_dependencies(sys)))
save_idxs = unique(vcat(SymbolicIndexingInterface.state_sym_to_index.((sys,), save_idxs), int_idxs))
end
has_difference = any(isdifferenceeq, equations(sys))
f, u0, p = process_DEProblem(ODEFunction{iip, specialize}, sys, u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan,
has_difference = has_difference,
check_length, kwargs...)
check_length, save_idxs = save_idxs, kwargs...)
cbs = process_events(sys; callback, has_difference, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
Expand All @@ -728,6 +743,7 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
else
svs = nothing
end

kwargs = filter_kwargs(kwargs)
pt = something(get_metadata(sys), StandardODEProblem())

Expand All @@ -738,6 +754,9 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
if save_idxs !== nothing
kwargs1 = merge(kwargs1, (save_idxs = save_idxs,))
end
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]
Expand Down
8 changes: 7 additions & 1 deletion src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ function build_explicit_observed_function(sys, ts;
expression = false,
output_type = Array,
checkbounds = true,
save_idxs = nothing,
throw = true)
if (isscalar = !(ts isa AbstractVector))
ts = [ts]
Expand Down Expand Up @@ -366,7 +367,12 @@ function build_explicit_observed_function(sys, ts;
push!(obsexprs, lhs ← rhs)
end

dvs = DestructuredArgs(states(sys), inbounds = !checkbounds)
savedstates = if save_idxs !== nothing
states(sys)[save_idxs]
else
states(sys)
end
dvs = DestructuredArgs(savedstates, inbounds = !checkbounds)
ps = DestructuredArgs(parameters(sys), inbounds = !checkbounds)
args = [dvs, ps, ivs...]
pre = get_postprocess_fbody(sys)
Expand Down
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,17 @@ function fast_substitute(expr, pair::Pair)
end

normalize_to_differential(s) = s

safe_unwrap(x) = x
safe_unwrap(x::Num) = unwrap(x)

function has_symbolic_elements(idxs)
(idxs !== nothing) && any(i -> (i isa Symbolics.Symbolic), safe_unwrap.(idxs))
end

function partition_ints(idxs)
idxs = safe_unwrap.(idxs)
syms = filter(i -> (i isa Symbolics.Symbolic), idxs)
ints = filter(i -> i isa Integer, setdiff(idxs, syms))
return syms, ints
end
35 changes: 35 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -985,3 +985,38 @@ let
prob = ODAEProblem(sys4s, [x => 1.0, D(x) => 1.0], (0, 1.0))
@test !isnothing(prob.f.sys)
end

@testset "Symbolic save_idxs" begin
@parameters t
@variables a(t) b(t) c(t) d(t) e(t)

D = Differential(t)

eqs = [D(a) ~ a,
D(b) ~ b,
D(c) ~ c,
D(d) ~ d,
e ~ d]

@named sys = ODESystem(eqs, t, [a, b, c, d, e], [];
defaults = Dict([a => 1.0,
b => 1.0,
c => 1.0,
d => 1.0,
e => 1.0]))
sys = structural_simplify(sys)
prob = ODEProblem(sys, [], (0, 1.0))
prob_sym = ODEProblem(sys, [], (0, 1.0), save_idxs = [a, b, d])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be a breaking change to move this to the problem instead of solve. Also, I don't get why it "should" be here. It makes it hard to re-solve and just save something else. Why should it be in the problem than in the solve?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above for my concerns


sol = solve(prob, Tsit5())
sol_sym = solve(prob_sym, Tsit5())

@test sol_sym[a] ≈ sol[a]
@test sol_sym[b] ≈ sol[b]
@test sol_sym[d] ≈ sol[d]
@test sol_sym[e] ≈ sol[e]

@test sol.u != sol_sym.u

@test_throws Exception sol_sym[c]
end