Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Work toward merging clock processing with the common interface #1949

Merged
merged 11 commits into from
Nov 21, 2022
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ export @variables, @parameters, @constants
export @named, @nonamespace, @namespace, extend, compose, complete
export debug_system

export Continuous, Discrete, sampletime, input_timedomain, output_timedomain
#export Continuous, Discrete, sampletime, input_timedomain, output_timedomain
#export has_discrete_domain, has_continuous_domain
#export is_discrete_domain, is_continuous_domain, is_hybrid_domain
export Sample, Hold, Shift, ShiftIndex
Expand Down
1 change: 1 addition & 0 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct Shift <: Operator
steps::Int
Shift(t, steps = 1) = new(value(t), steps)
end
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps
function (D::Shift)(x, allow_zero = false)
!allow_zero && D.steps == 0 && return x
Term{symtype(x)}(D, Any[x])
Expand Down
3 changes: 2 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ for prop in [:eqs
:torn_matching
:tearing_state
:substitutions
:metadata]
:metadata
:discrete_subsystems]
fname1 = Symbol(:get_, prop)
fname2 = Symbol(:has_, prop)
@eval begin
Expand Down
72 changes: 43 additions & 29 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@ end
function ClockInference(ts::TearingState)
@unpack fullvars, structure = ts
@unpack graph = structure
eq_domain = Vector{TimeDomain}(undef, nsrcs(graph))
var_domain = Vector{TimeDomain}(undef, ndsts(graph))
eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)]
var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)]
inferred = BitSet()
for (i, v) in enumerate(fullvars)
d = get_time_domain(v)
if d isa Union{AbstractClock, Continuous}
push!(inferred, i)
dd = d
else
dd = Inferred()
var_domain[i] = dd
end
var_domain[i] = dd
end
ClockInference(ts, eq_domain, var_domain, inferred)
end
Expand All @@ -28,6 +26,7 @@ function infer_clocks!(ci::ClockInference)
@unpack ts, eq_domain, var_domain, inferred = ci
@unpack fullvars = ts
@unpack graph = ts.structure
isempty(inferred) && return ci
# TODO: add a graph type to do this lazily
var_graph = SimpleGraph(ndsts(graph))
for eq in 𝑠vertices(graph)
Expand Down Expand Up @@ -58,7 +57,6 @@ function infer_clocks!(ci::ClockInference)
vd = var_domain[v]
eqs = 𝑑neighbors(graph, v)
isempty(eqs) && continue
#eq = first(eqs)
for eq in eqs
eq_domain[eq] = vd
end
Expand Down Expand Up @@ -116,7 +114,6 @@ function split_system(ci::ClockInference)
@assert cid!==0 "Internal error! Variable $(fullvars[i]) doesn't have a inferred time domain."
var_to_cid[i] = cid
v = fullvars[i]
#TODO: remove Inferred*
if istree(v) && (o = operation(v)) isa Operator &&
input_timedomain(o) != output_timedomain(o)
push!(input_idxs[cid], i)
Expand Down Expand Up @@ -147,21 +144,28 @@ function split_system(ci::ClockInference)
@set! ts_i.structure.eq_to_diff = eq_to_diff
tss[id] = ts_i
end
return tss, inputs, continuous_id
return tss, inputs, continuous_id, id_to_clock
end

function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = true)
function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
checkbounds = true,
eval_module = @__MODULE__, eval_expression = true)
@static if VERSION < v"1.7"
error("The `generate_discrete_affect` function requires at least Julia 1.7")
end
out = Sym{Any}(:out)
appended_parameters = parameters(syss[continuous_id])
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
offset = length(appended_parameters)
affect_funs = []
svs = []
clocks = TimeDomain[]
for (i, (sys, input)) in enumerate(zip(syss, inputs))
i == continuous_id && continue
push!(clocks, id_to_clock[i])
subs = get_substitutions(sys)
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
let_body = SetArray(!check_bounds, out, rhss(equations(sys)))
let_body = SetArray(!checkbounds, out, rhss(equations(sys)))
let_block = Let(assignments, let_body, false)
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
# TODO: filter the needed ones
Expand Down Expand Up @@ -190,27 +194,37 @@ function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = tr
cont_to_disc_idxs = (offset + 1):(offset += ni)
input_offset = offset
disc_range = (offset + 1):(offset += ns)
affect! = quote
function affect!(integrator, saved_values)
@unpack u, p, t = integrator
c2d_obs = $cont_to_disc_obs
d2c_obs = $disc_to_cont_obs
c2d_view = view(p, $cont_to_disc_idxs)
d2c_view = view(p, $disc_to_cont_idxs)
disc_state = view(p, $disc_range)
disc = $disc
# Write continuous info to discrete
# Write discrete info to continuous
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
copyto!(d2c_view, d2c_obs(disc_state, p, t))
push!(saved_values.t, t)
push!(saved_values.saveval, Base.@ntuple $ns i->p[$input_offset + i])
disc(disc_state, disc_state, p, t)
end
save_vec = Expr(:ref, :Float64)
for i in 1:ns
push!(save_vec.args, :(p[$(input_offset + i)]))
end
sv = SavedValues(Float64, NTuple{ns, Float64})
affect! = :(function (integrator, saved_values)
@unpack u, p, t = integrator
c2d_obs = $cont_to_disc_obs
d2c_obs = $disc_to_cont_obs
c2d_view = view(p, $cont_to_disc_idxs)
d2c_view = view(p, $disc_to_cont_idxs)
disc_state = view(p, $disc_range)
disc = $disc
# Write continuous info to discrete
# Write discrete info to continuous
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
copyto!(d2c_view, d2c_obs(disc_state, p, t))
push!(saved_values.t, t)
push!(saved_values.saveval, $save_vec)
disc(disc_state, disc_state, p, t)
end)
sv = SavedValues(Float64, Vector{Float64})
push!(affect_funs, affect!)
push!(svs, sv)
end
return map(a -> toexpr(LiteralExpr(a)), affect_funs), svs, appended_parameters
if eval_expression
affects = map(affect_funs) do a
@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a)))
end
else
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
end
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
return affects, clocks, svs, appended_parameters, defaults
end
38 changes: 34 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,12 @@ function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs...
ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

struct DiscreteSaveAffect{F, S} <: Function
f::F
s::S
end
(d::DiscreteSaveAffect)(args...) = d.f(args..., d.s)

function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
Expand All @@ -698,14 +704,38 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
has_difference = has_difference,
check_length, kwargs...)
cbs = process_events(sys; callback, has_difference, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
else
error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)
pt = something(get_metadata(sys), StandardODEProblem())

if cbs === nothing
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs...)
else
ODEProblem{iip}(f, u0, tspan, p, pt; callback = cbs, kwargs...)
kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

Expand Down
10 changes: 7 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,17 @@ struct ODESystem <: AbstractODESystem
complete: if a model `sys` is complete, then `sys.x` no longer performs namespacing.
"""
complete::Bool
"""
discrete_subsystems: a list of discrete subsystems
"""
discrete_subsystems::Any

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
torn_matching, connector_type, preface, cevents,
devents, metadata = nothing, tearing_state = nothing,
substitutions = nothing, complete = false;
checks::Union{Bool, Int} = true)
substitutions = nothing, complete = false,
discrete_subsystems = nothing; checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckComponents) > 0
check_variables(dvs, iv)
check_parameters(ps, iv)
Expand All @@ -145,7 +149,7 @@ struct ODESystem <: AbstractODESystem
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
connector_type, preface, cevents, devents, metadata, tearing_state,
substitutions, complete)
substitutions, complete, discrete_subsystems)
end
end

Expand Down
58 changes: 54 additions & 4 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
value, InvalidSystemException, isdifferential, _iszero,
isparameter, isconstant,
independent_variables, SparseMatrixCLIL, AbstractSystem,
equations, isirreducible
equations, isirreducible, input_timedomain, TimeDomain
using ..BipartiteGraphs
import ..BipartiteGraphs: invview, complete
using Graphs
Expand Down Expand Up @@ -285,7 +285,7 @@ function TearingState(sys; quick_cancel = false, check = true)
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
set_incidence = false
var = only(arguments(var))
var = setmetadata(var, ModelingToolkit.TimeDomain, it)
var = setmetadata(var, TimeDomain, it)
@goto ANOTHER_VAR
end
end
Expand Down Expand Up @@ -452,8 +452,59 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
end

# TODO: clean up
function merge_io(io, inputs)
isempty(inputs) && return io
if io === nothing
io = (inputs, [])
else
io = ([inputs; io[1]], io[2])
end
return io
end

function structural_simplify!(state::TearingState, io = nothing; simplify = false,
check_consistency = true, kwargs...)
if state.sys isa ODESystem
ci = ModelingToolkit.ClockInference(state)
ModelingToolkit.infer_clocks!(ci)
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
cont_io = merge_io(io, inputs[continuous_id])
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
check_consistency,
kwargs...)
if length(tss) > 1
# TODO: rename it to something else
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
# Note that the appended_parameters must agree with
# `generate_discrete_affect`!
appended_parameters = parameters(sys)
for (i, state) in enumerate(tss)
if i == continuous_id
discrete_subsystems[i] = sys
continue
end
dist_io = merge_io(io, inputs[i])
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
kwargs...)
append!(appended_parameters, inputs[i], states(ss))
discrete_subsystems[i] = ss
end
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id,
id_to_clock
@set! sys.ps = appended_parameters
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
else
sys, input_idxs = _structural_simplify!(state, io; simplify, check_consistency,
kwargs...)
end
has_io = io !== nothing
return has_io ? (sys, input_idxs) : sys
end

function _structural_simplify!(state::TearingState, io; simplify = false,
check_consistency = true, kwargs...)
has_io = io !== nothing
has_io && ModelingToolkit.markio!(state, io...)
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
Expand All @@ -464,8 +515,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
sys = ModelingToolkit.dummy_derivative(sys, state, ag; simplify)
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates)
ModelingToolkit.invalidate_cache!(sys)
return has_io ? (sys, input_idxs) : sys
ModelingToolkit.invalidate_cache!(sys), input_idxs
end

end # module
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -858,3 +858,5 @@ function fast_substitute(expr, pair::Pair)
symtype(expr);
metadata = metadata(expr))
end

normalize_to_differential(s) = s
13 changes: 12 additions & 1 deletion src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ isoutput(x) = isvarkind(VariableOutput, x)
isirreducible(x) = isvarkind(VariableIrreducible, x)
state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64

function default_toterm(x)
if istree(x) && (op = operation(x)) isa Operator
if !(op isa Differential)
x = normalize_to_differential(op)(arguments(x)...)
end
Symbolics.diff2term(x)
else
x
end
end

"""
$(SIGNATURES)

Expand All @@ -44,7 +55,7 @@ and creates the array of values in the correct order with default values when
applicable.
"""
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
toterm = Symbolics.diff2term, promotetoconcrete = nothing,
toterm = default_toterm, promotetoconcrete = nothing,
tofloat = true, use_union = false)
varlist = collect(map(unwrap, varlist))

Expand Down
Loading