Skip to content

Commit

Permalink
fix hypot with multiple arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Jorge Fernandez-de-Cossio-Diaz committed Jan 14, 2020
1 parent 756891d commit d3addb5
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 41 deletions.
99 changes: 65 additions & 34 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,56 +610,66 @@ julia> hypot(3, 4im)
5.0
```
"""
hypot(x::Number, y::Number) = hypot(promote(x, y)...)
hypot(x::Complex, y::Complex) = hypot(abs(x), abs(y))
hypot(x::T, y::T) where {T<:Real} = hypot(float(x), float(y))
hypot(x::T, y::T) where {T<:Number} = (z = y/x; abs(x) * sqrt(one(z) + z*z))
function hypot(x::T, y::T) where T<:AbstractFloat
#Return Inf if either or both imputs is Inf (Compliance with IEEE754)
if isinf(x) || isinf(y)
return T(Inf)
hypot(x::Number) = abs(float(x))
hypot(x::Number, xs::Number...) = hypot(promote(x, xs...)...)
function hypot(x::T, y::T) where T<:Number
# preserves unit
axu = abs(float(x))
ayu = abs(float(y))

# unitless
ax = axu / oneunit(axu)
ay = ayu / oneunit(ayu)


#ax = abs(float(x / oneunit(x)))
#ay = abs(float(y / oneunit(y)))

# Return Inf if either or both imputs is Inf (Compliance with IEEE754)
if isinf(ax) || isinf(ay)
return oftype(axu, Inf)
end

# Order the operands
ax,ay = abs(x), abs(y)
if ay > ax
ax,ay = ay,ax
axu, ayu = ayu, axu
ax, ay = ay, ax
end

# Widely varying operands
if ay <= ax*sqrt(eps(T)/2) #Note: This also gets ay == 0
return ax
if ay ax * sqrt(eps(typeof(ax)) / 2) # Note: This also gets ay == 0
return axu
end

# Operands do not vary widely
scale = eps(sqrt(floatmin(T))) #Rescaling constant
if ax > sqrt(floatmax(T)/2)
ax = ax*scale
ay = ay*scale
scale = eps(sqrt(floatmin(ax))) # rescaling constant
if ax > sqrt(floatmax(ax) / 2)
ax = ax * scale
ay = ay * scale
scale = inv(scale)
elseif ay < sqrt(floatmin(T))
ax = ax/scale
ay = ay/scale
elseif ay < sqrt(floatmin(ax))
ax = ax / scale
ay = ay / scale
else
scale = one(scale)
scale = oneunit(scale)
end
h = sqrt(muladd(ax,ax,ay*ay))
# This branch is correctly rounded but requires a native hardware fma.
if Base.Math.FMA_NATIVE
hsquared = h*h
axsquared = ax*ax
h -= (fma(-ay,ay,hsquared-axsquared) + fma(h,h,-hsquared) - fma(ax,ax,-axsquared))/(2*h)
# This branch is within one ulp of correctly rounded.
else
if h <= 2*ay
delta = h-ay
h -= muladd(delta,delta-2*(ax-ay),ax*(2*delta - ax))/(2*h)

h = sqrt(muladd(ax, ax, ay * ay))

if Base.Math.FMA_NATIVE # This branch is correctly rounded but requires a native hardware fma.
hsquared = h * h
axsquared = ax * ax
h -= (fma(-ay, ay, hsquared - axsquared) + fma(h, h, -hsquared) - fma(ax, ax, -axsquared))/(2*h)
else # This branch is within one ulp of correctly rounded.
if h 2ay
delta = h - ay
h -= muladd(delta, delta - 2 * (ax - ay), ax * (2delta - ax)) / (2h)
else
delta = h-ax
h -= muladd(delta,delta,muladd(ay,(4*delta-ay),2*delta*(ax-2*ay)))/(2*h)
h -= muladd(delta, delta, muladd(ay, (4delta - ay), 2delta * (ax - 2ay))) / (2h)
end
end
return h*scale
return h * scale * oneunit(axu)
end

"""
Expand All @@ -676,7 +686,28 @@ julia> hypot(3, 4im, 12.0)
13.0
```
"""
hypot(x::Number...) = sqrt(sum(abs2(y) for y in x))
function hypot(x1::T, x2::T, x3::T, xs::T...) where {T<:Number}
v = float.((x1, x2, x3, xs...))
# speculatively try fast naive approach
s = sum(abs2, v)
if isnan(s)
return any(isinf, v) ? oftype(sqrt(s), Inf) : sqrt(s) # IEEE 754
elseif isinf(s) || s oneunit(s) * _floatmin(one(s)) * (length(v) - 1)
# if overflow/underflow, try normalization
ma = maximum(abs, v)
if iszero(ma) || isinf(ma)
return ma
else
return ma * sqrt(sum(y -> abs2(y / ma), v))
end
else
return sqrt(s)
end
end

_floatmin(x::AbstractFloat) = _floatmin(typeof(x))
_floatmin(::Type{T}) where {T<:AbstractFloat} = nextfloat(zero(T)) / eps(T)
_floatmin(::Type{T}) where {T<:IEEEFloat} = floatmin(T)

atan(y::Real, x::Real) = atan(promote(float(y),float(x))...)
atan(y::T, x::T) where {T<:AbstractFloat} = Base.no_op_err("atan", T)
Expand Down
70 changes: 64 additions & 6 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,6 @@ end
@test hypot(T(0), T(0)) === T(0)
@test hypot(T(Inf), T(Inf)) === T(Inf)
@test hypot(T(Inf), T(x)) === T(Inf)
@test hypot(T(Inf), T(NaN)) === T(Inf)
@test isnan_type(T, hypot(T(x), T(NaN)))
end
end
Expand Down Expand Up @@ -1032,8 +1031,67 @@ end

isdefined(Main, :Furlongs) || @eval Main include("testhelpers/Furlongs.jl")
using .Main.Furlongs
@test hypot(Furlong(0), Furlong(0)) == Furlong(0.0)
@test hypot(Furlong(3), Furlong(4)) == Furlong(5.0)
@test hypot(Furlong(NaN), Furlong(Inf)) == Furlong(Inf)
@test hypot(Furlong(Inf), Furlong(NaN)) == Furlong(Inf)
@test hypot(Furlong(Inf), Furlong(Inf)) == Furlong(Inf)
@test (@inferred hypot(Furlong(0), Furlong(0))) == Furlong(0.0)
@test (@inferred hypot(Furlong(3), Furlong(4))) == Furlong(5.0)
@test (@inferred hypot(Furlong(NaN), Furlong(Inf))) == Furlong(Inf)
@test (@inferred hypot(Furlong(Inf), Furlong(NaN))) == Furlong(Inf)
@test (@inferred hypot(Furlong(0), Furlong(0), Furlong(0))) == Furlong(0.0)
@test (@inferred hypot(Furlong(Inf), Furlong(Inf))) == Furlong(Inf)
@test (@inferred hypot(Furlong(1), Furlong(1), Furlong(1))) == Furlong(sqrt(3))
@test (@inferred hypot(Furlong(Inf), Furlong(NaN), Furlong(0))) == Furlong(Inf)
@test (@inferred hypot(Furlong(Inf), Furlong(Inf), Furlong(Inf))) == Furlong(Inf)
@test isnan(hypot(Furlong(NaN), Furlong(0), Furlong(1)))

@testset "hypot" begin
@test_throws MethodError hypot()
@test (@inferred hypot(floatmax())) == floatmax()
@test (@inferred hypot(floatmax(), floatmax())) == Inf
@test (@inferred hypot(floatmin(), floatmin())) == 2floatmin()
@test (@inferred hypot(floatmin(), floatmin(), floatmin())) == 3floatmin()
@test (@inferred hypot(1e-162)) 1e-162
@test (@inferred hypot(2e-162, 1e-162, 1e-162)) hypot(2,1,1)*1e-162
@test (@inferred hypot(1e162)) 1e162
@test hypot(-2) === 2.0
@test hypot(-2, 0) === 2.0
let i = typemax(Int)
@test (@inferred hypot(i, i)) i * 2
@test (@inferred hypot(i, i, i)) i * 3
@test (@inferred hypot(i, i, i, i)) 2.0i
@test (@inferred hypot(i//1, 1//i, 1//i)) i
end
let i = typemin(Int)
@test (@inferred hypot(i, i)) -√2i
@test (@inferred hypot(i, i, i)) -√3i
@test (@inferred hypot(i, i, i, i)) -2.0i
end
@testset "$T" for T in (Float32, Float64)
@test (@inferred hypot(T(Inf), T(NaN))) == T(Inf) # IEEE754 says so
@test (@inferred hypot(T(Inf), T(3//2), T(NaN))) == T(Inf)
@test (@inferred hypot(T(1e10), T(1e10), T(1e10), T(1e10))) 2e10
@test isnan_type(T, hypot(T(3), T(3//4), T(NaN)))
@test hypot(T(1), T(0)) === T(1)
@test hypot(T(1), T(0), T(0)) === T(1)
@test (@inferred hypot(T(Inf), T(Inf), T(Inf))) == T(Inf)
for s in (zero(T), floatmin(T)*1e3, floatmax(T)*1e-3, T(Inf))
@test hypot(1s, 2s) s * hypot(1,2) rtol=8eps(T)
@test hypot(1s, 2s, 3s) s * hypot(1,2,3) rtol=8eps(T)
end
end
@testset "$T" for T in (Float16, Float32, Float64, BigFloat)
let x = 1.1sqrt(floatmin(T))
@test (@inferred hypot(x,x/4)) x * sqrt(17/BigFloat(16))
@test (@inferred hypot(x,x/4,x/4)) x * sqrt(9/BigFloat(8))
end
let x = 2sqrt(nextfloat(zero(T)))
@test (@inferred hypot(x,x/4)) x * sqrt(17/BigFloat(16))
@test (@inferred hypot(x,x/4,x/4)) x * sqrt(9/BigFloat(8))
end
let x = sqrt(nextfloat(zero(T))/eps(T))/8, f = sqrt(4eps(T))
@test hypot(x,x*f) x * hypot(one(f),f) rtol=eps(T)
@test hypot(x,x*f,x*f) x * hypot(one(f),f,f) rtol=eps(T)
end
end
# hypot on Complex returns Real
@test (@inferred hypot(3, 4im)) === 5.0
@test (@inferred hypot(3, 4im, 12)) == 13.0
end
9 changes: 8 additions & 1 deletion test/testhelpers/Furlongs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ Base.oneunit(x::Type{Furlong{p,T}}) where {p,T} = Furlong{p,T}(one(T))
Base.zero(x::Furlong{p,T}) where {p,T} = Furlong{p,T}(zero(T))
Base.zero(::Type{Furlong{p,T}}) where {p,T} = Furlong{p,T}(zero(T))
Base.iszero(x::Furlong) = iszero(x.val)
Base.float(x::Furlong{p}) where {p} = Furlong{p}(float(x.val))
Base.eps(::Type{Furlong{p,T}}) where {p,T<:AbstractFloat} = Furlong{p}(eps(T))
Base.eps(::Furlong{p,T}) where {p,T<:AbstractFloat} = eps(Furlong{p,T})
Base.floatmin(::Type{Furlong{p,T}}) where {p,T<:AbstractFloat} = Furlong{p}(floatmin(T))
Base.floatmin(::Furlong{p,T}) where {p,T<:AbstractFloat} = floatmin(Furlong{p,T})
Base.floatmax(::Type{Furlong{p,T}}) where {p,T<:AbstractFloat} = Furlong{p}(floatmax(T))
Base.floatmax(::Furlong{p,T}) where {p,T<:AbstractFloat} = floatmax(Furlong{p,T})

# convert Furlong exponent p to a canonical form. This
# is not type stable, but it doesn't matter since it is used
Expand All @@ -46,7 +53,7 @@ for f in (:real,:imag,:complex,:+,:-)
end

import Base: +, -, ==, !=, <, <=, isless, isequal, *, /, //, div, rem, mod, ^, hypot
for op in (:+, :-, :hypot)
for op in (:+, :-)
@eval function $op(x::Furlong{p}, y::Furlong{p}) where {p}
v = $op(x.val, y.val)
Furlong{p}(v)
Expand Down

0 comments on commit d3addb5

Please sign in to comment.