Skip to content

Commit

Permalink
Merge pull request #92 from ErikQQY/mirk5
Browse files Browse the repository at this point in the history
Add MIRK5 method
  • Loading branch information
ChrisRackauckas committed Jul 22, 2023
2 parents 44fb28d + e20749f commit ec3a956
Show file tree
Hide file tree
Showing 13 changed files with 145 additions and 44 deletions.
4 changes: 2 additions & 2 deletions src/BoundaryValueDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ include("jacobian.jl")
include("solve.jl")

export Shooting
export GeneralMIRK4, GeneralMIRK6
export MIRK4, MIRK6
export GeneralMIRK4, GeneralMIRK5, GeneralMIRK6
export MIRK4, MIRK5, MIRK6

end
15 changes: 14 additions & 1 deletion src/alg_utils.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
alg_order(alg::GeneralMIRK4) = 4
alg_order(alg::GeneralMIRK5) = 5
alg_order(alg::GeneralMIRK6) = 6
alg_order(alg::MIRK4) = 4
alg_order(alg::MIRK5) = 5
alg_order(alg::MIRK6) = 6

alg_stage(alg::GeneralMIRK4) = 3
alg_stage(alg::GeneralMIRK5) = 4
alg_stage(alg::GeneralMIRK6) = 5
alg_stage(alg::MIRK4) = 3
alg_stage(alg::MIRK5) = 4
alg_stage(alg::MIRK6) = 5

SciMLBase.isautodifferentiable(::BoundaryValueDiffEqAlgorithm) = true
SciMLBase.allows_arbitrary_number_types(::BoundaryValueDiffEqAlgorithm) = true
SciMLBase.allowscomplex(alg::BoundaryValueDiffEqAlgorithm) = true
SciMLBase.isadaptive(alg::Union{GeneralMIRK4, GeneralMIRK6, MIRK4, MIRK6}) = false
function SciMLBase.isadaptive(alg::Union{
GeneralMIRK4,
GeneralMIRK5,
GeneralMIRK6,
MIRK4,
MIRK5,
MIRK6,
})
false
end
30 changes: 30 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,20 @@ struct GeneralMIRK4{N} <: GeneralMIRK
nlsolve::N
end

"""
@article{Enright1996RungeKuttaSW,
title={Runge-Kutta Software with Defect Control for Boundary Value ODEs},
author={Wayne H. Enright and Paul H. Muir},
journal={SIAM J. Sci. Comput.},
year={1996},
volume={17},
pages={479-497}
}
"""
struct GeneralMIRK5{N} <: GeneralMIRK
nlsolve::N
end

"""
@article{Enright1996RungeKuttaSW,
title={Runge-Kutta Software with Defect Control for Boundary Value ODEs},
Expand Down Expand Up @@ -55,6 +69,20 @@ struct MIRK4{N} <: MIRK
nlsolve::N
end

"""
@article{Enright1996RungeKuttaSW,
title={Runge-Kutta Software with Defect Control for Boundary Value ODEs},
author={Wayne H. Enright and Paul H. Muir},
journal={SIAM J. Sci. Comput.},
year={1996},
volume={17},
pages={479-497}
}
"""
struct MIRK5{N} <: MIRK
nlsolve::N
end

"""
@article{Enright1996RungeKuttaSW,
title={Runge-Kutta Software with Defect Control for Boundary Value ODEs},
Expand All @@ -70,6 +98,8 @@ struct MIRK6{N} <: MIRK
end

GeneralMIRK4(; nlsolve = DEFAULT_NLSOLVE_MIRK) = GeneralMIRK4(nlsolve)
GeneralMIRK5(; nlsolve = DEFAULT_NLSOLVE_MIRK) = GeneralMIRK5(nlsolve)
GeneralMIRK6(; nlsolve = DEFAULT_NLSOLVE_MIRK) = GeneralMIRK6(nlsolve)
MIRK4(; nlsolve = DEFAULT_NLSOLVE_MIRK) = MIRK4(nlsolve)
MIRK5(; nlsolve = DEFAULT_NLSOLVE_MIRK) = MIRK5(nlsolve)
MIRK6(; nlsolve = DEFAULT_NLSOLVE_MIRK) = MIRK6(nlsolve)
10 changes: 10 additions & 0 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ function alg_cache(alg::Union{GeneralMIRK4, MIRK4}, S::BVPSystem{T, U}) where {T
MIRK4GeneralCache([similar(S.y[1]) for i in 1:(S.s)])
end

struct MIRK5GeneralCache{kType} <: GeneralMIRKCache
K::kType
end

@truncate_stacktrace MIRK5GeneralCache

function alg_cache(alg::Union{GeneralMIRK5, MIRK5}, S::BVPSystem{T, U}) where {T, U}
MIRK5GeneralCache([similar(S.y[1]) for i in 1:(S.s)])
end

struct MIRK6GeneralCache{kType} <: GeneralMIRKCache
K::kType
end
Expand Down
8 changes: 4 additions & 4 deletions src/collocation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ function BVPSystem(fun, bc, p, x, M::Integer, alg::Union{GeneralMIRK, MIRK})
order = alg_order(alg)
s = alg_stage(alg)
BVPSystem(order, M, N, fun, bc, p, s, x, y, vector_alloc(T, M, N),
vector_alloc(T, M, N),
eltype(y)(undef, M))
vector_alloc(T, M, N),
eltype(y)(undef, M))
end

# If user offers an intial guess
Expand All @@ -17,7 +17,7 @@ function BVPSystem(fun, bc, p, x, y, alg::Union{GeneralMIRK, MIRK})
order = alg_order(alg)
s = alg_stage(alg)
BVPSystem{T, U}(order, M, N, fun, bc, p, s, x, y, vector_alloc(T, M, N),
vector_alloc(T, M, N), eltype(y)(M))
vector_alloc(T, M, N), eltype(y)(M))
end

# Dispatch aware of eltype(x) != eltype(prob.u0)
Expand All @@ -28,7 +28,7 @@ function BVPSystem(prob::BVProblem, x, alg::Union{GeneralMIRK, MIRK})
order = alg_order(alg)
s = alg_stage(alg)
BVPSystem(order, M, N, prob.f, prob.bc, prob.p, s, x, y, deepcopy(y),
deepcopy(y), typeof(x)(undef, M))
deepcopy(y), typeof(x)(undef, M))
end

# Auxiliary functions for evaluation
Expand Down
2 changes: 1 addition & 1 deletion src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ end
(jw::BVPJacobianWrapper)(u, p) = (resid = similar(u); jw.loss(resid, u, p); resid)

function _construct_nonlinear_problem_with_jacobian(f!::BVPJacobianWrapper, S::BVPSystem,
y, p)
y, p)
jac_cache = FiniteDiff.JacobianCache(similar(y), similar(y), similar(y))
function jac!(J, x, p)
F = jac_cache.fx
Expand Down
37 changes: 26 additions & 11 deletions src/mirk_tableaus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,42 @@ function constructMIRK_IV(S::BVPSystem{T, U}) where {T, U}
v = [0, 1, 1 // 2, 27 // 32]
b = [1 // 6, 1 // 6, 2 // 3, 0]
x = [0 0 0 0
0 0 0 0
1//8 -1//8 0 0
3//64 -9//64 0 0]
0 0 0 0
1//8 -1//8 0 0
3//64 -9//64 0 0]
MIRKTableau(T.(c), T.(v), T.(b), T.(x))
end

MIRK_dispatcher(S::BVPSystem, ::Type{Val{4}}) = constructMIRK_IV(S)

function constructMIRK_V(S::BVPSystem{T, U}) where {T, U}
c = [0, 1, 3 // 4, 3 // 10, 4 // 5, 13 // 23]
v = [0, 1, 27 // 32, 837 // 1250, 4 // 5, 13 // 23]
b = [5 // 54, 1 // 14, 32 // 81, 250 // 567]
x = [0 0 0 0 0 0
0 0 0 0 0 0
3//64 -9//64 0 0 0 0
21//1000 63//5000 -252//625 0 0 0
14//1125 -74//875 -128//3375 104//945 0 0
1//2 4508233//1958887 48720832//2518569 -27646420//17629983 -11517095//559682 0]
MIRKTableau(T.(c), T.(v), T.(b), T.(x))
end

MIRK_dispatcher(S::BVPSystem, ::Type{Val{5}}) = constructMIRK_V(S)

function constructMIRK_VI(S::BVPSystem{T, U}) where {T, U}
c = [0, 1, 1 // 4, 3 // 4, 1 // 2, 7 // 16, 1 // 8, 9 // 16, 3 // 8]
v = [0, 1, 5 // 32, 27 // 32, 1 // 2, 7 // 16, 1 // 8, 9 // 16, 3 // 8]
b = [7 // 90, 7 // 90, 16 // 45, 16 // 45, 2 // 15, 0, 0, 0, 0]
x = [0 0 0 0 0 0 0 0 0
0 0 0 0 0 0 0 0 0
9//64 -3//64 0 0 0 0 0 0 0
3//64 -9//64 0 0 0 0 0 0 0
-5//24 5//24 2//3 -2//3 0 0 0 0 0
1547//32768 -1225//32768 749//4096 -287//2048 -861//16384 0 0 0 0
83//1536 -13//384 283//1536 -167//1536 -49//512 0 0 0 0
1225//32768 -1547//32768 287//2048 -749//4096 861//16384 0 0 0 0
233//3456 -19//1152 0 0 0 -5//72 7//72 -17//216 0]
0 0 0 0 0 0 0 0 0
9//64 -3//64 0 0 0 0 0 0 0
3//64 -9//64 0 0 0 0 0 0 0
-5//24 5//24 2//3 -2//3 0 0 0 0 0
1547//32768 -1225//32768 749//4096 -287//2048 -861//16384 0 0 0 0
83//1536 -13//384 283//1536 -167//1536 -49//512 0 0 0 0
1225//32768 -1547//32768 287//2048 -749//4096 861//16384 0 0 0 0
233//3456 -19//1152 0 0 0 -5//72 7//72 -17//216 0]
MIRKTableau(T.(c), T.(v), T.(b), T.(x))
end

Expand Down
8 changes: 4 additions & 4 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@ function DiffEqBase.__solve(prob::BVProblem, alg::Shooting; kwargs...)
return nothing
end
opt = solve(NonlinearProblem(NonlinearFunction{true}(loss!), u0, prob.p), alg.nlsolve;
kwargs...)
kwargs...)
sol_prob = ODEProblem{iip}(prob.f, opt.u, prob.tspan, prob.p)
sol = solve(sol_prob, alg.ode_alg; kwargs...)
return DiffEqBase.solution_new_retcode(sol,
sol.retcode == opt.retcode ? ReturnCode.Success :
ReturnCode.Failure)
sol.retcode == opt.retcode ? ReturnCode.Success :
ReturnCode.Failure)
end

function DiffEqBase.__solve(prob::BVProblem, alg::Union{GeneralMIRK, MIRK}; dt = 0.0,
kwargs...)
kwargs...)
dt 0 && throw(ArgumentError("dt must be positive"))
n = Int(cld((prob.tspan[2] - prob.tspan[1]), dt))
x = collect(range(prob.tspan[1], stop = prob.tspan[2], length = n + 1))
Expand Down
4 changes: 2 additions & 2 deletions src/vector_auxiliary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function vector_alloc(u0, x)
end

function flatten_vector!(dest::T,
src::Vector{T2}) where {T <: AbstractArray, T2 <: AbstractArray}
src::Vector{T2}) where {T <: AbstractArray, T2 <: AbstractArray}
N = length(src)
M = length(src[1])
for i in eachindex(src)
Expand All @@ -26,7 +26,7 @@ function flatten_vector!(dest::T,
end

function nest_vector!(dest::Vector{T},
src::T2) where {T <: AbstractArray, T2 <: AbstractArray}
src::T2) where {T <: AbstractArray, T2 <: AbstractArray}
M = length(dest[1])
for i in eachindex(dest)
copyto!(dest[i], src[((M * (i - 1)) + 1):(M * i)])
Expand Down
2 changes: 2 additions & 0 deletions test/ensemble.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ p = [rand()]
bvp = BVProblem(ode!, bc!, initial_guess, tspan, p)
ensemble_prob = EnsembleProblem(bvp, prob_func = prob_func)
@test_nowarn sim = solve(ensemble_prob, GeneralMIRK4(), trajectories = 10, dt = 0.1)
@test_nowarn sim = solve(ensemble_prob, GeneralMIRK5(), trajectories = 10, dt = 0.1)
@test_nowarn sim = solve(ensemble_prob, GeneralMIRK5(), trajectories = 10, dt = 0.1)
25 changes: 23 additions & 2 deletions test/mirk_convergence_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ end
# Not able to change the initial condition.
# Hard coded solution.
func_2 = ODEFunction(func_2!,
analytic = (u0, p, t) -> [5 * (cos(t) - cot(5) * sin(t)),
5 * (-cos(t) * cot(5) - sin(t))])
analytic = (u0, p, t) -> [5 * (cos(t) - cot(5) * sin(t)),
5 * (-cos(t) * cot(5) - sin(t))])
tspan = (0.0, 5.0)
u0 = [5.0, -3.5]
probArr = [
Expand All @@ -57,6 +57,11 @@ prob = probArr[1]
@time sol = solve(prob, GeneralMIRK4(), dt = 0.2)
@test norm(diff(first.(sol.u)) .+ 0.2, Inf) + abs(sol[1][1] - 5) < affineTol

# GeneralMIRK5

@time sol = solve(prob, GeneralMIRK5(), dt = 0.2)
@test norm(diff(first.(sol.u)) .+ 0.2, Inf) + abs(sol[1][1] - 5) < affineTol

# GeneralMIRK6

@time sol = solve(prob, GeneralMIRK6(), dt = 0.2)
Expand All @@ -70,6 +75,11 @@ prob = probArr[2]
@time sim = test_convergence(dts, prob, GeneralMIRK4(); abstol = 1e-13, reltol = 1e-13);
@test sim.𝒪est[:final]4 atol=testTol

# GeneralMIRK5

@time sim = test_convergence(dts, prob, GeneralMIRK5(); abstol = 1e-13, reltol = 1e-13);
@test sim.𝒪est[:final]5 atol=testTol

# GeneralMIRK6

@time sim = test_convergence(dts, prob, GeneralMIRK6(); abstol = 1e-13, reltol = 1e-13);
Expand All @@ -84,6 +94,11 @@ prob = probArr[3]
@time sol = solve(prob, MIRK4(), dt = 0.2)
@test norm(diff(map(x -> x[1], sol.u)) .+ 0.2, Inf) .+ abs(sol[1][1] - 5) < affineTol

# MIRK5

@time sol = solve(prob, MIRK5(), dt = 0.2)
@test norm(diff(map(x -> x[1], sol.u)) .+ 0.2, Inf) .+ abs(sol[1][1] - 5) < affineTol

# MIRK6

@time sol = solve(prob, MIRK6(), dt = 0.2)
Expand All @@ -97,6 +112,11 @@ prob = probArr[4]
@time sim = test_convergence(dts, prob, MIRK4(); abstol = 1e-13, reltol = 1e-13);
@test sim.𝒪est[:final]4 atol=testTol

# MIRK5

@time sim = test_convergence(dts, prob, MIRK5(); abstol = 1e-13, reltol = 1e-13);
@test sim.𝒪est[:final]5 atol=testTol

# MIRK6

@time sim = test_convergence(dts, prob, MIRK6(); abstol = 1e-13, reltol = 1e-13);
Expand All @@ -121,4 +141,5 @@ end
u0 = MVector{2}([pi / 2, pi / 2])
bvp1 = BVProblem(simplependulum!, bc1!, u0, tspan)
@test_nowarn solve(bvp1, GeneralMIRK4(), dt = 0.05)
@test_nowarn solve(bvp1, GeneralMIRK5(), dt = 0.05)
@test_nowarn solve(bvp1, GeneralMIRK6(), dt = 0.05)
24 changes: 12 additions & 12 deletions test/orbital.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,27 +60,27 @@ TestTol = 0.05
bvp = BVProblem(orbital, cur_bc!, y0, tspan)
nlsolve = NewtonRaphson(; autodiff = Val(false), diff_type = Val(:central))
@time sol = solve(bvp,
Shooting(DP5(); nlsolve),
force_dtmin = true,
abstol = 1e-13,
reltol = 1e-13)
Shooting(DP5(); nlsolve),
force_dtmin = true,
abstol = 1e-13,
reltol = 1e-13)
cur_bc!(resid_f, sol, nothing, sol.t)
@test norm(resid_f, Inf) < TestTol

nlsolve = NewtonRaphson(; autodiff = Val(false), diff_type = Val(:forward))
@time sol = solve(bvp,
Shooting(DP5(); nlsolve),
force_dtmin = true,
abstol = 1e-13,
reltol = 1e-13)
Shooting(DP5(); nlsolve),
force_dtmin = true,
abstol = 1e-13,
reltol = 1e-13)
cur_bc!(resid_f, sol, nothing, sol.t)
@test norm(resid_f, Inf) < TestTol

nlsolve = NewtonRaphson(; autodiff = Val(true))
@time sol = solve(bvp,
Shooting(DP5(); nlsolve),
force_dtmin = true,
abstol = 1e-13,
reltol = 1e-13)
Shooting(DP5(); nlsolve),
force_dtmin = true,
abstol = 1e-13,
reltol = 1e-13)
cur_bc!(resid_f, sol, nothing, sol.t)
@test norm(resid_f, Inf) < TestTol
20 changes: 15 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,23 @@ using Test, SafeTestsets

@testset "Boundary Value Problem Tests" begin
@time @testset "Shooting Method Tests" begin
@time @safetestset "Shooting Tests" begin include("shooting_tests.jl") end
@time @safetestset "Orbital" begin include("orbital.jl") end
@time @safetestset "Shooting Tests" begin
include("shooting_tests.jl")
end
@time @safetestset "Orbital" begin
include("orbital.jl")
end
end

@time @testset "Collocation Method (MIRK) Tests" begin
@time @safetestset "Ensemble" begin include("ensemble.jl") end
@time @safetestset "MIRK Convergence Tests" begin include("mirk_convergence_tests.jl") end
@time @safetestset "Vector of Vector" begin include("vectorofvector_initials.jl") end
@time @safetestset "Ensemble" begin
include("ensemble.jl")
end
@time @safetestset "MIRK Convergence Tests" begin
include("mirk_convergence_tests.jl")
end
@time @safetestset "Vector of Vector" begin
include("vectorofvector_initials.jl")
end
end
end

0 comments on commit ec3a956

Please sign in to comment.