Skip to content

Commit

Permalink
fix: fix incorrect indexes of array symbolics
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Apr 16, 2024
1 parent 709148e commit 745e889
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
20 changes: 18 additions & 2 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ struct ParameterIndex{P, I}
end

const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}}}
const UnknownIndexMap = Dict{
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, Array{Int}}}

struct IndexCache
unknown_idx::UnknownIndexMap
Expand Down Expand Up @@ -46,11 +47,26 @@ function IndexCache(sys::AbstractSystem)
end
unk_idxs[usym] = sym_idx

if hasname(sym)
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
unk_idxs[getname(usym)] = sym_idx
end
idx += length(sym)
end
for sym in unks
usym = unwrap(sym)
istree(sym) && operation(sym) === getindex || continue
arrsym = arguments(sym)[1]
all(haskey(unk_idxs, arrsym[i]) for i in eachindex(arrsym)) || continue

Check warning on line 59 in src/systems/index_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/index_cache.jl#L58-L59

Added lines #L58 - L59 were not covered by tests

idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)]
if idxs == idxs[begin]:idxs[end]
idxs = idxs[begin]:idxs[end]

Check warning on line 63 in src/systems/index_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/index_cache.jl#L61-L63

Added lines #L61 - L63 were not covered by tests
end
unk_idxs[arrsym] = idxs
if hasname(arrsym)
unk_idxs[getname(arrsym)] = idxs

Check warning on line 67 in src/systems/index_cache.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/index_cache.jl#L65-L67

Added lines #L65 - L67 were not covered by tests
end
end
end

disc_buffers = Dict{Any, Set{BasicSymbolic}}()
Expand Down
13 changes: 13 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1120,3 +1120,16 @@ tearing_state = TearingState(expand_connections(sys))
ts_vars = tearing_state.fullvars
orig_vars = unknowns(sys)
@test isempty(setdiff(ts_vars, orig_vars))

# Ensure indexes of array symbolics are cached appropriately
@variables x(t)[1:2]
@named sys = ODESystem(Equation[], t, [x], [])
sys1 = complete(sys)
@named sys = ODESystem(Equation[], t, [x...], [])
sys2 = complete(sys)
for sys in [sys1, sys2]
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
@test is_variable(sys, sym)
@test variable_index(sys, sym) == idx
end
end

0 comments on commit 745e889

Please sign in to comment.