Skip to content

Commit

Permalink
Merge pull request #219 from Vilin97/spatially-varying-rx-rates-WIP
Browse files Browse the repository at this point in the history
working on spatially varying reaction rates, WIP
  • Loading branch information
isaacsas committed Feb 18, 2022
2 parents 8a67239 + 4ec4385 commit c842db8
Show file tree
Hide file tree
Showing 10 changed files with 248 additions and 10 deletions.
3 changes: 3 additions & 0 deletions src/DiffEqJump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import RecursiveArrayTools: recursivecopy!
using StaticArrays, Base.Threads

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
abstract type AbstractAggregatorAlgorithm end
abstract type AbstractJumpAggregator end

Expand All @@ -43,6 +44,7 @@ include("aggregators/rssacr.jl")
include("aggregators/rdirect.jl")

# spatial:
include("spatial/spatial_massaction_jump.jl")
include("spatial/topology.jl")
include("spatial/hop_rates.jl")
include("spatial/reaction_rates.jl")
Expand Down Expand Up @@ -85,6 +87,7 @@ export ExtendedJumpArray

# spatial structs and functions
export CartesianGrid, CartesianGridRej, CartesianGridIter
export SpatialMassActionJump
export outdegree, num_sites, neighbors
export NSM, DirectCRDirect

Expand Down
4 changes: 2 additions & 2 deletions src/jumps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function RegularJump(rate,c,dc::AbstractMatrix; constant_c=false, mark_dist = no
RegularJump{true}(rate,_c,size(dc,2);mark_dist=mark_dist)
end

struct MassActionJump{T,S,U,V} <: AbstractJump
struct MassActionJump{T,S,U,V} <: AbstractMassActionJump
scaled_rates::T
reactant_stoch::S
net_stoch::U
Expand Down Expand Up @@ -190,7 +190,7 @@ JumpSet(vj, cj, rj, maj::MassActionJump{S,T,U,V}) where {S <: Number,T,U,V} = Ju
JumpSet(jump::ConstantRateJump) = JumpSet((),(jump,),nothing,nothing)
JumpSet(jump::VariableRateJump) = JumpSet((jump,),(),nothing,nothing)
JumpSet(jump::RegularJump) = JumpSet((),(),jump,nothing)
JumpSet(jump::MassActionJump) = JumpSet((),(),nothing,jump)
JumpSet(jump::AbstractMassActionJump) = JumpSet((),(),nothing,jump)
function JumpSet(; variable_jumps=(), constant_jumps=(),
regular_jumps=nothing, massaction_jumps=nothing)
JumpSet(variable_jumps, constant_jumps, regular_jumps, massaction_jumps)
Expand Down
19 changes: 15 additions & 4 deletions src/massaction_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
end

@inline @fastmath function executerx!(speciesvec::AbstractVector{T}, rxidx::S,
majump::MassActionJump{U,V,W,X}) where {T,S,U,V,W,X}
majump::M) where {T,S,M <: AbstractMassActionJump}
@inbounds net_stoch = majump.net_stoch[rxidx]
@inbounds for specstoch in net_stoch
speciesvec[specstoch[1]] += specstoch[2]
Expand All @@ -28,7 +28,7 @@ end
end

@inline @fastmath function executerx(speciesvec::SVector{T}, rxidx::S,
majump::MassActionJump{U,V,W,X}) where {T,S,U,V,W,X}
majump::M) where {T,S,M <: AbstractMassActionJump}
@inbounds net_stoch = majump.net_stoch[rxidx]
@inbounds for specstoch in net_stoch
speciesvec = setindex(speciesvec,speciesvec[specstoch[1]]+specstoch[2],specstoch[1])
Expand All @@ -53,6 +53,17 @@ function scalerates!(unscaled_rates::AbstractVector{U}, stochmat::AbstractVector
nothing
end

function scalerates!(unscaled_rates::AbstractMatrix{U}, stochmat::AbstractVector{V}) where {U,S,T,W <: Pair{S,T}, V <: AbstractVector{W}}
@inbounds for i in size(unscaled_rates, 1)
coef = one(T)
@inbounds for specstoch in stochmat[i]
coef *= factorial(specstoch[2])
end
unscaled_rates[i,:] /= coef
end
nothing
end

function scalerate(unscaled_rate::U, stochmat::AbstractVector{Pair{S,T}}) where {U <: Number, S, T}
coef = one(T)
@inbounds for specstoch in stochmat
Expand All @@ -70,7 +81,7 @@ end
# uses a Vector instead of a Set as the latter requires isEqual,
# and by using an underlying Dict can be slower for small numbers
# of dependencies
function var_to_jumps_map(numspec, ma_jumps::MassActionJump)
function var_to_jumps_map(numspec, ma_jumps::AbstractMassActionJump)

numrxs = get_num_majumps(ma_jumps)

Expand All @@ -94,7 +105,7 @@ end

# dependency graph is a map from a reaction to a vector of reactions
# that should depend on species it changes
function make_dependency_graph(numspec, ma_jumps::MassActionJump)
function make_dependency_graph(numspec, ma_jumps::AbstractMassActionJump)

numrxs = get_num_majumps(ma_jumps)
spec_to_dep_rxs = var_to_jumps_map(numspec, ma_jumps)
Expand Down
2 changes: 1 addition & 1 deletion src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ JumpProblem(prob,jumps::AbstractJump...;kwargs...) = JumpProblem(prob,JumpSet(ju
JumpProblem(prob,aggregator::AbstractAggregatorAlgorithm,jumps::ConstantRateJump;kwargs...) = JumpProblem(prob,aggregator,JumpSet(jumps);kwargs...)
JumpProblem(prob,aggregator::AbstractAggregatorAlgorithm,jumps::VariableRateJump;kwargs...) = JumpProblem(prob,aggregator,JumpSet(jumps);kwargs...)
JumpProblem(prob,aggregator::AbstractAggregatorAlgorithm,jumps::RegularJump;kwargs...) = JumpProblem(prob,aggregator,JumpSet(jumps);kwargs...)
JumpProblem(prob,aggregator::AbstractAggregatorAlgorithm,jumps::MassActionJump;kwargs...) = JumpProblem(prob,aggregator,JumpSet(jumps);kwargs...)
JumpProblem(prob,aggregator::AbstractAggregatorAlgorithm,jumps::AbstractMassActionJump;kwargs...) = JumpProblem(prob,aggregator,JumpSet(jumps);kwargs...)
JumpProblem(prob,aggregator::AbstractAggregatorAlgorithm,jumps::AbstractJump...;kwargs...) = JumpProblem(prob,aggregator,JumpSet(jumps...);kwargs...)
JumpProblem(prob,jumps::JumpSet;kwargs...) = JumpProblem(prob,NullAggregator(),jumps;kwargs...)

Expand Down
16 changes: 15 additions & 1 deletion src/spatial/flatten.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,21 @@ function flatten(ma_jump, prob::DiscreteProblem, spatial_system, hopping_constan
end
netstoch = ma_jump.net_stoch
reactstoch = ma_jump.reactant_stoch
rx_rates = ma_jump.scaled_rates
if isa(ma_jump, MassActionJump)
rx_rates = ma_jump.scaled_rates
elseif isa(ma_jump, SpatialMassActionJump)
num_nodes = num_sites(spatial_system)
if isnothing(ma_jump.uniform_rates) && isnothing(ma_jump.spatial_rates)
rx_rates = zeros(0,num_nodes)
elseif isnothing(ma_jump.uniform_rates)
rx_rates = ma_jump.spatial_rates
elseif isnothing(ma_jump.spatial_rates)
rx_rates = reshape(repeat(ma_jump.uniform_rates, num_nodes), length(ma_jump.uniform_rates), num_nodes)
else
@assert size(ma_jump.spatial_rates, 2) == num_nodes
rx_rates = cat(dims=1,reshape(repeat(ma_jump.uniform_rates, num_nodes), length(ma_jump.uniform_rates), num_nodes), ma_jump.spatial_rates)
end
end
flatten(netstoch, reactstoch, rx_rates, spatial_system, u0, tspan, hopping_constants; scale_rates = false, kwargs...)
end

Expand Down
11 changes: 9 additions & 2 deletions src/spatial/reaction_rates.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct RxRates{F,M}
"rx_rates_sum[j] is sum of reaction rates at site j"
sum_rates::Vector{F}

"MassActionJump"
"AbstractMassActionJump"
ma_jumps::M
end

Expand Down Expand Up @@ -52,13 +52,20 @@ end
update rates of all reactions in rxs at site
"""
function update_rx_rates!(rx_rates, rxs, u, site)
function update_rx_rates!(rx_rates::RxRates{F,M}, rxs, u, site) where {F, M <: MassActionJump}
ma_jumps = rx_rates.ma_jumps
@inbounds for rx in rxs
set_rx_rate_at_site!(rx_rates, site, rx, evalrxrate((@view u[:,site]), rx, ma_jumps))
end
end

function update_rx_rates!(rx_rates::RxRates{F,M}, rxs, u, site) where {F, M <: SpatialMassActionJump}
ma_jumps = rx_rates.ma_jumps
@inbounds for rx in rxs
set_rx_rate_at_site!(rx_rates, site, rx, evalrxrate(u, rx, ma_jumps, site))
end
end

"""
sample_rx_at_site(rx_rates::RxRates, site, rng)
Expand Down
71 changes: 71 additions & 0 deletions src/spatial/spatial_massaction_jump.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
struct SpatialMassActionJump{A<:Union{AbstractVector,Nothing},B<:Union{AbstractMatrix,Nothing},S,U,V} <: AbstractMassActionJump
uniform_rates::A # reactions that are uniform in space
spatial_rates::B # reactions whose rate depends on the site
reactant_stoch::S
net_stoch::U
param_mapper::V

"""
uniform rates go first in ordering
"""
function SpatialMassActionJump{A,B,S,U,V}(uniform_rates::A, spatial_rates::B, reactant_stoch::S, net_stoch::U, param_mapper::V, scale_rates::Bool, useiszero::Bool, nocopy::Bool) where {A<:Union{AbstractVector,Nothing},B<:Union{AbstractMatrix,Nothing}, S, U, V}
uniform_rates = (nocopy || isnothing(uniform_rates)) ? uniform_rates : copy(uniform_rates)
spatial_rates = (nocopy || isnothing(spatial_rates)) ? spatial_rates : copy(spatial_rates)
reactant_stoch = nocopy ? reactant_stoch : copy(reactant_stoch)
for i in eachindex(reactant_stoch)
if useiszero && (length(reactant_stoch[i]) == 1) && iszero(reactant_stoch[i][1][1])
reactant_stoch[i] = typeof(reactant_stoch[i])()
end
end
num_unif_rates = isnothing(uniform_rates) ? 0 : length(uniform_rates)
if scale_rates && num_unif_rates > 0
scalerates!(uniform_rates, reactant_stoch)
end
if scale_rates && !isnothing(spatial_rates) && !isempty(spatial_rates)
scalerates!(spatial_rates, reactant_stoch[num_unif_rates+1:end])
end
new(uniform_rates, spatial_rates, reactant_stoch, net_stoch, param_mapper)
end

end

################ Constructors ##################

SpatialMassActionJump(urates::A, srates::B, rs::S, ns::U, pmapper::V; scale_rates = true, useiszero = true, nocopy=false) where {A<:Union{AbstractVector,Nothing},B<:Union{AbstractMatrix,Nothing},S,U,V} = SpatialMassActionJump{A,B,S,U,V}(urates, srates, rs, ns, pmapper, scale_rates, useiszero, nocopy)
SpatialMassActionJump(urates::A, srates::B, rs, ns; scale_rates = true, useiszero = true, nocopy=false) where {A<:Union{AbstractVector,Nothing},B<:Union{AbstractMatrix,Nothing}} = SpatialMassActionJump(urates, srates, rs, ns, nothing; scale_rates=scale_rates, useiszero=useiszero, nocopy=nocopy)

SpatialMassActionJump(srates::B, rs, ns, pmapper; scale_rates = true, useiszero = true, nocopy=false) where {B<:Union{AbstractMatrix,Nothing}} = SpatialMassActionJump(nothing, srates, rs, ns, pmapper; scale_rates = scale_rates, useiszero = useiszero, nocopy=nocopy)
SpatialMassActionJump(srates::B, rs, ns; scale_rates = true, useiszero = true, nocopy=false) where {B<:Union{AbstractMatrix,Nothing}} = SpatialMassActionJump(nothing, srates, rs, ns, nothing; scale_rates = scale_rates, useiszero = useiszero, nocopy=nocopy)

SpatialMassActionJump(urates::A, rs, ns, pmapper; scale_rates = true, useiszero = true, nocopy=false) where {A<:Union{AbstractVector,Nothing}} = SpatialMassActionJump(urates, nothing, rs, ns, pmapper; scale_rates = scale_rates, useiszero = useiszero, nocopy=nocopy)
SpatialMassActionJump(urates::A, rs, ns; scale_rates = true, useiszero = true, nocopy=false) where {A<:Union{AbstractVector,Nothing}} = SpatialMassActionJump(urates, nothing, rs, ns, nothing; scale_rates = scale_rates, useiszero = useiszero, nocopy=nocopy)

SpatialMassActionJump(ma_jumps::MassActionJump{T,S,U,V}; scale_rates = true, useiszero = true, nocopy=false) where {T,S,U,V} = SpatialMassActionJump(ma_jumps.scaled_rates, ma_jumps.reactant_stoch, ma_jumps.net_stoch, ma_jumps.param_mapper; scale_rates = scale_rates, useiszero = useiszero, nocopy=nocopy)

##############################################

get_num_majumps(spatial_majump::SpatialMassActionJump{Nothing,Nothing,S,U,V}) where {S,U,V} = 0
get_num_majumps(spatial_majump::SpatialMassActionJump{Nothing,B,S,U,V}) where {B,S,U,V} = size(spatial_majump.spatial_rates, 1)
get_num_majumps(spatial_majump::SpatialMassActionJump{A,Nothing,S,U,V}) where {A,S,U,V} = length(spatial_majump.uniform_rates)
get_num_majumps(spatial_majump::SpatialMassActionJump{A,B,S,U,V}) where {A<:AbstractVector,B<:AbstractMatrix,S,U,V} = length(spatial_majump.uniform_rates) + size(spatial_majump.spatial_rates, 1)
using_params(spatial_majump::SpatialMassActionJump) = false

rate_at_site(rx, site, spatial_majump::SpatialMassActionJump{Nothing,B,S,U,V}) where {B,S,U,V} = spatial_majump.spatial_rates[rx, site]
rate_at_site(rx, site, spatial_majump::SpatialMassActionJump{A,Nothing,S,U,V}) where {A,S,U,V} = spatial_majump.uniform_rates[rx]
function rate_at_site(rx, site, spatial_majump::SpatialMassActionJump{A,B,S,U,V}) where {A<:AbstractVector,B<:AbstractMatrix,S,U,V}
num_unif_rxs = length(spatial_majump.uniform_rates)
rx <= num_unif_rxs ? spatial_majump.uniform_rates[rx] : spatial_majump.spatial_rates[rx-num_unif_rxs, site]
end

function evalrxrate(speciesmat::AbstractMatrix{T}, rxidx::S, majump::SpatialMassActionJump, site::Int) where {T,S}
val = one(T)
@inbounds for specstoch in majump.reactant_stoch[rxidx]
specpop = speciesmat[specstoch[1], site]
val *= specpop
@inbounds for k = 2:specstoch[2]
specpop -= one(specpop)
val *= specpop
end
end
@inbounds return val * rate_at_site(rxidx, site, majump)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ using DiffEqJump, DiffEqBase, SafeTestsets
@time @safetestset "Long time accuracy test" begin include("longtimes_test.jl") end
@time @safetestset "Spatial utilities" begin include("spatial/utils_test.jl") end
@time @safetestset "Spatial A + B <--> C" begin include("spatial/ABC.jl") end
@time @safetestset "Spatially Varying Reaction Rates" begin include("spatial/spatial_majump.jl") end
@time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end
end
1 change: 1 addition & 0 deletions test/spatial/run_spatial_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ using DiffEqJump, DiffEqBase, SafeTestsets
@time @safetestset "Spatial utilities" begin include("utils_test.jl") end
@time @safetestset "Spatial A + B <--> C" begin include("ABC.jl") end
@time @safetestset "Pure diffusion" begin include("diffusion.jl") end
@time @safetestset "Spatially Varying Reaction Rates" begin include("spatial_majump.jl") end
end
130 changes: 130 additions & 0 deletions test/spatial/spatial_majump.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
using DiffEqJump, DiffEqBase, OrdinaryDiffEq
using Test, Graphs, LinearAlgebra

reltol = 0.05
Nsims = 10^4

dim = 1
linear_size = 5
dims = Tuple(repeat([linear_size], dim))
num_nodes = prod(dims)
center_site = trunc(Int,(linear_size^dim + 1)/2)
u0 = zeros(Int, 1, num_nodes)
end_time = 100.0
diffusivity = 1.0
death_rate = 0.1

num_species = 1
reactstoch = [Pair{Int64, Int64}[], [1 => 1]]
netstoch = [[1 => 1],[1 => -1]]
uniform_rates = ones(2, num_nodes)
uniform_rates[2,:] *= death_rate
non_uniform_rates = zeros(2, num_nodes)
non_uniform_rates[:,center_site] = uniform_rates[:,center_site]

# DiscreteProblem setup
tspan = (0.0, end_time)
prob = DiscreteProblem(u0, tspan, uniform_rates)

# spatial system
grid = Graphs.grid(dims)
hopping_constants = [diffusivity for i in u0]

# majumps
uniform_majumps_1 = SpatialMassActionJump(uniform_rates[:,1], reactstoch, netstoch)
uniform_majumps_2 = SpatialMassActionJump(uniform_rates, reactstoch, netstoch)
uniform_majumps_3 = SpatialMassActionJump([1.], reshape(uniform_rates[2,:], 1, num_nodes), reactstoch, netstoch) # hybrid
uniform_majumps_4 = SpatialMassActionJump(MassActionJump(uniform_rates[:,1], reactstoch, netstoch))
uniform_majumps = [uniform_majumps_1, uniform_majumps_2, uniform_majumps_3, uniform_majumps_4]

non_uniform_majumps_1 = SpatialMassActionJump(non_uniform_rates, reactstoch, netstoch) # reactions are zero outside of center site
non_uniform_majumps_2 = SpatialMassActionJump([1.], reshape(non_uniform_rates[2,:], 1, num_nodes), reactstoch,netstoch) # birth everywhere, death only at center site
non_uniform_majumps_3 = SpatialMassActionJump([1. 0. 0. 0. 0.; 0. 0. 0. 0. death_rate], reactstoch,netstoch) # birth on the left, death on the right
non_uniform_majumps = [non_uniform_majumps_1, non_uniform_majumps_2, non_uniform_majumps_3]

# put together the JumpProblem's
uniform_jump_problems = JumpProblem[JumpProblem(prob, NSM(), majump, hopping_constants=hopping_constants, spatial_system = grid, save_positions=(false,false)) for majump in uniform_majumps]
# flattenned
append!(uniform_jump_problems, JumpProblem[JumpProblem(prob, NRM(), majump, hopping_constants=hopping_constants, spatial_system = grid, save_positions=(false,false)) for majump in uniform_majumps])

# non-uniform
non_uniform_jump_problems = JumpProblem[JumpProblem(prob, NSM(), majump, hopping_constants=hopping_constants, spatial_system = grid, save_positions=(false,false)) for majump in non_uniform_majumps]


# testing
function get_mean_end_state(jump_prob, Nsims)
end_state = zeros(size(jump_prob.prob.u0))
for i in 1:Nsims
sol = solve(jump_prob, SSAStepper())
end_state .+= sol.u[end]
end
end_state/Nsims
end

function discrete_laplacian_from_spatial_system(spatial_system, hopping_rate)
sites = 1:num_sites(spatial_system)
laplacian = zeros(length(sites), length(sites))
for site in sites
laplacian[site,site] = -outdegree(spatial_system, site)
for nb in neighbors(spatial_system, site)
laplacian[site, nb] = 1
end
end
laplacian .*= hopping_rate
laplacian
end
L = discrete_laplacian_from_spatial_system(grid, diffusivity)

# birth and death everywhere
f(u,p,t) = L*u - death_rate*u + uniform_rates[1,:]
ode_prob = ODEProblem(f, zeros(num_nodes), tspan)
sol = solve(ode_prob, Tsit5())

for spatial_jump_prob in uniform_jump_problems
solution = solve(spatial_jump_prob, SSAStepper())
mean_end_state = get_mean_end_state(spatial_jump_prob, Nsims)
mean_end_state = reshape(mean_end_state, num_nodes)
diff = mean_end_state - sol.u[end]
for (i,d) in enumerate(diff)
@test abs(d) < reltol*sol.u[end][i]
end
end

# birth and death zero outside of center site
f(u,p,t) = L*u - diagm([0., 0., death_rate, 0., 0.])*u + [0., 0., 1., 0., 0.]
ode_prob = ODEProblem(f, zeros(num_nodes), tspan)
sol = solve(ode_prob, Tsit5())

solution = solve(non_uniform_jump_problems[1], SSAStepper())
mean_end_state = get_mean_end_state(non_uniform_jump_problems[1], Nsims)
mean_end_state = reshape(mean_end_state, num_nodes)
diff = mean_end_state - sol.u[end]
for (i,d) in enumerate(diff)
@test abs(d) < reltol*sol.u[end][i]
end

# birth everywhere, death only at center site
f(u,p,t) = L*u - diagm([0., 0., death_rate, 0., 0.])*u + ones(num_nodes)
ode_prob = ODEProblem(f, zeros(num_nodes), tspan)
sol = solve(ode_prob, Tsit5())

solution = solve(non_uniform_jump_problems[2], SSAStepper())
mean_end_state = get_mean_end_state(non_uniform_jump_problems[2], Nsims)
mean_end_state = reshape(mean_end_state, num_nodes)
diff = mean_end_state - sol.u[end]
for (i,d) in enumerate(diff)
@test abs(d) < reltol*sol.u[end][i]
end

# birth on left end, death on right end
f(u,p,t) = L*u - diagm([0., 0., 0., 0., death_rate])*u + [1., 0., 0., 0., 0.]
ode_prob = ODEProblem(f, zeros(num_nodes), tspan)
sol = solve(ode_prob, Tsit5())

solution = solve(non_uniform_jump_problems[3], SSAStepper())
mean_end_state = get_mean_end_state(non_uniform_jump_problems[3], Nsims)
mean_end_state = reshape(mean_end_state, num_nodes)
diff = mean_end_state - sol.u[end]
for (i,d) in enumerate(diff)
@test abs(d) < reltol*sol.u[end][i]
end

0 comments on commit c842db8

Please sign in to comment.