Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster integer hashing (fixes #37800 UB) #38031

Merged
merged 2 commits into from Oct 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 0 additions & 3 deletions base/Base.jl
Expand Up @@ -324,9 +324,6 @@ using .MPFR

include("combinatorics.jl")

# more hashing definitions
include("hashing2.jl")

# irrational mathematical constants
include("irrationals.jl")
include("mathconstants.jl")
Expand Down
4 changes: 2 additions & 2 deletions base/abstractset.jl
Expand Up @@ -64,10 +64,10 @@ julia> union!(a, 1:2:8);

julia> a
Set{Int64} with 5 elements:
7
5
4
7
3
5
1
```
"""
Expand Down
132 changes: 124 additions & 8 deletions base/float.jl
Expand Up @@ -460,17 +460,133 @@ Test whether a number is infinite.
"""
isinf(x::Real) = !isnan(x) & !isfinite(x)

## hashing small, built-in numeric types ##
const hx_NaN = hash_uint64(reinterpret(UInt64, NaN))
let Tf = Float64, Tu = UInt64, Ti = Int64
@eval function hash(x::$Tf, h::UInt)
# see comments on trunc and hash(Real, UInt)
if $(Tf(typemin(Ti))) <= x < $(Tf(typemax(Ti)))
xi = fptosi($Ti, x)
if isequal(xi, x)
return hash(xi, h)
end
elseif $(Tf(typemin(Tu))) <= x < $(Tf(typemax(Tu)))
xu = fptoui($Tu, x)
if isequal(xu, x)
return hash(xu, h)
end
elseif isnan(x)
return hx_NaN ⊻ h # NaN does not have a stable bit pattern
end
return hash_uint64(bitcast(UInt64, x)) - 3h
end
end

hash(x::Float32, h::UInt) = hash(Float64(x), h)
hash(x::Float16, h::UInt) = hash(Float64(x), h)

hx(a::UInt64, b::Float64, h::UInt) = hash_uint64((3a + reinterpret(UInt64,b)) - h)
const hx_NaN = hx(UInt64(0), NaN, UInt(0 ))
## generic hashing for rational values ##

hash(x::UInt64, h::UInt) = hx(x, Float64(x), h)
hash(x::Int64, h::UInt) = hx(reinterpret(UInt64, abs(x)), Float64(x), h)
hash(x::Float64, h::UInt) = isnan(x) ? (hx_NaN ⊻ h) : hx(fptoui(UInt64, abs(x)), x, h)
function hash(x::Real, h::UInt)
# decompose x as num*2^pow/den
num, pow, den = decompose(x)

# handle special values
num == 0 && den == 0 && return hash(NaN, h)
num == 0 && return hash(ifelse(den > 0, 0.0, -0.0), h)
den == 0 && return hash(ifelse(num > 0, Inf, -Inf), h)

# normalize decomposition
if den < 0
num = -num
den = -den
end
z = trailing_zeros(num)
if z != 0
num >>= z
pow += z
end
z = trailing_zeros(den)
if z != 0
den >>= z
pow -= z
end

# handle values representable as Int64, UInt64, Float64
if den == 1
left = ndigits0z(num,2) + pow
right = trailing_zeros(num) + pow
if -1074 <= right
if 0 <= right && left <= 64
left <= 63 && return hash(Int64(num) << Int(pow), h)
signbit(num) == signbit(den) && return hash(UInt64(num) << Int(pow), h)
end # typemin(Int64) handled by Float64 case
left <= 1024 && left - right <= 53 && return hash(ldexp(Float64(num),pow), h)
end
end

# handle generic rational values
h = hash_integer(den, h)
h = hash_integer(pow, h)
h = hash_integer(num, h)
return h
end

#=
`decompose(x)`: non-canonical decomposition of rational values as `num*2^pow/den`.

The decompose function is the point where rational-valued numeric types that support
hashing hook into the hashing protocol. `decompose(x)` should return three integer
values `num, pow, den`, such that the value of `x` is mathematically equal to

num*2^pow/den

The decomposition need not be canonical in the sense that it just needs to be *some*
way to express `x` in this form, not any particular way – with the restriction that
`num` and `den` may not share any odd common factors. They may, however, have powers
of two in common – the generic hashing code will normalize those as necessary.

Special values:

- `x` is zero: `num` should be zero and `den` should have the same sign as `x`
- `x` is infinite: `den` should be zero and `num` should have the same sign as `x`
- `x` is not a number: `num` and `den` should both be zero
=#

decompose(x::Integer) = x, 0, 1

function decompose(x::Float16)::NTuple{3,Int}
isnan(x) && return 0, 0, 0
isinf(x) && return ifelse(x < 0, -1, 1), 0, 0
n = reinterpret(UInt16, x)
s = (n & 0x03ff) % Int16
e = ((n & 0x7c00) >> 10) % Int
s |= Int16(e != 0) << 10
d = ifelse(signbit(x), -1, 1)
s, e - 25 + (e == 0), d
end

function decompose(x::Float32)::NTuple{3,Int}
isnan(x) && return 0, 0, 0
isinf(x) && return ifelse(x < 0, -1, 1), 0, 0
n = reinterpret(UInt32, x)
s = (n & 0x007fffff) % Int32
e = ((n & 0x7f800000) >> 23) % Int
s |= Int32(e != 0) << 23
d = ifelse(signbit(x), -1, 1)
s, e - 150 + (e == 0), d
end

function decompose(x::Float64)::Tuple{Int64, Int, Int}
isnan(x) && return 0, 0, 0
isinf(x) && return ifelse(x < 0, -1, 1), 0, 0
n = reinterpret(UInt64, x)
s = (n & 0x000fffffffffffff) % Int64
e = ((n & 0x7ff0000000000000) >> 52) % Int
s |= Int64(e != 0) << 52
d = ifelse(signbit(x), -1, 1)
s, e - 1075 + (e == 0), d
end

hash(x::Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}, h::UInt) = hash(Int64(x), h)
hash(x::Float32, h::UInt) = hash(Float64(x), h)

"""
precision(num::AbstractFloat)
Expand Down
75 changes: 74 additions & 1 deletion base/gmp.jl
Expand Up @@ -132,7 +132,7 @@ module MPZ
# - a method modifying its input has a "!" appendend to its name, according to Julia's conventions
# - some convenient methods are added (in addition to the pure MPZ ones), e.g. `add(a, b) = add!(BigInt(), a, b)`
# and `add!(x, a) = add!(x, x, a)`.
using .Base.GMP: BigInt, Limb, BITS_PER_LIMB
using ..GMP: BigInt, Limb, BITS_PER_LIMB

const mpz_t = Ref{BigInt}
const bitcnt_t = Culong
Expand Down Expand Up @@ -764,4 +764,77 @@ function Base.deepcopy_internal(x::BigInt, stackdict::IdDict)
return y
end

## streamlined hashing for BigInt, by avoiding allocation from shifts ##

if Limb === UInt
# this condition is true most (all?) of the time, and in this case we can define
# an optimized version of the above hash_integer(::Integer, ::UInt) method for BigInt
# used e.g. for Rational{BigInt}
function hash_integer(n::BigInt, h::UInt)
GC.@preserve n begin
s = n.size
s == 0 && return hash_integer(0, h)
p = convert(Ptr{UInt}, n.d)
b = unsafe_load(p)
h ⊻= hash_uint(ifelse(s < 0, -b, b) ⊻ h)
for k = 2:abs(s)
h ⊻= hash_uint(unsafe_load(p, k) ⊻ h)
end
return h
end
end

_divLimb(n) = UInt === UInt64 ? n >>> 6 : n >>> 5
_modLimb(n) = UInt === UInt64 ? n & 63 : n & 31

function hash(x::BigInt, h::UInt)
GC.@preserve x begin
sz = x.size
sz == 0 && return hash(0, h)
ptr = Ptr{UInt}(x.d)
if sz == 1
return hash(unsafe_load(ptr), h)
elseif sz == -1
limb = unsafe_load(ptr)
limb <= typemin(Int) % UInt && return hash(-(limb % Int), h)
end
pow = trailing_zeros(x)
nd = ndigits0z(x, 2)
idx = _divLimb(pow) + 1
shift = _modLimb(pow) % UInt
upshift = BITS_PER_LIMB - shift
asz = abs(sz)
if shift == 0
limb = unsafe_load(ptr, idx)
else
limb1 = unsafe_load(ptr, idx)
limb2 = idx < asz ? unsafe_load(ptr, idx+1) : UInt(0)
limb = limb2 << upshift | limb1 >> shift
end
if nd <= 1024 && nd - pow <= 53
return hash(ldexp(flipsign(Float64(limb), sz), pow), h)
end
h = hash_integer(1, h)
h = hash_integer(pow, h)
h ⊻= hash_uint(flipsign(limb, sz) ⊻ h)
for idx = idx+1:asz
if shift == 0
limb = unsafe_load(ptr, idx)
else
limb1 = limb2
if idx == asz
limb = limb1 >> shift
limb == 0 && break # don't hash leading zeros
else
limb2 = unsafe_load(ptr, idx+1)
limb = limb2 << upshift | limb1 >> shift
end
end
h ⊻= hash_uint(limb ⊻ h)
end
return h
end
end
end

end # module
17 changes: 17 additions & 0 deletions base/hashing.jl
Expand Up @@ -66,6 +66,23 @@ else
hash_uint(x::UInt) = hash_32_32(x)
end

## efficient value-based hashing of integers ##

hash(x::Int64, h::UInt) = hash_uint64(bitcast(UInt64, x)) - 3h
hash(x::UInt64, h::UInt) = hash_uint64(x) - 3h
hash(x::Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}, h::UInt) = hash(Int64(x), h)

function hash_integer(n::Integer, h::UInt)
h ⊻= hash_uint((n % UInt) ⊻ h)
n = abs(n)
n >>>= sizeof(UInt) << 3
while n != 0
h ⊻= hash_uint((n % UInt) ⊻ h)
n >>>= sizeof(UInt) << 3
end
return h
end

## symbol & expression hashing ##

if UInt === UInt64
Expand Down