Skip to content

Commit

Permalink
Merge 446444e into c8ec2ba
Browse files Browse the repository at this point in the history
  • Loading branch information
gzagatti committed Apr 27, 2023
2 parents c8ec2ba + 446444e commit a9f851e
Show file tree
Hide file tree
Showing 15 changed files with 523 additions and 73 deletions.
92 changes: 76 additions & 16 deletions docs/src/tutorials/discrete_stochastic_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -596,12 +596,13 @@ jump4 = ConstantRateJump(rate4, affect4!)
With the jumps defined, we can build a
[`DiscreteProblem`](https://docs.sciml.ai/DiffEqDocs/stable/types/discrete_types/).
Bounded `VariableRateJump`s over a `DiscreteProblem` can currently only be
simulated with the `Coevolve` aggregator. The aggregator requires a dependency
graph to indicate when a given jump occurs which other jumps in the system
should have their rate recalculated (i.e., their rate depends on states modified
by one occurrence of the first jump). This ensures that rates, rate bounds, and
rate intervals are recalculated when invalidated due to changes in `u`. For the
current example, both processes mutually affect each other, so we have
simulated with the `Coevolve` or `CoevolveSynced` aggregators. Both aggregators
requires a dependency graph to indicate when a given jump occurs which other
jumps in the system should have their rate recalculated (i.e., their rate
depends on states modified by one occurrence of the first jump). This ensures
that rates, rate bounds, and rate intervals are recalculated when invalidated
due to changes in `u`. For the current example, both processes mutually affect
each other, so we have

```@example tut2
dep_graph = [[1, 2], [1, 2]]
Expand All @@ -628,11 +629,11 @@ We see that the time-dependent infection rate leads to a lower peak of the
infection throughout the population.

Note that bounded `VariableRateJump`s over `DiscreteProblem`s can be quite
general, but it is not possible to handle rates that change according to an
ODE/SDE modified variable. A rate such as `p[2]*u[1]*u[4]` when `u[4]` is the
solution of a continuous problem such as an ODE or SDE can only be handled using
a general `VariableRateJump` within a continuous integrator as discussed
[below](@ref VariableRateJumpSect).
general. However, when handling rates that change according to an ODE/SDE
modified variable we will need a continuous integrator as discussed
[below](@ref VariableRateJumpSect). One example of such a rate is
`p[2]*u[1]*u[4]` when `u[4]` is the solution of a continuous problem such as an
ODE or SDE.

## [Reducing Memory Use: Controlling Saving Behavior](@id save_positions_docs)

Expand Down Expand Up @@ -701,7 +702,7 @@ plot(sol; label = ["S(t)" "I(t)" "R(t)"])
```

Note that we can combine `MassActionJump`s, `ConstantRateJump`s and bounded
`VariableRateJump`s using the `Coevolve` aggregator.
`VariableRateJump`s using the `Coevolve` or `CoevolveSynced` aggregators.

## Adding Jumps to a Differential Equation

Expand All @@ -713,7 +714,7 @@ only acts on some new 4th component:
```@example tut2
using OrdinaryDiffEq
function f(du, u, p, t)
du[4] = u[2] * u[3] / 100000 - u[1] * u[4] / 100000
du[4] = u[2] * u[3] / 1e5 - u[1] * u[4] / 1e5
nothing
end
u₀ = [999.0, 10.0, 0.0, 100.0]
Expand Down Expand Up @@ -758,10 +759,10 @@ jump5 = VariableRateJump(rate5, affect5!)
```

Notice, since `rate5` depends on a variable that evolves continuously, and hence
is not constant between jumps, *we must use a general `VariableRateJump` without
upper/lower bounds*.
is not constant between jumps, *we must either use a general `VariableRateJump` without
upper/lower bounds or a bounded `VariableRateJump`*.

Solving the equation is exactly the same:
In the general case, solving the equation is exactly the same:

```@example tut2
u₀ = [999.0, 10.0, 0.0, 1.0]
Expand All @@ -774,6 +775,65 @@ plot(sol; label = ["S(t)" "I(t)" "R(t)" "u₄(t)"])
*Note that general `VariableRateJump`s require using a continuous problem, like
an ODE/SDE/DDE/DAE problem, and using floating point initial conditions.*

Alternatively, the case of bounded `VariableRateJump` requires some maths.
First, we need to obtain the upper bounds of `rate5` at time `t` given `u`.
Note that `rate5` evolves according to `u[4]` which is a separable first order
differential equation of the form ``x' = b - a x`` with general solution:

```math
x(t) = - \frac{e^{-a t - c_1 a}}{a} + \frac{b}{a}
```

This is bounded by ``b / a`` which is too high for our purposes since it would
lead to a high rate of rejection during sampling. However, since the function
is increasing we can compute the upper bound given an interval ``\Delta t``
as following:

```math
\bar{x}(s) = x(t) \, e^{-a (t + \Delta t)} + \frac{b}{a} (1 - e^{- a (t + \Delta t)}) \text{ , } \forall s \in [t, t + \Delta t]
```

However, when ``a = 0`` the differential equation becomes ``x' = b`` whose solution is ``x(t) = b t``. In which case, we obtain a different upper bound given by:

```math
\bar{x}(s) = x(t) + b * (t + \Delta t) \text{ , } \forall s \in [t, t + \Delta t]
```

These expressions allow us to write the upper-bound and the rate interval in Julia.

```@example tut2
function urate2(u, p, t)
if u[1] > 0
1e-2 * max(u[4],
(u[4] * exp(-1 * u[1] / 1e5) +
(u[2] * u[3] / u[1]) * (1 - exp(-1 * u[1] / 1e5))))
else
1e-2 * (u[4] + 1 * u[2] * u[3] / 1e5)
end
end
rateinterval2(u, p, t) = 1
```

We can then formulate the jump problem. The only aggregator that supports
bounded `VariableRateJump`s is `CoevolveSynced`. We formulate and solve the
jump problem with this aggregator. `CoevolveSynced` can be formulated as either
a discrete or continuous problem. In this case, we must formulate the problem
as continuous as it depends on a continuous variable.

```@example tut2
jump6 = VariableRateJump(rate5, affect5!; urate = urate2, rateinterval = rateinterval2)
dep_graph2 = [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
jump_prob = JumpProblem(prob, CoevolveSynced(), jump, jump2, jump6; dep_graph = dep_graph2)
sol = solve(jump_prob, Tsit5())
plot(sol; label = ["S(t)" "I(t)" "R(t)" "u₄(t)"])
```

We obtain the same solution as with `Direct`, but `CoevolveSynced` runs faster
because it doesn't need to compute the derivative of `rate5`. Each aggregator
faces a different trade-off, so the the choice of best aggregator will depend
on the problem at hand. `CoevolveSynced` requires a good understanding of the
equations involved, passing a wrong boundary can result in silent bugs.

Lastly, we are not restricted to ODEs. For example, we can solve the same jump
problem except with multiplicative noise on `u[4]` by using an `SDEProblem`
instead:
Expand Down
5 changes: 4 additions & 1 deletion src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
abstract type AbstractAggregatorAlgorithm end
abstract type AbstractJumpAggregator end
abstract type AbstractSSAIntegrator{Alg, IIP, U, T} <:
DiffEqBase.DEIntegrator{Alg, IIP, U, T} end

import Base.Threads
@static if VERSION < v"1.3"
Expand Down Expand Up @@ -51,6 +53,7 @@ include("aggregators/directcr.jl")
include("aggregators/rssacr.jl")
include("aggregators/rdirect.jl")
include("aggregators/coevolve.jl")
include("aggregators/coevolvesynced.jl")

# spatial:
include("spatial/spatial_massaction_jump.jl")
Expand Down Expand Up @@ -84,7 +87,7 @@ export Direct, DirectFW, SortingDirect, DirectCR
export BracketData, RSSA
export FRM, FRMFW, NRM
export RSSACR, RDirect
export Coevolve
export Coevolve, CoevolveSynced

export get_num_majumps, needs_depgraph, needs_vartojumps_map

Expand Down
7 changes: 6 additions & 1 deletion src/SSA_stepper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ Solution objects for pure jump problems solved via `SSAStepper`.
$(FIELDS)
"""
mutable struct SSAIntegrator{F, uType, tType, tdirType, P, S, CB, SA, OPT, TS} <:
DiffEqBase.DEIntegrator{SSAStepper, Nothing, uType, tType}
AbstractSSAIntegrator{SSAStepper, Nothing, uType, tType}
"""The underlying `prob.f` function. Not currently used."""
f::F
"""The current solution values."""
Expand Down Expand Up @@ -258,10 +258,15 @@ function DiffEqBase.step!(integrator::SSAIntegrator)
# FP error means the new time may equal the old if the next jump time is
# sufficiently small, hence we add this check to execute jumps until
# this is no longer true.
integrator.u_modified = true
while integrator.t == integrator.tstop
doaffect && integrator.cb.affect!(integrator)
end

if !integrator.u_modified
return nothing
end

if !(typeof(integrator.opts.callback.discrete_callbacks) <: Tuple{})
discrete_modified, saved_in_cb = DiffEqBase.apply_discrete_callback!(integrator,
integrator.opts.callback.discrete_callbacks...)
Expand Down
13 changes: 12 additions & 1 deletion src/aggregators/aggregators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,15 @@ evolution, Journal of Machine Learning Research 18(1), 1305–1353 (2017). doi:
"""
struct Coevolve <: AbstractAggregatorAlgorithm end

"""
A modification of the COEVOLVE algorithm for simulating any compound jump
process that evolves through time. As opposed to `Coevolve`, this method
syncs the thinning procedure with the stepper which allows it to handle
dependencies on continuous dynamics. It reduces to NRM when rates are
constant.
"""
struct CoevolveSynced <: AbstractAggregatorAlgorithm end

# spatial methods

"""
Expand All @@ -158,7 +167,7 @@ algorithm with optimal binning, Journal of Chemical Physics 143, 074108
struct DirectCRDirect <: AbstractAggregatorAlgorithm end

const JUMP_AGGREGATORS = (Direct(), DirectFW(), DirectCR(), SortingDirect(), RSSA(), FRM(),
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve())
FRMFW(), NRM(), RSSACR(), RDirect(), Coevolve(), CoevolveSynced())

# For JumpProblem construction without an aggregator
struct NullAggregator <: AbstractAggregatorAlgorithm end
Expand All @@ -170,6 +179,7 @@ needs_depgraph(aggregator::SortingDirect) = true
needs_depgraph(aggregator::NRM) = true
needs_depgraph(aggregator::RDirect) = true
needs_depgraph(aggregator::Coevolve) = true
needs_depgraph(aggregator::CoevolveSynced) = true

# true if aggregator requires a map from solution variable to dependent jumps.
# It is implicitly assumed these aggregators also require the reverse map, from
Expand All @@ -181,6 +191,7 @@ needs_vartojumps_map(aggregator::RSSACR) = true
# true if aggregator supports variable rates
supports_variablerates(aggregator::AbstractAggregatorAlgorithm) = false
supports_variablerates(aggregator::Coevolve) = true
supports_variablerates(aggregator::CoevolveSynced) = true

is_spatial(aggregator::AbstractAggregatorAlgorithm) = false
is_spatial(aggregator::NSM) = true
Expand Down
67 changes: 42 additions & 25 deletions src/aggregators/coevolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ function CoevolveJumpAggregation(nj::Int, njt::T, et::T, crs::Vector{T}, sr::Not
dg, pq, lrates, urates, rateintervals, haslratevec)
end

# display
num_constant_rate_jumps(aggregator::CoevolveJumpAggregation) = length(aggregator.urates)

# creating the JumpAggregation structure (tuple-based variable jumps)
function aggregate(aggregator::Coevolve, u, p, t, end_time, constant_jumps,
ma_jumps, save_positions, rng; dep_graph = nothing,
Expand Down Expand Up @@ -112,7 +115,6 @@ end
function execute_jumps!(p::CoevolveJumpAggregation, integrator, u, params, t, affects!)
# execute jump
u = update_state!(p, integrator, u, affects!)

# update current jump rates and times
update_dependent_rates!(p, u, params, t)
nothing
Expand All @@ -127,11 +129,11 @@ end
######################## SSA specific helper routines ########################
function update_dependent_rates!(p::CoevolveJumpAggregation, u, params, t)
@inbounds deps = p.dep_gr[p.next_jump]
@unpack cur_rates, end_time, pq = p
@unpack cur_rates, pq = p
for (ix, i) in enumerate(deps)
ti, last_urate_i = next_time(p, u, params, t, i, end_time)
ti, urate_i = next_time(p, u, params, t, i)
update!(pq, i, ti)
@inbounds cur_rates[i] = last_urate_i
@inbounds cur_rates[i] = urate_i
end
nothing
end
Expand All @@ -156,61 +158,76 @@ end
@inbounds return p.rates[lidx](u, params, t)
end

function next_time(p::CoevolveJumpAggregation{T}, u, params, t, i, tstop::T) where {T}
@unpack rng, haslratevec = p
num_majumps = get_num_majumps(p.ma_jumps)
num_cjumps = length(p.urates) - length(p.rates)
function next_time(p::CoevolveJumpAggregation{T}, u, params, t, i) where {T}
@unpack next_jump, cur_rates, ma_jumps, rates, rng, pq, urates = p
num_majumps = get_num_majumps(ma_jumps)
num_cjumps = length(urates) - length(rates)
uidx = i - num_majumps
lidx = uidx - num_cjumps
urate = uidx > 0 ? get_urate(p, uidx, u, params, t) : get_ma_urate(p, i, u, params, t)
last_urate = p.cur_rates[i]
if i != p.next_jump && last_urate > zero(t)
s = urate == zero(t) ? typemax(t) : last_urate / urate * (p.pq[i] - t)
if urate < zero(t)
error("urate = $(urate) < 0 for jump = $(i) at t = $(t) which is not allowed.")
end
last_urate = cur_rates[i]
if i != next_jump && last_urate > zero(t)
s = urate == zero(t) ? typemax(t) : last_urate / urate * (pq[i] - t)
else
s = urate == zero(t) ? typemax(t) : randexp(rng) / urate
end
_t = t + s
if lidx > 0
while t < tstop
@unpack end_time, haslratevec = p
while t < end_time
rateinterval = get_rateinterval(p, lidx, u, params, t)
if s > rateinterval
t = t + rateinterval
urate = get_urate(p, uidx, u, params, t)
if urate < zero(t)
error("urate = $(urate) < 0 for jump = $(i) at t = $(t) which is not allowed.")
end
s = urate == zero(t) ? typemax(t) : randexp(rng) / urate
_t = t + s
continue
end
(_t >= tstop) && break

(_t >= end_time) && break
lrate = haslratevec[lidx] ? get_lrate(p, lidx, u, params, t) : zero(t)
if lrate < urate
# when the lower and upper bound are the same, then v < 1 = lrate / urate = urate / urate
v = rand(rng) * urate
# first inequality is less expensive and short-circuits the evaluation
if (v > lrate) && (v > get_rate(p, lidx, u, params, _t))
t = _t
urate = get_urate(p, uidx, u, params, t)
s = urate == zero(t) ? typemax(t) : randexp(rng) / urate
_t = t + s
continue
if (v > lrate)
rate = get_rate(p, lidx, u, params, _t)
if rate < 0
error("rate = $(rate) < 0 for jump = $(i) at t = $(t) which is not allowed.")
elseif rate > urate
error("rate = $(rate) > urate = $(urate) for jump = $(i) at t = $(t) which is not allowed.")
end
if v > rate
t = _t
urate = get_urate(p, uidx, u, params, t)
if urate < zero(t)
error("urate = $(urate) < 0 for jump = $(i) at t = $(t) which is not allowed.")
end
s = urate == zero(t) ? typemax(t) : randexp(rng) / urate
_t = t + s
continue
end
end
elseif lrate > urate
error("The lower bound should be lower than the upper bound rate for t = $(t) and i = $(i), but lower bound = $(lrate) > upper bound = $(urate)")
error("lrate = $(lrate) > urate = $(urate) for jump = $(i) at t = $(t) which is not allowed.")
end
break
end
end
return _t, urate
end

# reevaulate all rates, recalculate all jump times, and reinit the priority queue
# re-evaluates all rates, recalculate all jump times, and reinit the priority queue
function fill_rates_and_get_times!(p::CoevolveJumpAggregation, u, params, t)
@unpack end_time = p
num_jumps = get_num_majumps(p.ma_jumps) + length(p.urates)
p.cur_rates = zeros(typeof(t), num_jumps)
jump_times = Vector{typeof(t)}(undef, num_jumps)
@inbounds for i in 1:num_jumps
jump_times[i], p.cur_rates[i] = next_time(p, u, params, t, i, end_time)
jump_times[i], p.cur_rates[i] = next_time(p, u, params, t, i)
end
p.pq = MutableBinaryMinHeap(jump_times)
nothing
Expand Down
Loading

0 comments on commit a9f851e

Please sign in to comment.