Skip to content

Commit

Permalink
Merge d328716 into a11e3fe
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Dec 12, 2016
2 parents a11e3fe + d328716 commit 776c2b4
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 3 deletions.
1 change: 1 addition & 0 deletions REQUIRE
@@ -1,3 +1,4 @@
julia 0.4
Polynomials
Compat 0.9
DiffEqBase
10 changes: 9 additions & 1 deletion src/ODE.jl
Expand Up @@ -5,6 +5,11 @@ module ODE

using Polynomials
using Compat
using DiffEqBase

import DiffEqBase: solve

include("algorithm_types.jl")

## minimal function export list
# adaptive non-stiff:
Expand Down Expand Up @@ -415,12 +420,15 @@ const s4_coefficients = (0.5,
ode4s_s(F, x0, tspan; jacobian=nothing) = oderosenbrock(F, x0, tspan, s4_coefficients...; jacobian=jacobian)

# Use Shampine coefficients by default (matching Numerical Recipes)
const ode4s = ode4s_s
ode4s(F, x0, tspan; jacobian=nothing) = ode4s_s(F, x0, tspan; jacobian=nothing)

const ms_coefficients4 = [ 1 0 0 0
-1/2 3/2 0 0
5/12 -4/3 23/12 0
-9/24 37/24 -59/24 55/24]

####### Common Interface Bindings

include("common.jl")

end # module ODE
8 changes: 8 additions & 0 deletions src/algorithm_types.jl
@@ -0,0 +1,8 @@
abstract ODEJLAlgorithm <: AbstractODEAlgorithm
immutable ode23 <: ODEJLAlgorithm end
immutable ode45 <: ODEJLAlgorithm end
immutable ode23s <: ODEJLAlgorithm end
immutable ode78 <: ODEJLAlgorithm end
immutable ode4 <: ODEJLAlgorithm end
immutable ode4ms <: ODEJLAlgorithm end
immutable ode4s <: ODEJLAlgorithm end
99 changes: 99 additions & 0 deletions src/common.jl
@@ -0,0 +1,99 @@
function solve{uType,tType,isinplace,algType<:ODEJLAlgorithm,F}(prob::AbstractODEProblem{uType,tType,isinplace,F},
alg::algType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true,
saveat=tType[],timeseries_errors=true,reltol = 1e-5, abstol = 1e-8,
dtmin = abs(prob.tspan[2]-prob.tspan[1])/1e-9,
dtmax = abs(prob.tspan[2]-prob.tspan[1])/2.5,
dt = 0.,norm = Base.vecnorm,
kwargs...)

tspan = prob.tspan

if tspan[end]-tspan[1]<tType(0)
error("final time must be greater than starting time. Aborting.")
end

u0 = prob.u0

Ts = sort(unique([tspan[1];saveat;tspan[2]]))

if save_timeseries
points = :all
else
points = :specified
end

sizeu = size(prob.u0)

if isinplace
f = (t,u) -> (du = zeros(u); prob.f(t,u,du); vec(du))
elseif uType <: AbstractArray
f = (t,u) -> vec(prob.f(t,reshape(u,sizeu)))
else
f = prob.f
end

if uType <: AbstractArray
u0 = vec(prob.u0)
else
u0 = prob.u0
end

if typeof(alg) <: ode23
ts,timeseries_tmp = ODE.ode23(f,u0,Ts,
norm = norm,
abstol=abstol,
reltol=reltol,
maxstep=dtmax,
minstep=dtmin,
initstep=dt,
points=points)
elseif typeof(alg) <: ode45
ts,timeseries_tmp = ODE.ode45(f,u0,Ts,
norm = norm,
abstol=abstol,
reltol=reltol,
maxstep=dtmax,
minstep=dtmin,
initstep=dt,
points=points)
elseif typeof(alg) <: ode78
ts,timeseries_tmp = ODE.ode78(f,u0,Ts,
norm = norm,
abstol=abstol,
reltol=reltol,
maxstep=dtmax,
minstep=dtmin,
initstep=dt,
points=points)
elseif typeof(alg) <: ode23s
ts,timeseries_tmp = ODE.ode23s(f,u0,Ts,
norm = norm,
abstol=abstol,
reltol=reltol,
maxstep=dtmax,
minstep=dtmin,
initstep=dt,
points=points)
elseif typeof(alg) <: ode4
ts,timeseries_tmp = ODE.ode4(f,u0,Ts)
elseif typeof(alg) <: ode4ms
ts,timeseries_tmp = ODE.ode4ms(f,u0,Ts)
elseif typeof(alg) <: ode4s
ts,timeseries_tmp = ODE.ode4s(f,u0,Ts)
end

# Reshape the result if needed
if uType <: AbstractArray
timeseries = Vector{uType}(0)
for i=1:length(timeseries_tmp)
push!(timeseries,reshape(timeseries_tmp[i],sizeu))
end
else
timeseries = timeseries_tmp
end

build_solution(prob,alg,ts,timeseries,
timeseries_errors = timeseries_errors)
end

export ODEJLAlgorithm, ode23Alg, ode23sAlg, ode45Alg, ode78Alg
2 changes: 1 addition & 1 deletion src/runge_kutta.jl
Expand Up @@ -216,7 +216,7 @@ ode23(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_rk23; kwargs...)
ode45_fe(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_rk45; kwargs...)
ode45_dp(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_dopri5; kwargs...)
# Use Dormand-Prince version of ode45 by default
const ode45 = ode45_dp
ode45(fn, y0, tspan; kwargs...) = ode45_dp(fn, y0, tspan; kwargs...)
ode78(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_feh78; kwargs...)

function oderk_adapt(fn, y0, tspan, btab::TableauRKExplicit; kwords...)
Expand Down
1 change: 1 addition & 0 deletions test/REQUIRE
@@ -0,0 +1 @@
DiffEqProblemLibrary
21 changes: 21 additions & 0 deletions test/common.jl
@@ -0,0 +1,21 @@
using ODE, DiffEqBase, DiffEqProblemLibrary

dt=1/2^(4)

algs = [ode23(),ode45(),ode78(),ode4(),ode4ms(),ode4s()] # no ode23s

# Check for errors

prob = prob_ode_linear

for alg in algs
sol =solve(prob,alg;dt=dt,abstol=1e-6,reltol=1e-3)
@test typeof(sol[2]) <: Number
end

prob = prob_ode_2Dlinear

for alg in algs
sol =solve(prob,alg;dt=dt,dtmin=eps(),abstol=1e-6,reltol=1e-3)
@test size(sol[2]) == (4,2)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Expand Up @@ -103,5 +103,5 @@ let
@test norm(refsol-y[end], Inf) < 2e-10
end
include("interface-tests.jl")

include("common.jl")
println("All looks OK")

0 comments on commit 776c2b4

Please sign in to comment.