Skip to content

Commit

Permalink
another attempt at correcting the mass action jump treatment
Browse files Browse the repository at this point in the history
  • Loading branch information
pihop committed Feb 8, 2023
1 parent 15a31f9 commit fad63ca
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 26 deletions.
55 changes: 29 additions & 26 deletions src/aggregators/extrande.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,11 @@ end
############################# Required Functions ##############################
function aggregate(aggregator::Extrande, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; variable_jumps = (), kwargs...)
ma_jumps_ = !isnothing(ma_jumps) ? ma_jumps : ()
rates, affects! = get_jump_info_fwrappers(u, p, t,
(constant_jumps..., variable_jumps...,
ma_jumps_...,
NullAffectJump))
rbnds, wnds = get_va_jump_bound_info_fwrapper(u, p, t,
(constant_jumps..., variable_jumps...,
ma_jumps_...,
NullAffectJump))
build_jump_aggregation(ExtrandeJumpAggregation, u, p, t, end_time, ma_jumps,
rates, affects!, save_positions, rng; u = u, rate_bounds = rbnds,
Expand All @@ -66,45 +63,51 @@ end

@fastmath function next_extrande_jump(p::ExtrandeJumpAggregation, u, params, t)
ttnj = typemax(typeof(t))
nextrx = zero(Int)
Wmin = typemax(typeof(t))
Bmax = typemax(typeof(t))
Bmax = zero(t)

prev_rate = zero(t)
new_rate = zero(t)
cur_rates = p.cur_rates

# Mass action rates
majumps = p.ma_jumps
idx = get_num_majumps(majumps)

@inbounds for i in 1:idx
new_rate = evalrxrate(u, i, majumps)
cur_rates[i] = add_fast(new_rate, prev_rate)
prev_rate = cur_rates[i]
Bmax += prev_rate
end

# Calculate the total rate bound and the largest common validity window.
if !isempty(p.rate_bnds)
Bmax = typeof(t)(0.0)
@inbounds for i in 1:length(p.wds)
Wmin = min(Wmin, p.wds[i](u, params, t))
Bmax += p.rate_bnds[i](u, params, t)
end
end

# Rejection sampling.
if !isempty(p.rates)
nextrx = length(p.rates)
idx = 1
prop_ttnj = randexp(p.rng) / Bmax
if prop_ttnj < Wmin
nextrx = length(cur_rates)
prop_ttnj = randexp(p.rng) / Bmax
if prop_ttnj < Wmin
if !isempty(p.rates)
idx += 1
fill_cur_rates(u, params, prop_ttnj + t, p.cur_rates, idx, p.rates...)

prev_rate = zero(t)
cur_rates = p.cur_rates
@inbounds for i in idx:length(cur_rates)
cur_rates[i] = cur_rates[i] + prev_rate
cur_rates[i] = add_fast(cur_rates[i], prev_rate)
prev_rate = cur_rates[i]
end

UBmax = rand(p.rng) * Bmax
ttnj = prop_ttnj
if p.cur_rates[end] UBmax
nextrx = 1
@inbounds while p.cur_rates[nextrx] < UBmax
nextrx += 1
end
end
else
ttnj = Wmin
end
UBmax = rand(p.rng) * Bmax
ttnj = prop_ttnj
if p.cur_rates[end] UBmax
nextrx = searchsortedfirst(p.cur_rates, UBmax)
end
else
ttnj = Wmin
end

return nextrx, ttnj
Expand Down
10 changes: 10 additions & 0 deletions test/extrande.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,13 @@ ode_sol = solve(ode_prob, Tsit5())

# Test extrande against the ODE mean.
@test prod(isapprox.(means, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3))

# Make sure interfaces correctly with Mass Action Jumps.
reactant_stoich = [[1 => 1]]
net_stoich = [[1 => -1]]
majd = MassActionJump(reactant_stoich, net_stoich; param_idxs = [1])
bmajd_prob = ODEProblem(f, u0, (0.0, 2pi), [0.08])
jump_bmajd_prob = JumpProblem(bmajd_prob, Extrande(), jumpb, majd)

means_mass_action = runsimulations(jump_bmajd_prob, test_times)
@test prod(isapprox.(means_mass_action, getindex.(ode_sol(test_times).u, 1), rtol = 1e-3))

0 comments on commit fad63ca

Please sign in to comment.