Skip to content

Commit

Permalink
Merge pull request #410 from MSeeker1340/expRK
Browse files Browse the repository at this point in the history
Fourth order EPIRK methods
  • Loading branch information
ChrisRackauckas committed Jul 4, 2018
2 parents 2ee5a8f + 461da92 commit 2207147
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ module OrdinaryDiffEq

export GenericIIF1, GenericIIF2

export LawsonEuler, NorsettEuler, ETD1, ETDRK2, ETDRK3, ETDRK4, HochOst4, Exp4, ETD2
export LawsonEuler, NorsettEuler, ETD1, ETDRK2, ETDRK3, ETDRK4, HochOst4, Exp4, EPIRK4s3A, EPIRK4s3B, ETD2

export SymplecticEuler, VelocityVerlet, VerletLeapfrog, PseudoVerletLeapfrog,
McAte2, Ruth3, McAte3, CandyRoz4, McAte4, McAte42, McAte5,
Expand Down
2 changes: 2 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ alg_order(alg::ETDRK3) = 3
alg_order(alg::ETDRK4) = 4
alg_order(alg::HochOst4) = 4
alg_order(alg::Exp4) = 4
alg_order(alg::EPIRK4s3A) = 4
alg_order(alg::EPIRK4s3B) = 4
alg_order(alg::SplitEuler) = 1
alg_order(alg::ETD2) = 2

Expand Down
10 changes: 6 additions & 4 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -732,11 +732,13 @@ for Alg in [:LawsonEuler, :NorsettEuler, :ETDRK2, :ETDRK3, :ETDRK4, :HochOst4]
@eval Base.@pure $Alg(;krylov=false, m=30, iop=0) = $Alg(krylov, m, iop)
end
ETD1 = NorsettEuler # alias
struct Exp4 <: OrdinaryDiffEqExponentialAlgorithm
m::Int
iop::Int
for Alg in [:Exp4, :EPIRK4s3A, :EPIRK4s3B]
@eval struct $Alg <: OrdinaryDiffEqExponentialAlgorithm
m::Int
iop::Int
end
@eval Base.@pure $Alg(;m=30, iop=0) = $Alg(m, iop)
end
Base.@pure Exp4(;m=30, iop=0) = Exp4(m, iop)
struct SplitEuler <: OrdinaryDiffEqExponentialAlgorithm end
struct ETD2 <: OrdinaryDiffEqExponentialAlgorithm end

Expand Down
65 changes: 62 additions & 3 deletions src/caches/linear_nonlinear_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,6 @@ struct Exp4Cache{uType,rateType,matType,KsType} <: ExpRKCache
tmp::uType
rtmp::rateType
rtmp2::rateType
k7::rateType
K::matType
A::matType
B::matType
Expand All @@ -352,7 +351,7 @@ end
function alg_cache(alg::Exp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
tmp = similar(u) # uType caches
rtmp, rtmp2, k7 = (zeros(rate_prototype) for i = 1:3) # rateType caches
rtmp, rtmp2 = (zeros(rate_prototype) for i = 1:2) # rateType caches
# Allocate matrices
# TODO: units
n = length(u); T = eltype(u)
Expand All @@ -362,7 +361,67 @@ function alg_cache(alg::Exp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnit
# Allocate caches for phiv_timestep
maxiter = min(alg.m, n)
KsCache = _phiv_timestep_caches(u, maxiter, 1)
Exp4Cache(u,uprev,tmp,rtmp,rtmp2,k7,K,A,B,KsCache)
Exp4Cache(u,uprev,tmp,rtmp,rtmp2,K,A,B,KsCache)
end

struct EPIRK4s3AConstantCache <: ExpRKConstantCache end
alg_cache(alg::EPIRK4s3A,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,
uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}}) = EPIRK4s3AConstantCache()

struct EPIRK4s3ACache{uType,rateType,matType,KsType} <: ExpRKCache
u::uType
uprev::uType
tmp::uType
rtmp::rateType
rtmp2::rateType
K::matType
A::matType
B::matType
KsCache::KsType
end
function alg_cache(alg::EPIRK4s3A,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
tmp = similar(u) # uType caches
rtmp, rtmp2 = (zeros(rate_prototype) for i = 1:2) # rateType caches
# Allocate matrices
n = length(u); T = eltype(u)
K = Matrix{T}(n, 2)
A = Matrix{T}(n, n) # TODO: sparse Jacobian support
B = zeros(T, n, 5)
# Allocate caches for phiv_timestep
maxiter = min(alg.m, n)
KsCache = _phiv_timestep_caches(u, maxiter, 4)
EPIRK4s3ACache(u,uprev,tmp,rtmp,rtmp2,K,A,B,KsCache)
end

struct EPIRK4s3BConstantCache <: ExpRKConstantCache end
alg_cache(alg::EPIRK4s3B,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,
uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}}) = EPIRK4s3BConstantCache()

struct EPIRK4s3BCache{uType,rateType,matType,KsType} <: ExpRKCache
u::uType
uprev::uType
tmp::uType
rtmp::rateType
rtmp2::rateType
K::matType
A::matType
B::matType
KsCache::KsType
end
function alg_cache(alg::EPIRK4s3B,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
tmp = similar(u) # uType caches
rtmp, rtmp2 = (zeros(rate_prototype) for i = 1:2) # rateType caches
# Allocate matrices
n = length(u); T = eltype(u)
K = Matrix{T}(n, 2)
A = Matrix{T}(n, n) # TODO: sparse Jacobian support
B = zeros(T, n, 5)
# Allocate caches for phiv_timestep
maxiter = min(alg.m, n)
KsCache = _phiv_timestep_caches(u, maxiter, 4)
EPIRK4s3BCache(u,uprev,tmp,rtmp,rtmp2,K,A,B,KsCache)
end

####################################
Expand Down
17 changes: 10 additions & 7 deletions src/exponential_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,7 @@ end
Non-allocating version of `expv` that uses precomputed Krylov subspace `Ks`.
"""
function expv!(w::Vector{T}, t::Number, Ks::KrylovSubspace{B, T};
function expv!(w::AbstractVector{T}, t::Number, Ks::KrylovSubspace{B, T};
cache=nothing) where {B, T <: Number}
m, beta, V, H = Ks.m, Ks.beta, getV(Ks), getH(Ks)
@assert length(w) == size(V, 1) "Dimension mismatch"
Expand Down Expand Up @@ -451,7 +451,7 @@ end
Non-allocating version of 'phiv' that uses precomputed Krylov subspace `Ks`.
"""
function phiv!(w::Matrix{T}, t::Number, Ks::KrylovSubspace{B, T}, k::Integer;
function phiv!(w::AbstractMatrix{T}, t::Number, Ks::KrylovSubspace{B, T}, k::Integer;
cache=nothing, correct=false, errest=false) where {B, T <: Number}
m, beta, V, H = Ks.m, Ks.beta, getV(Ks), getH(Ks)
@assert size(w, 1) == size(V, 1) "Dimension mismatch"
Expand Down Expand Up @@ -529,12 +529,12 @@ end
Non-allocating version of `expv_timestep`.
"""
function expv_timestep!(u::Vector{T}, t::tType, A, b::Vector{T};
function expv_timestep!(u::AbstractVector{T}, t::tType, A, b::AbstractVector{T};
kwargs...) where {T <: Number, tType <: Real}
expv_timestep!(reshape(u, length(u), 1), [t], A, b; kwargs...)
return u
end
function expv_timestep!(U::Matrix{T}, ts::Vector{tType}, A, b::Vector{T};
function expv_timestep!(U::AbstractMatrix{T}, ts::Vector{tType}, A, b::AbstractVector{T};
kwargs...) where {T <: Number, tType <: Real}
B = reshape(b, length(b), 1)
phiv_timestep!(U, ts, A, B; kwargs...)
Expand Down Expand Up @@ -580,12 +580,12 @@ end
Non-allocating version of `phiv_timestep`.
"""
function phiv_timestep!(u::Vector{T}, t::tType, A, B::Matrix{T};
function phiv_timestep!(u::AbstractVector{T}, t::tType, A, B::AbstractMatrix{T};
kwargs...) where {T <: Number, tType <: Real}
phiv_timestep!(reshape(u, length(u), 1), [t], A, B; kwargs...)
return u
end
function phiv_timestep!(U::Matrix{T}, ts::Vector{tType}, A, B::Matrix{T}; tau::Real=0.0,
function phiv_timestep!(U::AbstractMatrix{T}, ts::Vector{tType}, A, B::AbstractMatrix{T}; tau::Real=0.0,
m::Int=min(10, size(A, 1)), tol::Real=1e-7, norm=Base.norm, iop::Int=0,
correct::Bool=false, caches=nothing, adaptive=false, delta::Real=1.2,
gamma::Real=0.8, NA::Int=0, verbose=false) where {T <: Number, tType <: Real}
Expand Down Expand Up @@ -613,7 +613,10 @@ function phiv_timestep!(U::Matrix{T}, ts::Vector{tType}, A, B::Matrix{T}; tau::R
phiv_cache = nothing # cache used by phiv!
else
u, W, P, Ks, phiv_cache = caches
@assert length(u) == n && size(W) == (n, p+1) && size(P) == (n, p+2) "Dimension mismatch"
@assert length(u) == n && size(W, 1) == n && size(P, 1) == n "Dimension mismatch"
# W and P may be bigger than actually needed
W = @view(W[:, 1:p+1])
P = @view(P[:, 1:p+2])
end
copy!(u, @view(B[:, 1])) # u(0) = b0
coeffs = ones(tType, p);
Expand Down
136 changes: 135 additions & 1 deletion src/perform_step/exponential_rk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ end

function perform_step!(integrator, cache::Exp4Cache, repeat_step=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack tmp,rtmp,rtmp2,k7,K,A,B,KsCache = cache
@unpack tmp,rtmp,rtmp2,K,A,B,KsCache = cache
f.jac(A, uprev, p, t)
alg = typeof(integrator.alg) <: CompositeAlgorithm ? integrator.alg.algs[integrator.cache.current] : integrator.alg
f0 = integrator.fsalfirst # f(u0) is fsaled
Expand Down Expand Up @@ -656,6 +656,7 @@ function perform_step!(integrator, cache::Exp4Cache, repeat_step=false)
A_mul_B!(rtmp, K, [1.0, -4/3, 1.0])
Base.axpy!(dt, rtmp, u)
# Krylov for the second remainder d7
k7 = @view(K[:, 1])
phiv_timestep!(k7, ts[1], A, B; kwargs...)
k7 ./= ts[1]
Base.axpy!(dt/6, k7, u)
Expand All @@ -665,6 +666,139 @@ function perform_step!(integrator, cache::Exp4Cache, repeat_step=false)
# integrator.k is automatically set due to aliasing
end

function perform_step!(integrator, cache::EPIRK4s3AConstantCache, repeat_step=false)
@unpack t,dt,uprev,f,p = integrator
A = f.jac(uprev, p, t)
alg = typeof(integrator.alg) <: CompositeAlgorithm ? integrator.alg.algs[integrator.cache.current] : integrator.alg
f0 = integrator.fsalfirst # f(uprev) is fsaled
kwargs = [(:tol, integrator.opts.reltol), (:iop, alg.iop), (:norm, integrator.opts.internalnorm), (:adaptive, true)]

# Compute U2 and U3 vertically
K = phiv_timestep([dt/2, 2dt/3], A, [zeros(f0) f0]; kwargs...)
U2 = uprev + K[:, 1]
U3 = uprev + K[:, 2]
R2 = f(U2, p, t + dt/2) - f0 - A*K[:, 1] # remainder of U2
R3 = f(U3, p, t + 2dt/3) - f0 - A*K[:, 2] # remainder of U3

# Update u (horizontally)
B = zeros(eltype(f0), length(f0), 5)
B[:, 2] = f0
B[:, 4] = (32R2 - 13.5R3) / dt^2
B[:, 5] = (-144R2 + 81R3) / dt^3
u = uprev + phiv_timestep(dt, A, B; kwargs...)

# Update integrator state
integrator.fsallast = f(u, p, t + dt)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.u = u
end

function perform_step!(integrator, cache::EPIRK4s3ACache, repeat_step=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack tmp,rtmp,rtmp2,K,A,B,KsCache = cache
f.jac(A, uprev, p, t)
alg = typeof(integrator.alg) <: CompositeAlgorithm ? integrator.alg.algs[integrator.cache.current] : integrator.alg
f0 = integrator.fsalfirst # f(u0) is fsaled
kwargs = [(:tol, integrator.opts.reltol), (:iop, alg.iop), (:norm, integrator.opts.internalnorm),
(:adaptive, true), (:caches, KsCache)]

# Compute U2 and U3 vertically
B[:, 2] .= f0
phiv_timestep!(K, [dt/2, 2dt/3], A, @view(B[:, 1:2]); kwargs...)
## U2 and R2
@. tmp = uprev + @view(K[:, 1]) # tmp is now U2
f(rtmp, tmp, p, t + dt/2); A_mul_B!(rtmp2, A, @view(K[:, 1]))
@. rtmp = rtmp - f0 - rtmp2 # rtmp is now R2
B[:, 4] .= (32/dt^2) * rtmp
B[:, 5] .= (-144/dt^3) * rtmp
## U3 and R3
@. tmp = uprev + @view(K[:, 2]) # tmp is now U3
f(rtmp, tmp, p, t + 2dt/3); A_mul_B!(rtmp2, A, @view(K[:, 2]))
@. rtmp = rtmp - f0 - rtmp2 # rtmp is now R3
B[:, 4] .-= (13.5/dt^2) * rtmp
B[:, 5] .+= (81/dt^3) * rtmp

# Update u
du = @view(K[:, 1])
phiv_timestep!(du, dt, A, B; kwargs...)
@. u = uprev + du

# Update integrator state
f(integrator.fsallast, u, p, t + dt)
# integrator.k is automatically set due to aliasing
end

function perform_step!(integrator, cache::EPIRK4s3BConstantCache, repeat_step=false)
@unpack t,dt,uprev,f,p = integrator
A = f.jac(uprev, p, t)
alg = typeof(integrator.alg) <: CompositeAlgorithm ? integrator.alg.algs[integrator.cache.current] : integrator.alg
f0 = integrator.fsalfirst # f(uprev) is fsaled
kwargs = [(:tol, integrator.opts.reltol), (:iop, alg.iop), (:norm, integrator.opts.internalnorm), (:adaptive, true)]

# Compute U2 and U3 vertically
K = phiv_timestep([dt/2, 3dt/4], A, [zeros(f0) zeros(f0) f0]; kwargs...)
K[:, 1] .*= 8 / (3*dt)
K[:, 2] .*= 16 / (9*dt)
U2 = uprev + K[:, 1]
U3 = uprev + K[:, 2]
R2 = f(U2, p, t + dt/2) - f0 - A*K[:, 1] # remainder of U2
R3 = f(U3, p, t + 3dt/4) - f0 - A*K[:, 2] # remainder of U3

# Update u (horizontally)
B = zeros(eltype(f0), length(f0), 5)
B[:, 2] = f0
B[:, 4] = (54R2 - 16R3) / dt^2
B[:, 5] = (-324R2 + 144R3) / dt^3
u = uprev + phiv_timestep(dt, A, B; kwargs...)

# Update integrator state
integrator.fsallast = f(u, p, t + dt)
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
integrator.u = u
end

function perform_step!(integrator, cache::EPIRK4s3BCache, repeat_step=false)
@unpack t,dt,uprev,u,f,p = integrator
@unpack tmp,rtmp,rtmp2,K,A,B,KsCache = cache
f.jac(A, uprev, p, t)
alg = typeof(integrator.alg) <: CompositeAlgorithm ? integrator.alg.algs[integrator.cache.current] : integrator.alg
f0 = integrator.fsalfirst # f(u0) is fsaled
kwargs = [(:tol, integrator.opts.reltol), (:iop, alg.iop), (:norm, integrator.opts.internalnorm),
(:adaptive, true), (:caches, KsCache)]

# Compute U2 and U3 vertically
fill!(@view(B[:, 2]), zero(eltype(B)))
B[:, 3] .= f0
phiv_timestep!(K, [dt/2, 3dt/4], A, @view(B[:, 1:3]); kwargs...)
K[:, 1] .*= 8 / (3*dt)
K[:, 2] .*= 16 / (9*dt)
## U2 and R2
@. tmp = uprev + @view(K[:, 1]) # tmp is now U2
f(rtmp, tmp, p, t + dt/2); A_mul_B!(rtmp2, A, @view(K[:, 1]))
@. rtmp = rtmp - f0 - rtmp2 # rtmp is now R2
B[:, 4] .= (54/dt^2) * rtmp
B[:, 5] .= (-324/dt^3) * rtmp
## U3 and R3
@. tmp = uprev + @view(K[:, 2]) # tmp is now U3
f(rtmp, tmp, p, t + 3dt/4); A_mul_B!(rtmp2, A, @view(K[:, 2]))
@. rtmp = rtmp - f0 - rtmp2 # rtmp is now R3
B[:, 4] .-= (16/dt^2) * rtmp
B[:, 5] .+= (144/dt^3) * rtmp

# Update u
fill!(@view(B[:, 3]), zero(eltype(B)))
B[:, 2] .= f0
du = @view(K[:, 1])
phiv_timestep!(du, dt, A, B; kwargs...)
@. u = uprev + du

# Update integrator state
f(integrator.fsallast, u, p, t + dt)
# integrator.k is automatically set due to aliasing
end

######################################################
# Multistep exponential integrators
function initialize!(integrator,cache::ETD2ConstantCache)
Expand Down
4 changes: 2 additions & 2 deletions test/linear_nonlinear_krylov_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ end
prob = ODEProblem(f, u0, (0.0, 1.0))
prob_ip = ODEProblem{true}(f_ip, u0, (0.0, 1.0))

dt = 0.1; tol=1e-5
Algs = [Exp4]
dt = 0.05; tol=1e-5
Algs = [Exp4, EPIRK4s3A, EPIRK4s3B]
for Alg in Algs
gc()
sol = solve(prob, Alg(); dt=dt, internalnorm=Base.norm, reltol=tol)
Expand Down

0 comments on commit 2207147

Please sign in to comment.