Skip to content

Commit

Permalink
Merge pull request #102 from JuliaDiffEq/myb/dis
Browse files Browse the repository at this point in the history
Interpolation at discontinuity toggle
  • Loading branch information
ChrisRackauckas committed Sep 8, 2018
2 parents 73b5074 + 4253d33 commit 16d72e4
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/composite_solution.jl
Expand Up @@ -13,8 +13,8 @@ struct RODECompositeSolution{T,N,uType,uType2,EType,tType,randType,P,A,IType} <:
retcode::Symbol
seed::UInt64
end
(sol::RODECompositeSolution)(t,deriv::Type=Val{0};idxs=nothing) = sol.interp(t,idxs,deriv,sol.prob.p)
(sol::RODECompositeSolution)(v,t,deriv::Type=Val{0};idxs=nothing) = sol.interp(v,t,idxs,deriv,sol.prob.p)
(sol::RODECompositeSolution)(t,deriv::Type=Val{0};idxs=nothing,continuity=:left) = sol.interp(t,idxs,deriv,sol.prob.p,continuity)
(sol::RODECompositeSolution)(v,t,deriv::Type=Val{0};idxs=nothing,continuity=:left) = sol.interp(v,t,idxs,deriv,sol.prob.p,continuity)

function build_solution(
prob::AbstractRODEProblem,
Expand Down
28 changes: 16 additions & 12 deletions src/dense.jl
Expand Up @@ -134,7 +134,7 @@ sde_interpolation(tvals,ts,timeseries,ks)
Get the value at tvals where the solution is known at the
times ts (sorted), with values timeseries and derivatives ks
"""
@inline function sde_interpolation(tvals,id,idxs,deriv,p)
@inline function sde_interpolation(tvals,id,idxs,deriv,p,continuity::Symbol=:left)
@unpack ts,timeseries = id
tdir = sign(ts[end]-ts[1])
idx = sortperm(tvals,rev=tdir<0)
Expand All @@ -152,10 +152,11 @@ times ts (sorted), with values timeseries and derivatives ks
t = tvals[j]
i = searchsortedfirst(@view(ts[i:end]),t,rev=tdir<0)+i-1 # It's in the interval ts[i-1] to ts[i]
if ts[i] == t
k = continuity == :right && ts[i+1] == t ? i+1 : i
if idxs == nothing
vals[j] = timeseries[i]
vals[j] = timeseries[k]
else
vals[j] = timeseries[i][idxs]
vals[j] = timeseries[k][idxs]
end
elseif ts[i-1] == t # Can happen if it's the first value!
if idxs == nothing
Expand All @@ -178,17 +179,18 @@ sde_interpolation(tval::Number,ts,timeseries,ks)
Get the value at tval where the solution is known at the
times ts (sorted), with values timeseries and derivatives ks
"""
@inline function sde_interpolation(tval::Number,id,idxs,deriv,p)
@inline function sde_interpolation(tval::Number,id,idxs,deriv,p,continuity::Symbol=:left)
@unpack ts,timeseries = id
tdir = sign(ts[end]-ts[1])
tdir*tval > tdir*ts[end] && error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
tdir*tval < tdir*ts[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
@inbounds i = searchsortedfirst(ts,tval,rev=tdir<0) # It's in the interval ts[i-1] to ts[i]
@inbounds if ts[i] == tval
k = continuity == :right && ts[i+1] == tval ? i+1 : i
if idxs == nothing
val = timeseries[i]
val = timeseries[k]
else
val = timeseries[i][idxs]
val = timeseries[k][idxs]
end
elseif ts[i-1] == tval # Can happen if it's the first value!
if idxs == nothing
Expand All @@ -204,17 +206,18 @@ times ts (sorted), with values timeseries and derivatives ks
val
end

@inline function sde_interpolation!(out,tval::Number,id,idxs,deriv,p)
@inline function sde_interpolation!(out,tval::Number,id,idxs,deriv,p,continuity::Symbol=:left)
@unpack ts,timeseries = id
tdir = sign(ts[end]-ts[1])
tdir*tval > tdir*ts[end] && error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.")
tdir*tval < tdir*ts[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.")
@inbounds i = searchsortedfirst(ts,tval,rev=tdir<0) # It's in the interval ts[i-1] to ts[i]
@inbounds if ts[i] == tval
k = continuity == :right && ts[i+1] == tval ? i+1 : i
if idxs == nothing
copyto!(out,timeseries[i])
copyto!(out,timeseries[k])
else
copyto!(out,timeseries[i][idxs])
copyto!(out,timeseries[k][idxs])
end
elseif ts[i-1] == tval # Can happen if it's the first value!
if idxs == nothing
Expand All @@ -229,7 +232,7 @@ end
end
end

@inline function sde_interpolation!(vals,tvals,id,idxs,deriv,p)
@inline function sde_interpolation!(vals,tvals,id,idxs,deriv,p,continuity::Symbol=:left)
@unpack ts,timeseries = id
tdir = sign(ts[end]-ts[1])
idx = sortperm(tvals,rev=tdir<0)
Expand All @@ -240,10 +243,11 @@ end
t = tvals[j]
i = searchsortedfirst(@view(ts[i:end]),t,rev=tdir<0)+i-1 # It's in the interval ts[i-1] to ts[i]
if ts[i] == t
k = continuity == :right && ts[i+1] == t ? i+1 : i
if idxs == nothing
vals[j] = timeseries[i]
vals[j] = timeseries[k]
else
vals[j] = timeseries[i][idxs]
vals[j] = timeseries[k][idxs]
end
elseif ts[i-1] == t # Can happen if it's the first value!
if idxs == nothing
Expand Down
4 changes: 2 additions & 2 deletions src/interp_func.jl
Expand Up @@ -4,5 +4,5 @@ struct LinearInterpolationData{uType,tType} <: AbstractDiffEqInterpolation
end

DiffEqBase.interp_summary(::LinearInterpolationData) = "1st order linear"
(interp::LinearInterpolationData)(tvals,idxs,deriv,p) = sde_interpolation(tvals,interp,idxs,deriv,p)
(interp::LinearInterpolationData)(val,tvals,idxs,deriv,p) = sde_interpolation!(val,tvals,interp,idxs,deriv,p)
(interp::LinearInterpolationData)(tvals,idxs,deriv,p,continuity::Symbol=:left) = sde_interpolation(tvals,interp,idxs,deriv,p,continuity)
(interp::LinearInterpolationData)(val,tvals,idxs,deriv,p,continuity::Symbol=:left) = sde_interpolation!(val,tvals,interp,idxs,deriv,p,continuity)

0 comments on commit 16d72e4

Please sign in to comment.