Skip to content

Commit

Permalink
Merge pull request #2647 from AayushSabharwal/as/fix-index-caching
Browse files Browse the repository at this point in the history
fix: fix incorrect indexes of array symbolics
  • Loading branch information
ChrisRackauckas authored Apr 17, 2024
2 parents 709148e + c9498e2 commit 0db9cd2
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 4 deletions.
24 changes: 20 additions & 4 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ struct ParameterIndex{P, I}
end

const ParamIndexMap = Dict{Union{Symbol, BasicSymbolic}, Tuple{Int, Int}}
const UnknownIndexMap = Dict{Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}}}
const UnknownIndexMap = Dict{
Union{Symbol, BasicSymbolic}, Union{Int, UnitRange{Int}, AbstractArray{Int}}}

struct IndexCache
unknown_idx::UnknownIndexMap
Expand All @@ -40,17 +41,32 @@ function IndexCache(sys::AbstractSystem)
for sym in unks
usym = unwrap(sym)
sym_idx = if Symbolics.isarraysymbolic(sym)
idx:(idx + length(sym) - 1)
reshape(idx:(idx + length(sym) - 1), size(sym))
else
idx
end
unk_idxs[usym] = sym_idx

if hasname(sym)
if hasname(sym) && (!istree(sym) || operation(sym) !== getindex)
unk_idxs[getname(usym)] = sym_idx
end
idx += length(sym)
end
for sym in unks
usym = unwrap(sym)
istree(sym) && operation(sym) === getindex || continue
arrsym = arguments(sym)[1]
all(haskey(unk_idxs, arrsym[i]) for i in eachindex(arrsym)) || continue

idxs = [unk_idxs[arrsym[i]] for i in eachindex(arrsym)]
if idxs == idxs[begin]:idxs[end]
idxs = reshape(idxs[begin]:idxs[end], size(idxs))
end
unk_idxs[arrsym] = idxs
if hasname(arrsym)
unk_idxs[getname(arrsym)] = idxs
end
end
end

disc_buffers = Dict{Any, Set{BasicSymbolic}}()
Expand Down Expand Up @@ -124,7 +140,7 @@ function IndexCache(sys::AbstractSystem)
for (j, p) in enumerate(buf)
idxs[p] = (i, j)
idxs[default_toterm(p)] = (i, j)
if hasname(p)
if hasname(p) && (!istree(p) || operation(p) !== getindex)
idxs[getname(p)] = (i, j)
idxs[getname(default_toterm(p))] = (i, j)
end
Expand Down
27 changes: 27 additions & 0 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1120,3 +1120,30 @@ tearing_state = TearingState(expand_connections(sys))
ts_vars = tearing_state.fullvars
orig_vars = unknowns(sys)
@test isempty(setdiff(ts_vars, orig_vars))

# Ensure indexes of array symbolics are cached appropriately
@variables x(t)[1:2]
@named sys = ODESystem(Equation[], t, [x], [])
sys1 = complete(sys)
@named sys = ODESystem(Equation[], t, [x...], [])
sys2 = complete(sys)
for sys in [sys1, sys2]
for (sym, idx) in [(x, 1:2), (x[1], 1), (x[2], 2)]
@test is_variable(sys, sym)
@test variable_index(sys, sym) == idx
end
end

@variables x(t)[1:2, 1:2]
@named sys = ODESystem(Equation[], t, [x], [])
sys1 = complete(sys)
@named sys = ODESystem(Equation[], t, [x...], [])
sys2 = complete(sys)
for sys in [sys1, sys2]
@test is_variable(sys, x)
@test variable_index(sys, x) == [1 3; 2 4]
for i in eachindex(x)
@test is_variable(sys, x[i])
@test variable_index(sys, x[i]) == variable_index(sys, x)[i]
end
end

0 comments on commit 0db9cd2

Please sign in to comment.