diff --git a/src/aggregators/direct.jl b/src/aggregators/direct.jl index a9e9d8b9..2b8c518a 100644 --- a/src/aggregators/direct.jl +++ b/src/aggregators/direct.jl @@ -9,22 +9,22 @@ mutable struct DirectJumpAggregation{T,S,F1,F2,RNG} <: AbstractSSAJumpAggregator affects!::F2 save_positions::Tuple{Bool,Bool} rng::RNG - DirectJumpAggregation{T,S,F1,F2,RNG}(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG) where {T,S,F1,F2,RNG} = + DirectJumpAggregation{T,S,F1,F2,RNG}(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG) where {T,S,F1,F2,RNG} = new{T,S,F1,F2,RNG}(nj, njt, et, crs, sr, maj, rs, affs!, sps, rng) end -DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG; kwargs...) where {T,S,F1,F2,RNG} = +DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG; kwargs...) where {T,S,F1,F2,RNG} = DirectJumpAggregation{T,S,F1,F2,RNG}(nj, njt, et, crs, sr, maj, rs, affs!, sps, rng) ########### The following routines should be templates for all SSAs ########### # condition for jump to occur -@inline function (p::DirectJumpAggregation)(u, t, integrator) +@inline function (p::DirectJumpAggregation)(u, t, integrator) p.next_jump_time == t end # executing jump at the next jump time -function (p::DirectJumpAggregation)(integrator) +function (p::DirectJumpAggregation)(integrator) execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t) register_next_jump_time!(integrator, p, integrator.t) @@ -41,24 +41,24 @@ end ############################# Required Functions ############################# # creating the JumpAggregation structure (tuple-based constant jumps) -function aggregate(aggregator::Direct, u, p, t, end_time, constant_jumps, +function aggregate(aggregator::Direct, u, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; kwargs...) # handle constant jumps using tuples rates, affects! = get_jump_info_tuples(constant_jumps) - build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, + build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; kwargs...) end # creating the JumpAggregation structure (function wrapper-based constant jumps) -function aggregate(aggregator::DirectFW, u, p, t, end_time, constant_jumps, +function aggregate(aggregator::DirectFW, u, p, t, end_time, constant_jumps, ma_jumps, save_positions, rng; kwargs...) # handle constant jumps using function wrappers rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps) - build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, + build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps, rates, affects!, save_positions, rng; kwargs...) end @@ -72,7 +72,7 @@ end @inline function execute_jumps!(p::DirectJumpAggregation, integrator, u, params, t) num_ma_rates = get_num_majumps(p.ma_jumps) if p.next_jump <= num_ma_rates - @inbounds executerx!(u, p.next_jump, p.ma_jumps) + @inbounds executerx!(u, p.next_jump, p.ma_jumps) else idx = p.next_jump - num_ma_rates @inbounds p.affects![idx](integrator) @@ -101,12 +101,12 @@ end majumps = p.ma_jumps idx = get_num_majumps(majumps) @inbounds for i in 1:idx - new_rate = evalrxrate(u, i, majumps) + new_rate = evalrxrate(u, i, majumps) cur_rates[i] = new_rate + prev_rate prev_rate = cur_rates[i] end - - # constant jump rates + + # constant jump rates rates = p.rates if !isempty(rates) idx += 1 @@ -143,7 +143,7 @@ end majumps = p.ma_jumps idx = get_num_majumps(majumps) @inbounds for i in 1:idx - new_rate = evalrxrate(u, i, majumps) + new_rate = evalrxrate(u, i, majumps) cur_rates[i] = new_rate + prev_rate prev_rate = cur_rates[i] end @@ -161,4 +161,3 @@ end @inbounds sum_rate = cur_rates[end] sum_rate, randexp(p.rng) / sum_rate end - diff --git a/src/aggregators/rssa.jl b/src/aggregators/rssa.jl index dd23319f..207c394a 100644 --- a/src/aggregators/rssa.jl +++ b/src/aggregators/rssa.jl @@ -126,7 +126,7 @@ function initialize!(p::RSSAJumpAggregation, integrator, u, params, t) sum_rate += crhigh[k] k += 1 end - p.sum_rate = sum_rate + p.sum_rate = sum_rate generate_jumps!(p, integrator, u, params, t) nothing @@ -213,7 +213,7 @@ end notdone = false end else - @inbounds crate = rate[jidx - num_majumps](u, params, t) + @inbounds crate = p.rates[jidx - num_majumps](u, params, t) if crate > zero(crate) && r2 <= crate notdone = false end diff --git a/test/degenerate_rx_cases.jl b/test/degenerate_rx_cases.jl index 5e33203c..88088ce5 100644 --- a/test/degenerate_rx_cases.jl +++ b/test/degenerate_rx_cases.jl @@ -10,7 +10,7 @@ doprint = false #using Plots; plotlyjs() doplot = false -methods = (Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM()) +methods = (Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM(), RSSA()) # one reaction case, mass action jump, vector of data rate = [2.0] @@ -94,9 +94,12 @@ dep_graph = [ [1, 2], [1, 2] ] +spec_to_dep_jumps = [[2]] +jump_to_dep_specs = [[1],[1]] +namedpars = (dep_graph=dep_graph, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs) for method in methods - jump_prob = JumpProblem(prob, method, jump, jump2; dep_graph=dep_graph) + jump_prob = JumpProblem(prob, method, jump, jump2; namedpars...) sol = solve(jump_prob, SSAStepper()) if doplot diff --git a/test/linearreaction_test.jl b/test/linearreaction_test.jl index 0c37ccf5..1b40e4c7 100644 --- a/test/linearreaction_test.jl +++ b/test/linearreaction_test.jl @@ -14,8 +14,11 @@ tf = .1 baserate = .1 A0 = 100 exactmean = (t,ratevec) -> A0 * exp(-sum(ratevec) * t) -SSAalgs = [Direct()]#, DirectFW(), FRM(), FRMFW()] +SSAalgs = [Direct(),RSSA()] #[Direct(),RSSA()]#, DirectFW(), FRM(), FRMFW()] +spec_to_dep_jumps = [collect(1:Nrxs),[]] +jump_to_dep_specs = [[1,2] for i=1:Nrxs] +namedpars = (vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs) rates = ones(Float64, Nrxs) * baserate; cumsum!(rates, rates) exactmeanval = exactmean(tf, rates) @@ -47,7 +50,7 @@ function A_to_B_tuple(N, method) jumps = ((jump for jump in jumpvec)...,) jset = JumpSet((), jumps, nothing, nothing) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...) jump_prob end @@ -68,7 +71,7 @@ function A_to_B_vec(N, method) # convert jumpvec to tuple to send to JumpProblem... jset = JumpSet((), jumps, nothing, nothing) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...) jump_prob end @@ -85,7 +88,7 @@ function A_to_B_ma(N, method) majumps = MassActionJump(rates, reactstoch, netstoch) jset = JumpSet((), (), nothing, majumps) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...) jump_prob end @@ -118,7 +121,7 @@ function A_to_B_hybrid(N, method) majumps = MassActionJump(rates[1:switchidx] , reactstoch, netstoch) jset = JumpSet((), jumps, nothing, majumps) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...) jump_prob end @@ -151,7 +154,7 @@ function A_to_B_hybrid_nojset(N, method) majumps = MassActionJump(rates[1:switchidx] , reactstoch, netstoch) jumps = (constjumps...,majumps) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jumps...; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jumps...; save_positions=(false,false), namedpars...) jump_prob end @@ -181,7 +184,7 @@ function A_to_B_hybrid_vecs(N, method) end jset = JumpSet((), jumpvec, nothing, majumps) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...) jump_prob end @@ -210,7 +213,7 @@ function A_to_B_hybrid_vecs_scalars(N, method) end jset = JumpSet((), jumpvec, nothing, majumps) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...) jump_prob end @@ -241,7 +244,7 @@ function A_to_B_hybrid_tups_scalars(N, method) jumps = ((maj for maj in majumpsv)..., (jump for jump in jumpvec)...) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jumps...; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jumps...; save_positions=(false,false), namedpars...) jump_prob end @@ -272,7 +275,7 @@ function A_to_B_hybrid_tups(N, method) jumps = ((jump for jump in jumpvec)...,) jset = JumpSet((), jumps, nothing, majumps) prob = DiscreteProblem([A0,0], (0.0,tf)) - jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false)) + jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...) jump_prob end