Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Indexing fixes #408

Merged
merged 9 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ jobs:
- Core
version:
- '1'
- '1.6'
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
26 changes: 14 additions & 12 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"

[weakdeps]
Expand All @@ -28,20 +29,21 @@ FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
JumpProcessFastBroadcastExt = "FastBroadcast"

[compat]
ArrayInterface = "6, 7"
DataStructures = "0.17, 0.18"
DiffEqBase = "6.122"
DocStringExtensions = "0.8.6, 0.9"
FunctionWrappers = "1.0"
Graphs = "1.4"
ArrayInterface = "7.9"
DataStructures = "0.18"
DiffEqBase = "6.148"
DocStringExtensions = "0.9"
FunctionWrappers = "1.1"
Graphs = "1.9"
PoissonRandom = "0.4"
RandomNumbers = "1.3"
RecursiveArrayTools = "2, 3"
Reexport = "0.2, 1.0"
SciMLBase = "1.51, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
RandomNumbers = "1.5"
RecursiveArrayTools = "3.12"
Reexport = "1.0"
SciMLBase = "2.30.1"
StaticArrays = "1.9"
SymbolicIndexingInterface = "0.3.11"
UnPack = "1.0.2"
julia = "1.6"
julia = "1.10"

[extras]
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand Down
1 change: 1 addition & 0 deletions src/JumpProcesses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Graphs: neighbors, outdegree

import RecursiveArrayTools: recursivecopy!
using StaticArrays, Base.Threads
import SymbolicIndexingInterface as SII

abstract type AbstractJump end
abstract type AbstractMassActionJump <: AbstractJump end
Expand Down
2 changes: 1 addition & 1 deletion src/aggregators/coevolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ end

# executing jump at the next jump time
function (p::CoevolveJumpAggregation)(integrator::I) where {I <:
AbstractSSAIntegrator}
AbstractSSAIntegrator}
if !accept_next_jump!(p, integrator, integrator.u, integrator.p, integrator.t)
return nothing
end
Expand Down
3 changes: 1 addition & 2 deletions src/aggregators/ssajump.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ end
end

@inline function concretize_affects!(p::AbstractSSAJumpAggregator{T, S, F1, F2},
::I) where {T, S, F1, F2 <: Tuple,
I <: DiffEqBase.DEIntegrator}
::I) where {T, S, F1, F2 <: Tuple, I <: DiffEqBase.DEIntegrator}
nothing
end

Expand Down
24 changes: 14 additions & 10 deletions src/problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,17 @@
# when setindex! is used.
function Base.setindex!(prob::JumpProblem, args...; kwargs...)
SciMLBase.___internal_setindex!(prob.prob, args...; kwargs...)
end

# for updating parameters in JumpProblems to update MassActionJumps
function SII.set_parameter!(prob::JumpProblem, val, idx)
ans = SII.set_parameter!(SII.parameter_values(prob), val, idx)

if using_params(prob.massaction_jump)
update_parameters!(prob.massaction_jump, prob.prob.p)
end

ans

Check warning on line 135 in src/problem.jl

View check run for this annotation

Codecov / codecov/patch

src/problem.jl#L135

Added line #L135 was not covered by tests
Comment on lines +128 to +135
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is technically correct, set_parameter! will be called once for each parameter which would make this very expensive. I'll PR to SII to add a callback that runs at the end of setp (finalize_parameters_hook!(prob, ps)?)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I guess we could also use that for calling reset_aggregated_jumps! when integrators get updated in callbacks, which means one less thing for users to have to handle.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AayushSabharwal please ping me when a SII release with that callback is available so I can get this finished -- we need this for making updates to Catalyst for MTK9. Thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been released now @isaacsas

Copy link
Member

@AayushSabharwal AayushSabharwal Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, I thought it was merged and would be available. Sorry 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it's out in 0.3.13

end

# when getindex is used.
Expand All @@ -151,14 +159,12 @@
JumpProblem(prob, JumpSet(jumps...); kwargs...)
end

function JumpProblem(
prob, aggregator::AbstractAggregatorAlgorithm, jumps::ConstantRateJump;
kwargs...)
function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm,
jumps::ConstantRateJump; kwargs...)
JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...)
end
function JumpProblem(
prob, aggregator::AbstractAggregatorAlgorithm, jumps::VariableRateJump;
kwargs...)
function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm,
jumps::VariableRateJump; kwargs...)
JumpProblem(prob, aggregator, JumpSet(jumps); kwargs...)
end
function JumpProblem(prob, aggregator::AbstractAggregatorAlgorithm, jumps::RegularJump;
Expand Down Expand Up @@ -321,8 +327,7 @@
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:length(jumps)])
remake(prob, f = SDEFunction{isinplace(prob)}(jump_f, jump_g), g = jump_g, u0 = u0)
end

Expand All @@ -347,8 +352,7 @@
end

ttype = eltype(prob.tspan)
u0 = ExtendedJumpArray(prob.u0,
[-randexp(rng, ttype) for i in 1:length(jumps)])
u0 = ExtendedJumpArray(prob.u0, [-randexp(rng, ttype) for i in 1:length(jumps)])

Check warning on line 355 in src/problem.jl

View check run for this annotation

Codecov / codecov/patch

src/problem.jl#L355

Added line #L355 was not covered by tests
remake(prob, f = DDEFunction{isinplace(prob)}(jump_f), u0 = u0)
end

Expand Down
31 changes: 24 additions & 7 deletions test/jprob_symbol_indexing.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,36 @@
# prepares the problem
using JumpProcesses, Test
using JumpProcesses, Test, SymbolicIndexingInterface
rate1(u, p, t) = p[1]
rate2(u, p, t) = p[2]
affect1!(integ) = (integ.u[1] += 1)
affect2!(integ) = (integ.u[2] += 1)
crj1 = ConstantRateJump(rate1, affect1!)
crj2 = ConstantRateJump(rate2, affect2!)
g = DiscreteFunction((du, u, p, t) -> nothing; syms = [:a, :b], paramsyms = [:p1, :p2])
maj = MassActionJump([[1 => 1], [1 => 1]], [[1 => -1], [1 => -1]]; param_idxs = [1,2])
g = DiscreteFunction((du, u, p, t) -> nothing;
sys = SymbolicIndexingInterface.SymbolCache([:a, :b], [:p1, :p2], :t))
dprob = DiscreteProblem(g, [0, 10], (0.0, 10.0), [1.0, 2.0])
jprob = JumpProblem(dprob, Direct(), crj1, crj2)
jprob = JumpProblem(dprob, Direct(), crj1, crj2, maj)

# runs the tests
# test basic querying of u0 and p
@test jprob[:a] == 0
@test jprob[:b] == 10
@test jprob[:p1] == 1.0
@test jprob[:p2] == 2.0
@test getp(jprob,:p1)(jprob) == 1.0
@test getp(jprob,:p2)(jprob) == 2.0
@test jprob.ps[:p1] == 1.0
@test jprob.ps[:p2] == 2.0

# tests for setindex (e.g. `jprob[:a] = 10`) not possible, this requires the problem to have a .f.sys filed.,
# test updating u0
jprob[:a] = 20
@test jprob[:a] == 20

# test mass action jumps update with parameter mutation in problems
@test jprob.massaction_jump.scaled_rates[1] == 1.0
jprob.ps[:p1] = 3.0
@test jprob.ps[:p1] == 3.0
@test jprob.massaction_jump.scaled_rates[1] == 3.0
p1setter = setp(jprob, [:p1, :p2])
p1setter(jprob, [4.0, 10.0])
@test jprob.ps[:p1] == 4.0
@test jprob.ps[:p2] == 10.0
@test jprob.massaction_jump.scaled_rates == [4.0, 10.0]
Loading