Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster rand!(::MersenneTwister, ::Array{T}) for IntTypes and Float16/32 #8958

Merged
merged 5 commits into from
Nov 12, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 114 additions & 25 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,56 @@ export srand,

abstract AbstractRNG

const MTCacheLength = dsfmt_get_min_array_size()

type MersenneTwister <: AbstractRNG
state::DSFMT_state
vals::Vector{Float64}
idx::Int
seed::Vector{UInt32}

MersenneTwister(seed) = srand(new(DSFMT_state(), Array(Float64, dsfmt_get_min_array_size())),
MersenneTwister(seed) = srand(new(DSFMT_state(), Array(Float64, MTCacheLength)),
seed)
MersenneTwister() = MersenneTwister(0)
end

## Low level API for MersenneTwister

@inline mt_avail(r::MersenneTwister) = MTCacheLength - r.idx
@inline mt_empty(r::MersenneTwister) = r.idx == MTCacheLength
@inline mt_setfull!(r::MersenneTwister) = r.idx = 0
@inline mt_setempty!(r::MersenneTwister) = r.idx = MTCacheLength
@inline mt_pop!(r::MersenneTwister) = @inbounds return r.vals[r.idx+=1]

function gen_rand(r::MersenneTwister)
dsfmt_fill_array_close1_open2!(r.state, r.vals, length(r.vals))
r.idx = 0
mt_setfull!(r)
end

@inline gen_rand_maybe(r::MersenneTwister) = r.idx == length(r.vals) && gen_rand(r)
@inline gen_rand_maybe(r::MersenneTwister) = mt_empty(r) && gen_rand(r)

# precondition: r.idx < length(r.vals)
@inline rand_close1_open2_inbounds(r::MersenneTwister) = (r.idx += 1; @inbounds return r.vals[r.idx])
@inline rand_inbounds(r::MersenneTwister) = rand_close1_open2_inbounds(r) - 1.0
abstract FloatInterval
type CloseOpen <: FloatInterval end
type Close1Open2 <: FloatInterval end

# precondition: !mt_empty(r)
@inline rand_inbounds(r::MersenneTwister, ::Type{Close1Open2}) = mt_pop!(r)
@inline rand_inbounds(r::MersenneTwister, ::Type{CloseOpen}) = rand_inbounds(r, Close1Open2) - 1.0
@inline rand_inbounds(r::MersenneTwister) = rand_inbounds(r, CloseOpen)

# produce Float64 values
@inline rand_close1_open2(r::MersenneTwister) = (gen_rand_maybe(r); rand_close1_open2_inbounds(r))
@inline rand_close_open(r::MersenneTwister) = (gen_rand_maybe(r); rand_inbounds(r))
@inline rand{I<:FloatInterval}(r::MersenneTwister, ::Type{I}) = (gen_rand_maybe(r); rand_inbounds(r, I))

# this is similar to `dsfmt_genrand_uint32` from dSFMT.h:
@inline rand_ui32(r::MersenneTwister) = reinterpret(UInt64, rand_close1_open2(r)) % UInt32
@inline rand_ui32(r::MersenneTwister) = reinterpret(UInt64, rand(r, Close1Open2)) % UInt32

@inline rand_ui52_raw(r::MersenneTwister) = reinterpret(UInt64, rand(r, Close1Open2))
@inline rand_ui2x52_raw(r::MersenneTwister) = (((rand_ui52_raw(r) % UInt128) << 64) | rand_ui52_raw(r))

function srand(r::MersenneTwister, seed::Vector{UInt32})
r.seed = seed
dsfmt_init_by_array(r.state, r.seed)
r.idx = length(r.vals)
mt_setempty!(r)
return r
end

Expand Down Expand Up @@ -126,7 +140,7 @@ globalRNG() = GLOBAL_RNG

# rand: a non-specified RNG defaults to GLOBAL_RNG

@inline rand() = rand_close_open(GLOBAL_RNG)
@inline rand() = rand(GLOBAL_RNG, CloseOpen)
@inline rand(T::Type) = rand(GLOBAL_RNG, T)
rand(::()) = rand(GLOBAL_RNG, ()) # needed to resolve ambiguity
rand(dims::Dims) = rand(GLOBAL_RNG, dims)
Expand All @@ -137,10 +151,10 @@ rand!(A::AbstractArray) = rand!(GLOBAL_RNG, A)

## random floating point values

@inline rand(r::AbstractRNG) = rand_close_open(r)
@inline rand(r::AbstractRNG) = rand(r, CloseOpen)

# MersenneTwister
rand(r::MersenneTwister, ::Type{Float64}) = rand_close_open(r)
rand(r::MersenneTwister, ::Type{Float64}) = rand(r, CloseOpen)
rand{T<:Union(Float16, Float32)}(r::MersenneTwister, ::Type{T}) = convert(T, rand(r, Float64))

## random integers (MersenneTwister)
Expand Down Expand Up @@ -181,22 +195,21 @@ end

# MersenneTwister

function rand_AbstractArray_Float64!(r::MersenneTwister, A::AbstractArray{Float64})
n = length(A)
function rand_AbstractArray_Float64!{I<:FloatInterval}(r::MersenneTwister, A::AbstractArray{Float64}, n=length(A), ::Type{I}=CloseOpen)
# what follows is equivalent to this simple loop but more efficient:
# for i=1:n
# @inbounds A[i] = rand(r)
# @inbounds A[i] = rand(r, I)
# end
m = 0
while m < n
s = length(r.vals) - r.idx
s = mt_avail(r)
if s == 0
gen_rand(r)
s = length(r.vals)
s = mt_avail(r)
end
m2 = min(n, m+s)
for i=m+1:m2
@inbounds A[i] = rand_inbounds(r)
@inbounds A[i] = rand_inbounds(r, I)
end
m = m2
end
Expand All @@ -205,13 +218,89 @@ end

rand!(r::MersenneTwister, A::AbstractArray{Float64}) = rand_AbstractArray_Float64!(r, A)

function rand!(r::MersenneTwister, A::Array{Float64})
n = length(A)
fill_array!(s::DSFMT_state, A::Array{Float64}, n::Int, ::Type{CloseOpen}) = dsfmt_fill_array_close_open!(s, A, n)
fill_array!(s::DSFMT_state, A::Array{Float64}, n::Int, ::Type{Close1Open2}) = dsfmt_fill_array_close1_open2!(s, A, n)

function rand!{I<:FloatInterval}(r::MersenneTwister, A::Array{Float64}, n=length(A), ::Type{I}=CloseOpen)
if n < dsfmt_get_min_array_size()
rand_AbstractArray_Float64!(r, A)
rand_AbstractArray_Float64!(r, A, n, I)
else
dsfmt_fill_array_close_open!(r.state, A, 2*(n ÷ 2))
isodd(n) && (A[n] = rand(r))
fill_array!(r.state, A, 2*(n ÷ 2), I)
isodd(n) && (A[n] = rand(r, I))
end
A
end

@inline mask128(u::UInt128, ::Type{Float16}) = (u & 0x03ff03ff03ff03ff03ff03ff03ff03ff) | 0x3c003c003c003c003c003c003c003c00
@inline mask128(u::UInt128, ::Type{Float32}) = (u & 0x007fffff007fffff007fffff007fffff) | 0x3f8000003f8000003f8000003f800000

function rand!{T<:Union(Float16, Float32)}(r::MersenneTwister, A::Array{T}, ::Type{Close1Open2})
n = length(A)
n128 = n * sizeof(T) ÷ 16
rand!(r, pointer_to_array(convert(Ptr{Float64}, pointer(A)), 2*n128), 2*n128, Close1Open2)
A128 = pointer_to_array(convert(Ptr{UInt128}, pointer(A)), n128)
@inbounds for i in 1:n128
u = A128[i]
u $= u << 26
# at this point, the 64 low bits of u, "k" being the k-th bit of A128[i] and "+" the bit xor, are:
# [..., 58+32,..., 53+27, 52+26, ..., 33+7, 32+6, ..., 27+1, 26, ..., 1]
# the bits needing to be random are
# [1:10, 17:26, 33:42, 49:58] (for Float16)
# [1:23, 33:55] (for Float32)
# this is obviously satisfied on the 32 low bits side, and on the high side, the entropy comes
# from bits 33:52 of A128[i] and then from bits 27:32 (which are discarded on the low side)
# this is similar for the 64 high bits of u
A128[i] = mask128(u, T)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to get a few more eyes on this.

end
for i in 16*n128÷sizeof(T)+1:n
@inbounds A[i] = rand(r, T) + one(T)
end
A
end

function rand!{T<:Union(Float16, Float32)}(r::MersenneTwister, A::Array{T}, ::Type{CloseOpen})
rand!(r, A, Close1Open2)
I32 = one(Float32)
for i in 1:length(A)
@inbounds A[i] = T(Float32(A[i])-I32) # faster than "A[i] -= one(T)" for T==Float16
end
A
end

rand!{T<:Union(Float16, Float32)}(r::MersenneTwister, A::Array{T}) = rand!(r, A, CloseOpen)


function rand!(r::MersenneTwister, A::Array{UInt128}, n=length(A))
Af = pointer_to_array(convert(Ptr{Float64}, pointer(A)), 2n)
i = n
while true
rand!(r, Af, 2i, Close1Open2)
n < 5 && break
i = 0
@inbounds while n-i >= 5
u = A[i+=1]
A[n] $= u << 48
A[n-=1] $= u << 36
A[n-=1] $= u << 24
A[n-=1] $= u << 12
n-=1
end
end
if n > 0
u = rand_ui2x52_raw(r)
for i = 1:n
@inbounds A[i] $= u << 12*i
end
end
A
end

function rand!{T<:Union(Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128)}(r::MersenneTwister, A::Array{T})
n=length(A)
n128 = n * sizeof(T) ÷ 16
rand!(r, pointer_to_array(convert(Ptr{UInt128}, pointer(A)), n128))
for i = 16*n128÷sizeof(T)+1:n
@inbounds A[i] = rand(r, T)
end
A
end
Expand Down Expand Up @@ -811,7 +900,7 @@ const ziggurat_nor_r = 3.6541528853610087963519472518
const ziggurat_nor_inv_r = inv(ziggurat_nor_r)
const ziggurat_exp_r = 7.6971174701310497140446280481

@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand_close1_open2(rng)) & 0x000fffffffffffff
@inline randi(rng::MersenneTwister=GLOBAL_RNG) = reinterpret(Uint64, rand(rng, Close1Open2)) & 0x000fffffffffffff

function randmtzig_randn(rng::MersenneTwister=GLOBAL_RNG)
@inbounds begin
Expand Down
44 changes: 42 additions & 2 deletions test/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ rand!(MersenneTwister(0), A)
@test rand(MersenneTwister(0), Int64, 1) == [172014471070449746]
A = zeros(Int64, 2, 2)
rand!(MersenneTwister(0), A)
@test A == [ 172014471070449746 -193283627354378518;
-4679130500738884555 -9008350441255501549]
@test A == [858542123778948672 5715075217119798169;
8690327730555225005 8435109092665372532]

# randn
@test randn(MersenneTwister(42)) == -0.5560268761463861
Expand Down Expand Up @@ -182,3 +182,43 @@ i8257 = 1:1/3:100
for i = 1:100
@test rand(i8257) in i8257
end

# test code paths of rand!

let mt = MersenneTwister(0)
A128 = Array(UInt128, 0)
@test length(rand!(mt, A128)) == 0
for (i,n) in enumerate([1, 3, 5, 6, 10, 11, 30])
resize!(A128, n)
rand!(mt, A128)
@test length(A128) == n
@test A128[end] == UInt128[0x15de6b23025813ad129841f537a04e40,
0xcfa4db38a2c65bc4f18c07dc91125edf,
0x33bec08136f19b54290982449b3900d5,
0xde41af3463e74cb830dad4add353ca20,
0x066d8695ebf85f833427c93416193e1f,
0x48fab49cc9fcee1c920d6dae629af446,
0x4b54632b4619f4eca22675166784d229][i]

end

srand(mt,0)
for (i,T) in enumerate([Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64, Int128, Float16, Float32])
A = Array(T, 16)
B = Array(T, 31)
rand!(mt, A)
rand!(mt, B)
@test A[end] == Any[21,0x7b,17385,0x3086,-1574090021,0xadcb4460,6797283068698303107,0x4e91c9c4d4f5f759,
-3482609696641744459568613291754091152,float16(0.03125),0.68733835f0][i]

@test B[end] == Any[49,0x65,-3725,0x719d,814246081,0xdf61843a,-1603010949539670188,0x5e4ca1658810985d,
-33032345278809823492812856023466859769,float16(0.9346),0.5929704f0][i]
end

srand(mt,0)
AF64 = Array(Float64, Base.Random.dsfmt_get_min_array_size()-1)
@test rand!(mt, AF64)[end] == 0.957735065345398
@test rand!(mt, AF64)[end] == 0.6492481059865669
resize!(AF64, 2*length(mt.vals))
@test Base.Random.rand_AbstractArray_Float64!(mt, AF64)[end] == 0.432757268470779
end