Skip to content

Commit

Permalink
Merge pull request #119 from JuliaDiffEq/common
Browse files Browse the repository at this point in the history
add common interface bindings
  • Loading branch information
ChrisRackauckas committed Dec 22, 2016
2 parents a11e3fe + 8af616f commit 8d4827b
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 15 deletions.
1 change: 0 additions & 1 deletion .travis.yml
Expand Up @@ -3,7 +3,6 @@ os:
- osx
- linux
julia:
- 0.4
- release
- nightly
matrix:
Expand Down
3 changes: 2 additions & 1 deletion REQUIRE
@@ -1,3 +1,4 @@
julia 0.4
julia 0.5
Polynomials
Compat 0.9
DiffEqBase
4 changes: 2 additions & 2 deletions appveyor.yml
@@ -1,7 +1,7 @@
environment:
matrix:
- JULIAVERSION: "julialang/bin/winnt/x86/0.4/julia-0.4-latest-win32.exe"
- JULIAVERSION: "julialang/bin/winnt/x64/0.4/julia-0.4-latest-win64.exe"
- JULIAVERSION: "julialang/bin/winnt/x86/0.5/julia-0.5-latest-win32.exe"
- JULIAVERSION: "julialang/bin/winnt/x64/0.5/julia-0.5-latest-win64.exe"
- JULIAVERSION: "julianightlies/bin/winnt/x86/julia-latest-win32.exe"
- JULIAVERSION: "julianightlies/bin/winnt/x64/julia-latest-win64.exe"

Expand Down
26 changes: 19 additions & 7 deletions src/ODE.jl
Expand Up @@ -5,6 +5,11 @@ module ODE

using Polynomials
using Compat
using DiffEqBase

import DiffEqBase: solve

include("algorithm_types.jl")

## minimal function export list
# adaptive non-stiff:
Expand All @@ -16,6 +21,9 @@ export ode23s
# non-adaptive stiff:
export ode4s

# Common Interface
export ODEjlAlgorithm

## complete function export list: see runtests.jl

###############################################################################
Expand Down Expand Up @@ -152,7 +160,7 @@ include("runge_kutta.jl")

# ODE_MS Fixed-step, fixed-order multi-step numerical method
# with Adams-Bashforth-Moulton coefficients
function ode_ms(F, x0, tspan, order::Integer)
function ode_ms(F, x0, tspan, order::Integer; kwargs...)
h = diff(tspan)
x = Array(typeof(x0), length(tspan))
x[1] = x0
Expand Down Expand Up @@ -189,8 +197,8 @@ function ode_ms(F, x0, tspan, order::Integer)
end

# Use order 4 by default
ode4ms(F, x0, tspan) = ode_ms(F, x0, tspan, 4)
ode5ms(F, x0, tspan) = ODE.ode_ms(F, x0, tspan, 5)
ode4ms(F, x0, tspan; kwargs...) = ode_ms(F, x0, tspan, 4; kwargs...)
ode5ms(F, x0, tspan; kwargs...) = ODE.ode_ms(F, x0, tspan, 5; kwargs...)

###############################################################################
## STIFF SOLVERS
Expand Down Expand Up @@ -345,7 +353,7 @@ end

#ODEROSENBROCK Solve stiff differential equations, Rosenbrock method
# with provided coefficients.
function oderosenbrock(F, x0, tspan, gamma, a, b, c; jacobian=nothing)
function oderosenbrock(F, x0, tspan, gamma, a, b, c; jacobian=nothing, kwargs...)

if typeof(jacobian) == Function
G = jacobian
Expand Down Expand Up @@ -398,7 +406,7 @@ const kr4_coefficients = (0.231,
6.02015272865 0.1597500684673 0 0
-1.856343618677 -8.50538085819 -2.08407513602 0],)

ode4s_kr(F, x0, tspan; jacobian=nothing) = oderosenbrock(F, x0, tspan, kr4_coefficients...; jacobian=jacobian)
ode4s_kr(F, x0, tspan; jacobian=nothing, kwargs...) = oderosenbrock(F, x0, tspan, kr4_coefficients...; jacobian=jacobian, kwargs...)

# Shampine coefficients
const s4_coefficients = (0.5,
Expand All @@ -412,15 +420,19 @@ const s4_coefficients = (0.5,
372/25 12/5 0 0
-112/125 -54/125 -2/5 0],)

ode4s_s(F, x0, tspan; jacobian=nothing) = oderosenbrock(F, x0, tspan, s4_coefficients...; jacobian=jacobian)
ode4s_s(F, x0, tspan; jacobian=nothing, kwargs...) =
oderosenbrock(F, x0, tspan, s4_coefficients...; jacobian=jacobian, kwargs...)

# Use Shampine coefficients by default (matching Numerical Recipes)
const ode4s = ode4s_s
ode4s(F, x0, tspan; jacobian=nothing, kwargs...) = ode4s_s(F, x0, tspan; jacobian=nothing, kwargs...)

const ms_coefficients4 = [ 1 0 0 0
-1/2 3/2 0 0
5/12 -4/3 23/12 0
-9/24 37/24 -59/24 55/24]

####### Common Interface Bindings

include("common.jl")

end # module ODE
10 changes: 10 additions & 0 deletions src/algorithm_types.jl
@@ -0,0 +1,10 @@
abstract ODEjlAlgorithm <: AbstractODEAlgorithm
# Making the ODE-solver functions into types lets us dispatch on them.
# Used in the DiffEqBase interface.
immutable ode23 <: ODEjlAlgorithm end
immutable ode45 <: ODEjlAlgorithm end
immutable ode23s <: ODEjlAlgorithm end
immutable ode78 <: ODEjlAlgorithm end
immutable ode4 <: ODEjlAlgorithm end
immutable ode4ms <: ODEjlAlgorithm end
immutable ode4s <: ODEjlAlgorithm end
56 changes: 56 additions & 0 deletions src/common.jl
@@ -0,0 +1,56 @@
function solve{uType,tType,isinplace,AlgType<:ODEjlAlgorithm,F}(prob::AbstractODEProblem{uType,tType,isinplace,F},
alg::AlgType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true,
saveat=tType[],timeseries_errors=true,reltol = 1e-5, abstol = 1e-8,
dtmin = abs(prob.tspan[2]-prob.tspan[1])/1e-9,
dtmax = abs(prob.tspan[2]-prob.tspan[1])/2.5,
dt = 0.,norm = Base.vecnorm,
kwargs...)

tspan = prob.tspan

u0 = prob.u0

Ts = unique([tspan[1];saveat;tspan[2]])

if save_timeseries
points = :all
else
points = :specified
end

sizeu = size(prob.u0)

if isinplace
f = (t,u) -> (du = zeros(u); prob.f(t,u,du); vec(du))
elseif uType <: AbstractArray
f = (t,u) -> vec(prob.f(t,reshape(u,sizeu)))
else
f = prob.f
end

u0 = uType <: AbstractArray ? vec(prob.u0) : prob.u0

# Calling the solver, i.e. if the algorithm is ode45,
# then AlgType(...) is ode45(...)
ts,timeseries_tmp = AlgType(f,u0,Ts;
norm = norm,
abstol=abstol,
reltol=reltol,
maxstep=dtmax,
minstep=dtmin,
initstep=dt,
points=points)

# Reshape the result if needed
if uType <: AbstractArray
timeseries = Vector{uType}(0)
for i=1:length(timeseries_tmp)
push!(timeseries,reshape(timeseries_tmp[i],sizeu))
end
else
timeseries = timeseries_tmp
end

build_solution(prob,alg,ts,timeseries,
timeseries_errors = timeseries_errors)
end
6 changes: 3 additions & 3 deletions src/runge_kutta.jl
Expand Up @@ -165,9 +165,9 @@ const bt_feh78 = TableauRKExplicit(:feh78, (7,8), Rational{Int64},
ode1(fn, y0, tspan) = oderk_fixed(fn, y0, tspan, bt_feuler)
ode2_midpoint(fn, y0, tspan) = oderk_fixed(fn, y0, tspan, bt_midpoint)
ode2_heun(fn, y0, tspan) = oderk_fixed(fn, y0, tspan, bt_heun)
ode4(fn, y0, tspan) = oderk_fixed(fn, y0, tspan, bt_rk4)
ode4(fn, y0, tspan;kwargs...) = oderk_fixed(fn, y0, tspan, bt_rk4;kwargs...)

function oderk_fixed(fn, y0, tspan, btab::TableauRKExplicit)
function oderk_fixed(fn, y0, tspan, btab::TableauRKExplicit;kwargs...)
# Non-arrays y0 treat as scalar
fn_(t, y) = [fn(t, y[1])]
t,y = oderk_fixed(fn_, [y0], tspan, btab)
Expand Down Expand Up @@ -216,7 +216,7 @@ ode23(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_rk23; kwargs...)
ode45_fe(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_rk45; kwargs...)
ode45_dp(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_dopri5; kwargs...)
# Use Dormand-Prince version of ode45 by default
const ode45 = ode45_dp
ode45(fn, y0, tspan; kwargs...) = ode45_dp(fn, y0, tspan; kwargs...)
ode78(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_feh78; kwargs...)

function oderk_adapt(fn, y0, tspan, btab::TableauRKExplicit; kwords...)
Expand Down
1 change: 1 addition & 0 deletions test/REQUIRE
@@ -0,0 +1 @@
DiffEqProblemLibrary
21 changes: 21 additions & 0 deletions test/common.jl
@@ -0,0 +1,21 @@
using ODE, DiffEqBase, DiffEqProblemLibrary

dt=1/2^(4)

algs = [ode23(),ode45(),ode78(),ode4(),ode4ms(),ode4s()] # no ode23s

# Check for errors

prob = prob_ode_linear

for alg in algs
sol =solve(prob,alg;dt=dt,abstol=1e-6,reltol=1e-3)
@test typeof(sol[2]) <: Number
end

prob = prob_ode_2Dlinear

for alg in algs
sol =solve(prob,alg;dt=dt,dtmin=eps(),abstol=1e-6,reltol=1e-3)
@test size(sol[2]) == (4,2)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Expand Up @@ -103,5 +103,5 @@ let
@test norm(refsol-y[end], Inf) < 2e-10
end
include("interface-tests.jl")

include("common.jl")
println("All looks OK")

0 comments on commit 8d4827b

Please sign in to comment.