diff --git a/Project.toml b/Project.toml index 75c3df7518..eddbe228a7 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "3.1.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" +DiffEqJump = "c894b116-72e5-5b58-be3c-e6d8d4ac2b12" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index ea20a95073..03d6d432cd 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -5,6 +5,7 @@ using StaticArrays, LinearAlgebra, SparseArrays using Latexify, Unitful, ArrayInterface using MacroTools using UnPack: @unpack +using DiffEqJump using Base.Threads import MacroTools: splitdef, combinedef, postwalk, striplines @@ -86,6 +87,8 @@ include("systems/diffeqs/first_order_transform.jl") include("systems/diffeqs/modelingtoolkitize.jl") include("systems/diffeqs/validation.jl") +include("systems/jumps/jumpsystem.jl") + include("systems/nonlinear/nonlinearsystem.jl") include("systems/optimization/optimizationsystem.jl") @@ -99,7 +102,8 @@ include("build_function.jl") export ODESystem, ODEFunction export SDESystem, SDEFunction -export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem +export JumpSystem +export ODEProblem, SDEProblem, NonlinearProblem, OptimizationProblem, JumpProblem export NonlinearSystem, OptimizationSystem export ode_order_lowering export PDESystem diff --git a/src/build_function.jl b/src/build_function.jl index a09c9e5cf1..90f1997bb3 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -56,11 +56,50 @@ function build_function(args...;target = JuliaTarget(),kwargs...) _build_function(target,args...;kwargs...) end +function addheader(ex, fargs, iip; X=gensym(:MTIIPVar)) + if iip + wrappedex = :( + ($X,$(fargs.args...)) -> begin + $ex + nothing + end + ) + else + wrappedex = :( + ($(fargs.args...),) -> begin + $ex + end + ) + end + wrappedex +end + +function add_integrator_header(ex, fargs, iip; X=gensym(:MTIIPVar)) + integrator = gensym(:MTKIntegrator) + if iip + wrappedex = :( + $integrator -> begin + ($X,$(fargs.args...)) = (($integrator).u,($integrator).u,($integrator).p,($integrator).t) + $ex + nothing + end + ) + else + wrappedex = :( + $integrator -> begin + ($(fargs.args...),) = (($integrator).u,($integrator).p,($integrator).t) + $ex + end + ) + end + wrappedex +end + # Scalar output function _build_function(target::JuliaTarget, op::Operation, args...; conv = simplified_expr, expression = Val{true}, checkbounds = false, constructor=nothing, - linenumbers = true) + linenumbers = true, headerfun=addheader) argnames = [gensym(:MTKArg) for i in 1:length(args)] arg_pairs = map(vars_to_pairs,zip(argnames,args)) @@ -74,13 +113,8 @@ function _build_function(target::JuliaTarget, op::Operation, args...; bounds_block = checkbounds ? let_expr : :(@inbounds begin $let_expr end) fargs = Expr(:tuple,argnames...) - - oop_ex = :( - ($(fargs.args...),) -> begin - $bounds_block - end - ) - + oop_ex = headerfun(bounds_block, fargs, false) + if !linenumbers oop_ex = striplines(oop_ex) end @@ -95,8 +129,8 @@ end function _build_function(target::JuliaTarget, rhss, args...; conv = simplified_expr, expression = Val{true}, checkbounds = false, constructor=nothing, - linenumbers = false, multithread=false) - + linenumbers = false, multithread=false, + headerfun=addheader, outputidxs=nothing) argnames = [gensym(:MTKArg) for i in 1:length(args)] arg_pairs = map(vars_to_pairs,zip(argnames,args)) ls = reduce(vcat,first.(arg_pairs)) @@ -106,6 +140,8 @@ function _build_function(target::JuliaTarget, rhss, args...; fname = gensym(:ModelingToolkitFunction) fargs = Expr(:tuple,argnames...) + + oidx = isnothing(outputidxs) ? (i -> i) : (i -> outputidxs[i]) X = gensym(:MTIIPVar) if eltype(eltype(rhss)) <: AbstractArray # Array of arrays of arrays ip_sys_exprs = reduce(vcat,[vec(reduce(vcat,[vec([:($X[$i][$j][$k] = $(conv(rhs))) for (k, rhs) ∈ enumerate(rhsel2)]) for (j, rhsel2) ∈ enumerate(rhsel)],init=Expr[])) for (i,rhsel) ∈ enumerate(rhss)],init=Expr[]) @@ -118,7 +154,7 @@ function _build_function(target::JuliaTarget, rhss, args...; elseif rhss isa SparseMatrixCSC ip_sys_exprs = [:($X.nzval[$i] = $(conv(rhs))) for (i, rhs) ∈ enumerate(rhss.nzval)] else - ip_sys_exprs = [:($X[$i] = $(conv(rhs))) for (i, rhs) ∈ enumerate(rhss)] + ip_sys_exprs = [:($X[$(oidx(i))] = $(conv(rhs))) for (i, rhs) ∈ enumerate(rhss)] end ip_let_expr = Expr(:let, var_eqs, build_expr(:block, ip_sys_exprs)) @@ -165,26 +201,20 @@ function _build_function(target::JuliaTarget, rhss, args...; arr_bounds_block = checkbounds ? arr_let_expr : :(@inbounds begin $arr_let_expr end) ip_bounds_block = checkbounds ? ip_let_expr : :(@inbounds begin $ip_let_expr end) - oop_ex = :( - ($(fargs.args...),) -> begin - # If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways - if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC)) - return $arr_bounds_block - else - X = $bounds_block - construct = $_constructor - return construct(X) - end - end - ) - - iip_ex = :( - ($X,$(fargs.args...)) -> begin - $ip_bounds_block - nothing - end + oop_body_block = :( + # If u is a weird non-StaticArray type and we want a sparse matrix, just do the optimized sparse anyways + if $(fargs.args[1]) isa Array || (!(typeof($(fargs.args[1])) <: StaticArray) && $(rhss isa SparseMatrixCSC)) + return $arr_bounds_block + else + X = $bounds_block + construct = $_constructor + return construct(X) + end ) + oop_ex = headerfun(oop_body_block, fargs, false) + iip_ex = headerfun(ip_bounds_block, fargs, true; X=X) + if !linenumbers oop_ex = striplines(oop_ex) iip_ex = striplines(iip_ex) diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl new file mode 100644 index 0000000000..1906fd2f15 --- /dev/null +++ b/src/systems/jumps/jumpsystem.jl @@ -0,0 +1,68 @@ +JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump} + +struct JumpSystem <: AbstractSystem + eqs::Vector{JumpType} + iv::Variable + states::Vector{Variable} + ps::Vector{Variable} + name::Symbol + systems::Vector{JumpSystem} +end + +function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[], + name = gensym(:JumpSystem)) + JumpSystem(eqs, iv, convert.(Variable, states), convert.(Variable, ps), name, systems) +end + + + +generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js), + independent_variable(js), + expression=Val{false}) + +generate_affect_function(js, affect, outputidxs) = build_function(affect, states(js), + parameters(js), + independent_variable(js), + expression=Val{false}, + headerfun=add_integrator_header, + outputidxs=outputidxs)[2] +function assemble_vrj(js, vrj, statetoid) + rate = generate_rate_function(js, vrj.rate) + outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!) + outputidxs = ((statetoid[var] for var in outputvars)...,) + affect = generate_affect_function(js, vrj.affect!, outputidxs) + VariableRateJump(rate, affect) +end + +function assemble_crj(js, crj, statetoid) + rate = generate_rate_function(js, crj.rate) + outputvars = (convert(Variable,affect.lhs) for affect in crj.affect!) + outputidxs = ((statetoid[var] for var in outputvars)...,) + affect = generate_affect_function(js, crj.affect!, outputidxs) + ConstantRateJump(rate, affect) +end + +""" +```julia +function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...) +``` + +Generates a JumpProblem from a JumpSystem. +""" +function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...) + vrjs = Vector{VariableRateJump}() + crjs = Vector{ConstantRateJump}() + statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js))) + for j in equations(js) + if j isa ConstantRateJump + push!(crjs, assemble_crj(js, j, statetoid)) + elseif j isa VariableRateJump + push!(vrjs, assemble_vrj(js, j, statetoid)) + else + (j isa MassActionJump) && error("Generation of JumpProblems with MassActionJumps is not yet supported.") + end + end + ((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps") + jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, nothing) + JumpProblem(prob, aggregator, jset) +end \ No newline at end of file diff --git a/test/jumpsystem.jl b/test/jumpsystem.jl new file mode 100644 index 0000000000..68212fd791 --- /dev/null +++ b/test/jumpsystem.jl @@ -0,0 +1,106 @@ +using ModelingToolkit, DiffEqBase, DiffEqJump, Test, LinearAlgebra +MT = ModelingToolkit + +# basic MT SIR model with tweaks +@parameters β γ t +@variables S I R +rate₁ = β*S*I +affect₁ = [S ~ S - 1, I ~ I + 1] +rate₂ = γ*I+t +affect₂ = [I ~ I - 1, R ~ R + 1] +j₁ = ConstantRateJump(rate₁,affect₁) +j₂ = VariableRateJump(rate₂,affect₂) +js = JumpSystem([j₁,j₂], t, [S,I,R], [β,γ]) +statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js))) +mtjump1 = MT.assemble_crj(js, j₁, statetoid) +mtjump2 = MT.assemble_vrj(js, j₂, statetoid) + +# doc version +rate1(u,p,t) = (0.1/1000.0)*u[1]*u[2] +function affect1!(integrator) + integrator.u[1] -= 1 + integrator.u[2] += 1 +end +jump1 = ConstantRateJump(rate1,affect1!) +rate2(u,p,t) = 0.01u[2]+t +function affect2!(integrator) + integrator.u[2] -= 1 + integrator.u[3] += 1 +end +jump2 = VariableRateJump(rate2,affect2!) + +# test crjs +u = [100, 9, 5] +p = (0.1/1000,0.01) +tf = 1.0 +mutable struct TestInt{U,V,T} + u::U + p::V + t::T +end +mtintegrator = TestInt(u,p,tf) +integrator = TestInt(u,p,tf) +@test abs(mtjump1.rate(u,p,tf) - jump1.rate(u,p,tf)) < 10*eps() +@test abs(mtjump2.rate(u,p,tf) - jump2.rate(u,p,tf)) < 10*eps() +mtjump1.affect!(mtintegrator) +jump1.affect!(integrator) +@test all(integrator.u .== mtintegrator.u) +mtintegrator.u .= u; integrator.u .= u +mtjump2.affect!(mtintegrator) +jump2.affect!(integrator) +@test all(integrator.u .== mtintegrator.u) + +# test MT can make and solve a jump problem +rate₃ = γ*I +affect₃ = [I ~ I - 1, R ~ R + 1] +j₃ = ConstantRateJump(rate₃,affect₃) +js2 = JumpSystem([j₁,j₃], t, [S,I,R], [β,γ]) +u₀ = [999,1,0]; p = (0.1/1000,0.01); tspan = (0.,250.) +dprob = DiscreteProblem(u₀,tspan,p) +jprob = JumpProblem(js2, dprob, Direct(), save_positions=(false,false)) +Nsims = 10000 +function getmean(jprob,Nsims) + m = 0.0 + for i = 1:Nsims + sol = solve(jprob, SSAStepper()) + m += sol[end,end] + end + m/Nsims +end +m = getmean(jprob,Nsims) + +#test the MT JumpProblem rates/affects are correct +rate2(u,p,t) = 0.01u[2] +jump2 = ConstantRateJump(rate2,affect2!) +mtjumps = jprob.discrete_jump_aggregation +@test abs(mtjumps.rates[1](u,p,tf) - jump1.rate(u,p,tf)) < 10*eps() +@test abs(mtjumps.rates[2](u,p,tf) - jump2.rate(u,p,tf)) < 10*eps() +mtjumps.affects![1](mtintegrator) +jump1.affect!(integrator) +@test all(integrator.u .== mtintegrator.u) +mtintegrator.u .= u; integrator.u .= u +mtjumps.affects![2](mtintegrator) +jump2.affect!(integrator) +@test all(integrator.u .== mtintegrator.u) + +# direct vers +p = (0.1/1000,0.01) +prob = DiscreteProblem([999,1,0],(0.0,250.0),p) +r1(u,p,t) = (0.1/1000.0)*u[1]*u[2] +function a1!(integrator) + integrator.u[1] -= 1 + integrator.u[2] += 1 +end +j1 = ConstantRateJump(r1,a1!) +r2(u,p,t) = 0.01u[2] +function a2!(integrator) + integrator.u[2] -= 1 + integrator.u[3] += 1 +end +j2 = ConstantRateJump(r2,a2!) +jset = JumpSet((),(j1,j2),nothing,nothing) +jprob = JumpProblem(prob,Direct(),jset, save_positions=(false,false)) +m2 = getmean(jprob,Nsims) + +# test JumpSystem solution agrees with direct version +@test abs(m-m2) ./ m < .01 \ No newline at end of file