Skip to content

Commit

Permalink
Merge eca741e into a992eff
Browse files Browse the repository at this point in the history
  • Loading branch information
djsegal committed Jun 19, 2018
2 parents a992eff + eca741e commit 313e971
Showing 1 changed file with 242 additions and 27 deletions.
269 changes: 242 additions & 27 deletions src/bracketing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -718,39 +718,254 @@ Called by `fzeros` or `Roots.find_zeros`.
"""
function find_zeros(f, a::Real, b::Real, args...;
no_pts::Int=100,
abstol::Real=10*eps(), reltol::Real=10*eps(), ## should be abstol, reltol as used.
no_pts::Int=101,
abstol::Real=10*eps(), reltol::Real=10*eps(), ## should be abstol, reltol as used.
kwargs...)

a, b = a < b ? (a,b) : (b,a)
rts = eltype(promote(float(a),b))[]
xs = vcat(a, a .+ (b-a) .* sort(rand(no_pts)), b)
cur_range = get_real_finite_range(f, a, b, no_pts)

root_list = Real[]

## Look in [ai, bi)
for i in 1:(no_pts+1)
ai,bi=xs[i:i+1]
if isapprox(f(ai), 0.0, rtol=reltol, atol=abstol)
push!(rts, ai)
elseif sign(f(ai)) * sign(f(bi)) < 0
push!(rts, find_zero(f, [ai, bi], Bisection()))
else
try
x = find_zero(f, ai + (0.5)* (bi-ai), Order8(); maxevals=10, abstol=abstol, reltol=reltol)
if ai < x < bi
push!(rts, x)
end
catch e
end
isempty(cur_range) && return root_list

if a != first(cur_range)
sub_no_pts = Int(floor( no_pts / 4 )) + 1
cur_sub_range = collect(linspace(a, first(cur_range), sub_no_pts))

append!(
root_list,
find_bisection_roots(f, cur_sub_range, abstol, reltol)
)
end

if b != last(cur_range)
sub_no_pts = Int(floor( no_pts / 4 )) + 1
cur_sub_range = collect(linspace(last(cur_range), b, sub_no_pts))

append!(
root_list,
find_bisection_roots(f, cur_sub_range, abstol, reltol)
)
end

append!(
root_list,
_find_recursive_zeros(
f, cur_range, args...;
no_pts=no_pts, abstol=abstol, reltol=reltol, kwargs...
)
)

# redo if it appears function oscillates alot in this interval...
if length(root_list) > Int(ceil( (1/4) * no_pts ))
return find_zeros(
f, a, b, args...;
no_pts = 10*no_pts, abstol=abstol, reltol=reltol, kwargs...
)
end

sort!(root_list)

root_list
end

function get_real_finite_range(f, a::Real, b::Real, no_pts::Int)
f_a, f_b = f(a), f(b)

work_range = collect(linspace(a,b,no_pts))

f_a_next, f_b_prev = f(work_range[2]), f(work_range[end-1])

function is_valid_f(cur_f)
tmp_f = float(cur_f)
isreal(tmp_f) && !isinf(tmp_f)
end

is_valid_a = is_valid_f(f_a) && is_valid_f(f_a_next)
is_valid_b = is_valid_f(f_b) && is_valid_f(f_b_prev)

is_valid_a && is_valid_b && return work_range
is_valid_a || is_valid_b || return []

fixed_value = is_valid_a ? a : b
wrong_value = is_valid_a ? b : a

float_value = fixed_value
stash_value = fixed_value

attempt_count = 15

for cur_attempt in 1:attempt_count
float_value += wrong_value
float_value /= 2

is_valid_f(f(float_value)) || break
stash_value = float_value
end

cur_diff = ( float_value - stash_value )
cur_diff /= ( attempt_count + 1 )

found_value = stash_value
float_value = stash_value

for cur_attempt in 1:attempt_count
float_value += cur_diff

is_valid_f(f(float_value)) || break
found_value = float_value
end

cur_range = collect(linspace(fixed_value, found_value, no_pts))

( fixed_value > found_value ) && reverse!(cur_range)

cur_range
end

function _find_recursive_zeros(f, cur_range::AbstractVector{T}, args...;
cur_depth::Int=1,
no_pts::Int=101,
abstol::Real=10*eps(), reltol::Real=10*eps(), ## should be abstol, reltol as used.
kwargs...) where T <: Real

root_list = find_root_list(f, cur_range, abstol, reltol)

isempty(root_list) && return root_list

sub_no_pts = Int(ceil( no_pts / (length(root_list)/2 * 2^(cur_depth-1)) )) + 1

cur_intervals = zip(
vcat(first(cur_range),root_list),
vcat(root_list,last(cur_range))
)

for (cur_a, cur_b) in cur_intervals
isapprox(cur_a, cur_b, rtol=reltol, atol=abstol) && continue
cur_sub_range = collect(linspace(cur_a, cur_b, sub_no_pts))[2:end-1]

cur_roots = _find_recursive_zeros(
f, cur_sub_range, args...;
no_pts=sub_no_pts, abstol=abstol, reltol=reltol,
cur_depth=(cur_depth+1), kwargs...
)

append!(root_list, cur_roots)
end

root_list
end

function find_root_list(f, cur_range::AbstractVector{T}, abstol::Real, reltol::Real) where T <: Real
isempty(cur_range) && return Real[]

cur_roots = find_bisection_roots(f, cur_range, abstol, reltol)
isempty(cur_roots) || return cur_roots

cur_roots = find_order_roots(f, cur_range, abstol, reltol)

cur_roots
end

function find_bisection_roots(f, cur_range::AbstractVector{T}, abstol::Real, reltol::Real) where T <: Real
no_pts = length(cur_range)

cur_roots = Real[]

function work_f(x)
cur_f = float(f(x))
isreal(cur_f) && return cur_f

real_f, imag_f = real(cur_f), imag(cur_f)
cur_ratio = abs( real_f / imag_f )

( cur_ratio > 1e6 ) && return real_f
( abs(real_f) < 1e-3 && abs(imag_f) < 1e-3 ) && return real_f

cur_f
end

cur_values = map(
cur_f -> isapprox(cur_f, 0.0, rtol=reltol, atol=abstol) ? 0.0 : cur_f,
map(float, work_f.(cur_range))
)

cur_signs = map(sign, cur_values[1:end-1] .* cur_values[2:end])

for (cur_index, cur_sign) in enumerate(cur_signs)
imag(cur_sign) < 1e-6 || continue
isinf(cur_sign) && continue

cur_a, cur_b = cur_range[cur_index:cur_index+1]

if iszero(cur_sign)
iszero(cur_values[cur_index]) && push!(cur_roots, cur_a)

( cur_index == no_pts - 1 ) &&
iszero(cur_values[cur_index+1]) && push!(cur_roots, cur_b)

continue
end

( real(cur_sign) > 0 ) && continue

tmp_roots = _find_recursive_bisection_roots(work_f, cur_a, cur_b)
append!(cur_roots, tmp_roots)
end
## finally, b?
isapprox(f(b), 0.0, rtol=reltol, atol=abstol) && push!(rts, b)

## redo if it appears function oscillates alot in this interval...
if length(rts) > (1/4) * no_pts
return(find_zeros(f, a, b, args...; no_pts = 10*no_pts, abstol=abstol, reltol=reltol, kwargs...))
else
return(sort(rts))
cur_roots
end

function _find_recursive_bisection_roots(f, a::Real, b::Real; cur_depth::Int=1)
cur_end_points = [a, b]

try
cur_root = find_zero(f, cur_end_points, Bisection())
return [cur_root]
catch cur_error
isa(cur_error, ArgumentError) && return []
isa(cur_error, MethodError) || rethrow(cur_error)
( cur_depth > 8 ) && return []
end

cur_roots = []
cur_average = mean(cur_end_points)

append!(cur_roots, _find_recursive_bisection_roots(f, a, cur_average, cur_depth=cur_depth+1))
append!(cur_roots, _find_recursive_bisection_roots(f, cur_average, b, cur_depth=cur_depth+1))

cur_roots
end

function find_order_roots(f, cur_range::AbstractVector{T}, abstol::Real, reltol::Real) where T <: Real
cur_roots = Real[]

min_value = first(cur_range)
max_value = last(cur_range)

cur_guesses = view( cur_range , (2:2:length(cur_range)-1) )

for cur_guess in cur_guesses
tmp_root = NaN

try
tmp_root = find_zero(
f, cur_guess, Order8();
maxevals=10, abstol=abstol, reltol=reltol
)
catch cur_error
continue
end

( min_value < tmp_root < max_value ) || continue

any(
work_root -> isapprox(tmp_root, work_root, rtol=reltol, atol=abstol),
cur_roots
) && continue

push!(cur_roots, tmp_root)
end

cur_roots
end

0 comments on commit 313e971

Please sign in to comment.