Skip to content

Commit 2ae0b7e

Browse files
authored
Vectorise random vectors of Float16 (#55997)
1 parent 8248bf4 commit 2ae0b7e

File tree

1 file changed

+26
-8
lines changed

1 file changed

+26
-8
lines changed

stdlib/Random/src/XoshiroSimd.jl

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ simdThreshold(::Type{Bool}) = 640
4444
l = Float32(li >>> 8) * Float32(0x1.0p-24)
4545
(UInt64(reinterpret(UInt32, u)) << 32) | UInt64(reinterpret(UInt32, l))
4646
end
47+
@inline function _bits2float(x::UInt64, ::Type{Float16})
48+
i1 = (x>>>48) % UInt16
49+
i2 = (x>>>32) % UInt16
50+
i3 = (x>>>16) % UInt16
51+
i4 = x % UInt16
52+
f1 = Float16(i1 >>> 5) * Float16(0x1.0p-11)
53+
f2 = Float16(i2 >>> 5) * Float16(0x1.0p-11)
54+
f3 = Float16(i3 >>> 5) * Float16(0x1.0p-11)
55+
f4 = Float16(i4 >>> 5) * Float16(0x1.0p-11)
56+
return (UInt64(reinterpret(UInt16, f1)) << 48) | (UInt64(reinterpret(UInt16, f2)) << 32) | (UInt64(reinterpret(UInt16, f3)) << 16) | UInt64(reinterpret(UInt16, f4))
57+
end
4758

4859
# required operations. These could be written more concisely with `ntuple`, but the compiler
4960
# sometimes refuses to properly vectorize.
@@ -118,6 +129,18 @@ for N in [4,8,16]
118129
ret <$N x i64> %i
119130
"""
120131
@eval @inline _bits2float(x::$VT, ::Type{Float32}) = llvmcall($code, $VT, Tuple{$VT}, x)
132+
133+
code = """
134+
%as16 = bitcast <$N x i64> %0 to <$(4N) x i16>
135+
%shiftamt = shufflevector <1 x i16> <i16 5>, <1 x i16> undef, <$(4N) x i32> zeroinitializer
136+
%sh = lshr <$(4N) x i16> %as16, %shiftamt
137+
%f = uitofp <$(4N) x i16> %sh to <$(4N) x half>
138+
%scale = shufflevector <1 x half> <half 0x3f40000000000000>, <1 x half> undef, <$(4N) x i32> zeroinitializer
139+
%m = fmul <$(4N) x half> %f, %scale
140+
%i = bitcast <$(4N) x half> %m to <$N x i64>
141+
ret <$N x i64> %i
142+
"""
143+
@eval @inline _bits2float(x::$VT, ::Type{Float16}) = llvmcall($code, $VT, Tuple{$VT}, x)
121144
end
122145
end
123146

@@ -137,7 +160,7 @@ end
137160

138161
_id(x, T) = x
139162

140-
@inline function xoshiro_bulk(rng::Union{TaskLocalRNG, Xoshiro}, dst::Ptr{UInt8}, len::Int, T::Union{Type{UInt8}, Type{Bool}, Type{Float32}, Type{Float64}}, ::Val{N}, f::F = _id) where {N, F}
163+
@inline function xoshiro_bulk(rng::Union{TaskLocalRNG, Xoshiro}, dst::Ptr{UInt8}, len::Int, T::Union{Type{UInt8}, Type{Bool}, Type{Float16}, Type{Float32}, Type{Float64}}, ::Val{N}, f::F = _id) where {N, F}
141164
if len >= simdThreshold(T)
142165
written = xoshiro_bulk_simd(rng, dst, len, T, Val(N), f)
143166
len -= written
@@ -265,13 +288,8 @@ end
265288
end
266289

267290

268-
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{Float32}, ::SamplerTrivial{CloseOpen01{Float32}})
269-
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*4, Float32, xoshiroWidth(), _bits2float)
270-
dst
271-
end
272-
273-
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{Float64}, ::SamplerTrivial{CloseOpen01{Float64}})
274-
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*8, Float64, xoshiroWidth(), _bits2float)
291+
function rand!(rng::Union{TaskLocalRNG, Xoshiro}, dst::Array{T}, ::SamplerTrivial{CloseOpen01{T}}) where {T<:Union{Float16,Float32,Float64}}
292+
GC.@preserve dst xoshiro_bulk(rng, convert(Ptr{UInt8}, pointer(dst)), length(dst)*sizeof(T), T, xoshiroWidth(), _bits2float)
275293
dst
276294
end
277295

0 commit comments

Comments
 (0)