diff --git a/src/caches/kencarp_kvaerno_caches.jl b/src/caches/kencarp_kvaerno_caches.jl index 42d96f675e..08cbdae6f8 100644 --- a/src/caches/kencarp_kvaerno_caches.jl +++ b/src/caches/kencarp_kvaerno_caches.jl @@ -1,5 +1,4 @@ -@cache mutable struct KenCarp3ConstantCache{F,N,Tab} <: OrdinaryDiffEqConstantCache - uf::F +@cache mutable struct KenCarp3ConstantCache{N,Tab} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -10,17 +9,14 @@ function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - KenCarp3ConstantCache(uf,nlsolver,tab) + KenCarp3ConstantCache(nlsolver,tab) end -@cache mutable struct KenCarp3Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,N,Tab,F,kType} <: SDIRKMutableCache +@cache mutable struct KenCarp3Cache{uType,rateType,uNoUnitsType,N,Tab,kType} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType z₃::uType @@ -29,15 +25,7 @@ end k2::kType k3::kType k4::kType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -48,7 +36,7 @@ function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) if typeof(f) <: SplitFunction k1 = similar(u); k2 = similar(u) @@ -59,15 +47,13 @@ function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo uf = DiffEqDiffTools.UJacobianWrapper(f,t,p) end - z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = z + z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - KenCarp3Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,k1,k2,k3,k4,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + KenCarp3Cache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,k1,k2,k3,k4,atmp,nlsolver,tab) end -@cache mutable struct Kvaerno4ConstantCache{F,N,Tab} <: OrdinaryDiffEqConstantCache - uf::F +@cache mutable struct Kvaerno4ConstantCache{N,Tab} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -78,30 +64,19 @@ function alg_cache(alg::Kvaerno4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields Kvaerno4ConstantCache(uf,nlsolver,tab) end -@cache mutable struct Kvaerno4Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,N,Tab,F} <: SDIRKMutableCache +@cache mutable struct Kvaerno4Cache{uType,rateType,uNoUnitsType,N,Tab} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType z₃::uType z₄::uType z₅::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -112,17 +87,15 @@ function alg_cache(alg::Kvaerno4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) - z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u); z₅ = z + z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u); z₅ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - Kvaerno4Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,z₅,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + Kvaerno4Cache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,z₅,atmp,nlsolver,tab) end -@cache mutable struct KenCarp4ConstantCache{F,N,Tab} <: OrdinaryDiffEqConstantCache - uf::F +@cache mutable struct KenCarp4ConstantCache{N,Tab} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -133,16 +106,13 @@ function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - KenCarp4ConstantCache(uf,nlsolver,tab) + KenCarp4ConstantCache(nlsolver,tab) end -@cache mutable struct KenCarp4Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,N,Tab,F,kType} <: SDIRKMutableCache +@cache mutable struct KenCarp4Cache{uType,rateType,uNoUnitsType,N,Tab,kType} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType z₃::uType @@ -155,15 +125,7 @@ end k4::kType k5::kType k6::kType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -174,7 +136,7 @@ function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) if typeof(f) <: SplitFunction k1 = similar(u); k2 = similar(u) @@ -188,16 +150,13 @@ function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo end z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u); z₅ = zero(u) - z₆ = z + z₆ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - KenCarp4Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,z₅,z₆,k1,k2,k3,k4,k5,k6, - dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + KenCarp4Cache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,z₅,z₆,k1,k2,k3,k4,k5,k6,atmp,nlsolver,tab) end -@cache mutable struct Kvaerno5ConstantCache{F,N,Tab} <: OrdinaryDiffEqConstantCache - uf::F +@cache mutable struct Kvaerno5ConstantCache{N,Tab} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -208,15 +167,13 @@ function alg_cache(alg::Kvaerno5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - Kvaerno5ConstantCache(uf,nlsolver,tab) + Kvaerno5ConstantCache(nlsolver,tab) end -@cache mutable struct Kvaerno5Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,N,Tab,F} <: SDIRKMutableCache +@cache mutable struct Kvaerno5Cache{uType,rateType,uNoUnitsType,N,Tab} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType k::rateType z₁::uType @@ -226,15 +183,7 @@ end z₅::uType z₆::uType z₇::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -245,18 +194,16 @@ function alg_cache(alg::Kvaerno5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u); z₅ = zero(u) - z₆ = zero(u); z₇ = z + z₆ = zero(u); z₇ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - Kvaerno5Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,z₅,z₆,z₇,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + Kvaerno5Cache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,z₅,z₆,z₇,atmp,nlsolver,tab) end -@cache mutable struct KenCarp5ConstantCache{F,N,Tab} <: OrdinaryDiffEqConstantCache - uf::F +@cache mutable struct KenCarp5ConstantCache{N,Tab} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -267,17 +214,14 @@ function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - KenCarp5ConstantCache(uf,nlsolver,tab) + KenCarp5ConstantCache(nlsolver,tab) end -@cache mutable struct KenCarp5Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,N,Tab,F,kType} <: SDIRKMutableCache +@cache mutable struct KenCarp5Cache{uType,rateType,uNoUnitsType,N,Tab,kType} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType z₃::uType @@ -294,15 +238,7 @@ end k6::kType k7::kType k8::kType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -313,7 +249,7 @@ function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, tab.c3 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) if typeof(f) <: SplitFunction k1 = similar(u); k2 = similar(u) @@ -325,15 +261,12 @@ function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo k3 = nothing; k4 = nothing k5 = nothing; k6 = nothing k7 = nothing; k8 = nothing - uf = DiffEqDiffTools.UJacobianWrapper(f,t,p) end z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u) - z₅ = zero(u); z₆ = zero(u); z₇ = zero(u); z₈ = z + z₅ = zero(u); z₆ = zero(u); z₇ = zero(u); z₈ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - KenCarp5Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈, - k1,k2,k3,k4,k5,k6,k7,k8, - dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + KenCarp5Cache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈, + k1,k2,k3,k4,k5,k6,k7,k8,atmp,nlsolver,tab) end diff --git a/src/caches/sdirk_caches.jl b/src/caches/sdirk_caches.jl index d7dce182da..b9b3d57040 100644 --- a/src/caches/sdirk_caches.jl +++ b/src/caches/sdirk_caches.jl @@ -33,8 +33,7 @@ function alg_cache(alg::ImplicitEuler,u,rate_prototype,uEltypeNoUnits,uBottomElt ImplicitEulerConstantCache(nlsolver) end -mutable struct ImplicitMidpointConstantCache{F,N} <: OrdinaryDiffEqConstantCache - uf::F +mutable struct ImplicitMidpointConstantCache{N} <: OrdinaryDiffEqConstantCache nlsolver::N end @@ -42,25 +41,13 @@ function alg_cache(alg::ImplicitMidpoint,u,rate_prototype,uEltypeNoUnits,uBottom γ, c = 1//2, 1//2 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - ImplicitMidpointConstantCache(uf,nlsolver) + ImplicitMidpointConstantCache(nlsolver) end -@cache mutable struct ImplicitMidpointCache{uType,rateType,JType,WType,UF,JC,F,N} <: SDIRKMutableCache +@cache mutable struct ImplicitMidpointCache{uType,rateType,N} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType - z::uType - dz::uType - b::uType - tmp::uType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N end @@ -69,12 +56,11 @@ function alg_cache(alg::ImplicitMidpoint,u,rate_prototype,uEltypeNoUnits,uBottom γ, c = 1//2, 1//2 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields - ImplicitMidpointCache(u,uprev,du1,fsalfirst,k,z,dz,b,tmp,J,W,uf,jac_config,linsolve,nlsolver) + fsalfirst = zero(rate_prototype) + ImplicitMidpointCache(u,uprev,fsalfirst,nlsolver) end -mutable struct TrapezoidConstantCache{F,uType,tType,N} <: OrdinaryDiffEqConstantCache - uf::F +mutable struct TrapezoidConstantCache{uType,tType,N} <: OrdinaryDiffEqConstantCache uprev3::uType tprev2::tType nlsolver::N @@ -85,31 +71,19 @@ function alg_cache(alg::Trapezoid,u,rate_prototype,uEltypeNoUnits,uBottomEltypeN γ, c = 1//2, 1 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields uprev3 = u tprev2 = t - TrapezoidConstantCache(uf,uprev3,tprev2,nlsolver) + TrapezoidConstantCache(uprev3,tprev2,nlsolver) end -@cache mutable struct TrapezoidCache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,tType,F,N} <: SDIRKMutableCache +@cache mutable struct TrapezoidCache{uType,rateType,uNoUnitsType,tType,N} <: SDIRKMutableCache u::uType uprev::uType uprev2::uType - du1::rateType fsalfirst::rateType - k::rateType - z::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F uprev3::uType tprev2::tType nlsolver::N @@ -120,17 +94,16 @@ function alg_cache(alg::Trapezoid,u,rate_prototype,uEltypeNoUnits,uBottomEltypeN γ, c = 1//2, 1 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) uprev3 = zero(u) tprev2 = t atmp = similar(u,uEltypeNoUnits) - TrapezoidCache(u,uprev,uprev2,du1,fsalfirst,k,z,dz,b,tmp,atmp,J,W,uf,jac_config,linsolve,uprev3,tprev2,nlsolver) + TrapezoidCache(u,uprev,uprev2,fsalfirst,atmp,uprev3,tprev2,nlsolver) end -mutable struct TRBDF2ConstantCache{F,Tab,N} <: OrdinaryDiffEqConstantCache - uf::F +mutable struct TRBDF2ConstantCache{Tab,N} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -141,28 +114,16 @@ function alg_cache(alg::TRBDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn γ, c = tab.d, tab.γ J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - TRBDF2ConstantCache(uf,nlsolver,tab) + TRBDF2ConstantCache(nlsolver,tab) end -@cache mutable struct TRBDF2Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,Tab,F,N} <: SDIRKMutableCache +@cache mutable struct TRBDF2Cache{uType,rateType,uNoUnitsType,Tab,N} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType zprev::uType zᵧ::uType - z::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -173,16 +134,14 @@ function alg_cache(alg::TRBDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn γ, c = tab.d, tab.γ J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) atmp = similar(u,uEltypeNoUnits); zprev = similar(u); zᵧ = similar(u) - TRBDF2Cache(u,uprev,du1,fsalfirst,k,zprev,zᵧ,z,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + TRBDF2Cache(u,uprev,fsalfirst,zprev,zᵧ,atmp,nlsolver,tab) end -mutable struct SDIRK2ConstantCache{F,N} <: OrdinaryDiffEqConstantCache - uf::F +mutable struct SDIRK2ConstantCache{N} <: OrdinaryDiffEqConstantCache nlsolver::N end @@ -191,27 +150,16 @@ function alg_cache(alg::SDIRK2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn γ, c = 1, 1 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - SDIRK2ConstantCache(uf,nlsolver) + SDIRK2ConstantCache(nlsolver) end -@cache mutable struct SDIRK2Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,F,N} <: SDIRKMutableCache +@cache mutable struct SDIRK2Cache{uType,rateType,uNoUnitsType,N} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N end @@ -220,17 +168,15 @@ function alg_cache(alg::SDIRK2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUn γ, c = 1, 1 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) - z₁ = similar(u); z₂ = z + z₁ = similar(u); z₂ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - SDIRK2Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver) + SDIRK2Cache(u,uprev,fsalfirst,z₁,z₂,atmp,nlsolver) end -mutable struct SSPSDIRK2ConstantCache{F,N} <: OrdinaryDiffEqConstantCache - uf::F +mutable struct SSPSDIRK2ConstantCache{N} <: OrdinaryDiffEqConstantCache nlsolver::N end @@ -239,26 +185,15 @@ function alg_cache(alg::SSPSDIRK2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeN γ, c = 1//4, 1//1 J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - SSPSDIRK2ConstantCache(uf,nlsolver) + SSPSDIRK2ConstantCache(nlsolver) end -@cache mutable struct SSPSDIRK2Cache{uType,rateType,JType,WType,UF,JC,F,N} <: SDIRKMutableCache +@cache mutable struct SSPSDIRK2Cache{uType,rateType,N} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType - dz::uType - b::uType - tmp::uType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N end @@ -267,17 +202,15 @@ function alg_cache(alg::SSPSDIRK2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeN γ, c = 1//4, 1//1 J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) - z₁ = similar(u); z₂ = z + z₁ = similar(u); z₂ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - SSPSDIRK2Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,dz,b,tmp,J, - W,uf,jac_config,linsolve,nlsolver) + SSPSDIRK2Cache(u,uprev,fsalfirst,z₁,z₂,nlsolver) end -mutable struct Kvaerno3ConstantCache{UF,Tab,N} <: OrdinaryDiffEqConstantCache - uf::UF +mutable struct Kvaerno3ConstantCache{Tab,N} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -288,29 +221,18 @@ function alg_cache(alg::Kvaerno3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, 2tab.γ J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - Kvaerno3ConstantCache(uf,nlsolver,tab) + Kvaerno3ConstantCache(nlsolver,tab) end -@cache mutable struct Kvaerno3Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,Tab,F,N} <: SDIRKMutableCache +@cache mutable struct Kvaerno3Cache{uType,rateType,uNoUnitsType,Tab,N} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType z₃::uType z₄::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -321,17 +243,15 @@ function alg_cache(alg::Kvaerno3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo γ, c = tab.γ, 2tab.γ J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) - z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = z + z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - Kvaerno3Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + Kvaerno3Cache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,atmp,nlsolver,tab) end -mutable struct Cash4ConstantCache{F,N,Tab} <: OrdinaryDiffEqConstantCache - uf::F +mutable struct Cash4ConstantCache{N,Tab} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -342,30 +262,19 @@ function alg_cache(alg::Cash4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni γ, c = tab.γ,tab.γ J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - Cash4ConstantCache(uf,nlsolver,tab) + Cash4ConstantCache(nlsolver,tab) end -@cache mutable struct Cash4Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,N,Tab,F} <: SDIRKMutableCache +@cache mutable struct Cash4Cache{uType,rateType,uNoUnitsType,N,Tab} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType z₃::uType z₄::uType z₅::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -376,17 +285,15 @@ function alg_cache(alg::Cash4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni γ, c = tab.γ,tab.γ J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) - z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = similar(u); z₅ = z + z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = similar(u); z₅ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - Cash4Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,z₅,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + Cash4Cache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,z₅,atmp,nlsolver,tab) end -mutable struct Hairer4ConstantCache{F,N,Tab} <: OrdinaryDiffEqConstantCache - uf::F +mutable struct Hairer4ConstantCache{N,Tab} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -401,30 +308,19 @@ function alg_cache(alg::Union{Hairer4,Hairer42},u,rate_prototype,uEltypeNoUnits, γ, c = tab.γ, tab.γ J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - Hairer4ConstantCache(uf,nlsolver,tab) + Hairer4ConstantCache(nlsolver,tab) end -@cache mutable struct Hairer4Cache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,Tab,F,N} <: SDIRKMutableCache +@cache mutable struct Hairer4Cache{uType,rateType,uNoUnitsType,Tab,N} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType z₂::uType z₃::uType z₄::uType z₅::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -439,31 +335,20 @@ function alg_cache(alg::Union{Hairer4,Hairer42},u,rate_prototype,uEltypeNoUnits, γ, c = tab.γ, tab.γ J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) - z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = similar(u); z₅ = z + z₁ = similar(u); z₂ = similar(u); z₃ = similar(u); z₄ = similar(u); z₅ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - Hairer4Cache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,z₅,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + Hairer4Cache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,z₅,atmp,nlsolver,tab) end -@cache mutable struct ESDIRK54I8L2SACache{uType,rateType,uNoUnitsType,JType,WType,UF,JC,Tab,F,N} <: SDIRKMutableCache +@cache mutable struct ESDIRK54I8L2SACache{uType,rateType,uNoUnitsType,Tab,N} <: SDIRKMutableCache u::uType uprev::uType - du1::rateType fsalfirst::rateType - k::rateType z₁::uType; z₂::uType; z₃::uType; z₄::uType; z₅::uType; z₆::uType; z₇::uType; z₈::uType - dz::uType - b::uType - tmp::uType atmp::uNoUnitsType - J::JType - W::WType - uf::UF - jac_config::JC - linsolve::F nlsolver::N tab::Tab end @@ -474,18 +359,16 @@ function alg_cache(alg::ESDIRK54I8L2SA,u,rate_prototype,uEltypeNoUnits,uBottomEl γ, c = tab.γ, tab.γ J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getiipnlsolvefields + fsalfirst = zero(rate_prototype) z₁ = zero(u); z₂ = zero(u); z₃ = zero(u); z₄ = zero(u) - z₅ = zero(u); z₆ = zero(u); z₇ = zero(u); z₈ = z + z₅ = zero(u); z₆ = zero(u); z₇ = zero(u); z₈ = nlsolver.z atmp = similar(u,uEltypeNoUnits) - ESDIRK54I8L2SACache(u,uprev,du1,fsalfirst,k,z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈,dz,b,tmp,atmp,J, - W,uf,jac_config,linsolve,nlsolver,tab) + ESDIRK54I8L2SACache(u,uprev,fsalfirst,z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈,atmp,nlsolver,tab) end -mutable struct ESDIRK54I8L2SAConstantCache{F,N,Tab} <: OrdinaryDiffEqConstantCache - uf::F +mutable struct ESDIRK54I8L2SAConstantCache{N,Tab} <: OrdinaryDiffEqConstantCache nlsolver::N tab::Tab end @@ -496,6 +379,5 @@ function alg_cache(alg::ESDIRK54I8L2SA,u,rate_prototype,uEltypeNoUnits,uBottomEl γ, c = tab.γ,tab.γ J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits) nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c) - @getoopnlsolvefields - ESDIRK54I8L2SAConstantCache(uf,nlsolver,tab) + ESDIRK54I8L2SAConstantCache(nlsolver,tab) end diff --git a/src/perform_step/kencarp_kvaerno_perform_step.jl b/src/perform_step/kencarp_kvaerno_perform_step.jl index 57fd94c9fb..067436377d 100644 --- a/src/perform_step/kencarp_kvaerno_perform_step.jl +++ b/src/perform_step/kencarp_kvaerno_perform_step.jl @@ -23,7 +23,7 @@ function initialize!(integrator, cache::Union{Kvaerno3Cache, KenCarp5Cache}) integrator.kshortsize = 2 integrator.fsalfirst = cache.fsalfirst - integrator.fsallast = cache.k + integrator.fsallast = cache.nlsolver.k resize!(integrator.k, integrator.kshortsize) integrator.k[1] = integrator.fsalfirst integrator.k[2] = integrator.fsallast @@ -96,7 +96,8 @@ end @muladd function perform_step!(integrator, cache::Kvaerno3Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack dz,z₁,z₂,z₃,z₄,k,b,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,atmp,nlsolver = cache + @unpack dz,k,tmp = nlsolver @unpack γ,a31,a32,a41,a42,a43,btilde1,btilde2,btilde3,btilde4,c3,α31,α32 = cache.tab alg = unwrap_alg(integrator, true) @@ -154,7 +155,7 @@ end @.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) else tmp .= dz end @@ -289,7 +290,8 @@ end @muladd function perform_step!(integrator, cache::KenCarp3Cache, repeat_step=false) @unpack t,dt,uprev,u,p = integrator - @unpack dz,z₁,z₂,z₃,z₄,k1,k2,k3,k4,k,b,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,k1,k2,k3,k4,atmp,nlsolver = cache + @unpack dz,k,tmp = nlsolver @unpack γ,a31,a32,a41,a42,a43,btilde1,btilde2,btilde3,btilde4,c3,α31,α32 = cache.tab @unpack ea21,ea31,ea32,ea41,ea42,ea43,eb1,eb2,eb3,eb4 = cache.tab @unpack ebtilde1,ebtilde2,ebtilde3,ebtilde4 = cache.tab @@ -405,7 +407,7 @@ end end if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) else tmp .= dz end @@ -500,7 +502,8 @@ end @muladd function perform_step!(integrator, cache::Kvaerno4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack dz,z₁,z₂,z₃,z₄,z₅,k,b,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,z₅,atmp,nlsolver = cache + @unpack dz,k,tmp = nlsolver @unpack γ,a31,a32,a41,a42,a43,a51,a52,a53,a54,c3,c4 = cache.tab @unpack α21,α31,α32,α41,α42 = cache.tab @unpack btilde1,btilde2,btilde3,btilde4,btilde5 = cache.tab @@ -567,7 +570,7 @@ end @.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) else tmp .= dz end @@ -746,7 +749,8 @@ end @muladd function perform_step!(integrator, cache::KenCarp4Cache, repeat_step=false) @unpack t,dt,uprev,u,p = integrator - @unpack dz,z₁,z₂,z₃,z₄,z₅,z₆,k,b,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,z₅,z₆,atmp,nlsolver = cache + @unpack dz,k,tmp = nlsolver @unpack k1,k2,k3,k4,k5,k6 = cache @unpack γ,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a63,a64,a65,c3,c4,c5 = cache.tab @unpack α31,α32,α41,α42,α51,α52,α53,α54,α61,α62,α63,α64,α65 = cache.tab @@ -924,7 +928,7 @@ end if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) else tmp .= dz end @@ -1037,7 +1041,8 @@ end @muladd function perform_step!(integrator, cache::Kvaerno5Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack dz,z₁,z₂,z₃,z₄,z₅,z₆,z₇,k,b,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,z₅,z₆,z₇,atmp,nlsolver = cache + @unpack dz,k,tmp = nlsolver @unpack γ,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a63,a64,a65,a71,a73,a74,a75,a76,c3,c4,c5,c6 = cache.tab @unpack btilde1,btilde3,btilde4,btilde5,btilde6,btilde7 = cache.tab @unpack α31,α32,α41,α42,α43,α51,α52,α53,α61,α62,α63 = cache.tab @@ -1133,7 +1138,7 @@ end end if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) else tmp .= dz end @@ -1354,8 +1359,9 @@ end @muladd function perform_step!(integrator, cache::KenCarp5Cache, repeat_step=false) @unpack t,dt,uprev,u,p = integrator - @unpack dz,z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈,k,b,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈,atmp,nlsolver = cache @unpack k1,k2,k3,k4,k5,k6,k7,k8 = cache + @unpack dz,k,tmp = nlsolver @unpack γ,a31,a32,a41,a43,a51,a53,a54,a61,a63,a64,a65,a71,a73,a74,a75,a76,a81,a84,a85,a86,a87,c3,c4,c5,c6,c7 = cache.tab @unpack α31,α32,α41,α42,α51,α52,α61,α62,α71,α72,α73,α74,α75,α81,α82,α83,α84,α85 = cache.tab @unpack btilde1,btilde4,btilde5,btilde6,btilde7,btilde8 = cache.tab @@ -1584,7 +1590,7 @@ end if isnewton(nlsolver) && alg.smooth_est # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) + nlsolver.linsolve(vec(tmp),get_W(nlsolver),vec(dz),false) else tmp .= dz end diff --git a/src/perform_step/sdirk_perform_step.jl b/src/perform_step/sdirk_perform_step.jl index bf59053abd..efcee3e2e9 100644 --- a/src/perform_step/sdirk_perform_step.jl +++ b/src/perform_step/sdirk_perform_step.jl @@ -153,7 +153,8 @@ end @muladd function perform_step!(integrator, cache::ImplicitMidpointCache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack z,tmp,nlsolver = cache + @unpack nlsolver = cache + @unpack z,tmp = nlsolver mass_matrix = integrator.f.mass_matrix alg = unwrap_alg(integrator, true) γ = 1//2 @@ -240,7 +241,8 @@ end @muladd function perform_step!(integrator, cache::TrapezoidCache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack z,jac_config,tmp,atmp,nlsolver = cache + @unpack atmp,nlsolver = cache + @unpack z,jac_config,tmp = nlsolver alg = unwrap_alg(integrator, true) mass_matrix = integrator.f.mass_matrix @@ -356,7 +358,10 @@ end @muladd function perform_step!(integrator, cache::TRBDF2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack zprev,dz,zᵧ,z,k,b,W,tmp,atmp,nlsolver = cache + @unpack zprev,zᵧ,atmp,nlsolver = cache + @unpack dz,z,k,tmp = nlsolver + W = isnewton(nlsolver) ? get_W(nlsolver) : nothing + b = nlsolver.ztmp @unpack γ,d,ω,btilde1,btilde2,btilde3,α1,α2 = cache.tab alg = unwrap_alg(integrator, true) @@ -393,9 +398,9 @@ end if integrator.opts.adaptive @.. dz = btilde1*zprev + btilde2*zᵧ + btilde3*z - if alg.smooth_est # From Shampine + if alg.smooth_est && isnewton(nlsolver) # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),W,vec(dz),false) + nlsolver.linsolve(vec(tmp),W,vec(dz),false) else tmp .= dz end @@ -458,7 +463,9 @@ end @muladd function perform_step!(integrator, cache::SDIRK2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack dz,z₁,z₂,k,b,W,jac_config,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,atmp,nlsolver = cache + @unpack dz,k,jac_config,tmp = nlsolver + W = isnewton(nlsolver) ? get_W(nlsolver) : nothing alg = unwrap_alg(integrator, true) update_W!(integrator, cache, dt, repeat_step) @@ -495,9 +502,9 @@ end if integrator.opts.adaptive @.. dz = z₁/2 - z₂/2 - if alg.smooth_est # From Shampine + if alg.smooth_est && isnewton(nlsolver) # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),W,vec(dz),false) + nlsolver.linsolve(vec(tmp),W,vec(dz),false) else tmp .= dz end @@ -564,7 +571,8 @@ end @muladd function perform_step!(integrator, cache::SSPSDIRK2Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack dz,z₁,z₂,k,b,W,jac_config,tmp,nlsolver = cache + @unpack z₁,z₂,nlsolver = cache + @unpack dz,k,jac_config,tmp = nlsolver alg = unwrap_alg(integrator, true) γ = eltype(u)(1//4) @@ -702,7 +710,9 @@ end @muladd function perform_step!(integrator, cache::Cash4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack dz,z₁,z₂,z₃,z₄,z₅,k,b,W,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,z₅,atmp,nlsolver = cache + @unpack dz,k,tmp = nlsolver + W = isnewton(nlsolver) ? get_W(nlsolver) : nothing @unpack γ,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,c2,c3,c4 = cache.tab @unpack b1hat1,b2hat1,b3hat1,b4hat1,b1hat2,b2hat2,b3hat2,b4hat2 = cache.tab alg = unwrap_alg(integrator, true) @@ -778,9 +788,9 @@ end end @.. dz = btilde1*z₁ + btilde2*z₂ + btilde3*z₃ + btilde4*z₄ + btilde5*z₅ - if alg.smooth_est # From Shampine + if alg.smooth_est && isnewton(nlsolver) # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),W,vec(dz),false) + nlsolver.linsolve(vec(tmp),W,vec(dz),false) else tmp .= dz end @@ -871,7 +881,9 @@ end @muladd function perform_step!(integrator, cache::Hairer4Cache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack dz,z₁,z₂,z₃,z₄,z₅,k,b,W,jac_config,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,z₅,atmp,nlsolver = cache + @unpack dz,k,jac_config,tmp = nlsolver + W = isnewton(nlsolver) ? get_W(nlsolver) : nothing @unpack γ,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,c2,c3,c4 = cache.tab @unpack α21,α31,α32,α41,α43 = cache.tab @unpack bhat1,bhat2,bhat3,bhat4,btilde1,btilde2,btilde3,btilde4,btilde5 = cache.tab @@ -945,9 +957,9 @@ end @tight_loop_macros for i in eachindex(u) dz[i] = btilde1*z₁[i] + btilde2*z₂[i] + btilde3*z₃[i] + btilde4*z₄[i] + btilde5*z₅[i] end - if alg.smooth_est # From Shampine + if alg.smooth_est && isnewton(nlsolver) # From Shampine integrator.destats.nsolve += 1 - cache.linsolve(vec(tmp),W,vec(dz),false) + nlsolver.linsolve(vec(tmp),W,vec(dz),false) else tmp .= dz end @@ -1065,7 +1077,8 @@ end @muladd function perform_step!(integrator, cache::ESDIRK54I8L2SACache, repeat_step=false) @unpack t,dt,uprev,u,f,p = integrator - @unpack z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈,k,b,tmp,atmp,nlsolver = cache + @unpack z₁,z₂,z₃,z₄,z₅,z₆,z₇,z₈,atmp,nlsolver = cache + @unpack k,tmp = cache @unpack γ, a31, a32, a41, a42, a43,