Skip to content

Commit

Permalink
Merge pull request #706 from JuliaDiffEq/myb/W_transform
Browse files Browse the repository at this point in the history
Optimize W matrix formation and fix `fastconvergence` judgment
  • Loading branch information
YingboMa committed Mar 21, 2019
2 parents 01c54e0 + 938b627 commit 98412e9
Show file tree
Hide file tree
Showing 11 changed files with 189 additions and 146 deletions.
14 changes: 14 additions & 0 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,20 @@ function unwrap_alg(integrator, is_stiff)
end
end

function unwrap_cache(integrator, is_stiff)
alg = integrator.alg
cache = integrator.cache
iscomp = alg isa CompositeAlgorithm
if !iscomp
return cache
elseif alg.choice_function isa AutoSwitch
num = is_stiff ? 2 : 1
return cache.caches[num]
else
return cache.caches[integrator.cache.current]
end
end

# Whether `uprev` is used in the algorithm directly.
uses_uprev(alg::OrdinaryDiffEqAlgorithm, adaptive::Bool) = true
uses_uprev(alg::ORK256, adaptive::Bool) = false
Expand Down
2 changes: 1 addition & 1 deletion src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -821,4 +821,4 @@ function alg_cache(alg::RosenbrockW6S4OS,u,rate_prototype,uEltypeNoUnits,uBottom
tf = DiffEqDiffTools.TimeDerivativeWrapper(f,u,p)
uf = DiffEqDiffTools.UDerivativeWrapper(f,t,p)
RosenbrockWConstantCache(tf,uf,RosenbrockW6S4OSConstantCache(real(uBottomEltypeNoUnits),real(tTypeNoUnits)))
end
end
87 changes: 46 additions & 41 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ mutable struct WOperator{T,
_func_cache # cache used in `mul!`
_concrete_form # non-lazy form (matrix/number) of the operator
WOperator(mass_matrix, gamma, J, inplace; transform=false) = new{eltype(J),typeof(mass_matrix),
typeof(gamma),typeof(J)}(mass_matrix,gamma,J,inplace,transform,nothing,nothing)
typeof(gamma),typeof(J)}(mass_matrix,gamma,J,transform,inplace,nothing,nothing)
end
function WOperator(f::DiffEqBase.AbstractODEFunction, gamma, inplace; transform=false)
@assert DiffEqBase.has_jac(f) "f needs to have an associated jacobian"
Expand All @@ -152,50 +152,50 @@ function Base.convert(::Type{AbstractMatrix}, W::WOperator)
if W._concrete_form === nothing || !W.inplace
# Allocating
if W.transform
W._concrete_form = W.mass_matrix / W.gamma - convert(AbstractMatrix,W.J)
W._concrete_form = -W.mass_matrix / W.gamma + convert(AbstractMatrix,W.J)
else
W._concrete_form = W.mass_matrix - W.gamma * convert(AbstractMatrix,W.J)
W._concrete_form = -W.mass_matrix + W.gamma * convert(AbstractMatrix,W.J)
end
else
# Non-allocating
if W.transform
rmul!(copyto!(W._concrete_form, W.mass_matrix), 1/W.gamma)
axpy!(-1, convert(AbstractMatrix,W.J), W._concrete_form)
copyto!(W._concrete_form, W.mass_matrix)
axpby!(one(W.gamma), convert(AbstractMatrix,W.J), -inv(W.gamma), W._concrete_form)
else
copyto!(W._concrete_form, W.mass_matrix)
axpy!(-W.gamma, convert(AbstractMatrix,W.J), W._concrete_form)
axpby!(W.gamma, convert(AbstractMatrix,W.J), -one(W.gamma), W._concrete_form)
end
end
W._concrete_form
end
function Base.convert(::Type{Number}, W::WOperator)
if W.transform
W._concrete_form = W.mass_matrix / W.gamma - convert(Number,W.J)
W._concrete_form = -W.mass_matrix / W.gamma + convert(Number,W.J)
else
W._concrete_form = W.mass_matrix - W.gamma * convert(Number,W.J)
W._concrete_form = -W.mass_matrix + W.gamma * convert(Number,W.J)
end
W._concrete_form
end
Base.size(W::WOperator, args...) = size(W.J, args...)
function Base.getindex(W::WOperator, i::Int)
if W.transform
W.mass_matrix[i] / W.gamma - W.J[i]
-W.mass_matrix[i] / W.gamma + W.J[i]
else
W.mass_matrix[i] - W.gamma * W.J[i]
-W.mass_matrix[i] + W.gamma * W.J[i]
end
end
function Base.getindex(W::WOperator, I::Vararg{Int,N}) where {N}
if W.transform
W.mass_matrix[I...] / W.gamma - W.J[I...]
-W.mass_matrix[I...] / W.gamma + W.J[I...]
else
W.mass_matrix[I...] - W.gamma * W.J[I...]
-W.mass_matrix[I...] + W.gamma * W.J[I...]
end
end
function Base.:*(W::WOperator, x::Union{AbstractVecOrMat,Number})
if W.transform
(W.mass_matrix*x) / W.gamma - W.J*x
(W.mass_matrix*x) / -W.gamma + W.J*x
else
W.mass_matrix*x - W.gamma * (W.J*x)
-W.mass_matrix*x + W.gamma * (W.J*x)
end
end
function Base.:\(W::WOperator, x::Union{AbstractVecOrMat,Number})
Expand All @@ -214,15 +214,15 @@ function LinearAlgebra.mul!(Y::AbstractVecOrMat, W::WOperator, B::AbstractVecOrM
if W.transform
# Compute mass_matrix * B
if isa(W.mass_matrix, UniformScaling)
a = W.mass_matrix.λ / W.gamma
a = -W.mass_matrix.λ / W.gamma
@. Y = a * B
else
mul!(Y, W.mass_matrix, B)
lmul!(1/W.gamma, Y)
lmul!(-1/W.gamma, Y)
end
# Compute J * B and subtract
# Compute J * B and add
mul!(W._func_cache, W.J, B)
Y .-= W._func_cache
Y .+= W._func_cache
else
# Compute mass_matrix * B
if isa(W.mass_matrix, UniformScaling)
Expand All @@ -232,23 +232,23 @@ function LinearAlgebra.mul!(Y::AbstractVecOrMat, W::WOperator, B::AbstractVecOrM
end
# Compute J * B
mul!(W._func_cache, W.J, B)
# Subtract result
axpy!(-W.gamma, W._func_cache, Y)
# Add result
axpby!(W.gamma, W._func_cache, -one(W.gamma), Y)
end
end

function do_nowJ(integrator, alg, repeat_step)::Bool
function do_newJ(integrator, alg::T, cache, repeat_step)::Bool where T
repeat_step && return false
!alg_can_repeat_jac(alg) && return true
isnewton = alg isa NewtonAlgorithm
isnewton && ( @unpack ηold,nl_iters = integrator.cache.nlsolver )
isnewton = T <: NewtonAlgorithm
isnewton && (T <: RadauIIA5 ? ( @unpack ηold,nl_iters = cache ) : ( @unpack ηold,nl_iters = cache.nlsolver ))
integrator.force_stepfail && return true
# reuse J when there is fast convergence
fastconvergence = nl_iters == 1 && ηold >= alg.new_jac_conv_bound
fastconvergence = nl_iters == 1 && ηold <= alg.new_jac_conv_bound
return !fastconvergence
end

function do_nowW(integrator, new_jac)::Bool
function do_newW(integrator, new_jac)::Bool
integrator.iter <= 1 && return true
new_jac && return true
# reuse W when the change in stepsize is small enough
Expand All @@ -261,23 +261,29 @@ end
@noinline _throwWJerror(W, J) = throw(DimensionMismatch("W: $(axes(W)), J: $(axes(J))"))
@noinline _throwWMerror(W, mass_matrix) = throw(DimensionMismatch("W: $(axes(W)), mass matrix: $(axes(mass_matrix))"))

@inline function jacobian2W!(W, mass_matrix, dtgamma, J, W_transform)::Nothing
@inline function jacobian2W!(W::AbstractMatrix, mass_matrix::MT, dtgamma::Number, J::AbstractMatrix, W_transform::Bool)::Nothing where MT
# check size and dimension
iijj = axes(W)
@boundscheck (iijj === axes(J) && length(iijj) === 2) || _throwWJerror(W, J)
mass_matrix isa UniformScaling || @boundscheck axes(mass_matrix) === axes(W) || _throwWMerror(W, mass_matrix)
@inbounds if W_transform
invdtgamma′ = inv(dtgamma)
for i in iijj[1]
@inbounds for j in iijj[2]
W[i, j] = muladd(mass_matrix[i, j], invdtgamma′, -J[i, j])
invdtgamma = inv(dtgamma)
if MT <: UniformScaling
copyto!(W, J)
@simd for i in diagind(W)
W[i] = muladd(-mass_matrix.λ, invdtgamma, J[i])
end
else
for j in iijj[2]
@simd for i in iijj[1]
W[i, j] = muladd(-mass_matrix[i, j], invdtgamma, J[i, j])
end
end
end
else
dtgamma′ = -dtgamma
for i in iijj[1]
@simd for j in iijj[2]
W[i, j] = muladd(dtgamma′, J[i, j], mass_matrix[i, j])
for j in iijj[2]
@simd for i in iijj[1]
W[i, j] = muladd(dtgamma, J[i, j], -mass_matrix[i, j])
end
end
end
Expand Down Expand Up @@ -306,18 +312,17 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_
# skip calculation of inv(W) if step is repeated
(!repeat_step && W_transform) ? f.invW_t(W, uprev, p, dtgamma, t) : f.invW(W, uprev, p, dtgamma, t) # W == inverse W
is_compos && calc_J!(integrator, cache, true)

elseif DiffEqBase.has_jac(f) && f.jac_prototype !== nothing
# skip calculation of J if step is repeated
( new_jac = do_nowJ(integrator, alg, repeat_step) ) && DiffEqBase.update_coefficients!(W,uprev,p,t)
( new_jac = do_newJ(integrator, alg, cache, repeat_step) ) && DiffEqBase.update_coefficients!(W,uprev,p,t)
# skip calculation of W if step is repeated
@label J2W
( new_W = do_nowW(integrator, new_jac) ) && (W.transform = W_transform; set_gamma!(W, dtgamma))
( new_W = do_newW(integrator, new_jac) ) && (W.transform = W_transform; set_gamma!(W, dtgamma))
else # concrete W using jacobian from `calc_J!`
# skip calculation of J if step is repeated
( new_jac = do_nowJ(integrator, alg, repeat_step) ) && calc_J!(integrator, cache, is_compos)
( new_jac = do_newJ(integrator, alg, cache, repeat_step) ) && calc_J!(integrator, cache, is_compos)
# skip calculation of W if step is repeated
( new_W = do_nowW(integrator, new_jac) ) && jacobian2W!(W, mass_matrix, dtgamma, J, W_transform)
( new_W = do_newW(integrator, new_jac) ) && jacobian2W!(W, mass_matrix, dtgamma, J, W_transform)
end
isnewton && set_new_W!(cache.nlsolver, new_W)
new_W && (integrator.destats.nw += 1)
Expand Down Expand Up @@ -345,8 +350,8 @@ function calc_W!(integrator, cache::OrdinaryDiffEqConstantCache, dtgamma, repeat
else
integrator.destats.nw += 1
J = calc_J(integrator, cache, is_compos)
W_full = W_transform ? mass_matrix*inv(dtgamma) - J :
mass_matrix - dtgamma*J
W_full = W_transform ? -mass_matrix*inv(dtgamma) + J :
-mass_matrix + dtgamma*J
W = W_full isa Number ? W_full : lu(W_full)
end
is_compos && (integrator.eigen_est = isarray ? opnorm(J, Inf) : J)
Expand Down
16 changes: 9 additions & 7 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
function derivative!(df::AbstractArray{<:Number}, f, x::Union{Number,AbstractArray{<:Number}}, fx::AbstractArray{<:Number}, integrator, grad_config)
alg = unwrap_alg(integrator, true)
tmp = length(x) # We calculate derivtive for all elements in gradient
if get_current_alg_autodiff(integrator.alg, integrator.cache)
if alg_autodiff(alg)
ForwardDiff.derivative!(df, f, fx, x, grad_config)
integrator.destats.nf += 1
else
DiffEqDiffTools.finite_difference_gradient!(df, f, x, grad_config)
fdtype = integrator.alg.diff_type
fdtype = alg.diff_type
if fdtype == Val{:forward} || fdtype == Val{:central}
tmp *= 2
if eltype(df)<:Complex
Expand All @@ -22,7 +23,7 @@ function derivative(f, x::Union{Number,AbstractArray{<:Number}},
local d
tmp = length(x) # We calculate derivtive for all elements in gradient
alg = unwrap_alg(integrator, true)
if get_current_alg_autodiff(integrator.alg, integrator.cache)
if alg_autodiff(alg)
integrator.destats.nf += 1
d = ForwardDiff.derivative(f, x)
else
Expand All @@ -38,7 +39,7 @@ end
function jacobian(f, x, integrator)
alg = unwrap_alg(integrator, true)
local tmp
if get_current_alg_autodiff(alg, integrator.cache)
if alg_autodiff(alg)
J = jacobian_autodiff(f, x)
tmp = 1
else
Expand All @@ -62,11 +63,12 @@ jacobian_finitediff(f, x::AbstractArray, diff_type) =
DiffEqDiffTools.finite_difference_jacobian(f, x, diff_type, eltype(x), Val{false})

function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, fx::AbstractArray{<:Number}, integrator::DiffEqBase.DEIntegrator, jac_config)
if get_current_alg_autodiff(integrator.alg, integrator.cache)
alg = unwrap_alg(integrator, true)
if alg_autodiff(alg)
ForwardDiff.jacobian!(J, f, fx, x, jac_config)
integrator.destats.nf += 1
else
isforward = integrator.alg.diff_type === Val{:forward}
isforward = alg.diff_type === Val{:forward}
if isforward
forwardcache = get_tmp_cache(integrator)[2]
f(forwardcache, x)
Expand All @@ -75,7 +77,7 @@ function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, f
else # not forward difference
DiffEqDiffTools.finite_difference_jacobian!(J, f, x, jac_config)
end
integrator.destats.nf += (integrator.alg.diff_type==Val{:complex} && eltype(x)<:Real || isforward) ? length(x) : 2length(x)
integrator.destats.nf += (alg.diff_type==Val{:complex} && eltype(x)<:Real || isforward) ? length(x) : 2length(x)
end
nothing
end
Expand Down
8 changes: 5 additions & 3 deletions src/nlsolve/newton.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Equations II, Springer Series in Computational Mathematics. ISBN
end
integrator.destats.nf += 1
if DiffEqBase.has_invW(f)
dz = _reshape(W * _vec(ztmp), axes(ztmp)) # Here W is actually invW
dz = _reshape(W * -_vec(ztmp), axes(ztmp)) # Here W is actually invW
else
dz = _reshape(W \ _vec(ztmp), axes(ztmp))
end
Expand All @@ -82,7 +82,7 @@ Equations II, Springer Series in Computational Mathematics. ISBN
end

# update solution
z = z .+ dz
z = z .- dz

# check stopping criterion
iter > 1 &&= θ / (1 - θ))
Expand All @@ -104,6 +104,7 @@ end
@unpack t,dt,uprev,u,p,cache = integrator
@unpack z,dz,tmp,ztmp,k,κtol,c,γ,max_iter = nlsolver
@unpack W, new_W = nlcache
cache = unwrap_cache(integrator, true)

# precalculations
mass_matrix = integrator.f.mass_matrix
Expand Down Expand Up @@ -132,6 +133,7 @@ end
end
if DiffEqBase.has_invW(f)
mul!(vecdz,W,vecztmp) # Here W is actually invW
@. vecdz = -vecdz
else
cache.linsolve(vecdz,W,vecztmp,iter == 1 && new_W)
end
Expand All @@ -151,7 +153,7 @@ end
end

# update solution
z .+= dz
@. z = z - dz

# check stopping criterion
iter > 1 &&= θ / (1 - θ))
Expand Down
43 changes: 15 additions & 28 deletions src/perform_step/firk_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ end
γdt, αdt, βdt = γ/dt, α/dt, β/dt
J = calc_J(integrator, cache, is_compos)
if u isa Number
LU1 = γdt*mass_matrix - J
LU2 = (αdt + βdt*im)*mass_matrix - J
LU1 = -γdt*mass_matrix + J
LU2 = -(αdt + βdt*im)*mass_matrix + J
else
LU1 = lu(γdt*mass_matrix - J)
LU2 = lu((αdt + βdt*im)*mass_matrix - J)
LU1 = lu(-γdt*mass_matrix + J)
LU2 = lu(-(αdt + βdt*im)*mass_matrix + J)
end
integrator.destats.nw += 1

Expand Down Expand Up @@ -126,9 +126,9 @@ end
end
end

w1 = @. w1 + dw1
w2 = @. w2 + dw2
w3 = @. w3 + dw3
w1 = @. w1 - dw1
w2 = @. w2 - dw2
w3 = @. w3 - dw3

# transform `w` to `z`
z1 = @. T11 * w1 + T12 * w2 + T13 * w3
Expand Down Expand Up @@ -214,27 +214,14 @@ end
c1mc2= c1-c2
κtol = κ*tol # used in Newton iteration
γdt, αdt, βdt = γ/dt, α/dt, β/dt
new_W = true
if repeat_step || (alg_can_repeat_jac(alg) &&
(!integrator.last_stepfail && cache.nl_iters == 1 &&
cache.ηold < alg.new_jac_conv_bound))
new_jac = false
else
new_jac = true
calc_J!(integrator, cache, is_compos)
end
# skip calculation of W if step is repeated
if !repeat_step && (!alg_can_repeat_jac(alg) ||
(integrator.iter < 1 || new_jac ||
abs(dt - (t-integrator.tprev)) > 100eps(typeof(integrator.t))))
(new_jac = do_newJ(integrator, alg, cache, repeat_step)) && calc_J!(integrator, cache, is_compos)
if (new_W = do_newW(integrator, new_jac))
@inbounds for II in CartesianIndices(J)
W1[II] = γdt * mass_matrix[Tuple(II)...] - J[II]
W2[II] = (αdt + βdt*im) * mass_matrix[Tuple(II)...] - J[II]
W1[II] = -γdt * mass_matrix[Tuple(II)...] + J[II]
W2[II] = -(αdt + βdt*im) * mass_matrix[Tuple(II)...] + J[II]
end
else
new_W = false
integrator.destats.nw += 1
end
new_W && (integrator.destats.nw += 1)

# TODO better initial guess
if integrator.iter == 1 || integrator.u_modified || alg.extrapolant == :constant
Expand Down Expand Up @@ -320,9 +307,9 @@ end
end
end

@. w1 = w1 + dw1
@. w2 = w2 + dw2
@. w3 = w3 + dw3
@. w1 = w1 - dw1
@. w2 = w2 - dw2
@. w3 = w3 - dw3

# transform `w` to `z`
@. z1 = T11 * w1 + T12 * w2 + T13 * w3
Expand Down
Loading

0 comments on commit 98412e9

Please sign in to comment.