Skip to content

Commit

Permalink
Merge 846818b into 87d93d0
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Sep 7, 2022
2 parents 87d93d0 + 846818b commit 1881abd
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 26 deletions.
1 change: 1 addition & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using DataStructures, PoissonRandom, Random, ArrayInterfaceCore
using FunctionWrappers, UnPack
using Graphs
using SciMLBase: SciMLBase
using Base.FastMath: add_fast

import DiffEqBase: DiscreteCallback, init, solve, solve!, plot_indices, initialize!
import Base: size, getindex, setindex!, length, similar, show, merge!, merge
Expand Down
19 changes: 9 additions & 10 deletions src/aggregators/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ end
# calculate the next jump / jump time
function generate_jumps!(p::DirectJumpAggregation, integrator, u, params, t)
p.sum_rate, ttnj = time_to_next_jump(p, u, params, t)
@fastmath p.next_jump_time = t + ttnj
p.next_jump_time = add_fast(t, ttnj)
@inbounds p.next_jump = searchsortedfirst(p.cur_rates, rand(p.rng) * p.sum_rate)
nothing
end

######################## SSA specific helper routines ########################

# tuple-based constant jumps
@fastmath function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: Tuple, F2 <: Tuple, RNG}
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: Tuple, F2 <: Tuple, RNG}
prev_rate = zero(t)
new_rate = zero(t)
cur_rates = p.cur_rates
Expand All @@ -77,7 +77,7 @@ end
idx = get_num_majumps(majumps)
@inbounds for i in 1:idx
new_rate = evalrxrate(u, i, majumps)
cur_rates[i] = new_rate + prev_rate
cur_rates[i] = add_fast(new_rate, prev_rate)
prev_rate = cur_rates[i]
end

Expand All @@ -87,7 +87,7 @@ end
idx += 1
fill_cur_rates(u, params, t, cur_rates, idx, rates...)
@inbounds for i in idx:length(cur_rates)
cur_rates[i] = cur_rates[i] + prev_rate
cur_rates[i] = add_fast(cur_rates[i], prev_rate)
prev_rate = cur_rates[i]
end
end
Expand All @@ -108,9 +108,8 @@ end
end

# function wrapper-based constant jumps
@fastmath function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: AbstractArray,
F2 <: AbstractArray, RNG}
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: AbstractArray, F2 <: AbstractArray, RNG}
prev_rate = zero(t)
new_rate = zero(t)
cur_rates = p.cur_rates
Expand All @@ -120,7 +119,7 @@ end
idx = get_num_majumps(majumps)
@inbounds for i in 1:idx
new_rate = evalrxrate(u, i, majumps)
cur_rates[i] = new_rate + prev_rate
cur_rates[i] = add_fast(new_rate, prev_rate)
prev_rate = cur_rates[i]
end

Expand All @@ -129,7 +128,7 @@ end
rates = p.rates
@inbounds for i in 1:length(p.rates)
new_rate = rates[i](u, params, t)
cur_rates[idx] = new_rate + prev_rate
cur_rates[idx] = add_fast(new_rate, prev_rate)
prev_rate = cur_rates[idx]
idx += 1
end
Expand Down
14 changes: 6 additions & 8 deletions src/aggregators/frm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end
######################## SSA specific helper routines ########################

# mass action jumps
@fastmath function next_ma_jump(p::FRMJumpAggregation, u, params, t)
function next_ma_jump(p::FRMJumpAggregation, u, params, t)
ttnj = typemax(typeof(t))
nextrx = zero(Int)
majumps = p.ma_jumps
Expand All @@ -91,9 +91,8 @@ end
end

# tuple-based constant jumps
@fastmath function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u,
params,
t) where {T, S, F1 <: Tuple, F2 <: Tuple, RNG}
function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: Tuple, F2 <: Tuple, RNG}
ttnj = typemax(typeof(t))
nextrx = zero(Int)
if !isempty(p.rates)
Expand All @@ -111,10 +110,9 @@ end
end

# function wrapper-based constant jumps
@fastmath function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u,
params,
t) where {T, S, F1 <: AbstractArray,
F2 <: AbstractArray, RNG}
function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: AbstractArray, F2 <: AbstractArray,
RNG}
ttnj = typemax(typeof(t))
nextrx = zero(Int)
if !isempty(p.rates)
Expand Down
12 changes: 6 additions & 6 deletions src/massaction_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# stochiometric coefficient.
###############################################################################

@inline @fastmath function evalrxrate(speciesvec::AbstractVector{T}, rxidx::S,
majump::MassActionJump{U, V, W, X})::R where
@inline function evalrxrate(speciesvec::AbstractVector{T}, rxidx::S,
majump::MassActionJump{U, V, W, X})::R where
{T, S, R, U <: AbstractVector{R}, V, W, X}
val = one(T)
@inbounds for specstoch in majump.reactant_stoch[rxidx]
Expand All @@ -19,17 +19,17 @@
@inbounds return val * majump.scaled_rates[rxidx]
end

@inline @fastmath function executerx!(speciesvec::AbstractVector{T}, rxidx::S,
majump::M) where {T, S, M <: AbstractMassActionJump}
@inline function executerx!(speciesvec::AbstractVector{T}, rxidx::S,
majump::M) where {T, S, M <: AbstractMassActionJump}
@inbounds net_stoch = majump.net_stoch[rxidx]
@inbounds for specstoch in net_stoch
speciesvec[specstoch[1]] += specstoch[2]
end
nothing
end

@inline @fastmath function executerx(speciesvec::SVector{T}, rxidx::S,
majump::M) where {T, S, M <: AbstractMassActionJump}
@inline function executerx(speciesvec::SVector{T}, rxidx::S,
majump::M) where {T, S, M <: AbstractMassActionJump}
@inbounds net_stoch = majump.net_stoch[rxidx]
@inbounds for specstoch in net_stoch
speciesvec = setindex(speciesvec, speciesvec[specstoch[1]] + specstoch[2],
Expand Down
4 changes: 2 additions & 2 deletions test/ssa_callback_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ function fuel_affect!(integrator)
end
cb = DiscreteCallback(condition, fuel_affect!, save_positions = (false, true))

sol = solve(jump_prob, SSAStepper(), callback = cb, tstops = [5])
sol = solve(jump_prob, SSAStepper(); callback = cb, tstops = [5])
@test sol.t[1:2] == [0.0, 5.0] # no jumps between t=0 and t=5
@test sol(5 + 1e-10) == [100, 0] # state just after fueling before any decays can happen

# test can pass callbacks via JumpProblem
jump_prob2 = JumpProblem(prob, Direct(), jump; rng = rng, callback = cb)
sol2 = solve(jump_prob2, SSAStepper(), tstops = [5])
sol2 = solve(jump_prob2, SSAStepper(); tstops = [5])
@test sol2.t[1:2] == [0.0, 5.0] # no jumps between t=0 and t=5
@test sol2(5 + 1e-10) == [100, 0] # state just after fueling before any decays can happen

Expand Down

0 comments on commit 1881abd

Please sign in to comment.