Skip to content

Commit

Permalink
Merge e85d1e8 into fd47a82
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Apr 19, 2023
2 parents fd47a82 + e85d1e8 commit ef65cc2
Show file tree
Hide file tree
Showing 14 changed files with 172 additions and 116 deletions.
17 changes: 9 additions & 8 deletions src/aggregators/coevolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Queue method. This method handles variable intensity rates.
"""
mutable struct CoevolveJumpAggregation{T, S, F1, F2, RNG, GR, PQ} <:
AbstractSSAJumpAggregator
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int # the next jump to execute
prev_jump::Int # the previous jump that was executed
next_jump_time::T # the time of the next jump
Expand Down Expand Up @@ -46,7 +46,8 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not
end

pq = MutableBinaryMinHeap{T}()
CoevolveJumpAggregation{T, S, F1, F2, RNG, typeof(dg),
affecttype = F2 <: Tuple ? F2 : Any
CoevolveJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg),
typeof(pq)}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps, rng,
dg, pq, lrates, urates, rateintervals, haslratevec)
end
Expand All @@ -55,14 +56,13 @@ end
function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; dep_graph = nothing,
variable_jumps = nothing, kwargs...)
AffectWrapper = FunctionWrappers.FunctionWrapper{Nothing, Tuple{Any}}
RateWrapper = FunctionWrappers.FunctionWrapper{typeof(t),
Tuple{typeof(u), typeof(p), typeof(t)}}

ncrjs = (constant_jumps === nothing) ? 0 : length(constant_jumps)
nvrjs = (variable_jumps === nothing) ? 0 : length(variable_jumps)
nrjs = ncrjs + nvrjs
affects! = Vector{AffectWrapper}(undef, nrjs)
affects! = Vector{Any}(undef, nrjs)
rates = Vector{RateWrapper}(undef, nvrjs)
lrates = similar(rates)
rateintervals = similar(rates)
Expand All @@ -72,15 +72,15 @@ function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
idx = 1
if constant_jumps !== nothing
for crj in constant_jumps
affects![idx] = AffectWrapper(integ -> (crj.affect!(integ); nothing))
affects![idx] = integ -> (crj.affect!(integ); nothing)
urates[idx] = RateWrapper(crj.rate)
idx += 1
end
end

if variable_jumps !== nothing
for (i, vrj) in enumerate(variable_jumps)
affects![idx] = AffectWrapper(integ -> (vrj.affect!(integ); nothing))
affects![idx] = integ -> (vrj.affect!(integ); nothing)
urates[idx] = RateWrapper(vrj.urate)
idx += 1
rates[i] = RateWrapper(vrj.rate)
Expand Down Expand Up @@ -109,9 +109,10 @@ function initialize!(p::CoevolveJumpAggregation, integrator, u, params, t)
end

# execute one jump, changing the system state
function execute_jumps!(p::CoevolveJumpAggregation, integrator, u, params, t)
function execute_jumps!(p::CoevolveJumpAggregation, integrator, u, params, t, affects!)
# execute jump
u = update_state!(p, integrator, u)
u = update_state!(p, integrator, u, affects!)

# update current jump rates and times
update_dependent_rates!(p, u, params, t)
nothing
Expand Down
42 changes: 11 additions & 31 deletions src/aggregators/direct.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mutable struct DirectJumpAggregation{T, S, F1, F2, RNG} <: AbstractSSAJumpAggregator
mutable struct DirectJumpAggregation{T, S, F1, F2, RNG} <:
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
next_jump_time::T
Expand All @@ -14,8 +15,9 @@ end
function DirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S,
rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG;
kwargs...) where {T, S, F1, F2, RNG}
DirectJumpAggregation{T, S, F1, F2, RNG}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps,
rng)
affecttype = F2 <: Tuple ? F2 : Any
DirectJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs,
affs!, sps, rng)
end

############################# Required Functions #############################
Expand Down Expand Up @@ -50,8 +52,8 @@ function initialize!(p::DirectJumpAggregation, integrator, u, params, t)
end

# execute one jump, changing the system state
@inline function execute_jumps!(p::DirectJumpAggregation, integrator, u, params, t)
update_state!(p, integrator, u)
@inline function execute_jumps!(p::DirectJumpAggregation, integrator, u, params, t, affects!)
update_state!(p, integrator, u, affects!)
nothing
end

Expand All @@ -66,8 +68,8 @@ end
######################## SSA specific helper routines ########################

# tuple-based constant jumps
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: Tuple, F2, RNG}
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params,
t) where {T, S, F1 <: Tuple}
prev_rate = zero(t)
new_rate = zero(t)
cur_rates = p.cur_rates
Expand Down Expand Up @@ -108,8 +110,8 @@ end
end

# function wrapper-based constant jumps
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: AbstractArray, F2, RNG}
function time_to_next_jump(p::DirectJumpAggregation{T, S, F1}, u, params,
t) where {T, S, F1 <: AbstractArray}
prev_rate = zero(t)
new_rate = zero(t)
cur_rates = p.cur_rates
Expand All @@ -136,25 +138,3 @@ function time_to_next_jump(p::DirectJumpAggregation{T, S, F1, F2, RNG}, u, param
@inbounds sum_rate = cur_rates[end]
sum_rate, randexp(p.rng) / sum_rate
end

@generated function update_state!(p::DirectJumpAggregation{T, S, F1, F2}, integrator,
u) where {T, S, F1 <: Tuple, F2 <: Tuple}
quote
@unpack ma_jumps, next_jump = p
num_ma_rates = get_num_majumps(ma_jumps)
if next_jump <= num_ma_rates # is next jump a mass action jump
if u isa SVector
integrator.u = executerx(u, next_jump, ma_jumps)
else
@inbounds executerx!(u, next_jump, ma_jumps)
end
else
idx = next_jump - num_ma_rates
Base.Cartesian.@nif $(fieldcount(F2)) i->(i == idx) i->(@inbounds p.affects![i](integrator)) i->(@inbounds p.affects![fieldcount(F2)](integrator))
end

# save jump that was just executed
p.prev_jump = next_jump
return integrator.u
end
end
10 changes: 6 additions & 4 deletions src/aggregators/directcr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ by S. Mauch and M. Stalzer, ACM Trans. Comp. Biol. and Bioinf., 8, No. 1, 27-35
const MINJUMPRATE = 2.0^exponent(1e-12)

mutable struct DirectCRJumpAggregation{T, S, F1, F2, RNG, DEPGR, U <: PriorityTable,
W <: Function} <: AbstractSSAJumpAggregator
W <: Function} <:
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
next_jump_time::T
Expand Down Expand Up @@ -61,7 +62,8 @@ function DirectCRJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,
# construct an empty initial priority table -- we'll reset this in init
rt = PriorityTable(ratetogroup, zeros(T, 1), minrate, 2 * minrate)

DirectCRJumpAggregation{T, S, F1, F2, RNG, typeof(dg),
affecttype = F2 <: Tuple ? F2 : Any
DirectCRJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg),
typeof(rt), typeof(ratetogroup)}(nj, nj, njt, et, crs, sr, maj,
rs, affs!, sps, rng, dg,
minrate, maxrate, rt,
Expand Down Expand Up @@ -100,9 +102,9 @@ function initialize!(p::DirectCRJumpAggregation, integrator, u, params, t)
end

# execute one jump, changing the system state
function execute_jumps!(p::DirectCRJumpAggregation, integrator, u, params, t)
function execute_jumps!(p::DirectCRJumpAggregation, integrator, u, params, t, affects!)
# execute jump
u = update_state!(p, integrator, u)
u = update_state!(p, integrator, u, affects!)

# update current jump rates
update_dependent_rates!(p, u, params, t)
Expand Down
17 changes: 9 additions & 8 deletions src/aggregators/frm.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mutable struct FRMJumpAggregation{T, S, F1, F2, RNG} <: AbstractSSAJumpAggregator
mutable struct FRMJumpAggregation{T, S, F1, F2, RNG} <:
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
next_jump_time::T
Expand All @@ -14,8 +15,9 @@ end
function FRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, maj::S, rs::F1,
affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG;
kwargs...) where {T, S, F1, F2, RNG}
FRMJumpAggregation{T, S, F1, F2, RNG}(nj, nj, njt, et, crs, sr, maj, rs, affs!, sps,
rng)
affecttype = F2 <: Tuple ? F2 : Any
FRMJumpAggregation{T, S, F1, affecttype, RNG}(nj, nj, njt, et, crs, sr, maj, rs,
affs!, sps, rng)
end

############################# Required Functions #############################
Expand Down Expand Up @@ -50,9 +52,9 @@ function initialize!(p::FRMJumpAggregation, integrator, u, params, t)
end

# execute one jump, changing the system state
@inline function execute_jumps!(p::FRMJumpAggregation, integrator, u, params, t)
@inline function execute_jumps!(p::FRMJumpAggregation, integrator, u, params, t, affects!)
# execute jump
update_state!(p, integrator, u)
update_state!(p, integrator, u, affects!)
nothing
end

Expand Down Expand Up @@ -110,9 +112,8 @@ function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, pa
end

# function wrapper-based constant jumps
function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1, F2, RNG}, u, params,
t) where {T, S, F1 <: AbstractArray, F2 <: AbstractArray,
RNG}
function next_constant_rate_jump(p::FRMJumpAggregation{T, S, F1}, u, params,
t) where {T, S, F1 <: AbstractArray}
ttnj = typemax(typeof(t))
nextrx = zero(Int)
if !isempty(p.rates)
Expand Down
16 changes: 9 additions & 7 deletions src/aggregators/nrm.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Implementation the original Next Reaction Method
# Gibson and Bruck, J. Phys. Chem. A, 104 (9), (2000)

mutable struct NRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PQ} <: AbstractSSAJumpAggregator
mutable struct NRMJumpAggregation{T, S, F1, F2, RNG, DEPGR, PQ} <:
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
next_jump_time::T
Expand Down Expand Up @@ -38,10 +39,11 @@ function NRMJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,

pq = MutableBinaryMinHeap{T}()

NRMJumpAggregation{T, S, F1, F2, RNG, typeof(dg), typeof(pq)}(nj, nj, njt, et, crs, sr,
maj,
rs, affs!, sps, rng, dg,
pq)
affecttype = F2 <: Tuple ? F2 : Any
NRMJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg), typeof(pq)}(nj, nj, njt, et,
crs, sr, maj,
rs, affs!, sps,
rng, dg, pq)
end

+############################# Required Functions ##############################
Expand All @@ -66,9 +68,9 @@ function initialize!(p::NRMJumpAggregation, integrator, u, params, t)
end

# execute one jump, changing the system state
function execute_jumps!(p::NRMJumpAggregation, integrator, u, params, t)
function execute_jumps!(p::NRMJumpAggregation, integrator, u, params, t, affects!)
# execute jump
u = update_state!(p, integrator, u)
u = update_state!(p, integrator, u, affects!)

# update current jump rates and times
update_dependent_rates!(p, u, params, t)
Expand Down
17 changes: 10 additions & 7 deletions src/aggregators/rdirect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
Direct with rejection sampling
"""

mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <: AbstractSSAJumpAggregator
mutable struct RDirectJumpAggregation{T, S, F1, F2, RNG, DEPGR} <:
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
next_jump_time::T
Expand Down Expand Up @@ -40,10 +41,12 @@ function RDirectJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T, m
end

max_rate = maximum(crs)
return RDirectJumpAggregation{T, S, F1, F2, RNG, typeof(dg)}(nj, nj, njt, et, crs, sr,
maj, rs, affs!, sps, rng,
dg, max_rate, 0,
counter_threshold)
affecttype = F2 <: Tuple ? F2 : Any
return RDirectJumpAggregation{T, S, F1, affecttype, RNG, typeof(dg)}(nj, nj, njt, et,
crs, sr, maj, rs,
affs!, sps, rng,
dg, max_rate, 0,
counter_threshold)
end

############################# Required Functions #############################
Expand Down Expand Up @@ -72,9 +75,9 @@ end
"""
execute one jump, changing the system state and updating rates
"""
function execute_jumps!(p::RDirectJumpAggregation, integrator, u, params, t)
function execute_jumps!(p::RDirectJumpAggregation, integrator, u, params, t, affects!)
# execute jump
u = update_state!(p, integrator, u)
u = update_state!(p, integrator, u, affects!)

# update rates
update_dependent_rates!(p, u, params, t)
Expand Down
19 changes: 11 additions & 8 deletions src/aggregators/rssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# requires vartojumps_map and fluct_rates as JumpProblem keywords

mutable struct RSSAJumpAggregation{T, T2, S, F1, F2, RNG, VJMAP, JVMAP, BD, T2V} <:
AbstractSSAJumpAggregator
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
next_jump_time::T
Expand Down Expand Up @@ -65,11 +65,14 @@ function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,
ulow = @view cs_bnds[1, :]
uhigh = @view cs_bnds[2, :]

RSSAJumpAggregation{T, eltype(U), S, F1, F2, RNG, typeof(vtoj_map), typeof(jtov_map),
typeof(bd), typeof(ulow)}(nj, nj, njt, et, crl_bnds, crh_bnds, sr,
cs_bnds, maj, rs,
affs!, sps, rng, vtoj_map, jtov_map, bd,
ulow, uhigh)
affecttype = F2 <: Tuple ? F2 : Any
RSSAJumpAggregation{T, eltype(U), S, F1, affecttype, RNG, typeof(vtoj_map),
typeof(jtov_map), typeof(bd), typeof(ulow)}(nj, nj, njt, et,
crl_bnds, crh_bnds, sr,
cs_bnds, maj, rs,
affs!, sps, rng,
vtoj_map, jtov_map, bd,
ulow, uhigh)
end

############################# Required Functions ##############################
Expand All @@ -95,9 +98,9 @@ function initialize!(p::RSSAJumpAggregation, integrator, u, params, t)
end

# execute one jump, changing the system state
function execute_jumps!(p::RSSAJumpAggregation, integrator, u, params, t)
function execute_jumps!(p::RSSAJumpAggregation, integrator, u, params, t, affects!)
# execute jump
u = update_state!(p, integrator, u)
u = update_state!(p, integrator, u, affects!)
update_rates!(p, u, params, t)
nothing
end
Expand Down
14 changes: 7 additions & 7 deletions src/aggregators/rssacr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ Composition-Rejection with Rejection sampling method (RSSA-CR)

const MINJUMPRATE = 2.0^exponent(1e-12)

mutable struct RSSACRJumpAggregation{F, U, S, F1, F2, RNG, VJMAP, JVMAP, BD, T2V,
mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, T2V,
P <: PriorityTable, W <: Function} <:
AbstractSSAJumpAggregator
AbstractSSAJumpAggregator{F, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
next_jump_time::F
Expand All @@ -32,8 +32,7 @@ mutable struct RSSACRJumpAggregation{F, U, S, F1, F2, RNG, VJMAP, JVMAP, BD, T2V
end

function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate::F, maj::S,
rs::F1,
affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; u::U,
rs::F1, affs!::F2, sps::Tuple{Bool, Bool}, rng::RNG; u::U,
vartojumps_map = nothing, jumptovars_map = nothing,
bracket_data = nothing, minrate = convert(F, MINJUMPRATE),
maxrate = convert(F, Inf),
Expand Down Expand Up @@ -82,7 +81,8 @@ function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate:
# construct an empty initial priority table -- we'll reset this in init
rt = PriorityTable(ratetogroup, zeros(F, 1), minrate, 2 * minrate)

RSSACRJumpAggregation{typeof(njt), eltype(U), S, F1, F2, RNG, typeof(vtoj_map),
affecttype = F2 <: Tuple ? F2 : Any
RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, RNG, eltype(U), typeof(vtoj_map),
typeof(jtov_map), typeof(bd), typeof(ulow), typeof(rt),
typeof(ratetogroup)}(nj, nj, njt, et, crl_bnds, crh_bnds,
sum_rate, maj, rs, affs!, sps, rng, vtoj_map,
Expand Down Expand Up @@ -119,9 +119,9 @@ function initialize!(p::RSSACRJumpAggregation, integrator, u, params, t)
end

# execute one jump, changing the system state
function execute_jumps!(p::RSSACRJumpAggregation, integrator, u, params, t)
function execute_jumps!(p::RSSACRJumpAggregation, integrator, u, params, t, affects!)
# execute jump
u = update_state!(p, integrator, u)
u = update_state!(p, integrator, u, affects!)

# update rates
update_dependent_rates!(p, u, params, t)
Expand Down
Loading

0 comments on commit ef65cc2

Please sign in to comment.