Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into coevolve-ii
Browse files Browse the repository at this point in the history
  • Loading branch information
gzagatti committed May 3, 2023
2 parents ebde865 + c8ec2ba commit f39ab18
Show file tree
Hide file tree
Showing 6 changed files with 124 additions and 50 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "JumpProcesses"
uuid = "ccbc3e58-028d-4f4c-8cd5-9ae44345cda5"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "9.6.2"
version = "9.6.3"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
22 changes: 18 additions & 4 deletions src/aggregators/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,27 @@ get brackets for reaction rx by first checking if the reaction is a massaction r
end
end

@inline function update_u_brackets!(p::AbstractSSAJumpAggregator, u::AbstractVector)
@unpack ulow, uhigh = p
@inbounds for (i, uval) in enumerate(u)
ulow[i], uhigh[i] = get_spec_brackets(p.bracket_data, i, uval)
end
nothing
end

@inline function update_u_brackets!(p::AbstractSSAJumpAggregator, u::SVector)
@inbounds for (i, uval) in enumerate(u)
ulow, uhigh = get_spec_brackets(p.bracket_data, i, uval)
p.ulow = setindex(p.ulow, ulow, i)
p.uhigh = setindex(p.uhigh, uhigh, i)
end
nothing
end

# set up bracketing
function set_bracketing!(p::AbstractSSAJumpAggregator, u, params, t)
# species bracketing interval
ubnds = p.cur_u_bnds
@inbounds for (i, uval) in enumerate(u)
ubnds[1, i], ubnds[2, i] = get_spec_brackets(p.bracket_data, i, uval)
end
update_u_brackets!(p, u)

# reaction rate bracketing interval
# mass action jumps
Expand Down
61 changes: 41 additions & 20 deletions src/aggregators/rssa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# functions of the current population sizes (i.e. u)
# requires vartojumps_map and fluct_rates as JumpProblem keywords

mutable struct RSSAJumpAggregation{T, T2, S, F1, F2, RNG, VJMAP, JVMAP, BD, T2V} <:
mutable struct RSSAJumpAggregation{T, S, F1, F2, RNG, VJMAP, JVMAP, BD, U} <:
AbstractSSAJumpAggregator{T, S, F1, F2, RNG}
next_jump::Int
prev_jump::Int
Expand All @@ -13,7 +13,6 @@ mutable struct RSSAJumpAggregation{T, T2, S, F1, F2, RNG, VJMAP, JVMAP, BD, T2V}
cur_rate_low::Vector{T}
cur_rate_high::Vector{T}
sum_rate::T
cur_u_bnds::Matrix{T2}
ma_jumps::S
rates::F1
affects!::F2
Expand All @@ -22,8 +21,8 @@ mutable struct RSSAJumpAggregation{T, T2, S, F1, F2, RNG, VJMAP, JVMAP, BD, T2V}
vartojumps_map::VJMAP
jumptovars_map::JVMAP
bracket_data::BD
ulow::T2V
uhigh::T2V
ulow::U
uhigh::U
end

function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,
Expand Down Expand Up @@ -59,20 +58,16 @@ function RSSAJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::T,
# a bracket data structure is needed for updating species populations
bd = (bracket_data === nothing) ? BracketData{T, eltype(U)}() : bracket_data

# matrix to store bracketing interval for species and the relative interval width
# first row is Xlow, second is Xhigh
cs_bnds = Matrix{eltype(U)}(undef, 2, length(u))
ulow = @view cs_bnds[1, :]
uhigh = @view cs_bnds[2, :]
# current bounds on solution
ulow = similar(u)
uhigh = similar(u)

affecttype = F2 <: Tuple ? F2 : Any
RSSAJumpAggregation{T, eltype(U), S, F1, affecttype, RNG, typeof(vtoj_map),
typeof(jtov_map), typeof(bd), typeof(ulow)}(nj, nj, njt, et,
crl_bnds, crh_bnds, sr,
cs_bnds, maj, rs,
affs!, sps, rng,
vtoj_map, jtov_map, bd,
ulow, uhigh)
RSSAJumpAggregation{T, S, F1, affecttype, RNG, typeof(vtoj_map),
typeof(jtov_map), typeof(bd), U}(nj, nj, njt, et, crl_bnds,
crh_bnds, sr, maj, rs, affs!, sps,
rng, vtoj_map, jtov_map, bd, ulow,
uhigh)
end

############################# Required Functions ##############################
Expand Down Expand Up @@ -147,19 +142,45 @@ end
"""
Update rates
"""
@inline function update_rates!(p::RSSAJumpAggregation, u, params, t)
@inline function update_rates!(p::RSSAJumpAggregation, u::AbstractVector, params, t)
# update bracketing intervals
ubnds = p.cur_u_bnds
@unpack ulow, uhigh = p
sum_rate = p.sum_rate
crhigh = p.cur_rate_high

@inbounds for uidx in p.jumptovars_map[p.next_jump]
uval = u[uidx]

# if new u value is outside the bracketing interval
if uval == zero(uval) || uval < ubnds[1, uidx] || uval > ubnds[2, uidx]
if uval == zero(uval) || uval < ulow[uidx] || uval > uhigh[uidx]
# update u bracketing interval
ubnds[1, uidx], ubnds[2, uidx] = get_spec_brackets(p.bracket_data, uidx, uval)
ulow[uidx], uhigh[uidx] = get_spec_brackets(p.bracket_data, uidx, uval)

# for each dependent jump, update jump rate brackets
for jidx in p.vartojumps_map[uidx]
sum_rate -= crhigh[jidx]
p.cur_rate_low[jidx], crhigh[jidx] = get_jump_brackets(jidx, p, params, t)
sum_rate += crhigh[jidx]
end
end
end
p.sum_rate = sum_rate
end

@inline function update_rates!(p::RSSAJumpAggregation, u::SVector, params, t)
# update bracketing intervals
sum_rate = p.sum_rate
crhigh = p.cur_rate_high

@inbounds for uidx in p.jumptovars_map[p.next_jump]
uval = u[uidx]

# if new u value is outside the bracketing interval
if uval == zero(uval) || uval < p.ulow[uidx] || uval > p.uhigh[uidx]
# update u bracketing interval
ulow, uhigh = get_spec_brackets(p.bracket_data, uidx, uval)
p.ulow = setindex(p.ulow, ulow, uidx)
p.uhigh = setindex(p.uhigh, uhigh, uidx)

# for each dependent jump, update jump rate brackets
for jidx in p.vartojumps_map[uidx]
Expand Down
55 changes: 41 additions & 14 deletions src/aggregators/rssacr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Composition-Rejection with Rejection sampling method (RSSA-CR)

const MINJUMPRATE = 2.0^exponent(1e-12)

mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, T2V,
mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD,
P <: PriorityTable, W <: Function} <:
AbstractSSAJumpAggregator{F, S, F1, F2, RNG}
next_jump::Int
Expand All @@ -22,13 +22,12 @@ mutable struct RSSACRJumpAggregation{F, S, F1, F2, RNG, U, VJMAP, JVMAP, BD, T2V
vartojumps_map::VJMAP
jumptovars_map::JVMAP
bracket_data::BD
ulow::T2V
uhigh::T2V
ulow::U
uhigh::U
minrate::F
maxrate::F # initial maxrate only, table can increase beyond it!
rt::P #rate table
ratetogroup::W
cur_u_bnds::Matrix{U} # current bounds on state u
end

function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate::F, maj::S,
Expand Down Expand Up @@ -67,9 +66,8 @@ function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate:

# matrix to store bracketing interval for species and the relative interval width
# first row is Xlow, second is Xhigh
cs_bnds = Matrix{eltype(U)}(undef, 2, length(u))
ulow = @view cs_bnds[1, :]
uhigh = @view cs_bnds[2, :]
ulow = similar(u)
uhigh = similar(u)

# mapping from jump rate to group id
minexponent = exponent(minrate)
Expand All @@ -82,12 +80,12 @@ function RSSACRJumpAggregation(nj::Int, njt::F, et::F, crs::Vector{F}, sum_rate:
rt = PriorityTable(ratetogroup, zeros(F, 1), minrate, 2 * minrate)

affecttype = F2 <: Tuple ? F2 : Any
RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, RNG, eltype(U), typeof(vtoj_map),
typeof(jtov_map), typeof(bd), typeof(ulow), typeof(rt),
RSSACRJumpAggregation{typeof(njt), S, F1, affecttype, RNG, U, typeof(vtoj_map),
typeof(jtov_map), typeof(bd), typeof(rt),
typeof(ratetogroup)}(nj, nj, njt, et, crl_bnds, crh_bnds,
sum_rate, maj, rs, affs!, sps, rng, vtoj_map,
jtov_map, bd, ulow, uhigh, minrate, maxrate,
rt, ratetogroup, cs_bnds)
rt, ratetogroup)
end

############################# Required Functions ##############################
Expand Down Expand Up @@ -163,17 +161,46 @@ end
"""
update bracketing for species that depend on the just executed jump
"""
@inline function update_dependent_rates!(p::RSSACRJumpAggregation, u, params, t)
@inline function update_dependent_rates!(p::RSSACRJumpAggregation, u::AbstractVector,
params, t)
# update bracketing intervals
ubnds = p.cur_u_bnds
@unpack ulow, uhigh = p
crhigh = p.cur_rate_high

@inbounds for uidx in p.jumptovars_map[p.next_jump]
uval = u[uidx]
# if new u value is outside the bracketing interval
if uval == zero(uval) || uval < ubnds[1, uidx] || uval > ubnds[2, uidx]
if uval == zero(uval) || uval < ulow[uidx] || uval > uhigh[uidx]
# update u bracketing interval
ubnds[1, uidx], ubnds[2, uidx] = get_spec_brackets(p.bracket_data, uidx, uval)
ulow[uidx], uhigh[uidx] = get_spec_brackets(p.bracket_data, uidx, uval)

# for each dependent jump, update jump rate brackets
for jidx in p.vartojumps_map[uidx]
oldrate = crhigh[jidx]
p.cur_rate_low[jidx], crhigh[jidx] = get_jump_brackets(jidx, p, params, t)

# update the priority table
update!(p.rt, jidx, oldrate, crhigh[jidx])
end
end
end

p.sum_rate = groupsum(p.rt)
nothing
end

@inline function update_dependent_rates!(p::RSSACRJumpAggregation, u::SVector, params, t)
# update bracketing intervals
crhigh = p.cur_rate_high

@inbounds for uidx in p.jumptovars_map[p.next_jump]
uval = u[uidx]
# if new u value is outside the bracketing interval
if uval == zero(uval) || uval < p.ulow[uidx] || uval > p.uhigh[uidx]
# update u bracketing interval
ulow, uhigh = get_spec_brackets(p.bracket_data, uidx, uval)
p.ulow = setindex(p.ulow, ulow, uidx)
p.uhigh = setindex(p.uhigh, uhigh, uidx)

# for each dependent jump, update jump rate brackets
for jidx in p.vartojumps_map[uidx]
Expand Down
32 changes: 22 additions & 10 deletions test/allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ let
return/ K) * (K - (X + Y))
end

function makeprob(; T = 100.0, alg = Direct(), save_positions = (false, false))
function makeprob(; T = 100.0, alg = Direct(), save_positions = (false, false),
graphkwargs = (;))
r1(u, p, t) = rate(p[1], u[1], u[2], p[2]) * u[1]
r2(u, p, t) = rate(p[1], u[2], u[1], p[2]) * u[2]
r3(u, p, t) = p[3] * u[1]
Expand Down Expand Up @@ -83,17 +84,28 @@ let
ConstantRateJump(r3, aff3!),
ConstantRateJump(r4, aff4!), ConstantRateJump(r5, aff5!),
ConstantRateJump(r6, aff6!);
save_positions)
save_positions, graphkwargs...)
return jprob
end

jprob1 = makeprob(; T = 10.0)
jprob2 = makeprob(; T = 100.0)
stepper = SSAStepper()
sol1 = solve(jprob1, stepper)
al1 = @allocated solve(jprob1, stepper)
sol2 = solve(jprob2, SSAStepper())
al2 = @allocated solve(jprob2, stepper)
idxs1 = [1, 2, 3, 4]
idxs2 = [1, 2, 4, 5, 6]
idxs = collect(1:6)
dep_graph = [copy(idxs1), copy(idxs2), copy(idxs1), copy(idxs2), copy(idxs), copy(idxs)]
vartojumps_map = [copy(idxs1), copy(idxs2)]
jumptovars_map = [[1], [2], [1], [2], [1, 2], [1, 2]]
graphkwargs = (; dep_graph, vartojumps_map, jumptovars_map)

@test al1 == al2
@testset "Allocations for $agg" for agg in JumpProcesses.JUMP_AGGREGATORS
jprob1 = makeprob(; alg = agg, T = 10.0, graphkwargs)
jprob2 = makeprob(; alg = agg, T = 100.0, graphkwargs)
stepper = SSAStepper()
sol1 = solve(jprob1, stepper)
al1 = @allocated solve(jprob1, stepper)
sol2 = solve(jprob2, SSAStepper())
al2 = @allocated solve(jprob2, stepper)
@test al1 == al2
end
end

nothing
2 changes: 1 addition & 1 deletion test/extinction_test.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DiffEqBase, JumpProcesses, StaticArrays
using JumpProcesses, StaticArrays
using Test
using StableRNGs
rng = StableRNG(12345)
Expand Down

0 comments on commit f39ab18

Please sign in to comment.