Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ export SplitEuler

export Nystrom4, FineRKN4, FineRKN5, Nystrom4VelocityIndependent,
Nystrom5VelocityIndependent,
IRKN3, IRKN4, DPRKN4, DPRKN5, DPRKN6, DPRKN6FM, DPRKN8, DPRKN12, ERKN4, ERKN5, ERKN7, RKN4
IRKN3, IRKN4, DPRKN4, DPRKN5, DPRKN6, DPRKN6FM, DPRKN8, DPRKN12, ERKN4, ERKN5, ERKN7,
RKN4

export ROCK2, ROCK4, RKC, IRKC, ESERK4, ESERK5, SERK2

Expand Down
1 change: 1 addition & 0 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ struct ERKN7 <: OrdinaryDiffEqAdaptivePartitionedAlgorithm end
Does not include an adaptive method. Solves for for d-dimensional differential systems of second order linear inhomogeneous equations.

!!! warn

This method is only fourth order for these systems, the method is second order otherwise!

## References
Expand Down
8 changes: 4 additions & 4 deletions src/caches/rkn_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -693,9 +693,9 @@ end
end

function alg_cache(alg::RKN4, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
reduced_rate_prototype = rate_prototype.x[2]
k₁ = zero(rate_prototype)
k₂ = zero(reduced_rate_prototype)
Expand All @@ -712,4 +712,4 @@ function alg_cache(alg::RKN4, u, rate_prototype, ::Type{uEltypeNoUnits},
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
RKN4ConstantCache()
end
end
6 changes: 4 additions & 2 deletions src/caches/rosenbrock_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1045,7 +1045,8 @@ function alg_cache(alg::Rodas5, u, rate_prototype, ::Type{uEltypeNoUnits},
constvalue(tTypeNoUnits)), J, W, linsolve)
end

function alg_cache(alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
function alg_cache(
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{true}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Expand Down Expand Up @@ -1093,7 +1094,8 @@ function alg_cache(alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, :
linsolve, jac_config, grad_config, reltol, alg)
end

function alg_cache(alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
function alg_cache(
alg::Union{Rodas5P, Rodas5Pe, Rodas5Pr}, u, rate_prototype, ::Type{uEltypeNoUnits},
::Type{uBottomEltypeNoUnits}, ::Type{tTypeNoUnits}, uprev, uprev2, f, t,
dt, reltol, p, calck,
::Val{false}) where {uEltypeNoUnits, uBottomEltypeNoUnits, tTypeNoUnits}
Expand Down
9 changes: 6 additions & 3 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,8 @@ end
if J isa StaticArray &&
integrator.alg isa
Union{
Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2,
Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
W = W_transform ? J - mass_matrix * inv(dtgamma) :
dtgamma * J - mass_matrix
else
Expand All @@ -775,7 +776,8 @@ end
W_full
elseif len !== nothing &&
integrator.alg isa
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P,
Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
StaticWOperator(W_full)
else
DiffEqBase.default_factorize(W_full)
Expand Down Expand Up @@ -923,7 +925,8 @@ function build_J_W(alg, u, uprev, p, t, dt, f::F, ::Type{uEltypeNoUnits},
len = StaticArrayInterface.known_length(typeof(J))
if len !== nothing &&
alg isa
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P, Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
Union{Rosenbrock23, Rodas23W, Rodas3P, Rodas4, Rodas4P,
Rodas4P2, Rodas5, Rodas5P, Rodas5Pe, Rodas5Pr}
StaticWOperator(J, false)
else
ArrayInterface.lu_instance(J)
Expand Down
6 changes: 6 additions & 0 deletions src/integrators/controllers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ end

@inline function step_reject_controller!(integrator, alg)
step_reject_controller!(integrator, integrator.opts.controller, alg)
cache = integrator.cache
if hasfield(typeof(cache), :nlsolve)
nlsolve = cache.nlsolve
nlsolve.prev_θ = one(nlsolve.prev_θ)
end
return nothing
end

reset_alg_dependent_opts!(controller::AbstractController, alg1, alg2) = nothing
Expand Down
19 changes: 14 additions & 5 deletions src/nlsolve/nlsolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ dt⋅f(innertmp + γ⋅z, p, t + c⋅dt) + outertmp = z

where `dt` is the step size and `γ` and `c` are constants, and return the solution `z`.
"""
function nlsolve!(nlsolver::AbstractNLSolver, integrator::DiffEqBase.DEIntegrator,
cache = nothing, repeat_step = false)
function nlsolve!(nlsolver::NL, integrator::DiffEqBase.DEIntegrator,
cache = nothing, repeat_step = false) where {NL <: AbstractNLSolver}
always_new = is_always_new(nlsolver)
check_div′ = check_div(nlsolver)
@label REDO
Expand Down Expand Up @@ -59,9 +59,11 @@ function nlsolve!(nlsolver::AbstractNLSolver, integrator::DiffEqBase.DEIntegrato
break
end

prev_θ = hasfield(NL, :prev_θ) ? nlsolver.prev_θ : one(ndz)

# check divergence (not in initial step)
if iter > 1
θ = ndz / ndzprev
θ = prev_θ = max(0.3 * prev_θ, ndz / ndzprev)

# When one Newton iteration basically does nothing, it's likely that we
# are at the precision limit of floating point number. Thus, we just call
Expand All @@ -84,13 +86,20 @@ function nlsolve!(nlsolver::AbstractNLSolver, integrator::DiffEqBase.DEIntegrato
nlsolver.nfails += 1
break
end
else
θ = min(one(prev_θ), prev_θ)
end

if hasfield(NL, :prev_θ)
nlsolver.prev_θ = prev_θ
end

apply_step!(nlsolver, integrator)

# check for convergence
iter > 1 && (η = DiffEqBase.value(θ / (1 - θ)))
if (iter == 1 && ndz < 1e-5) || (iter > 1 && (η >= zero(η) && η * ndz < κ))
η = DiffEqBase.value(θ / (1 - θ))
if (iter == 1 && ndz < 1e-5) ||
((iter > 1 || isnewton(nlsolver)) && η >= zero(η) && η * ndz < κ)
nlsolver.status = Convergence
nlsolver.nfails = 0
break
Expand Down
10 changes: 7 additions & 3 deletions src/nlsolve/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ end
abstract type AbstractNLSolver{algType, iip} end

mutable struct NLSolver{algType, iip, uType, gamType, tmpType, tType,
C <: AbstractNLSolverCache} <: AbstractNLSolver{algType, iip}
C <: AbstractNLSolverCache, E} <: AbstractNLSolver{algType, iip}
z::uType
tmp::uType # DIRK and multistep methods only use tmp
tmp2::tmpType # for GLM if neccssary
Expand All @@ -88,13 +88,16 @@ mutable struct NLSolver{algType, iip, uType, gamType, tmpType, tType,
cache::C
method::MethodType
nfails::Int
prev_θ::E
end

# default to DIRK
function NLSolver{iip, tType}(z, tmp, ztmp, γ, c, α, alg, κ, fast_convergence_cutoff, ηold,
iter, maxiters, status, cache, method = DIRK, tmp2 = nothing,
nfails::Int = 0) where {iip, tType}
NLSolver{typeof(alg), iip, typeof(z), typeof(γ), typeof(tmp2), tType, typeof(cache)}(z,
RT = real(eltype(z))
NLSolver{typeof(alg), iip, typeof(z), typeof(γ), typeof(tmp2), tType, typeof(cache), RT}(
z,
tmp,
tmp2,
ztmp,
Expand All @@ -115,7 +118,8 @@ function NLSolver{iip, tType}(z, tmp, ztmp, γ, c, α, alg, κ, fast_convergence
status,
cache,
method,
nfails)
nfails,
one(RT))
end

# caches
Expand Down
40 changes: 20 additions & 20 deletions src/perform_step/rkn_perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1840,12 +1840,12 @@ end
duprev, uprev = integrator.uprev.x
u, du = integrator.u.x
#define dt values
halfdt = dt/2
halfdt = dt / 2
dtsq = dt^2
eightdtsq = dtsq/8
halfdtsq = dtsq/2
sixthdtsq = dtsq/6
sixthdt = dt/6
eightdtsq = dtsq / 8
halfdtsq = dtsq / 2
sixthdtsq = dtsq / 6
sixthdt = dt / 6
ttmp = t + halfdt

#perform operations to find k values
Expand All @@ -1860,13 +1860,13 @@ end
k₃ = f.f1(kdu, ku, p, t + dt)

#perform final calculations to determine new y and y'.
u = uprev + sixthdtsq* (1*k₁ + 2*k₂ + 0*k₃) + dt * duprev
du = duprev + sixthdt * (1*k₁ + 4*k₂ + 1*k₃)
u = uprev + sixthdtsq * (1 * k₁ + 2 * k₂ + 0 * k₃) + dt * duprev
du = duprev + sixthdt * (1 * k₁ + 4 * k₂ + 1 * k₃)

integrator.u = ArrayPartition((du, u))
integrator.fsallast = ArrayPartition((f.f1(du, u, p, t + dt), f.f2(du, u, p, t + dt)))
integrator.stats.nf += 2
integrator.stats.nf2 += 1
integrator.stats.nf2 += 1
integrator.k[1] = integrator.fsalfirst
integrator.k[2] = integrator.fsallast
end
Expand All @@ -1879,32 +1879,32 @@ end
kdu, ku = integrator.cache.tmp.x[1], integrator.cache.tmp.x[2]

#define dt values
halfdt = dt/2
halfdt = dt / 2
dtsq = dt^2
eightdtsq = dtsq/8
halfdtsq = dtsq/2
sixthdtsq = dtsq/6
sixthdt = dt/6
eightdtsq = dtsq / 8
halfdtsq = dtsq / 2
sixthdtsq = dtsq / 6
sixthdt = dt / 6
ttmp = t + halfdt

#perform operations to find k values
k₁ = integrator.fsalfirst.x[1]
@.. broadcast=false ku = uprev + halfdt * duprev + eightdtsq * k₁
@.. broadcast=false kdu = duprev + halfdt * k₁
@.. broadcast=false ku=uprev + halfdt * duprev + eightdtsq * k₁
@.. broadcast=false kdu=duprev + halfdt * k₁

f.f1(k₂, kdu, ku, p, ttmp)
@.. broadcast=false ku = uprev + dt * duprev + halfdtsq * k₂
@.. broadcast=false kdu = duprev + dt * k₂
@.. broadcast=false ku=uprev + dt * duprev + halfdtsq * k₂
@.. broadcast=false kdu=duprev + dt * k₂

f.f1(k₃, kdu, ku, p, t + dt)

#perform final calculations to determine new y and y'.
@.. broadcast=false u = uprev + sixthdtsq* (1*k₁ + 2*k₂ + 0*k₃) + dt * duprev
@.. broadcast=false du = duprev + sixthdt * (1*k₁ + 4*k₂ + 1*k₃)
@.. broadcast=false u=uprev + sixthdtsq * (1 * k₁ + 2 * k₂ + 0 * k₃) + dt * duprev
@.. broadcast=false du=duprev + sixthdt * (1 * k₁ + 4 * k₂ + 1 * k₃)

f.f1(k.x[1], du, u, p, t + dt)
f.f2(k.x[2], du, u, p, t + dt)

integrator.stats.nf += 2
integrator.stats.nf2 += 1
end
end
Loading