diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index cbf2b2ad4..b629f311b 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -12,6 +12,40 @@ Base.summary(A::AbstractNoTimeSolution) = string(nameof(typeof(A))," with uType Base.show(io::IO, A::AbstractNoTimeSolution) = (print(io,"u: ");show(io, A.u)) Base.show(io::IO, m::MIME"text/plain", A::AbstractNoTimeSolution) = (print(io,"u: ");show(io,m,A.u)) +# Symbol Handling + +# For handling ambiguities +Base.@propagate_inbounds Base.getindex(A::AbstractTimeseriesSolution, I::Int) = A.u[I] +Base.@propagate_inbounds Base.getindex(A::AbstractTimeseriesSolution, I::Int...) = A.u[I[end]][Base.front(I)...] +Base.@propagate_inbounds Base.getindex(A::AbstractTimeseriesSolution, i::Int,::Colon) = [A.u[j][i] for j in 1:length(A)] +Base.@propagate_inbounds Base.getindex(A::AbstractTimeseriesSolution, i::Int,II::AbstractArray{Int}) = [A.u[j][i] for j in II] + +Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution,sym) + if issymbollike(sym) + i = sym_to_index(sym,A) + else + i = sym + end + A[i,:] +end + +Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution,sym,I::Int...) + if issymbollike(sym) + i = sym_to_index(sym,A) + else + i = sym + end + A[i,I...] +end +Base.@propagate_inbounds function Base.getindex(A::AbstractTimeseriesSolution,sym,I::Union{AbstractArray{Int},Colon,CartesianIndex}) + if issymbollike(sym) + i = sym_to_index(sym,A) + else + i = sym + end + A[i,I] +end + ## AbstractTimeseriesSolution Interface Base.summary(A::AbstractTimeseriesSolution) = string( @@ -174,7 +208,15 @@ function cleansym(sym::Symbol) end return str end -issymbollike(x) = typeof(x) <: Symbol || Symbol(typeof(x)) == :Operation || Symbol(typeof(x)) == :Variable || Symbol(typeof(x)) == :Num + +sym_to_index(sym,sol::SciMLSolution) = sym_to_index(sym,getsyms(sol)) +sym_to_index(sym,syms) = findfirst(isequal(Symbol(sym)),syms) +issymbollike(x) = typeof(x) <: Symbol || + Symbol(parameterless_type(typeof(x))) == :Operation || + Symbol(parameterless_type(typeof(x))) == :Variable || + Symbol(parameterless_type(typeof(x))) == :Sym || + Symbol(parameterless_type(typeof(x))) == :Num || + Symbol(parameterless_type(typeof(x))) == :Term function diffeq_to_arrays(sol,plot_analytic,denseplot,plotdensity,tspan,axis_safety,vars,int_vars,tscale,strs) if tspan === nothing @@ -259,7 +301,7 @@ function interpret_vars(vars,sol,syms) tmp = [] for x in var if issymbollike(x) - push!(tmp,something(findfirst(isequal(Symbol(x)), syms), 0)) + push!(tmp,something(sym_to_index(x,syms),0)) else push!(tmp,x) end @@ -270,7 +312,7 @@ function interpret_vars(vars,sol,syms) var_int = tmp end elseif issymbollike(var) - var_int = something(findfirst(isequal(Symbol(var)), syms), 0) + var_int = something(sym_to_index(var,syms),0) else var_int = var end