Skip to content

Commit

Permalink
Merge pull request #1949 from SciML/myb_fb/clocks
Browse files Browse the repository at this point in the history
WIP: work toward merging clock processing with the common interface
  • Loading branch information
YingboMa committed Nov 21, 2022
2 parents 4ab3846 + 17be238 commit 6e49923
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 89 deletions.
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ export @variables, @parameters, @constants
export @named, @nonamespace, @namespace, extend, compose, complete
export debug_system

export Continuous, Discrete, sampletime, input_timedomain, output_timedomain
#export Continuous, Discrete, sampletime, input_timedomain, output_timedomain
#export has_discrete_domain, has_continuous_domain
#export is_discrete_domain, is_continuous_domain, is_hybrid_domain
export Sample, Hold, Shift, ShiftIndex
Expand Down
1 change: 1 addition & 0 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct Shift <: Operator
steps::Int
Shift(t, steps = 1) = new(value(t), steps)
end
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps
function (D::Shift)(x, allow_zero = false)
!allow_zero && D.steps == 0 && return x
Term{symtype(x)}(D, Any[x])
Expand Down
3 changes: 2 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ for prop in [:eqs
:torn_matching
:tearing_state
:substitutions
:metadata]
:metadata
:discrete_subsystems]
fname1 = Symbol(:get_, prop)
fname2 = Symbol(:has_, prop)
@eval begin
Expand Down
72 changes: 43 additions & 29 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,16 @@ end
function ClockInference(ts::TearingState)
@unpack fullvars, structure = ts
@unpack graph = structure
eq_domain = Vector{TimeDomain}(undef, nsrcs(graph))
var_domain = Vector{TimeDomain}(undef, ndsts(graph))
eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)]
var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)]
inferred = BitSet()
for (i, v) in enumerate(fullvars)
d = get_time_domain(v)
if d isa Union{AbstractClock, Continuous}
push!(inferred, i)
dd = d
else
dd = Inferred()
var_domain[i] = dd
end
var_domain[i] = dd
end
ClockInference(ts, eq_domain, var_domain, inferred)
end
Expand All @@ -28,6 +26,7 @@ function infer_clocks!(ci::ClockInference)
@unpack ts, eq_domain, var_domain, inferred = ci
@unpack fullvars = ts
@unpack graph = ts.structure
isempty(inferred) && return ci
# TODO: add a graph type to do this lazily
var_graph = SimpleGraph(ndsts(graph))
for eq in 𝑠vertices(graph)
Expand Down Expand Up @@ -58,7 +57,6 @@ function infer_clocks!(ci::ClockInference)
vd = var_domain[v]
eqs = 𝑑neighbors(graph, v)
isempty(eqs) && continue
#eq = first(eqs)
for eq in eqs
eq_domain[eq] = vd
end
Expand Down Expand Up @@ -116,7 +114,6 @@ function split_system(ci::ClockInference)
@assert cid!==0 "Internal error! Variable $(fullvars[i]) doesn't have a inferred time domain."
var_to_cid[i] = cid
v = fullvars[i]
#TODO: remove Inferred*
if istree(v) && (o = operation(v)) isa Operator &&
input_timedomain(o) != output_timedomain(o)
push!(input_idxs[cid], i)
Expand Down Expand Up @@ -147,21 +144,28 @@ function split_system(ci::ClockInference)
@set! ts_i.structure.eq_to_diff = eq_to_diff
tss[id] = ts_i
end
return tss, inputs, continuous_id
return tss, inputs, continuous_id, id_to_clock
end

function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = true)
function generate_discrete_affect(syss, inputs, continuous_id, id_to_clock;
checkbounds = true,
eval_module = @__MODULE__, eval_expression = true)
@static if VERSION < v"1.7"
error("The `generate_discrete_affect` function requires at least Julia 1.7")
end
out = Sym{Any}(:out)
appended_parameters = parameters(syss[continuous_id])
param_to_idx = Dict{Any, Int}(reverse(en) for en in enumerate(appended_parameters))
offset = length(appended_parameters)
affect_funs = []
svs = []
clocks = TimeDomain[]
for (i, (sys, input)) in enumerate(zip(syss, inputs))
i == continuous_id && continue
push!(clocks, id_to_clock[i])
subs = get_substitutions(sys)
assignments = map(s -> Assignment(s.lhs, s.rhs), subs.subs)
let_body = SetArray(!check_bounds, out, rhss(equations(sys)))
let_body = SetArray(!checkbounds, out, rhss(equations(sys)))
let_block = Let(assignments, let_body, false)
needed_cont_to_disc_obs = map(v -> arguments(v)[1], input)
# TODO: filter the needed ones
Expand Down Expand Up @@ -190,27 +194,37 @@ function generate_discrete_affect(syss, inputs, continuous_id, check_bounds = tr
cont_to_disc_idxs = (offset + 1):(offset += ni)
input_offset = offset
disc_range = (offset + 1):(offset += ns)
affect! = quote
function affect!(integrator, saved_values)
@unpack u, p, t = integrator
c2d_obs = $cont_to_disc_obs
d2c_obs = $disc_to_cont_obs
c2d_view = view(p, $cont_to_disc_idxs)
d2c_view = view(p, $disc_to_cont_idxs)
disc_state = view(p, $disc_range)
disc = $disc
# Write continuous info to discrete
# Write discrete info to continuous
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
copyto!(d2c_view, d2c_obs(disc_state, p, t))
push!(saved_values.t, t)
push!(saved_values.saveval, Base.@ntuple $ns i->p[$input_offset + i])
disc(disc_state, disc_state, p, t)
end
save_vec = Expr(:ref, :Float64)
for i in 1:ns
push!(save_vec.args, :(p[$(input_offset + i)]))
end
sv = SavedValues(Float64, NTuple{ns, Float64})
affect! = :(function (integrator, saved_values)
@unpack u, p, t = integrator
c2d_obs = $cont_to_disc_obs
d2c_obs = $disc_to_cont_obs
c2d_view = view(p, $cont_to_disc_idxs)
d2c_view = view(p, $disc_to_cont_idxs)
disc_state = view(p, $disc_range)
disc = $disc
# Write continuous info to discrete
# Write discrete info to continuous
copyto!(c2d_view, c2d_obs(integrator.u, p, t))
copyto!(d2c_view, d2c_obs(disc_state, p, t))
push!(saved_values.t, t)
push!(saved_values.saveval, $save_vec)
disc(disc_state, disc_state, p, t)
end)
sv = SavedValues(Float64, Vector{Float64})
push!(affect_funs, affect!)
push!(svs, sv)
end
return map(a -> toexpr(LiteralExpr(a)), affect_funs), svs, appended_parameters
if eval_expression
affects = map(affect_funs) do a
@RuntimeGeneratedFunction(eval_module, toexpr(LiteralExpr(a)))
end
else
affects = map(a -> toexpr(LiteralExpr(a)), affect_funs)
end
defaults = Dict{Any, Any}(v => 0.0 for v in Iterators.flatten(inputs))
return affects, clocks, svs, appended_parameters, defaults
end
38 changes: 34 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,12 @@ function DiffEqBase.ODEProblem{false}(sys::AbstractODESystem, args...; kwargs...
ODEProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

struct DiscreteSaveAffect{F, S} <: Function
f::F
s::S
end
(d::DiscreteSaveAffect)(args...) = d.f(args..., d.s)

function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
Expand All @@ -698,14 +704,38 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
has_difference = has_difference,
check_length, kwargs...)
cbs = process_events(sys; callback, has_difference, kwargs...)
if has_discrete_subsystems(sys) && (dss = get_discrete_subsystems(sys)) !== nothing
affects, clocks, svs = ModelingToolkit.generate_discrete_affect(dss...)
discrete_cbs = map(affects, clocks, svs) do affect, clock, sv
if clock isa Clock
PeriodicCallback(DiscreteSaveAffect(affect, sv), clock.dt)
else
error("$clock is not a supported clock type.")
end
end
if cbs === nothing
if length(discrete_cbs) == 1
cbs = only(discrete_cbs)
else
cbs = CallbackSet(discrete_cbs...)
end
else
cbs = CallbackSet(cbs, discrete_cbs)
end
else
svs = nothing
end
kwargs = filter_kwargs(kwargs)
pt = something(get_metadata(sys), StandardODEProblem())

if cbs === nothing
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs...)
else
ODEProblem{iip}(f, u0, tspan, p, pt; callback = cbs, kwargs...)
kwargs1 = (;)
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end
if svs !== nothing
kwargs1 = merge(kwargs1, (disc_saved_values = svs,))
end
ODEProblem{iip}(f, u0, tspan, p, pt; kwargs1..., kwargs...)
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

Expand Down
10 changes: 7 additions & 3 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,17 @@ struct ODESystem <: AbstractODESystem
complete: if a model `sys` is complete, then `sys.x` no longer performs namespacing.
"""
complete::Bool
"""
discrete_subsystems: a list of discrete subsystems
"""
discrete_subsystems::Any

function ODESystem(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad,
jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults,
torn_matching, connector_type, preface, cevents,
devents, metadata = nothing, tearing_state = nothing,
substitutions = nothing, complete = false;
checks::Union{Bool, Int} = true)
substitutions = nothing, complete = false,
discrete_subsystems = nothing; checks::Union{Bool, Int} = true)
if checks == true || (checks & CheckComponents) > 0
check_variables(dvs, iv)
check_parameters(ps, iv)
Expand All @@ -145,7 +149,7 @@ struct ODESystem <: AbstractODESystem
new(tag, deqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac,
ctrl_jac, Wfact, Wfact_t, name, systems, defaults, torn_matching,
connector_type, preface, cevents, devents, metadata, tearing_state,
substitutions, complete)
substitutions, complete, discrete_subsystems)
end
end

Expand Down
58 changes: 54 additions & 4 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import ..ModelingToolkit: isdiffeq, var_from_nested_derivative, vars!, flatten,
value, InvalidSystemException, isdifferential, _iszero,
isparameter, isconstant,
independent_variables, SparseMatrixCLIL, AbstractSystem,
equations, isirreducible
equations, isirreducible, input_timedomain, TimeDomain
using ..BipartiteGraphs
import ..BipartiteGraphs: invview, complete
using Graphs
Expand Down Expand Up @@ -285,7 +285,7 @@ function TearingState(sys; quick_cancel = false, check = true)
!isdifferential(var) && (it = input_timedomain(var)) !== nothing
set_incidence = false
var = only(arguments(var))
var = setmetadata(var, ModelingToolkit.TimeDomain, it)
var = setmetadata(var, TimeDomain, it)
@goto ANOTHER_VAR
end
end
Expand Down Expand Up @@ -452,8 +452,59 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
end

# TODO: clean up
function merge_io(io, inputs)
isempty(inputs) && return io
if io === nothing
io = (inputs, [])
else
io = ([inputs; io[1]], io[2])
end
return io
end

function structural_simplify!(state::TearingState, io = nothing; simplify = false,
check_consistency = true, kwargs...)
if state.sys isa ODESystem
ci = ModelingToolkit.ClockInference(state)
ModelingToolkit.infer_clocks!(ci)
tss, inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
cont_io = merge_io(io, inputs[continuous_id])
sys, input_idxs = _structural_simplify!(tss[continuous_id], cont_io; simplify,
check_consistency,
kwargs...)
if length(tss) > 1
# TODO: rename it to something else
discrete_subsystems = Vector{ODESystem}(undef, length(tss))
# Note that the appended_parameters must agree with
# `generate_discrete_affect`!
appended_parameters = parameters(sys)
for (i, state) in enumerate(tss)
if i == continuous_id
discrete_subsystems[i] = sys
continue
end
dist_io = merge_io(io, inputs[i])
ss, = _structural_simplify!(state, dist_io; simplify, check_consistency,
kwargs...)
append!(appended_parameters, inputs[i], states(ss))
discrete_subsystems[i] = ss
end
@set! sys.discrete_subsystems = discrete_subsystems, inputs, continuous_id,
id_to_clock
@set! sys.ps = appended_parameters
@set! sys.defaults = merge(ModelingToolkit.defaults(sys),
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
else
sys, input_idxs = _structural_simplify!(state, io; simplify, check_consistency,
kwargs...)
end
has_io = io !== nothing
return has_io ? (sys, input_idxs) : sys
end

function _structural_simplify!(state::TearingState, io; simplify = false,
check_consistency = true, kwargs...)
has_io = io !== nothing
has_io && ModelingToolkit.markio!(state, io...)
state, input_idxs = ModelingToolkit.inputs_to_parameters!(state, io)
Expand All @@ -464,8 +515,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
sys = ModelingToolkit.dummy_derivative(sys, state, ag; simplify)
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
@set! sys.observed = ModelingToolkit.topsort_equations(observed(sys), fullstates)
ModelingToolkit.invalidate_cache!(sys)
return has_io ? (sys, input_idxs) : sys
ModelingToolkit.invalidate_cache!(sys), input_idxs
end

end # module
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -858,3 +858,5 @@ function fast_substitute(expr, pair::Pair)
symtype(expr);
metadata = metadata(expr))
end

normalize_to_differential(s) = s
13 changes: 12 additions & 1 deletion src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,17 @@ isoutput(x) = isvarkind(VariableOutput, x)
isirreducible(x) = isvarkind(VariableIrreducible, x)
state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64

function default_toterm(x)
if istree(x) && (op = operation(x)) isa Operator
if !(op isa Differential)
x = normalize_to_differential(op)(arguments(x)...)
end
Symbolics.diff2term(x)
else
x
end
end

"""
$(SIGNATURES)
Expand All @@ -44,7 +55,7 @@ and creates the array of values in the correct order with default values when
applicable.
"""
function varmap_to_vars(varmap, varlist; defaults = Dict(), check = true,
toterm = Symbolics.diff2term, promotetoconcrete = nothing,
toterm = default_toterm, promotetoconcrete = nothing,
tofloat = true, use_union = false)
varlist = collect(map(unwrap, varlist))

Expand Down
Loading

0 comments on commit 6e49923

Please sign in to comment.