Skip to content

Commit

Permalink
Remake with symbolic map (#1835)
Browse files Browse the repository at this point in the history
  • Loading branch information
xtalax authored Oct 18, 2022
1 parent 3d04208 commit 93a3aaf
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ module ModelingToolkit
using DocStringExtensions
using AbstractTrees
using DiffEqBase, SciMLBase, ForwardDiff, Reexport
using SciMLBase: StandardODEProblem, StandardNonlinearProblem
using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap
using Distributed
using StaticArrays, LinearAlgebra, SparseArrays, LabelledArrays
using InteractiveUtils
Expand Down
16 changes: 14 additions & 2 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ applicable.
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
toterm = Symbolics.diff2term, promotetoconcrete = nothing,
tofloat = true, use_union = false)
varlist = map(unwrap, varlist)
varlist = collect(map(unwrap, varlist))

# Edge cases where one of the arguments is effectively empty.
is_incomplete_initialization = varmap isa DiffEqBase.NullParameters ||
varmap === nothing
Expand Down Expand Up @@ -97,7 +98,7 @@ function _varmap_to_vars(varmap::Dict, varlist; defaults = Dict(), check = false
varmap[p] = fixpoint_sub(v, varmap)
end

missingvars = setdiff(varlist, keys(varmap))
missingvars = setdiff(varlist, collect(keys(varmap)))
check && (isempty(missingvars) || throw_missingvars(missingvars))

out = [varmap[var] for var in varlist]
Expand All @@ -107,6 +108,17 @@ end
throw(ArgumentError("$vars are missing from the variable map."))
end

"""
$(SIGNATURES)
Intercept the call to `handle_varmap` and convert it to an ordered list if the user has
ModelingToolkit loaded, and the problem has a symbolic origin.
"""
function SciMLBase.handle_varmap(varmap, sys::AbstractSystem; field = :states, kwargs...)
out = varmap_to_vars(varmap, getfield(sys, field); kwargs...)
return out
end

struct IsHistory end
ishistory(x) = ishistory(unwrap(x))
ishistory(x::Symbolic) = getmetadata(x, IsHistory, false)
Expand Down
17 changes: 17 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,23 @@ for p in [prob1, prob14]
@test Set(Num.(parameters(sys)) .=> p.p) == Set([k₁ => 0.04, k₂ => 3e7, k₃ => 1e4])
@test Set(Num.(states(sys)) .=> p.u0) == Set([y₁ => 1, y₂ => 0, y₃ => 0])
end
# test remake with symbols
p3 = [k₁ => 0.05,
k₂ => 2e7,
k₃ => 1.1e4]
u01 = [y₁ => 1, y₂ => 1, y₃ => 1]
prob_pmap = remake(prob14; p = p3, u0 = u01)
prob_dpmap = remake(prob14; p = Dict(p3), u0 = Dict(u01))
for p in [prob_pmap, prob_dpmap]
@test Set(Num.(parameters(sys)) .=> p.p) == Set([k₁ => 0.05, k₂ => 2e7, k₃ => 1.1e4])
@test Set(Num.(states(sys)) .=> p.u0) == Set([y₁ => 1, y₂ => 1, y₃ => 1])
end
sol_pmap = solve(prob_pmap, Rodas5())
sol_dpmap = solve(prob_dpmap, Rodas5())

@test sol_pmap.u sol_dpmap.u

# test kwargs
prob2 = ODEProblem(sys, u0, tspan, p, jac = true)
prob3 = ODEProblem(sys, u0, tspan, p, jac = true, sparse = true)
@test prob3.f.jac_prototype isa SparseMatrixCSC
Expand Down

0 comments on commit 93a3aaf

Please sign in to comment.