Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/API/codegen.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ ModelingToolkit.calculate_A_b
All code generation eventually calls `build_function_wrapper`.

```@docs
build_function_wrapper
ModelingToolkit.build_function_wrapper
```
6 changes: 4 additions & 2 deletions src/systems/diffeqs/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1037,9 +1037,11 @@ function respecialize(sys::AbstractSystem, mapping; all = false)
"""

if iscall(k)
op = operation(k)
op = operation(k)::BasicSymbolic
@assert !iscall(op)
op = SymbolicUtils.Sym{SymbolicUtils.FnType{Tuple{Any}, T}}(nameof(op))
args = arguments(k)
new_p = SymbolicUtils.term(op, args...; type = T)
new_p = op(args...)
else
new_p = SymbolicUtils.Sym{T}(getname(k))
end
Expand Down
12 changes: 9 additions & 3 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -763,8 +763,14 @@ function __remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = tru
oldbuf.discrete, newbuf.discrete)
@set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.(
oldbuf.constant, newbuf.constant)
@set! newbuf.nonnumeric = narrow_buffer_type_and_fallback_undefs.(
oldbuf.nonnumeric, newbuf.nonnumeric)
for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric)
for i in eachindex(oldv)
isassigned(newv, i) && continue
newv[i] = oldv[i]
end
end
@set! newbuf.nonnumeric = Tuple(
typeof(oldv)(newv) for (oldv, newv) in zip(oldbuf.nonnumeric, newbuf.nonnumeric))
if !ArrayInterface.ismutable(oldbuf)
@set! newbuf.tunable = similar_type(oldbuf.tunable, eltype(newbuf.tunable))(newbuf.tunable)
@set! newbuf.initials = similar_type(oldbuf.initials, eltype(newbuf.initials))(newbuf.initials)
Expand Down Expand Up @@ -820,7 +826,7 @@ function SciMLBase.create_parameter_timeseries_collection(
isempty(ps.discrete) && return nothing
num_discretes = only(blocksize(ps.discrete[1]))
buffers = []
partition_type = Tuple{(Vector{eltype(buf)} for buf in ps.discrete)...}
partition_type = Tuple{(typeof(parent(buf)) for buf in ps.discrete)...}
for i in 1:num_discretes
ts = eltype(tspan)[]
us = NestedGetIndex{partition_type}[]
Expand Down
73 changes: 46 additions & 27 deletions src/systems/problem_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,10 @@ function.
Note that the getter ONLY works for problem-like objects, since it generates an observed
function. It does NOT work for solutions.
"""
Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector)
Base.@nospecializeinfer function concrete_getu(indp, syms; eval_expression, eval_module)
@nospecialize
obsfn = build_explicit_observed_function(indp, syms; wrap_delays = false)
obsfn = build_explicit_observed_function(
indp, syms; wrap_delays = false, eval_expression, eval_module)
return ObservedWrapper{is_time_dependent(indp)}(obsfn)
end

Expand Down Expand Up @@ -757,7 +758,8 @@ takes a value provider of `srcsys` and a value provider of `dstsys` and returns
- `p_constructor`: The `p_constructor` argument to `process_SciMLProblem`.
"""
function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::AbstractSystem;
initials = false, unwrap_initials = false, p_constructor = identity)
initials = false, unwrap_initials = false, p_constructor = identity,
eval_expression = false, eval_module = @__MODULE__)
_p_constructor = p_constructor
p_constructor = PConstructorApplicator(p_constructor)
# if we call `getu` on this (and it were able to handle empty tuples) we get the
Expand All @@ -773,7 +775,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
tunable_getter = if isempty(tunable_syms)
Returns(SizedVector{0, Float64}())
else
p_constructor ∘ concrete_getu(srcsys, tunable_syms)
p_constructor ∘ concrete_getu(srcsys, tunable_syms; eval_expression, eval_module)
end
initials_getter = if initials && !isempty(syms[2])
initsyms = Vector{Any}(syms[2])
Expand All @@ -792,7 +794,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
end
end
end
p_constructor ∘ concrete_getu(srcsys, initsyms)
p_constructor ∘ concrete_getu(srcsys, initsyms; eval_expression, eval_module)
else
Returns(SizedVector{0, Float64}())
end
Expand All @@ -810,7 +812,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
# tuple of `BlockedArray`s
Base.Fix2(Broadcast.BroadcastFunction(BlockedArray), blockarrsizes) ∘
Base.Fix1(broadcast, p_constructor) ∘
getu(srcsys, syms[3])
concrete_getu(srcsys, syms[3]; eval_expression, eval_module)
end
const_getter = if syms[4] == ()
Returns(())
Expand All @@ -826,7 +828,8 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac
end)
# nonnumerics retain the assigned buffer type without narrowing
Base.Fix1(broadcast, _p_constructor) ∘
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘ getu(srcsys, syms[5])
Base.Fix1(Broadcast.BroadcastFunction(call), buftypes) ∘
concrete_getu(srcsys, syms[5]; eval_expression, eval_module)
end
getters = (
tunable_getter, initials_getter, discs_getter, const_getter, nonnumeric_getter)
Expand All @@ -853,14 +856,19 @@ Construct a `ReconstructInitializeprob` which reconstructs the `u0` and `p` of `
with values from `srcsys`.
"""
function ReconstructInitializeprob(
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity)
srcsys::AbstractSystem, dstsys::AbstractSystem; u0_constructor = identity, p_constructor = identity,
eval_expression = false, eval_module = @__MODULE__)
@assert is_initializesystem(dstsys)
ugetter = u0_constructor ∘ getu(srcsys, unknowns(dstsys))
ugetter = u0_constructor ∘
concrete_getu(srcsys, unknowns(dstsys); eval_expression, eval_module)
if is_split(dstsys)
pgetter = get_mtkparameters_reconstructor(srcsys, dstsys; p_constructor)
pgetter = get_mtkparameters_reconstructor(
srcsys, dstsys; p_constructor, eval_expression, eval_module)
else
syms = parameters(dstsys)
pgetter = let inner = concrete_getu(srcsys, syms), p_constructor = p_constructor
pgetter = let inner = concrete_getu(srcsys, syms; eval_expression, eval_module),
p_constructor = p_constructor

function _getter2(valp, initprob)
p_constructor(inner(valp))
end
Expand Down Expand Up @@ -924,18 +932,20 @@ Given `sys` and its corresponding initialization system `initsys`, return the
`initializeprobpmap` function in `OverrideInitData` for the systems.
"""
function construct_initializeprobpmap(
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity)
sys::AbstractSystem, initsys::AbstractSystem; p_constructor = identity, eval_expression, eval_module)
@assert is_initializesystem(initsys)
if is_split(sys)
return let getter = get_mtkparameters_reconstructor(
initsys, sys; initials = true, unwrap_initials = true, p_constructor)
initsys, sys; initials = true, unwrap_initials = true, p_constructor,
eval_expression, eval_module)
function initprobpmap_split(prob, initsol)
getter(initsol, prob)
end
end
else
return let getter = getu(initsys, parameters(sys; initial_parameters = true)),
p_constructor = p_constructor
return let getter = concrete_getu(
initsys, parameters(sys; initial_parameters = true);
eval_expression, eval_module), p_constructor = p_constructor

function initprobpmap_nosplit(prob, initsol)
return p_constructor(getter(initsol))
Expand Down Expand Up @@ -1039,14 +1049,14 @@ struct GetUpdatedU0{GG, GIU}
get_initial_unknowns::GIU
end

function GetUpdatedU0(sys::AbstractSystem, initsys::AbstractSystem, op::AbstractDict)
function GetUpdatedU0(sys::AbstractSystem, initprob::SciMLBase.AbstractNonlinearProblem, op::AbstractDict)
dvs = unknowns(sys)
eqs = equations(sys)
guessvars = trues(length(dvs))
for (i, var) in enumerate(dvs)
guessvars[i] = !isequal(get(op, var, nothing), Initial(var))
end
get_guessvars = getu(initsys, dvs[guessvars])
get_guessvars = getu(initprob, dvs[guessvars])
get_initial_unknowns = getu(sys, Initial.(dvs))
return GetUpdatedU0(guessvars, get_guessvars, get_initial_unknowns)
end
Expand Down Expand Up @@ -1108,7 +1118,7 @@ function maybe_build_initialization_problem(
guesses, missing_unknowns; implicit_dae = false,
time_dependent_init = is_time_dependent(sys), u0_constructor = identity,
p_constructor = identity, floatT = Float64, initialization_eqs = [],
use_scc = true, kwargs...)
use_scc = true, eval_expression = false, eval_module = @__MODULE__, kwargs...)
guesses = merge(ModelingToolkit.guesses(sys), todict(guesses))

if t === nothing && is_time_dependent(sys)
Expand All @@ -1117,7 +1127,7 @@ function maybe_build_initialization_problem(

initializeprob = ModelingToolkit.InitializationProblem{iip}(
sys, t, op; guesses, time_dependent_init, initialization_eqs,
use_scc, u0_constructor, p_constructor, kwargs...)
use_scc, u0_constructor, p_constructor, eval_expression, eval_module, kwargs...)
if state_values(initializeprob) !== nothing
_u0 = state_values(initializeprob)
if ArrayInterface.ismutable(_u0)
Expand Down Expand Up @@ -1145,15 +1155,16 @@ function maybe_build_initialization_problem(
initializeprob = remake(initializeprob; p = initp)

get_initial_unknowns = if time_dependent_init
GetUpdatedU0(sys, initializeprob.f.sys, op)
GetUpdatedU0(sys, initializeprob, op)
else
nothing
end
meta = InitializationMetadata(
copy(op), copy(guesses), Vector{Equation}(initialization_eqs),
use_scc, time_dependent_init,
ReconstructInitializeprob(
sys, initializeprob.f.sys; u0_constructor, p_constructor),
sys, initializeprob.f.sys; u0_constructor,
p_constructor, eval_expression, eval_module),
get_initial_unknowns, SetInitialUnknowns(sys))

if time_dependent_init
Expand All @@ -1172,10 +1183,9 @@ function maybe_build_initialization_problem(
initializeprobpmap = nothing
else
initializeprobpmap = construct_initializeprobpmap(
sys, initializeprob.f.sys; p_constructor)
sys, initializeprob.f.sys; p_constructor, eval_expression, eval_module)
end

reqd_syms = parameter_symbols(initializeprob)
# we still want the `initialization_data` because it helps with `remake`
if initializeprobmap === nothing && initializeprobpmap === nothing
update_initializeprob! = nothing
Expand All @@ -1186,7 +1196,9 @@ function maybe_build_initialization_problem(
filter!(punknowns) do p
is_parameter_solvable(p, op, defs, guesses) && get(op, p, missing) === missing
end
pvals = getu(initializeprob, punknowns)(initializeprob)
# See comment below for why `getu` is not used here.
_pgetter = build_explicit_observed_function(initializeprob.f.sys, punknowns)
pvals = _pgetter(state_values(initializeprob), parameter_values(initializeprob))
for (p, pval) in zip(punknowns, pvals)
p = unwrap(p)
op[p] = pval
Expand All @@ -1198,7 +1210,13 @@ function maybe_build_initialization_problem(
end

if time_dependent_init
uvals = getu(initializeprob, collect(missing_unknowns))(initializeprob)
# We can't use `getu` here because that goes to `SII.observed`, which goes to
# `ObservedFunctionCache` which uses `eval_expression` and `eval_module`. If
# `eval_expression == true`, this then runs into world-age issues. Building an
# RGF here is fine since it is always discarded. We can't use `eval_module` for
# the RGF since the user may not have run RGF's init.
_ugetter = build_explicit_observed_function(initializeprob.f.sys, collect(missing_unknowns))
uvals = _ugetter(state_values(initializeprob), parameter_values(initializeprob))
for (v, val) in zip(missing_unknowns, uvals)
op[v] = val
end
Expand Down Expand Up @@ -1461,7 +1479,7 @@ function process_SciMLProblem(
if is_time_dependent(sys) && t0 === nothing
t0 = zero(floatT)
end
initialization_data = SciMLBase.remake_initialization_data(
initialization_data = @invokelatest SciMLBase.remake_initialization_data(
sys, kwargs, u0, t0, p, u0, p)
kwargs = merge(kwargs, (; initialization_data))
end
Expand Down Expand Up @@ -1773,7 +1791,8 @@ Construct SciMLProblem `T` with positional arguments `args` and keywords `kwargs
"""
function maybe_codegen_scimlproblem(::Type{Val{false}}, T, args::NamedTuple; kwargs...)
# Call `remake` so it runs initialization if it is trivial
remake(T(args...; kwargs...))
# Use `@invokelatest` to avoid world-age issues with `eval_expression = true`
@invokelatest remake(T(args...; kwargs...))
end

"""
Expand Down
11 changes: 6 additions & 5 deletions test/basic_transformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,12 @@ foofn(x) = 4

@testset "`respecialize`" begin
@parameters p::AbstractFoo p2(t)::AbstractFoo = p q[1:2]::AbstractFoo r
rp,
rp2 = let
only(@parameters p::Bar),
SymbolicUtils.term(operation(p2), arguments(p2)...; type = Baz)
end
rp = only(let p = nothing
@parameters p::Bar
end)
rp2 = only(let p2 = nothing
@parameters p2(t)::Baz
end)
@variables x(t) = 1.0
@named sys1 = System([D(x) ~ foofn(p) + foofn(p2) + x], t, [x], [p, p2, q, r])

Expand Down
2 changes: 1 addition & 1 deletion test/initializationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ end
@test integ.ps[param]≈val rtol=1e-5
# some algorithms are a little temperamental
sol = solve(prob, alg)
@test sol.ps[param]≈val rtol=1e-5
@test sol.ps[param]≈val rtol=1e-5 broken=(alg===SimpleNewtonRaphson())
@test SciMLBase.successful_retcode(sol)
end

Expand Down
13 changes: 12 additions & 1 deletion test/mtkparameters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ ps = MTKParameters(
(BlockedArray([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [3, 3]),
BlockedArray(falses(1), [1, 0])),
(), (), ())
@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}}
@test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, BitVector}
tsidx1 = 1
tsidx2 = 2
@test length(ps.discrete[1][Block(tsidx1)]) == 3
Expand All @@ -368,3 +368,14 @@ with_updated_parameter_timeseries_values(
sys, ps, tsidx1 => ModelingToolkit.NestedGetIndex(([10.0, 11.0, 12.0], [false])))
@test ps.discrete[1][Block(tsidx1)] == [10.0, 11.0, 12.0]
@test ps.discrete[2][Block(tsidx1)][] == false

@testset "Avoid specialization of nonnumeric parameters on `remake_buffer`" begin
@variables x(t)
@parameters p::Any
@named sys = System(D(x) ~ x, t, [x], [p])
sys = complete(sys)
ps = MTKParameters(sys, [p => 1.0])
@test ps.nonnumeric isa Tuple{Vector{Any}}
ps2 = remake_buffer(sys, ps, [p], [:a])
@test ps2.nonnumeric isa Tuple{Vector{Any}}
end
3 changes: 3 additions & 0 deletions test/precompile_test.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Test
using ModelingToolkit
using OrdinaryDiffEqDefault

using Distributed

Expand Down Expand Up @@ -38,3 +39,5 @@ ODEPrecompileTest.f_eval_bad(u, p, 0.1)
@test parentmodule(typeof(ODEPrecompileTest.f_eval_good.f.f_oop)) ==
ODEPrecompileTest
@test ODEPrecompileTest.f_eval_good(u, p, 0.1) == [4, 0, -16]

@test_nowarn solve(ODEPrecompileTest.prob_eval)
20 changes: 20 additions & 0 deletions test/precompile_test/ODEPrecompileTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,24 @@ const f_eval_bad = system(; eval_expression = true, eval_module = @__MODULE__)
# Change the module the eval'd function is eval'd into to be the containing module,
# which should make it be in the package image
const f_eval_good = system(; eval_expression = true, eval_module = @__MODULE__)

function problem(; kwargs...)
# Define some variables
@independent_variables t
@parameters σ ρ β
@variables x(t) y(t) z(t)
D = Differential(t)

# Define a differential equation
eqs = [D(x) ~ σ * (y - x),
D(y) ~ x * (ρ - z) - y,
D(z) ~ x * y - β * z]

@named de = System(eqs, t)
de = complete(de)
return ODEProblem(de, [x => 1, y => 0, z => 0, σ => 10, ρ => 28, β => 8/3], (0.0, 5.0); kwargs...)
end

const prob_eval = problem(; eval_expression = true, eval_module = @__MODULE__)

end
Loading