-
-
Notifications
You must be signed in to change notification settings - Fork 232
JumpSystems for constant and variable rate jumps #317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
a40879a
start adding JumpSystem features
isaacsas f7a21e8
updating build_function for integrators
isaacsas ff393c8
constant rate jump assembly working
isaacsas 595a530
constant and variable rate assembly working
isaacsas 976c22e
switch to header function
isaacsas f65a39d
can make JumpProblem, but wrong answers on SIR
isaacsas a2b2070
still searching for bug...
isaacsas 792a84d
still bug hunting
isaacsas 14d2551
unpack arrays into tuples
isaacsas 0392ad5
Merge remote-tracking branch 'origin/master' into jumpsystems
isaacsas 1d71d7a
add outputidxs
isaacsas 3968efd
add test for generated jump system
isaacsas File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
JumpType = Union{VariableRateJump, ConstantRateJump, MassActionJump} | ||
|
||
struct JumpSystem <: AbstractSystem | ||
eqs::Vector{JumpType} | ||
iv::Variable | ||
states::Vector{Variable} | ||
ps::Vector{Variable} | ||
name::Symbol | ||
systems::Vector{JumpSystem} | ||
end | ||
|
||
function JumpSystem(eqs, iv, states, ps; systems = JumpSystem[], | ||
name = gensym(:JumpSystem)) | ||
JumpSystem(eqs, iv, convert.(Variable, states), convert.(Variable, ps), name, systems) | ||
end | ||
|
||
|
||
|
||
generate_rate_function(js, rate) = build_function(rate, states(js), parameters(js), | ||
independent_variable(js), | ||
expression=Val{false}) | ||
|
||
generate_affect_function(js, affect, outputidxs) = build_function(affect, states(js), | ||
parameters(js), | ||
independent_variable(js), | ||
expression=Val{false}, | ||
headerfun=add_integrator_header, | ||
outputidxs=outputidxs)[2] | ||
function assemble_vrj(js, vrj, statetoid) | ||
rate = generate_rate_function(js, vrj.rate) | ||
outputvars = (convert(Variable,affect.lhs) for affect in vrj.affect!) | ||
outputidxs = ((statetoid[var] for var in outputvars)...,) | ||
affect = generate_affect_function(js, vrj.affect!, outputidxs) | ||
VariableRateJump(rate, affect) | ||
end | ||
|
||
function assemble_crj(js, crj, statetoid) | ||
rate = generate_rate_function(js, crj.rate) | ||
outputvars = (convert(Variable,affect.lhs) for affect in crj.affect!) | ||
outputidxs = ((statetoid[var] for var in outputvars)...,) | ||
affect = generate_affect_function(js, crj.affect!, outputidxs) | ||
ConstantRateJump(rate, affect) | ||
end | ||
|
||
""" | ||
```julia | ||
function DiffEqBase.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...) | ||
``` | ||
|
||
Generates a JumpProblem from a JumpSystem. | ||
""" | ||
function DiffEqJump.JumpProblem(js::JumpSystem, prob, aggregator; kwargs...) | ||
vrjs = Vector{VariableRateJump}() | ||
crjs = Vector{ConstantRateJump}() | ||
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js))) | ||
for j in equations(js) | ||
if j isa ConstantRateJump | ||
push!(crjs, assemble_crj(js, j, statetoid)) | ||
elseif j isa VariableRateJump | ||
push!(vrjs, assemble_vrj(js, j, statetoid)) | ||
else | ||
(j isa MassActionJump) && error("Generation of JumpProblems with MassActionJumps is not yet supported.") | ||
end | ||
end | ||
((prob isa DiscreteProblem) && !isempty(vrjs)) && error("Use continuous problems such as an ODEProblem or a SDEProblem with VariableRateJumps") | ||
jset = JumpSet(Tuple(vrjs), Tuple(crjs), nothing, nothing) | ||
JumpProblem(prob, aggregator, jset) | ||
end |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
using ModelingToolkit, DiffEqBase, DiffEqJump, Test, LinearAlgebra | ||
MT = ModelingToolkit | ||
|
||
# basic MT SIR model with tweaks | ||
@parameters β γ t | ||
@variables S I R | ||
rate₁ = β*S*I | ||
affect₁ = [S ~ S - 1, I ~ I + 1] | ||
rate₂ = γ*I+t | ||
affect₂ = [I ~ I - 1, R ~ R + 1] | ||
j₁ = ConstantRateJump(rate₁,affect₁) | ||
j₂ = VariableRateJump(rate₂,affect₂) | ||
js = JumpSystem([j₁,j₂], t, [S,I,R], [β,γ]) | ||
statetoid = Dict(convert(Variable,state) => i for (i,state) in enumerate(states(js))) | ||
mtjump1 = MT.assemble_crj(js, j₁, statetoid) | ||
mtjump2 = MT.assemble_vrj(js, j₂, statetoid) | ||
|
||
# doc version | ||
rate1(u,p,t) = (0.1/1000.0)*u[1]*u[2] | ||
function affect1!(integrator) | ||
integrator.u[1] -= 1 | ||
integrator.u[2] += 1 | ||
end | ||
jump1 = ConstantRateJump(rate1,affect1!) | ||
rate2(u,p,t) = 0.01u[2]+t | ||
function affect2!(integrator) | ||
integrator.u[2] -= 1 | ||
integrator.u[3] += 1 | ||
end | ||
jump2 = VariableRateJump(rate2,affect2!) | ||
|
||
# test crjs | ||
u = [100, 9, 5] | ||
p = (0.1/1000,0.01) | ||
tf = 1.0 | ||
mutable struct TestInt{U,V,T} | ||
u::U | ||
p::V | ||
t::T | ||
end | ||
mtintegrator = TestInt(u,p,tf) | ||
integrator = TestInt(u,p,tf) | ||
@test abs(mtjump1.rate(u,p,tf) - jump1.rate(u,p,tf)) < 10*eps() | ||
@test abs(mtjump2.rate(u,p,tf) - jump2.rate(u,p,tf)) < 10*eps() | ||
mtjump1.affect!(mtintegrator) | ||
jump1.affect!(integrator) | ||
@test all(integrator.u .== mtintegrator.u) | ||
mtintegrator.u .= u; integrator.u .= u | ||
mtjump2.affect!(mtintegrator) | ||
jump2.affect!(integrator) | ||
@test all(integrator.u .== mtintegrator.u) | ||
|
||
# test MT can make and solve a jump problem | ||
rate₃ = γ*I | ||
affect₃ = [I ~ I - 1, R ~ R + 1] | ||
j₃ = ConstantRateJump(rate₃,affect₃) | ||
js2 = JumpSystem([j₁,j₃], t, [S,I,R], [β,γ]) | ||
u₀ = [999,1,0]; p = (0.1/1000,0.01); tspan = (0.,250.) | ||
dprob = DiscreteProblem(u₀,tspan,p) | ||
jprob = JumpProblem(js2, dprob, Direct(), save_positions=(false,false)) | ||
Nsims = 10000 | ||
function getmean(jprob,Nsims) | ||
m = 0.0 | ||
for i = 1:Nsims | ||
sol = solve(jprob, SSAStepper()) | ||
m += sol[end,end] | ||
end | ||
m/Nsims | ||
end | ||
m = getmean(jprob,Nsims) | ||
|
||
#test the MT JumpProblem rates/affects are correct | ||
rate2(u,p,t) = 0.01u[2] | ||
jump2 = ConstantRateJump(rate2,affect2!) | ||
mtjumps = jprob.discrete_jump_aggregation | ||
@test abs(mtjumps.rates[1](u,p,tf) - jump1.rate(u,p,tf)) < 10*eps() | ||
@test abs(mtjumps.rates[2](u,p,tf) - jump2.rate(u,p,tf)) < 10*eps() | ||
mtjumps.affects | ||
jump1.affect!(integrator) | ||
@test all(integrator.u .== mtintegrator.u) | ||
mtintegrator.u .= u; integrator.u .= u | ||
mtjumps.affects | ||
jump2.affect!(integrator) | ||
@test all(integrator.u .== mtintegrator.u) | ||
|
||
# direct vers | ||
p = (0.1/1000,0.01) | ||
prob = DiscreteProblem([999,1,0],(0.0,250.0),p) | ||
r1(u,p,t) = (0.1/1000.0)*u[1]*u[2] | ||
function a1!(integrator) | ||
integrator.u[1] -= 1 | ||
integrator.u[2] += 1 | ||
end | ||
j1 = ConstantRateJump(r1,a1!) | ||
r2(u,p,t) = 0.01u[2] | ||
function a2!(integrator) | ||
integrator.u[2] -= 1 | ||
integrator.u[3] += 1 | ||
end | ||
j2 = ConstantRateJump(r2,a2!) | ||
jset = JumpSet((),(j1,j2),nothing,nothing) | ||
jprob = JumpProblem(prob,Direct(),jset, save_positions=(false,false)) | ||
m2 = getmean(jprob,Nsims) | ||
|
||
# test JumpSystem solution agrees with direct version | ||
@test abs(m-m2) ./ m < .01 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ChrisRackauckas With this jumps seem to be working now. If this design looks ok to you I can add tests, and if you want add the
oidx
call for the other equation types ofrhss
I can add them too (right now it is only applied for the defaultrhss
case).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
probably fine to just have it here.