Skip to content

Commit

Permalink
Add fast_substitute
Browse files Browse the repository at this point in the history
Before:
```julia
julia> @time sysRed = tearing(sysEx);
  8.903631 seconds (42.38 M allocations: 2.968 GiB, 8.06% gc time)
```

After:
```julia
julia> @time tearing(sysEx);
  1.733097 seconds (10.90 M allocations: 1.059 GiB, 19.44% gc time)
```
  • Loading branch information
YingboMa committed Oct 19, 2022
1 parent 022008b commit c049454
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/structural_transformation/StructuralTransformations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ using ModelingToolkit: ODESystem, AbstractSystem, var_from_nested_derivative, Di
get_postprocess_fbody, vars!,
IncrementalCycleTracker, add_edge_checked!, topological_sort,
invalidate_cache!, Substitutions, get_or_construct_tearing_state,
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL
AliasGraph, filter_kwargs, lower_varname, setio, SparseMatrixCLIL,
fast_substitute

using ModelingToolkit.BipartiteGraphs
import .BipartiteGraphs: invview, complete
Expand Down
8 changes: 4 additions & 4 deletions src/structural_transformation/symbolics_tearing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
idx_buffer = Int[]
sub_callback! = let eqs = neweqs, fullvars = fullvars
(ieq, s) -> begin
neweq = substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]])
neweq = fast_substitute(eqs[ieq], fullvars[s[1]] => fullvars[s[2]])
eqs[ieq] = neweq
end
end
Expand Down Expand Up @@ -282,7 +282,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
end
for eq in 𝑑neighbors(graph, dv)
dummy_sub[dd] = v_t
neweqs[eq] = substitute(neweqs[eq], dd => v_t)
neweqs[eq] = fast_substitute(neweqs[eq], dd => v_t)
end
fullvars[dv] = v_t
# If we have:
Expand All @@ -295,7 +295,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
while (ddx = var_to_diff[dx]) !== nothing
dx_t = D(x_t)
for eq in 𝑑neighbors(graph, ddx)
neweqs[eq] = substitute(neweqs[eq], fullvars[ddx] => dx_t)
neweqs[eq] = fast_substitute(neweqs[eq], fullvars[ddx] => dx_t)
end
fullvars[ddx] = dx_t
dx = ddx
Expand Down Expand Up @@ -655,7 +655,7 @@ function tearing_reassemble(state::TearingState, var_eq_matching, ag = nothing;
obs_sub[eq.lhs] = eq.rhs
end
# TODO: compute the dependency correctly so that we don't have to do this
obs = substitute.([oldobs; subeqs], (obs_sub,))
obs = fast_substitute([oldobs; subeqs], obs_sub)
@set! sys.observed = obs
@set! state.sys = sys
@set! sys.tearing_state = state
Expand Down
3 changes: 1 addition & 2 deletions src/systems/alias_elimination.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,9 @@ function alias_elimination!(state::TearingState; kwargs...)
k === nothing && break
end
end
subfun = Base.Fix2(substitute, subs)
for ieq in eqs_to_update
eq = eqs[ieq]
eqs[ieq] = subfun(eq.lhs) ~ subfun(eq.rhs)
eqs[ieq] = fast_substitute(eq, subs)
end

for old_ieq in to_expand
Expand Down
33 changes: 33 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -741,3 +741,36 @@ function jacobian_wrt_vars(pf::F, p, input_idxs, chunk::C) where {F, C}
cfg = ForwardDiff.JacobianConfig(p_closure, p_small, chunk, tag)
ForwardDiff.jacobian(p_closure, p_small, cfg, Val(false))
end

# Symbolics needs to call unwrap on the substitution rules, but most of the time
# we don't want to do that in MTK.
function fast_substitute(eq::Equation, subs)
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
end
function fast_substitute(eq::Equation, subs::Pair)
fast_substitute(eq.lhs, subs) ~ fast_substitute(eq.rhs, subs)
end
fast_substitute(eqs::AbstractArray{Equation}, subs) = fast_substitute.(eqs, (subs,))
fast_substitute(a, b) = substitute(a, b)
function fast_substitute(expr, pair::Pair)
a, b = pair
isequal(expr, a) && return b

istree(expr) || return expr
op = fast_substitute(operation(expr), pair)
canfold = Ref(!(op isa Symbolic))
args = let canfold = canfold
map(SymbolicUtils.unsorted_arguments(expr)) do x
x′ = fast_substitute(x, pair)
canfold[] = canfold[] && !(x′ isa Symbolic)
x′
end
end
canfold[] && return op(args...)

similarterm(expr,
op,
args,
symtype(expr);
metadata = metadata(expr))
end

0 comments on commit c049454

Please sign in to comment.