Skip to content

Commit

Permalink
Merge pull request #1893 from SciML/myb/fastobs
Browse files Browse the repository at this point in the history
Add one-arg observed dispatch
  • Loading branch information
YingboMa committed Oct 20, 2022
2 parents 4c30760 + a47662f commit 3149f77
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 8 deletions.
10 changes: 8 additions & 2 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -318,14 +318,20 @@ function build_torn_function(sys;
sol_states = sol_states,
var2assignment = var2assignment

function generated_observed(obsvar, u, p, t)
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_observed_function(state, obsvar, var_eq_matching, var_sccs,
is_solver_state_idxs, assignments, deps,
sol_states, var2assignment,
checkbounds = checkbounds)
end
obs(u, p, t)
if args === ()
let obs = obs
(u, p, t) -> obs(u, p, t)
end
else
obs(args...)
end
end
end

Expand Down
30 changes: 24 additions & 6 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,20 +328,32 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
obs = observed(sys)
observedfun = if steady_state
let sys = sys, dict = Dict()
function generated_observed(obsvar, u, p, t = Inf)
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar)
end
obs(u, p, t)
if args === ()
let obs = obs
(u, p, t = Inf) -> obs(u, p, t)
end
else
length(args) == 2 ? obs(args..., Inf) : obs(args...)
end
end
end
else
let sys = sys, dict = Dict()
function generated_observed(obsvar, u, p, t)
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
end
obs(u, p, t)
if args === ()
let obs = obs
(u, p, t) -> obs(u, p, t)
end
else
obs(args...)
end
end
end
end
Expand Down Expand Up @@ -424,11 +436,17 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = states(sys),

obs = observed(sys)
observedfun = let sys = sys, dict = Dict()
function generated_observed(obsvar, u, p, t)
function generated_observed(obsvar, args...)
obs = get!(dict, value(obsvar)) do
build_explicit_observed_function(sys, obsvar; checkbounds = checkbounds)
end
obs(u, p, t)
if args === ()
let obs = obs
(u, p, t) -> obs(u, p, t)
end
else
obs(args...)
end
end
end

Expand Down
2 changes: 2 additions & 0 deletions test/components.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ end

function check_rc_sol(sol)
rpi = sol[rc_model.resistor.p.i]
rpifun = sol.prob.f.observed(rc_model.resistor.p.i)
@test rpifun.(sol.u, (sol.prob.p,), sol.t) == rpi
@test any(!isequal(rpi[1]), rpi) # test that we don't have a constant system
@test sol[rc_model.resistor.p.i] == sol[resistor.p.i] == sol[capacitor.p.i]
@test sol[rc_model.resistor.n.i] == sol[resistor.n.i] == -sol[capacitor.p.i]
Expand Down

0 comments on commit 3149f77

Please sign in to comment.