Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 27 additions & 38 deletions stdlib/Random/src/MersenneTwister.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const MT_CACHE_I = 501 << 4 # number of bytes in the UInt128 cache
@assert dsfmt_get_min_array_size() <= MT_CACHE_F

mutable struct MersenneTwister <: AbstractRNG
seed::Any
seed::NTuple{2, UInt128}
const state::DSFMT_state
const vals::Memory{Float64}
const ints::Vector{UInt128} # it's temporarily resized internally
Expand All @@ -22,7 +22,7 @@ mutable struct MersenneTwister <: AbstractRNG
adv_ints::Int64 # state of advance when ints is filled-up

global _MersenneTwister(::UndefInitializer) =
new(nothing, DSFMT_state(),
new((UInt128(0), UInt128(0)), DSFMT_state(),
Memory{Float64}(undef, MT_CACHE_F),
Vector{UInt128}(undef, MT_CACHE_I >> 4),
MT_CACHE_F, 0, 0, Base.GMP.ZERO, -1, -1)
Expand Down Expand Up @@ -89,33 +89,26 @@ hash(r::MersenneTwister, h::UInt) =
foldr(hash, (r.seed, r.state, r.vals, r.ints, r.idxF, r.idxI); init=h)

function show(io::IO, rng::MersenneTwister)
sep = ", "
# seed
print(io, MersenneTwister, "(", repr(rng.seed[1]), sep, repr(rng.seed[2]))
if rng.adv_jump == 0 && rng.adv == 0
return print(io, MersenneTwister, "(", repr(rng.seed), ")")
return print(io, ")")
end
print(io, MersenneTwister, "(", repr(rng.seed), ", (")
# state
sep = ", "
show(io, rng.adv_jump)
print(io, sep)
show(io, rng.adv)
print(io, sep, rng.adv_jump, sep, rng.adv)
if rng.adv_vals != -1 || rng.adv_ints != -1
# "(0, 0)" is nicer on the eyes than (-1, 1002)
s = rng.adv_vals != -1
print(io, sep)
show(io, s ? rng.adv_vals : zero(rng.adv_vals))
print(io, sep)
show(io, s ? rng.idxF : zero(rng.idxF))
print(io, sep, s ? rng.adv_vals : zero(rng.adv_vals),
sep, s ? rng.idxF : zero(rng.idxF))
end
if rng.adv_ints != -1
idxI = (length(rng.ints)*16 - rng.idxI) / 8 # 8 represents one Int64
idxI = Int(idxI) # idxI should always be an integer when using public APIs
print(io, sep)
show(io, rng.adv_ints)
print(io, sep)
show(io, idxI)
print(io, sep, rng.adv_ints, sep, idxI)
end
print(io, "))")
print(io, ")")
end

### low level API
Expand Down Expand Up @@ -226,27 +219,19 @@ end

#### seed!()

function initstate!(r::MersenneTwister, data::StridedVector, seed)
# we deepcopy `seed` because the caller might mutate it, and it's useful
# to keep it constant inside `MersenneTwister`; but multiple instances
# can share the same seed without any problem (e.g. in `copy`)
r.seed = deepcopy(seed)
dsfmt_init_by_array(r.state, reinterpret(UInt32, data))
function initstate!(r::MersenneTwister, seed)
r.seed = seed # store the seed for `show`
seedvec = view(r.ints, 1:2) # re-use r.ints to temporarily store the seed
seedvec .= seed
dsfmt_init_by_array(r.state, reinterpret(UInt32, seedvec))
reset_caches!(r)
r.adv = 0
r.adv_jump = Base.GMP.ZERO
return r
end

# When a seed is not provided, we generate one via `RandomDevice()` rather
# than calling directly `initstate!` with `rand(RandomDevice(), UInt32, 8)` because the
# seed is printed in `show(::MersenneTwister)`, so we need one; the cost of `hash_seed` is a
# small overhead compared to `initstate!`.
# A random seed with 128 bits is a good compromise for almost surely getting distinct
# seeds, while having them printed reasonably tersely.
seed!(r::MersenneTwister, seeder::AbstractRNG) = seed!(r, rand(seeder, UInt128))
seed!(r::MersenneTwister, ::Nothing) = seed!(r, RandomDevice())
seed!(r::MersenneTwister, seed) = initstate!(r, rand(SeedHasher(seed), UInt32, 8), seed)
seed!(r::MersenneTwister, seeder::AbstractRNG) =
initstate!(r, rand(seeder, NTuple{2, UInt128}))


### generation
Expand Down Expand Up @@ -575,14 +560,18 @@ jump!(r::MersenneTwister, steps::Integer) = copy!(r, jump(r, steps))
# 3, 4: .adv_vals, .idxF (counters to reconstruct the float cache, optional if 5-6 not shown))
# 5, 6: .adv_ints, .idxI (counters to reconstruct the integer cache, optional)

Random.MersenneTwister(seed, advance::NTuple{6,Integer}) =
advance!(MersenneTwister(seed), advance...)
MersenneTwister(s1::Integer, s2::Integer) = initstate!(_MersenneTwister(undef), (s1, s2))

MersenneTwister(s1::Integer, s2::Integer, s3::Integer, s4::Integer,
s5::Integer, s6::Integer, s7::Integer, s8::Integer) =
advance!(MersenneTwister(s1, s2), s3, s4, s5, s6, s7, s8)

Random.MersenneTwister(seed, advance::NTuple{4,Integer}) =
MersenneTwister(seed, (advance..., 0, 0))
MersenneTwister(s1::Integer, s2::Integer, s3::Integer, s4::Integer,
s5::Integer, s6::Integer) =
MersenneTwister(s1, s2, s3, s4, s5, s6, 0, 0)

Random.MersenneTwister(seed, advance::NTuple{2,Integer}) =
MersenneTwister(seed, (advance..., 0, 0, 0, 0))
MersenneTwister(s1::Integer, s2::Integer, s3::Integer, s4::Integer) =
MersenneTwister(s1, s2, s3, s4, 0, 0, 0, 0)

# advances raw state (per fill_array!) of r by n steps (Float64 values)
function _advance_n!(r::MersenneTwister, n::Int64, work::Vector{Float64})
Expand Down
52 changes: 13 additions & 39 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -646,18 +646,6 @@ end
# MersenneTwister initialization with invalid values
@test_throws DomainError DSFMT.DSFMT_state(zeros(Int32, rand(0:DSFMT.JN32-1)))

# seed is private to MersenneTwister
let seed = rand(UInt32, 10)
r = MersenneTwister(seed)
@test r.seed == seed && r.seed !== seed
let r2 = Future.randjump(r, big(10)^20)
Random.seed!(r2)
@test seed == r.seed != r2.seed
end
resize!(seed, 4)
@test r.seed != seed
end

@testset "Random.seed!(rng, ...) returns rng" begin
# issue #21248
seed = rand(UInt)
Expand Down Expand Up @@ -957,42 +945,28 @@ end
@testset "show" begin
@testset "MersenneTwister" begin
m = MersenneTwister(123)
@test string(m) == "MersenneTwister(123)"
@test string(m) == "MersenneTwister(0xf80cc98e147960c1fefa8d41b8f5dca5, 0xea7a7dcb2e787c0120e2ccc17662fc1d)"
@test m == MersenneTwister(0xf80cc98e147960c1fefa8d41b8f5dca5, 0xea7a7dcb2e787c0120e2ccc17662fc1d)
Random.jump!(m, 2*big(10)^20)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 0))"
@test m == MersenneTwister(123, (200000000000000000000, 0))
@test string(m) == "MersenneTwister(0xf80cc98e147960c1fefa8d41b8f5dca5, 0xea7a7dcb2e787c0120e2ccc17662fc1d, 200000000000000000000, 0)"
@test m == MersenneTwister(0xf80cc98e147960c1fefa8d41b8f5dca5, 0xea7a7dcb2e787c0120e2ccc17662fc1d, 200000000000000000000, 0)
rand(m)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 1002, 0, 1))"
@test string(m) == "MersenneTwister(0xf80cc98e147960c1fefa8d41b8f5dca5, 0xea7a7dcb2e787c0120e2ccc17662fc1d, 200000000000000000000, 1002, 0, 1)"

@test m == MersenneTwister(123, (200000000000000000000, 1002, 0, 1))
@test m == MersenneTwister(0xf80cc98e147960c1fefa8d41b8f5dca5, 0xea7a7dcb2e787c0120e2ccc17662fc1d, 200000000000000000000, 1002, 0, 1)
rand(m, Int64)
@test string(m) == "MersenneTwister(123, (200000000000000000000, 2256, 0, 1, 1002, 1))"
@test m == MersenneTwister(123, (200000000000000000000, 2256, 0, 1, 1002, 1))
@test string(m) == "MersenneTwister(0xf80cc98e147960c1fefa8d41b8f5dca5, 0xea7a7dcb2e787c0120e2ccc17662fc1d, 200000000000000000000, 2256, 0, 1, 1002, 1)"
@test m == MersenneTwister(0xf80cc98e147960c1fefa8d41b8f5dca5, 0xea7a7dcb2e787c0120e2ccc17662fc1d, 200000000000000000000, 2256, 0, 1, 1002, 1)

m = MersenneTwister(0x0ecfd77f89dcd508caa37a17ebb7556b)
@test string(m) == "MersenneTwister(0x0ecfd77f89dcd508caa37a17ebb7556b)"
@test string(m) == "MersenneTwister(0x07a0cc280198a55c39fa6f802d242f8b, 0x8472a002c9dd8879235ae29f67bc7496)"
rand(m, Int64)
@test string(m) == "MersenneTwister(0x0ecfd77f89dcd508caa37a17ebb7556b, (0, 1254, 0, 0, 0, 1))"
@test m == MersenneTwister(0xecfd77f89dcd508caa37a17ebb7556b, (0, 1254, 0, 0, 0, 1))
@test string(m) == "MersenneTwister(0x07a0cc280198a55c39fa6f802d242f8b, 0x8472a002c9dd8879235ae29f67bc7496, 0, 1254, 0, 0, 0, 1)"
@test m == MersenneTwister(0x07a0cc280198a55c39fa6f802d242f8b, 0x8472a002c9dd8879235ae29f67bc7496, 0, 1254, 0, 0, 0, 1)

m = MersenneTwister(0); rand(m, Int64); rand(m)
@test string(m) == "MersenneTwister(0, (0, 2256, 1254, 1, 0, 1))"
@test m == MersenneTwister(0, (0, 2256, 1254, 1, 0, 1))

# negative seeds
Random.seed!(m, -3)
@test string(m) == "MersenneTwister(-3)"
Random.seed!(m, typemin(Int8))
@test string(m) == "MersenneTwister(-128)"

# string seeds
Random.seed!(m, "seed 1")
@test string(m) == "MersenneTwister(\"seed 1\")"
x = rand(m)
@test x == rand(MersenneTwister("seed 1"))
@test string(m) == """MersenneTwister("seed 1", (0, 1002, 0, 1))"""
# test that MersenneTwister's fancy constructors accept string seeds
@test MersenneTwister("seed 1", (0, 1002, 0, 1)) == m
@test string(m) == "MersenneTwister(0x48d73dc42d195740db2fa90498613fdf, 0x1911b814c02405e88c49bc52dc8a77ea, 0, 2256, 1254, 1, 0, 1)"
@test m == MersenneTwister(0x48d73dc42d195740db2fa90498613fdf, 0x1911b814c02405e88c49bc52dc8a77ea, 0, 2256, 1254, 1, 0, 1)
end

@testset "RandomDevice" begin
Expand Down