Skip to content

Commit

Permalink
update tests for RSSA
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Nov 5, 2018
1 parent 4f1a864 commit 484d5bc
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 28 deletions.
27 changes: 13 additions & 14 deletions src/aggregators/direct.jl
Expand Up @@ -9,22 +9,22 @@ mutable struct DirectJumpAggregation{T,S,F1,F2,RNG} <: AbstractSSAJumpAggregator
affects!::F2
save_positions::Tuple{Bool,Bool}
rng::RNG
DirectJumpAggregation{T,S,F1,F2,RNG}(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG) where {T,S,F1,F2,RNG} =
DirectJumpAggregation{T,S,F1,F2,RNG}(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG) where {T,S,F1,F2,RNG} =
new{T,S,F1,F2,RNG}(nj, njt, et, crs, sr, maj, rs, affs!, sps, rng)
end
DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG; kwargs...) where {T,S,F1,F2,RNG} =
DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1, affs!::F2, sps::Tuple{Bool,Bool}, rng::RNG; kwargs...) where {T,S,F1,F2,RNG} =
DirectJumpAggregation{T,S,F1,F2,RNG}(nj, njt, et, crs, sr, maj, rs, affs!, sps, rng)


########### The following routines should be templates for all SSAs ###########

# condition for jump to occur
@inline function (p::DirectJumpAggregation)(u, t, integrator)
@inline function (p::DirectJumpAggregation)(u, t, integrator)
p.next_jump_time == t
end

# executing jump at the next jump time
function (p::DirectJumpAggregation)(integrator)
function (p::DirectJumpAggregation)(integrator)
execute_jumps!(p, integrator, integrator.u, integrator.p, integrator.t)
generate_jumps!(p, integrator, integrator.u, integrator.p, integrator.t)
register_next_jump_time!(integrator, p, integrator.t)
Expand All @@ -41,24 +41,24 @@ end
############################# Required Functions #############################

# creating the JumpAggregation structure (tuple-based constant jumps)
function aggregate(aggregator::Direct, u, p, t, end_time, constant_jumps,
function aggregate(aggregator::Direct, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; kwargs...)

# handle constant jumps using tuples
rates, affects! = get_jump_info_tuples(constant_jumps)

build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps,
build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps,
rates, affects!, save_positions, rng; kwargs...)
end

# creating the JumpAggregation structure (function wrapper-based constant jumps)
function aggregate(aggregator::DirectFW, u, p, t, end_time, constant_jumps,
function aggregate(aggregator::DirectFW, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; kwargs...)

# handle constant jumps using function wrappers
rates, affects! = get_jump_info_fwrappers(u, p, t, constant_jumps)

build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps,
build_jump_aggregation(DirectJumpAggregation, u, p, t, end_time, ma_jumps,
rates, affects!, save_positions, rng; kwargs...)
end

Expand All @@ -72,7 +72,7 @@ end
@inline function execute_jumps!(p::DirectJumpAggregation, integrator, u, params, t)
num_ma_rates = get_num_majumps(p.ma_jumps)
if p.next_jump <= num_ma_rates
@inbounds executerx!(u, p.next_jump, p.ma_jumps)
@inbounds executerx!(u, p.next_jump, p.ma_jumps)
else
idx = p.next_jump - num_ma_rates
@inbounds p.affects![idx](integrator)
Expand Down Expand Up @@ -101,12 +101,12 @@ end
majumps = p.ma_jumps
idx = get_num_majumps(majumps)
@inbounds for i in 1:idx
new_rate = evalrxrate(u, i, majumps)
new_rate = evalrxrate(u, i, majumps)
cur_rates[i] = new_rate + prev_rate
prev_rate = cur_rates[i]
end
# constant jump rates

# constant jump rates
rates = p.rates
if !isempty(rates)
idx += 1
Expand Down Expand Up @@ -143,7 +143,7 @@ end
majumps = p.ma_jumps
idx = get_num_majumps(majumps)
@inbounds for i in 1:idx
new_rate = evalrxrate(u, i, majumps)
new_rate = evalrxrate(u, i, majumps)
cur_rates[i] = new_rate + prev_rate
prev_rate = cur_rates[i]
end
Expand All @@ -161,4 +161,3 @@ end
@inbounds sum_rate = cur_rates[end]
sum_rate, randexp(p.rng) / sum_rate
end

4 changes: 2 additions & 2 deletions src/aggregators/rssa.jl
Expand Up @@ -126,7 +126,7 @@ function initialize!(p::RSSAJumpAggregation, integrator, u, params, t)
sum_rate += crhigh[k]
k += 1
end
p.sum_rate = sum_rate
p.sum_rate = sum_rate

generate_jumps!(p, integrator, u, params, t)
nothing
Expand Down Expand Up @@ -213,7 +213,7 @@ end
notdone = false
end
else
@inbounds crate = rate[jidx - num_majumps](u, params, t)
@inbounds crate = p.rates[jidx - num_majumps](u, params, t)
if crate > zero(crate) && r2 <= crate
notdone = false
end
Expand Down
7 changes: 5 additions & 2 deletions test/degenerate_rx_cases.jl
Expand Up @@ -10,7 +10,7 @@ doprint = false
#using Plots; plotlyjs()
doplot = false

methods = (Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM())
methods = (Direct(), DirectFW(), FRM(), FRMFW(), SortingDirect(), NRM(), RSSA())

# one reaction case, mass action jump, vector of data
rate = [2.0]
Expand Down Expand Up @@ -94,9 +94,12 @@ dep_graph = [
[1, 2],
[1, 2]
]
spec_to_dep_jumps = [[2]]
jump_to_dep_specs = [[1],[1]]
namedpars = (dep_graph=dep_graph, vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs)

for method in methods
jump_prob = JumpProblem(prob, method, jump, jump2; dep_graph=dep_graph)
jump_prob = JumpProblem(prob, method, jump, jump2; namedpars...)
sol = solve(jump_prob, SSAStepper())

if doplot
Expand Down
23 changes: 13 additions & 10 deletions test/linearreaction_test.jl
Expand Up @@ -14,8 +14,11 @@ tf = .1
baserate = .1
A0 = 100
exactmean = (t,ratevec) -> A0 * exp(-sum(ratevec) * t)
SSAalgs = [Direct()]#, DirectFW(), FRM(), FRMFW()]
SSAalgs = [Direct(),RSSA()] #[Direct(),RSSA()]#, DirectFW(), FRM(), FRMFW()]

spec_to_dep_jumps = [collect(1:Nrxs),[]]
jump_to_dep_specs = [[1,2] for i=1:Nrxs]
namedpars = (vartojumps_map=spec_to_dep_jumps, jumptovars_map=jump_to_dep_specs)
rates = ones(Float64, Nrxs) * baserate;
cumsum!(rates, rates)
exactmeanval = exactmean(tf, rates)
Expand Down Expand Up @@ -47,7 +50,7 @@ function A_to_B_tuple(N, method)
jumps = ((jump for jump in jumpvec)...,)
jset = JumpSet((), jumps, nothing, nothing)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...)

jump_prob
end
Expand All @@ -68,7 +71,7 @@ function A_to_B_vec(N, method)
# convert jumpvec to tuple to send to JumpProblem...
jset = JumpSet((), jumps, nothing, nothing)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...)

jump_prob
end
Expand All @@ -85,7 +88,7 @@ function A_to_B_ma(N, method)
majumps = MassActionJump(rates, reactstoch, netstoch)
jset = JumpSet((), (), nothing, majumps)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...)

jump_prob
end
Expand Down Expand Up @@ -118,7 +121,7 @@ function A_to_B_hybrid(N, method)
majumps = MassActionJump(rates[1:switchidx] , reactstoch, netstoch)
jset = JumpSet((), jumps, nothing, majumps)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...)

jump_prob
end
Expand Down Expand Up @@ -151,7 +154,7 @@ function A_to_B_hybrid_nojset(N, method)
majumps = MassActionJump(rates[1:switchidx] , reactstoch, netstoch)
jumps = (constjumps...,majumps)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jumps...; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jumps...; save_positions=(false,false), namedpars...)

jump_prob
end
Expand Down Expand Up @@ -181,7 +184,7 @@ function A_to_B_hybrid_vecs(N, method)
end
jset = JumpSet((), jumpvec, nothing, majumps)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...)

jump_prob
end
Expand Down Expand Up @@ -210,7 +213,7 @@ function A_to_B_hybrid_vecs_scalars(N, method)
end
jset = JumpSet((), jumpvec, nothing, majumps)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...)

jump_prob
end
Expand Down Expand Up @@ -241,7 +244,7 @@ function A_to_B_hybrid_tups_scalars(N, method)

jumps = ((maj for maj in majumpsv)..., (jump for jump in jumpvec)...)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jumps...; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jumps...; save_positions=(false,false), namedpars...)

jump_prob
end
Expand Down Expand Up @@ -272,7 +275,7 @@ function A_to_B_hybrid_tups(N, method)
jumps = ((jump for jump in jumpvec)...,)
jset = JumpSet((), jumps, nothing, majumps)
prob = DiscreteProblem([A0,0], (0.0,tf))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false))
jump_prob = JumpProblem(prob, method, jset; save_positions=(false,false), namedpars...)

jump_prob
end
Expand Down

0 comments on commit 484d5bc

Please sign in to comment.