Skip to content

Commit

Permalink
Random,threads: allocate state at runtime for each thread
Browse files Browse the repository at this point in the history
The `Random.GLOBAL_RNG` is now a singleton placeholder object
which implements the prior `Random` public API for MersenneTwister
as a shim to support existing clients until Julia v2.0.
  • Loading branch information
vtjnash authored and JeffBezanson committed Jul 17, 2019
1 parent 0a12944 commit 84eca18
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 82 deletions.
70 changes: 61 additions & 9 deletions stdlib/Random/src/RNGs.jl
Expand Up @@ -2,9 +2,9 @@

## RandomDevice

# SamplerUnion(Union{X,Y,...}) == Union{SamplerType{X},SamplerType{Y},...}
SamplerUnion(U::Union) = Union{map(T->SamplerType{T}, Base.uniontypes(U))...}
const SamplerBoolBitInteger = SamplerUnion(Union{Bool, BitInteger})
# SamplerUnion(X, Y, ...}) == Union{SamplerType{X}, SamplerType{Y}, ...}
SamplerUnion(U...) = Union{Any[SamplerType{T} for T in U]...}
const SamplerBoolBitInteger = SamplerUnion(Bool, BitInteger_types...)

if Sys.iswindows()
struct RandomDevice <: AbstractRNG
Expand Down Expand Up @@ -285,14 +285,66 @@ function seed!(r::MersenneTwister, seed::Vector{UInt32})
return r
end

seed!(r::MersenneTwister=GLOBAL_RNG) = seed!(r, make_seed())
seed!(r::MersenneTwister=get_local_rng()) = seed!(r, make_seed())
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(GLOBAL_RNG, seed)
seed!(seed::Union{Integer,Vector{UInt32}}) = seed!(get_local_rng(), seed)


### Global RNG (must be defined after seed!)
### Global RNG

const GLOBAL_RNG = MersenneTwister(0)
const THREAD_RNGs = MersenneTwister[]
@inline get_local_rng() = get_local_rng(Threads.threadid())
@noinline Base.@pure function get_local_rng(tid::Int)
#tls = task_local_storage()
#RNG = get(tls, :RNG, nothing)
#RNG isa MersenneTwister && return RNG
if length(THREAD_RNGs) < tid
resize!(THREAD_RNGs, Threads.nthreads())
end
if @inbounds isassigned(THREAD_RNGs, tid)
@inbounds MT = THREAD_RNGs[tid]
else
MT = MersenneTwister()
@inbounds THREAD_RNGs[tid] = MT
end
return MT
end
__init__() = empty!(THREAD_RNGs) # ensures that we didn't save a bad object


struct _GLOBAL_RNG <: AbstractRNG
global const GLOBAL_RNG = _GLOBAL_RNG.instance
end

copy!(dst::MersenneTwister, ::_GLOBAL_RNG) = copy!(dst, get_local_rng())
copy!(::_GLOBAL_RNG, src::MersenneTwister) = copy!(get_local_rng(), src)
copy(::_GLOBAL_RNG) = copy(get_local_rng())

seed!(::_GLOBAL_RNG, seed::Vector{UInt32}) = seed!(get_local_rng(), seed)
seed!(::_GLOBAL_RNG, n::Integer) = seed!(get_local_rng(), n)
seed!(::_GLOBAL_RNG, ::Nothing) = seed!(get_local_rng(), nothing)

rng_native_52(::_GLOBAL_RNG) = rng_native_52(get_local_rng())
rand(::_GLOBAL_RNG, sp::SamplerBoolBitInteger) = rand(get_local_rng(), sp)
for T in (:(SamplerTrivial{UInt52Raw{UInt64}}),
:(SamplerTrivial{UInt2x52Raw{UInt128}}),
:(SamplerTrivial{UInt104Raw{UInt128}}),
:(SamplerTrivial{CloseOpen12_64}),
:(SamplerUnion(Int64, UInt64, Int128, UInt128)),
:(SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)),
)
@eval rand(::_GLOBAL_RNG, x::$T) = rand(get_local_rng(), x)
end

rand!(::_GLOBAL_RNG, A::AbstractArray{Float64}, I::SamplerTrivial{<:FloatInterval_64}) = rand!(get_local_rng(), A, I)
rand!(::_GLOBAL_RNG, A::Array{Float64}, I::SamplerTrivial{<:FloatInterval_64}) = rand!(get_local_rng(), A, I)
for T in (Float16, Float32)
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerTrivial{CloseOpen12{$T}}) = rand!(get_local_rng(), A, I)
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerTrivial{CloseOpen01{$T}}) = rand!(get_local_rng(), A, I)
end
for T in BitInteger_types
@eval rand!(::_GLOBAL_RNG, A::Array{$T}, I::SamplerType{$T}) = rand!(get_local_rng(), A, I)
end


### generation
Expand Down Expand Up @@ -332,10 +384,10 @@ rand(r::MersenneTwister, sp::SamplerTrivial{CloseOpen12_64}) =

#### integers

rand(r::MersenneTwister, T::SamplerUnion(Union{Int64,UInt64,Int128,UInt128})) =
rand(r::MersenneTwister, T::SamplerUnion(Int64, UInt64, Int128, UInt128)) =
mt_pop!(r, T[])

rand(r::MersenneTwister, T::SamplerUnion(Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32})) =
rand(r::MersenneTwister, T::SamplerUnion(Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)) =
rand(r, UInt52Raw()) % T[]

#### arrays of floats
Expand Down
31 changes: 10 additions & 21 deletions stdlib/Random/src/Random.jl
Expand Up @@ -108,7 +108,7 @@ abstract type Sampler{E} end
gentype(::Type{<:Sampler{E}}) where {E} = E

# temporarily for BaseBenchmarks
RangeGenerator(x) = Sampler(GLOBAL_RNG, x)
RangeGenerator(x) = Sampler(get_local_rng(), x)

# In some cases, when only 1 random value is to be generated,
# the optimal sampler can be different than if multiple values
Expand Down Expand Up @@ -247,18 +247,18 @@ rand(rng::AbstractRNG, ::UniformT{T}) where {T} = rand(rng, T)

#### scalars

rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG, X) = rand(rng, Sampler(rng, X, Val(1)))
# this is needed to disambiguate
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG=GLOBAL_RNG, ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG, X::Dims) = rand(rng, Sampler(rng, X, Val(1)))
rand(rng::AbstractRNG=get_local_rng(), ::Type{X}=Float64) where {X} = rand(rng, Sampler(rng, X, Val(1)))

rand(X) = rand(GLOBAL_RNG, X)
rand(::Type{X}) where {X} = rand(GLOBAL_RNG, X)
rand(X) = rand(get_local_rng(), X)
rand(::Type{X}) where {X} = rand(get_local_rng(), X)

#### arrays

rand!(A::AbstractArray{T}, X) where {T} = rand!(GLOBAL_RNG, A, X)
rand!(A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(GLOBAL_RNG, A, X)
rand!(A::AbstractArray{T}, X) where {T} = rand!(get_local_rng(), A, X)
rand!(A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(get_local_rng(), A, X)

rand!(rng::AbstractRNG, A::AbstractArray{T}, X) where {T} = rand!(rng, A, Sampler(rng, X))
rand!(rng::AbstractRNG, A::AbstractArray{T}, ::Type{X}=T) where {T,X} = rand!(rng, A, Sampler(rng, X))
Expand All @@ -274,7 +274,7 @@ rand(r::AbstractRNG, dims::Integer...) = rand(r, Float64, Dims(dims))
rand( dims::Integer...) = rand(Float64, Dims(dims))

rand(r::AbstractRNG, X, dims::Dims) = rand!(r, Array{gentype(X)}(undef, dims), X)
rand( X, dims::Dims) = rand(GLOBAL_RNG, X, dims)
rand( X, dims::Dims) = rand(get_local_rng(), X, dims)

rand(r::AbstractRNG, X, d::Integer, dims::Integer...) = rand(r, X, Dims((d, dims...)))
rand( X, d::Integer, dims::Integer...) = rand(X, Dims((d, dims...)))
Expand All @@ -283,23 +283,12 @@ rand( X, d::Integer, dims::Integer...) = rand(X, Dims((d, dims...
# moreover, a call like rand(r, NotImplementedType()) would be an infinite loop

rand(r::AbstractRNG, ::Type{X}, dims::Dims) where {X} = rand!(r, Array{X}(undef, dims), X)
rand( ::Type{X}, dims::Dims) where {X} = rand(GLOBAL_RNG, X, dims)
rand( ::Type{X}, dims::Dims) where {X} = rand(get_local_rng(), X, dims)

rand(r::AbstractRNG, ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(r, X, Dims((d, dims...)))
rand( ::Type{X}, d::Integer, dims::Integer...) where {X} = rand(X, Dims((d, dims...)))


## __init__ & include

function __init__()
try
seed!()
catch ex
Base.showerror_nostdio(ex,
"WARNING: Error during initialization of module Random")
end
end

include("RNGs.jl")
include("generation.jl")
include("normal.jl")
Expand Down
20 changes: 10 additions & 10 deletions stdlib/Random/src/misc.jl
Expand Up @@ -73,8 +73,8 @@ let b = UInt8['0':'9';'A':'Z';'a':'z']
global randstring
randstring(r::AbstractRNG, chars=b, n::Integer=8) = String(rand(r, chars, n))
randstring(r::AbstractRNG, n::Integer) = randstring(r, b, n)
randstring(chars=b, n::Integer=8) = randstring(GLOBAL_RNG, chars, n)
randstring(n::Integer) = randstring(GLOBAL_RNG, b, n)
randstring(chars=b, n::Integer=8) = randstring(get_local_rng(), chars, n)
randstring(n::Integer) = randstring(get_local_rng(), b, n)
end


Expand Down Expand Up @@ -140,7 +140,7 @@ julia> S
8
```
"""
randsubseq!(S::AbstractArray, A::AbstractArray, p::Real) = randsubseq!(GLOBAL_RNG, S, A, p)
randsubseq!(S::AbstractArray, A::AbstractArray, p::Real) = randsubseq!(get_local_rng(), S, A, p)

randsubseq(r::AbstractRNG, A::AbstractArray{T}, p::Real) where {T} =
randsubseq!(r, T[], A, p)
Expand All @@ -163,7 +163,7 @@ julia> randsubseq(rng, collect(1:8), 0.3)
8
```
"""
randsubseq(A::AbstractArray, p::Real) = randsubseq(GLOBAL_RNG, A, p)
randsubseq(A::AbstractArray, p::Real) = randsubseq(get_local_rng(), A, p)


## rand Less Than Masked 52 bits (helper function)
Expand Down Expand Up @@ -217,7 +217,7 @@ function shuffle!(r::AbstractRNG, a::AbstractArray)
return a
end

shuffle!(a::AbstractArray) = shuffle!(GLOBAL_RNG, a)
shuffle!(a::AbstractArray) = shuffle!(get_local_rng(), a)

"""
shuffle([rng=GLOBAL_RNG,] v::AbstractArray)
Expand Down Expand Up @@ -246,7 +246,7 @@ julia> shuffle(rng, Vector(1:10))
```
"""
shuffle(r::AbstractRNG, a::AbstractArray) = shuffle!(r, copymutable(a))
shuffle(a::AbstractArray) = shuffle(GLOBAL_RNG, a)
shuffle(a::AbstractArray) = shuffle(get_local_rng(), a)


## randperm & randperm!
Expand Down Expand Up @@ -277,7 +277,7 @@ julia> randperm(MersenneTwister(1234), 4)
```
"""
randperm(r::AbstractRNG, n::T) where {T <: Integer} = randperm!(r, Vector{T}(undef, n))
randperm(n::Integer) = randperm(GLOBAL_RNG, n)
randperm(n::Integer) = randperm(get_local_rng(), n)

"""
randperm!([rng=GLOBAL_RNG,] A::Array{<:Integer})
Expand Down Expand Up @@ -314,7 +314,7 @@ function randperm!(r::AbstractRNG, a::Array{<:Integer})
return a
end

randperm!(a::Array{<:Integer}) = randperm!(GLOBAL_RNG, a)
randperm!(a::Array{<:Integer}) = randperm!(get_local_rng(), a)


## randcycle & randcycle!
Expand Down Expand Up @@ -343,7 +343,7 @@ julia> randcycle(MersenneTwister(1234), 6)
```
"""
randcycle(r::AbstractRNG, n::T) where {T <: Integer} = randcycle!(r, Vector{T}(undef, n))
randcycle(n::Integer) = randcycle(GLOBAL_RNG, n)
randcycle(n::Integer) = randcycle(get_local_rng(), n)

"""
randcycle!([rng=GLOBAL_RNG,] A::Array{<:Integer})
Expand Down Expand Up @@ -379,4 +379,4 @@ function randcycle!(r::AbstractRNG, a::Array{<:Integer})
return a
end

randcycle!(a::Array{<:Integer}) = randcycle!(GLOBAL_RNG, a)
randcycle!(a::Array{<:Integer}) = randcycle!(get_local_rng(), a)
16 changes: 8 additions & 8 deletions stdlib/Random/src/normal.jl
Expand Up @@ -35,7 +35,7 @@ julia> randn(rng, ComplexF32, (2, 3))
0.611224+1.56403im 0.355204-0.365563im 0.0905552+1.31012im
```
"""
@inline function randn(rng::AbstractRNG=GLOBAL_RNG)
@inline function randn(rng::AbstractRNG=get_local_rng())
@inbounds begin
r = rand(rng, UInt52())
rabs = Int64(r>>1) # One bit for the sign
Expand Down Expand Up @@ -95,7 +95,7 @@ julia> randexp(rng, 3, 3)
0.695867 0.693292 0.643644
```
"""
function randexp(rng::AbstractRNG=GLOBAL_RNG)
function randexp(rng::AbstractRNG=get_local_rng())
@inbounds begin
ri = rand(rng, UInt52())
idx = ri & 0xFF
Expand Down Expand Up @@ -165,7 +165,7 @@ for randfun in [:randn, :randexp]
@eval begin
# scalars
$randfun(rng::AbstractRNG, T::BitFloatType) = convert(T, $randfun(rng))
$randfun(::Type{T}) where {T} = $randfun(GLOBAL_RNG, T)
$randfun(::Type{T}) where {T} = $randfun(get_local_rng(), T)

# filling arrays
function $randfun!(rng::AbstractRNG, A::AbstractArray{T}) where T
Expand All @@ -175,19 +175,19 @@ for randfun in [:randn, :randexp]
A
end

$randfun!(A::AbstractArray) = $randfun!(GLOBAL_RNG, A)
$randfun!(A::AbstractArray) = $randfun!(get_local_rng(), A)

# generating arrays
$randfun(rng::AbstractRNG, ::Type{T}, dims::Dims ) where {T} = $randfun!(rng, Array{T}(undef, dims))
# Note that this method explicitly does not define $randfun(rng, T),
# in order to prevent an infinite recursion.
$randfun(rng::AbstractRNG, ::Type{T}, dim1::Integer, dims::Integer...) where {T} = $randfun!(rng, Array{T}(undef, dim1, dims...))
$randfun( ::Type{T}, dims::Dims ) where {T} = $randfun(GLOBAL_RNG, T, dims)
$randfun( ::Type{T}, dims::Integer... ) where {T} = $randfun(GLOBAL_RNG, T, dims...)
$randfun( ::Type{T}, dims::Dims ) where {T} = $randfun(get_local_rng(), T, dims)
$randfun( ::Type{T}, dims::Integer... ) where {T} = $randfun(get_local_rng(), T, dims...)
$randfun(rng::AbstractRNG, dims::Dims ) = $randfun(rng, Float64, dims)
$randfun(rng::AbstractRNG, dims::Integer... ) = $randfun(rng, Float64, dims...)
$randfun( dims::Dims ) = $randfun(GLOBAL_RNG, Float64, dims)
$randfun( dims::Integer... ) = $randfun(GLOBAL_RNG, Float64, dims...)
$randfun( dims::Dims ) = $randfun(get_local_rng(), Float64, dims)
$randfun( dims::Integer... ) = $randfun(get_local_rng(), Float64, dims...)
end
end

Expand Down
59 changes: 58 additions & 1 deletion stdlib/Random/test/runtests.jl
Expand Up @@ -601,7 +601,7 @@ end

# Random.seed!(rng, ...) returns rng (#21248)
guardseed() do
g = Random.GLOBAL_RNG
g = Random.get_local_rng()
m = MersenneTwister(0)
@test Random.seed!() === g
@test Random.seed!(rand(UInt)) === g
Expand Down Expand Up @@ -713,3 +713,60 @@ end
@test rand((x, 2, 3, 4, 6)) 1:6
end
end

@testset "GLOBAL_RNG" begin
local GLOBAL_RNG = Random.GLOBAL_RNG
local LOCAL_RNG = Random.get_local_rng()
@test VERSION < v"2" # deprecate this in v2

@test Random.seed!(GLOBAL_RNG, nothing) === LOCAL_RNG
@test Random.seed!(GLOBAL_RNG, UInt32[0]) === LOCAL_RNG
@test Random.seed!(GLOBAL_RNG, 0) === LOCAL_RNG

mt = MersenneTwister(1)
@test copy!(mt, GLOBAL_RNG) === mt
@test mt == LOCAL_RNG
Random.seed!(mt, 2)
@test mt != LOCAL_RNG
@test copy!(GLOBAL_RNG, mt) === LOCAL_RNG
@test mt == LOCAL_RNG
mt2 = copy(GLOBAL_RNG)
@test mt2 isa typeof(LOCAL_RNG)
@test mt2 !== LOCAL_RNG
@test mt2 == LOCAL_RNG

for T in (Random.UInt52Raw{UInt64},
Random.UInt2x52Raw{UInt128},
Random.UInt104Raw{UInt128},
Random.CloseOpen12_64)
x = Random.SamplerTrivial(T())
@test rand(GLOBAL_RNG, x) === rand(mt, x)
end
for T in (Int64, UInt64, Int128, UInt128, Bool, Int8, UInt8, Int16, UInt16, Int32, UInt32)
x = Random.SamplerType{T}()
@test rand(GLOBAL_RNG, x) === rand(mt, x)
end

A = fill(0.0, 100, 100)
B = fill(1.0, 100, 100)
vA = view(A, :, :)
vB = view(B, :, :)
I1 = Random.SamplerTrivial(Random.CloseOpen01{Float64}())
I2 = Random.SamplerTrivial(Random.CloseOpen12{Float64}())
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, vA, I1) === vA
rand!(mt, vB, I1)
@test A == B
for T in (Float16, Float32)
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, I2) === A == rand!(mt, B, I2) === B
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, I1) === A == rand!(mt, B, I1) === B
end
for T in Base.BitInteger_types
x = Random.SamplerType{T}()
B = fill!(B, 1.0)
@test rand!(GLOBAL_RNG, A, x) === A == rand!(mt, B, x) === B
end
end
2 changes: 1 addition & 1 deletion stdlib/SparseArrays/src/SparseArrays.jl
Expand Up @@ -29,7 +29,7 @@ import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
rotl90, rotr90, round, setindex!, similar, size, transpose,
vec, permute!, map, map!, Array, diff, circshift!, circshift

using Random: GLOBAL_RNG, AbstractRNG, randsubseq, randsubseq!
using Random: get_local_rng, AbstractRNG, randsubseq, randsubseq!

export AbstractSparseArray, AbstractSparseMatrix, AbstractSparseVector,
SparseMatrixCSC, SparseVector, blockdiag, droptol!, dropzeros!, dropzeros,
Expand Down

0 comments on commit 84eca18

Please sign in to comment.