Skip to content

Commit

Permalink
Merge 20ad0c6 into 7ff1362
Browse files Browse the repository at this point in the history
  • Loading branch information
kanav99 committed Jul 19, 2019
2 parents 7ff1362 + 20ad0c6 commit 1680c8c
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ module OrdinaryDiffEq
include("caches/rkc_caches.jl")
include("caches/extrapolation_caches.jl")
include("caches/prk_caches.jl")
include("caches/pdirk_caches.jl")


include("alg_utils.jl")
Expand Down Expand Up @@ -130,6 +131,7 @@ module OrdinaryDiffEq
include("perform_step/rkc_perform_step.jl")
include("perform_step/extrapolation_perform_step.jl")
include("perform_step/prk_perform_step.jl")
include("perform_step/pdirk_perform_step.jl")

include("dense/generic_dense.jl")
include("dense/interpolants.jl")
Expand Down Expand Up @@ -235,5 +237,5 @@ module OrdinaryDiffEq

export AitkenNeville, ExtrapolationMidpointDeuflhard, ExtrapolationMidpointHairerWanner, ImplicitEulerExtrapolation

export KuttaPRK2p5
export KuttaPRK2p5, PDIRK44
end # module
2 changes: 2 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ isfsal(alg::DGLDDRK84_F) = false
isfsal(alg::NDBLSRK124) = false
isfsal(alg::NDBLSRK134) = false
isfsal(alg::NDBLSRK144) = false
isfsal(alg::PDIRK44) = false
get_current_isfsal(alg, cache) = isfsal(alg)
get_current_isfsal(alg::CompositeAlgorithm, cache) = isfsal(alg.algs[cache.current])

Expand Down Expand Up @@ -342,6 +343,7 @@ alg_order(alg::RKC) = 2
alg_order(alg::IRKC) = 2

alg_order(alg::MEBDF2) = 2
alg_order(alg::PDIRK44) = 4

alg_maximum_order(alg) = alg_order(alg)
alg_maximum_order(alg::CompositeAlgorithm) = maximum(alg_order(x) for x in alg.algs)
Expand Down
12 changes: 12 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -943,6 +943,18 @@ MEBDF2(;chunk_size=0,autodiff=true,diff_type=Val{:forward},

#################################################

struct PDIRK44{CS,AD,F,F2,FDT} <: OrdinaryDiffEqNewtonAlgorithm{CS,AD}
linsolve::F
nlsolve::F2
diff_type::FDT
extrapolant::Symbol
threading::Bool
end
PDIRK44(;chunk_size=0,autodiff=true,diff_type=Val{:forward},
linsolve=DEFAULT_LINSOLVE,nlsolve=NLNewton(),
extrapolant=:constant,threading=true) =
PDIRK44{chunk_size,autodiff,typeof(linsolve),typeof(nlsolve),typeof(diff_type)}(
linsolve,nlsolve,diff_type,extrapolant,threading)
### Algorithm Groups

const MultistepAlgorithms = Union{IRKN3,IRKN4,
Expand Down
82 changes: 82 additions & 0 deletions src/caches/pdirk_caches.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
@cache struct PDIRK44Cache{uType,rateType,N,TabType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
k1::Array{rateType}
k2::Array{rateType}
nlsolver::N
tab::TabType
end

struct PDIRK44ConstantCache{N,TabType} <: OrdinaryDiffEqConstantCache
nlsolver::N
tab::TabType
end

struct PDIRK44Tableau{T,T2}
γs::SVector{2,T2}
cs::SVector{4,T2}
α1::SVector{2,T}
α2::SVector{2,T}
b1::T
b2::T
b3::T
b4::T
end

function PDIRK44Tableau(::Type{T}, ::Type{T2}) where {T,T2}
γ1 = convert(T2, 1//2)
γ2 = convert(T2, 2//3)
γs = SVector(γ1, γ2)
c1 = convert(T2, 1//2)
c2 = convert(T2, 2//3)
c3 = convert(T2, 1//2)
c4 = convert(T2, 1//3)
cs = SVector(c1,c2,c3,c4)
α11 = convert(T, -5//2)
α12 = convert(T, -5//3)
α1 = SVector(α11, α12)
α21 = convert(T, 5//2)
α22 = convert(T, 4//3)
α2 = SVector(α21, α22)
b1 = convert(T, -1//1)
b2 = convert(T, -1//1)
b3 = convert(T, 3//2)
b4 = convert(T, 3//2)
PDIRK44Tableau{T,T2}(γs,cs,α1,α2,b1,b2,b3,b4)
end

function alg_cache(alg::PDIRK44,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
γ, c = 1.0, 1.0
if alg.threading
J1, W1 = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver1 = iipnlsolve(alg,u,uprev,p,t,dt,f,W1,J1,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
J2, W2 = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver2 = iipnlsolve(alg,u,uprev,p,t,dt,f,W2,J2,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = [nlsolver1, nlsolver2]
else
_J, _W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
_nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,_W,_J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = [_nlsolver]
end
tab = PDIRK44Tableau(real(uBottomEltypeNoUnits), real(tTypeNoUnits))
k1 = [zero(rate_prototype) for i in 1:2 ]
k2 = [zero(rate_prototype) for i in 1:2 ]
PDIRK44Cache(u,uprev,k1,k2,nlsolver,tab)
end

function alg_cache(alg::PDIRK44,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
γ, c = 1.0, 1.0
if alg.threading
J1, W1 = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver1 = oopnlsolve(alg,u,uprev,p,t,dt,f,W1,J1,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
J2, W2 = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver2 = oopnlsolve(alg,u,uprev,p,t,dt,f,W2,J2,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = [nlsolver1, nlsolver2]
else
_J, _W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
_nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,_W,_J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = [_nlsolver]
end
tab = PDIRK44Tableau(real(uBottomEltypeNoUnits), real(tTypeNoUnits))
PDIRK44ConstantCache(nlsolver,tab)
end
1 change: 1 addition & 0 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -499,5 +499,6 @@ function update_W!(nlsolver::NLSolver, integrator, cache::OrdinaryDiffEqConstant
nothing
end


iip_get_uf(alg::OrdinaryDiffEqAlgorithm,nf,t,p) = DiffEqDiffTools.UJacobianWrapper(nf,t,p)
oop_get_uf(alg::OrdinaryDiffEqAlgorithm,nf,t,p) = DiffEqDiffTools.UDerivativeWrapper(nf,t,p)
1 change: 1 addition & 0 deletions src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ function reset_fsal!(integrator)
end

nlsolve!(integrator, cache) = DiffEqBase.nlsolve!(cache.nlsolver, cache.nlsolver.cache, integrator)
nlsolve!(nlsolver::NLSolver, integrator) = DiffEqBase.nlsolve!(nlsolver, nlsolver.cache, integrator)

DiffEqBase.nlsolve_f(f, alg::OrdinaryDiffEqAlgorithm) = f isa SplitFunction && issplit(alg) ? f.f1 : f
DiffEqBase.nlsolve_f(integrator::ODEIntegrator) =
Expand Down
134 changes: 134 additions & 0 deletions src/perform_step/pdirk_perform_step.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
function initialize!(integrator, cache::PDIRK44ConstantCache) end

@muladd function perform_step!(integrator, cache::PDIRK44ConstantCache, repeat_step=false)
@unpack dt,uprev,u = integrator
alg = unwrap_alg(integrator, true)
@unpack nlsolver, tab = cache
@unpack γs,cs,α1,α2,b1,b2,b3,b4 = tab

if alg.threading == true
k2 = Array{typeof(u)}(undef,2)
k1 = Array{typeof(u)}(undef,2)
let nlsolver=nlsolver, u=u, uprev=uprev, integrator=integrator, cache=cache, dt=dt, repeat_step=repeat_step,
k1=k1
Threads.@threads for i in 1:2
nlsolver[i].z = zero(u)
nlsolver[i].tmp = uprev
update_W!(nlsolver[i], integrator, cache, γs[i]*dt, repeat_step)
nlsolver[i].γ = γs[i]
nlsolver[i].c = cs[i]
k1[i] = nlsolve!(nlsolver[i], integrator)
end
end
nlsolvefail(nlsolver[1]) && return
nlsolvefail(nlsolver[2]) && return
let nlsolver=nlsolver, u=u, uprev=uprev, integrator=integrator, cache=cache, dt=dt, repeat_step=repeat_step,
k1=k1, k2=k2
Threads.@threads for i in 1:2
nlsolver[i].c = cs[2+i]
nlsolver[i].z = zero(u)
nlsolver[i].tmp = uprev + α1[i] * k1[1] + α2[i] * k1[2]
k2[i] = DiffEqBase.nlsolve!(nlsolver[i], nlsolver[i].cache, integrator)
end
end
nlsolvefail(nlsolver[1]) && return
nlsolvefail(nlsolver[2]) && return
integrator.u = uprev + b1 * k1[1] + b2 * k2[1] + b3 * k1[2] + b4 * k2[2]
else
_nlsolver = nlsolver[1]
_nlsolver.z = zero(u)
update_W!(_nlsolver, integrator, cache, γs[1]*dt, repeat_step)
_nlsolver.tmp = uprev
_nlsolver.γ = γs[1]
_nlsolver.c = cs[1]
k11 = nlsolve!(_nlsolver, integrator)
nlsolvefail(_nlsolver) && return
_nlsolver.z = zero(u)
update_W!(_nlsolver, integrator, cache, γs[2]*dt, repeat_step)
_nlsolver.tmp = uprev
_nlsolver.γ = γs[2]
_nlsolver.c = cs[2]
k12 = nlsolve!(_nlsolver, integrator)
nlsolvefail(_nlsolver) && return
_nlsolver.z = zero(u)
update_W!(_nlsolver, integrator, cache, γs[1]*dt, repeat_step)
_nlsolver.tmp = uprev + α1[1] * k11 + α2[1] * k12
_nlsolver.γ = γs[1]
_nlsolver.c = cs[3]
k21 = nlsolve!(_nlsolver, integrator)
nlsolvefail(_nlsolver) && return
_nlsolver.z = zero(u)
update_W!(_nlsolver, integrator, cache, γs[2]*dt, repeat_step)
_nlsolver.tmp = uprev + α1[2] * k11 + α2[2] * k12
_nlsolver.γ = γs[2]
_nlsolver.c = cs[4]
k22 = nlsolve!(_nlsolver, integrator)
nlsolvefail(_nlsolver) && return
integrator.u = uprev + b1 * k11 + b2 * k21 + b3 * k12 + b4 * k22
end
end

function initialize!(integrator, cache::PDIRK44Cache) end

@muladd function perform_step!(integrator, cache::PDIRK44Cache, repeat_step=false)
@unpack t,dt,uprev,u,f,p,alg = integrator
@unpack nlsolver,k1,k2,tab = cache
@unpack γs,cs,α1,α2,b1,b2,b3,b4 = tab
if alg.threading == true
let nlsolver=nlsolver, u=u, uprev=uprev, integrator=integrator, cache=cache, dt=dt, repeat_step=repeat_step,
k1=k1
Threads.@threads for i in 1:2
nlsolver[i].z .= zero(eltype(u))
nlsolver[i].tmp .= uprev
update_W!(nlsolver[i], integrator, cache, γs[i]*dt, repeat_step)
nlsolver[i].γ = γs[i]
nlsolver[i].c = cs[i]
k1[i] .= nlsolve!(nlsolver[i], integrator)
end
end
nlsolvefail(nlsolver[1]) && return
nlsolvefail(nlsolver[2]) && return
let nlsolver=nlsolver, u=u, uprev=uprev, integrator=integrator, cache=cache, dt=dt, repeat_step=repeat_step,
k1=k1, k2=k2
Threads.@threads for i in 1:2
nlsolver[i].c = cs[2+i]
nlsolver[i].z .= zero(eltype(u))
@.. nlsolver[i].tmp = uprev + α1[i] * k1[1] + α2[i] * k1[2]
k2[i] .= nlsolve!(nlsolver[i], integrator)
end
end
nlsolvefail(nlsolver[1]) && return
nlsolvefail(nlsolver[2]) && return
else
_nlsolver = nlsolver[1]
_nlsolver.z .= zero(eltype(u))
update_W!(_nlsolver, integrator, cache, γs[1]*dt, repeat_step)
_nlsolver.tmp .= uprev
_nlsolver.γ = γs[1]
_nlsolver.c = cs[1]
k1[1] .= nlsolve!(_nlsolver, integrator)
nlsolvefail(_nlsolver) && return
_nlsolver.z .= zero(eltype(u))
update_W!(_nlsolver, integrator, cache, γs[2]*dt, repeat_step)
_nlsolver.tmp .= uprev
_nlsolver.γ = γs[2]
_nlsolver.c = cs[2]
k1[2] .= nlsolve!(_nlsolver, integrator)
nlsolvefail(_nlsolver) && return
_nlsolver.z .= zero(eltype(u))
update_W!(_nlsolver, integrator, cache, γs[1]*dt, repeat_step)
@.. _nlsolver.tmp .= uprev + α1[1] * k1[1] + α2[1] * k1[2]
_nlsolver.γ = γs[1]
_nlsolver.c = cs[3]
k2[1] .= nlsolve!(_nlsolver, integrator)
nlsolvefail(_nlsolver) && return
_nlsolver.z .= zero(eltype(u))
update_W!(_nlsolver, integrator, cache, γs[2]*dt, repeat_step)
@.. _nlsolver.tmp = uprev + α1[2] * k1[1] + α2[2] * k1[2]
_nlsolver.γ = γs[2]
_nlsolver.c = cs[4]
k2[2] .= nlsolve!(_nlsolver, integrator)
nlsolvefail(_nlsolver) && return
end
@.. u = uprev + b1 * k1[1] + b2 * k2[1] + b3 * k1[2] + b4 * k2[2]
end

0 comments on commit 1680c8c

Please sign in to comment.