diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 79fbbfb16b..8cd96a8f97 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -669,7 +669,7 @@ function SciMLBase.remake_initialization_data( end promote_type_with_nothing(::Type{T}, ::Nothing) where {T} = T -promote_type_with_nothing(::Type{T}, ::SizedVector{0}) where {T} = T +promote_type_with_nothing(::Type{T}, ::StaticVector{0}) where {T} = T function promote_type_with_nothing(::Type{T}, ::AbstractArray{T2}) where {T, T2} promote_type(T, T2) end @@ -678,7 +678,7 @@ function promote_type_with_nothing(::Type{T}, p::MTKParameters) where {T} end promote_with_nothing(::Type, ::Nothing) = nothing -promote_with_nothing(::Type, x::SizedVector{0}) = x +promote_with_nothing(::Type, x::StaticVector{0}) = x promote_with_nothing(::Type{T}, x::AbstractArray{T}) where {T} = x function promote_with_nothing(::Type{T}, x::AbstractArray{T2}) where {T, T2} if ArrayInterface.ismutable(x) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 2f802da6bb..a2421b5646 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -10,6 +10,26 @@ struct MTKParameters{T, I, D, C, N, H} constant::C nonnumeric::N caches::H + + function MTKParameters{T, I, D, C, N, H}(tunables::T, initials::I, discrete::D, + constant::C, nonnumeric::N, + caches::H) where {T, I, D, C, N, H} + if tunables isa StaticVector{0} + tunables = SVector{0, eltype(tunables)}() + end + if initials isa StaticVector{0} + initials = SVector{0, eltype(initials)}() + end + return new{typeof(tunables), typeof(initials), D, C, N, H}(tunables, initials, + discrete, constant, + nonnumeric, caches) + end + function MTKParameters(tunables::T, initials::I, discrete::D, + constant::C, nonnumeric::N, + caches::H) where {T, I, D, C, N, H} + return MTKParameters{T, I, D, C, N, H}(tunables, initials, discrete, constant, + nonnumeric, caches) + end end """ @@ -138,11 +158,11 @@ function MTKParameters( end tunable_buffer = narrow_buffer_type(tunable_buffer; p_constructor) if isempty(tunable_buffer) - tunable_buffer = SizedVector{0, Float64}() + tunable_buffer = SVector{0, Float64}() end initials_buffer = narrow_buffer_type(initials_buffer; p_constructor) if isempty(initials_buffer) - initials_buffer = SizedVector{0, Float64}() + initials_buffer = SVector{0, Float64}() end disc_buffer = narrow_buffer_type.(disc_buffer; p_constructor) const_buffer = narrow_buffer_type.(const_buffer; p_constructor) @@ -879,10 +899,10 @@ end @generated function Base.getindex( ps::MTKParameters{T, I, D, C, N, H}, idx::Int) where {T, I, D, C, N, H} paths = [] - if !(T <: SizedVector{0}) + if !(T <: StaticVector{0}) push!(paths, :(ps.tunable)) end - if !(I <: SizedVector{0}) + if !(I <: StaticVector{0}) push!(paths, :(ps.initials)) end for i in 1:fieldcount(D) @@ -909,10 +929,10 @@ end @generated function Base.length(ps::MTKParameters{ T, I, D, C, N, H}) where {T, I, D, C, N, H} len = 0 - if !(T <: SizedVector{0}) + if !(T <: SVector{0}) len += 1 end - if !(I <: SizedVector{0}) + if !(I <: SVector{0}) len += 1 end len += fieldcount(D) + fieldcount(C) + fieldcount(N) + fieldcount(H) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index fd1662be87..ab7735e05f 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -780,7 +780,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac # `syms[1]` is always the tunables because `srcsys` will have initials. tunable_syms = syms[1] tunable_getter = if isempty(tunable_syms) - Returns(SizedVector{0, Float64}()) + Returns(SVector{0, Float64}()) else p_constructor ∘ concrete_getu(srcsys, tunable_syms; eval_expression, eval_module) end @@ -803,7 +803,7 @@ function get_mtkparameters_reconstructor(srcsys::AbstractSystem, dstsys::Abstrac end p_constructor ∘ concrete_getu(srcsys, initsyms; eval_expression, eval_module) else - Returns(SizedVector{0, Float64}()) + Returns(SVector{0, Float64}()) end discs_getter = if isempty(syms[3]) Returns(()) @@ -923,7 +923,7 @@ function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) if newp isa MTKParameters # and initials portion buf, repack, alias = SciMLStructures.canonicalize(SciMLStructures.Initials(), newp) - if eltype(buf) != T + if eltype(buf) != T && !(buf isa SVector{0}) newbuf = similar(buf, T) copyto!(newbuf, buf) newp = repack(newbuf) @@ -1148,8 +1148,10 @@ function maybe_build_initialization_problem( if is_split(sys) buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), initp) initp = repack(floatT.(buffer)) - buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp) - initp = repack(floatT.(buffer)) + if !(initp.initials isa StaticVector{0}) + buffer, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Initials(), initp) + initp = repack(floatT.(buffer)) + end elseif initp isa AbstractArray if ArrayInterface.ismutable(initp) initp′ = similar(initp, floatT) diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 966a087d9c..566e309378 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -429,3 +429,13 @@ end grad = ForwardDiff.gradient(Base.Fix2(loss, (setter, prob)), [3.0]) @test grad ≈ [0.14882627068752538] atol=1e-10 end + +@testset "MTKParameters can be made `isbits`" begin + @variables x(t) + @parameters p + @named sys = System(D(x) ~ x * p, t) + sys = complete(sys) + prob = ODEProblem(sys, SA[x => 1.0, p => 1.0], (0.0, 1.0)) + @test isbits(prob.p) + @test isbits(prob.f.initialization_data.initializeprob.p) +end diff --git a/test/split_parameters.jl b/test/split_parameters.jl index a92d9f8350..e0bd7328d9 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -6,7 +6,6 @@ using BlockArrays: BlockedArray using ModelingToolkit: t_nounits as t, D_nounits as D using ModelingToolkit: MTKParameters, ParameterIndex, NONNUMERIC_PORTION using SciMLStructures: Tunable, Discrete, Constants, Initials -using StaticArrays: SizedVector using SymbolicIndexingInterface: is_parameter, getp x = [1, 2.0, false, [1, 2, 3], Parameter(1.0)]