-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve the type stability of sqrt(::Complex)
#54869
base: master
Are you sure you want to change the base?
Conversation
Would this break a hypothetical user-defined floating-point type with a |
To ensure type stability while keeping correctness for generic code, perhaps it'd make sense to do something like k = zero(exponent(x)) Or maybe also promote with |
That's a good point, I didn't consider cases where Also, a number julia> big(2)^typemax(Int)
ERROR: OutOfMemoryError()
Stacktrace:
[1] pow_ui!(x::BigInt, a::BigInt, b::UInt64)
@ Base.GMP.MPZ ./gmp.jl:180
[2] pow_ui
@ ./gmp.jl:181 [inlined]
[3] ^
@ ./gmp.jl:626 [inlined]
[4] bigint_pow(x::BigInt, y::Int64)
@ Base.GMP ./gmp.jl:647
[5] ^(x::BigInt, y::Int64)
@ Base.GMP ./gmp.jl:652
[6] top-level scope
@ REPL[32]:1 In your hypothetical example where I suppose changing the initialization of k = m==0 ? k : convert(typeof(k), exponent(m)) |
In the latest commit I changed the code to convert the output of Here's an example of a custom number type that outputs struct MyBigFloat <: AbstractFloat
x::BigFloat
end
Base.exponent(x::MyBigFloat) = big(exponent(x.x))
Base.promote_rule(::Type{MyBigFloat}, ::Type{<:Real}) = MyBigFloat
Base.:<(x::MyBigFloat, y::MyBigFloat) = (x.x < y.x)
Base.:-(x::MyBigFloat) = MyBigFloat(-x.x)
Base.:+(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x + y.x)
Base.:-(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x - y.x)
Base.:*(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x * y.x)
Base.:/(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x / y.x)
Base.nextfloat(x::MyBigFloat) = MyBigFloat(nextfloat(x.x))
Base.eps(type::Type{MyBigFloat}) = MyBigFloat(eps(BigFloat))
Base.ldexp(x::MyBigFloat, y::Int64) = MyBigFloat(ldexp(x.x, y))
Base.sqrt(x::MyBigFloat) = MyBigFloat(sqrt(x.x)) and here is the output of `@code_warntype Base.ssqs(MyBigFloat(1.0), MyBigFloat(2.0))`: latest commit of this PR
julia> @code_warntype Base.ssqs(MyBigFloat(1.0), MyBigFloat(2.0))
MethodInstance for Base.ssqs(::MyBigFloat, ::MyBigFloat)
from ssqs(x::T, y::T) where T<:Real @ Base complex.jl:516
Static Parameters
T = MyBigFloat
Arguments
#self#::Core.Const(Base.ssqs)
x::MyBigFloat
y::MyBigFloat
Locals
yk::MyBigFloat
xk::MyBigFloat
m::MyBigFloat
ρ::MyBigFloat
k::Int64
@_9::MyBigFloat
@_10::Bool
@_11::Bool
@_12::Bool
@_13::Int64
Body::Tuple{MyBigFloat, Int64}
1 ── Core.NewvarNode(:(yk))
│ Core.NewvarNode(:(xk))
│ Core.NewvarNode(:(m))
│ (k = 0)
│ %5 = Base.:+::Core.Const(+)
│ %6 = (x * x)::MyBigFloat
│ %7 = (y * y)::MyBigFloat
│ (ρ = (%5)(%6, %7))
│ %9 = Base.:!::Core.Const(!)
│ %10 = ρ::MyBigFloat
│ %11 = Base.isfinite(%10)::Bool
│ %12 = (%9)(%11)::Bool
└─── goto #7 if not %12
2 ── %14 = Base.isinf(x)::Bool
└─── goto #4 if not %14
3 ── (@_10 = %14)
└─── goto #5
4 ── (@_10 = Base.isinf(y))
5 ┄─ %19 = @_10::Bool
└─── goto #7 if not %19
6 ── %21 = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│ (ρ = Base.convert(%21, Base.Inf))
└─── goto #25
7 ┄─ %24 = ρ::MyBigFloat
│ %25 = Base.isinf(%24)::Bool
└─── goto #9 if not %25
8 ── goto #18
9 ── %28 = ρ::MyBigFloat
│ %29 = (%28 == 0)::Bool
└─── goto #14 if not %29
10 ─ %31 = (x != 0)::Bool
└─── goto #12 if not %31
11 ─ (@_12 = %31)
└─── goto #13
12 ─ (@_12 = y != 0)
13 ┄ %36 = @_12::Bool
│ (@_11 = %36)
└─── goto #15
14 ─ (@_11 = false)
15 ┄ %40 = @_11::Bool
└─── goto #17 if not %40
16 ─ goto #18
17 ─ %43 = Base.:<::Core.Const(<)
│ %44 = ρ::MyBigFloat
│ %45 = Base.:/::Core.Const(/)
│ %46 = Base.nextfloat::Core.Const(nextfloat)
│ %47 = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│ %48 = Base.zero(%47)::MyBigFloat
│ %49 = (%46)(%48)::MyBigFloat
│ %50 = Base.:*::Core.Const(*)
│ %51 = Base.:^::Core.Const(^)
│ %52 = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│ %53 = Base.eps(%52)::MyBigFloat
│ %54 = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│ %55 = (%54)()::Core.Const(Val{2}())
│ %56 = Base.literal_pow(%51, %53, %55)::MyBigFloat
│ %57 = (%50)(2, %56)::MyBigFloat
│ %58 = (%45)(%49, %57)::MyBigFloat
│ %59 = (%43)(%44, %58)::Bool
└─── goto #25 if not %59
18 ┄ %61 = Base.max::Core.Const(max)
│ %62 = Base.abs(x)::MyBigFloat
│ %63 = Base.abs(y)::MyBigFloat
│ %64 = (%61)(%62, %63)::MyBigFloat
│ (@_9 = %64)
│ %66 = @_9::MyBigFloat
│ %67 = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│ %68 = (%66 isa %67)::Core.Const(true)
└─── goto #20 if not %68
19 ─ goto #21
20 ─ Core.Const(:($(Expr(:static_parameter, 1))))
│ Core.Const(:(@_9))
│ Core.Const(:(Base.convert(%71, %72)))
│ Core.Const(:($(Expr(:static_parameter, 1))))
└─── Core.Const(:(@_9 = Core.typeassert(%73, %74)))
21 ┄ %76 = @_9::MyBigFloat
│ (m = %76)
│ %78 = m::MyBigFloat
│ %79 = (%78 == 0)::Bool
└─── goto #23 if not %79
22 ─ (@_13 = 0)
└─── goto #24
23 ─ %83 = Base.convert::Core.Const(convert)
│ %84 = Base.Int::Core.Const(Int64)
│ %85 = Base.exponent::Core.Const(exponent)
│ %86 = m::MyBigFloat
│ %87 = (%85)(%86)::BigInt
└─── (@_13 = (%83)(%84, %87))
24 ┄ %89 = @_13::Int64
│ (k = %89)
│ %91 = Base.ldexp::Core.Const(ldexp)
│ %92 = k::Int64
│ %93 = -%92::Int64
│ %94 = (%91)(x, %93)::MyBigFloat
│ %95 = Base.ldexp::Core.Const(ldexp)
│ %96 = k::Int64
│ %97 = -%96::Int64
│ %98 = (%95)(y, %97)::MyBigFloat
│ (xk = %94)
│ (yk = %98)
│ %101 = Base.:+::Core.Const(+)
│ %102 = xk::MyBigFloat
│ %103 = xk::MyBigFloat
│ %104 = (%102 * %103)::MyBigFloat
│ %105 = yk::MyBigFloat
│ %106 = yk::MyBigFloat
│ %107 = (%105 * %106)::MyBigFloat
└─── (ρ = (%101)(%104, %107))
25 ┄ %109 = ρ::MyBigFloat
│ %110 = k::Int64
│ %111 = Core.tuple(%109, %110)::Tuple{MyBigFloat, Int64}
└─── return %111 so you can see there aren't any |
base/complex.jl
Outdated
elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2) | ||
m::T = max(abs(x), abs(y)) | ||
k = m==0 ? m : exponent(m) | ||
k = m==0 ? 0 : convert(Int, exponent(m)) | ||
xk, yk = ldexp(x,-k), ldexp(y,-k) | ||
ρ = xk*xk + yk*yk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to the issue you're trying to solve here, so feel free to ignore, but while at it: For k==0
, it seems this doesn't change anything, and as we're checking m==0
anyway (which implies k==0
), maybe this could be slightly more efficient:
elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2) | |
m::T = max(abs(x), abs(y)) | |
k = m==0 ? m : exponent(m) | |
k = m==0 ? 0 : convert(Int, exponent(m)) | |
xk, yk = ldexp(x,-k), ldexp(y,-k) | |
ρ = xk*xk + yk*yk | |
elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2) | |
m::T = max(abs(x), abs(y)) | |
if m != 0 | |
k = convert(Int, exponent(m)) | |
xk, yk = ldexp(x,-k), ldexp(y,-k) | |
ρ = xk*xk + yk*yk | |
end |
But at that point, it might make sense to subsume this in the condition above. Doesn't m == 0
imply x==0 && y==0
or is there some sneaky floating point thing going on so that abs(x)==0
even if x!=0
? Note further that that condition is a bit strange to begin with: Assuming eps(T)<sqrt(0.5)
for any reasonable float, nextfloat(zero(T))/(2*eps(T)^2)
will be be greater than 0
. So for ρ==0
, the condition will always be true, so matter x!=0 || y!=0
. However, if we want to exclude x==0 && y==0
, we could actually rewrite this as
elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2) | |
m::T = max(abs(x), abs(y)) | |
k = m==0 ? m : exponent(m) | |
k = m==0 ? 0 : convert(Int, exponent(m)) | |
xk, yk = ldexp(x,-k), ldexp(y,-k) | |
ρ = xk*xk + yk*yk | |
elseif isinf(ρ) || (ρ<nextfloat(zero(T))/(2*eps(T)^2) && (x!=0 || y!=0)) | |
m::T = max(abs(x), abs(y)) | |
k = convert(Int, exponent(m)) | |
xk, yk = ldexp(x,-k), ldexp(y,-k) | |
ρ = xk*xk + yk*yk |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds reasonable to me, I'm happy to change it to that.
A slight optimization on your last suggestion could be:
elseif isinf(ρ) || ((x!=0 || y!=0) && ρ<nextfloat(zero(T))/(2*eps(T)^2))
since presumably comparisons to 0 are faster than those arithmetic operations, and a more common situation is that this function is being called with x==0
and y==0
so in that common case this will skip the check on ρ
.
From what I can tell this code path only gets called when x
and/or y
are very small or large but still finite and one or the other is nonzero. Here's a demonstration that in some limiting cases your suggestion is the same as the old code logic:
function ssqs1(x::T, y::T) where T<:Real
k = 0
ρ = x*x + y*y
if !isfinite(ρ) && (isinf(x) || isinf(y))
ρ = convert(T, Inf)
elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2)
m::T = max(abs(x), abs(y))
k = m==0 ? 0 : convert(Int, exponent(m))
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk
end
ρ, k
end
function ssqs2(x::T, y::T) where T<:Real
k = 0
ρ = x*x + y*y
if !isfinite(ρ) && (isinf(x) || isinf(y))
ρ = convert(T, Inf)
elseif isinf(ρ) || ((x!=0 || y!=0) && ρ<nextfloat(zero(T))/(2*eps(T)^2))
m::T = max(abs(x), abs(y))
k = convert(Int, exponent(m))
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk
end
ρ, k
end
for (x, y) in (
(1e-146, 0.0),
(1e-147, 0.0),
(1e-161, 0.0),
(1e-162, 0.0),
(1e154, 0.0),
(1e155, 0.0),
)
@show x, y
@show ρ = x*x + y*y
@show ssqs1(x, y)
@show ssqs2(x, y)
println()
end
which outputs:
(x, y) = (1.0e-146, 0.0)
ρ = x * x + y * y = 1.0e-292
ssqs1(x, y) = (1.0e-292, 0)
ssqs2(x, y) = (1.0e-292, 0)
(x, y) = (1.0e-147, 0.0)
ρ = x * x + y * y = 1.0e-294
ssqs1(x, y) = (2.5546755962044414, -489)
ssqs2(x, y) = (2.5546755962044414, -489)
(x, y) = (1.0e-161, 0.0)
ρ = x * x + y * y = 1.0e-322
ssqs1(x, y) = (1.2650140831706915, -535)
ssqs2(x, y) = (1.2650140831706915, -535)
(x, y) = (1.0e-162, 0.0)
ρ = x * x + y * y = 0.0
ssqs1(x, y) = (3.2384360529169696, -539)
ssqs2(x, y) = (3.2384360529169696, -539)
(x, y) = (1.0e154, 0.0)
ρ = x * x + y * y = 1.0e308
ssqs1(x, y) = (1.0e308, 0)
ssqs2(x, y) = (1.0e308, 0)
(x, y) = (1.0e155, 0.0)
ρ = x * x + y * y = Inf
ssqs1(x, y) = (3.476677903917502, 514)
ssqs2(x, y) = (3.476677903917502, 514)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with your assessment that I don't see how m
could be zero inside this branch of the code logic so I don't know why that was being checked in the first place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With the old logic it could, as with y==0 && y==0
, clearly (ρ==0 && (x!=0 || y!=0))
would be false, but ρ<nextfloat(zero(T))/(2*eps(T)^2)
would be true (and m==0
, obviously). So I guess that check on m
was necessary because the branch condition did not do what it was intended to.
More specifically, this improves the type stability of
Base.ssqs(x::T, y::T) where T<:Real
, which is called bysqrt(::Complex)
.Here's a demonstration:
`@code_warntype Base.ssqs(1.0, 2.0)`: this pull request
`@code_warntype Base.ssqs(1.0, 2.0)`: nightly
There may be something I'm missing about why it was written the way it was before, but this new way looks simpler to me, avoids a
Union
type in type inference, and precludes the need for a few type assertions.This showed up "in the wild" in JuliaGPU/Metal.jl#374, where the
Union
type in type inference caused issues for compilingsqrt(::Complex)
on Apple GPUs.