Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dSFMT: use fill_array_* API instead of genrand_* API #8808

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions base/dSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ export DSFMT_state, dsfmt_get_min_array_size, dsfmt_get_idstring,
dsfmt_genrand_close1_open2, dsfmt_gv_genrand_close1_open2,
dsfmt_genrand_close_open, dsfmt_gv_genrand_close_open,
dsfmt_genrand_uint32, dsfmt_gv_genrand_uint32,
dsfmt_fill_array_close_open!, dsfmt_fill_array_close1_open2!,
win32_SystemFunction036!

type DSFMT_state
Expand Down Expand Up @@ -95,6 +96,24 @@ function dsfmt_gv_genrand_uint32()
())
end

# precondition for dsfmt_fill_array_*:
# the underlying C array must be 16-byte aligned, which is the case for "Array"
function dsfmt_fill_array_close1_open2!(s::DSFMT_state, A::Array{Float64}, n::Int)
@assert dsfmt_min_array_size <= n <= length(A) && iseven(n)
ccall((:dsfmt_fill_array_close1_open2,:libdSFMT),
Void,
(Ptr{Void}, Ptr{Float64}, Int),
s.val, A, n)
end

function dsfmt_fill_array_close_open!(s::DSFMT_state, A::Array{Float64}, n::Int)
@assert dsfmt_min_array_size <= n <= length(A) && iseven(n)
ccall((:dsfmt_fill_array_close_open,:libdSFMT),
Void,
(Ptr{Void}, Ptr{Float64}, Int),
s.val, A, n)
end

## Windows entropy

@windows_only begin
Expand Down
34 changes: 31 additions & 3 deletions base/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,23 @@ abstract AbstractRNG
type MersenneTwister <: AbstractRNG
state::DSFMT_state
seed::Union(Uint32,Vector{Uint32})
vals::Vector{Float64}
idx::Int

function MersenneTwister(seed::Vector{Uint32})
state = DSFMT_state()
dsfmt_init_by_array(state, seed)
return new(state, seed)
return new(state, seed, Array(Float64, dsfmt_get_min_array_size()), dsfmt_get_min_array_size())
end

MersenneTwister(seed=0) = MersenneTwister(make_seed(seed))
end

function gen_rand(r::MersenneTwister)
dsfmt_fill_array_close1_open2!(r.state, r.vals, length(r.vals))
r.idx = 0
end

function srand(r::MersenneTwister, seed)
r.seed = seed
dsfmt_init_gen_rand(r.state, seed)
Expand Down Expand Up @@ -96,8 +103,8 @@ rand(::Type{Float16}) = float16(rand())

rand{T<:Real}(::Type{Complex{T}}) = complex(rand(T),rand(T))


rand(r::MersenneTwister) = dsfmt_genrand_close_open(r.state)
@inline rand_inbounds(r::MersenneTwister) = (r.idx += 1; @inbounds return r.vals[r.idx] - 1.0)
@inline rand(r::MersenneTwister) = (r.idx == length(r.vals) && gen_rand(r); rand_inbounds(r))

## random integers

Expand Down Expand Up @@ -141,6 +148,27 @@ function rand!(r::AbstractRNG, A::AbstractArray)
A
end

function rand!(r::MersenneTwister, A::Array{Float64})
n = length(A)
if n < dsfmt_get_min_array_size()
s = length(r.vals) - r.idx
m = min(n, s)
for i=1:m
@inbounds A[i] = rand_inbounds(r)
end
if n > s
gen_rand(r)
for i=m+1:n
@inbounds A[i] = rand_inbounds(r)
end
end
else
dsfmt_fill_array_close_open!(r.state, A, n & (0xfffffffffffffffe % Int))
isodd(n) && (A[n] = rand(r))
end
A
end

rand(T::Type, dims::Dims) = rand!(Array(T, dims))
rand{T<:Number}(::Type{T}) = error("no random number generator for type $T; try a more specific type")
rand{T<:Number}(::Type{T}, dims::Int...) = rand(T, dims)
Expand Down