Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the type stability of sqrt(::Complex) #54869

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

mtfishman
Copy link
Contributor

@mtfishman mtfishman commented Jun 20, 2024

More specifically, this improves the type stability of Base.ssqs(x::T, y::T) where T<:Real, which is called by sqrt(::Complex).

Here's a demonstration:

`@code_warntype Base.ssqs(1.0, 2.0)`: this pull request

julia> @code_warntype Base.ssqs(1.0, 2.0)
MethodInstance for Base.ssqs(::Float64, ::Float64)
  from ssqs(x::T, y::T) where T<:Real @ Base complex.jl:516
Static Parameters
  T = Float64
Arguments
  #self#::Core.Const(Base.ssqs)
  x::Float64
  y::Float64
Locals
  yk::Float64
  xk::Float64
  m::Float64
  ρ::Float64
  k::Int64
  @_9::Float64
  @_10::Bool
  @_11::Bool
  @_12::Bool
  @_13::Int64
Body::Tuple{Float64, Int64}
1 ──        Core.NewvarNode(:(yk))
│           Core.NewvarNode(:(xk))
│           Core.NewvarNode(:(m))
│           (k = 0)
│    %5   = Base.:+::Core.Const(+)
│    %6   = (x * x)::Float64%7   = (y * y)::Float64
│           (ρ = (%5)(%6, %7))
│    %9   = Base.:!::Core.Const(!)
│    %10  = ρ::Float64%11  = Base.isfinite(%10)::Bool%12  = (%9)(%11)::Bool
└───        goto #7 if not %12
2 ── %14  = Base.isinf(x)::Bool
└───        goto #4 if not %14
3 ──        (@_10 = %14)
└───        goto #5
4 ──        (@_10 = Base.isinf(y))
5 ┄─ %19  = @_10::Bool
└───        goto #7 if not %19
6 ── %21  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│           (ρ = Base.convert(%21, Base.Inf))
└───        goto #25
7 ┄─ %24  = ρ::Float64%25  = Base.isinf(%24)::Bool
└───        goto #9 if not %25
8 ──        goto #18
9 ── %28  = ρ::Float64%29  = (%28 == 0)::Bool
└───        goto #14 if not %29
10%31  = (x != 0)::Bool
└───        goto #12 if not %31
11 ─        (@_12 = %31)
└───        goto #13
12 ─        (@_12 = y != 0)
13%36  = @_12::Bool
│           (@_11 = %36)
└───        goto #15
14 ─        (@_11 = false)
15%40  = @_11::Bool
└───        goto #17 if not %40
16 ─        goto #18
17%43  = Base.:<::Core.Const(<)
│    %44  = ρ::Float64%45  = Base.:/::Core.Const(/)
│    %46  = Base.nextfloat::Core.Const(nextfloat)
│    %47  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %48  = Base.zero(%47)::Core.Const(0.0)
│    %49  = (%46)(%48)::Core.Const(5.0e-324)
│    %50  = Base.:*::Core.Const(*)
│    %51  = Base.:^::Core.Const(^)
│    %52  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %53  = Base.eps(%52)::Core.Const(2.220446049250313e-16)
│    %54  = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│    %55  = (%54)()::Core.Const(Val{2}())
│    %56  = Base.literal_pow(%51, %53, %55)::Core.Const(4.930380657631324e-32)
│    %57  = (%50)(2, %56)::Core.Const(9.860761315262648e-32)
│    %58  = (%45)(%49, %57)::Core.Const(5.010420900022432e-293)
│    %59  = (%43)(%44, %58)::Bool
└───        goto #25 if not %59
18%61  = Base.max::Core.Const(max)
│    %62  = Base.abs(x)::Float64%63  = Base.abs(y)::Float64%64  = (%61)(%62, %63)::Float64
│           (@_9 = %64)
│    %66  = @_9::Float64%67  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %68  = (%66 isa %67)::Core.Const(true)
└───        goto #20 if not %68
19 ─        goto #21
20 ─        Core.Const(:($(Expr(:static_parameter, 1))))
│           Core.Const(:(@_9))
│           Core.Const(:(Base.convert(%71, %72)))
│           Core.Const(:($(Expr(:static_parameter, 1))))
└───        Core.Const(:(@_9 = Core.typeassert(%73, %74)))
21%76  = @_9::Float64
│           (m = %76)
│    %78  = m::Float64%79  = (%78 == 0)::Bool
└───        goto #23 if not %79
22%81  = k::Core.Const(0)
│           (@_13 = %81)
└───        goto #24
23%84  = Base.exponent::Core.Const(exponent)
│    %85  = m::Float64
└───        (@_13 = (%84)(%85))
24%87  = @_13::Int64
│           (k = %87)
│    %89  = Base.ldexp::Core.Const(ldexp)
│    %90  = k::Int64%91  = -%90::Int64%92  = (%89)(x, %91)::Float64%93  = Base.ldexp::Core.Const(ldexp)
│    %94  = k::Int64%95  = -%94::Int64%96  = (%93)(y, %95)::Float64
│           (xk = %92)
│           (yk = %96)
│    %99  = Base.:+::Core.Const(+)
│    %100 = xk::Float64%101 = xk::Float64%102 = (%100 * %101)::Float64%103 = yk::Float64%104 = yk::Float64%105 = (%103 * %104)::Float64
└───        (ρ = (%99)(%102, %105))
25%107 = ρ::Float64%108 = k::Int64%109 = Core.tuple(%107, %108)::Tuple{Float64, Int64}
└───        return %109


julia> @btime sqrt(1.0 + 2.0im)
  1.250 ns (0 allocations: 0 bytes)
1.272019649514069 + 0.7861513777574233im

julia> versioninfo()
Julia Version 1.12.0-DEV.753
Commit 4d0149d160* (2024-06-20 16:00 UTC)
Platform Info:
  OS: macOS (arm64-apple-darwin23.5.0)
  CPU: 10 × Apple M1 Max
  WORD_SIZE: 64
  LLVM: libLLVM-17.0.6 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)

`@code_warntype Base.ssqs(1.0, 2.0)`: nightly

julia> @code_warntype Base.ssqs(1.0, 2.0)
MethodInstance for Base.ssqs(::Float64, ::Float64)
  from ssqs(x::T, y::T) where T<:Real @ Base complex.jl:516
Static Parameters
  T = Float64
Arguments
  #self#::Core.Const(Base.ssqs)
  x::Float64
  y::Float64
Locals
  yk::Float64
  xk::Float64
  m::Float64
  ρ::Float64
  k::Int64
  @_9::Int64
  @_10::Float64
  @_11::Union{Float64, Int64}
  @_12::Bool
  @_13::Bool
  @_14::Bool
  @_15::Union{Float64, Int64}
Body::Tuple{Float64, Int64}
1 ──        Core.NewvarNode(:(yk))
│           Core.NewvarNode(:(xk))
│           Core.NewvarNode(:(m))
│           Core.NewvarNode(:(ρ))
│           Core.NewvarNode(:(k))
│           (@_9 = 0)
│    %7   = @_9::Core.Const(0)
│    %8   = (%7 isa Base.Int)::Core.Const(true)
└───        goto #3 if not %8
2 ──        goto #4
3 ──        Core.Const(:(@_9))
│           Core.Const(:(Base.convert(Base.Int, %11)))
│           Core.Const(:(Base.Int))
└───        Core.Const(:(@_9 = Core.typeassert(%12, %13)))
4 ┄─ %15  = @_9::Core.Const(0)
│           (k = %15)
│    %17  = Base.:+::Core.Const(+)
│    %18  = (x * x)::Float64%19  = (y * y)::Float64
│           (ρ = (%17)(%18, %19))
│    %21  = Base.:!::Core.Const(!)
│    %22  = ρ::Float64%23  = Base.isfinite(%22)::Bool%24  = (%21)(%23)::Bool
└───        goto #10 if not %24
5 ── %26  = Base.isinf(x)::Bool
└───        goto #7 if not %26
6 ──        (@_12 = %26)
└───        goto #8
7 ──        (@_12 = Base.isinf(y))
8 ┄─ %31  = @_12::Bool
└───        goto #10 if not %31
9 ── %33  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│           (ρ = Base.convert(%33, Base.Inf))
└───        goto #31
10%36  = ρ::Float64%37  = Base.isinf(%36)::Bool
└───        goto #12 if not %37
11 ─        goto #21
12%40  = ρ::Float64%41  = (%40 == 0)::Bool
└───        goto #17 if not %41
13%43  = (x != 0)::Bool
└───        goto #15 if not %43
14 ─        (@_14 = %43)
└───        goto #16
15 ─        (@_14 = y != 0)
16%48  = @_14::Bool
│           (@_13 = %48)
└───        goto #18
17 ─        (@_13 = false)
18%52  = @_13::Bool
└───        goto #20 if not %52
19 ─        goto #21
20%55  = Base.:<::Core.Const(<)
│    %56  = ρ::Float64%57  = Base.:/::Core.Const(/)
│    %58  = Base.nextfloat::Core.Const(nextfloat)
│    %59  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %60  = Base.zero(%59)::Core.Const(0.0)
│    %61  = (%58)(%60)::Core.Const(5.0e-324)
│    %62  = Base.:*::Core.Const(*)
│    %63  = Base.:^::Core.Const(^)
│    %64  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %65  = Base.eps(%64)::Core.Const(2.220446049250313e-16)
│    %66  = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│    %67  = (%66)()::Core.Const(Val{2}())
│    %68  = Base.literal_pow(%63, %65, %67)::Core.Const(4.930380657631324e-32)
│    %69  = (%62)(2, %68)::Core.Const(9.860761315262648e-32)
│    %70  = (%57)(%61, %69)::Core.Const(5.010420900022432e-293)
│    %71  = (%55)(%56, %70)::Bool
└───        goto #31 if not %71
21%73  = Base.max::Core.Const(max)
│    %74  = Base.abs(x)::Float64%75  = Base.abs(y)::Float64%76  = (%73)(%74, %75)::Float64
│           (@_10 = %76)
│    %78  = @_10::Float64%79  = $(Expr(:static_parameter, 1))::Core.Const(Float64)
│    %80  = (%78 isa %79)::Core.Const(true)
└───        goto #23 if not %80
22 ─        goto #24
23 ─        Core.Const(:($(Expr(:static_parameter, 1))))
│           Core.Const(:(@_10))
│           Core.Const(:(Base.convert(%83, %84)))
│           Core.Const(:($(Expr(:static_parameter, 1))))
└───        Core.Const(:(@_10 = Core.typeassert(%85, %86)))
24%88  = @_10::Float64
│           (m = %88)
│    %90  = m::Float64%91  = (%90 == 0)::Bool
└───        goto #26 if not %91
25%93  = m::Float64
│           (@_15 = %93)
└───        goto #27
26%96  = Base.exponent::Core.Const(exponent)
│    %97  = m::Float64
└───        (@_15 = (%96)(%97))
27%99  = @_15::Union{Float64, Int64}
│           (@_11 = %99)
│    %101 = @_11::Union{Float64, Int64}%102 = (%101 isa Base.Int)::Bool
└───        goto #29 if not %102
28 ─        goto #30
29%105 = @_11::Float64%106 = Base.convert(Base.Int, %105)::Int64%107 = Base.Int::Core.Const(Int64)
└───        (@_11 = Core.typeassert(%106, %107))
30%109 = @_11::Int64
│           (k = %109)
│    %111 = Base.ldexp::Core.Const(ldexp)
│    %112 = k::Int64%113 = -%112::Int64%114 = (%111)(x, %113)::Float64%115 = Base.ldexp::Core.Const(ldexp)
│    %116 = k::Int64%117 = -%116::Int64%118 = (%115)(y, %117)::Float64
│           (xk = %114)
│           (yk = %118)
│    %121 = Base.:+::Core.Const(+)
│    %122 = xk::Float64%123 = xk::Float64%124 = (%122 * %123)::Float64%125 = yk::Float64%126 = yk::Float64%127 = (%125 * %126)::Float64
└───        (ρ = (%121)(%124, %127))
31%129 = ρ::Float64%130 = k::Int64%131 = Core.tuple(%129, %130)::Tuple{Float64, Int64}
└───        return %131


julia> @btime sqrt(1.0 + 2.0im)
  1.250 ns (0 allocations: 0 bytes)
1.272019649514069 + 0.7861513777574233im

julia> versioninfo()
Julia Version 1.12.0-DEV.753
Commit 4d0149d160b (2024-06-20 16:00 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: macOS (arm64-apple-darwin22.4.0)
  CPU: 10 × Apple M1 Max
  WORD_SIZE: 64
  LLVM: libLLVM-17.0.6 (ORCJIT, apple-m1)
Threads: 1 default, 0 interactive, 1 GC (on 8 virtual cores)

There may be something I'm missing about why it was written the way it was before, but this new way looks simpler to me, avoids a Union type in type inference, and precludes the need for a few type assertions.

This showed up "in the wild" in JuliaGPU/Metal.jl#374, where the Union type in type inference caused issues for compiling sqrt(::Complex) on Apple GPUs.

@nsajko nsajko added the maths Mathematical functions label Jun 21, 2024
@nsajko
Copy link
Contributor

nsajko commented Jun 21, 2024

Would this break a hypothetical user-defined floating-point type with a BigInt exponent?

@nsajko
Copy link
Contributor

nsajko commented Jun 21, 2024

To ensure type stability while keeping correctness for generic code, perhaps it'd make sense to do something like

k = zero(exponent(x))

Or maybe also promote with Int?

@mtfishman
Copy link
Contributor Author

mtfishman commented Jun 21, 2024

That's a good point, I didn't consider cases where exponent might not be outputting something of type Int. It seems like for that, instead of making it more generic, we should convert to Int like the behavior before. exponent isn't always called in the code, so it isn't clear to me how to initialize k to zero(exponent(x)) without calling exponent or inferring the output type of exponent some other way, which seems pretty heavy-duty for such a low level math function.

Also, a number x such that exponent(x) > typemax(Int) is really big, that's too big to even be represented by BigFloat/BigInt, for example:

julia> big(2)^typemax(Int)
ERROR: OutOfMemoryError()
Stacktrace:
 [1] pow_ui!(x::BigInt, a::BigInt, b::UInt64)
   @ Base.GMP.MPZ ./gmp.jl:180
 [2] pow_ui
   @ ./gmp.jl:181 [inlined]
 [3] ^
   @ ./gmp.jl:626 [inlined]
 [4] bigint_pow(x::BigInt, y::Int64)
   @ Base.GMP ./gmp.jl:647
 [5] ^(x::BigInt, y::Int64)
   @ Base.GMP ./gmp.jl:652
 [6] top-level scope
   @ REPL[32]:1

In your hypothetical example where exponent(x) outputs something like BigInt, it would have failed in the version of Base.ssqs on master with a conversion error if the output couldn't get converted to Int (i.e. if exponent(x) > typemax(Int)), so it should be ok to keep that behavior.

I suppose changing the initialization of k back from k = 0 to k::Int = 0 would make sure that if exponent(m) output something with a type other than Int it would get converted to Int like before, but maybe there is a better code pattern to do that, like:

k = m==0 ? k : convert(typeof(k), exponent(m))

@mtfishman
Copy link
Contributor Author

mtfishman commented Jun 21, 2024

In the latest commit I changed the code to convert the output of exponent to Int in Base.ssqs.

Here's an example of a custom number type that outputs BigInt from exponent:

struct MyBigFloat <: AbstractFloat
  x::BigFloat
end
Base.exponent(x::MyBigFloat) = big(exponent(x.x))
Base.promote_rule(::Type{MyBigFloat}, ::Type{<:Real}) = MyBigFloat
Base.:<(x::MyBigFloat, y::MyBigFloat) = (x.x < y.x)
Base.:-(x::MyBigFloat) = MyBigFloat(-x.x)
Base.:+(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x + y.x)
Base.:-(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x - y.x)
Base.:*(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x * y.x)
Base.:/(x::MyBigFloat, y::MyBigFloat) = MyBigFloat(x.x / y.x)
Base.nextfloat(x::MyBigFloat) = MyBigFloat(nextfloat(x.x))
Base.eps(type::Type{MyBigFloat}) = MyBigFloat(eps(BigFloat))
Base.ldexp(x::MyBigFloat, y::Int64) = MyBigFloat(ldexp(x.x, y))
Base.sqrt(x::MyBigFloat) = MyBigFloat(sqrt(x.x))

and here is the output of @code_warntype Base.ssqs(MyBigFloat(1.0), MyBigFloat(2.0)) based on the latest commit:

`@code_warntype Base.ssqs(MyBigFloat(1.0), MyBigFloat(2.0))`: latest commit of this PR

julia> @code_warntype Base.ssqs(MyBigFloat(1.0), MyBigFloat(2.0))
MethodInstance for Base.ssqs(::MyBigFloat, ::MyBigFloat)
  from ssqs(x::T, y::T) where T<:Real @ Base complex.jl:516
Static Parameters
  T = MyBigFloat
Arguments
  #self#::Core.Const(Base.ssqs)
  x::MyBigFloat
  y::MyBigFloat
Locals
  yk::MyBigFloat
  xk::MyBigFloat
  m::MyBigFloat
  ρ::MyBigFloat
  k::Int64
  @_9::MyBigFloat
  @_10::Bool
  @_11::Bool
  @_12::Bool
  @_13::Int64
Body::Tuple{MyBigFloat, Int64}
1 ──        Core.NewvarNode(:(yk))
│           Core.NewvarNode(:(xk))
│           Core.NewvarNode(:(m))
│           (k = 0)
│    %5   = Base.:+::Core.Const(+)
│    %6   = (x * x)::MyBigFloat%7   = (y * y)::MyBigFloat
│           (ρ = (%5)(%6, %7))
│    %9   = Base.:!::Core.Const(!)
│    %10  = ρ::MyBigFloat%11  = Base.isfinite(%10)::Bool%12  = (%9)(%11)::Bool
└───        goto #7 if not %12
2 ── %14  = Base.isinf(x)::Bool
└───        goto #4 if not %14
3 ──        (@_10 = %14)
└───        goto #5
4 ──        (@_10 = Base.isinf(y))
5 ┄─ %19  = @_10::Bool
└───        goto #7 if not %19
6 ── %21  = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│           (ρ = Base.convert(%21, Base.Inf))
└───        goto #25
7 ┄─ %24  = ρ::MyBigFloat%25  = Base.isinf(%24)::Bool
└───        goto #9 if not %25
8 ──        goto #18
9 ── %28  = ρ::MyBigFloat%29  = (%28 == 0)::Bool
└───        goto #14 if not %29
10%31  = (x != 0)::Bool
└───        goto #12 if not %31
11 ─        (@_12 = %31)
└───        goto #13
12 ─        (@_12 = y != 0)
13%36  = @_12::Bool
│           (@_11 = %36)
└───        goto #15
14 ─        (@_11 = false)
15%40  = @_11::Bool
└───        goto #17 if not %40
16 ─        goto #18
17%43  = Base.:<::Core.Const(<)
│    %44  = ρ::MyBigFloat%45  = Base.:/::Core.Const(/)
│    %46  = Base.nextfloat::Core.Const(nextfloat)
│    %47  = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│    %48  = Base.zero(%47)::MyBigFloat%49  = (%46)(%48)::MyBigFloat%50  = Base.:*::Core.Const(*)
│    %51  = Base.:^::Core.Const(^)
│    %52  = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│    %53  = Base.eps(%52)::MyBigFloat%54  = Core.apply_type(Base.Val, 2)::Core.Const(Val{2})
│    %55  = (%54)()::Core.Const(Val{2}())
│    %56  = Base.literal_pow(%51, %53, %55)::MyBigFloat%57  = (%50)(2, %56)::MyBigFloat%58  = (%45)(%49, %57)::MyBigFloat%59  = (%43)(%44, %58)::Bool
└───        goto #25 if not %59
18%61  = Base.max::Core.Const(max)
│    %62  = Base.abs(x)::MyBigFloat%63  = Base.abs(y)::MyBigFloat%64  = (%61)(%62, %63)::MyBigFloat
│           (@_9 = %64)
│    %66  = @_9::MyBigFloat%67  = $(Expr(:static_parameter, 1))::Core.Const(MyBigFloat)
│    %68  = (%66 isa %67)::Core.Const(true)
└───        goto #20 if not %68
19 ─        goto #21
20 ─        Core.Const(:($(Expr(:static_parameter, 1))))
│           Core.Const(:(@_9))
│           Core.Const(:(Base.convert(%71, %72)))
│           Core.Const(:($(Expr(:static_parameter, 1))))
└───        Core.Const(:(@_9 = Core.typeassert(%73, %74)))
21%76  = @_9::MyBigFloat
│           (m = %76)
│    %78  = m::MyBigFloat%79  = (%78 == 0)::Bool
└───        goto #23 if not %79
22 ─        (@_13 = 0)
└───        goto #24
23%83  = Base.convert::Core.Const(convert)
│    %84  = Base.Int::Core.Const(Int64)
│    %85  = Base.exponent::Core.Const(exponent)
│    %86  = m::MyBigFloat%87  = (%85)(%86)::BigInt
└───        (@_13 = (%83)(%84, %87))
24%89  = @_13::Int64
│           (k = %89)
│    %91  = Base.ldexp::Core.Const(ldexp)
│    %92  = k::Int64%93  = -%92::Int64%94  = (%91)(x, %93)::MyBigFloat%95  = Base.ldexp::Core.Const(ldexp)
│    %96  = k::Int64%97  = -%96::Int64%98  = (%95)(y, %97)::MyBigFloat
│           (xk = %94)
│           (yk = %98)
│    %101 = Base.:+::Core.Const(+)
│    %102 = xk::MyBigFloat%103 = xk::MyBigFloat%104 = (%102 * %103)::MyBigFloat%105 = yk::MyBigFloat%106 = yk::MyBigFloat%107 = (%105 * %106)::MyBigFloat
└───        (ρ = (%101)(%104, %107))
25%109 = ρ::MyBigFloat%110 = k::Int64%111 = Core.tuple(%109, %110)::Tuple{MyBigFloat, Int64}
└───        return %111

so you can see there aren't any Union types in type inference with this version, whereas there would be on master and also on the previous commit of this PR, as @nsajko pointed out.

base/complex.jl Outdated
Comment on lines 521 to 525
elseif isinf(ρ) || (ρ==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2)
m::T = max(abs(x), abs(y))
k = m==0 ? m : exponent(m)
k = m==0 ? 0 : convert(Int, exponent(m))
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to the issue you're trying to solve here, so feel free to ignore, but while at it: For k==0, it seems this doesn't change anything, and as we're checking m==0 anyway (which implies k==0), maybe this could be slightly more efficient:

Suggested change
elseif isinf(ρ) ||==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2)
m::T = max(abs(x), abs(y))
k = m==0 ? m : exponent(m)
k = m==0 ? 0 : convert(Int, exponent(m))
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk
elseif isinf(ρ) ||==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2)
m::T = max(abs(x), abs(y))
if m != 0
k = convert(Int, exponent(m))
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk
end

But at that point, it might make sense to subsume this in the condition above. Doesn't m == 0 imply x==0 && y==0 or is there some sneaky floating point thing going on so that abs(x)==0 even if x!=0? Note further that that condition is a bit strange to begin with: Assuming eps(T)<sqrt(0.5) for any reasonable float, nextfloat(zero(T))/(2*eps(T)^2) will be be greater than 0. So for ρ==0, the condition will always be true, so matter x!=0 || y!=0. However, if we want to exclude x==0 && y==0, we could actually rewrite this as

Suggested change
elseif isinf(ρ) ||==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2)
m::T = max(abs(x), abs(y))
k = m==0 ? m : exponent(m)
k = m==0 ? 0 : convert(Int, exponent(m))
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk
elseif isinf(ρ) ||<nextfloat(zero(T))/(2*eps(T)^2) && (x!=0 || y!=0))
m::T = max(abs(x), abs(y))
k = convert(Int, exponent(m))
xk, yk = ldexp(x,-k), ldexp(y,-k)
ρ = xk*xk + yk*yk

Copy link
Contributor Author

@mtfishman mtfishman Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds reasonable to me, I'm happy to change it to that.

A slight optimization on your last suggestion could be:

    elseif isinf(ρ) || ((x!=0 || y!=0) && ρ<nextfloat(zero(T))/(2*eps(T)^2))

since presumably comparisons to 0 are faster than those arithmetic operations, and a more common situation is that this function is being called with x==0 and y==0 so in that common case this will skip the check on ρ.

From what I can tell this code path only gets called when x and/or y are very small or large but still finite and one or the other is nonzero. Here's a demonstration that in some limiting cases your suggestion is the same as the old code logic:

function ssqs1(x::T, y::T) where T<:Real
    k = 0
    ρ = x*x + y*y
    if !isfinite(ρ) && (isinf(x) || isinf(y))
        ρ = convert(T, Inf)
    elseif isinf(ρ) ||==0 && (x!=0 || y!=0)) || ρ<nextfloat(zero(T))/(2*eps(T)^2)
        m::T = max(abs(x), abs(y))
        k = m==0 ? 0 : convert(Int, exponent(m))
        xk, yk = ldexp(x,-k), ldexp(y,-k)
        ρ = xk*xk + yk*yk
    end
    ρ, k
end

function ssqs2(x::T, y::T) where T<:Real
    k = 0
    ρ = x*x + y*y
    if !isfinite(ρ) && (isinf(x) || isinf(y))
        ρ = convert(T, Inf)
    elseif isinf(ρ) || ((x!=0 || y!=0) && ρ<nextfloat(zero(T))/(2*eps(T)^2))
        m::T = max(abs(x), abs(y))
        k = convert(Int, exponent(m))
        xk, yk = ldexp(x,-k), ldexp(y,-k)
        ρ = xk*xk + yk*yk
    end
    ρ, k
end

for (x, y) in (
  (1e-146, 0.0),
  (1e-147, 0.0),
  (1e-161, 0.0),
  (1e-162, 0.0),
  (1e154, 0.0),
  (1e155, 0.0),
)
  @show x, y
  @show ρ = x*x + y*y
  @show ssqs1(x, y)
  @show ssqs2(x, y)
  println()
end

which outputs:

(x, y) = (1.0e-146, 0.0)
ρ = x * x + y * y = 1.0e-292
ssqs1(x, y) = (1.0e-292, 0)
ssqs2(x, y) = (1.0e-292, 0)

(x, y) = (1.0e-147, 0.0)
ρ = x * x + y * y = 1.0e-294
ssqs1(x, y) = (2.5546755962044414, -489)
ssqs2(x, y) = (2.5546755962044414, -489)

(x, y) = (1.0e-161, 0.0)
ρ = x * x + y * y = 1.0e-322
ssqs1(x, y) = (1.2650140831706915, -535)
ssqs2(x, y) = (1.2650140831706915, -535)

(x, y) = (1.0e-162, 0.0)
ρ = x * x + y * y = 0.0
ssqs1(x, y) = (3.2384360529169696, -539)
ssqs2(x, y) = (3.2384360529169696, -539)

(x, y) = (1.0e154, 0.0)
ρ = x * x + y * y = 1.0e308
ssqs1(x, y) = (1.0e308, 0)
ssqs2(x, y) = (1.0e308, 0)

(x, y) = (1.0e155, 0.0)
ρ = x * x + y * y = Inf
ssqs1(x, y) = (3.476677903917502, 514)
ssqs2(x, y) = (3.476677903917502, 514)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with your assessment that I don't see how m could be zero inside this branch of the code logic so I don't know why that was being checked in the first place.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the old logic it could, as with y==0 && y==0, clearly (ρ==0 && (x!=0 || y!=0)) would be false, but ρ<nextfloat(zero(T))/(2*eps(T)^2) would be true (and m==0, obviously). So I guess that check on m was necessary because the branch condition did not do what it was intended to.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
maths Mathematical functions
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants