Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jump functions (jump, long_jump) for Xoshiro #47743

Merged
merged 11 commits into from
Sep 28, 2023
102 changes: 102 additions & 0 deletions stdlib/Random/src/Xoshiro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,108 @@ end

rng_native_52(::Xoshiro) = UInt64

# Jump functions from: https://xoshiro.di.unimi.it/xoshiro256plusplus.c

for (fname, JUMP) in ((:jump_128, (0x180ec6d33cfd0aba, 0xd5a61266f0c9392c, 0xa9582618e03fc9aa, 0x39abdc4529b1661c)),
(:jump_192, (0x76e15d3efefdcbbf, 0xc5004e441c522fb3, 0x77710069854ee241, 0x39109bb02acbe635)))
local fname! = Symbol(fname, :!)
@eval function $fname!(rng::Xoshiro)
_s0 = 0x0000000000000000
_s1 = 0x0000000000000000
_s2 = 0x0000000000000000
_s3 = 0x0000000000000000
s0, s1, s2, s3 = rng.s0, rng.s1, rng.s2, rng.s3
for j in $JUMP
for b in 0x0000000000000000:0x000000000000003f
if (j & 0x0000000000000001 << b) != 0
_s0 ⊻= s0
_s1 ⊻= s1
_s2 ⊻= s2
_s3 ⊻= s3
end
t = s1 << 17
s2 = xor(s2, s0)
s3 = xor(s3, s1)
s1 = xor(s1, s2)
s0 = xor(s0, s3)
s2 = xor(s2, t)
s3 = s3 << 45 | s3 >> 19
end
end
setstate!(rng, (_s0, _s1, _s2, _s3, nothing))
end
@eval $fname(rng::Xoshiro) = $fname!(copy(rng))

@eval function $fname!(rng::Xoshiro, n::Integer)
n < 0 && throw(DomainError(n, "the number of jumps must be ≥ 0"))
i = zero(n)
while i < n
$fname!(rng)
i += one(n)
end
rng
end

@eval $fname(rng::Xoshiro, n::Integer) = $fname!(copy(rng), n)
end

for (fname, sz) in ((:jump_128, 128), (:jump_192, 192))
local fname! = Symbol(fname, :!)
local see_other = Symbol(fname === :jump_128 ? :jump_192 : :jump_128)
local see_other! = Symbol(see_other, :!)
local seq_pow = 256 - sz
@eval begin
"""
$($fname!)(rng::Xoshiro, [n::Integer=1])

Jump forward, advancing the state equivalent to `2^$($sz)` calls which consume
8 bytes (i.e. a full `UInt64`) each.

If `n > 0` is provided, the state is advanced equivalent to `n * 2^$($sz)` calls; if `n = 0`,
the state remains unchanged.

This can be used to generate `2^$($seq_pow)` non-overlapping subsequences for parallel computations.

See also: [`$($fname)`](@ref), [`$($see_other!)`](@ref)

# Examples
```julia-repl
julia> $($fname!)($($fname!)(Xoshiro(1))) == $($fname!)(Xoshiro(1), 2)
true
```
"""
function $fname! end
end

@eval begin
"""
$($fname)(rng::Xoshiro, [n::Integer=1])

Return a copy of `rng` with the state advanced equivalent to `n * 2^$($sz)` calls which consume
8 bytes (i.e. a full `UInt64`) each; if `n = 0`, the state of the returned copy will be
identical to `rng`.

This can be used to generate `2^$($seq_pow)` non-overlapping subsequences for parallel computations.

See also: [`$($fname!)`](@ref), [`$($see_other)`](@ref)

# Examples
```julia-repl
julia> x = Xoshiro(1);

julia> $($fname)($($fname)(x)) == $($fname)(x, 2)
true

julia> $($fname)(x, 0) == x
true

julia> $($fname)(x, 0) === x
false
```
"""
function $fname end
end
end

## Task local RNG

Expand Down
64 changes: 64 additions & 0 deletions stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Random
using Random.DSFMT

using Random: Sampler, SamplerRangeFast, SamplerRangeInt, SamplerRangeNDL, MT_CACHE_F, MT_CACHE_I
using Random: jump, long_jump, jump!, long_jump!
andrewjradcliffe marked this conversation as resolved.
Show resolved Hide resolved

import Future # randjump

Expand Down Expand Up @@ -1036,6 +1037,7 @@ guardseed() do
end
end

<<<<<<< HEAD
andrewjradcliffe marked this conversation as resolved.
Show resolved Hide resolved
@testset "TaskLocalRNG: stream collision smoke test" begin
# spawn a trinary tree of tasks:
# - spawn three recursive child tasks in each
Expand Down Expand Up @@ -1105,3 +1107,65 @@ end
@test TaskLocalRNG() == rng3
end
end

# Xoshiro jumps
@testset "Xoshiro jump, basic" begin
x1 = Xoshiro(1)
x2 = Xoshiro(1)

@test jump_128!(jump_128!(x1)) == jump_128!(x1, 2)

xo1 = Xoshiro(0xfff0241072ddab67, 0xc53bc12f4c3f0b4e, 0x56d451780b2dd4ba, 0x50a4aa153d208dd8)
@test rand(jump_128(xo1), UInt64) == 0x87c158da8c35824d
@test rand(jump_192(xo1), UInt64) == 0xcaecd5afdd0847d5

@test rand(jump_128(xo1, 98765), UInt64) == 0xcbec1d5053142608
@test rand(jump_192(xo1, 98765), UInt64) == 0x3b97a94c44d66216

# Throws where appropriate
@test_throws DomainError jump_128(Xoshiro(1), -1)
@test_throws DomainError jump_128!(Xoshiro(1), -1)
@test_throws DomainError jump_192(Xoshiro(1), -1)
@test_throws DomainError jump_192!(Xoshiro(1), -1)

# clean copy when non-mut and no state advance
x = Xoshiro(1)
@test jump_128(x, 0) == x
@test jump_128(x, 0) !== x
@test jump_192(x, 0) == x
@test jump_192(x, 0) !== x

y = Xoshiro(1)
@test jump_128!(x, 0) == y
@test jump_192!(x, 0) == y
end

@testset "Xoshiro jump, various seeds" begin
for seed in (0, 1, 0xa0a3f09d0cecd878, 0x7ff8)
x = Xoshiro(seed)
@test jump_128(jump_128(jump_128(x))) == jump_128(x, 3)
x1 = Xoshiro(seed)
@test jump_128!(jump_128!(jump_128!(x1))) == jump_128(x, 3)
jump_128!(x1, 997)
x2 = jump_128!(Xoshiro(seed), 1000)
for T ∈ (Float64, UInt64, Int, Char, Bool)
@test rand(x1, T, 5) == rand(x2, T, 5)
@test rand(jump_128!(x1), T, 5) == rand(jump_128!(x2), T, 5)
end
end
end

@testset "Xoshiro jump_192, various seeds" begin
for seed in (0, 1, 0xa0a3f09d0cecd878, 0x7ff8)
x = Xoshiro(seed)
@test jump_192(jump_192(jump_192(x))) == jump_192(x, 3)
x1 = Xoshiro(seed)
@test jump_192!(jump_192!(jump_192!(x1))) == jump_192(x, 3)
jump_192!(x1, 997)
x2 = jump_192!(Xoshiro(seed), 1000)
for T ∈ (Float64, UInt64, Int, Char, Bool)
@test rand(x1, T, 5) == rand(x2, T, 5)
@test rand(jump_192!(x1), T, 5) == rand(jump_192!(x2), T, 5)
end
end
end