Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
11477b3
`root_eqs` are stored in ODESystem
baggepinnen Nov 8, 2021
1d52025
add docs and refactor `gen_nlsolve`
baggepinnen Nov 8, 2021
2b9ca8c
ODEProblem generates callback to handle root-finding equations
baggepinnen Nov 8, 2021
a7e1cfe
include testset
baggepinnen Nov 8, 2021
b5b5778
update dict key storing callback
baggepinnen Nov 9, 2021
e58ddb3
use build_function instead of gen_nlsolve
baggepinnen Nov 9, 2021
ce89be9
handle `get_root_eqs` separately
baggepinnen Nov 9, 2021
1ef8880
add test for multiple root_eqs on the same state
baggepinnen Nov 9, 2021
d0cd1b8
function name change
baggepinnen Nov 9, 2021
6e3758e
add documentation entry
baggepinnen Nov 9, 2021
43e71fb
determine length of vector callback based on num equations
baggepinnen Nov 9, 2021
01f3626
introduce EqAffectPair type
baggepinnen Nov 10, 2021
4e21ae5
add bouncing ball example
baggepinnen Nov 10, 2021
feae414
affect given by equations
baggepinnen Nov 10, 2021
6305d38
rename root_eqs -> events
baggepinnen Nov 10, 2021
5467944
test and document multiple events
baggepinnen Nov 10, 2021
2835bac
rename to SymbolicContinuousCallback
baggepinnen Nov 10, 2021
2367ff5
add test for multi-variable affect
baggepinnen Nov 10, 2021
b800721
fix argument check
baggepinnen Nov 10, 2021
8c8c581
add get_callback function
baggepinnen Nov 10, 2021
eca0f9b
handle and merge user provided callbacks in ODEProblem
baggepinnen Nov 10, 2021
858d7b7
better handling of callback kwarg
baggepinnen Nov 11, 2021
73d98ac
fuse continuous events into VectorContinuousCallback
baggepinnen Nov 11, 2021
22ff700
better check for empty callbacks
baggepinnen Nov 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/basics/AbstractSystem.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ Optionally, a system could have:

- `observed(sys)`: All observed equations of the system and its subsystems.
- `get_observed(sys)`: Observed equations of the current-level system.
- `get_continuous_events(sys)`: `SymbolicContinuousCallback`s of the current-level system.
- `get_defaults(sys)`: A `Dict` that maps variables into their default values.
- `independent_variables(sys)`: The independent variables of a system.
- `get_noiseeqs(sys)`: Noise equations of the current-level system.
Expand Down
91 changes: 91 additions & 0 deletions docs/src/basics/Composition.md
Original file line number Diff line number Diff line change
Expand Up @@ -250,3 +250,94 @@ strongly connected components calculated during the process of simplification
as the basis for building pre-simplified nonlinear systems in the implicit
solving. In summary: these problems are structurally modified, but could be
more efficient and more stable.

## Components with discontinuous dynamics
When modeling, e.g., impacts, saturations or Coulomb friction, the dynamic equations are discontinuous in either the state or one of its derivatives. This causes the solver to take very small steps around the discontinuity, and sometimes leads to early stopping due to `dt <= dt_min`. The correct way to handle such dynamics is to tell the solver about the discontinuity be means of a root-finding equation. [`ODEsystem`](@ref)s accept a keyword argument `continuous_events`
```
ODESystem(eqs, ...; continuous_events::Vector{Equation})
ODESystem(eqs, ...; continuous_events::Pair{Vector{Equation}, Vector{Equation}})
```
where equations can be added that evaluate to 0 at discontinuities.

To model events that have an affect on the state, provide `events::Pair{Vector{Equation}, Vector{Equation}}` where the first entry in the pair is a vector of equations describing event conditions, and the second vector of equations describe the affect on the state. The affect equations must be on the form
```
single_state_variable ~ expression_involving_any_variables
```

### Example: Friction
The system below illustrates how this can be used to model Coulomb friction
```julia
using ModelingToolkit, OrdinaryDiffEq, Plots
function UnitMassWithFriction(k; name)
@variables t x(t)=0 v(t)=0
D = Differential(t)
eqs = [
D(x) ~ v
D(v) ~ sin(t) - k*sign(v) # f = ma, sinusoidal force acting on the mass, and Coulomb friction opposing the movement
]
ODESystem(eqs, t, continuous_events=[v ~ 0], name=name) # when v = 0 there is a discontinuity
end
@named m = UnitMassWithFriction(0.7)
prob = ODEProblem(m, Pair[], (0, 10pi))
sol = solve(prob, Tsit5())
plot(sol)
```

### Example: Bouncing ball
In the documentation for DifferentialEquations, we have an example where a bouncing ball is simulated using callbacks which has an `affect!` on the state. We can model the same system using ModelingToolkit like this

```julia
@variables t x(t)=1 v(t)=0
D = Differential(t)

root_eqs = [x ~ 0] # the event happens at the ground x(t) = 0
affect = [v ~ -v] # the effect is that the velocity changes sign

@named ball = ODESystem([
D(x) ~ v
D(v) ~ -9.8
], t, continuous_events = root_eqs => affect) # equation => affect

ball = structural_simplify(ball)

tspan = (0.0,5.0)
prob = ODEProblem(ball, Pair[], tspan)
sol = solve(prob,Tsit5())
@assert 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close
plot(sol)
```

### Test bouncing ball in 2D with walls
Multiple events? No problem! This example models a bouncing ball in 2D that is enclosed by two walls at $y = \pm 1.5$.
```julia
@variables t x(t)=1 y(t)=0 vx(t)=0 vy(t)=2
D = Differential(t)

continuous_events = [ # This time we have a vector of pairs
[x ~ 0] => [vx ~ -vx]
[y ~ -1.5, y ~ 1.5] => [vy ~ -vy]
]

@named ball = ODESystem([
D(x) ~ vx,
D(y) ~ vy,
D(vx) ~ -9.8-0.1vx, # gravity + some small air resistance
D(vy) ~ -0.1vy,
], t, continuous_events = continuous_events)


ball = structural_simplify(ball)

tspan = (0.0,10.0)
prob = ODEProblem(ball, Pair[], tspan)

sol = solve(prob,Tsit5())
@assert 0 <= minimum(sol[x]) <= 1e-10 # the ball never went through the floor but got very close
@assert minimum(sol[y]) > -1.5 # check wall conditions
@assert maximum(sol[y]) < 1.5 # check wall conditions

tv = sort([LinRange(0, 10, 200); sol.t])
plot(sol(tv)[y], sol(tv)[x], line_z=tv)
vline!([-1.5, 1.5], l=(:black, 5), primary=false)
hline!([0], l=(:black, 5), primary=false)
```
35 changes: 27 additions & 8 deletions src/structural_transformation/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,39 @@ function partitions_dag(s::SystemStructure)
sparse(I, J, true, n, n)
end

function gen_nlsolve(sys, eqs, vars; checkbounds=true)
@assert !isempty(vars)
@assert length(eqs) == length(vars)
"""
exprs = gen_nlsolve(eqs::Vector{Equation}, vars::Vector, u0map::Dict; checkbounds = true)

Generate `SymbolicUtils` expressions for a root-finding function based on `eqs`,
as well as a call to the root-finding solver.

`exprs` is a two element vector
```
exprs = [fname = f, numerical_nlsolve(fname, ...)]
```

# Arguments:
- `eqs`: Equations to find roots of.
- `vars`: ???
- `u0map`: A `Dict` which maps variables in `eqs` to values, e.g., `defaults(sys)` if `eqs = equations(sys)`.
- `checkbounds`: Apply bounds checking in the generated code.
"""
function gen_nlsolve(eqs, vars, u0map::AbstractDict; checkbounds=true)
isempty(vars) && throw(ArgumentError("vars may not be empty"))
length(eqs) == length(vars) || throw(ArgumentError("vars must be of the same length as the number of equations to find the roots of"))
rhss = map(x->x.rhs, eqs)
# We use `vars` instead of `graph` to capture parameters, too.
allvars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))
params = setdiff(allvars, vars)
params = setdiff(allvars, vars) # these are not the subject of the root finding

u0map = defaults(sys)
# splatting to tighten the type
u0 = [map(var->get(u0map, var, 1e-3), vars)...]
# specialize on the scalar case
isscalar = length(u0) == 1
u0 = isscalar ? u0[1] : SVector(u0...)

fname = gensym("fun")
# f is the function to find roots on
f = Func(
[
DestructuredArgs(vars, inbounds=!checkbounds)
Expand All @@ -150,6 +167,7 @@ function gen_nlsolve(sys, eqs, vars; checkbounds=true)
isscalar ? rhss[1] : MakeArray(rhss, SVector)
) |> SymbolicUtils.Code.toexpr

# solver call contains code to call the root-finding solver on the function f
solver_call = LiteralExpr(quote
$numerical_nlsolve(
$fname,
Expand All @@ -174,8 +192,9 @@ function get_torn_eqs_vars(sys; checkbounds=true)

torn_eqs = map(idxs-> eqs[idxs], map(x->x.e_residual, partitions))
torn_vars = map(idxs->vars[idxs], map(x->x.v_residual, partitions))
u0map = defaults(sys)

gen_nlsolve.((sys,), torn_eqs, torn_vars, checkbounds=checkbounds)
gen_nlsolve.(torn_eqs, torn_vars, (u0map,), checkbounds=checkbounds)
end

function build_torn_function(
Expand Down Expand Up @@ -308,8 +327,8 @@ function build_observed_function(

torn_eqs = map(idxs-> eqs[idxs.e_residual], subset)
torn_vars = map(idxs->fullvars[idxs.v_residual], subset)

solves = gen_nlsolve.((sys,), torn_eqs, torn_vars; checkbounds=checkbounds)
u0map = defaults(sys)
solves = gen_nlsolve.(torn_eqs, torn_vars, (u0map,); checkbounds=checkbounds)
else
solves = []
end
Expand Down
49 changes: 47 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,40 @@ independent_variables(sys::AbstractTimeDependentSystem) = [getfield(sys, :iv)]
independent_variables(sys::AbstractTimeIndependentSystem) = []
independent_variables(sys::AbstractMultivariateSystem) = getfield(sys, :ivs)

const NULL_AFFECT = Equation[]
struct SymbolicContinuousCallback
eqs::Vector{Equation}
affect::Vector{Equation}
SymbolicContinuousCallback(eqs::Vector{Equation}, affect=NULL_AFFECT) = new(eqs, affect) # Default affect to nothing
end

Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback) = isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect)

to_equation_vector(eq::Equation) = [eq]
to_equation_vector(eqs::Vector{Equation}) = eqs
function to_equation_vector(eqs::Vector{Any})
isempty(eqs) || error("This should never happen")
Equation[]
end

SymbolicContinuousCallback(args...) = SymbolicContinuousCallback(to_equation_vector.(args)...) # wrap eq in vector
SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2])
SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough

SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb]
SymbolicContinuousCallbacks(cbs::Vector{<:SymbolicContinuousCallback}) = cbs
SymbolicContinuousCallbacks(cbs::Vector) = SymbolicContinuousCallback.(cbs)
SymbolicContinuousCallbacks(ve::Vector{Equation}) = SymbolicContinuousCallbacks(SymbolicContinuousCallback(ve))
SymbolicContinuousCallbacks(others) = SymbolicContinuousCallbacks(SymbolicContinuousCallback(others))
SymbolicContinuousCallbacks(::Nothing) = SymbolicContinuousCallbacks(Equation[])

equations(cb::SymbolicContinuousCallback) = cb.eqs
equations(cbs::Vector{<:SymbolicContinuousCallback}) = reduce(vcat, [equations(cb) for cb in cbs])
affect_equations(cb::SymbolicContinuousCallback) = cb.affect
affect_equations(cbs::Vector{SymbolicContinuousCallback}) = reduce(vcat, [affect_equations(cb) for cb in cbs])
namespace_equation(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback = SymbolicContinuousCallback(namespace_equation.(equations(cb), (s, )), namespace_equation.(affect_equations(cb), (s, )))


function structure(sys::AbstractSystem)
s = get_structure(sys)
s isa SystemStructure || throw(ArgumentError("SystemStructure is not yet initialized, please run `sys = initialize_system_structure(sys)` or `sys = alias_elimination(sys)`."))
Expand Down Expand Up @@ -415,6 +449,15 @@ function observed(sys::AbstractSystem)
init=Equation[])]
end

function continuous_events(sys::AbstractSystem)
obs = get_continuous_events(sys)
systems = get_systems(sys)
[obs;
reduce(vcat,
(map(o->namespace_equation(o, s), continuous_events(s)) for s in systems),
init=SymbolicContinuousCallback[])]
end

Base.@deprecate default_u0(x) defaults(x) false
Base.@deprecate default_p(x) defaults(x) false
function defaults(sys::AbstractSystem)
Expand Down Expand Up @@ -941,6 +984,7 @@ function Base.hash(sys::AbstractSystem, s::UInt)
s = foldr(hash, get_eqs(sys), init=s)
end
s = foldr(hash, get_observed(sys), init=s)
s = foldr(hash, get_continuous_events(sys), init=s)
s = hash(independent_variables(sys), s)
return s
end
Expand Down Expand Up @@ -968,13 +1012,14 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol=nameo
sts = union(get_states(basesys), get_states(sys))
ps = union(get_ps(basesys), get_ps(sys))
obs = union(get_observed(basesys), get_observed(sys))
evs = union(get_continuous_events(basesys), get_continuous_events(sys))
defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys`
syss = union(get_systems(basesys), get_systems(sys))

if length(ivs) == 0
T(eqs, sts, ps, observed = obs, defaults = defs, name=name, systems = syss)
T(eqs, sts, ps, observed = obs, defaults = defs, name=name, systems = syss, continuous_events=evs)
elseif length(ivs) == 1
T(eqs, ivs[1], sts, ps, observed = obs, defaults = defs, name = name, systems = syss)
T(eqs, ivs[1], sts, ps, observed = obs, defaults = defs, name = name, systems = syss, continuous_events=evs)
end
end

Expand Down
120 changes: 116 additions & 4 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,105 @@ function generate_difference_cb(sys::ODESystem, dvs = states(sys), ps = paramete
PeriodicCallback(cb_affect!, first(dt))
end

function generate_rootfinding_callback(sys::ODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
cbs = continuous_events(sys)
isempty(cbs) && return nothing
generate_rootfinding_callback(cbs, sys, dvs, ps; kwargs...)
end

function generate_rootfinding_callback(cbs, sys::ODESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
eqs = map(cb->cb.eqs, cbs)
num_eqs = length.(eqs)
(isempty(eqs) || sum(num_eqs) == 0) && return nothing
# fuse equations to create VectorContinuousCallback
eqs = reduce(vcat, eqs)
# rewrite all equations as 0 ~ interesting stuff
eqs = map(eqs) do eq
isequal(eq.lhs, 0) && return eq
0 ~ eq.lhs - eq.rhs
end

rhss = map(x->x.rhs, eqs)
root_eq_vars = unique(collect(Iterators.flatten(map(ModelingToolkit.vars, rhss))))

u = map(x->time_varying_as_func(value(x), sys), dvs)
p = map(x->time_varying_as_func(value(x), sys), ps)
t = get_iv(sys)
rf_oop, rf_ip = build_function(rhss, u, p, t; expression=Val{false}, kwargs...)

affect_functions = map(cbs) do cb # Keep affect function separate
eq_aff = affect_equations(cb)
affect = compile_affect(eq_aff, sys, dvs, ps; kwargs...)
end

if length(eqs) == 1
cond = function(u, t, integ)
if DiffEqBase.isinplace(integ.sol.prob)
tmp, = DiffEqBase.get_tmp_cache(integ)
rf_ip(tmp, u, integ.p, t)
tmp[1]
else
rf_oop(u, integ.p, t)
end
end
ContinuousCallback(cond, affect_functions[])
else
cond = function(out, u, t, integ)
rf_ip(out, u, integ.p, t)
end

# since there may be different number of conditions and affects,
# we build a map that translates the condition eq. number to the affect number
eq_ind2affect = reduce(vcat, [fill(i, num_eqs[i]) for i in eachindex(affect_functions)])
@assert length(eq_ind2affect) == length(eqs)
@assert maximum(eq_ind2affect) == length(affect_functions)

affect = let affect_functions=affect_functions, eq_ind2affect=eq_ind2affect
function(integ, eq_ind) # eq_ind refers to the equation index that triggered the event, each event has num_eqs[i] equations
affect_functions[eq_ind2affect[eq_ind]](integ)
end
end
VectorContinuousCallback(cond, affect, length(eqs))
end
end

compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...) = compile_affect(affect_equations(cb), args...; kwargs...)

"""
compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
compile_affect(cb::SymbolicContinuousCallback, args...; kwargs...)

Returns a function that takes an integrator as argument and modifies the state with the affect.
"""
function compile_affect(eqs::Vector{Equation}, sys, dvs, ps; kwargs...)
if isempty(eqs)
return (args...) -> () # We don't do anything in the callback, we're just after the event
else
rhss = map(x->x.rhs, eqs)
lhss = map(x->x.lhs, eqs)
update_vars = collect(Iterators.flatten(map(ModelingToolkit.vars, lhss))) # these are the ones we're chaning
length(update_vars) == length(unique(update_vars)) == length(eqs) ||
error("affected variables not unique, each state can only be affected by one equation for a single `root_eqs => affects` pair.")
vars = states(sys)

u = map(x->time_varying_as_func(value(x), sys), vars)
p = map(x->time_varying_as_func(value(x), sys), ps)
t = get_iv(sys)
rf_oop, rf_ip = build_function(rhss, u, p, t; expression=Val{false}, kwargs...)

stateind(sym) = findfirst(isequal(sym),vars)

update_inds = stateind.(update_vars)
let update_inds=update_inds
function(integ)
lhs = @views integ.u[update_inds]
rf_ip(lhs, integ.u, integ.p, integ.t)
end
end
end
end


function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
# if something is not x(t) (the current state)
# but is `x(t-1)` or something like that, pass in `x` as a callable function rather
Expand Down Expand Up @@ -552,15 +651,28 @@ Generates an ODEProblem from an ODESystem and allows for automatically
symbolically calculating numerical enhancements.
"""
function DiffEqBase.ODEProblem{iip}(sys::AbstractODESystem,u0map,tspan,
parammap=DiffEqBase.NullParameters();kwargs...) where iip
parammap=DiffEqBase.NullParameters(); callback=nothing, kwargs...) where iip
has_difference = any(isdifferenceeq, equations(sys))
f, u0, p = process_DEProblem(ODEFunction{iip}, sys, u0map, parammap; has_difference=has_difference, kwargs...)
if has_difference
ODEProblem{iip}(f,u0,tspan,p;difference_cb=generate_difference_cb(sys;kwargs...),kwargs...)
if has_continuous_events(sys)
event_cb = generate_rootfinding_callback(sys; kwargs...)
else
event_cb = nothing
end
difference_cb = has_difference ? generate_difference_cb(sys; kwargs...) : nothing
cb = merge_cb(event_cb, difference_cb)
cb = merge_cb(cb, callback)

if cb === nothing
ODEProblem{iip}(f, u0, tspan, p; kwargs...)
else
ODEProblem{iip}(f,u0,tspan,p;kwargs...)
ODEProblem{iip}(f, u0, tspan, p; callback=cb, kwargs...)
end
end
merge_cb(::Nothing, ::Nothing) = nothing
merge_cb(::Nothing, x) = merge_cb(x, nothing)
merge_cb(x, ::Nothing) = x
merge_cb(x, y) = CallbackSet(x, y)

"""
```julia
Expand Down
Loading