Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy W operator support #443

Merged
merged 43 commits into from
Jul 29, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
71147ee
`WOperator` and test
MSeeker1340 Jul 19, 2018
42ce8c5
Merge remote-tracking branch 'origin/master' into lazy-W
MSeeker1340 Jul 20, 2018
a4addeb
Adjust position for `WOperator`
MSeeker1340 Jul 20, 2018
197fc40
Update `calc_W!` for constant cache
MSeeker1340 Jul 20, 2018
1a1f5ad
Scalar `WOperator` compatibility
MSeeker1340 Jul 20, 2018
2f17963
Mass matrix for finite differece `calc_W!`
MSeeker1340 Jul 20, 2018
df84f70
Fix non-scalar adaptive `ImplicitEuler`
MSeeker1340 Jul 20, 2018
c57604b
Merge remote-tracking branch 'origin/master' into lazy-W
MSeeker1340 Jul 20, 2018
9f40d55
Support transformed W
MSeeker1340 Jul 20, 2018
66c62aa
Merge remote-tracking branch 'origin/master' into lazy-W
MSeeker1340 Jul 20, 2018
2c5850e
Merge remote-tracking branch 'origin/master' into lazy-W
MSeeker1340 Jul 22, 2018
f7e425e
Add constructor with uninitiated `mass_matrix`
MSeeker1340 Jul 22, 2018
4bd3617
Add `WOperator` support to in-place `calc_W!`
MSeeker1340 Jul 22, 2018
bb185eb
Update `J` and `W` allocation for `ImplicitEulerCache`
MSeeker1340 Jul 22, 2018
d175f9b
Use `mass_matrix` from `ODEFunction`
MSeeker1340 Jul 24, 2018
7b7a1cd
Merge remote-tracking branch 'origin/master' into lazy-W
MSeeker1340 Jul 24, 2018
08f460d
Merge remote-tracking branch 'origin/master' into lazy-W
MSeeker1340 Jul 25, 2018
54b9029
Safer lazy W predicate
MSeeker1340 Jul 25, 2018
fcc937f
Non-allocating `convert` method
MSeeker1340 Jul 25, 2018
9fe4202
Use lazy cache for `mul!`
MSeeker1340 Jul 25, 2018
e6f639c
Bugfix
MSeeker1340 Jul 26, 2018
5cedc29
Remove `lazy_W` and always use `WOperator`
MSeeker1340 Jul 26, 2018
bf600ea
Constructor using `f`
MSeeker1340 Jul 26, 2018
c48a0de
W_transform consistency
MSeeker1340 Jul 26, 2018
7983293
Update test script
MSeeker1340 Jul 26, 2018
21b1a5b
Improve handling of out-of-place jacobian
MSeeker1340 Jul 26, 2018
be258f8
Merge remote-tracking branch 'origin/master' into lazy-W
MSeeker1340 Jul 28, 2018
66a38fb
Update sdirk methods
MSeeker1340 Jul 28, 2018
f01872d
Update ABM methods
MSeeker1340 Jul 28, 2018
e3c1505
Update bdf methods
MSeeker1340 Jul 28, 2018
3f8cefe
Update EulerIMEX
MSeeker1340 Jul 28, 2018
c4437ea
Update Kencarp Kvaerno methods
MSeeker1340 Jul 28, 2018
1eba5e9
Update Rosenbrock methods
MSeeker1340 Jul 28, 2018
b8ae169
Add back concrete W for functions without `jac`
MSeeker1340 Jul 29, 2018
e0eea48
Bugfix
MSeeker1340 Jul 29, 2018
72cde28
Fix invW
MSeeker1340 Jul 29, 2018
80d072b
Fix differentiation traits tests
MSeeker1340 Jul 29, 2018
87e174d
Fix test scripts
MSeeker1340 Jul 29, 2018
90f377b
Address `jac_prototype=nothing` case
MSeeker1340 Jul 29, 2018
e2beccb
Merge remote-tracking branch 'origin/master' into lazy-W
MSeeker1340 Jul 29, 2018
b520fe0
Integration test
MSeeker1340 Jul 29, 2018
c1e53da
Update REQUIRE
ChrisRackauckas Jul 29, 2018
284fa70
Update utility_tests.jl
ChrisRackauckas Jul 29, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
julia 0.7-beta2
DiffEqBase 3.8.0
DiffEqOperators 3.2.0
Parameters 0.5.0
ForwardDiff 0.7.0
GenericSVD 0.0.2
Expand Down
2 changes: 2 additions & 0 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ module OrdinaryDiffEq
# Internal utils
import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_ISOUTOFDOMAIN, ODE_DEFAULT_PROG_MESSAGE, ODE_DEFAULT_UNSTABLE_CHECK

using DiffEqOperators: DiffEqArrayOperator

import RecursiveArrayTools: chain, recursivecopy!

using Parameters, GenericSVD, ForwardDiff, RecursiveArrayTools,
Expand Down
26 changes: 18 additions & 8 deletions src/caches/adams_bashforth_moulton_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -933,7 +933,7 @@ mutable struct CNAB2ConstantCache{rateType,F,uToltype,uType,tType} <: OrdinaryDi
tprev2::tType
end

mutable struct CNAB2Cache{uType,rateType,uNoUnitsType,J,UF,JC,uToltype,tType,F} <: OrdinaryDiffEqMutableCache
mutable struct CNAB2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,uToltype,tType,F} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
Expand All @@ -949,7 +949,7 @@ mutable struct CNAB2Cache{uType,rateType,uNoUnitsType,J,UF,JC,uToltype,tType,F}
tmp::uType
atmp::uNoUnitsType
J::J
W::J
W::W
uf::UF
jac_config::JC
linsolve::F
Expand Down Expand Up @@ -987,8 +987,13 @@ function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni
end

function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
z = similar(u,axes(u))
dz = similar(u,axes(u))
tmp = similar(u); b = similar(u,axes(u));
Expand Down Expand Up @@ -1043,7 +1048,7 @@ mutable struct CNLF2ConstantCache{rateType,F,uToltype,uType,tType} <: OrdinaryDi
tprev2::tType
end

mutable struct CNLF2Cache{uType,rateType,uNoUnitsType,J,UF,JC,uToltype,tType,F} <: OrdinaryDiffEqMutableCache
mutable struct CNLF2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,uToltype,tType,F} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
Expand All @@ -1059,7 +1064,7 @@ mutable struct CNLF2Cache{uType,rateType,uNoUnitsType,J,UF,JC,uToltype,tType,F}
tmp::uType
atmp::uNoUnitsType
J::J
W::J
W::W
uf::UF
jac_config::JC
linsolve::F
Expand Down Expand Up @@ -1098,8 +1103,13 @@ function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni
end

function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
z = similar(u,axes(u))
dz = similar(u,axes(u))
tmp = similar(u); b = similar(u,axes(u));
Expand Down
52 changes: 36 additions & 16 deletions src/caches/bdf_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function alg_cache(alg::ABDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni
ABDF2ConstantCache(uf, ηold, κ, tol, 10000, eulercache, dtₙ₋₁, fsalfirstprev)
end

mutable struct ABDF2Cache{uType,rateType,uNoUnitsType,J,UF,JC,uToltype,F,dtType} <: OrdinaryDiffEqMutableCache
mutable struct ABDF2Cache{uType,rateType,uNoUnitsType,J,W,UF,JC,uToltype,F,dtType} <: OrdinaryDiffEqMutableCache
uₙ::uType
uₙ₋₁::uType
uₙ₋₂::uType
Expand All @@ -49,7 +49,7 @@ mutable struct ABDF2Cache{uType,rateType,uNoUnitsType,J,UF,JC,uToltype,F,dtType}
tmp::uType
atmp::uNoUnitsType
J::J
W::J
W::W
uf::UF
jac_config::JC
linsolve::F
Expand All @@ -67,8 +67,13 @@ du_cache(c::ABDF2Cache) = (c.k,c.fsalfirst)
function alg_cache(alg::ABDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
du1 = zero(rate_prototype)
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
zprev = similar(u,axes(u))
zₙ₋₁ = similar(u,axes(u)); z = similar(u,axes(u))
dz = similar(u,axes(u))
Expand Down Expand Up @@ -119,7 +124,7 @@ mutable struct QNDF1ConstantCache{F,uToltype,coefType,coefType1,dtType,uType} <:
dtₙ₋₁::dtType
end

mutable struct QNDF1Cache{uType,rateType,coefType,coefType1,coefType2,uNoUnitsType,J,UF,JC,uToltype,F,dtType} <: OrdinaryDiffEqMutableCache
mutable struct QNDF1Cache{uType,rateType,coefType,coefType1,coefType2,uNoUnitsType,J,W,UF,JC,uToltype,F,dtType} <: OrdinaryDiffEqMutableCache
uprev2::uType
du1::rateType
fsalfirst::rateType
Expand All @@ -135,7 +140,7 @@ mutable struct QNDF1Cache{uType,rateType,coefType,coefType1,coefType2,uNoUnitsTy
atmp::uNoUnitsType
utilde::uType
J::J
W::J
W::W
uf::UF
jac_config::JC
linsolve::F
Expand Down Expand Up @@ -179,8 +184,13 @@ end

function alg_cache(alg::QNDF1,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
du1 = zero(rate_prototype)
J = fill(zero(uEltypeNoUnits),length(u),length(u))
W = similar(J)
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
z = similar(u,axes(u))
dz = similar(u,axes(u))
fsalfirst = zero(rate_prototype)
Expand Down Expand Up @@ -243,7 +253,7 @@ mutable struct QNDF2ConstantCache{F,uToltype,coefType,coefType1,uType,dtType} <:
dtₙ₋₂::dtType
end

mutable struct QNDF2Cache{uType,rateType,coefType,coefType1,coefType2,uNoUnitsType,J,UF,JC,uToltype,F,dtType} <: OrdinaryDiffEqMutableCache
mutable struct QNDF2Cache{uType,rateType,coefType,coefType1,coefType2,uNoUnitsType,J,W,UF,JC,uToltype,F,dtType} <: OrdinaryDiffEqMutableCache
uprev2::uType
uprev3::uType
du1::rateType
Expand All @@ -260,7 +270,7 @@ mutable struct QNDF2Cache{uType,rateType,coefType,coefType1,coefType2,uNoUnitsTy
atmp::uNoUnitsType
utilde::uType
J::J
W::J
W::W
uf::UF
jac_config::JC
linsolve::F
Expand Down Expand Up @@ -307,8 +317,13 @@ end

function alg_cache(alg::QNDF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
du1 = zero(rate_prototype)
J = fill(zero(uEltypeNoUnits),length(u),length(u))
W = similar(J)
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
z = similar(u,axes(u))
dz = similar(u,axes(u))
fsalfirst = zero(rate_prototype)
Expand Down Expand Up @@ -374,7 +389,7 @@ mutable struct QNDFConstantCache{F,uToltype,coefType1,coefType2,coefType3,uType,
c::Int64
end

mutable struct QNDFCache{uType,rateType,coefType1,coefType,coefType2,coefType3,dtType,dtsType,uNoUnitsType,J,UF,JC,uToltype,F} <: OrdinaryDiffEqMutableCache
mutable struct QNDFCache{uType,rateType,coefType1,coefType,coefType2,coefType3,dtType,dtsType,uNoUnitsType,J,W,UF,JC,uToltype,F} <: OrdinaryDiffEqMutableCache
du1::rateType
fsalfirst::rateType
k::rateType
Expand All @@ -393,7 +408,7 @@ mutable struct QNDFCache{uType,rateType,coefType1,coefType,coefType2,coefType3,d
atmp::uNoUnitsType
utilde::uType
J::J
W::J
W::W
uf::UF
jac_config::JC
linsolve::F
Expand Down Expand Up @@ -440,8 +455,13 @@ end

function alg_cache(alg::QNDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
du1 = zero(rate_prototype)
J = fill(zero(uEltypeNoUnits),length(u),length(u))
W = similar(J)
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
z = similar(u,axes(u))
dz = similar(u,axes(u))
fsalfirst = zero(rate_prototype)
Expand Down
13 changes: 9 additions & 4 deletions src/caches/euler_imex_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ mutable struct IMEXEulerConstantCache{F,uToltype} <: OrdinaryDiffEqConstantCache
newton_iters::Int
end

mutable struct IMEXEulerCache{uType,rateType,uNoUnitsType,J,UF,JC,uToltype,F} <: OrdinaryDiffEqMutableCache
mutable struct IMEXEulerCache{uType,rateType,uNoUnitsType,J,W,UF,JC,uToltype,F} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
Expand All @@ -22,7 +22,7 @@ mutable struct IMEXEulerCache{uType,rateType,uNoUnitsType,J,UF,JC,uToltype,F} <:
tmp::uType
atmp::uNoUnitsType
J::J
W::J
W::W
uf::UF
jac_config::JC
linsolve::F
Expand Down Expand Up @@ -55,8 +55,13 @@ function alg_cache(alg::IMEXEuler,u,rate_prototype,uEltypeNoUnits,uBottomEltypeN
end

function alg_cache(alg::IMEXEuler,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{true}})
J = fill(zero(uEltypeNoUnits),length(u),length(u))
W = similar(J)
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
else
J = fill(zero(uEltypeNoUnits),length(u),length(u)) # uEltype?
W = similar(J)
end
z = similar(u,axes(u))
dz = similar(u,axes(u)); tmp = similar(u,axes(u)); b = similar(u,axes(u))
fsalfirst = zero(rate_prototype)
Expand Down