@@ -44,6 +44,17 @@ simdThreshold(::Type{Bool}) = 640
4444 l = Float32 (li >>> 8 ) * Float32 (0x1 .0 p- 24 )
4545 (UInt64 (reinterpret (UInt32, u)) << 32 ) | UInt64 (reinterpret (UInt32, l))
4646end
47+ @inline function _bits2float (x:: UInt64 , :: Type{Float16} )
48+ i1 = (x>>> 48 ) % UInt16
49+ i2 = (x>>> 32 ) % UInt16
50+ i3 = (x>>> 16 ) % UInt16
51+ i4 = x % UInt16
52+ f1 = Float16 (i1 >>> 5 ) * Float16 (0x1 .0 p- 11 )
53+ f2 = Float16 (i2 >>> 5 ) * Float16 (0x1 .0 p- 11 )
54+ f3 = Float16 (i3 >>> 5 ) * Float16 (0x1 .0 p- 11 )
55+ f4 = Float16 (i4 >>> 5 ) * Float16 (0x1 .0 p- 11 )
56+ return (UInt64 (reinterpret (UInt16, f1)) << 48 ) | (UInt64 (reinterpret (UInt16, f2)) << 32 ) | (UInt64 (reinterpret (UInt16, f3)) << 16 ) | UInt64 (reinterpret (UInt16, f4))
57+ end
4758
4859# required operations. These could be written more concisely with `ntuple`, but the compiler
4960# sometimes refuses to properly vectorize.
@@ -118,6 +129,18 @@ for N in [4,8,16]
118129 ret <$N x i64> %i
119130 """
120131 @eval @inline _bits2float (x:: $VT , :: Type{Float32} ) = llvmcall ($ code, $ VT, Tuple{$ VT}, x)
132+
133+ code = """
134+ %as16 = bitcast <$N x i64> %0 to <$(4 N) x i16>
135+ %shiftamt = shufflevector <1 x i16> <i16 5>, <1 x i16> undef, <$(4 N) x i32> zeroinitializer
136+ %sh = lshr <$(4 N) x i16> %as16, %shiftamt
137+ %f = uitofp <$(4 N) x i16> %sh to <$(4 N) x half>
138+ %scale = shufflevector <1 x half> <half 0x3f40000000000000>, <1 x half> undef, <$(4 N) x i32> zeroinitializer
139+ %m = fmul <$(4 N) x half> %f, %scale
140+ %i = bitcast <$(4 N) x half> %m to <$N x i64>
141+ ret <$N x i64> %i
142+ """
143+ @eval @inline _bits2float (x:: $VT , :: Type{Float16} ) = llvmcall ($ code, $ VT, Tuple{$ VT}, x)
121144 end
122145end
123146
137160
138161_id (x, T) = x
139162
140- @inline function xoshiro_bulk (rng:: Union{TaskLocalRNG, Xoshiro} , dst:: Ptr{UInt8} , len:: Int , T:: Union{Type{UInt8}, Type{Bool}, Type{Float32}, Type{Float64}} , :: Val{N} , f:: F = _id) where {N, F}
163+ @inline function xoshiro_bulk (rng:: Union{TaskLocalRNG, Xoshiro} , dst:: Ptr{UInt8} , len:: Int , T:: Union{Type{UInt8}, Type{Bool}, Type{Float16}, Type{ Float32}, Type{Float64}} , :: Val{N} , f:: F = _id) where {N, F}
141164 if len >= simdThreshold (T)
142165 written = xoshiro_bulk_simd (rng, dst, len, T, Val (N), f)
143166 len -= written
265288end
266289
267290
268- function rand! (rng:: Union{TaskLocalRNG, Xoshiro} , dst:: Array{Float32} , :: SamplerTrivial{CloseOpen01{Float32}} )
269- GC. @preserve dst xoshiro_bulk (rng, convert (Ptr{UInt8}, pointer (dst)), length (dst)* 4 , Float32, xoshiroWidth (), _bits2float)
270- dst
271- end
272-
273- function rand! (rng:: Union{TaskLocalRNG, Xoshiro} , dst:: Array{Float64} , :: SamplerTrivial{CloseOpen01{Float64}} )
274- GC. @preserve dst xoshiro_bulk (rng, convert (Ptr{UInt8}, pointer (dst)), length (dst)* 8 , Float64, xoshiroWidth (), _bits2float)
291+ function rand! (rng:: Union{TaskLocalRNG, Xoshiro} , dst:: Array{T} , :: SamplerTrivial{CloseOpen01{T}} ) where {T<: Union{Float16,Float32,Float64} }
292+ GC. @preserve dst xoshiro_bulk (rng, convert (Ptr{UInt8}, pointer (dst)), length (dst)* sizeof (T), T, xoshiroWidth (), _bits2float)
275293 dst
276294end
277295
0 commit comments