Skip to content

Commit

Permalink
Keep the shared memory keey ot support re-seeding.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Aug 15, 2023
1 parent 2788a9c commit 2e00444
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions src/device/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@ import RandomNumbers

# XXX: sharing state means that with multiple generators we can't guarantee determinism.

# shared memory with the actual seed, per warp, initialized upon RNG construction
# or by calling `seed!`
@eval @inline function global_random_keys()
ptr = Base.llvmcall(
$("""@global_random_keys = weak addrspace($(AS.Shared)) global [32 x i32] zeroinitializer, align 32
define i8 addrspace($(AS.Shared))* @entry() #0 {
%ptr = getelementptr inbounds [32 x i32], [32 x i32] addrspace($(AS.Shared))* @global_random_keys, i64 0, i64 0
%untyped_ptr = bitcast i32 addrspace($(AS.Shared))* %ptr to i8 addrspace($(AS.Shared))*
ret i8 addrspace($(AS.Shared))* %untyped_ptr
}
attributes #0 = { alwaysinline }
""", "entry"), LLVMPtr{UInt32, AS.Shared}, Tuple{})
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
end

# shared memory with per-warp counters, incremented when generating numbers
@eval @inline function global_random_counters()
ptr = Base.llvmcall(
Expand All @@ -31,6 +46,11 @@ using Random123: philox2x_round, philox2x_bumpkey

# GPU-compatible/optimized version of the generator from Random123.jl
struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64}
@inline function Philox2x32{R}() where R
rng = new{R}()
rng.key = kernel_state().random_seed
return rng
end
end

# default to 7 rounds; enough to pass SmallCrush
Expand All @@ -41,8 +61,8 @@ end
(threadIdx().z - 1i32) * blockDim().x * blockDim().y
warpId = (threadId - 1i32) >> 0x5 + 1i32 # fld1

if field === :seed
@inbounds global_random_seed()[1]
if field === :key
@inbounds global_random_keys()[warpId]
elseif field === :key
kernel_state().random_seed
elseif field === :ctr1
Expand All @@ -60,7 +80,9 @@ end
(threadIdx().z - 1i32) * blockDim().x * blockDim().y
warpId = (threadId - 1i32) >> 0x5 + 1i32 # fld1

if field === :ctr1
if field === :key
@inbounds global_random_keys()[warpId] = x
elseif field === :ctr1
@inbounds global_random_counters()[warpId] = x
end
end
Expand Down Expand Up @@ -114,7 +136,7 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
# to a unique location so the duplicate writes are innocuous
# NOTE: this update is not guaranteed to be visible in subsequent kernel launche,
# e.g., see JuliaGPU/CUDA.jl#2008
# XXX: what if this overflows? we can't increment ctr2, and the key is immutable.
# XXX: what if this overflows? we can't increment ctr2. bump the key?
rng.ctr1 += 1i32

# NOTE: it's too expensive to keep both numbers around in case the user only wanted one,
Expand Down

0 comments on commit 2e00444

Please sign in to comment.