Skip to content

Commit

Permalink
Use RandomNumbers.jl to provide rand() for other types.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Mar 24, 2021
1 parent 466943c commit e5fa433
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 34 deletions.
6 changes: 6 additions & 0 deletions Manifest.toml
Expand Up @@ -190,6 +190,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.4.0"

[[Reexport]]
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -21,6 +21,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
30 changes: 8 additions & 22 deletions src/device/intrinsics/random.jl
@@ -1,6 +1,7 @@
## random number generation

using Random
import RandomNumbers


# helpers
Expand All @@ -11,7 +12,7 @@ global_index() = (threadIdx().x, threadIdx().y, threadIdx().z,

# global state

struct ThreadLocalRNG <: AbstractRNG
struct ThreadLocalXorshift32 <: RandomNumbers.AbstractRNG{UInt32}
vals::CuDeviceArray{UInt32, 6, AS.Generic}
end

Expand All @@ -37,13 +38,13 @@ end
CuDeviceArray(dims, ptr)
end

@device_override Random.default_rng() = ThreadLocalRNG(global_random_state())
@device_override Random.default_rng() = ThreadLocalXorshift32(global_random_state())

@device_override Random.make_seed() = clock(UInt32)

function Random.seed!(rng::ThreadLocalRNG, seed::Integer)
function Random.seed!(rng::ThreadLocalXorshift32, seed::Integer)
index = global_index()
rng.vals[index...] = seed
rng.vals[index...] = seed % UInt32
return
end

Expand All @@ -57,29 +58,14 @@ function xorshift(x::UInt32)::UInt32
return x
end

function get_thread_word(rng::ThreadLocalRNG)
function Random.rand(rng::ThreadLocalXorshift32, ::Type{UInt32})
# NOTE: we add the current linear index to the local state, to make sure threads get
# different random numbers when unseeded (initial state = 0 for all threads)
index = global_index()
offset = LinearIndices(rng.vals)[index...]
state = rng.vals[index...] + UInt32(offset)

new_state = generate_next_state(state)
new_state = xorshift(state)
rng.vals[index...] = new_state

return new_state # FIXME: return old state?
end

function generate_next_state(state::UInt32)
new_val = xorshift(state)
return UInt32(new_val)
end

# TODO: support for more types (can we reuse more of the Random standard library?)
# see RandomNumbers.jl

function Random.rand(rng::ThreadLocalRNG, ::Type{Float32})
word = get_thread_word(rng)
res = (word >> 9) | reinterpret(UInt32, 1f0)
return reinterpret(Float32, res) - 1.0f0
return new_state
end
32 changes: 20 additions & 12 deletions test/device/intrinsics.jl
Expand Up @@ -1197,15 +1197,20 @@ n = 256
return nothing
end

a = CUDA.zeros(Float32, n)
b = CUDA.zeros(Float32, n)
@testset for T in (Int32, UInt32, Int64, UInt64, Int128, UInt128,
Float32, Float64)
a = CUDA.zeros(T, n)
b = CUDA.zeros(T, n)

@cuda threads=n kernel(a, b)
@cuda threads=n kernel(a, b)

# FIXME(?): all these tests have a (very small) chance to fail, but that's somewhat inherent to rand() without a seed
@test allunique(Array(a))
@test allunique(Array(b))
@test Array(a) != Array(b)
@test all(Array(a) .!= Array(b))

if T == Float64
@test allunique(Array(a))
@test allunique(Array(b))
end
end
end

@testset "custom seed" begin
Expand All @@ -1216,13 +1221,16 @@ end
return nothing
end

a = CUDA.zeros(Float32, n)
b = CUDA.zeros(Float32, n)
@testset for T in (Int32, UInt32, Int64, UInt64, Int128, UInt128,
Float32, Float64)
a = CUDA.zeros(T, n)
b = CUDA.zeros(T, n)

@cuda threads=n kernel(a)
@cuda threads=n kernel(b)
@cuda threads=n kernel(a)
@cuda threads=n kernel(b)

@test Array(a) == Array(b)
@test Array(a) == Array(b)
end
end

end

0 comments on commit e5fa433

Please sign in to comment.