Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random: support threads for GLOBAL_RNG access #32407

Merged
merged 1 commit into from
Jul 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Multi-threading changes
This does not include subtypes of `IO` that are entirely in-memory, such as `IOBuffer`,
although it specifically does include `BufferStream`.
([#32309], [#32174], [#31981], [#32421]).
* The global random number generator (`GLOBAL_RNG`) is now thread-safe (and thread-local) ([#32407]).

Build system changes
--------------------
Expand Down
67 changes: 58 additions & 9 deletions stdlib/Random/src/RNGs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

## RandomDevice

# SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...}
SamplerUnion(U::Union) = Union{map(T->SamplerType{T}, Base.uniontypes(U))...}
const SamplerBoolBitInteger = SamplerUnion(Union{Bool, BitInteger})
# SamplerUnion(X, Y, ...}) == Union{SamplerType{X}, SamplerType{Y}, ...}
SamplerUnion(U...) = Union{Any[SamplerType{T} for T in U]...}
const SamplerBoolBitInteger = SamplerUnion(Bool, BitInteger_types...)

if Sys.iswindows()
struct RandomDevice <: AbstractRNG
Expand Down Expand Up @@ -285,14 +285,63 @@ function seed!(r::MersenneTwister, seed::Vector{UInt32})
return r
end

seed!(r::MersenneTwister=GLOBAL_RNG) = seed!(r, make_seed())
seed!(r::MersenneTwister=get_local_rng()) = seed!(r, make_seed())
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(GLOBAL_RNG, seed)
seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(get_local_rng(), seed)
vtjnash marked this conversation as resolved.
Show resolved Hide resolved


### Global RNG (must be defined after seed!)
### Global RNG

const GLOBAL_RNG = MersenneTwister(0)
const THREAD_RNGs = MersenneTwister[]
@inline get_local_rng() = get_local_rng(Threads.threadid())
@noinline function get_local_rng(tid::Int)
@assert 0 < tid <= length(THREAD_RNGs)
if @inbounds isassigned(THREAD_RNGs, tid)
@inbounds MT = THREAD_RNGs[tid]
else
MT = MersenneTwister()
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
@inbounds THREAD_RNGs[tid] = MT
end
return MT
end
function __init__()
resize!(empty!(THREAD_RNGs), Threads.nthreads()) # ensures that we didn't save a bad object
end


struct _GLOBAL_RNG <: AbstractRNG
global const GLOBAL_RNG = _GLOBAL_RNG.instance
end

copy!(dst::MersenneTwister, ::_GLOBAL_RNG) = copy!(dst, get_local_rng())
copy!(::_GLOBAL_RNG, src::MersenneTwister) = copy!(get_local_rng(), src)
copy(::_GLOBAL_RNG) = copy(get_local_rng())

seed!(::_GLOBAL_RNG, seed::Vector{UInt32}) = seed!(get_local_rng(), seed)
seed!(::_GLOBAL_RNG, n::Integer) = seed!(get_local_rng(), n)
seed!(::_GLOBAL_RNG, ::Nothing) = seed!(get_local_rng(), nothing)

rng_native_52(::_GLOBAL_RNG) = rng_native_52(get_local_rng())
rand(::_GLOBAL_RNG, sp::SamplerBoolBitInteger) = rand(get_local_rng(), sp)
for T in (:(SamplerTrivial{UInt52Raw{UInt64}}),
:(SamplerTrivial{UInt2x52Raw{UInt128}}),
:(SamplerTrivial{UInt104Raw{UInt128}}),
:(SamplerTrivial{CloseOpen12_64}),
:(SamplerUnion(Int64, UInt64, Int128, UInt128)),
:(SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)),
)
@eval rand(::_GLOBAL_RNG, x::$T) = rand(get_local_rng(), x)
end

rand!(::_GLOBAL_RNG, A::AbstractArray{Float64}, I::SamplerTrivial{<:FloatInterval_64}) = rand!(get_local_rng(), A, I)
rand!(::_GLOBAL_RNG, A::Array{Float64}, I::SamplerTrivial{<:FloatInterval_64}) = rand!(get_local_rng(), A, I)
for T in (Float16, Float32)
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerTrivial{CloseOpen12{$T}}) = rand!(get_local_rng(), A, I)
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerTrivial{CloseOpen01{$T}}) = rand!(get_local_rng(), A, I)
end
for T in BitInteger_types
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerType{$T}) = rand!(get_local_rng(), A, I)
end


### generation
Expand Down Expand Up @@ -332,10 +381,10 @@ rand(r::MersenneTwister, sp::SamplerTrivial{CloseOpen12_64}) =

#### integers

rand(r::MersenneTwister, T::SamplerUnion(Union{Int64,UInt64,Int128,UInt128})) =
rand(r::MersenneTwister, T::SamplerUnion(Int64, UInt64, Int128, UInt128)) =
mt_pop!(r, T[])

rand(r::MersenneTwister, T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
rand(r::MersenneTwister, T::SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)) =
rand(r, UInt52Raw()) % T[]

#### arrays of floats
Expand Down
34 changes: 13 additions & 21 deletions stdlib/Random/src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ abstract type Sampler{E} end
gentype(::Type{<:Sampler{E}}) where {E} = E

# temporarily for BaseBenchmarks
RangeGenerator(x) = Sampler(GLOBAL_RNG, x)
RangeGenerator(x) = Sampler(get_local_rng(), x)

# In some cases, when only 1 random value is to be generated,
# the optimal sampler can be different than if multiple values
Expand Down Expand Up @@ -247,18 +247,18 @@ rand(rng::AbstractRNG, ::UniformT{T}) where {T} = rand(rng, T)

#### scalars

rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
# this is needed to disambiguate
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG=GLOBAL_RNG, ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG=get_local_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))

rand(X) = rand(GLOBAL_RNG, X)
rand(::Type{X}) where {X} = rand(GLOBAL_RNG, X)
rand(X) = rand(get_local_rng(), X)
rand(::Type{X}) where {X} = rand(get_local_rng(), X)

#### arrays

rand!(A::AbstractArray{T}, X) where {T} = rand!(GLOBAL_RNG, A, X)
rand!(A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(GLOBAL_RNG, A, X)
rand!(A::AbstractArray{T}, X) where {T} = rand!(get_local_rng(), A, X)
rand!(A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(get_local_rng(), A, X)

rand!(rng::AbstractRNG, A::AbstractArray{T}, X) where {T} = rand!(rng, A, Sampler(rng, X))
rand!(rng::AbstractRNG, A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(rng, A, Sampler(rng, X))
Expand All @@ -274,7 +274,7 @@ rand(r::AbstractRNG, dims::Integer...) = rand(r, Float64, Dims(dims))
rand( dims::Integer...) = rand(Float64, Dims(dims))

rand(r::AbstractRNG, X, dims::Dims) = rand!(r, Array{gentype(X)}(undef, dims), X)
rand( X, dims::Dims) = rand(GLOBAL_RNG, X, dims)
rand( X, dims::Dims) = rand(get_local_rng(), X, dims)

rand(r::AbstractRNG, X, d::Integer, dims::Integer...) = rand(r, X, Dims((d, dims...)))
rand( X, d::Integer, dims::Integer...) = rand(X, Dims((d, dims...)))
Expand All @@ -283,23 +283,12 @@ rand( X, d::Integer, dims::Integer...) = rand(X, Dims((d, dims...
# moreover, a call like rand(r, NotImplementedType()) would be an infinite loop

rand(r::AbstractRNG, ::Type{X}, dims::Dims) where {X} = rand!(r, Array{X}(undef, dims), X)
rand( ::Type{X}, dims::Dims) where {X} = rand(GLOBAL_RNG, X, dims)
rand( ::Type{X}, dims::Dims) where {X} = rand(get_local_rng(), X, dims)

rand(r::AbstractRNG, ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(r, X, Dims((d, dims...)))
rand( ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(X, Dims((d, dims...)))


## __init__ & include

function __init__()
try
seed!()
catch ex
Base.showerror_nostdio(ex,
"WARNING: Error during initialization of module Random")
end
end

include("RNGs.jl")
include("generation.jl")
include("normal.jl")
Expand Down Expand Up @@ -382,6 +371,9 @@ don't accept a seed, like `RandomDevice`.
After the call to `seed!`, `rng` is equivalent to a newly created
object initialized with the same seed.

If `rng` is not specified, it defaults to seeding the state of the
shared thread-local generator.

# Examples
```julia-repl
julia> Random.seed!(1234);
Expand Down
20 changes: 10 additions & 10 deletions stdlib/Random/src/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ let b = UInt8['0':'9';'A':'Z';'a':'z']
global randstring
randstring(r::AbstractRNG, chars=b, n::Integer=8) = String(rand(r, chars, n))
randstring(r::AbstractRNG, n::Integer) = randstring(r, b, n)
randstring(chars=b, n::Integer=8) = randstring(GLOBAL_RNG, chars, n)
randstring(n::Integer) = randstring(GLOBAL_RNG, b, n)
randstring(chars=b, n::Integer=8) = randstring(get_local_rng(), chars, n)
randstring(n::Integer) = randstring(get_local_rng(), b, n)
end


Expand Down Expand Up @@ -140,7 +140,7 @@ julia> S
8
```
"""
randsubseq!(S::AbstractArray, A::AbstractArray, p::Real) = randsubseq!(GLOBAL_RNG, S, A, p)
randsubseq!(S::AbstractArray, A::AbstractArray, p::Real) = randsubseq!(get_local_rng(), S, A, p)

randsubseq(r::AbstractRNG, A::AbstractArray{T}, p::Real) where {T} =
randsubseq!(r, T[], A, p)
Expand All @@ -163,7 +163,7 @@ julia> randsubseq(rng, collect(1:8), 0.3)
8
```
"""
randsubseq(A::AbstractArray, p::Real) = randsubseq(GLOBAL_RNG, A, p)
randsubseq(A::AbstractArray, p::Real) = randsubseq(get_local_rng(), A, p)


## rand Less Than Masked 52 bits (helper function)
Expand Down Expand Up @@ -217,7 +217,7 @@ function shuffle!(r::AbstractRNG, a::AbstractArray)
return a
end

shuffle!(a::AbstractArray) = shuffle!(GLOBAL_RNG, a)
shuffle!(a::AbstractArray) = shuffle!(get_local_rng(), a)

"""
shuffle([rng=GLOBAL_RNG,] v::AbstractArray)
Expand Down Expand Up @@ -246,7 +246,7 @@ julia> shuffle(rng, Vector(1:10))
```
"""
shuffle(r::AbstractRNG, a::AbstractArray) = shuffle!(r, copymutable(a))
shuffle(a::AbstractArray) = shuffle(GLOBAL_RNG, a)
shuffle(a::AbstractArray) = shuffle(get_local_rng(), a)


## randperm & randperm!
Expand Down Expand Up @@ -277,7 +277,7 @@ julia> randperm(MersenneTwister(1234), 4)
```
"""
randperm(r::AbstractRNG, n::T) where {T <: Integer} = randperm!(r, Vector{T}(undef, n))
randperm(n::Integer) = randperm(GLOBAL_RNG, n)
randperm(n::Integer) = randperm(get_local_rng(), n)

"""
randperm!([rng=GLOBAL_RNG,] A::Array{<:Integer})
Expand Down Expand Up @@ -314,7 +314,7 @@ function randperm!(r::AbstractRNG, a::Array{<:Integer})
return a
end

randperm!(a::Array{<:Integer}) = randperm!(GLOBAL_RNG, a)
randperm!(a::Array{<:Integer}) = randperm!(get_local_rng(), a)


## randcycle & randcycle!
Expand Down Expand Up @@ -343,7 +343,7 @@ julia> randcycle(MersenneTwister(1234), 6)
```
"""
randcycle(r::AbstractRNG, n::T) where {T <: Integer} = randcycle!(r, Vector{T}(undef, n))
randcycle(n::Integer) = randcycle(GLOBAL_RNG, n)
randcycle(n::Integer) = randcycle(get_local_rng(), n)

"""
randcycle!([rng=GLOBAL_RNG,] A::Array{<:Integer})
Expand Down Expand Up @@ -379,4 +379,4 @@ function randcycle!(r::AbstractRNG, a::Array{<:Integer})
return a
end

randcycle!(a::Array{<:Integer}) = randcycle!(GLOBAL_RNG, a)
randcycle!(a::Array{<:Integer}) = randcycle!(get_local_rng(), a)
16 changes: 8 additions & 8 deletions stdlib/Random/src/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ julia> randn(rng, ComplexF32, (2, 3))
0.611224+1.56403im 0.355204-0.365563im 0.0905552+1.31012im
```
"""
@inline function randn(rng::AbstractRNG=GLOBAL_RNG)
@inline function randn(rng::AbstractRNG=get_local_rng())
@inbounds begin
r = rand(rng, UInt52())
rabs = Int64(r>>1) # One bit for the sign
Expand Down Expand Up @@ -95,7 +95,7 @@ julia> randexp(rng, 3, 3)
0.695867 0.693292 0.643644
```
"""
function randexp(rng::AbstractRNG=GLOBAL_RNG)
function randexp(rng::AbstractRNG=get_local_rng())
@inbounds begin
ri = rand(rng, UInt52())
idx = ri & 0xFF
Expand Down Expand Up @@ -165,7 +165,7 @@ for randfun in [:randn, :randexp]
@eval begin
# scalars
$randfun(rng::AbstractRNG, T::BitFloatType) = convert(T, $randfun(rng))
$randfun(::Type{T}) where {T} = $randfun(GLOBAL_RNG, T)
$randfun(::Type{T}) where {T} = $randfun(get_local_rng(), T)

# filling arrays
function $randfun!(rng::AbstractRNG, A::AbstractArray{T}) where T
Expand All @@ -175,19 +175,19 @@ for randfun in [:randn, :randexp]
A
end

$randfun!(A::AbstractArray) = $randfun!(GLOBAL_RNG, A)
$randfun!(A::AbstractArray) = $randfun!(get_local_rng(), A)

# generating arrays
$randfun(rng::AbstractRNG, ::Type{T}, dims::Dims ) where {T} = $randfun!(rng, Array{T}(undef, dims))
# Note that this method explicitly does not define $randfun(rng, T),
# in order to prevent an infinite recursion.
$randfun(rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...) where {T} = $randfun!(rng, Array{T}(undef, dim1, dims...))
$randfun( ::Type{T}, dims::Dims ) where {T} = $randfun(GLOBAL_RNG, T, dims)
$randfun( ::Type{T}, dims::Integer... ) where {T} = $randfun(GLOBAL_RNG, T, dims...)
$randfun( ::Type{T}, dims::Dims ) where {T} = $randfun(get_local_rng(), T, dims)
$randfun( ::Type{T}, dims::Integer... ) where {T} = $randfun(get_local_rng(), T, dims...)
$randfun(rng::AbstractRNG, dims::Dims ) = $randfun(rng, Float64, dims)
$randfun(rng::AbstractRNG, dims::Integer... ) = $randfun(rng, Float64, dims...)
$randfun( dims::Dims ) = $randfun(GLOBAL_RNG, Float64, dims)
$randfun( dims::Integer... ) = $randfun(GLOBAL_RNG, Float64, dims...)
$randfun( dims::Dims ) = $randfun(get_local_rng(), Float64, dims)
$randfun( dims::Integer... ) = $randfun(get_local_rng(), Float64, dims...)
end
end

Expand Down
59 changes: 58 additions & 1 deletion stdlib/Random/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ end

# Random.seed!(rng, ...) returns rng (#21248)
guardseed() do
g = Random.GLOBAL_RNG
g = Random.get_local_rng()
m = MersenneTwister(0)
@test Random.seed!() === g
@test Random.seed!(rand(UInt)) === g
Expand Down Expand Up @@ -713,3 +713,60 @@ end
@test rand((x, 2, 3, 4, 6)) ∈ 1:6
end
end

@testset "GLOBAL_RNG" begin
local GLOBAL_RNG = Random.GLOBAL_RNG
local LOCAL_RNG = Random.get_local_rng()
@test VERSION < v"2" # deprecate this in v2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _GLOBAL_RNG thing is rather clever! Do I get it right that it's only to avoid breaking code and that it's intended to be removed in v2?


@test Random.seed!(GLOBAL_RNG, nothing) === LOCAL_RNG
@test Random.seed!(GLOBAL_RNG, UInt32[0]) === LOCAL_RNG
@test Random.seed!(GLOBAL_RNG, 0) === LOCAL_RNG

mt = MersenneTwister(1)
@test copy!(mt, GLOBAL_RNG) === mt
@test mt == LOCAL_RNG
Random.seed!(mt, 2)
@test mt != LOCAL_RNG
@test copy!(GLOBAL_RNG, mt) === LOCAL_RNG
@test mt == LOCAL_RNG
mt2 = copy(GLOBAL_RNG)
@test mt2 isa typeof(LOCAL_RNG)
@test mt2 !== LOCAL_RNG
@test mt2 == LOCAL_RNG

for T in (Random.UInt52Raw{UInt64},
Random.UInt2x52Raw{UInt128},
Random.UInt104Raw{UInt128},
Random.CloseOpen12_64)
x = Random.SamplerTrivial(T())
@test rand(GLOBAL_RNG, x) === rand(mt, x)
end
for T in (Int64, UInt64, Int128, UInt128, Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)
x = Random.SamplerType{T}()
@test rand(GLOBAL_RNG, x) === rand(mt, x)
end

A = fill(0.0, 100, 100)
B = fill(1.0, 100, 100)
vA = view(A, :, :)
vB = view(B, :, :)
I1 = Random.SamplerTrivial(Random.CloseOpen01{Float64}())
I2 = Random.SamplerTrivial(Random.CloseOpen12{Float64}())
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, vA, I1) === vA
rand!(mt, vB, I1)
@test A == B
for T in (Float16, Float32)
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, I2) === A == rand!(mt, B, I2) === B
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
end
for T in Base.BitInteger_types
x = Random.SamplerType{T}()
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, x) === A == rand!(mt, B, x) === B
end
end
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/SparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
rotl90, rotr90, round, setindex!, similar, size, transpose,
vec, permute!, map, map!, Array, diff, circshift!, circshift

using Random: GLOBAL_RNG, AbstractRNG, randsubseq, randsubseq!
using Random: get_local_rng, AbstractRNG, randsubseq, randsubseq!

export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
SparseMatrixCSC, SparseVector, blockdiag, droptol!, dropzeros!, dropzeros,
Expand Down
Loading