Skip to content
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "3.1.1"
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12"
DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down
6 changes: 5 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using StaticArrays, LinearAlgebra, SparseArrays
using Latexify, Unitful, ArrayInterface
using MacroTools
using UnPack: @unpack
using DiffEqJump

using Base.Threads
import MacroTools: splitdef, combinedef, postwalk, striplines
Expand Down Expand Up @@ -86,6 +87,8 @@ include("systems/diffeqs/first_order_transform.jl")
include("systems/diffeqs/modelingtoolkitize.jl")
include("systems/diffeqs/validation.jl")

include("systems/jumps/jumpsystem.jl")

include("systems/nonlinear/nonlinearsystem.jl")

include("systems/optimization/optimizationsystem.jl")
Expand All @@ -99,7 +102,8 @@ include("build_function.jl")

export ODESystem, ODEFunction
export SDESystem, SDEFunction
export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem
export JumpSystem
export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem, JumpProblem
export NonlinearSystem, OptimizationSystem
export ode_order_lowering
export PDESystem
Expand Down
88 changes: 59 additions & 29 deletions src/build_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,50 @@ function build_function(args...;target = JuliaTarget(),kwargs...)
_build_function(target,args...;kwargs...)
end

function addheader(ex, fargs, iip; X=gensym(:MTIIPVar))
if iip
wrappedex = :(
($X,$(fargs.args...)) -> begin
$ex
nothing
end
)
else
wrappedex = :(
($(fargs.args...),) -> begin
$ex
end
)
end
wrappedex
end

function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar))
integrator = gensym(:MTKIntegrator)
if iip
wrappedex = :(
$integrator -> begin
($X,$(fargs.args...)) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t)
$ex
nothing
end
)
else
wrappedex = :(
$integrator -> begin
($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t)
$ex
end
)
end
wrappedex
end

# Scalar output
function _build_function(target::JuliaTarget, op::Operation, args...;
conv = simplified_expr, expression = Val{true},
checkbounds = false, constructor=nothing,
linenumbers = true)
linenumbers = true, headerfun=addheader)

argnames = [gensym(:MTKArg) for i in 1:length(args)]
arg_pairs = map(vars_to_pairs,zip(argnames,args))
Expand All @@ -74,13 +113,8 @@ function _build_function(target::JuliaTarget, op::Operation, args...;
bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end)

fargs = Expr(:tuple,argnames...)

oop_ex = :(
($(fargs.args...),) -> begin
$bounds_block
end
)

oop_ex = headerfun(bounds_block, fargs, false)

if !linenumbers
oop_ex = striplines(oop_ex)
end
Expand All @@ -95,8 +129,8 @@ end
function _build_function(target::JuliaTarget, rhss, args...;
conv = simplified_expr, expression = Val{true},
checkbounds = false, constructor=nothing,
linenumbers = false, multithread=false)

linenumbers = false, multithread=false,
headerfun=addheader, outputidxs=nothing)
argnames = [gensym(:MTKArg) for i in 1:length(args)]
arg_pairs = map(vars_to_pairs,zip(argnames,args))
ls = reduce(vcat,first.(arg_pairs))
Expand All @@ -106,6 +140,8 @@ function _build_function(target::JuliaTarget, rhss, args...;
fname = gensym(:ModelingToolkitFunction)
fargs = Expr(:tuple,argnames...)


oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i])
Copy link
Member Author

@isaacsas isaacsas May 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ChrisRackauckas With this jumps seem to be working now. If this design looks ok to you I can add tests, and if you want add the oidx call for the other equation types of rhss I can add them too (right now it is only applied for the default rhss case).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably fine to just have it here.

X = gensym(:MTIIPVar)
if eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays
ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) ∈ enumerate(rhsel2)]) for (j, rhsel2) ∈ enumerate(rhsel)],init=Expr[])) for (i,rhsel) ∈ enumerate(rhss)],init=Expr[])
Expand All @@ -118,7 +154,7 @@ function _build_function(target::JuliaTarget, rhss, args...;
elseif rhss isa SparseMatrixCSC
ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) ∈ enumerate(rhss.nzval)]
else
ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) ∈ enumerate(rhss)]
ip_sys_exprs = [:($X[$(oidx(i))] = $(conv(rhs))) for (i, rhs) ∈ enumerate(rhss)]
end

ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs))
Expand Down Expand Up @@ -165,26 +201,20 @@ function _build_function(target::JuliaTarget, rhss, args...;
arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end)
ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end)

oop_ex = :(
($(fargs.args...),) -> begin
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
return $arr_bounds_block
else
X = $bounds_block
construct = $_constructor
return construct(X)
end
end
)

iip_ex = :(
($X,$(fargs.args...)) -> begin
$ip_bounds_block
nothing
end
oop_body_block = :(
# If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways
if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC))
return $arr_bounds_block
else
X = $bounds_block
construct = $_constructor
return construct(X)
end
)

oop_ex = headerfun(oop_body_block, fargs, false)
iip_ex = headerfun(ip_bounds_block, fargs, true; X=X)

if !linenumbers
oop_ex = striplines(oop_ex)
iip_ex = striplines(iip_ex)
Expand Down
68 changes: 68 additions & 0 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump}

struct JumpSystem <: AbstractSystem
eqs::Vector{JumpType}
iv::Variable
states::Vector{Variable}
ps::Vector{Variable}
name::Symbol
systems::Vector{JumpSystem}
end

function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[],
name = gensym(:JumpSystem))
JumpSystem(eqs, iv, convert.(Variable, states), convert.(Variable, ps), name, systems)
end



generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js),
independent_variable(js),
expression=Val{false})

generate_affect_function(js, affect, outputidxs) = build_function(affect, states(js),
parameters(js),
independent_variable(js),
expression=Val{false},
headerfun=add_integrator_header,
outputidxs=outputidxs)[2]
function assemble_vrj(js, vrj, statetoid)
rate = generate_rate_function(js, vrj.rate)
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!)
outputidxs = ((statetoid[var] for var in outputvars)...,)
affect = generate_affect_function(js, vrj.affect!, outputidxs)
VariableRateJump(rate, affect)
end

function assemble_crj(js, crj, statetoid)
rate = generate_rate_function(js, crj.rate)
outputvars = (convert(Variable,affect.lhs) for affect in crj.affect!)
outputidxs = ((statetoid[var] for var in outputvars)...,)
affect = generate_affect_function(js, crj.affect!, outputidxs)
ConstantRateJump(rate, affect)
end

"""
```julia
function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
```

Generates a JumpProblem from a JumpSystem.
"""
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...)
vrjs = Vector{VariableRateJump}()
crjs = Vector{ConstantRateJump}()
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
for j in equations(js)
if j isa ConstantRateJump
push!(crjs, assemble_crj(js, j, statetoid))
elseif j isa VariableRateJump
push!(vrjs, assemble_vrj(js, j, statetoid))
else
(j isa MassActionJump) && error("Generation of JumpProblems with MassActionJumps is not yet supported.")
end
end
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps")
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, nothing)
JumpProblem(prob, aggregator, jset)
end
106 changes: 106 additions & 0 deletions test/jumpsystem.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
using ModelingToolkit, DiffEqBase, DiffEqJump, Test, LinearAlgebra
MT = ModelingToolkit

# basic MT SIR model with tweaks
@parameters β γ t
@variables S I R
rate₁ = β*S*I
affect₁ = [S ~ S - 1, I ~ I + 1]
rate₂ = γ*I+t
affect₂ = [I ~ I - 1, R ~ R + 1]
j₁ = ConstantRateJump(rate₁,affect₁)
j₂ = VariableRateJump(rate₂,affect₂)
js = JumpSystem([j₁,j₂], t, [S,I,R], [β,γ])
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js)))
mtjump1 = MT.assemble_crj(js, j₁, statetoid)
mtjump2 = MT.assemble_vrj(js, j₂, statetoid)

# doc version
rate1(u,p,t) = (0.1/1000.0)*u[1]*u[2]
function affect1!(integrator)
integrator.u[1] -= 1
integrator.u[2] += 1
end
jump1 = ConstantRateJump(rate1,affect1!)
rate2(u,p,t) = 0.01u[2]+t
function affect2!(integrator)
integrator.u[2] -= 1
integrator.u[3] += 1
end
jump2 = VariableRateJump(rate2,affect2!)

# test crjs
u = [100, 9, 5]
p = (0.1/1000,0.01)
tf = 1.0
mutable struct TestInt{U,V,T}
u::U
p::V
t::T
end
mtintegrator = TestInt(u,p,tf)
integrator = TestInt(u,p,tf)
@test abs(mtjump1.rate(u,p,tf) - jump1.rate(u,p,tf)) < 10*eps()
@test abs(mtjump2.rate(u,p,tf) - jump2.rate(u,p,tf)) < 10*eps()
mtjump1.affect!(mtintegrator)
jump1.affect!(integrator)
@test all(integrator.u .== mtintegrator.u)
mtintegrator.u .= u; integrator.u .= u
mtjump2.affect!(mtintegrator)
jump2.affect!(integrator)
@test all(integrator.u .== mtintegrator.u)

# test MT can make and solve a jump problem
rate₃ = γ*I
affect₃ = [I ~ I - 1, R ~ R + 1]
j₃ = ConstantRateJump(rate₃,affect₃)
js2 = JumpSystem([j₁,j₃], t, [S,I,R], [β,γ])
u₀ = [999,1,0]; p = (0.1/1000,0.01); tspan = (0.,250.)
dprob = DiscreteProblem(u₀,tspan,p)
jprob = JumpProblem(js2, dprob, Direct(), save_positions=(false,false))
Nsims = 10000
function getmean(jprob,Nsims)
m = 0.0
for i = 1:Nsims
sol = solve(jprob, SSAStepper())
m += sol[end,end]
end
m/Nsims
end
m = getmean(jprob,Nsims)

#test the MT JumpProblem rates/affects are correct
rate2(u,p,t) = 0.01u[2]
jump2 = ConstantRateJump(rate2,affect2!)
mtjumps = jprob.discrete_jump_aggregation
@test abs(mtjumps.rates[1](u,p,tf) - jump1.rate(u,p,tf)) < 10*eps()
@test abs(mtjumps.rates[2](u,p,tf) - jump2.rate(u,p,tf)) < 10*eps()
mtjumps.affects![1](mtintegrator)
jump1.affect!(integrator)
@test all(integrator.u .== mtintegrator.u)
mtintegrator.u .= u; integrator.u .= u
mtjumps.affects![2](mtintegrator)
jump2.affect!(integrator)
@test all(integrator.u .== mtintegrator.u)

# direct vers
p = (0.1/1000,0.01)
prob = DiscreteProblem([999,1,0],(0.0,250.0),p)
r1(u,p,t) = (0.1/1000.0)*u[1]*u[2]
function a1!(integrator)
integrator.u[1] -= 1
integrator.u[2] += 1
end
j1 = ConstantRateJump(r1,a1!)
r2(u,p,t) = 0.01u[2]
function a2!(integrator)
integrator.u[2] -= 1
integrator.u[3] += 1
end
j2 = ConstantRateJump(r2,a2!)
jset = JumpSet((),(j1,j2),nothing,nothing)
jprob = JumpProblem(prob,Direct(),jset, save_positions=(false,false))
m2 = getmean(jprob,Nsims)

# test JumpSystem solution agrees with direct version
@test abs(m-m2) ./ m < .01