Skip to content

Commit

Permalink
fix: fix incorrect indexes of array symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Apr 16, 2024
1 parent 7bc758b commit e878da1
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct ParameterIndex{P, I}
end

const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}}}
const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, Vector{Int}}}

struct IndexCache
unknown_idx::UnknownIndexMap
Expand Down Expand Up @@ -51,6 +51,19 @@ function IndexCache(sys::AbstractSystem)
end
idx += length(sym)
end
for sym in unks
usym = unwrap(sym)
istree(sym) && operation(sym) === getindex || continue
arrsym = arguments(sym)[1]
idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)]
if idxs == idxs[begin]:idxs[end]
idxs = idxs[begin]:idxs[end]

Check warning on line 60 in src/systems/index_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/index_cache.jl#L54-L60

Added lines #L54 - L60 were not covered by tests
end
unk_idxs[arrsym] = idxs
if hasname(arrsym)
unk_idxs[getname(arrsym)] = idxs

Check warning on line 64 in src/systems/index_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/index_cache.jl#L62-L64

Added lines #L62 - L64 were not covered by tests
end
end

Check warning on line 66 in src/systems/index_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/index_cache.jl#L66

Added line #L66 was not covered by tests
end

disc_buffers = Dict{Any, Set{BasicSymbolic}}()
Expand Down
13 changes: 13 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1066,3 +1066,16 @@ prob = SteadyStateProblem(sys, u0, p)
@test prob isa SteadyStateProblem
prob = SteadyStateProblem(ODEProblem(sys, u0, (0.0, 10.0), p))
@test prob isa SteadyStateProblem

# Ensure indexes of array symbolics are cached appropriately
@variables x(t)[1:2]
@named sys = ODESystem(Equation[], t, [x], [])
sys1 = complete(sys)
@named sys = ODESystem(Equation[], t, [x...], [])
sys2 = complete(sys)
for sys in [sys1, sys2]
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
@test is_variable(sys, sym)
@test variable_index(sys, sym) == idx
end
end

0 comments on commit e878da1

Please sign in to comment.