Skip to content

Commit

Permalink
MersenneTwister: hash seeds like for Xoshiro (#51436)
Browse files Browse the repository at this point in the history
This addresses a part of #37165:

> It's common that sequential seeds for RNGs are not as independent as
one might like.

This clears out this problem for `MersenneTwister`, and makes it easy to
add the same feature to other RNGs via a new `hash_seed` function, which
replaces `make_seed`.

This is an alternative to #37766.
  • Loading branch information
rfourquet committed Sep 29, 2023
1 parent 3a85776 commit ab992b9
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 182 deletions.
8 changes: 4 additions & 4 deletions stdlib/Random/docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ Random.SamplerSimple
Decoupling pre-computation from actually generating the values is part of the API, and is also available to the user. As an example, assume that `rand(rng, 1:20)` has to be called repeatedly in a loop: the way to take advantage of this decoupling is as follows:

```julia
rng = MersenneTwister()
sp = Random.Sampler(rng, 1:20) # or Random.Sampler(MersenneTwister, 1:20)
rng = Xoshiro()
sp = Random.Sampler(rng, 1:20) # or Random.Sampler(Xoshiro, 1:20)
for x in X
n = rand(rng, sp) # similar to n = rand(rng, 1:20)
# use n
Expand Down Expand Up @@ -159,8 +159,8 @@ Scalar and array methods for `Die` now work as expected:
julia> rand(Die)
Die(5)
julia> rand(MersenneTwister(0), Die)
Die(11)
julia> rand(Xoshiro(0), Die)
Die(10)
julia> rand(Die, 3)
3-element Vector{Die}:
Expand Down
3 changes: 2 additions & 1 deletion stdlib/Random/src/DSFMT.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ function dsfmt_init_gen_rand(s::DSFMT_state, seed::UInt32)
s.val, seed)
end

function dsfmt_init_by_array(s::DSFMT_state, seed::Vector{UInt32})
function dsfmt_init_by_array(s::DSFMT_state, seed::StridedVector{UInt32})
strides(seed) == (1,) || throw(ArgumentError("seed must have its stride equal to 1"))
ccall((:dsfmt_init_by_array,:libdSFMT),
Cvoid,
(Ptr{Cvoid}, Ptr{UInt32}, Int32),
Expand Down
132 changes: 68 additions & 64 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The entropy is obtained from the operating system.
"""
struct RandomDevice <: AbstractRNG; end
RandomDevice(seed::Nothing) = RandomDevice()
seed!(rng::RandomDevice) = rng
seed!(rng::RandomDevice, ::Nothing) = rng

rand(rd::RandomDevice, sp::SamplerBoolBitInteger) = Libc.getrandom!(Ref{sp[]}())[]
rand(rd::RandomDevice, ::SamplerType{Bool}) = rand(rd, UInt8) % Bool
Expand Down Expand Up @@ -44,7 +44,7 @@ const MT_CACHE_I = 501 << 4 # number of bytes in the UInt128 cache
@assert dsfmt_get_min_array_size() <= MT_CACHE_F

mutable struct MersenneTwister <: AbstractRNG
seed::Vector{UInt32}
seed::Any
state::DSFMT_state
vals::Vector{Float64}
ints::Vector{UInt128}
Expand All @@ -70,7 +70,7 @@ mutable struct MersenneTwister <: AbstractRNG
end
end

MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
MersenneTwister(seed, state::DSFMT_state) =
MersenneTwister(seed, state,
Vector{Float64}(undef, MT_CACHE_F),
Vector{UInt128}(undef, MT_CACHE_I >> 4),
Expand All @@ -92,19 +92,17 @@ See the [`seed!`](@ref) function for reseeding an already existing `MersenneTwis
# Examples
```jldoctest
julia> rng = MersenneTwister(1234);
julia> rng = MersenneTwister(123);
julia> x1 = rand(rng, 2)
2-element Vector{Float64}:
0.5908446386657102
0.7667970365022592
0.37453777969575874
0.8735343642013971
julia> rng = MersenneTwister(1234);
julia> x2 = rand(rng, 2)
julia> x2 = rand(MersenneTwister(123), 2)
2-element Vector{Float64}:
0.5908446386657102
0.7667970365022592
0.37453777969575874
0.8735343642013971
julia> x1 == x2
true
Expand All @@ -115,7 +113,7 @@ MersenneTwister(seed=nothing) =


function copy!(dst::MersenneTwister, src::MersenneTwister)
copyto!(resize!(dst.seed, length(src.seed)), src.seed)
dst.seed = src.seed
copy!(dst.state, src.state)
copyto!(dst.vals, src.vals)
copyto!(dst.ints, src.ints)
Expand All @@ -129,7 +127,7 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
end

copy(src::MersenneTwister) =
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
MersenneTwister(src.seed, copy(src.state), copy(src.vals), copy(src.ints),
src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints)


Expand All @@ -144,12 +142,10 @@ hash(r::MersenneTwister, h::UInt) =

function show(io::IO, rng::MersenneTwister)
# seed
seed = from_seed(rng.seed)
seed_str = seed <= typemax(Int) ? string(seed) : "0x" * string(seed, base=16) # DWIM
if rng.adv_jump == 0 && rng.adv == 0
return print(io, MersenneTwister, "(", seed_str, ")")
return print(io, MersenneTwister, "(", repr(rng.seed), ")")
end
print(io, MersenneTwister, "(", seed_str, ", (")
print(io, MersenneTwister, "(", repr(rng.seed), ", (")
# state
adv = Integer[rng.adv_jump, rng.adv]
if rng.adv_vals != -1 || rng.adv_ints != -1
Expand Down Expand Up @@ -277,76 +273,84 @@ end

### seeding

#### make_seed()
#### random_seed() & hash_seed()

# make_seed produces values of type Vector{UInt32}, suitable for MersenneTwister seeding
function make_seed()
# random_seed tries to produce a random seed of type UInt128 from system entropy
function random_seed()
try
return rand(RandomDevice(), UInt32, 4)
# as MersenneTwister prints its seed when `show`ed, 128 bits is a good compromise for
# almost surely always getting distinct seeds, while having them printed reasonably tersely
return rand(RandomDevice(), UInt128)
catch ex
ex isa IOError || rethrow()
@warn "Entropy pool not available to seed RNG; using ad-hoc entropy sources."
return make_seed(Libc.rand())
return Libc.rand()
end
end

"""
make_seed(n::Integer) -> Vector{UInt32}
Transform `n` into a bit pattern encoded as a `Vector{UInt32}`, suitable for
RNG seeding routines.
`make_seed` is "injective" : if `n != m`, then `make_seed(n) != `make_seed(m)`.
Moreover, if `n == m`, then `make_seed(n) == make_seed(m)`.
This is an internal function, subject to change.
"""
function make_seed(n::Integer)
neg = signbit(n)
function hash_seed(seed::Integer)
ctx = SHA.SHA2_256_CTX()
neg = signbit(seed)
if neg
n = ~n
end
@assert n >= 0
seed = UInt32[]
# we directly encode the bit pattern of `n` into the resulting vector `seed`;
# to greatly limit breaking the streams of random numbers, we encode the sign bit
# as the upper bit of `seed[end]` (i.e. for most positive seeds, `make_seed` returns
# the same vector as when we didn't encode the sign bit)
while !iszero(n)
push!(seed, n & 0xffffffff)
n >>>= 32
seed = ~seed
end
if isempty(seed) || !iszero(seed[end] & 0x80000000)
push!(seed, zero(UInt32))
end
if neg
seed[end] |= 0x80000000
@assert seed >= 0
while true
word = (seed % UInt32) & 0xffffffff
seed >>>= 32
SHA.update!(ctx, reinterpret(NTuple{4, UInt8}, word))
iszero(seed) && break
end
seed
# make sure the hash of negative numbers is different from the hash of positive numbers
neg && SHA.update!(ctx, (0x01,))
SHA.digest!(ctx)
end

# inverse of make_seed(::Integer)
function from_seed(a::Vector{UInt32})::BigInt
neg = !iszero(a[end] & 0x80000000)
seed = sum((i == length(a) ? a[i] & 0x7fffffff : a[i]) * big(2)^(32*(i-1))
for i in 1:length(a))
neg ? ~seed : seed
function hash_seed(seed::Union{AbstractArray{UInt32}, AbstractArray{UInt64}})
ctx = SHA.SHA2_256_CTX()
for xx in seed
SHA.update!(ctx, reinterpret(NTuple{8, UInt8}, UInt64(xx)))
end
# discriminate from hash_seed(::Integer)
SHA.update!(ctx, (0x10,))
SHA.digest!(ctx)
end


"""
hash_seed(seed) -> AbstractVector{UInt8}
Return a cryptographic hash of `seed` of size 256 bits (32 bytes).
`seed` can currently be of type `Union{Integer, DenseArray{UInt32}, DenseArray{UInt64}}`,
but modules can extend this function for types they own.
`hash_seed` is "injective" : if `n != m`, then `hash_seed(n) != `hash_seed(m)`.
Moreover, if `n == m`, then `hash_seed(n) == hash_seed(m)`.
This is an internal function subject to change.
"""
hash_seed

#### seed!()

function seed!(r::MersenneTwister, seed::Vector{UInt32})
copyto!(resize!(r.seed, length(seed)), seed)
dsfmt_init_by_array(r.state, r.seed)
function initstate!(r::MersenneTwister, data::StridedVector, seed)
# we deepcopy `seed` because the caller might mutate it, and it's useful
# to keep it constant inside `MersenneTwister`; but multiple instances
# can share the same seed without any problem (e.g. in `copy`)
r.seed = deepcopy(seed)
dsfmt_init_by_array(r.state, reinterpret(UInt32, data))
reset_caches!(r)
r.adv = 0
r.adv_jump = 0
return r
end

seed!(r::MersenneTwister) = seed!(r, make_seed())
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
# when a seed is not provided, we generate one via `RandomDevice()` in `random_seed()` rather
# than calling directly `initstate!` with `rand(RandomDevice(), UInt32, whatever)` because the
# seed is printed in `show(::MersenneTwister)`, so we need one; the cost of `hash_seed` is a
# small overhead compared to `initstate!`, so this simple solution is fine
seed!(r::MersenneTwister, ::Nothing) = seed!(r, random_seed())
seed!(r::MersenneTwister, seed) = initstate!(r, hash_seed(seed), seed)


### Global RNG
Expand Down Expand Up @@ -713,7 +717,7 @@ end
function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X)
adv = r.adv
adv_jump = r.adv_jump
s = MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly))
s = MersenneTwister(r.seed, DSFMT.dsfmt_jump(r.state, jumppoly))
reset_caches!(s)
s.adv = adv
s.adv_jump = adv_jump
Expand Down
25 changes: 14 additions & 11 deletions stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ julia> rand(Int, 2)
julia> using Random
julia> rand(MersenneTwister(0), Dict(1=>2, 3=>4))
1=>2
julia> rand(Xoshiro(0), Dict(1=>2, 3=>4))
3 => 4
julia> rand((2, 3))
3
Expand Down Expand Up @@ -389,15 +389,13 @@ but without allocating a new array.
# Examples
```jldoctest
julia> rng = MersenneTwister(1234);
julia> rand!(rng, zeros(5))
julia> rand!(Xoshiro(123), zeros(5))
5-element Vector{Float64}:
0.5908446386657102
0.7667970365022592
0.5662374165061859
0.4600853424625171
0.7940257103317943
0.521213795535383
0.5868067574533484
0.8908786980927811
0.19090669902576285
0.5256623915420473
```
"""
rand!
Expand Down Expand Up @@ -452,6 +450,11 @@ julia> rand(Xoshiro(), Bool) # not reproducible either
true
```
"""
seed!(rng::AbstractRNG, ::Nothing) = seed!(rng)
seed!(rng::AbstractRNG) = seed!(rng, nothing)
#=
We have this generic definition instead of the alternative option
`seed!(rng::AbstractRNG, ::Nothing) = seed!(rng)`
because it would lead too easily to ambiguities, e.g. when we define `seed!(::Xoshiro, seed)`.
=#

end # module
17 changes: 8 additions & 9 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,16 +230,20 @@ rng_native_52(::TaskLocalRNG) = UInt64
## Shared implementation between Xoshiro and TaskLocalRNG

# this variant of setstate! initializes the internal splitmix state, a.k.a. `s4`
@inline initstate!(x::Union{TaskLocalRNG, Xoshiro}, (s0, s1, s2, s3)::NTuple{4, UInt64}) =
@inline function initstate!(x::Union{TaskLocalRNG, Xoshiro}, state)
length(state) == 4 && eltype(state) == UInt64 ||
throw(ArgumentError("initstate! expects a list of 4 `UInt64` values"))
s0, s1, s2, s3 = state
setstate!(x, (s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3))
end

copy(rng::Union{TaskLocalRNG, Xoshiro}) = Xoshiro(getstate(rng)...)
copy!(dst::Union{TaskLocalRNG, Xoshiro}, src::Union{TaskLocalRNG, Xoshiro}) = setstate!(dst, getstate(src))
==(x::Union{TaskLocalRNG, Xoshiro}, y::Union{TaskLocalRNG, Xoshiro}) = getstate(x) == getstate(y)
# use a magic (random) number to scramble `h` so that `hash(x)` is distinct from `hash(getstate(x))`
hash(x::Union{TaskLocalRNG, Xoshiro}, h::UInt) = hash(getstate(x), h + 0x49a62c2dda6fa9be % UInt)

function seed!(rng::Union{TaskLocalRNG, Xoshiro})
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, ::Nothing)
# as we get good randomness from RandomDevice, we can skip hashing
rd = RandomDevice()
s0 = rand(rd, UInt64)
Expand All @@ -249,14 +253,9 @@ function seed!(rng::Union{TaskLocalRNG, Xoshiro})
initstate!(rng, (s0, s1, s2, s3))
end

function seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
c = SHA.SHA2_256_CTX()
SHA.update!(c, reinterpret(UInt8, seed))
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
initstate!(rng, (s0, s1, s2, s3))
end
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed) =
initstate!(rng, reinterpret(UInt64, hash_seed(seed)))

seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))

@inline function rand(x::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt64})
s0, s1, s2, s3 = getstate(x)
Expand Down

0 comments on commit ab992b9

Please sign in to comment.