Skip to content

Commit

Permalink
improve exception type inference for core math functions
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Nov 20, 2023
1 parent 76143d3 commit 525bd6c
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 34 deletions.
12 changes: 6 additions & 6 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ to tell the compiler that indexing operations within the applied expression are
inbounds and do not need to taint `:consistent` and `:nothrow`.
"""
macro _safeindex(ex)
return esc(_safeindex(__module__, ex))
return esc(_safeindex(@__MODULE__, ex))

This comment has been minimized.

Copy link
@vtjnash

vtjnash Nov 20, 2023

Member

Might be clearer to write Base here explicitly?

This comment has been minimized.

Copy link
@aviatesk

aviatesk Nov 21, 2023

Author Member

This is used from Core.Compiler too, so this need to be @__MODULE__.

end
function _safeindex(__module__, ex)
function _safeindex(mod, ex)
isa(ex, Expr) || return ex
if ex.head === :(=)
lhs = ex.args[1]
Expand All @@ -141,16 +141,16 @@ function _safeindex(__module__, ex)
xs = lhs.args[1]
args = Vector{Any}(undef, length(lhs.args)-1)
for i = 2:length(lhs.args)
args[i-1] = _safeindex(__module__, lhs.args[i])
args[i-1] = _safeindex(mod, lhs.args[i])
end
return Expr(:call, GlobalRef(__module__, :__safe_setindex!), xs, _safeindex(__module__, rhs), args...)
return Expr(:call, GlobalRef(mod, :__safe_setindex!), xs, _safeindex(mod, rhs), args...)
end
elseif ex.head === :ref # xs[i]
return Expr(:call, GlobalRef(__module__, :__safe_getindex), ex.args...)
return Expr(:call, GlobalRef(mod, :__safe_getindex), ex.args...)
end
args = Vector{Any}(undef, length(ex.args))
for i = 1:length(ex.args)
args[i] = _safeindex(__module__, ex.args[i])
args[i] = _safeindex(mod, ex.args[i])
end
return Expr(ex.head, args...)
end
Expand Down
8 changes: 4 additions & 4 deletions base/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,10 +304,10 @@ end

# polynomial evaluation using compensated summation.
# much more accurate, especially when lo can be combined with other rounding errors
Base.@assume_effects :terminates_locally @inline function exthorner(x, p::Tuple)
hi, lo = p[end], zero(x)
for i in length(p)-1:-1:1
pi = getfield(p, i) # needed to prove consistency
Base.@assume_effects :terminates_locally @inline function exthorner(x::T, p::Tuple{T,T,T}) where T<:Union{Float32,Float64}
hi, lo = Base.@_safeindex(p[lastindex(p)]), zero(x)
Base.@_safeindex for i in length(p)-1:-1:1
pi = p[i] # needed to prove consistency
prod, err = two_mul(hi,x)
hi = pi+prod
lo = fma(lo, x, prod - (hi - pi) + err)
Expand Down
14 changes: 4 additions & 10 deletions base/special/log.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,11 @@ logbU(::Type{Float64},::Val{10}) = 0.4342944819032518
logbL(::Type{Float64},::Val{10}) = 1.098319650216765e-17

# Procedure 1
# XXX we want to mark :noub here so that this function can be concrete-folded,
# because the effect analysis currently can't prove it in the presence of `@inbounds` or
# `:boundscheck`, but still the access to `t_log_Float64` is really safe here
Base.@assume_effects :consistent :noub @inline function log_proc1(y::Float64,mf::Float64,F::Float64,f::Float64,base=Val(:ℯ))
@inline function log_proc1(y::Float64,mf::Float64,F::Float64,f::Float64,base=Val(:ℯ))
jp = unsafe_trunc(Int,128.0*F)-127

## Steps 1 and 2
@inbounds hi,lo = t_log_Float64[jp]
Base.@_safeindex hi,lo = t_log_Float64[jp]

This comment has been minimized.

Copy link
@Keno

Keno Nov 20, 2023

Member

I'd rather pull out this table lookup into a separate function along with the jp computation and a long comment about why it's inbounds.

l_hi = mf* 0.6931471805601177 + hi
l_lo = mf*-1.7239444525614835e-13 + lo

Expand Down Expand Up @@ -216,14 +213,11 @@ end
end

# Procedure 1
# XXX we want to mark :noub here so that this function can be concrete-folded,
# because the effect analysis currently can't prove it in the presence of `@inbounds` or
# `:boundscheck`, but still the access to `t_log_Float32` is really safe here
Base.@assume_effects :consistent :noub @inline function log_proc1(y::Float32,mf::Float32,F::Float32,f::Float32,base=Val(:ℯ))
@inline function log_proc1(y::Float32,mf::Float32,F::Float32,f::Float32,base=Val(:ℯ))
jp = unsafe_trunc(Int,128.0f0*F)-127

## Steps 1 and 2
@inbounds hi = t_log_Float32[jp]
Base.@_safeindex hi = t_log_Float32[jp]
l = mf*0.6931471805599453 + hi

## Step 3
Expand Down
19 changes: 8 additions & 11 deletions base/special/rem_pio2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,7 @@ function fromfraction(f::Int128)
return (z1,z2)
end

# XXX we want to mark :noub here so that this function can be concrete-folded,
# because the effect analysis currently can't prove it in the presence of `@inbounds` or
# `:boundscheck`, but still the accesses to `INV_2PI` are really safe here
Base.@assume_effects :consistent :noub function paynehanek(x::Float64)
function paynehanek(x::Float64)
# 1. Convert to form
#
# x = X * 2^k,
Expand Down Expand Up @@ -168,15 +165,15 @@ Base.@assume_effects :consistent :noub function paynehanek(x::Float64)
idx = k >> 6

shift = k - (idx << 6)
if shift == 0
@inbounds a1 = INV_2PI[idx+1]
@inbounds a2 = INV_2PI[idx+2]
@inbounds a3 = INV_2PI[idx+3]
Base.@_safeindex if shift == 0
a1 = INV_2PI[idx+1]
a2 = INV_2PI[idx+2]
a3 = INV_2PI[idx+3]
else
# use shifts to extract the relevant 64 bit window
@inbounds a1 = (idx < 0 ? zero(UInt64) : INV_2PI[idx+1] << shift) | (INV_2PI[idx+2] >> (64 - shift))
@inbounds a2 = (INV_2PI[idx+2] << shift) | (INV_2PI[idx+3] >> (64 - shift))
@inbounds a3 = (INV_2PI[idx+3] << shift) | (INV_2PI[idx+4] >> (64 - shift))
a1 = (idx < 0 ? zero(UInt64) : INV_2PI[idx+1] << shift) | (INV_2PI[idx+2] >> (64 - shift))
a2 = (INV_2PI[idx+2] << shift) | (INV_2PI[idx+3] >> (64 - shift))
a3 = (INV_2PI[idx+3] << shift) | (INV_2PI[idx+4] >> (64 - shift))
end

# 3. Perform the multiplication:
Expand Down
28 changes: 25 additions & 3 deletions test/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1543,9 +1543,8 @@ end

@testset "constant-foldability of core math functions" begin
for fn in (:sin, :cos, :tan, :log, :log2, :log10, :log1p, :exponent, :sqrt, :cbrt, :fourthroot,
:asin, :atan, :acos, :sinh, :cosh, :tanh, :asinh, :acosh, :atanh,
:exp, :exp2, :exp10, :expm1
)
:asin, :atan, :acos, :sinh, :cosh, :tanh, :asinh, :acosh, :atanh,
:exp, :exp2, :exp10, :expm1)
for T in (Float16, Float32, Float64)
@testset let f = getfield(@__MODULE__, fn), T = T
@test Core.Compiler.is_foldable(Base.infer_effects(f, (T,)))
Expand All @@ -1566,6 +1565,29 @@ end;
end
end
end;
@testset "exception type inference of core math functions" begin
for fn in (:sin, :cos, :tan, :log, :log2, :log10, :log1p, :exponent, :sqrt, :cbrt, :fourthroot,
:asin, :atan, :acos, :sinh, :cosh, :tanh, :asinh, :acosh, :atanh,
:exp, :exp2, :exp10, :expm1)
for T in (Float16, Float32, Float64)
@testset let f = getfield(@__MODULE__, fn), T = T
@test Base.exception_type(f, (T,)) <: Union{DomainError, InexactError}
@show f, T, Base.exception_type(f, (T,))
end
end
end
end;
@test Base.return_types((Int,)) do x
local r = nothing
try
r = sin(x)
catch err
if err isa DomainError
r = 0.0
end
end
return r
end |> only === Float64

@testset "BigInt Rationals with special funcs" begin
@test sinpi(big(1//1)) == big(0.0)
Expand Down

0 comments on commit 525bd6c

Please sign in to comment.