Skip to content

Commit

Permalink
Add a device-side rand().
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Mar 19, 2021
1 parent 451a338 commit 32e7705
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 3 deletions.
32 changes: 29 additions & 3 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,10 +193,21 @@ end

## host-side kernels

struct HostKernel{F,TT} <: AbstractKernel{F,TT}
mutable struct HostKernel{F,TT} <: AbstractKernel{F,TT}
ctx::CuContext
mod::CuModule
fun::CuFunction
uses_random_state::Bool
random_state::Union{Nothing,Mem.DeviceBuffer}
function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, uses_random_state::Bool) where {F,TT}
kernel = new{F,TT}(ctx, mod, fun, uses_random_state, nothing)
finalizer(kernel) do k
if !isnothing(k.random_state)
Mem.free(k.random_state)
end
end
return kernel
end
end

@doc (@doc AbstractKernel) HostKernel
Expand Down Expand Up @@ -345,10 +356,25 @@ function cufunction_link(@nospecialize(job::CompilerJob), compiled)
filter!(!isequal("exception_flag"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun)
uses_random_state = false

if "global_random_state" in compiled.external_gvars
uses_random_state = true
filter!(!isequal("global_random_state"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, uses_random_state)
end

(kernel::HostKernel)(args...; kwargs...) = call(kernel, map(cudaconvert, args)...; kwargs...)
function (kernel::HostKernel)(args...; kwargs...)
if kernel.uses_random_state
kws = Dict(kwargs)
num_threads = get(kws, :threads, 1) * get(kws, :blocks, 1)

init_random_state!(kernel, num_threads)
end
call(kernel, map(cudaconvert, args)...; kwargs...)
end


## device-side kernels
Expand Down
1 change: 1 addition & 0 deletions src/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ include("intrinsics/memory_dynamic.jl")
include("intrinsics/atomics.jl")
include("intrinsics/misc.jl")
include("intrinsics/wmma.jl")
include("intrinsics/random.jl")

# functionality from libdevice
#
Expand Down
82 changes: 82 additions & 0 deletions src/device/intrinsics/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
function get_global_thread_index()
block_id = (gridDim().x * (blockIdx().y - 1)) + (blockIdx().x - 1)

thread_id = (block_id * (blockDim().x * blockDim().y)) +
((threadIdx().y - 1) * blockDim().x) +
(threadIdx().x - 1)

return thread_id + 1
end

struct ThreadLocalRNGState
val::UInt32
end

function xorshift(x::UInt32)::UInt32
x = xor(x, x << 13)
x = xor(x, x >> 17)
x = xor(x, x << 5)
return x
end

function generate_next_state(state::ThreadLocalRNGState)
new_val = xorshift(state.val)
return ThreadLocalRNGState(new_val)
end

@eval @inline global_random_state() =
Base.llvmcall(
$("""@global_random_state = weak externally_initialized global i$(WORD_SIZE) 0
define i$(WORD_SIZE) @entry() #0 {
%ptr = load i$(WORD_SIZE), i$(WORD_SIZE)* @global_random_state, align 8
ret i$(WORD_SIZE) %ptr
}
attributes #0 = { alwaysinline }
""", "entry"), Ptr{Cvoid}, Tuple{})

function get_random_state_ptr()
random_state_address = global_random_state()
random_state_ptr = LLVMPtr{ThreadLocalRNGState, AS.Generic}(random_state_address)
# FIXME: we should wrap this pointer in a CuDeviceVector but this leads to some very weird assertion errors
end

function seed_rng!(seed=nothing)
if isnothing(seed)
seed = clock(UInt32)
end

# make sure the seed is different for each thread
index = get_global_thread_index()
seed += (UInt32(index) + UInt32(10)) ^ 5

random_state_ptr = get_random_state_ptr()
align = Base.datatype_alignment(ThreadLocalRNGState)
unsafe_store!(random_state_ptr, ThreadLocalRNGState(seed), index, Val(align))
end

function rand()
random_state_ptr = get_random_state_ptr()
index = get_global_thread_index()

align = Base.datatype_alignment(ThreadLocalRNGState)
state = unsafe_load(random_state_ptr, index, Val(align))
new_state = generate_next_state(state)
unsafe_store!(random_state_ptr, new_state, index, Val(align))

res = (new_state.val >> 9) | reinterpret(UInt32, 1f0)
return reinterpret(Float32, res) - 1.0f0
end

function init_random_state!(kernel, len)
required_size = sizeof(ThreadLocalRNGState) * len

if isnothing(kernel.random_state)
kernel.random_state = Mem.alloc(Mem.Device, required_size)
elseif sizeof(kernel.random_state) < required_size
Mem.free(kernel.random_state, async=true)
kernel.random_state = Mem.alloc(Mem.Device, required_size)
end

random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state")
random_state_ptr[] = reinterpret(Ptr{Cvoid}, convert(CuPtr{Cvoid}, kernel.random_state))
end
46 changes: 46 additions & 0 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1179,4 +1179,50 @@ end
synchronize()
end

@testset "random" begin

n = 256

@testset "basic" begin
function kernel(A::CuDeviceArray{Float32}, B::CuDeviceArray{Float32})
tid = threadIdx().x
CUDA.seed_rng!()
A[tid] = CUDA.rand()
B[tid] = CUDA.rand()
return nothing
end

a = zeros(Float32, n)
dev_a = CuArray(a)
b = zeros(Float32, n)
dev_b = CuArray(b)

@cuda threads=n kernel(dev_a, dev_b)

# FIXME(?): all these tests have a (very small) chance to fail, but that's somewhat inherent to rand() without a seed
@test allunique(Array(dev_a))
@test allunique(Array(dev_b))
@test Array(dev_a) != Array(dev_b)
end

@testset "custom seed" begin
function kernel(A::CuDeviceArray{Float32})
tid = threadIdx().x
CUDA.seed_rng!(1234)
A[tid] = CUDA.rand()
return nothing
end

a = zeros(Float32, n)
dev_a = CuArray(a)
b = zeros(Float32, n)
dev_b = CuArray(b)

@cuda threads=n kernel(dev_a)
@cuda threads=n kernel(dev_b)

@test Array(dev_a) == Array(dev_b)
end
end

end

0 comments on commit 32e7705

Please sign in to comment.