diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 348b7b09ec..c6ce34123e 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -96,7 +96,9 @@ function AffectSystem(affect::Vector{Equation}; discrete_parameters = Any[], collect_vars!(dvs, params, eq, iv) end pre_params = filter(haspre ∘ value, params) + discrete_parameters = gather_array_params(OrderedSet(discrete_parameters)) sys_params = collect(setdiff(params, union(discrete_parameters, pre_params))) + discrete_parameters = collect(discrete_parameters) discretes = map(tovar, discrete_parameters) dvs = collect(dvs) _dvs = map(default_toterm, dvs) @@ -904,7 +906,12 @@ function compile_equational_affect( obseqs, Dict([p => unPre(p) for p in parameters(affsys)])) rhss = map(x -> x.rhs, update_eqs) lhss = map(x -> x.lhs, update_eqs) - is_p = [lhs in Set(ps_to_update) for lhs in lhss] + update_ps_set = Set(ps_to_update) + is_p = map(lhss) do lhs + lhs in update_ps_set || + iscall(lhs) && operation(lhs) === getindex && + arguments(lhs)[1] in update_ps_set + end is_u = [lhs in Set(dvs_to_update) for lhs in lhss] dvs = unknowns(sys) ps = parameters(sys) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 19c78413cf..3592f8e605 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -131,8 +131,14 @@ function IndexCache(sys::AbstractSystem) end for sym in discs - is_parameter(sys, sym) || - error("Expected discrete variable $sym in callback to be a parameter") + if !is_parameter(sys, sym) + if iscall(sym) && operation(sym) === getindex && + is_parameter(sys, arguments(sym)[1]) + sym = arguments(sym)[1] + else + error("Expected discrete variable $sym in callback to be a parameter") + end + end # Only `foo(t)`-esque parameters can be saved if iscall(sym) && length(arguments(sym)) == 1 && diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 9f49950ddc..2f802da6bb 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -805,6 +805,10 @@ function Base.getindex(ngi::NestedGetIndex, idx::Tuple) i, j, k... = idx return ngi.x[i][j][k...] end +function Base.getindex(ngi::NestedGetIndex, idx::NTuple{2}) + i, j = idx + return ngi.x[i][j] +end # Required for DiffEqArray constructor to work during interpolation Base.size(::NestedGetIndex) = () @@ -826,19 +830,27 @@ function SciMLBase.create_parameter_timeseries_collection( isempty(ps.discrete) && return nothing num_discretes = only(blocksize(ps.discrete[1])) buffers = [] - partition_type = Tuple{(typeof(parent(buf)) for buf in ps.discrete)...} + partition_type = typeof(SciMLBase.get_saveable_values(sys, ps, 1)) for i in 1:num_discretes ts = eltype(tspan)[] - us = NestedGetIndex{partition_type}[] + us = partition_type[] push!(buffers, DiffEqArray(us, ts, (1, 1))) end return ParameterTimeseriesCollection(Tuple(buffers), copy(ps)) end +@inline __get_blocks(tsidx::Int) = () +@inline function __get_blocks(tsidx::Int, buffer::BlockedArray, buffers...) + (buffer[Block(tsidx)], __get_blocks(tsidx, buffers...)...) +end +@inline function __get_blocks(tsidx::Int, buffer::BlockedArray{<:AbstractArray}, buffers...) + (copy.(buffer[Block(tsidx)]), __get_blocks(tsidx, buffers...)...) +end + function SciMLBase.get_saveable_values( sys::AbstractSystem, ps::MTKParameters, timeseries_idx) - return NestedGetIndex(Tuple(buffer[Block(timeseries_idx)] for buffer in ps.discrete)) + return NestedGetIndex(__get_blocks(timeseries_idx, ps.discrete...)) end function save_callback_discretes!(integ::SciMLBase.DEIntegrator, callback) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index acf21bc361..b18ddaa73e 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1455,3 +1455,83 @@ end Pre(X) + 10.0]) end + +@testset "Issue#3990: Scalarized array passed to `discrete_parameters` of symbolic affect" begin + N = 2 + @parameters v(t)[1:N] + @parameters M(t)[1:N, 1:N] + + @variables x(t) + + Mini = rand(N, N) ./ (N^2) + vini = vec(sum(Mini, dims = 1)) + + v_eq = [D(x) ~ x * Symbolics.scalarize(sum(v))] + M_eq = [D(x) ~ x * Symbolics.scalarize(sum(M))] + + v_event = ModelingToolkit.SymbolicDiscreteCallback( + 1.0, + [v ~ -Pre(v)], + discrete_parameters = [v] + ) + + M_event = ModelingToolkit.SymbolicDiscreteCallback( + 1.0, + [M ~ -Pre(M)], + discrete_parameters = [M] + ) + + @mtkcompile v_sys = System(v_eq, t; discrete_events = v_event) + @mtkcompile M_sys = System(M_eq, t; discrete_events = M_event) + + u0p0_map = Dict(x => 1.0, M => Mini, v => vini) + + v_prob = ODEProblem(v_sys, u0p0_map, (0.0, 2.5)) + M_prob = ODEProblem(M_sys, u0p0_map, (0.0, 2.5)) + + v_sol = solve(v_prob, Tsit5()) + M_sol = solve(M_prob, Tsit5()) + + @test v_sol[v] ≈ [vini, -vini, vini] + @test M_sol[M] ≈ [Mini, -Mini, Mini] +end + +@testset "Issue#3990: Scalarized array passed to `discrete_parameters` of symbolic affect" begin + N = 2 + @parameters v(t)[1:N] + @parameters M(t)[1:N, 1:N] + + @variables x(t) + + Mini = rand(N, N) ./ (N^2) + vini = vec(sum(Mini, dims = 1)) + + v_eq = [D(x) ~ x * Symbolics.scalarize(sum(v))] + M_eq = [D(x) ~ x * Symbolics.scalarize(sum(M))] + + v_event = ModelingToolkit.SymbolicDiscreteCallback( + 1.0, + [v ~ -Pre(v)], + discrete_parameters = collect(v) + ) + + M_event = ModelingToolkit.SymbolicDiscreteCallback( + 1.0, + [M ~ -Pre(M)], + discrete_parameters = vec(collect(M)) + ) + + @mtkcompile v_sys = System(v_eq, t; discrete_events = v_event) + @mtkcompile M_sys = System(M_eq, t; discrete_events = M_event) + + u0p0_map = Dict(x => 1.0, M => Mini, v => vini) + + v_prob = ODEProblem(v_sys, u0p0_map, (0.0, 2.5)) + M_prob = ODEProblem(M_sys, u0p0_map, (0.0, 2.5)) + + v_sol = solve(v_prob, Tsit5()) + M_sol = solve(M_prob, Tsit5()) + + @test v_sol[v] ≈ [vini, -vini, vini] + @test M_sol[M] ≈ [Mini, -Mini, Mini] +end