Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed up pow by 4x or more #210

Closed
wants to merge 11 commits into from
4 changes: 3 additions & 1 deletion src/IntervalArithmetic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ using SetRounding

using Markdown


using LinearAlgebra
import LinearAlgebra: ×, dot, norm
export ×, dot
Expand All @@ -36,7 +37,8 @@ import Base:
abs, abs2,
show,
isinteger, setdiff,
parse, hash
parse, hash,
power_by_squaring

import Base: # for IntervalBox
broadcast, length,
Expand Down
85 changes: 52 additions & 33 deletions src/intervals/functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
# Use the BigFloat version from MPFR instead, which is correctly-rounded:

# Write explicitly like this to avoid ambiguity warnings:

for T in (:Integer, :Rational, :Float64, :BigFloat, :Interval)
@eval ^(a::Interval{Float64}, x::$T) = atomic(Interval{Float64}, big53(a)^x)
@eval pow(a::Interval{Float64}, x::$T) = atomic(Interval{Float64}, big53(a)^x)
end


Expand All @@ -16,10 +17,16 @@ end
# overwrite new behaviour for small integer powers from
# https://github.com/JuliaLang/julia/pull/24240:

Base.literal_pow(::typeof(^), x::Interval{T}, ::Val{p}) where {T,p} = x^p
Base.literal_pow(::typeof(^), x::Interval{T}, ::Val{p}) where {T,p} = ^(x, p)


"""
pow(x::Interval, y)

function ^(a::Interval{BigFloat}, n::Integer)
Slow, correctly-rounded calculation of `x^y`.
This uses `BigFloat`s internally.
"""
function pow(a::Interval{BigFloat}, n::Integer)
isempty(a) && return a
n == 0 && return one(a)
n == 1 && return a
Expand Down Expand Up @@ -69,23 +76,12 @@ function ^(a::Interval{BigFloat}, n::Integer)
end
end

function sqr(a::Interval{T}) where T<:Real
return a^2
# isempty(a) && return a
# if a.lo ≥ zero(T)
# return @round(a.lo^2, a.hi^2)
#
# elseif a.hi ≤ zero(T)
# return @round(a.hi^2, a.lo^2)
# end
#
# return @round(mig(a)^2, mag(a)^2)
end
sqr(a::Interval{T}) where {T<:Real} = a^2

^(a::Interval{BigFloat}, x::AbstractFloat) = a^big(x)
pow(a::Interval{BigFloat}, x::AbstractFloat) = a^big(x)

# Floating-point power of a BigFloat interval:
function ^(a::Interval{BigFloat}, x::BigFloat)
function pow(a::Interval{BigFloat}, x::BigFloat)

domain = Interval{BigFloat}(0, Inf)

Expand Down Expand Up @@ -135,14 +131,14 @@ function ^(a::Interval{BigFloat}, x::BigFloat)
return hull(lo, hi)
end

function ^(a::Interval{Rational{T}}, x::AbstractFloat) where T<:Integer
function pow(a::Interval{Rational{T}}, x::AbstractFloat) where T<:Integer
a = Interval(a.lo.num/a.lo.den, a.hi.num/a.hi.den)
a = a^x
atomic(Interval{Rational{T}}, a)
end

# Rational power
function ^(a::Interval{BigFloat}, r::Rational{S}) where S<:Integer
function pow(a::Interval{BigFloat}, r::Rational{S}) where S<:Integer
T = BigFloat
domain = Interval{T}(0, Inf)

Expand All @@ -164,7 +160,7 @@ function ^(a::Interval{BigFloat}, r::Rational{S}) where S<:Integer
end

# Interval power of an interval:
function ^(a::Interval{BigFloat}, x::Interval)
function pow(a::Interval{BigFloat}, x::Interval)
T = BigFloat
domain = Interval{T}(0, Inf)

Expand Down Expand Up @@ -194,36 +190,59 @@ function sqrt(a::Interval{T}) where T
end

"""
pow(x::Interval, n::Integer)
^(x::Interval, n::Integer)

A faster implementation of `x^n`, currently using `power_by_squaring`.
`pow(x, n)` will usually return an interval that is slightly larger than that calculated by `x^n`, but is guaranteed to be a correct
A fast implementation of `x^n`, using `power_by_squaring`.
`^(x, n)` will usually return an interval that is slightly larger than that calculated by `pow(x, n)`, but is guaranteed to be a correct
enclosure when using multiplication with correct rounding.
"""
function pow(x::Interval, n::Integer) # fast integer power
function ^(x::Interval, n::Integer) # fast integer power

isempty(x) && return x

if iseven(n) && 0 ∈ x
if n < 0
return inv(x^(-n))
end

if iseven(n)
if 0 ∈ x

return Interval(zero(eltype(x)),
power_by_squaring(mag(x), n, RoundUp))

elseif x.lo > 0
return Interval(power_by_squaring(x.lo, n, RoundDown),
power_by_squaring(x.hi, n, RoundUp))

else # x.lo < x.hi < 0
return Interval(power_by_squaring(-x.hi, n, RoundDown),
power_by_squaring(-x.lo, n, RoundUp))
end

return hull(zero(x),
hull(Base.power_by_squaring(Interval(mig(x)), n),
Base.power_by_squaring(Interval(mag(x)), n))
)
else # odd n

else
a = power_by_squaring(x.lo, n, RoundDown)
b = power_by_squaring(x.hi, n, RoundUp)

return hull( Base.power_by_squaring(Interval(x.lo), n),
Base.power_by_squaring(Interval(x.hi), n) )
return Interval(a, b)

end
#
# else # completely negative interval
# a = power_by_squaring(x.lo, n, RoundDown)
# b = power_by_squaring(x.hi, n, RoundUp)
#
# return Interval(a, b)
#end

end

function pow(x::Interval, y::Real) # fast real power, including for y an Interval
function ^(x::Interval, y::Real) # fast real power, including for y an Interval

isempty(x) && return x

isinteger(y) && return x^(convert(Integer, y))

return exp(y * log(x))

end
Expand Down
3 changes: 2 additions & 1 deletion src/intervals/intervals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Interval(a::T, b::S) where {T<:Real, S<:Real} = Interval(promote(a,b)...)
Interval(a::T, b::T) where T<:Integer = Interval(float(a), float(b))
Interval(a::T, b::T) where T<:Irrational = Interval(float(a), float(b))

eltype(x::Interval{T}) where T<:Real = T
eltype(x::Interval{T}) where T<:Real = typeof(x)

Interval(x::Interval) = x
Interval(x::Complex) = Interval(real(x)) + im*Interval(imag(x))
Expand Down Expand Up @@ -125,6 +125,7 @@ include("conversion.jl")
include("precision.jl")
include("set_operations.jl")
include("arithmetic.jl")
include("powers.jl")
include("functions.jl")
include("trigonometric.jl")
include("hyperbolic.jl")
Expand Down
52 changes: 52 additions & 0 deletions src/intervals/powers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# power_by_squaring adapted from Base Julia

function power_by_squaring(x::AbstractFloat, p::Integer, r::RoundingMode)

if p == 1
return x
elseif p == 0
return one(x)
elseif p == 2
return *(x, x, r) # multiplication with directed rounding
elseif p < 0
isone(x) && return copy(x)
isone(-x) && return iseven(p) ? one(x) : copy(x)
Base.throw_domerr_powbysq(x, p)
end
t = trailing_zeros(p) + 1
p >>= t
while (t -= 1) > 0
x = *(x, x, r)
end
y = x
while p > 0
t = trailing_zeros(p) + 1
p >>= t
while (t -= 1) >= 0
x = *(x, x, r)
end
y = *(y, x, r)
end
return y
end


function fast_sqrt(x::AbstractFloat, ::RoundingMode{:Down})
y = sqrt(x)

while *(y, y, RoundUp) > x
y = prevfloat(y)
end

return y
end

function fast_sqrt(x::AbstractFloat, ::RoundingMode{:Up})
y = sqrt(x)

while *(y, y, RoundDown) > x
y = nextfloat(y)
end

return y
end