Skip to content

Commit

Permalink
Merge pull request #830 from JuliaDiffEq/WJshift
Browse files Browse the repository at this point in the history
[WIP] Remove redundant NLSolve fields from alg cache
  • Loading branch information
ChrisRackauckas authored Jul 16, 2019
2 parents 6a8c91c + d10a9f5 commit 9876d03
Show file tree
Hide file tree
Showing 16 changed files with 416 additions and 607 deletions.
2 changes: 1 addition & 1 deletion src/bdf_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ end
# Implementation of an Adaptive BDF2 Formula and Comparison with the MATLAB Ode15s paper
# E. Alberdi Celaya, J. J. Anza Aguirrezabala, and P. Chatzipantelidis
function reinterpolate_history!(cache::OrdinaryDiffEqMutableCache, D, R, k)
@unpack tmp = cache
@unpack tmp = cache.nlsolver
fill!(tmp,zero(eltype(D[1])))
for j = 1:k
for k = 1:k
Expand Down
58 changes: 16 additions & 42 deletions src/caches/adams_bashforth_moulton_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -883,120 +883,94 @@ end

# CNAB2

@cache mutable struct CNAB2ConstantCache{rateType,F,N,uType,tType} <: OrdinaryDiffEqConstantCache
@cache mutable struct CNAB2ConstantCache{rateType,N,uType,tType} <: OrdinaryDiffEqConstantCache
k2::rateType
uf::F
nlsolver::N
uprev3::uType
tprev2::tType
end

@cache mutable struct CNAB2Cache{uType,rateType,JType,WType,UF,JC,N,tType,F} <: OrdinaryDiffEqMutableCache
@cache mutable struct CNAB2Cache{uType,rateType,N,tType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
fsalfirst::rateType
k::rateType
k1::rateType
k2::rateType
du₁::rateType
du1::rateType
z::uType
dz::uType
b::uType
tmp::uType
J::JType
W::WType
uf::UF
jac_config::JC
linsolve::F
nlsolver::N
uprev3::uType
tprev2::tType
end

function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
γ, c = 1//2, 1
W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
@getoopnlsolvefields
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)

k2 = rate_prototype
uprev3 = u
tprev2 = t

CNAB2ConstantCache(k2,uf,nlsolver,uprev3,tprev2)
CNAB2ConstantCache(k2,nlsolver,uprev3,tprev2)
end

function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
γ, c = 1//2, 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
@getiipnlsolvefields
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
fsalfirst = zero(rate_prototype)

k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
du₁ = zero(rate_prototype)
uprev3 = zero(u)
tprev2 = t

CNAB2Cache(u,uprev,uprev2,fsalfirst,k,k1,k2,du₁,du1,z,dz,b,tmp,J,W,uf,jac_config,linsolve,nlsolver,uprev3,tprev2)
CNAB2Cache(u,uprev,uprev2,fsalfirst,k1,k2,du₁,nlsolver,uprev3,tprev2)
end

# CNLF2

@cache mutable struct CNLF2ConstantCache{rateType,F,N,uType,tType} <: OrdinaryDiffEqConstantCache
@cache mutable struct CNLF2ConstantCache{rateType,N,uType,tType} <: OrdinaryDiffEqConstantCache
k2::rateType
uf::F
nlsolver::N
uprev2::uType
uprev3::uType
tprev2::tType
end

@cache mutable struct CNLF2Cache{uType,rateType,JType,WType,UF,JC,N,tType,F} <: OrdinaryDiffEqMutableCache
@cache mutable struct CNLF2Cache{uType,rateType,N,tType} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
fsalfirst::rateType
k::rateType
k1::rateType
k2::rateType
du₁::rateType
du1::rateType
z::uType
dz::uType
b::uType
tmp::uType
J::JType
W::WType
uf::UF
jac_config::JC
linsolve::F
nlsolver::N
uprev3::uType
tprev2::tType
end

function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
γ, c = 1//1, 1
W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
@getoopnlsolvefields
J, W = oop_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = oopnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)

k2 = rate_prototype
uprev2 = u
uprev3 = u
tprev2 = t

CNLF2ConstantCache(k2,uf,nlsolver,uprev2,uprev3,tprev2)
CNLF2ConstantCache(k2,nlsolver,uprev2,uprev3,tprev2)
end

function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
γ, c = 1//1, 1
J, W = iip_generate_W(alg,u,uprev,p,t,dt,f,uEltypeNoUnits)
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
@getiipnlsolvefields
nlsolver = iipnlsolve(alg,u,uprev,p,t,dt,f,W,J,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,γ,c)
fsalfirst = zero(rate_prototype)

k1 = zero(rate_prototype)
k2 = zero(rate_prototype)
Expand All @@ -1005,5 +979,5 @@ function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni
uprev3 = zero(u)
tprev2 = t

CNLF2Cache(u,uprev,uprev2,fsalfirst,k,k1,k2,du₁,du1,z,dz,b,tmp,J,W,uf,jac_config,linsolve,nlsolver,uprev3,tprev2)
CNLF2Cache(u,uprev,uprev2,fsalfirst,k1,k2,du₁,nlsolver,uprev3,tprev2)
end
Loading

0 comments on commit 9876d03

Please sign in to comment.