Skip to content

Commit

Permalink
fix problem symbolic indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
isaacsas committed Mar 20, 2024
1 parent 7a279c6 commit dbd37f6
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 19 deletions.
18 changes: 10 additions & 8 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
Expand All @@ -28,18 +29,19 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
JumpProcessFastBroadcastExt = "FastBroadcast"

[compat]
ArrayInterface = "7"
ArrayInterface = "7.9"
DataStructures = "0.18"
DiffEqBase = "6.122"
DiffEqBase = "6.148"
DocStringExtensions = "0.9"
FunctionWrappers = "1.0"
Graphs = "1.4"
FunctionWrappers = "1.1"
Graphs = "1.9"
PoissonRandom = "0.4"
RandomNumbers = "1.3"
RecursiveArrayTools = "3"
RandomNumbers = "1.5"
RecursiveArrayTools = "3.12"
Reexport = "1.0"
SciMLBase = "1.51, 2"
StaticArrays = "1.0"
SciMLBase = "2.30.1"
StaticArrays = "1.9"
SymbolicIndexingInterface = "0.3.11"
UnPack = "1.0.2"
julia = "1.10"

Expand Down
1 change: 1 addition & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Graphs: neighbors, outdegree

import RecursiveArrayTools: recursivecopy!
using StaticArrays, Base.Threads
import SymbolicIndexingInterface as SII

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
Expand Down
13 changes: 8 additions & 5 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,15 @@ end
# when setindex! is used.
function Base.setindex!(prob::JumpProblem, args...; kwargs...)
SciMLBase.___internal_setindex!(prob.prob, args...; kwargs...)
end

# for updating parameters in JumpProblems to update MassActionJumps
function SII.set_parameter!(prob::JumpProblem, val, idx)
ans = SII.set_parameter!(SII.parameter_values(prob), val, idx)

# since parameters are no longer allowed to be mutated and the preceding will error this
# isn't needed
# if using_params(prob.massaction_jump)
# update_parameters!(prob.massaction_jump, prob.prob.p)
# end
if using_params(prob.massaction_jump)
update_parameters!(prob.massaction_jump, prob.prob.p)
end
end

# when getindex is used.
Expand Down
27 changes: 21 additions & 6 deletions test/jprob_symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,36 @@
# prepares the problem
using JumpProcesses, Test
using JumpProcesses, Test, SymbolicIndexingInterface
rate1(u, p, t) = p[1]
rate2(u, p, t) = p[2]
affect1!(integ) = (integ.u[1] += 1)
affect2!(integ) = (integ.u[2] += 1)
crj1 = ConstantRateJump(rate1, affect1!)
crj2 = ConstantRateJump(rate2, affect2!)
g = DiscreteFunction((du, u, p, t) -> nothing; sys = SymbolicIndexingInterface.SymbolCache([:a, :b], [:p1, :p2], :t))
maj = MassActionJump([[1 => 1], [1 => 1]], [[1 => -1], [1 => -1]]; param_idxs = [1,2])
g = DiscreteFunction((du, u, p, t) -> nothing;
sys = SymbolicIndexingInterface.SymbolCache([:a, :b], [:p1, :p2], :t))
dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0])
jprob = JumpProblem(dprob, Direct(), crj1, crj2)
jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj)

# runs the tests
# test basic querying of u0 and p
@test jprob[:a] == 0
@test jprob[:b] == 10

# these are no longer supported by SciMLBase
@test getp(jprob,:p1)(jprob) == 1.0
@test getp(jprob,:p2)(jprob) == 2.0
@test jprob.ps[:p1] == 1.0
@test jprob.ps[:p2] == 2.0

# test updating u0
jprob[:a] = 20
@test jprob[:a] == 20

# test mass action jumps update with parameter mutation in problems
@test jprob.massaction_jump.scaled_rates[1] == 1.0
jprob.ps[:p1] = 3.0
@test jprob.ps[:p1] == 3.0
@test jprob.massaction_jump.scaled_rates[1] == 3.0
p1setter = setp(jprob, [:p1, :p2])
p1setter(jprob, [4.0, 10.0])
@test jprob.ps[:p1] == 4.0
@test jprob.ps[:p2] == 10.0
@test jprob.massaction_jump.scaled_rates == [4.0, 10.0]

0 comments on commit dbd37f6

Please sign in to comment.