Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
250b903
Merge remote-tracking branch 'origin/master' into depgraphs
isaacsas May 9, 2020
562cdd1
add dep graphs for jumps
isaacsas May 9, 2020
9a60530
updates
isaacsas May 10, 2020
8b9246f
add dep graphs for jumps
isaacsas May 9, 2020
e87ade5
updates
isaacsas May 10, 2020
6a5745e
update to master
isaacsas May 12, 2020
891a29c
Merge remote-tracking branch 'myrepo/depgraphs' into depgraphs
isaacsas May 12, 2020
a81446e
store js eqs as arraypartition
isaacsas May 12, 2020
5bbda08
pass syms to DiscreteProblem for plotting
isaacsas May 12, 2020
c14933a
start depgraphs for AbstractSys
isaacsas May 13, 2020
d6ab464
eq to var dependencies
isaacsas May 14, 2020
80af7c0
BiPartite -> Bipartite
isaacsas May 14, 2020
9f7384d
rename and cleanup
isaacsas May 14, 2020
d855b1b
modified state dep graph 1
isaacsas May 14, 2020
1768e27
add variable dependencies
isaacsas May 14, 2020
909c6a8
add conversion to LightGraph.SimpleDiGraph
isaacsas May 14, 2020
a0d22c2
cleanup digraph generation
isaacsas May 14, 2020
16ae064
finish off dependency graphs
isaacsas May 15, 2020
ca9b78f
simplify varvar_deps
isaacsas May 15, 2020
1eaa06f
cleanup
isaacsas May 15, 2020
0ddb5fb
fix substitute_expr! deprec warning
isaacsas May 15, 2020
d8373f0
add JumpSystem dependency graph tests
isaacsas May 15, 2020
4fc4ad8
typo fix
isaacsas May 15, 2020
2d8d0c0
fix runtests
isaacsas May 15, 2020
08e76dd
make depgraphs for jumps
isaacsas May 15, 2020
b3c69a9
Merge branch 'master' of https://github.com/SciML/ModelingToolkit.jl …
isaacsas May 15, 2020
eb5089b
remove arraypartition slicing in ReactionSystem tests
isaacsas May 15, 2020
1727b3b
reenable type test
isaacsas May 15, 2020
b16a989
remove crashing SSAs in tests
isaacsas May 15, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
GeneralizedGenerated = "6b9d7cbe-bcb9-11e9-073f-15a7a543e2eb"
Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
Expand Down
9 changes: 8 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ using RecursiveArrayTools
import SymbolicUtils
import SymbolicUtils: to_symbolic, FnType

import LightGraphs: SimpleDiGraph, add_edge!

import TreeViews

"""
Expand Down Expand Up @@ -101,6 +103,7 @@ include("systems/optimization/optimizationsystem.jl")
include("systems/pde/pdesystem.jl")

include("systems/reaction/reactionsystem.jl")
include("systems/dependency_graphs.jl")

include("latexify_recipes.jl")
include("build_function.jl")
Expand All @@ -118,7 +121,7 @@ export Differential, expand_derivatives, @derivatives
export IntervalDomain, ProductDomain, ⊗, CircleDomain
export Equation, ConstrainedEquation
export Operation, Expression, Variable
export independent_variable, states, parameters, equations
export independent_variable, states, parameters, equations

export calculate_jacobian, generate_jacobian, generate_function
export calculate_tgrad, generate_tgrad
Expand All @@ -127,6 +130,10 @@ export calculate_factorized_W, generate_factorized_W
export calculate_hessian, generate_hessian
export calculate_massmatrix, generate_diffusion_function

export BipartiteGraph, equation_dependencies, variable_dependencies
export eqeq_dependencies, varvar_dependencies
export asgraph, asdigraph

export simplified_expr, rename, get_variables
export simplify, substitute
export build_function
Expand Down
114 changes: 114 additions & 0 deletions src/systems/dependency_graphs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# variables equations depend on as a vector of vectors of variables
# each system type should define extract_variables! for a single equation
function equation_dependencies(sys::AbstractSystem; variables=states(sys))
eqs = equations(sys)
deps = Set{Variable}()
depeqs_to_vars = Vector{Vector{Variable}}(undef,length(eqs))

for (i,eq) in enumerate(eqs)
depeqs_to_vars[i] = collect(get_variables!(deps, eq, variables))
empty!(deps)
end

depeqs_to_vars
end

# modeled on LightGraphs SimpleGraph
mutable struct BipartiteGraph{T <: Integer}
ne::Int
fadjlist::Vector{Vector{T}} # fadjlist[src] = [dest1,dest2,...]
badjlist::Vector{Vector{T}} # badjlist[dst] = [src1,src2,...]
end

# convert equation-variable dependencies to a bipartite graph
function asgraph(eqdeps, vtois)
fadjlist = Vector{Vector{Int}}(undef, length(eqdeps))
for (i,dep) in enumerate(eqdeps)
fadjlist[i] = sort!([vtois[var] for var in dep])
end

badjlist = [Vector{Int}() for i = 1:length(vtois)]
ne = 0
for (eqidx,vidxs) in enumerate(fadjlist)
foreach(vidx -> push!(badjlist[vidx], eqidx), vidxs)
ne += length(vidxs)
end

BipartiteGraph(ne, fadjlist, badjlist)
end

function Base.isequal(bg1::BipartiteGraph{T}, bg2::BipartiteGraph{T}) where {T<:Integer}
iseq = (bg1.ne == bg2.ne)
iseq &= (bg1.fadjlist == bg2.fadjlist)
iseq &= (bg1.badjlist == bg2.badjlist)
iseq
end

# could be made to directly generate graph and save memory
function asgraph(sys::AbstractSystem; variables=nothing, variablestoids=nothing)
vs = isnothing(variables) ? states(sys) : variables
eqdeps = equation_dependencies(sys, variables=vs)
vtois = isnothing(variablestoids) ? Dict(convert(Variable, v) => i for (i,v) in enumerate(vs)) : variablestoids
asgraph(eqdeps, vtois)
end

# for each variable determine the equations that modify it
function variable_dependencies(sys::AbstractSystem; variables=states(sys), variablestoids=nothing)
eqs = equations(sys)
vtois = isnothing(variablestoids) ? Dict(convert(Variable, v) => i for (i,v) in enumerate(variables)) : variablestoids

deps = Set{Variable}()
badjlist = Vector{Vector{Int}}(undef, length(eqs))
for (eidx,eq) in enumerate(eqs)
modified_states!(deps, eq, variables)
badjlist[eidx] = sort!([vtois[var] for var in deps])
empty!(deps)
end

fadjlist = [Vector{Int}() for i = 1:length(variables)]
ne = 0
for (eqidx,vidxs) in enumerate(badjlist)
foreach(vidx -> push!(fadjlist[vidx], eqidx), vidxs)
ne += length(vidxs)
end

BipartiteGraph(ne, fadjlist, badjlist)
end

# convert BipartiteGraph to LightGraph.SimpleDiGraph
function asdigraph(g::BipartiteGraph, sys::AbstractSystem; variables = states(sys), equationsfirst = true)
neqs = length(equations(sys))
nvars = length(variables)
fadjlist = deepcopy(g.fadjlist)
badjlist = deepcopy(g.badjlist)

# offset is for determining indices for the second set of vertices
offset = equationsfirst ? neqs : nvars
for i = 1:offset
fadjlist[i] .+= offset
end

# add empty rows for vertices without connections
append!(fadjlist, [Vector{Int}() for i=1:(equationsfirst ? nvars : neqs)])
prepend!(badjlist, [Vector{Int}() for i=1:(equationsfirst ? neqs : nvars)])

SimpleDiGraph(g.ne, fadjlist, badjlist)
end

# maps the i'th eq to equations that depend on it
function eqeq_dependencies(eqdeps::BipartiteGraph{T}, vardeps::BipartiteGraph{T}) where {T <: Integer}
g = SimpleDiGraph{T}(length(eqdeps.fadjlist))

for (eqidx,sidxs) in enumerate(vardeps.badjlist)
# states modified by eqidx
for sidx in sidxs
# equations depending on sidx
foreach(v -> add_edge!(g, eqidx, v), eqdeps.badjlist[sidx])
end
end

g
end

# maps the i'th variable to variables that depend on it
varvar_dependencies(eqdeps::BipartiteGraph{T}, vardeps::BipartiteGraph{T}) where {T <: Integer} = eqeq_dependencies(vardeps, eqdeps)
107 changes: 82 additions & 25 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump}

struct JumpSystem <: AbstractSystem
eqs::Vector{JumpType}
struct JumpSystem{U <: ArrayPartition} <: AbstractSystem
eqs::U
iv::Variable
states::Vector{Variable}
ps::Vector{Variable}
Expand All @@ -11,9 +11,22 @@ end

function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
name = gensym(:JumpSystem))
JumpSystem(eqs, iv, convert.(Variable, states), convert.(Variable, ps), name, systems)
end

ap = ArrayPartition(MassActionJump[], ConstantRateJump[], VariableRateJump[])
for eq in eqs
if eq isa MassActionJump
push!(ap.x[1], eq)
elseif eq isa ConstantRateJump
push!(ap.x[2], eq)
elseif eq isa VariableRateJump
push!(ap.x[3], eq)
else
error("JumpSystem equations must contain MassActionJumps, ConstantRateJumps, or VariableRateJumps.")
end
end

JumpSystem{typeof(ap)}(ap, convert(Variable,iv), convert.(Variable, states), convert.(Variable, ps), name, systems)
end


generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
Expand All @@ -26,6 +39,7 @@ generate_affect_function(js, affect, outputidxs) = build_function(affect, states
expression=Val{false},
headerfun=add_integrator_header,
outputidxs=outputidxs)[2]

function assemble_vrj(js, vrj, statetoid)
rate = generate_rate_function(js, vrj.rate)
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
Expand Down Expand Up @@ -84,10 +98,13 @@ Generates a DiscreteProblem from an AbstractSystem
function DiffEqBase.DiscreteProblem(sys::AbstractSystem, u0map, tspan::Tuple,
parammap=DiffEqBase.NullParameters(); kwargs...)
u0 = varmap_to_vars(u0map, states(sys))
p = varmap_to_vars(parammap, parameters(sys))
DiscreteProblem(u0, tspan, p; kwargs...)
p = varmap_to_vars(parammap, parameters(sys))
f = (du,u,p,t) -> du.=u # identity function to make syms works
df = DiscreteFunction(f, syms=Symbol.(states(sys)))
DiscreteProblem(df, u0, tspan, p; kwargs...)
end


"""
```julia
function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
Expand All @@ -96,25 +113,65 @@ function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
Generates a JumpProblem from a JumpSystem.
"""
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
vrjs = Vector{VariableRateJump}()
crjs = Vector{ConstantRateJump}()
majs = Vector{MassActionJump}()
pvars = parameters(js)

statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
parammap = map((x,y)->Pair(x(),y),pvars,prob.p)

for j in equations(js)
if j isa ConstantRateJump
push!(crjs, assemble_crj(js, j, statetoid))
elseif j isa VariableRateJump
push!(vrjs, assemble_vrj(js, j, statetoid))
elseif j isa MassActionJump
push!(majs, assemble_maj(js, j, statetoid, parammap))
else
error("JumpSystems should only contain Constant, Variable or Mass Action Jumps.")
end
end
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
parammap = map((x,y)->Pair(x(),y), parameters(js), prob.p)
eqs = equations(js)

majs = MassActionJump[assemble_maj(js, j, statetoid, parammap) for j in eqs.x[1]]
crjs = ConstantRateJump[assemble_crj(js, j, statetoid) for j in eqs.x[2]]
vrjs = VariableRateJump[assemble_vrj(js, j, statetoid) for j in eqs.x[3]]
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, isempty(majs) ? nothing : majs)
JumpProblem(prob, aggregator, jset)

if needs_vartojumps_map(aggregator) || needs_depgraph(aggregator)
jdeps = asgraph(js)
vdeps = variable_dependencies(js)
vtoj = jdeps.badjlist
jtov = vdeps.badjlist
jtoj = needs_depgraph(aggregator) ? eqeq_dependencies(jdeps, vdeps).fadjlist : nothing
else
vtoj = nothing; jtov = nothing; jtoj = nothing
end

JumpProblem(prob, aggregator, jset; dep_graph=jtoj, vartojumps_map=vtoj, jumptovars_map=jtov)
end


### Functions to determine which states a jump depends on
function get_variables!(dep, jump::Union{ConstantRateJump,VariableRateJump}, variables)
foreach(var -> (var in variables) && push!(dep, var), vars(jump.rate))
dep
end

function get_variables!(dep, jump::MassActionJump, variables)
jsr = jump.scaled_rates

if jsr isa Variable
(jsr in variables) && push!(dep, jsr)
elseif jsr isa Operation
foreach(var -> (var in variables) && push!(dep, var), vars(jsr))
end

for varasop in jump.reactant_stoch
var = convert(Variable, varasop[1])
(var in variables) && push!(dep, var)
end

dep
end

### Functions to determine which states are modified by a given jump
function modified_states!(mstates, jump::Union{ConstantRateJump,VariableRateJump}, sts)
for eq in jump.affect!
st = convert(Variable, eq.lhs)
(st in sts) && push!(mstates, st)
end
end

function modified_states!(mstates, jump::MassActionJump, sts)
for (state,stoich) in jump.net_stoch
st = convert(Variable, state)
(st in sts) && push!(mstates, st)
end
end
2 changes: 1 addition & 1 deletion src/systems/reaction/reactionsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ function jumpratelaw(rx)
@unpack rate, substrates, substoich, only_use_rate = rx
rl = deepcopy(rate)
for op in get_variables(rx.rate)
rl = substitute_expr!(rl,op=>var2op(op.op))
rl = substitute(rl,op=>var2op(op.op))
end
if !only_use_rate
for (i,stoich) in enumerate(substoich)
Expand Down
74 changes: 74 additions & 0 deletions test/dep_graphs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using ModelingToolkit, LightGraphs

# use a ReactionSystem to generate systems for testing
@parameters k1 k2 t
@variables S(t) I(t) R(t)

rxs = [Reaction(k1, nothing, [S]),
Reaction(k1, [S], nothing),
Reaction(k2, [S,I], [I], [1,1], [2]),
Reaction(k2, [S,R], [S], [2,1], [2]),
Reaction(k1*I, nothing, [R]),
Reaction(k1*k2/(1+t), [S], [R])]
rs = ReactionSystem(rxs, t, [S,I,R], [k1,k2])


#################################
# testing for Jumps
#################################
js = convert(JumpSystem, rs)
S = convert(Variable,S); I = convert(Variable,I); R = convert(Variable,R)
k1 = convert(Variable,k1); k2 = convert(Variable,k2)

# eq to vars they depend on
eq_sdeps = [Variable[], [S], [S,I], [S,R], [I], [S]]
eq_sidepsf = [Int[], [1], [1,2], [1,3], [2], [1]]
eq_sidepsb = [[2,3,4,6], [3,5],[4]]
deps = equation_dependencies(js)
@test all(i -> isequal(Set(eq_sdeps[i]),Set(deps[i])), 1:length(rxs))
depsbg = asgraph(js)
@test depsbg.fadjlist == eq_sidepsf
@test depsbg.badjlist == eq_sidepsb

# eq to params they depend on
eq_pdeps = [[k1],[k1],[k2],[k2],[k1],[k1,k2]]
eq_pidepsf = [[1],[1],[2],[2],[1],[1,2]]
eq_pidepsb = [[1,2,5,6],[3,4,6]]
deps = equation_dependencies(js, variables=parameters(js))
@test all(i -> isequal(Set(eq_pdeps[i]),Set(deps[i])), 1:length(rxs))
depsbg2 = asgraph(js, variables=parameters(js))
@test depsbg2.fadjlist == eq_pidepsf
@test depsbg2.badjlist == eq_pidepsb

# var to eqs that modify them
s_eqdepsf = [[1,2,3,6],[3],[4,5,6]]
s_eqdepsb = [[1],[1],[1,2],[3],[3],[1,3]]
ne = 8
bg = BipartiteGraph(ne, s_eqdepsf, s_eqdepsb)
deps2 = variable_dependencies(js)
@test isequal(bg,deps2)

# eq to eqs that depend on them
eq_eqdeps = [[2,3,4,6],[2,3,4,6],[2,3,4,5,6],[4],[4],[2,3,4,6]]
dg = SimpleDiGraph(6)
for (eqidx,eqdeps) in enumerate(eq_eqdeps)
for eqdepidx in eqdeps
add_edge!(dg, eqidx, eqdepidx)
end
end
dg3 = eqeq_dependencies(depsbg,deps2)
@test dg == dg3

# var to vars that depend on them
var_vardeps = [[1,2,3],[1,2,3],[3]]
ne = 7
dg = SimpleDiGraph(3)
for (vidx,vdeps) in enumerate(var_vardeps)
for vdepidx in vdeps
add_edge!(dg, vidx, vdepidx)
end
end
dg4 = varvar_dependencies(depsbg,deps2)
@test dg == dg4


Loading