diff --git a/REQUIRE b/REQUIRE index 1eb1d320a..17059319e 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1,3 +1,4 @@ julia 0.4 Polynomials Compat 0.9 +DiffEqBase diff --git a/src/ODE.jl b/src/ODE.jl index 0cf4334f7..19b779ad8 100644 --- a/src/ODE.jl +++ b/src/ODE.jl @@ -5,6 +5,11 @@ module ODE using Polynomials using Compat +using DiffEqBase + +import DiffEqBase: solve + +include("algorithm_types.jl") ## minimal function export list # adaptive non-stiff: @@ -415,12 +420,15 @@ const s4_coefficients = (0.5, ode4s_s(F, x0, tspan; jacobian=nothing) = oderosenbrock(F, x0, tspan, s4_coefficients...; jacobian=jacobian) # Use Shampine coefficients by default (matching Numerical Recipes) -const ode4s = ode4s_s +ode4s(F, x0, tspan; jacobian=nothing) = ode4s_s(F, x0, tspan; jacobian=nothing) const ms_coefficients4 = [ 1 0 0 0 -1/2 3/2 0 0 5/12 -4/3 23/12 0 -9/24 37/24 -59/24 55/24] +####### Common Interface Bindings + +include("common.jl") end # module ODE diff --git a/src/algorithm_types.jl b/src/algorithm_types.jl new file mode 100644 index 000000000..f22ac3d3e --- /dev/null +++ b/src/algorithm_types.jl @@ -0,0 +1,8 @@ +abstract ODEJLAlgorithm <: AbstractODEAlgorithm +immutable ode23 <: ODEJLAlgorithm end +immutable ode45 <: ODEJLAlgorithm end +immutable ode23s <: ODEJLAlgorithm end +immutable ode78 <: ODEJLAlgorithm end +immutable ode4 <: ODEJLAlgorithm end +immutable ode4ms <: ODEJLAlgorithm end +immutable ode4s <: ODEJLAlgorithm end diff --git a/src/common.jl b/src/common.jl new file mode 100644 index 000000000..1490ed4b4 --- /dev/null +++ b/src/common.jl @@ -0,0 +1,99 @@ +function solve{uType,tType,isinplace,algType<:ODEJLAlgorithm,F}(prob::AbstractODEProblem{uType,tType,isinplace,F}, + alg::algType,timeseries=[],ts=[],ks=[];dense=true,save_timeseries=true, + saveat=tType[],timeseries_errors=true,reltol = 1e-5, abstol = 1e-8, + dtmin = abs(prob.tspan[2]-prob.tspan[1])/1e-9, + dtmax = abs(prob.tspan[2]-prob.tspan[1])/2.5, + dt = 0.,norm = Base.vecnorm, + kwargs...) + + tspan = prob.tspan + + if tspan[end]-tspan[1] (du = zeros(u); prob.f(t,u,du); vec(du)) + elseif uType <: AbstractArray + f = (t,u) -> vec(prob.f(t,reshape(u,sizeu))) + else + f = prob.f + end + + if uType <: AbstractArray + u0 = vec(prob.u0) + else + u0 = prob.u0 + end + + if typeof(alg) <: ode23 + ts,timeseries_tmp = ODE.ode23(f,u0,Ts, + norm = norm, + abstol=abstol, + reltol=reltol, + maxstep=dtmax, + minstep=dtmin, + initstep=dt, + points=points) + elseif typeof(alg) <: ode45 + ts,timeseries_tmp = ODE.ode45(f,u0,Ts, + norm = norm, + abstol=abstol, + reltol=reltol, + maxstep=dtmax, + minstep=dtmin, + initstep=dt, + points=points) + elseif typeof(alg) <: ode78 + ts,timeseries_tmp = ODE.ode78(f,u0,Ts, + norm = norm, + abstol=abstol, + reltol=reltol, + maxstep=dtmax, + minstep=dtmin, + initstep=dt, + points=points) + elseif typeof(alg) <: ode23s + ts,timeseries_tmp = ODE.ode23s(f,u0,Ts, + norm = norm, + abstol=abstol, + reltol=reltol, + maxstep=dtmax, + minstep=dtmin, + initstep=dt, + points=points) + elseif typeof(alg) <: ode4 + ts,timeseries_tmp = ODE.ode4(f,u0,Ts) + elseif typeof(alg) <: ode4ms + ts,timeseries_tmp = ODE.ode4ms(f,u0,Ts) + elseif typeof(alg) <: ode4s + ts,timeseries_tmp = ODE.ode4s(f,u0,Ts) + end + + # Reshape the result if needed + if uType <: AbstractArray + timeseries = Vector{uType}(0) + for i=1:length(timeseries_tmp) + push!(timeseries,reshape(timeseries_tmp[i],sizeu)) + end + else + timeseries = timeseries_tmp + end + + build_solution(prob,alg,ts,timeseries, + timeseries_errors = timeseries_errors) +end + +export ODEJLAlgorithm, ode23Alg, ode23sAlg, ode45Alg, ode78Alg diff --git a/src/runge_kutta.jl b/src/runge_kutta.jl index 52b501de5..3944966f9 100644 --- a/src/runge_kutta.jl +++ b/src/runge_kutta.jl @@ -216,7 +216,7 @@ ode23(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_rk23; kwargs...) ode45_fe(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_rk45; kwargs...) ode45_dp(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_dopri5; kwargs...) # Use Dormand-Prince version of ode45 by default -const ode45 = ode45_dp +ode45(fn, y0, tspan; kwargs...) = ode45_dp(fn, y0, tspan; kwargs...) ode78(fn, y0, tspan; kwargs...) = oderk_adapt(fn, y0, tspan, bt_feh78; kwargs...) function oderk_adapt(fn, y0, tspan, btab::TableauRKExplicit; kwords...) diff --git a/test/REQUIRE b/test/REQUIRE new file mode 100644 index 000000000..62544ef8d --- /dev/null +++ b/test/REQUIRE @@ -0,0 +1 @@ +DiffEqProblemLibrary diff --git a/test/common.jl b/test/common.jl new file mode 100644 index 000000000..dd462040f --- /dev/null +++ b/test/common.jl @@ -0,0 +1,21 @@ +using ODE, DiffEqBase, DiffEqProblemLibrary + +dt=1/2^(4) + +algs = [ode23(),ode45(),ode78(),ode4(),ode4ms(),ode4s()] # no ode23s + +# Check for errors + +prob = prob_ode_linear + +for alg in algs + sol =solve(prob,alg;dt=dt,abstol=1e-6,reltol=1e-3) + @test typeof(sol[2]) <: Number +end + +prob = prob_ode_2Dlinear + +for alg in algs + sol =solve(prob,alg;dt=dt,dtmin=eps(),abstol=1e-6,reltol=1e-3) + @test size(sol[2]) == (4,2) +end diff --git a/test/runtests.jl b/test/runtests.jl index 6f8770581..f7f1fd293 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -103,5 +103,5 @@ let @test norm(refsol-y[end], Inf) < 2e-10 end include("interface-tests.jl") - +include("common.jl") println("All looks OK")