Skip to content

Commit

Permalink
Merge 8fe7753 into c045289
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Mar 26, 2024
2 parents c045289 + 8fe7753 commit b131adf
Show file tree
Hide file tree
Showing 14 changed files with 110 additions and 49 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ RecursiveArrayTools = "3.12"
Reexport = "1.0"
SciMLBase = "2.30.1"
StaticArrays = "1.9"
SymbolicIndexingInterface = "0.3.11"
SymbolicIndexingInterface = "0.3.13"
UnPack = "1.0.2"
julia = "1.10"

Expand Down
2 changes: 1 addition & 1 deletion src/aggregators/coevolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ end

# executing jump at the next jump time
function (p::CoevolveJumpAggregation)(integrator::I) where {I <:
AbstractSSAIntegrator}
AbstractSSAIntegrator}
if !accept_next_jump!(p, integrator, integrator.u, integrator.p, integrator.t)
return nothing
end
Expand Down
9 changes: 3 additions & 6 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,11 @@ function Base.setindex!(prob::JumpProblem, args...; kwargs...)
end

# for updating parameters in JumpProblems to update MassActionJumps
function SII.set_parameter!(prob::JumpProblem, val, idx)
ans = SII.set_parameter!(SII.parameter_values(prob), val, idx)

function SII.finalize_parameters_hook!(prob::JumpProblem, p)
if using_params(prob.massaction_jump)
update_parameters!(prob.massaction_jump, prob.prob.p)
update_parameters!(prob.massaction_jump, SII.parameter_values(prob))
end

ans
nothing
end

# when getindex is used.
Expand Down
4 changes: 2 additions & 2 deletions test/constant_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ nums = Int[]
@time for i in 1:10000
local jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng)
local sol = solve(jump_prob, FunctionMap())
push!(nums, sol[end])
push!(nums, sol.u[end])
end

@test mean(nums) - 45 < 1
Expand All @@ -47,7 +47,7 @@ nums = Int[]
@time for i in 1:10000
local jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng)
local sol = solve(jump_prob, FunctionMap())
push!(nums, sol[2])
push!(nums, sol.u[2])
end

@test sum(nums .== 0) / 10000 - 0.33 < 0.02
16 changes: 8 additions & 8 deletions test/ensemble_uniqueness.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ u0 = [0]
prob = DiscreteProblem(u0, (0.0, 100.0))
jump_prob = JumpProblem(prob, Direct(), j1, j2; rng = rng)
sol = solve(EnsembleProblem(jump_prob), FunctionMap(), trajectories = 3)
@test Array(sol[1]) !== Array(sol[2])
@test Array(sol[1]) !== Array(sol[3])
@test Array(sol[2]) !== Array(sol[3])
@test eltype(sol[1].u[1]) == Int
@test Array(sol.u[1]) !== Array(sol.u[2])
@test Array(sol.u[1]) !== Array(sol.u[3])
@test Array(sol.u[2]) !== Array(sol.u[3])
@test eltype(sol.u[1].u[1]) == Int
sol = solve(EnsembleProblem(jump_prob), SSAStepper(), trajectories = 3)
@test Array(sol[1]) !== Array(sol[2])
@test Array(sol[1]) !== Array(sol[3])
@test Array(sol[2]) !== Array(sol[3])
@test eltype(sol[1].u[1]) == Int
@test Array(sol.u[1]) !== Array(sol.u[2])
@test Array(sol.u[1]) !== Array(sol.u[3])
@test Array(sol.u[2]) !== Array(sol.u[3])
@test eltype(sol.u[1].u[1]) == Int
6 changes: 3 additions & 3 deletions test/extinction_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ sol = solve(jprob, SSAStepper(), callback = cb, save_end = false)
@test sol.t[end] < 1000.0

# test terminate
function extinction_condition(u, t, integrator)
function extinction_condition2(u, t, integrator)
u[1] == 1
end
function extinction_affect!(integrator)
function extinction_affect!2(integrator)
(saved, savedexactly) = savevalues!(integrator, true)
terminate!(integrator)
nothing
end
cb = DiscreteCallback(extinction_condition, extinction_affect!,
cb = DiscreteCallback(extinction_condition2, extinction_affect!2,
save_positions = (false, false))
dprob = DiscreteProblem(u0, (0.0, 1000.0), rates)
jprob = JumpProblem(dprob, Direct(), majump; save_positions = (false, false), rng = rng)
Expand Down
66 changes: 63 additions & 3 deletions test/jprob_symbol_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ affect1!(integ) = (integ.u[1] += 1)
affect2!(integ) = (integ.u[2] += 1)
crj1 = ConstantRateJump(rate1, affect1!)
crj2 = ConstantRateJump(rate2, affect2!)
maj = MassActionJump([[1 => 1], [1 => 1]], [[1 => -1], [1 => -1]]; param_idxs = [1,2])
maj = MassActionJump([[1 => 1], [1 => 1]], [[1 => -1], [1 => -1]]; param_idxs = [1, 2])
g = DiscreteFunction((du, u, p, t) -> nothing;
sys = SymbolicIndexingInterface.SymbolCache([:a, :b], [:p1, :p2], :t))
dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0])
Expand All @@ -15,8 +15,8 @@ jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj)
# test basic querying of u0 and p
@test jprob[:a] == 0
@test jprob[:b] == 10
@test getp(jprob,:p1)(jprob) == 1.0
@test getp(jprob,:p2)(jprob) == 2.0
@test getp(jprob, :p1)(jprob) == 1.0
@test getp(jprob, :p2)(jprob) == 2.0
@test jprob.ps[:p1] == 1.0
@test jprob.ps[:p2] == 2.0

Expand All @@ -34,3 +34,63 @@ p1setter(jprob, [4.0, 10.0])
@test jprob.ps[:p1] == 4.0
@test jprob.ps[:p2] == 10.0
@test jprob.massaction_jump.scaled_rates == [4.0, 10.0]

# integrator tests
# note that `setu` is not currently supported as `set_u!` is not implemented for SSAStepper
integ = init(jprob, SSAStepper())
@test getu(integ, [:a, :b])(integ) == [20, 10]
integ[[:b, :a]] = [40, 5]
@test getu(integ, [:a, :b])(integ) == [5, 40]
@test getp(integ, :p2)(integ) == 10.0
setp(integ, :p2)(integ, 15.0)
@test getp(integ, :p2)(integ) == 15.0
@test jprob.massaction_jump.scaled_rates[2] == 10.0 # jump rate not updated
reset_aggregated_jumps!(integ)
@test jprob.massaction_jump.scaled_rates[2] == 15.0 # jump rate now updated

# remake tests
dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0])
jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj)
jprob = remake(jprob; u0 = [:a => -10, :b => 100], p = [:p2 => 3.5, :p1 => 0.5])
@test jprob.prob.u0 == [-10, 100]
@test jprob.prob.p == [0.5, 3.5]
@test jprob.massaction_jump.scaled_rates == [0.5, 3.5]
jprob = remake(jprob; u0 = [:b => 10], p = [:p2 => 4.5])
@test jprob.prob.u0 == [-10, 10]
@test jprob.prob.p == [0.5, 4.5]
@test jprob.massaction_jump.scaled_rates == [0.5, 4.5]

# test updating problems via regular indexing still updates the mass action jump
dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0])
jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj)
@test jprob.massaction_jump.scaled_rates[1] == 1.0
jprob.ps[1] = 3.0
@test jprob.ps[1] == 3.0
@test jprob.massaction_jump.scaled_rates[1] == 3.0

# test updating integrators via regular indexing
dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0])
jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj)
integ = init(jprob, SSAStepper())
integ.u .= [40, 5]
@test getu(integ, [1, 2])(integ) == [40, 5]
@test getp(integ, 2)(integ) == 2.0
@test integ.p[2] == 2.0
@test jprob.massaction_jump.scaled_rates[2] == 2.0
setp(integ, 2)(integ, 15.0)
@test integ.p[2] == 15.0
@test getp(integ, 2)(integ) == 15.0
reset_aggregated_jumps!(integ)
@test jprob.massaction_jump.scaled_rates[2] == 15.0 # jump rate now updated

# remake tests for regular indexing
dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0])
jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj)
jprob = remake(jprob; u0 = [-10, 100], p = [0.5, 3.5])
@test jprob.prob.u0 == [-10, 100]
@test jprob.prob.p == [0.5, 3.5]
@test jprob.massaction_jump.scaled_rates == [0.5, 3.5]
jprob = remake(jprob; u0 = [2 => 10], p = [2 => 4.5])
@test jprob.prob.u0 == [-10, 10]
@test jprob.prob.p == [0.5, 4.5]
@test jprob.massaction_jump.scaled_rates == [0.5, 4.5]
2 changes: 1 addition & 1 deletion test/longtimes_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ dt = tspan[2] / 1000
dprob = DiscreteProblem(u0, tspan, p)
jprob = JumpProblem(dprob, Direct(), maj, save_positions = (false, false), rng = rng)
sol = solve(jprob, SSAStepper(), saveat = tspan[1]:dt:tspan[2])
@test length(unique(sol[(end - 10):end][:])) > 1
@test length(unique(sol.u[(end - 10):end][:])) > 1
4 changes: 2 additions & 2 deletions test/monte_carlo_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ jump_prob = JumpProblem(prob, Direct(), jump; rng = rng)
monte_prob = EnsembleProblem(jump_prob)
sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3,
save_everystep = false, dt = 0.001, adaptive = false)
@test sol[1].t[2] != sol[2].t[2] != sol[3].t[2]
@test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2]

jump = ConstantRateJump(rate, affect!)
jump_prob = JumpProblem(prob, Direct(), jump, save_positions = (true, false), rng = rng)
monte_prob = EnsembleProblem(jump_prob)
sol = solve(monte_prob, SRIW1(), EnsembleSerial(), trajectories = 3,
save_everystep = false, dt = 0.001, adaptive = false)
@test sol[1].t[2] != sol[2].t[2] != sol[3].t[2]
@test sol.u[1].t[2] != sol.u[2].t[2] != sol.u[3].t[2]
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using JumpProcesses, DiffEqBase, SafeTestsets
@time @safetestset "Thread Safety test" begin include("thread_safety.jl") end
@time @safetestset "A + B <--> C" begin include("reversible_binding.jl") end
@time @safetestset "Remake tests" begin include("remake_test.jl") end
@time @safetestset "Symbol based problem indexing" begin include("jprob_symbol_indexing.jl") end
@time @safetestset "Long time accuracy test" begin include("longtimes_test.jl") end
@time @safetestset "Hawkes process" begin include("hawkes_test.jl") end
@time @safetestset "Reaction rates" begin include("spatial/reaction_rates.jl") end
Expand All @@ -37,5 +38,4 @@ using JumpProcesses, DiffEqBase, SafeTestsets
@time @safetestset "Spatial A + B <--> C" begin include("spatial/ABC.jl") end
@time @safetestset "Spatially Varying Reaction Rates" begin include("spatial/spatial_majump.jl") end
@time @safetestset "Pure diffusion" begin include("spatial/diffusion.jl") end
@time @safetestset "Symbol based problem indexing" begin include("jprob_symbol_indexing.jl") end
end
4 changes: 2 additions & 2 deletions test/saveat_regression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(), saveat = ts,
trajectories = Nsims)

for i in 1:length(sol)
NA .+= sol[i][1, :]
NA .+= sol.u[i][1, :]
end

for i in 1:length(ts)
Expand All @@ -33,7 +33,7 @@ sol = JumpProcesses.solve(EnsembleProblem(jprob), SSAStepper(), trajectories = N

for i in 1:Nsims
for n in 1:length(ts)
NA[n] += sol[i](ts[n])[1]
NA[n] += sol.u[i](ts[n])[1]
end
end

Expand Down
4 changes: 2 additions & 2 deletions test/spatial/diffusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ starting_state = zeros(Int, length(u0), num_nodes)
center_node = trunc(Int, (num_nodes + 1) / 2)
starting_state[:, center_node] = copy(u0)
tspan = (0.0, tf)
prob = DiscreteProblem(starting_state, tspan, [])
prob = DiscreteProblem(starting_state, tspan)

hopping_rate = diffusivity * (linear_size / domain_size)^2
hopping_constants = [hopping_rate for i in starting_state]
Expand Down Expand Up @@ -162,7 +162,7 @@ for ci in CartesianIndices(hopping_constants)
end
starting_state = 25 * ones(Int, length(u0), num_nodes)
tspan = (0.0, 10.0)
prob = DiscreteProblem(starting_state, tspan, [])
prob = DiscreteProblem(starting_state, tspan)

jp = JumpProblem(prob, NSM(), majumps, hopping_constants = hopping_constants,
spatial_system = grid, save_positions = (false, false), rng = rng)
Expand Down
16 changes: 10 additions & 6 deletions test/spatial/spatial_majump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ for spatial_jump_prob in uniform_jump_problems
end

# birth and death zero outside of center site
f(u, p, t) = L * u - diagm([0.0, 0.0, death_rate, 0.0, 0.0]) * u + [0.0, 0.0, 1.0, 0.0, 0.0]
ode_prob = ODEProblem(f, zeros(num_nodes), tspan)
function f2(u, p, t)
L * u - diagm([0.0, 0.0, death_rate, 0.0, 0.0]) * u + [0.0, 0.0, 1.0, 0.0, 0.0]
end
ode_prob = ODEProblem(f2, zeros(num_nodes), tspan)
sol = solve(ode_prob, Tsit5())

solution = solve(non_uniform_jump_problems[1], SSAStepper())
Expand All @@ -129,8 +131,8 @@ for (i, d) in enumerate(diff)
end

# birth everywhere, death only at center site
f(u, p, t) = L * u - diagm([0.0, 0.0, death_rate, 0.0, 0.0]) * u + ones(num_nodes)
ode_prob = ODEProblem(f, zeros(num_nodes), tspan)
f3(u, p, t) = L * u - diagm([0.0, 0.0, death_rate, 0.0, 0.0]) * u + ones(num_nodes)
ode_prob = ODEProblem(f3, zeros(num_nodes), tspan)
sol = solve(ode_prob, Tsit5())

solution = solve(non_uniform_jump_problems[2], SSAStepper())
Expand All @@ -142,8 +144,10 @@ for (i, d) in enumerate(diff)
end

# birth on left end, death on right end
f(u, p, t) = L * u - diagm([0.0, 0.0, 0.0, 0.0, death_rate]) * u + [1.0, 0.0, 0.0, 0.0, 0.0]
ode_prob = ODEProblem(f, zeros(num_nodes), tspan)
function f4(u, p, t)
L * u - diagm([0.0, 0.0, 0.0, 0.0, death_rate]) * u + [1.0, 0.0, 0.0, 0.0, 0.0]
end
ode_prob = ODEProblem(f4, zeros(num_nodes), tspan)
sol = solve(ode_prob, Tsit5())

solution = solve(non_uniform_jump_problems[3], SSAStepper())
Expand Down
22 changes: 11 additions & 11 deletions test/variable_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ sol = solve(jump_prob, Rosenbrock23())
# @show sol[end]
# display(sol[end])

@test maximum([sol[i][2] for i in 1:length(sol)]) <= 1e-12
@test maximum([sol[i][3] for i in 1:length(sol)]) <= 1e-12
@test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12
@test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1e-12

g = function (du, u, p, t)
du[1] = u[1]
Expand All @@ -52,8 +52,8 @@ jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng)

sol = solve(jump_prob, SRIW1())

@test maximum([sol[i][2] for i in 1:length(sol)]) <= 1e-12
@test maximum([sol[i][3] for i in 1:length(sol)]) <= 1e-12
@test maximum([sol.u[i][2] for i in 1:length(sol)]) <= 1e-12
@test maximum([sol.u[i][3] for i in 1:length(sol)]) <= 1e-12

function ff(du, u, p, t)
if p == 0
Expand Down Expand Up @@ -95,16 +95,16 @@ jump = ConstantRateJump(rate2, affect2!)
jump_prob = JumpProblem(prob, Direct(), jump; rng = rng)
sol = solve(jump_prob, Tsit5())
sol(4.0)
sol[4]
sol.u[4]

rate2(u, p, t) = u[1]
rate2b(u, p, t) = u[1]
affect2!(integrator) = (integrator.u[1] = integrator.u[1] / 2)
jump = VariableRateJump(rate2, affect2!)
jump = VariableRateJump(rate2b, affect2!)
jump2 = deepcopy(jump)
jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng)
sol = solve(jump_prob, Tsit5())
sol(4.0)
sol[4]
sol.u[4]

function g2(du, u, p, t)
du[1] = u[1]
Expand All @@ -114,7 +114,7 @@ prob = SDEProblem(f2, g2, [0.2], (0.0, 10.0))
jump_prob = JumpProblem(prob, Direct(), jump, jump2; rng = rng)
sol = solve(jump_prob, SRIW1())
sol(4.0)
sol[4]
sol.u[4]

function f3(du, u, p, t)
du .= u
Expand Down Expand Up @@ -151,7 +151,7 @@ function drift(x, p, t)
return p * x
end

function rate2(x, p, t)
function rate2c(x, p, t)
return 3 * max(0.0, x[1])
end

Expand All @@ -160,7 +160,7 @@ function affect!2(integrator)
end
x0 = rand(2)
prob = ODEProblem(drift, x0, (0.0, 10.0), 2.0)
jump = VariableRateJump(rate2, affect!2)
jump = VariableRateJump(rate2c, affect!2)
jump_prob = JumpProblem(prob, Direct(), jump)

# test to check lack of dependency graphs is caught in Coevolve for systems with non-maj
Expand Down

0 comments on commit b131adf

Please sign in to comment.