From f33d4e06d85032e97c9adfb7c61e4eb67c2a1200 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Sat, 30 Jan 2021 14:37:21 -0500 Subject: [PATCH] Indexing hooks for symbolic DSLs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```julia using ModelingToolkit, OrdinaryDiffEq @parameters t σ ρ β @variables x(t) y(t) z(t) D = Differential(t) eqs = [D(x) ~ σ*(y-x), D(y) ~ x*(ρ-z)-y, D(z) ~ x*y - β*z] lorenz1 = ODESystem(eqs,name=:lorenz1) lorenz2 = ODESystem(eqs,name=:lorenz2) @variables a @parameters γ connections = [0 ~ lorenz1.x + lorenz2.y + a*γ] connected = ODESystem(connections,t,[a],[γ],systems=[lorenz1,lorenz2]) u0 = [lorenz1.x => 1.0, lorenz1.y => 0.0, lorenz1.z => 0.0, lorenz2.x => 0.0, lorenz2.y => 1.0, lorenz2.z => 0.0, a => 2.0] p = [lorenz1.σ => 10.0, lorenz1.ρ => 28.0, lorenz1.β => 8/3, lorenz2.σ => 10.0, lorenz2.ρ => 28.0, lorenz2.β => 8/3, γ => 2.0] tspan = (0.0,100.0) prob = ODEProblem(connected,u0,tspan,p) sol = solve(prob,Rodas5()) sol[lorenz1.x] sol[lorenz1.x,2] sol[lorenz1.x,:] sol[lorenz1.x,1:5] ``` Now for observed functions. --- src/solutions/solution_interface.jl | 48 +++++++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index cbf2b2ad4..b629f311b 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -12,6 +12,40 @@ Base.summary(A::AbstractNoTimeSolution) = string(nameof(typeof(A))," with uType Base.show(io::IO, A::AbstractNoTimeSolution) = (print(io,"u: ");show(io, A.u)) Base.show(io::IO, m::MIME"text/plain", A::AbstractNoTimeSolution) = (print(io,"u: ");show(io,m,A.u)) +# Symbol Handling + +# For handling ambiguities +Base.@propagate_inbounds Base.getindex(A::AbstractTimeseriesSolution, I::Int) = A.u[I] +Base.@propagate_inbounds Base.getindex(A::AbstractTimeseriesSolution, I::Int...) = A.u[I[end]][Base.front(I)...] +Base.@propagate_inbounds Base.getindex(A::AbstractTimeseriesSolution, i::Int,::Colon) = [A.u[j][i] for j in 1:length(A)] +Base.@propagate_inbounds Base.getindex(A::AbstractTimeseriesSolution, i::Int,II::AbstractArray{Int}) = [A.u[j][i] for j in II] + +Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution,sym) + if issymbollike(sym) + i = sym_to_index(sym,A) + else + i = sym + end + A[i,:] +end + +Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution,sym,I::Int...) + if issymbollike(sym) + i = sym_to_index(sym,A) + else + i = sym + end + A[i,I...] +end +Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution,sym,I::Union{AbstractArray{Int},Colon,CartesianIndex}) + if issymbollike(sym) + i = sym_to_index(sym,A) + else + i = sym + end + A[i,I] +end + ## AbstractTimeseriesSolution Interface Base.summary(A::AbstractTimeseriesSolution) = string( @@ -174,7 +208,15 @@ function cleansym(sym::Symbol) end return str end -issymbollike(x) = typeof(x) <: Symbol || Symbol(typeof(x)) == :Operation || Symbol(typeof(x)) == :Variable || Symbol(typeof(x)) == :Num + +sym_to_index(sym,sol::SciMLSolution) = sym_to_index(sym,getsyms(sol)) +sym_to_index(sym,syms) = findfirst(isequal(Symbol(sym)),syms) +issymbollike(x) = typeof(x) <: Symbol || + Symbol(parameterless_type(typeof(x))) == :Operation || + Symbol(parameterless_type(typeof(x))) == :Variable || + Symbol(parameterless_type(typeof(x))) == :Sym || + Symbol(parameterless_type(typeof(x))) == :Num || + Symbol(parameterless_type(typeof(x))) == :Term function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale,strs) if tspan === nothing @@ -259,7 +301,7 @@ function interpret_vars(vars,sol,syms) tmp = [] for x in var if issymbollike(x) - push!(tmp,something(findfirst(isequal(Symbol(x)), syms), 0)) + push!(tmp,something(sym_to_index(x,syms),0)) else push!(tmp,x) end @@ -270,7 +312,7 @@ function interpret_vars(vars,sol,syms) var_int = tmp end elseif issymbollike(var) - var_int = something(findfirst(isequal(Symbol(var)), syms), 0) + var_int = something(sym_to_index(var,syms),0) else var_int = var end