diff --git a/src/interpolation_methods.jl b/src/interpolation_methods.jl index 697f7375..03adf558 100644 --- a/src/interpolation_methods.jl +++ b/src/interpolation_methods.jl @@ -33,11 +33,12 @@ function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, igues val end -function _interpolate(A::LinearInterpolation{<:AbstractMatrix}, t::Number, iguess) +function _interpolate(A::LinearInterpolation{<:AbstractArray}, t::Number, iguess) idx = get_idx(A, t, iguess) Δt = t - A.t[idx] slope = get_parameters(A, idx) - return A.u[:, idx] + slope * Δt + ax = axes(A.u)[1:(end - 1)] + return A.u[ax..., idx] + slope * Δt end # Quadratic Interpolation diff --git a/src/interpolation_utils.jl b/src/interpolation_utils.jl index 0df5e484..f8fc7147 100644 --- a/src/interpolation_utils.jl +++ b/src/interpolation_utils.jl @@ -94,6 +94,21 @@ function munge_data(U::StridedMatrix, t::AbstractVector) return U, t end +function munge_data(U::AbstractArray{T, N}, t) where {T, N} + TU = Base.nonmissingtype(eltype(U)) + Tt = Base.nonmissingtype(eltype(t)) + @assert length(t) == size(U, ndims(U)) + ax = axes(U)[1:(end - 1)] + non_missing_indices = collect( + i for i in 1:length(t) + if !any(ismissing, U[ax..., i]) && !ismissing(t[i]) + ) + U = cat([TU.(U[ax..., i]) for i in non_missing_indices]...; dims = ndims(U)) + t = Tt.([t[i] for i in non_missing_indices]) + + return U, t +end + seems_linear(assume_linear_t::Bool, _) = assume_linear_t seems_linear(assume_linear_t::Number, t) = looks_linear(t; threshold = assume_linear_t) diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 0701b3a2..824b2e2a 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -18,9 +18,11 @@ function safe_diff(b, a::T) where {T} b == a ? zero(T) : b - a end -function linear_interpolation_parameters(u::AbstractArray{T}, t, idx) where {T} - Δu = if u isa AbstractMatrix - [safe_diff(u[j, idx + 1], u[j, idx]) for j in 1:size(u)[1]] +function linear_interpolation_parameters(u::AbstractArray{T, N}, t, idx) where {T, N} + Δu = if N > 1 + ax = axes(u) + safe_diff.( + u[ax[1:(end - 1)]..., (idx + 1):(idx + 1)], u[ax[1:(end - 1)]..., idx:idx]) else safe_diff(u[idx + 1], u[idx]) end diff --git a/test/interpolation_tests.jl b/test/interpolation_tests.jl index 3ebc1741..69e5c197 100644 --- a/test/interpolation_tests.jl +++ b/test/interpolation_tests.jl @@ -43,11 +43,11 @@ end A = LinearInterpolation(u, t; extrapolate = true) for (_t, _u) in zip(t, eachcol(u)) - @test A(_t) == _u + @test A(_t) == reshape(_u, :, 1) end - @test A(0) == [0.0, 0.0] - @test A(5.5) == [11.0, 16.5] - @test A(11) == [22, 33] + @test A(0) == [0.0; 0.0;;] + @test A(5.5) == [11.0; 16.5;;] + @test A(11) == [22; 33;;] x = 1:10 y = 2:4