Skip to content

Commit

Permalink
Merge 2a58647 into 72050cb
Browse files Browse the repository at this point in the history
  • Loading branch information
jverzani committed Oct 1, 2018
2 parents 72050cb + 2a58647 commit 5f6bd88
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 63 deletions.
113 changes: 57 additions & 56 deletions src/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ tolerances are set to zero (the default) guarantees a "best" solution
When tolerances are given, this algorithm terminates when the midpoint
is approximately equal to an endpoint using absolute tolerance `xatol`
and relative tolerance `xrtol`.
and relative tolerance `xrtol`.
When a zero tolerance is given and the values are not `Float64`
values, this will call the `A42` method.
"""
struct Bisection <: AbstractBisection end # either solvable or A42
Expand All @@ -57,8 +57,8 @@ function show_tracks(l::Tracks, M::AbstractBracketing)
end
println("")
end



## helper function
function adjust_bracket(x0)
Expand All @@ -75,7 +75,7 @@ end
# a,b both finite and of the same sign
function init_state(method::AbstractBisection, fs, x)
length(x) > 1 || throw(ArgumentError(bracketing_error))

x0, x1 = adjust_bracket(x) # now finite, right order
fx0, fx1 = promote(sign(fs(x0)), sign(fs(x1)))
fx0 * fx1 > 0 && throw(ArgumentError(bracketing_error))
Expand Down Expand Up @@ -140,8 +140,8 @@ the function changes sign at one of the answer's adjacent floating
point values.
For other types, the the `A42` method (with its tolerances) is used.
"""
"""
default_tolerances(M::Union{Bisection, BisectionExact}) = default_tolerances(M,Float64, Float64)
function default_tolerances(::M, ::Type{T}, ::Type{S}) where {M<:Union{Bisection, BisectionExact},T,S}
xatol = zero(T)
Expand Down Expand Up @@ -181,7 +181,7 @@ __middle(x::Float64, y::Float64) = __middle(Float64, UInt64, x, y)
__middle(x::Float32, y::Float32) = __middle(Float32, UInt32, x, y)
__middle(x::Float16, y::Float16) = __middle(Float16, UInt16, x, y)
## fallback for non FloatNN number types
__middle(x::Number, y::Number) = 0.5*x + 0.5*y
__middle(x::Number, y::Number) = 0.5*x + 0.5*y


function __middle(T, S, x, y)
Expand All @@ -206,12 +206,12 @@ function assess_convergence(M::Bisection, state::UnivariateZeroState{T,S}, optio
assess_convergence(BisectionExact(), state, options) && return true

x0, x1 = state.xn0, state.xn1

tol = max(options.xabstol, max(abs(x0), abs(x1)) * options.xreltol)
if x1 - x0 > tol
if x1 - x0 > tol
return false
end

state.message = ""
state.x_converged = true
return true
Expand Down Expand Up @@ -275,7 +275,7 @@ function find_zero(fs, x0, method::M;
tracks = NullTracks(),
verbose=false,
kwargs...) where {M <: Union{Bisection}}

x = adjust_bracket(x0)
T = eltype(x[1])
F = callable_function(fs)
Expand All @@ -285,19 +285,20 @@ function find_zero(fs, x0, method::M;

l = (verbose && isa(tracks, NullTracks)) ? Tracks(eltype(state.xn1)[], eltype(state.fxn1)[]) : tracks


if iszero(tol)
if T <: FloatNN
find_zero(BisectionExact(), F, options, state, l)
return find_zero(F, x, BisectionExact(); tracks=tracks, verbose=verbose, kwargs...)
else
return find_zero(F, x, A42())
return find_zero(F, x, A42(); tracks=tracks, verbose=verbose, kwargs...)
end
else
find_zero(method, F, options, state, l)
end

find_zero(method, F, options, state, l)

verbose && show_trace(method, nothing, state, l)

state.xn1

end


Expand Down Expand Up @@ -371,18 +372,18 @@ function ipzero(a::T, b, c, d, fa, fb, fc, fd, delta=zero(T)) where {T}
c = a + (Q31 + Q32 + Q33)

(a + 2delta < c < b - 2delta) && return c

newton_quadratic(a,b,d,fa,fb,fd, 3, delta)

end

# return c in (a+delta, b-delta)
# adds part of `bracket` from paper with `delta`
function newton_quadratic(a::T, b, d, fa, fb, fd, k::Int, delta=zero(T)) where {T}

A = f_abd(a,b,d,fa,fb,fd)
r = isbracket(A,fa) ? b : a

# use quadratic step; if that fails, use secant step; if that fails, bisection
if !(isnan(A) || isinf(A)) || !iszero(A)
B = f_ab(a,b,fa,fb)
Expand All @@ -401,15 +402,15 @@ function newton_quadratic(a::T, b, d, fa, fb, fd, k::Int, delta=zero(T)) where {
r = secant_step(a, b, fa, fb)

if a + 2delta < r < b - 2delta
return r
return r
end

return _middle(a, b) # is in paper r + sgn * 2 * delta

end

# (todo: DRY up?)
function init_state(M::AbstractAlefeldPotraShi, f, xs)
function init_state(M::AbstractAlefeldPotraShi, f, xs)
u, v = promote(float(xs[1]), float(xs[2]))
if u > v
u, v = v, u
Expand All @@ -421,7 +422,7 @@ function init_state(M::AbstractAlefeldPotraShi, f, xs)
0, 2,
false, false, false, false,
"")

init_state!(state, M, f, (u,v), false)
state
end
Expand Down Expand Up @@ -544,8 +545,8 @@ function assess_convergence(method::AbstractAlefeldPotraShi, state::UnivariateZe
return true
end



return false
end

Expand All @@ -567,7 +568,7 @@ function update_state(M::A42, f, state::UnivariateZeroState{T,S}, options::Univa
μ, λ = 0.5, 0.7
tole = max(options.xabstol, max(abs(a),abs(b)) * options.xreltol) # paper uses 2|u|*rtol + atol
delta = λ * tole

if state.steps < 1
c = newton_quadratic(a, b, d, fa, fb, fd, 2)
else
Expand All @@ -576,7 +577,7 @@ function update_state(M::A42, f, state::UnivariateZeroState{T,S}, options::Univa
fc::S = f(c)
incfn(state)
check_zero(M, state, c, fc) && return nothing

ab::T, bb::T, db::T, fab::S, fbb::S, fdb::S = bracket(a,b,c,fa,fb,fc)
eb::T, feb::S = d, fd

Expand All @@ -586,21 +587,21 @@ function update_state(M::A42, f, state::UnivariateZeroState{T,S}, options::Univa
check_zero(M, state, cb, fcb) && return nothing

ab,bb,db,fab,fbb,fdb = bracket(ab,bb,cb,fab,fbb,fcb)


u::T, fu::S = choose_smallest(ab, bb, fab, fbb)
cb = u - 2 * fu * (bb - ab) / (fbb - fab)
ch::T = cb
if abs(cb - u) > 0.5 * (b-a)
if abs(cb - u) > 0.5 * (b-a)
ch = _middle(an, bn)
end
fch::S = f(cb)
incfn(state)
incfn(state)
check_zero(M, state, ch, fch) && return nothing

ah::T, bh::T, dh::T, fah::S, fbh::S, fdh::S = bracket(ab, bb, ch, fab, fbb, fch)

if bh - ah < μ * (b - a)
if bh - ah < μ * (b - a)
#a, b, d, fa, fb, fd = ahat, b, dhat, fahat, fb, fdhat # typo in paper
a, b, d, ee = ah, bh, dh, db
fa, fb, fd, fe = fah, fbh, fdh, fdb
Expand Down Expand Up @@ -633,39 +634,39 @@ struct AlefeldPotraShi <: AbstractAlefeldPotraShi end

# ## 3, maybe 4, functions calls per step
function update_state(M::AlefeldPotraShi, f, state::UnivariateZeroState{T,S}, options::UnivariateZeroOptions) where {T,S}

a::T,b::T,d::T = state.xn0, state.xn1, state.m[1]
fa::S,fb::S,fd::S = state.fxn0, state.fxn1, state.fm[1]

μ, λ = 0.5, 0.7
tole = max(options.xabstol, max(abs(a),abs(b)) * options.xreltol) # paper uses 2|u|*rtol + atol
delta = λ * tole

c::T = newton_quadratic(a, b, d, fa, fb, fd, 2, delta)
fc::S = f(c)
incfn(state)
check_zero(M, state, c, fc) && return nothing

a,b,d,fa,fb,fd = bracket(a,b,c,fa,fb,fc)

c = newton_quadratic(a,b,d,fa,fb,fd, 3, delta)
fc = f(c)
incfn(state)
incfn(state)
check_zero(M, state, c, fc) && return nothing

a, b, d, fa, fb, fd = bracket(a, b, c, fa, fb,fc)

u::T, fu::S = choose_smallest(a, b, fa, fb)
c = u - 2 * fu * (b - a) / (fb - fa)
if abs(c - u) > 0.5 * (b - a)
c = _middle(a, b)
c = _middle(a, b)
end
fc = f(c)
incfn(state)
incfn(state)
check_zero(M, state, c, fc) && return nothing

ahat::T, bhat::T, dhat::T, fahat::S, fbhat::S, fdhat::S = bracket(a, b, c, fa, fb, fc)
if bhat - ahat < μ * (b - a)
if bhat - ahat < μ * (b - a)
#a, b, d, fa, fb, fd = ahat, b, dhat, fahat, fb, fdhat # typo in paper
a, b, d, fa, fb, fd = ahat, bhat, dhat, fahat, fbhat, fdhat
else
Expand Down Expand Up @@ -701,18 +702,18 @@ function log_step(l::Tracks, M::Brent, state)
end

#
function init_state(M::Brent, f, xs)
function init_state(M::Brent, f, xs)
u, v = promote(float(xs[1]), float(xs[2]))
fu, fv = promote(f(u), f(v))
isbracket(fu, fv) || throw(ArgumentError(bracketing_error))

# brent store b as smaller of |fa|, |fb|
if abs(fu) > abs(fv)
a, b, fa, fb = u, v, fu, fv
else
a, b, fa, fb = v, u, fv, fu
end



state = UnivariateZeroState(b, a, [a, a], ## x1, x0, c, d
Expand All @@ -728,7 +729,7 @@ function init_state!(state::UnivariateZeroState{T,S}, ::Brent, f, xs::Union{Tupl
u::T, v::T = promote(float(xs[1]), float(xs[2]))
fu::S, fv::S = promote(f(u), f(v))
isbracket(fu, fv) || throw(ArgumentError(bracketing_error))

# brent store b as smaller of |fa|, |fb|
if abs(fu) > abs(fv)
a, b, fa, fb = u, v, fu, fv
Expand Down Expand Up @@ -769,7 +770,7 @@ function update_state(M::Brent, f, state::UnivariateZeroState{T,S}, options::Uni
end

tol = max(options.xabstol, max(abs(b), abs(c), abs(d)) * options.xreltol)
if !(u < s < v) ||
if !(u < s < v) ||
(mflag && abs(s - b) >= abs(b-c)/2) ||
(!mflag && abs(s - b) >= abs(b-c)/2) ||
(mflag && abs(b-c) <= tol) ||
Expand All @@ -791,15 +792,15 @@ function update_state(M::Brent, f, state::UnivariateZeroState{T,S}, options::Uni
else
a, fa = s, fs
end

if abs(fa) < abs(fb)
a, b, fa, fb = b, a, fb, fa
end

state.xn0, state.xn1, state.m[1], state.m[2] = a, b, c, d
state.fxn0, state.fxn1, state.fm[1] = fa, fb, fc
state.fm[2] = mflag ? one(fa) : -one(fa)

return nothing
end

Expand Down Expand Up @@ -851,7 +852,7 @@ function update_state(method::FalsePosition, fs, o::UnivariateZeroState{T,S}, op
lambda = 1/2
end

x::T = b - lambda * (b-a)
x::T = b - lambda * (b-a)
fx::S = fs(x)
incfn(o)

Expand All @@ -865,13 +866,13 @@ function update_state(method::FalsePosition, fs, o::UnivariateZeroState{T,S}, op
if sign(fx)*sign(fb) < 0
a, fa = b, fb
else
fa = galdino_reduction(method, fa, fb, fx)
fa = galdino_reduction(method, fa, fb, fx)
end
b, fb = x, fx

o.xn0, o.xn1 = a, b
o.xn0, o.xn1 = a, b
o.fxn0, o.fxn1 = fa, fb

nothing
end

Expand All @@ -888,7 +889,7 @@ galdino = Dict{Union{Int,Symbol},Function}(:1 => (fa, fb, fx) -> fa*fb/(fb+fx),
:9 => (fa, fb, fx) -> fa/(1 + fx/fb)^2,
:10 => (fa, fb, fx) -> (fa-fx)/4,
:11 => (fa, fb, fx) -> fx*fa/(fb+fx),
:12 => (fa, fb, fx) -> (fa * (1-fx/fb > 0 ? 1-fx/fb : 1/2))
:12 => (fa, fb, fx) -> (fa * (1-fx/fb > 0 ? 1-fx/fb : 1/2))
)


Expand Down
14 changes: 7 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ isissue(x) = iszero(x) || isnan(x) || isinf(x)


## find a default secant step
function _default_secant_step(x1)
h = eps(one(real(x1)))^(1/3)
function _default_secant_step(x1)
h = eps(one(real(x1)))^(1//3)
dx = h*oneunit(x1) + abs(x1)*h^2 # adjust for if eps(x1) > h
x0 = x1 + dx
x0
Expand All @@ -39,10 +39,10 @@ function steff_step(x, fx)

xbar, fxbar = real(x/oneunit(x)), fx/oneunit(fx)
thresh = max(1, abs(xbar)) * sqrt(eps(one(xbar))) #^(1/2) # max(1, sqrt(abs(x/fx))) * 1e-6
out = abs(fxbar) <= thresh ? fxbar : sign(fx) * thresh

out = abs(fxbar) <= thresh ? fxbar : sign(fx) * thresh
out * oneunit(x)

end

function guarded_secant_step(alpha, beta, falpha, fbeta)
Expand Down Expand Up @@ -104,10 +104,10 @@ function _fbracket_diff(a,b,c, fa, fb, fc)
x1, issue = _fbracket(b, c, fb, fc)
issue && return x1, issue
x2, issue = _fbracket(a, b, fa, fb)
issue && return x2, issue
issue && return x2, issue
x3, issue = _fbracket(a, c, fa, fc)
issue && return x3, issue

out = x1 - x2 + x3
out, isissue(out)
end
Expand Down

0 comments on commit 5f6bd88

Please sign in to comment.