Skip to content

Commit

Permalink
Merge 5c97a63 into 42404e0
Browse files Browse the repository at this point in the history
  • Loading branch information
Vilin97 committed Sep 14, 2023
2 parents 42404e0 + 5c97a63 commit aff92d8
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 24 deletions.
1 change: 1 addition & 0 deletions src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ function JumpSet(vj, cj, rj, maj::MassActionJump{S, T, U, V}) where {S <: Number
end

JumpSet(jump::ConstantRateJump) = JumpSet((), (jump,), nothing, nothing)
JumpSet(jumps::AbstractVector{ConstantRateJump}) = JumpSet((), jumps, nothing, nothing)
JumpSet(jump::VariableRateJump) = JumpSet((jump,), (), nothing, nothing)
JumpSet(jump::RegularJump) = JumpSet((), (), jump, nothing)
JumpSet(jump::AbstractMassActionJump) = JumpSet((), (), nothing, jump)
Expand Down
7 changes: 5 additions & 2 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,12 +198,15 @@ function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::JumpS

## Spatial jumps handling
if spatial_system !== nothing && hopping_constants !== nothing
(num_crjs(jumps) == num_vrjs(jumps) == 0) ||
error("Spatial aggregators only support MassActionJumps currently.")
(num_vrjs(jumps) == 0) ||
error("Spatial aggregators currently only support MassActionJumps and ConstantRateJumps.")

if is_spatial(aggregator)
kwargs = merge((; hopping_constants, spatial_system), kwargs)
else
if num_crjs(jumps) != 0
error("Use a spatial SSA, e.g. DirectCRDirect in order to use ConstantRateJumps.")
end
prob, maj = flatten(maj, prob, spatial_system, hopping_constants; kwargs...)
end
end
Expand Down
10 changes: 8 additions & 2 deletions src/spatial/directcrdirect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rat

# a dependency graph is needed
if dep_graph === nothing
if length(rx_rates.cr_jumps) != 0
error("Provide a dependency graph to use DirectCRDirect with constant rate jumps.")
end
dg = make_dependency_graph(num_specs, rx_rates.ma_jumps)
else
dg = dep_graph
Expand All @@ -54,6 +57,9 @@ function DirectCRDirectJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rat
end

if jumptovars_map === nothing
if length(rx_rates.cr_jumps) != 0
error("Provide a jump-to-species dependency graph to use DirectCRDirect with constant rate jumps.")
end
jtov_map = jump_to_vars_map(rx_rates.ma_jumps)
else
jtov_map = jumptovars_map
Expand Down Expand Up @@ -94,7 +100,7 @@ function aggregate(aggregator::DirectCRDirect, starting_state, p, t, end_time,

next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder
next_jump_time = typemax(typeof(end_time))
rx_rates = RxRates(num_sites(spatial_system), majumps)
rx_rates = RxRates(num_sites(spatial_system), majumps, constant_jumps)
hop_rates = HopRates(hopping_constants, spatial_system)
site_rates = zeros(typeof(end_time), num_sites(spatial_system))

Expand Down Expand Up @@ -199,4 +205,4 @@ end
number of constant rate jumps
"""
num_constant_rate_jumps(aggregator::DirectCRDirectJumpAggregation) = 0
num_constant_rate_jumps(aggregator::DirectCRDirectJumpAggregation) = length(aggregator.rx_rates.cr_jumps)
10 changes: 8 additions & 2 deletions src/spatial/nsm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ function NSMJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop

# a dependency graph is needed
if dep_graph === nothing
if length(rx_rates.cr_jumps) != 0
error("Provide a dependency graph to use NSM with constant rate jumps.")
end
dg = make_dependency_graph(num_specs, rx_rates.ma_jumps)
else
dg = dep_graph
Expand All @@ -47,6 +50,9 @@ function NSMJumpAggregation(nj::SpatialJump{J}, njt::T, et::T, rx_rates::RX, hop
end

if jumptovars_map === nothing
if length(rx_rates.cr_jumps) != 0
error("Provide a jump-to-species dependency graph to use NSM with constant rate jumps.")
end
jtov_map = jump_to_vars_map(rx_rates.ma_jumps)
else
jtov_map = jumptovars_map
Expand Down Expand Up @@ -83,7 +89,7 @@ function aggregate(aggregator::NSM, starting_state, p, t, end_time, constant_jum

next_jump = SpatialJump{Int}(typemax(Int), typemax(Int), typemax(Int)) #a placeholder
next_jump_time = typemax(typeof(end_time))
rx_rates = RxRates(num_sites(spatial_system), majumps)
rx_rates = RxRates(num_sites(spatial_system), majumps, constant_jumps)
hop_rates = HopRates(hopping_constants, spatial_system)

NSMJumpAggregation(next_jump, next_jump_time, end_time, rx_rates, hop_rates,
Expand Down Expand Up @@ -187,4 +193,4 @@ end
number of constant rate jumps
"""
num_constant_rate_jumps(aggregator::NSMJumpAggregation) = 0
num_constant_rate_jumps(aggregator::NSMJumpAggregation) = length(aggregator.rx_rates.cr_jumps)
46 changes: 36 additions & 10 deletions src/spatial/reaction_rates.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""
A file with structs and functions for sampling reactions and updating reaction rates in spatial SSAs
A file with structs and functions for sampling reactions and updating reaction rates in spatial SSAs.
Massaction jumps go first in the indexing, then constant rate jumps.
"""

### spatial rx rates ###
struct RxRates{F, M}
struct RxRates{F, M, C}
"rx_rates[i,j] is rate of reaction i at site j"
rates::Matrix{F}

Expand All @@ -12,20 +13,25 @@ struct RxRates{F, M}

"AbstractMassActionJump"
ma_jumps::M

"indexable collection of ConstantRateJump"
cr_jumps::C
end

"""
RxRates(num_sites::Int, ma_jumps::M) where {M}
RxRates(num_sites::Int, ma_jumps::M, cr_jumps::C) where {M, C}
initializes RxRates with zero rates
"""
function RxRates(num_sites::Int, ma_jumps::M) where {M}
numrxjumps = get_num_majumps(ma_jumps)
function RxRates(num_sites::Int, ma_jumps::M, cr_jumps::C) where {M, C}
numrxjumps = get_num_majumps(ma_jumps) + length(cr_jumps)
rates = zeros(Float64, numrxjumps, num_sites)
RxRates{Float64, M}(rates, vec(sum(rates, dims = 1)), ma_jumps)
RxRates{Float64, M, C}(rates, vec(sum(rates, dims = 1)), ma_jumps, cr_jumps)
end
RxRates(num_sites::Int, ma_jumps::M) where {M<:AbstractMassActionJump} = RxRates(num_sites, ma_jumps, ConstantRateJump[])
RxRates(num_sites::Int, cr_jumps::C) where {C} = RxRates(num_sites, SpatialMassActionJump(), cr_jumps)

num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps)
num_rxs(rx_rates::RxRates) = get_num_majumps(rx_rates.ma_jumps) + length(rx_rates.cr_jumps)

"""
reset!(rx_rates::RxRates)
Expand All @@ -48,16 +54,21 @@ function total_site_rx_rate(rx_rates::RxRates, site)
end

"""
update_rx_rates!(rx_rates, rxs, u, site)
update_rx_rates!(rx_rates, rxs, integrator, site)
update rates of all reactions in rxs at site
"""
function update_rx_rates!(rx_rates::RxRates, rxs, u::AbstractMatrix, integrator,
site)
ma_jumps = rx_rates.ma_jumps
@inbounds for rx in rxs
rate = eval_massaction_rate(u, rx, ma_jumps, site)
set_rx_rate_at_site!(rx_rates, site, rx, rate)
if is_massaction(rx_rates, rx)
rate = eval_massaction_rate(u, rx, ma_jumps, site)
set_rx_rate_at_site!(rx_rates, site, rx, rate)
else
cr_jump = rx_rates.cr_jumps[rx - get_num_majumps(ma_jumps)]
set_rx_rate_at_site!(rx_rates, site, rx, cr_jump.rate(u, integrator.p, integrator.t, site))
end
end
end

Expand All @@ -77,6 +88,16 @@ function sample_rx_at_site(rx_rates::RxRates, site, rng)
rand(rng) * total_site_rx_rate(rx_rates, site))
end

function execute_rx_at_site!(integrator, rx_rates::RxRates, rx, site)
if is_massaction(rx_rates, rx)
@inbounds executerx!((@view integrator.u[:, site]), rx,
rx_rates.ma_jumps)
else
cr_jump = rx_rates.cr_jumps[rx - get_num_majumps(rx_rates.ma_jumps)]
cr_jump.affect!(integrator, site)
end
end

# helper functions
function set_rx_rate_at_site!(rx_rates::RxRates, site, rx, rate)
@inbounds old_rate = rx_rates.rates[rx, site]
Expand All @@ -90,5 +111,10 @@ function Base.show(io::IO, ::MIME"text/plain", rx_rates::RxRates)
println(io, "RxRates with $num_rxs reactions and $num_sites sites")
end

"Return true if jump is a massaction jump."
function is_massaction(rx_rates::RxRates, rx)
rx <= get_num_majumps(rx_rates.ma_jumps)
end

eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: SpatialMassActionJump} = evalrxrate(u, rx, ma_jumps, site)
eval_massaction_rate(u, rx, ma_jumps::M, site) where {M <: MassActionJump} = evalrxrate((@view u[:, site]), rx, ma_jumps)
6 changes: 6 additions & 0 deletions src/spatial/spatial_massaction_jump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ function SpatialMassActionJump(ma_jumps::MassActionJump{T, S, U, V}; scale_rates
scale_rates = scale_rates, useiszero = useiszero, nocopy = nocopy)
end

function SpatialMassActionJump()
empty_majump = MassActionJump(Vector{Float64}(),
Vector{Vector{Pair{Int, Int}}}(),
Vector{Vector{Pair{Int, Int}}}())
SpatialMassActionJump(empty_majump)
end
##############################################

function get_num_majumps(smaj::SpatialMassActionJump{Nothing, Nothing, S, U, V}) where
Expand Down
5 changes: 2 additions & 3 deletions src/spatial/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ struct SpatialJump{J}
"source location"
src::J

"index of jump as a hop or reaction"
"index of jump as a hop or reaction, hops first, then massaction reactions, then constant rate reactions"
jidx::Int

"destination location, equal to src for within-site reactions"
Expand Down Expand Up @@ -69,8 +69,7 @@ function update_state!(p, integrator)
execute_hop!(integrator, jump.src, jump.dst, jump.jidx)
else
rx_index = reaction_id_from_jump(p, jump)
@inbounds executerx!((@view integrator.u[:, jump.src]), rx_index,
p.rx_rates.ma_jumps)
execute_rx_at_site!(integrator, p.rx_rates, rx_index, jump.src)
end
# save jump that was just exectued
p.prev_jump = jump
Expand Down
22 changes: 22 additions & 0 deletions test/spatial/ABC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,23 @@ netstoch = [[1 => -1, 2 => -1, 3 => 1], [1 => 1, 2 => 1, 3 => -1]]
rates = [0.1 / mesh_size, 1.0]
majumps = MassActionJump(rates, reactstoch, netstoch)

# equivalent constant rate jumps
rate1(u,p,t,site) = u[1,site]*u[2,site] / 2
rate2(u,p,t,site) = u[3,site]
affect1!(integrator,site) = begin
integrator.u[1, site] -= 1
integrator.u[2, site] -= 1
integrator.u[3, site] += 1
end
affect2!(integrator,site) = begin
integrator.u[1, site] += 1
integrator.u[2, site] += 1
integrator.u[3, site] -= 1
end
crjumps = JumpSet([ConstantRateJump(rate1, affect1!), ConstantRateJump(rate2, affect2!)])
dep_graph = [[1,2],[1,2]]
jumptovars_map = [[1,2,3],[1,2,3]]

# spatial system setup
hopping_rate = diffusivity * (linear_size / domain_size)^2

Expand Down Expand Up @@ -56,6 +73,11 @@ jump_problems = JumpProblem[JumpProblem(prob, NSM(), majumps,
push!(jump_problems,
JumpProblem(prob, DirectCRDirect(), majumps, hopping_constants = hopping_constants,
spatial_system = grids[1], save_positions = (false, false), rng = rng))
# setup constant rate jump problems
push!(jump_problems, JumpProblem(prob, NSM(), crjumps, hopping_constants = hopping_constants,
spatial_system = CartesianGrid(dims), save_positions = (false, false), dep_graph = dep_graph, jumptovars_map = jumptovars_map, rng = rng))
push!(jump_problems, JumpProblem(prob, DirectCRDirect(), crjumps, hopping_constants = hopping_constants,
spatial_system = CartesianGrid(dims), save_positions = (false, false), dep_graph = dep_graph, jumptovars_map = jumptovars_map, rng = rng))
# setup flattenned jump prob
push!(jump_problems,
JumpProblem(prob, NRM(), majumps, hopping_constants = hopping_constants,
Expand Down
19 changes: 14 additions & 5 deletions test/spatial/reaction_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,31 @@ num_species = 3
reactstoch = [[1 => 1, 2 => 1], [3 => 1]]
netstoch = [[1 => -1, 2 => -1, 3 => 1], [1 => 1, 2 => 1, 3 => -1]]
rates = [0.1, 1.0]
num_rxs = length(rates)
ma_jumps = MassActionJump(rates, reactstoch, netstoch)
spatial_ma_jumps = SpatialMassActionJump(rates, reactstoch, netstoch)
rate_fn = (u, p, t, site) -> 1.0
affect_fn!(integrator) = nothing # a dummy reaction, does nothing
cr_jumps = [ConstantRateJump(rate_fn, affect_fn!)]
num_rxs = 3
u = ones(Int, num_species, num_nodes)
integrator = DummyIntegrator(u,nothing,nothing)
rng = StableRNG(12345)

# Test constructors
@test JP.RxRates(num_nodes, ma_jumps).ma_jumps == ma_jumps
@test JP.RxRates(num_nodes, spatial_ma_jumps).ma_jumps == spatial_ma_jumps
@test JP.RxRates(num_nodes, cr_jumps).cr_jumps == cr_jumps

# Tests for RxRates
rx_rates_list = [JP.RxRates(num_nodes, ma_jumps), JP.RxRates(num_nodes, spatial_ma_jumps)]
rx_rates_list = [JP.RxRates(num_nodes, ma_jumps, cr_jumps), JP.RxRates(num_nodes, spatial_ma_jumps, cr_jumps)]
for rx_rates in rx_rates_list
@test JP.num_rxs(rx_rates) == length(rates)
@test JP.num_rxs(rx_rates) == num_rxs
show(io, "text/plain", rx_rates)
for site in 1:num_nodes
JP.update_rx_rates!(rx_rates, 1:num_rxs, integrator, site)
@test JP.total_site_rx_rate(rx_rates, site) == 1.1
rx_props = [JP.evalrxrate(u[:, site], rx, ma_jumps) for rx in 1:num_rxs]
@test JP.total_site_rx_rate(rx_rates, site) == 2.1
majump_props = [JP.evalrxrate(u[:, site], rx, ma_jumps) for rx in 1:2]
rx_props = [majump_props..., 1.0]
rx_probs = rx_props / sum(rx_props)
d = Dict{Int, Int}()
for i in 1:num_samples
Expand Down

0 comments on commit aff92d8

Please sign in to comment.