Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize W matrix formation and fix fastconvergence judgment #706

Merged
merged 15 commits into from
Mar 21, 2019
14 changes: 14 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,20 @@ function unwrap_alg(integrator, is_stiff)
end
end

function unwrap_cache(integrator, is_stiff)
alg = integrator.alg
cache = integrator.cache
iscomp = alg isa CompositeAlgorithm
if !iscomp
return cache
elseif alg.choice_function isa AutoSwitch
num = is_stiff ? 2 : 1
return cache.caches[num]
else
return cache.caches[integrator.cache.current]
end
end

# Whether `uprev` is used in the algorithm directly.
uses_uprev(alg::OrdinaryDiffEqAlgorithm, adaptive::Bool) = true
uses_uprev(alg::ORK256, adaptive::Bool) = false
Expand Down
2 changes: 1 addition & 1 deletion src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -821,4 +821,4 @@ function alg_cache(alg::RosenbrockW6S4OS,u,rate_prototype,uEltypeNoUnits,uBottom
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
RosenbrockWConstantCache(tf,uf,RosenbrockW6S4OSConstantCache(real(uBottomEltypeNoUnits),real(tTypeNoUnits)))
end
end
87 changes: 46 additions & 41 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ mutable struct WOperator{T,
_func_cache # cache used in `mul!`
_concrete_form # non-lazy form (matrix/number) of the operator
WOperator(mass_matrix, gamma, J, inplace; transform=false) = new{eltype(J),typeof(mass_matrix),
typeof(gamma),typeof(J)}(mass_matrix,gamma,J,inplace,transform,nothing,nothing)
typeof(gamma),typeof(J)}(mass_matrix,gamma,J,transform,inplace,nothing,nothing)
end
function WOperator(f::DiffEqBase.AbstractODEFunction, gamma, inplace; transform=false)
@assert DiffEqBase.has_jac(f) "f needs to have an associated jacobian"
Expand All @@ -152,50 +152,50 @@ function Base.convert(::Type{AbstractMatrix}, W::WOperator)
if W._concrete_form === nothing || !W.inplace
# Allocating
if W.transform
W._concrete_form = W.mass_matrix / W.gamma - convert(AbstractMatrix,W.J)
W._concrete_form = -W.mass_matrix / W.gamma + convert(AbstractMatrix,W.J)
else
W._concrete_form = W.mass_matrix - W.gamma * convert(AbstractMatrix,W.J)
W._concrete_form = -W.mass_matrix + W.gamma * convert(AbstractMatrix,W.J)
end
else
# Non-allocating
if W.transform
rmul!(copyto!(W._concrete_form, W.mass_matrix), 1/W.gamma)
axpy!(-1, convert(AbstractMatrix,W.J), W._concrete_form)
copyto!(W._concrete_form, W.mass_matrix)
axpby!(one(W.gamma), convert(AbstractMatrix,W.J), -inv(W.gamma), W._concrete_form)
else
copyto!(W._concrete_form, W.mass_matrix)
axpy!(-W.gamma, convert(AbstractMatrix,W.J), W._concrete_form)
axpby!(W.gamma, convert(AbstractMatrix,W.J), -one(W.gamma), W._concrete_form)
end
end
W._concrete_form
end
function Base.convert(::Type{Number}, W::WOperator)
if W.transform
W._concrete_form = W.mass_matrix / W.gamma - convert(Number,W.J)
W._concrete_form = -W.mass_matrix / W.gamma + convert(Number,W.J)
else
W._concrete_form = W.mass_matrix - W.gamma * convert(Number,W.J)
W._concrete_form = -W.mass_matrix + W.gamma * convert(Number,W.J)
end
W._concrete_form
end
Base.size(W::WOperator, args...) = size(W.J, args...)
function Base.getindex(W::WOperator, i::Int)
if W.transform
W.mass_matrix[i] / W.gamma - W.J[i]
-W.mass_matrix[i] / W.gamma + W.J[i]
else
W.mass_matrix[i] - W.gamma * W.J[i]
-W.mass_matrix[i] + W.gamma * W.J[i]
end
end
function Base.getindex(W::WOperator, I::Vararg{Int,N}) where {N}
if W.transform
W.mass_matrix[I...] / W.gamma - W.J[I...]
-W.mass_matrix[I...] / W.gamma + W.J[I...]
else
W.mass_matrix[I...] - W.gamma * W.J[I...]
-W.mass_matrix[I...] + W.gamma * W.J[I...]
end
end
function Base.:*(W::WOperator, x::Union{AbstractVecOrMat,Number})
if W.transform
(W.mass_matrix*x) / W.gamma - W.J*x
(W.mass_matrix*x) / -W.gamma + W.J*x
else
W.mass_matrix*x - W.gamma * (W.J*x)
-W.mass_matrix*x + W.gamma * (W.J*x)
end
end
function Base.:\(W::WOperator, x::Union{AbstractVecOrMat,Number})
Expand All @@ -214,15 +214,15 @@ function LinearAlgebra.mul!(Y::AbstractVecOrMat, W::WOperator, B::AbstractVecOrM
if W.transform
# Compute mass_matrix * B
if isa(W.mass_matrix, UniformScaling)
a = W.mass_matrix.λ / W.gamma
a = -W.mass_matrix.λ / W.gamma
@. Y = a * B
else
mul!(Y, W.mass_matrix, B)
lmul!(1/W.gamma, Y)
lmul!(-1/W.gamma, Y)
end
# Compute J * B and subtract
# Compute J * B and add
mul!(W._func_cache, W.J, B)
Y .-= W._func_cache
Y .+= W._func_cache
else
# Compute mass_matrix * B
if isa(W.mass_matrix, UniformScaling)
Expand All @@ -232,23 +232,23 @@ function LinearAlgebra.mul!(Y::AbstractVecOrMat, W::WOperator, B::AbstractVecOrM
end
# Compute J * B
mul!(W._func_cache, W.J, B)
# Subtract result
axpy!(-W.gamma, W._func_cache, Y)
# Add result
axpby!(W.gamma, W._func_cache, -one(W.gamma), Y)
end
end

function do_nowJ(integrator, alg, repeat_step)::Bool
function do_newJ(integrator, alg::T, cache, repeat_step)::Bool where T
repeat_step && return false
!alg_can_repeat_jac(alg) && return true
isnewton = alg isa NewtonAlgorithm
isnewton && ( @unpack ηold,nl_iters = integrator.cache.nlsolver )
isnewton = T <: NewtonAlgorithm
isnewton && (T <: RadauIIA5 ? ( @unpack ηold,nl_iters = cache ) : ( @unpack ηold,nl_iters = cache.nlsolver ))
integrator.force_stepfail && return true
# reuse J when there is fast convergence
fastconvergence = nl_iters == 1 && ηold >= alg.new_jac_conv_bound
fastconvergence = nl_iters == 1 && ηold <= alg.new_jac_conv_bound
return !fastconvergence
end

function do_nowW(integrator, new_jac)::Bool
function do_newW(integrator, new_jac)::Bool
integrator.iter <= 1 && return true
new_jac && return true
# reuse W when the change in stepsize is small enough
Expand All @@ -261,23 +261,29 @@ end
@noinline _throwWJerror(W, J) = throw(DimensionMismatch("W: $(axes(W)), J: $(axes(J))"))
@noinline _throwWMerror(W, mass_matrix) = throw(DimensionMismatch("W: $(axes(W)), mass matrix: $(axes(mass_matrix))"))

@inline function jacobian2W!(W, mass_matrix, dtgamma, J, W_transform)::Nothing
@inline function jacobian2W!(W::AbstractMatrix, mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT
# check size and dimension
iijj = axes(W)
@boundscheck (iijj === axes(J) && length(iijj) === 2) || _throwWJerror(W, J)
mass_matrix isa UniformScaling || @boundscheck axes(mass_matrix) === axes(W) || _throwWMerror(W, mass_matrix)
@inbounds if W_transform
invdtgamma′ = inv(dtgamma)
for i in iijj[1]
@inbounds for j in iijj[2]
W[i, j] = muladd(mass_matrix[i, j], invdtgamma′, -J[i, j])
invdtgamma = inv(dtgamma)
if MT <: UniformScaling
copyto!(W, J)
@simd for i in diagind(W)
W[i] = muladd(-mass_matrix.λ, invdtgamma, J[i])
YingboMa marked this conversation as resolved.
Show resolved Hide resolved
end
else
for j in iijj[2]
@simd for i in iijj[1]
W[i, j] = muladd(-mass_matrix[i, j], invdtgamma, J[i, j])
end
end
end
else
dtgamma′ = -dtgamma
for i in iijj[1]
@simd for j in iijj[2]
W[i, j] = muladd(dtgamma′, J[i, j], mass_matrix[i, j])
for j in iijj[2]
@simd for i in iijj[1]
W[i, j] = muladd(dtgamma, J[i, j], -mass_matrix[i, j])
end
end
end
Expand Down Expand Up @@ -306,18 +312,17 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_
# skip calculation of inv(W) if step is repeated
(!repeat_step && W_transform) ? f.invW_t(W, uprev, p, dtgamma, t) : f.invW(W, uprev, p, dtgamma, t) # W == inverse W
is_compos && calc_J!(integrator, cache, true)

elseif DiffEqBase.has_jac(f) && f.jac_prototype !== nothing
# skip calculation of J if step is repeated
( new_jac = do_nowJ(integrator, alg, repeat_step) ) && DiffEqBase.update_coefficients!(W,uprev,p,t)
( new_jac = do_newJ(integrator, alg, cache, repeat_step) ) && DiffEqBase.update_coefficients!(W,uprev,p,t)
# skip calculation of W if step is repeated
@label J2W
( new_W = do_nowW(integrator, new_jac) ) && (W.transform = W_transform; set_gamma!(W, dtgamma))
( new_W = do_newW(integrator, new_jac) ) && (W.transform = W_transform; set_gamma!(W, dtgamma))
else # concrete W using jacobian from `calc_J!`
# skip calculation of J if step is repeated
( new_jac = do_nowJ(integrator, alg, repeat_step) ) && calc_J!(integrator, cache, is_compos)
( new_jac = do_newJ(integrator, alg, cache, repeat_step) ) && calc_J!(integrator, cache, is_compos)
# skip calculation of W if step is repeated
( new_W = do_nowW(integrator, new_jac) ) && jacobian2W!(W, mass_matrix, dtgamma, J, W_transform)
( new_W = do_newW(integrator, new_jac) ) && jacobian2W!(W, mass_matrix, dtgamma, J, W_transform)
end
isnewton && set_new_W!(cache.nlsolver, new_W)
new_W && (integrator.destats.nw += 1)
Expand Down Expand Up @@ -345,8 +350,8 @@ function calc_W!(integrator, cache::OrdinaryDiffEqConstantCache, dtgamma, repeat
else
integrator.destats.nw += 1
J = calc_J(integrator, cache, is_compos)
W_full = W_transform ? mass_matrix*inv(dtgamma) - J :
mass_matrix - dtgamma*J
W_full = W_transform ? -mass_matrix*inv(dtgamma) + J :
-mass_matrix + dtgamma*J
W = W_full isa Number ? W_full : lu(W_full)
end
is_compos && (integrator.eigen_est = isarray ? opnorm(J, Inf) : J)
Expand Down
16 changes: 9 additions & 7 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
function derivative!(df::AbstractArray{<:Number}, f, x::Union{Number,AbstractArray{<:Number}}, fx::AbstractArray{<:Number}, integrator, grad_config)
alg = unwrap_alg(integrator, true)
tmp = length(x) # We calculate derivtive for all elements in gradient
if get_current_alg_autodiff(integrator.alg, integrator.cache)
if alg_autodiff(alg)
ForwardDiff.derivative!(df, f, fx, x, grad_config)
integrator.destats.nf += 1
else
DiffEqDiffTools.finite_difference_gradient!(df, f, x, grad_config)
fdtype = integrator.alg.diff_type
fdtype = alg.diff_type
if fdtype == Val{:forward} || fdtype == Val{:central}
tmp *= 2
if eltype(df)<:Complex
Expand All @@ -22,7 +23,7 @@ function derivative(f, x::Union{Number,AbstractArray{<:Number}},
local d
tmp = length(x) # We calculate derivtive for all elements in gradient
alg = unwrap_alg(integrator, true)
if get_current_alg_autodiff(integrator.alg, integrator.cache)
if alg_autodiff(alg)
integrator.destats.nf += 1
d = ForwardDiff.derivative(f, x)
else
Expand All @@ -38,7 +39,7 @@ end
function jacobian(f, x, integrator)
alg = unwrap_alg(integrator, true)
local tmp
if get_current_alg_autodiff(alg, integrator.cache)
if alg_autodiff(alg)
J = jacobian_autodiff(f, x)
tmp = 1
else
Expand All @@ -62,11 +63,12 @@ jacobian_finitediff(f, x::AbstractArray, diff_type) =
DiffEqDiffTools.finite_difference_jacobian(f, x, diff_type, eltype(x), Val{false})

function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, fx::AbstractArray{<:Number}, integrator::DiffEqBase.DEIntegrator, jac_config)
if get_current_alg_autodiff(integrator.alg, integrator.cache)
alg = unwrap_alg(integrator, true)
if alg_autodiff(alg)
ForwardDiff.jacobian!(J, f, fx, x, jac_config)
integrator.destats.nf += 1
else
isforward = integrator.alg.diff_type === Val{:forward}
isforward = alg.diff_type === Val{:forward}
if isforward
forwardcache = get_tmp_cache(integrator)[2]
f(forwardcache, x)
Expand All @@ -75,7 +77,7 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, f
else # not forward difference
DiffEqDiffTools.finite_difference_jacobian!(J, f, x, jac_config)
end
integrator.destats.nf += (integrator.alg.diff_type==Val{:complex} && eltype(x)<:Real || isforward) ? length(x) : 2length(x)
integrator.destats.nf += (alg.diff_type==Val{:complex} && eltype(x)<:Real || isforward) ? length(x) : 2length(x)
end
nothing
end
Expand Down
8 changes: 5 additions & 3 deletions src/nlsolve/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Equations II, Springer Series in Computational Mathematics. ISBN
end
integrator.destats.nf += 1
if DiffEqBase.has_invW(f)
dz = _reshape(W * _vec(ztmp), axes(ztmp)) # Here W is actually invW
dz = _reshape(W * -_vec(ztmp), axes(ztmp)) # Here W is actually invW
else
dz = _reshape(W \ _vec(ztmp), axes(ztmp))
end
Expand All @@ -82,7 +82,7 @@ Equations II, Springer Series in Computational Mathematics. ISBN
end

# update solution
z = z .+ dz
z = z .- dz

# check stopping criterion
iter > 1 && (η = θ / (1 - θ))
Expand All @@ -104,6 +104,7 @@ end
@unpack t,dt,uprev,u,p,cache = integrator
@unpack z,dz,tmp,ztmp,k,κtol,c,γ,max_iter = nlsolver
@unpack W, new_W = nlcache
cache = unwrap_cache(integrator, true)

# precalculations
mass_matrix = integrator.f.mass_matrix
Expand Down Expand Up @@ -132,6 +133,7 @@ end
end
if DiffEqBase.has_invW(f)
mul!(vecdz,W,vecztmp) # Here W is actually invW
@. vecdz = -vecdz
else
cache.linsolve(vecdz,W,vecztmp,iter == 1 && new_W)
end
Expand All @@ -151,7 +153,7 @@ end
end

# update solution
z .+= dz
@. z = z - dz

# check stopping criterion
iter > 1 && (η = θ / (1 - θ))
Expand Down
43 changes: 15 additions & 28 deletions src/perform_step/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ end
γdt, αdt, βdt = γ/dt, α/dt, β/dt
J = calc_J(integrator, cache, is_compos)
if u isa Number
LU1 = γdt*mass_matrix - J
LU2 = (αdt + βdt*im)*mass_matrix - J
LU1 = -γdt*mass_matrix + J
LU2 = -(αdt + βdt*im)*mass_matrix + J
else
LU1 = lu(γdt*mass_matrix - J)
LU2 = lu((αdt + βdt*im)*mass_matrix - J)
LU1 = lu(-γdt*mass_matrix + J)
LU2 = lu(-(αdt + βdt*im)*mass_matrix + J)
end
integrator.destats.nw += 1

Expand Down Expand Up @@ -126,9 +126,9 @@ end
end
end

w1 = @. w1 + dw1
w2 = @. w2 + dw2
w3 = @. w3 + dw3
w1 = @. w1 - dw1
w2 = @. w2 - dw2
w3 = @. w3 - dw3

# transform `w` to `z`
z1 = @. T11 * w1 + T12 * w2 + T13 * w3
Expand Down Expand Up @@ -214,27 +214,14 @@ end
c1mc2= c1-c2
κtol = κ*tol # used in Newton iteration
γdt, αdt, βdt = γ/dt, α/dt, β/dt
new_W = true
if repeat_step || (alg_can_repeat_jac(alg) &&
(!integrator.last_stepfail && cache.nl_iters == 1 &&
cache.ηold < alg.new_jac_conv_bound))
new_jac = false
else
new_jac = true
calc_J!(integrator, cache, is_compos)
end
# skip calculation of W if step is repeated
if !repeat_step && (!alg_can_repeat_jac(alg) ||
(integrator.iter < 1 || new_jac ||
abs(dt - (t-integrator.tprev)) > 100eps(typeof(integrator.t))))
(new_jac = do_newJ(integrator, alg, cache, repeat_step)) && calc_J!(integrator, cache, is_compos)
if (new_W = do_newW(integrator, new_jac))
@inbounds for II in CartesianIndices(J)
W1[II] = γdt * mass_matrix[Tuple(II)...] - J[II]
W2[II] = (αdt + βdt*im) * mass_matrix[Tuple(II)...] - J[II]
W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II]
W2[II] = -(αdt + βdt*im) * mass_matrix[Tuple(II)...] + J[II]
end
else
new_W = false
integrator.destats.nw += 1
end
new_W && (integrator.destats.nw += 1)

# TODO better initial guess
if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
Expand Down Expand Up @@ -320,9 +307,9 @@ end
end
end

@. w1 = w1 + dw1
@. w2 = w2 + dw2
@. w3 = w3 + dw3
@. w1 = w1 - dw1
@. w2 = w2 - dw2
@. w3 = w3 - dw3

# transform `w` to `z`
@. z1 = T11 * w1 + T12 * w2 + T13 * w3
Expand Down
Loading