Skip to content

Commit

Permalink
base:nlsolver_unified Adapt to DiffEqBase changes
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Aug 29, 2019
1 parent 8f94512 commit 148cbad
Show file tree
Hide file tree
Showing 20 changed files with 454 additions and 470 deletions.
6 changes: 3 additions & 3 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ module OrdinaryDiffEq

using DiffEqBase: check_error!, @def, @.. , _vec, _reshape

using DiffEqBase: nlsolvefail, isnewton, set_new_W!, get_W, iipnlsolve, oopnlsolve
using DiffEqBase: nlsolvefail, isnewton, set_new_W!, get_W, get_linsolve, build_nlsolver, nlsolve!

using DiffEqBase: NLSolver

using DiffEqBase: FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence
using DiffEqBase: FastConvergence, Convergence, SlowConvergence, VerySlowConvergence, Divergence, MaxItersReached

import DiffEqBase: calculate_residuals, calculate_residuals!, nlsolve_f, unwrap_cache, @tight_loop_macros, islinear

import DiffEqBase: iip_get_uf, oop_get_uf, build_jac_config
import DiffEqBase: build_jac_config

import SparseDiffTools: forwarddiff_color_jacobian!, ForwardColorJacCache

Expand Down
12 changes: 4 additions & 8 deletions src/caches/adams_bashforth_moulton_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -905,8 +905,7 @@ end

function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = 1//2, 1
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

k2 = rate_prototype
uprev3 = u
Expand All @@ -917,8 +916,7 @@ end

function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = 1//2, 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

k1 = zero(rate_prototype)
Expand Down Expand Up @@ -955,8 +953,7 @@ end

function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = 1//1, 1
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

k2 = rate_prototype
uprev2 = u
Expand All @@ -968,8 +965,7 @@ end

function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = 1//1, 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

k1 = zero(rate_prototype)
Expand Down
36 changes: 12 additions & 24 deletions src/caches/bdf_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ end
function alg_cache(alg::ABDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,
uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = 2//3, 1
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))
eulercache = ImplicitEulerConstantCache(nlsolver)

dtₙ₋₁ = one(dt)
Expand All @@ -34,8 +33,7 @@ end
function alg_cache(alg::ABDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = 2//3, 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

fsalfirstprev = zero(rate_prototype)
Expand Down Expand Up @@ -84,8 +82,7 @@ end

function alg_cache(alg::SBDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = 1//1, 1
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

k2 = rate_prototype
k₁ = rate_prototype; k₂ = rate_prototype; k₃ = rate_prototype
Expand All @@ -98,8 +95,7 @@ end

function alg_cache(alg::SBDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = 1//1, 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

order = alg.order
Expand Down Expand Up @@ -144,8 +140,7 @@ end

function alg_cache(alg::QNDF1,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = zero(inv((1-alg.kappa))), 1
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

uprev2 = u
dtₙ₋₁ = t
Expand All @@ -162,8 +157,7 @@ end

function alg_cache(alg::QNDF1,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = zero(inv((1-alg.kappa))), 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

D = Array{typeof(u)}(undef, 1, 1)
Expand Down Expand Up @@ -215,8 +209,7 @@ end

function alg_cache(alg::QNDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = zero(inv((1-alg.kappa))), 1
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

uprev2 = u
uprev3 = u
Expand All @@ -235,8 +228,7 @@ end

function alg_cache(alg::QNDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = zero(inv((1-alg.kappa))), 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

D = Array{typeof(u)}(undef, 1, 2)
Expand Down Expand Up @@ -292,8 +284,7 @@ end

function alg_cache(alg::QNDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = one(eltype(alg.kappa)), 1
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

udiff = fill(zero(u), 1, 6)
dts = fill(zero(dt), 1, 6)
Expand All @@ -311,8 +302,7 @@ end

function alg_cache(alg::QNDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = one(eltype(alg.kappa)), 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

udiff = Array{typeof(u)}(undef, 1, 6)
Expand Down Expand Up @@ -357,8 +347,7 @@ end
function alg_cache(alg::MEBDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = 1, 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); tmp2 = zero(u)
Expand All @@ -374,7 +363,6 @@ end
function alg_cache(alg::MEBDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = 1, 1
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))
MEBDF2ConstantCache(nlsolver)
end
30 changes: 10 additions & 20 deletions src/caches/kencarp_kvaerno_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
tab = KenCarp3Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

KenCarp3ConstantCache(nlsolver,tab)
end
Expand All @@ -34,8 +33,7 @@ function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
tab = KenCarp3Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

if typeof(f) <: SplitFunction
Expand All @@ -62,8 +60,7 @@ function alg_cache(alg::Kvaerno4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
tab = Kvaerno4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))
Kvaerno4ConstantCache(nlsolver,tab)
end

Expand All @@ -85,8 +82,7 @@ function alg_cache(alg::Kvaerno4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
tab = Kvaerno4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u); z₅ = nlsolver.z
Expand All @@ -104,8 +100,7 @@ function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
tab = KenCarp4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))
KenCarp4ConstantCache(nlsolver,tab)
end

Expand Down Expand Up @@ -134,8 +129,7 @@ function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
tab = KenCarp4Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

if typeof(f) <: SplitFunction
Expand Down Expand Up @@ -165,8 +159,7 @@ function alg_cache(alg::Kvaerno5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
tab = Kvaerno5Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

Kvaerno5ConstantCache(nlsolver,tab)
end
Expand All @@ -191,8 +184,7 @@ function alg_cache(alg::Kvaerno5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
tab = Kvaerno5Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u); z₅ = zero(u)
Expand All @@ -211,8 +203,7 @@ function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
tab = KenCarp5Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))

KenCarp5ConstantCache(nlsolver,tab)
end
Expand Down Expand Up @@ -246,8 +237,7 @@ function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
tab = KenCarp5Tableau(real(uBottomEltypeNoUnits),real(tTypeNoUnits))
γ, c = tab.γ, tab.c3
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
fsalfirst = zero(rate_prototype)

if typeof(f) <: SplitFunction
Expand Down
18 changes: 6 additions & 12 deletions src/caches/pdirk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,11 @@ end
function alg_cache(alg::PDIRK44,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = 1.0, 1.0
if alg.threading
J1, W1 = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver1 = iipnlsolve(alg,u,uprev,p,t,dt,f,W1,J1,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
J2, W2 = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver2 = iipnlsolve(alg,u,uprev,p,t,dt,f,W2,J2,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver1 = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
nlsolver2 = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
nlsolver = [nlsolver1, nlsolver2]
else
_J, _W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
_nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,_W,_J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
_nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))
nlsolver = [_nlsolver]
end
tab = PDIRK44Tableau(real(uBottomEltypeNoUnits), real(tTypeNoUnits))
Expand All @@ -67,14 +64,11 @@ end
function alg_cache(alg::PDIRK44,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = 1.0, 1.0
if alg.threading
J1, W1 = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver1 = oopnlsolve(alg,u,uprev,p,t,dt,f,W1,J1,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
J2, W2 = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver2 = oopnlsolve(alg,u,uprev,p,t,dt,f,W2,J2,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver1 = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))
nlsolver2 = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))
nlsolver = [nlsolver1, nlsolver2]
else
_J, _W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
_nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,_W,_J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
_nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))
nlsolver = [_nlsolver]
end
tab = PDIRK44Tableau(real(uBottomEltypeNoUnits), real(tTypeNoUnits))
Expand Down
6 changes: 2 additions & 4 deletions src/caches/rkc_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,15 @@ end

function alg_cache(alg::IRKC,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{false})
γ, c = 1.0, 1.0
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(false))
zprev = u
du₁ = rate_prototype; du₂ = rate_prototype
IRKCConstantCache(50,zprev,nlsolver,du₁,du₂)
end

function alg_cache(alg::IRKC,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Val{true})
γ, c = 1.0, 1.0
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
nlsolver = build_nlsolver(alg,alg.nlsolve,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,f,t,dt,p,γ,c,Val(true))

gprev = similar(u)
gprev2 = similar(u)
Expand Down

0 comments on commit 148cbad

Please sign in to comment.