diff --git a/.github/workflows/Downstream.yml b/.github/workflows/Downstream.yml index ff3f6be9..dc9cf69e 100644 --- a/.github/workflows/Downstream.yml +++ b/.github/workflows/Downstream.yml @@ -20,6 +20,7 @@ jobs: - {user: SciML, repo: RecursiveArrayTools.jl, group: SymbolicIndexingInterface} - {user: JuliaSymbolics, repo: Symbolics.jl, group: SymbolicIndexingInterface} - {user: SciML, repo: SciMLBase.jl, group: SymbolicIndexingInterface} + - {user: SciML, repo: ModelingToolkit.jl, group: SymbolicIndexingInterface} steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v1 diff --git a/docs/src/api.md b/docs/src/api.md index 17e29494..3157ba7e 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -88,4 +88,5 @@ symbolic_evaluate ```@docs SymbolCache +ProblemState ``` diff --git a/docs/src/usage.md b/docs/src/usage.md index 98e5ad6e..693256d4 100644 --- a/docs/src/usage.md +++ b/docs/src/usage.md @@ -22,10 +22,10 @@ Consider the following example: ```@example Usage using ModelingToolkit, OrdinaryDiffEq, SymbolicIndexingInterface, Plots +using ModelingToolkit: t_nounits as t, D_nounits as D @parameters σ ρ β -@variables t x(t) y(t) z(t) w(t) -D = Differential(t) +@variables x(t) y(t) z(t) w(t) eqs = [D(D(x)) ~ σ * (y - x), D(y) ~ x * (ρ - z) - y, @@ -121,6 +121,30 @@ output, the following shorthand is used: sol[allvariables] # equivalent to sol[all_variable_symbols(sol)] ``` +### Evaluating expressions + +`getu` also generates functions for expressions if the object passed to it supports +[`observed`](@ref). For example: + +```@example Usage +getu(prob, x + y + z)(prob) +``` + +To evaluate this function using values other than the ones contained in `prob`, we need +an object that supports [`state_values`](@ref), [`parameter_values`](@ref), +[`current_time`](@ref). SymbolicIndexingInterface provides the [`ProblemState`](@ref) type, +which has trivial implementations of the above functions. We can thus do: + +```@example Usage +temp_state = ProblemState(; u = [0.1, 0.2, 0.3, 0.4], p = parameter_values(prob)) +getu(prob, x + y + z)(temp_state) +``` + +Note that providing all of the state vector, parameter object and time may not be +necessary if the function generated by `observed` does not access them. ModelingToolkit.jl +generates functions that access the parameters regardless of whether they are used in the +expression, and thus it needs to be provided to the `ProblemState`. + ## Parameter Indexing: Getting and Setting Parameter Values Parameters cannot be obtained using this syntax, and instead require using [`getp`](@ref) and [`setp`](@ref). diff --git a/src/SymbolicIndexingInterface.jl b/src/SymbolicIndexingInterface.jl index af90a252..e5f8e9af 100644 --- a/src/SymbolicIndexingInterface.jl +++ b/src/SymbolicIndexingInterface.jl @@ -31,6 +31,9 @@ include("parameter_indexing.jl") export state_values, set_state!, current_time, getu, setu include("state_indexing.jl") +export ProblemState +include("problem_state.jl") + export ParameterIndexingProxy include("parameter_indexing_proxy.jl") diff --git a/src/problem_state.jl b/src/problem_state.jl new file mode 100644 index 00000000..4a2312c8 --- /dev/null +++ b/src/problem_state.jl @@ -0,0 +1,23 @@ +""" + struct ProblemState + function ProblemState(; u = nothing, p = nothing, t = nothing) + +A struct which can be used as an argument to the function returned by [`getu`](@ref) or +[`setu`](@ref). It stores the state vector, parameter object and current time, and +forwards calls to [`state_values`](@ref), [`parameter_values`](@ref), +[`current_time`](@ref), [`set_state!`](@ref), [`set_parameter!`](@ref) to the contained +objects. +""" +struct ProblemState{U, P, T} + u::U + p::P + t::T +end + +ProblemState(; u = nothing, p = nothing, t = nothing) = ProblemState(u, p, t) + +state_values(ps::ProblemState) = ps.u +parameter_values(ps::ProblemState) = ps.p +current_time(ps::ProblemState) = ps.t +set_state!(ps::ProblemState, val, idx) = set_state!(ps.u, val, idx) +set_parameter!(ps::ProblemState, val, idx) = set_parameter!(ps.p, val, idx) diff --git a/test/problem_state_test.jl b/test/problem_state_test.jl new file mode 100644 index 00000000..d0609251 --- /dev/null +++ b/test/problem_state_test.jl @@ -0,0 +1,15 @@ +using SymbolicIndexingInterface +using Test + +sys = SymbolCache([:x, :y, :z], [:a, :b, :c], :t) +prob = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.5) + +for (i, sym) in enumerate(variable_symbols(sys)) + @test getu(sys, sym)(prob) == prob.u[i] +end +for (i, sym) in enumerate(parameter_symbols(sys)) + @test getp(sys, sym)(prob) == prob.p[i] +end +@test getu(sys, :t)(prob) == prob.t + +@test getu(sys, :(x + a + t))(prob) == 1.6 diff --git a/test/runtests.jl b/test/runtests.jl index 4ebb17cd..d334c172 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,3 +26,6 @@ end @safetestset "Remake test" begin @time include("remake_test.jl") end +@safetestset "ProblemState test" begin + @time include("problem_state_test.jl") +end