Skip to content

Commit

Permalink
Merge pull request #501 from JuliaDiffEq/myb/fastnewton
Browse files Browse the repository at this point in the history
Fast path when `f` is linear
  • Loading branch information
ChrisRackauckas committed Oct 4, 2018
2 parents 215e441 + d5d2db8 commit 2ddf833
Show file tree
Hide file tree
Showing 10 changed files with 59 additions and 80 deletions.
2 changes: 1 addition & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ module OrdinaryDiffEq
# Internal utils
import DiffEqBase: ODE_DEFAULT_NORM, ODE_DEFAULT_ISOUTOFDOMAIN, ODE_DEFAULT_PROG_MESSAGE, ODE_DEFAULT_UNSTABLE_CHECK

using DiffEqOperators: DiffEqArrayOperator
using DiffEqOperators: DiffEqArrayOperator, DEFAULT_UPDATE_FUNC

import RecursiveArrayTools: chain, recursivecopy!

Expand Down
4 changes: 4 additions & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -767,3 +767,7 @@ const MassMatrixAlgorithms = Union{OrdinaryDiffEqRosenbrockAlgorithm,
const MultistepAlgorithms = Union{IRKN3,IRKN4,
ABDF2,
AB3,AB4,AB5,ABM32,ABM43,ABM54}

const SplitAlgorithms = Union{CNAB2,CNLF2,SBDF,
GenericIIF1,GenericIIF2,
KenCarp3,KenCarp4,KenCarp5}
14 changes: 0 additions & 14 deletions src/caches/adams_bashforth_moulton_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,6 @@ du_cache(c::CNAB2Cache) = ()
function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
@oopnlcachefields
k2 = rate_prototype
uf != nothing && ( uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p) )
uprev3 = u
tprev2 = t

Expand All @@ -976,12 +975,6 @@ function alg_cache(alg::CNAB2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni
k2 = zero(rate_prototype)
du₁ = zero(rate_prototype)

if typeof(f) <: SplitFunction && uf != nothing
uf = DiffEqDiffTools.UJacobianWrapper(f.f1,t,p)
linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
end

uprev3 = similar(u)
tprev2 = t

Expand Down Expand Up @@ -1031,7 +1024,6 @@ du_cache(c::CNLF2Cache) = ()
function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
@oopnlcachefields
k2 = rate_prototype
uf != nothing && ( uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p) )
uprev2 = u
uprev3 = u
tprev2 = t
Expand All @@ -1047,12 +1039,6 @@ function alg_cache(alg::CNLF2,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUni
k2 = zero(rate_prototype)
du₁ = zero(rate_prototype)

if typeof(f) <: SplitFunction && uf != nothing
uf = DiffEqDiffTools.UJacobianWrapper(f.f1,t,p)
linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
end

uprev2 = similar(u)
uprev3 = similar(u)
tprev2 = t
Expand Down
6 changes: 0 additions & 6 deletions src/caches/bdf_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ du_cache(c::SBDFCache) = ()
function alg_cache(alg::SBDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
@oopnlcachefields
k2 = rate_prototype
uf != nothing && ( uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p) )
uprev2 = u
uprev3 = u
uprev4 = u
Expand All @@ -138,11 +137,6 @@ function alg_cache(alg::SBDF,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnit
k₃ = order == 4 ? zero(rate_prototype) : k₁
du₁ = zero(rate_prototype)
du₂ = zero(rate_prototype)
if uf != nothing
uf = DiffEqDiffTools.UJacobianWrapper(f.f1,t,p)
linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
end

nlsolve = typeof(_nlsolve)(NLSolverCache(κ,tol,min_iter,max_iter,10000,new_W,z,W,1//1,1,ηold,z₊,dz,tmp,b,k))
SBDFCache(1,u,uprev,fsalfirst,k,du1,z,dz,b,tmp,atmp,J,W,uf,jac_config,linsolve,nlsolve,uprev2,uprev3,uprev4,k₁,k₂,k₃,du₁,du₂)
Expand Down
24 changes: 0 additions & 24 deletions src/caches/kencarp_kvaerno_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,6 @@ end
function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,
uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
@oopnlcachefields
if typeof(f) <: SplitFunction && uf != nothing
uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p)
end
tab = KenCarp3Tableau(uToltype,real(tTypeNoUnits))
nlsolve = typeof(_nlsolve)(NLSolverCache(κ,tol,min_iter,max_iter,10000,new_W,z,W,tab.γ,tab.c3,ηold,z₊,dz,tmp,b,k))

Expand Down Expand Up @@ -56,11 +53,6 @@ function alg_cache(alg::KenCarp3,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
if typeof(f) <: SplitFunction
k1 = similar(u,axes(u)); k2 = similar(u,axes(u))
k3 = similar(u,axes(u)); k4 = similar(u,axes(u))
if uf != nothing
uf = DiffEqDiffTools.UJacobianWrapper(f.f1,t,p)
linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
end
else
k1 = nothing; k2 = nothing
k3 = nothing; k4 = nothing
Expand Down Expand Up @@ -141,9 +133,6 @@ end
function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,
uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
@oopnlcachefields
if typeof(f) <: SplitFunction && uf != nothing
uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p)
end
uprev3 = u
tprev2 = t
tab = KenCarp4Tableau(uToltype,real(tTypeNoUnits))
Expand Down Expand Up @@ -198,11 +187,6 @@ function alg_cache(alg::KenCarp4,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
k1 = similar(u,axes(u)); k2 = similar(u,axes(u))
k3 = similar(u,axes(u)); k4 = similar(u,axes(u))
k5 = similar(u,axes(u)); k6 = similar(u,axes(u))
if uf != nothing
uf = DiffEqDiffTools.UJacobianWrapper(f.f1,t,p)
linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
end
else
k1 = nothing; k2 = nothing
k3 = nothing; k4 = nothing
Expand Down Expand Up @@ -286,9 +270,6 @@ end
function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,
uprev,uprev2,f,t,dt,reltol,p,calck,::Type{Val{false}})
@oopnlcachefields
if typeof(f) <: SplitFunction && uf != nothing
uf = DiffEqDiffTools.UDerivativeWrapper(f.f1,t,p)
end
tab = KenCarp5Tableau(uToltype,real(tTypeNoUnits))
nlsolve = typeof(_nlsolve)(NLSolverCache(κ,tol,min_iter,max_iter,10000,new_W,z,W,tab.γ,tab.c3,ηold,z₊,dz,tmp,b,k))

Expand Down Expand Up @@ -347,11 +328,6 @@ function alg_cache(alg::KenCarp5,u,rate_prototype,uEltypeNoUnits,uBottomEltypeNo
k3 = similar(u,axes(u)); k4 = similar(u,axes(u))
k5 = similar(u,axes(u)); k6 = similar(u,axes(u))
k7 = similar(u,axes(u)); k8 = similar(u,axes(u))
if uf != nothing
uf = DiffEqDiffTools.UJacobianWrapper(f.f1,t,p)
linsolve = alg.linsolve(Val{:init},uf,u)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
end
else
k1 = nothing; k2 = nothing
k3 = nothing; k4 = nothing
Expand Down
31 changes: 24 additions & 7 deletions src/caches/sdirk_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,19 @@ DiffEqBase.@def iipnlcachefields begin
uToltype = real(uBottomEltypeNoUnits)
ηold = one(uToltype)

if typeof(alg.nlsolve) <: NLNewton
nf = f isa SplitFunction && alg isa SplitAlgorithms ? f.f1 : f
islin = (f isa ODEFunction && islinear(f.f)) || (f isa SplitFunction && islinear(f.f1.f))
# check if `nf` is linear
if islin && alg.nlsolve isa NLNewton
# get the operator
J = nf.f
W = WOperator(f.mass_matrix, dt, J)
du1 = rate_prototype
uf = nothing
jac_config = nothing
linsolve = alg.linsolve(Val{:init},nf,u)
z₊ = z
elseif alg.nlsolve isa NLNewton
if DiffEqBase.has_jac(f) && !DiffEqBase.has_invW(f) && f.jac_prototype != nothing
W = WOperator(f, dt)
J = nothing # is J = W.J better?
Expand All @@ -17,8 +29,9 @@ DiffEqBase.@def iipnlcachefields begin
W = similar(J)
end
du1 = zero(rate_prototype)
uf = DiffEqDiffTools.UJacobianWrapper(f,t,p)
jac_config = build_jac_config(alg,f,uf,du1,uprev,u,tmp,dz)
# if the algorithm specializes on split problems the use `nf`
uf = DiffEqDiffTools.UJacobianWrapper(nf,t,p)
jac_config = build_jac_config(alg,nf,uf,du1,uprev,u,tmp,dz)
linsolve = alg.linsolve(Val{:init},uf,u)
z₊ = z
elseif typeof(alg.nlsolve) <: NLFunctional
Expand Down Expand Up @@ -47,13 +60,17 @@ DiffEqBase.@def oopnlcachefields begin
nlcache = alg.nlsolve.cache
@unpack κ,tol,max_iter,min_iter,new_W = nlcache
z = uprev
if typeof(alg.nlsolve) <: NLNewton
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
nf = f isa SplitFunction && alg isa SplitAlgorithms ? f.f1 : f
if alg.nlsolve isa NLNewton
# only use `nf` if the algorithm specializes on split eqs
uf = DiffEqDiffTools.UDerivativeWrapper(nf,t,p)
else
uf = nothing
end
if DiffEqBase.has_jac(f) && typeof(alg.nlsolve) <: NLNewton
J = f.jac(uprev, p, t)
islin = (f isa ODEFunction && islinear(f.f)) || (f isa SplitFunction && islinear(f.f1.f))
if (islin || DiffEqBase.has_jac(f)) && typeof(alg.nlsolve) <: NLNewton
# get the operator
J = islin ? nf.f : f.jac(uprev, p, t)
if !isa(J, DiffEqBase.AbstractDiffEqLinearOperator)
J = DiffEqArrayOperator(J)
end
Expand Down
44 changes: 22 additions & 22 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,10 +241,18 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_
alg = unwrap_alg(integrator, true)
isnewton = !(typeof(alg) <: OrdinaryDiffEqRosenbrockAdaptiveAlgorithm ||
typeof(alg) <: OrdinaryDiffEqRosenbrockAlgorithm)
isnewton && ( nlcache = cache.nlsolve.cache; @unpack ηold,nl_iters = cache.nlsolve.cache )
isnewton && ( nlcache = cache.nlsolve.cache; @unpack ηold,nl_iters = cache.nlsolve.cache)

# calculate W
# fast pass
# we only want to factorize the linear operator once
new_jac = true
new_W = true
if (f isa ODEFunction && islinear(f.f)) || (f isa SplitFunction && islinear(f.f1.f))
new_jac = false
@goto J2W # Jump to W calculation directly, because we already have J
end

# calculate W
if DiffEqBase.has_invW(f)
# skip calculation of inv(W) if step is repeated
!repeat_step && W_transform ? f.invW_t(W, uprev, p, dtgamma, t) :
Expand All @@ -261,6 +269,7 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_
new_jac = true
DiffEqBase.update_coefficients!(W,uprev,p,t)
end
@label J2W
# skip calculation of W if step is repeated
if !repeat_step && (!alg_can_repeat_jac(alg) ||
(integrator.iter < 1 || new_jac ||
Expand Down Expand Up @@ -310,28 +319,19 @@ function calc_W!(integrator, cache::OrdinaryDiffEqConstantCache, dtgamma, repeat
# calculate W
uf.t = t
is_compos = typeof(integrator.alg) <: CompositeAlgorithm
if !W_transform
if DiffEqBase.has_jac(f)
J = f.jac(uprev, p, t)
if !isa(J, DiffEqBase.AbstractDiffEqLinearOperator)
J = DiffEqArrayOperator(J)
end
W = WOperator(mass_matrix, dtgamma, J; transform=false)
else
J = calc_J(integrator, cache, is_compos)
W = mass_matrix - dtgamma*J
if (f isa ODEFunction && islinear(f.f)) || (f isa SplitFunction && islinear(f.f1.f))
J = f.f1.f
W = WOperator(mass_matrix, dtgamma, J; transform=W_transform)
elseif DiffEqBase.has_jac(f)
J = f.jac(uprev, p, t)
if !isa(J, DiffEqBase.AbstractDiffEqLinearOperator)
J = DiffEqArrayOperator(J)
end
W = WOperator(mass_matrix, dtgamma, J; transform=W_transform)
else
if DiffEqBase.has_jac(f)
J = f.jac(uprev, p, t)
if !isa(J, DiffEqBase.AbstractDiffEqLinearOperator)
J = DiffEqArrayOperator(J)
end
W = WOperator(mass_matrix, dtgamma, J; transform=true)
else
J = calc_J(integrator, cache, is_compos)
W = mass_matrix*inv(dtgamma) - J
end
J = calc_J(integrator, cache, is_compos)
W = W_transform ? mass_matrix*inv(dtgamma) - J :
mass_matrix - dtgamma*J
end
is_compos && (integrator.eigen_est = isarray ? opnorm(J, Inf) : J)
W
Expand Down
2 changes: 2 additions & 0 deletions src/misc_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,5 @@ macro swap!(x,y)
$(esc(y)) = tmp
end
end

islinear(f) = f isa DiffEqBase.AbstractDiffEqLinearOperator && f.update_func === DEFAULT_UPDATE_FUNC
8 changes: 4 additions & 4 deletions src/nlsolve/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ function (S::NLNewton{false})(integrator)
@unpack t,dt,uprev,u,f,p = integrator
@unpack z,tmp,W,κ,tol,c,γ,max_iter,min_iter = nlcache
mass_matrix = integrator.f.mass_matrix
#alg = unwrap_alg(integrator, true)
if typeof(integrator.f) <: SplitFunction
alg = unwrap_alg(integrator, true)
if integrator.f isa SplitFunction && alg isa SplitAlgorithms
f = integrator.f.f1
else
f = integrator.f
Expand Down Expand Up @@ -90,8 +90,8 @@ function (S::NLNewton{true})(integrator)
@unpack t,dt,uprev,u,f,p = integrator
@unpack z,dz,tmp,b,W,κ,tol,k,new_W,c,γ,max_iter,min_iter = nlcache
mass_matrix = integrator.f.mass_matrix
#alg = unwrap_alg(integrator, true)
if typeof(integrator.f) <: SplitFunction
alg = unwrap_alg(integrator, true)
if integrator.f isa SplitFunction && alg isa SplitAlgorithms
f = integrator.f.f1
else
f = integrator.f
Expand Down
4 changes: 2 additions & 2 deletions test/linear_nonlinear_convergence_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ using OrdinaryDiffEq: alg_order

Random.seed!(100)
dts = 1 ./2 .^(7:-1:4) #14->7 good plot
for Alg in [GenericIIF1,GenericIIF2,LawsonEuler,NorsettEuler,ETDRK2,ETDRK3,ETDRK4,HochOst4,Exprb32,Exprb43,ETD2]
for Alg in [GenericIIF1,GenericIIF2,LawsonEuler,NorsettEuler,ETDRK2,ETDRK3,ETDRK4,HochOst4,Exprb32,Exprb43,ETD2,KenCarp3]
sim = test_convergence(dts,prob,Alg())
if Alg in [Exprb32, Exprb43]
@test_broken abs(sim.𝒪est[:l2] - alg_order(Alg())) < 0.2
Expand All @@ -35,7 +35,7 @@ end
prob = SplitODEProblem(linnonlin_fun_iip,u0,(0.0,1.0))

dts = 1 ./2 .^(8:-1:4) #14->7 good plot
for Alg in [GenericIIF1,GenericIIF2,LawsonEuler,NorsettEuler,ETDRK2,ETDRK3,ETDRK4,HochOst4,Exprb32,Exprb43,ETD2]
for Alg in [GenericIIF1,GenericIIF2,LawsonEuler,NorsettEuler,ETDRK2,ETDRK3,ETDRK4,HochOst4,Exprb32,Exprb43,ETD2,KenCarp3]
sim = test_convergence(dts,prob,Alg())
if Alg in [Exprb32, Exprb43]
@test_broken abs(sim.𝒪est[:l2] - alg_order(Alg())) < 0.1
Expand Down

0 comments on commit 2ddf833

Please sign in to comment.