Skip to content

Commit

Permalink
Merge d7960b8 into 9566c8b
Browse files Browse the repository at this point in the history
  • Loading branch information
kanav99 committed Aug 18, 2019
2 parents 9566c8b + d7960b8 commit 1436d9b
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 274 deletions.
57 changes: 28 additions & 29 deletions src/caches/extrapolation_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function alg_cache(alg::AitkenNeville,u,rate_prototype,uEltypeNoUnits,uBottomElt
AitkenNevilleConstantCache(dtpropose,T,cur_order,work,A,step_no)
end

@cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache
@cache mutable struct ImplicitEulerExtrapolationCache{uType,rateType,arrayType,dtType,F,JCType,GCType,uNoUnitsType,TFType,UFType,N} <: OrdinaryDiffEqMutableCache
uprev::uType
u_tmps::Array{uType,1}
utilde::uType
Expand All @@ -82,8 +82,7 @@ end
step_no::Int
du1::rateType
du2::rateType
J::JType
W::WType
nlsolver::N
tf::TFType
uf::UFType
linsolve_tmps::Array{rateType,1}
Expand All @@ -92,7 +91,7 @@ end
grad_config::GCType
end

@cache mutable struct ImplicitEulerExtrapolationConstantCache{dtType,arrayType,TF,UF} <: OrdinaryDiffEqConstantCache
@cache mutable struct ImplicitEulerExtrapolationConstantCache{dtType,arrayType,TF,UF,N} <: OrdinaryDiffEqConstantCache
dtpropose::dtType
T::arrayType
cur_order::Int
Expand All @@ -102,6 +101,7 @@ end

tf::TF
uf::UF
nlsolver::N
end

function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
Expand All @@ -114,7 +114,8 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
step_no = zero(Int)
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
ImplicitEulerExtrapolationConstantCache(dtpropose,T,cur_order,work,A,step_no,tf,uf)
nlsolver = SemiImplicitNLSolver(nothing,nothing,nothing,uf,nothing)
ImplicitEulerExtrapolationConstantCache(dtpropose,T,cur_order,work,A,step_no,tf,uf,nlsolver)
end

function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
Expand Down Expand Up @@ -161,16 +162,6 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
W_el = similar(J)
end

W = Array{typeof(W_el),1}(undef, Threads.nthreads())
W[1] = W_el
for i=2:Threads.nthreads()
if W_el isa WOperator
W[i] = WOperator(f, dt, true)
else
W[i] = zero(W_el)
end
end

tf = DiffEqDiffTools.TimeGradientWrapper(f,uprev,p)
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
linsolve_tmp = zero(rate_prototype)
Expand All @@ -189,9 +180,14 @@ function alg_cache(alg::ImplicitEulerExtrapolation,u,rate_prototype,uEltypeNoUni
grad_config = build_grad_config(alg,f,tf,du1,t)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2)

nlsolver = Array{SemiImplicitNLSolver{typeof(W_el),typeof(J),typeof(du1),typeof(uf),typeof(jac_config)}}(undef, Threads.nthreads())
nlsolver[1] = SemiImplicitNLSolver(W_el,J,du1,uf,jac_config)
for i=2:Threads.nthreads()
nlsolver[i] = SemiImplicitNLSolver(zero(W_el),J,du1,uf,jac_config)
end

ImplicitEulerExtrapolationCache(uprev,u_tmps,utilde,tmp,atmp,k_tmps,dtpropose,T,cur_order,work,A,step_no,
du1,du2,J,W,tf,uf,linsolve_tmps,linsolve,jac_config,grad_config)
du1,du2,nlsolver,tf,uf,linsolve_tmps,linsolve,jac_config,grad_config)
end


Expand Down Expand Up @@ -387,7 +383,7 @@ function alg_cache(alg::ExtrapolationMidpointDeuflhard,u,rate_prototype,uEltypeN
ExtrapolationMidpointDeuflhardCache(utilde, u_temp1, u_temp2, u_temp3, u_temp4, tmp, T, res, fsalfirst, k, k_tmps, cc.Q, cc.n_curr, cc.n_old, cc.coefficients,cc.stage_number)
end

@cache mutable struct ImplicitDeuflhardExtrapolationConstantCache{QType,extrapolation_coefficients,TF,UF} <: OrdinaryDiffEqConstantCache
@cache mutable struct ImplicitDeuflhardExtrapolationConstantCache{QType,extrapolation_coefficients,TF,UF,N} <: OrdinaryDiffEqConstantCache
# Values that are mutated
Q::Vector{QType} # Storage for stepsize scaling factors. Q[n] contains information for extrapolation order (n + alg.n_min - 1)
n_curr::Int64 # Storage for the current extrapolation order
Expand All @@ -399,9 +395,10 @@ end

tf::TF
uf::UF
nlsolver::N
end

@cache mutable struct ImplicitDeuflhardExtrapolationCache{uType,QType,extrapolation_coefficients,rateType,JType,WType,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache
@cache mutable struct ImplicitDeuflhardExtrapolationCache{uType,QType,extrapolation_coefficients,rateType,N,F,JCType,GCType,uNoUnitsType,TFType,UFType} <: OrdinaryDiffEqMutableCache
# Values that are mutated
utilde::uType
u_temp1::uType
Expand All @@ -426,8 +423,7 @@ end

du1::rateType
du2::rateType
J::JType
W::WType
nlsolver::N
tf::TFType
uf::UFType
linsolve_tmp::rateType
Expand All @@ -449,7 +445,8 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation,u,rate_prototype,uEltypeN

tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
ImplicitDeuflhardExtrapolationConstantCache(Q,n_curr,n_old,coefficients,stage_number,tf,uf)
nlsolver = SemiImplicitNLSolver(nothing,nothing,nothing,uf,nothing)
ImplicitDeuflhardExtrapolationConstantCache(Q,n_curr,n_old,coefficients,stage_number,tf,uf,nlsolver)
end

function alg_cache(alg::ImplicitDeuflhardExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
Expand Down Expand Up @@ -496,10 +493,10 @@ function alg_cache(alg::ImplicitDeuflhardExtrapolation,u,rate_prototype,uEltypeN
linsolve = alg.linsolve(Val{:init},uf,u)
grad_config = build_grad_config(alg,f,tf,du1,t)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2)

nlsolver = SemiImplicitNLSolver(W,J,du1,uf,jac_config)

ImplicitDeuflhardExtrapolationCache(utilde,u_temp1,u_temp2,u_temp3,u_temp4,tmp,T,res,fsalfirst,k,k_tmps,cc.Q,cc.n_curr,cc.n_old,cc.coefficients,cc.stage_number,
du1,du2,J,W,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config)
du1,du2,nlsolver,tf,uf,linsolve_tmp,linsolve,jac_config,grad_config)
end

@cache mutable struct ExtrapolationMidpointHairerWannerConstantCache{QType,extrapolation_coefficients} <: OrdinaryDiffEqConstantCache
Expand Down Expand Up @@ -587,7 +584,7 @@ function alg_cache(alg::ExtrapolationMidpointHairerWanner,u,rate_prototype,uElty
cc.Q, cc.n_curr, cc.n_old, cc.coefficients, cc.stage_number, cc.sigma)
end

@cache mutable struct ImplicitHairerWannerExtrapolationConstantCache{QType,extrapolation_coefficients,TF,UF} <: OrdinaryDiffEqConstantCache
@cache mutable struct ImplicitHairerWannerExtrapolationConstantCache{QType,extrapolation_coefficients,TF,UF,N} <: OrdinaryDiffEqConstantCache
# Values that are mutated
Q::Vector{QType} # Storage for stepsize scaling factors. Q[n] contains information for extrapolation order (n - 1)
n_curr::Int64 # Storage for the current extrapolation order
Expand All @@ -600,6 +597,7 @@ end

tf::TF
uf::UF
nlsolver::N
end

function alg_cache(alg::ImplicitHairerWannerExtrapolation,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
Expand All @@ -617,10 +615,11 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation,u,rate_prototype,uElty
# Initialize the constant cache
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
ImplicitHairerWannerExtrapolationConstantCache(Q, n_curr, n_old, coefficients, stage_number, sigma, tf, uf)
nlsolver = SemiImplicitNLSolver(nothing,nothing,nothing,uf,nothing)
ImplicitHairerWannerExtrapolationConstantCache(Q, n_curr, n_old, coefficients, stage_number, sigma, tf, uf, nlsolver)
end

@cache mutable struct ImplicitHairerWannerExtrapolationCache{uType,uNoUnitsType,rateType,QType,extrapolation_coefficients,JType,WType,F,JCType,GCType,TFType,UFType} <: OrdinaryDiffEqMutableCache
@cache mutable struct ImplicitHairerWannerExtrapolationCache{uType,uNoUnitsType,rateType,QType,extrapolation_coefficients,N,F,JCType,GCType,TFType,UFType} <: OrdinaryDiffEqMutableCache
# Values that are mutated
utilde::uType
u_temp1::uType
Expand All @@ -645,8 +644,7 @@ end

du1::rateType
du2::rateType
J::JType
W::WType
nlsolver::N
tf::TFType
uf::UFType
linsolve_tmp::rateType
Expand Down Expand Up @@ -699,9 +697,10 @@ function alg_cache(alg::ImplicitHairerWannerExtrapolation,u,rate_prototype,uElty
linsolve = alg.linsolve(Val{:init},uf,u)
grad_config = build_grad_config(alg,f,tf,du1,t)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,du1,du2)
nlsolver = SemiImplicitNLSolver(W,J,du1,uf,jac_config)

# Initialize the cache
ImplicitHairerWannerExtrapolationCache(utilde, u_temp1, u_temp2, u_temp3, u_temp4, tmp, T, res, fsalfirst, k, k_tmps,
cc.Q, cc.n_curr, cc.n_old, cc.coefficients, cc.stage_number, cc.sigma, du1, du2, J, W, tf, uf, linsolve_tmp,
cc.Q, cc.n_curr, cc.n_old, cc.coefficients, cc.stage_number, cc.sigma, du1, du2, nlsolver, tf, uf, linsolve_tmp,
linsolve, jac_config, grad_config)
end

0 comments on commit 1436d9b

Please sign in to comment.