Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/index_provider_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,24 @@ See also: [`is_time_dependent`](@ref), [`is_markovian`](@ref), [`constant_struct
observed(indp, sym) = observed(symbolic_container(indp), sym)
observed(indp, sym, states) = observed(symbolic_container(indp), sym, states)

"""
supports_tuple_observed(indp)

Check if the given index provider supports generating observed functions for tuples of
symbolic variables. Falls back using `symbolic_container`, and returns `false` by
default.

See also: [`observed`](@ref), [`parameter_observed`](@ref), [`symbolic_container`](@ref).
"""
function supports_tuple_observed(indp)
if hasmethod(symbolic_container, Tuple{typeof(indp)}) &&
(sc = symbolic_container(indp)) !== indp
supports_tuple_observed(sc)
else
false
end
end

"""
is_time_dependent(indp)

Expand Down
7 changes: 5 additions & 2 deletions src/parameter_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -599,12 +599,14 @@ for (t1, t2) in [
# `getp` errors on older MTK that doesn't support `parameter_observed`.
getters = getp.((sys,), p)
num_observed = count(is_observed_getter, getters)
supports_tuple = supports_tuple_observed(sys)
p_arr = p isa Tuple ? collect(p) : p

if num_observed == 0
return MultipleParametersGetter(getters)
else
pofn = parameter_observed(sys, p_arr)
pofn = supports_tuple ? parameter_observed(sys, p) :
parameter_observed(sys, p_arr)
if pofn === nothing
return MultipleParametersGetter.(getters)
end
Expand All @@ -615,7 +617,8 @@ for (t1, t2) in [
else
getter = GetParameterObservedNoTime(pofn)
end
return p isa Tuple ? AsParameterTupleWrapper{length(p)}(getter) : getter
return p isa Tuple && !supports_tuple ?
AsParameterTupleWrapper{length(p)}(getter) : getter
end
end
end
Expand Down
7 changes: 4 additions & 3 deletions src/state_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ for (t1, t2) in [
return MultipleGetters(ContinuousTimeseries(), sym)
end
sym_arr = sym isa Tuple ? collect(sym) : sym
supports_tuple = supports_tuple_observed(sys)
num_observed = 0
for s in sym
num_observed += is_observed(sys, s)
Expand All @@ -261,7 +262,7 @@ for (t1, t2) in [
if num_observed == 0 || num_observed == 1 && sym isa Tuple
return MultipleGetters(nothing, getsym.((sys,), sym))
else
obs = observed(sys, sym_arr)
obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr)
getter = TimeIndependentObservedFunction(obs)
if sym isa Tuple
getter = AsTupleWrapper{length(sym)}(getter)
Expand All @@ -283,13 +284,13 @@ for (t1, t2) in [
getters = getsym.((sys,), sym)
return MultipleGetters(ts_idxs, getters)
else
obs = observed(sys, sym_arr)
obs = supports_tuple ? observed(sys, sym) : observed(sys, sym_arr)
getter = if is_time_dependent(sys)
TimeDependentObservedFunction{is_markovian(sys)}(ts_idxs, obs)
else
TimeIndependentObservedFunction(obs)
end
if sym isa Tuple
if sym isa Tuple && !supports_tuple
getter = AsTupleWrapper{length(sym)}(getter)
end
return getter
Expand Down
32 changes: 32 additions & 0 deletions test/state_indexing_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,35 @@ getter = getsym(sys, :(x + y))
@test getter(fi) ≈ 2.8
@test getter(fs) ≈ [3.0i + 2(ts[i] - 0.1) for i in 1:11]
@test getter(fs, 1) ≈ 2.8

struct TupleObservedWrapper{S}
sys::S
end
SymbolicIndexingInterface.symbolic_container(t::TupleObservedWrapper) = t.sys
SymbolicIndexingInterface.supports_tuple_observed(::TupleObservedWrapper) = true

@testset "Tuple observed" begin
sc = SymbolCache([:x, :y, :z], [:a, :b, :c])
sys = TupleObservedWrapper(sc)
ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3])
getter = getsym(sys, (:(x + y), :(y + z)))
@test all(getter(ps) .≈ (3.0, 5.0))
@test getter(ps) isa Tuple
@test_nowarn @inferred getter(ps)
getter = getsym(sys, (:(a + b), :(b + c)))
@test all(getter(ps) .≈ (0.3, 0.5))
@test getter(ps) isa Tuple
@test_nowarn @inferred getter(ps)

sc = SymbolCache([:x, :y, :z], [:a, :b, :c], :t)
sys = TupleObservedWrapper(sc)
ps = ProblemState(; u = [1.0, 2.0, 3.0], p = [0.1, 0.2, 0.3], t = 0.1)
getter = getsym(sys, (:(x + y), :(y + t)))
@test all(getter(ps) .≈ (3.0, 2.1))
@test getter(ps) isa Tuple
@test_nowarn @inferred getter(ps)
getter = getsym(sys, (:(a + b), :(b + c)))
@test all(getter(ps) .≈ (0.3, 0.5))
@test getter(ps) isa Tuple
@test_nowarn @inferred getter(ps)
end
Loading