Skip to content

Commit

Permalink
Simplify implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Mar 19, 2021
1 parent 32e7705 commit 269a2dd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 46 deletions.
29 changes: 13 additions & 16 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,13 @@ mutable struct HostKernel{F,TT} <: AbstractKernel{F,TT}
ctx::CuContext
mod::CuModule
fun::CuFunction
uses_random_state::Bool
random_state::Union{Nothing,Mem.DeviceBuffer}
function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, uses_random_state::Bool) where {F,TT}
kernel = new{F,TT}(ctx, mod, fun, uses_random_state, nothing)

random_state::Union{Nothing,Missing,Mem.DeviceBuffer}

function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, random_state) where {F,TT}
kernel = new{F,TT}(ctx, mod, fun, random_state)
finalizer(kernel) do k
if !isnothing(k.random_state)
if k.random_state isa Mem.DeviceBuffer
Mem.free(k.random_state)
end
end
Expand Down Expand Up @@ -356,24 +357,20 @@ function cufunction_link(@nospecialize(job::CompilerJob), compiled)
filter!(!isequal("exception_flag"), compiled.external_gvars)
end

uses_random_state = false

random_state = nothing
if "global_random_state" in compiled.external_gvars
uses_random_state = true
random_state = missing
filter!(!isequal("global_random_state"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, uses_random_state)
return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, random_state)
end

function (kernel::HostKernel)(args...; kwargs...)
if kernel.uses_random_state
kws = Dict(kwargs)
num_threads = get(kws, :threads, 1) * get(kws, :blocks, 1)

init_random_state!(kernel, num_threads)
function (kernel::HostKernel)(args...; threads::Integer=1, blocks::Integer=1, kwargs...)
if kernel.random_state !== nothing
init_random_state!(kernel, threads * blocks)
end
call(kernel, map(cudaconvert, args)...; kwargs...)
call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...)
end


Expand Down
74 changes: 44 additions & 30 deletions src/device/intrinsics/random.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,30 @@
function get_global_thread_index()
block_id = (gridDim().x * (blockIdx().y - 1)) + (blockIdx().x - 1)
## random number generation

thread_id = (block_id * (blockDim().x * blockDim().y)) +
((threadIdx().y - 1) * blockDim().x) +
(threadIdx().x - 1)
using Random

# TODO: use an RNG, overriding Random.default_rng()?
# or otherwise reuse more of the Random stdlib?

return thread_id + 1
end

# global state

struct ThreadLocalRNGState
val::UInt32
end

function xorshift(x::UInt32)::UInt32
x = xor(x, x << 13)
x = xor(x, x >> 17)
x = xor(x, x << 5)
return x
end
function init_random_state!(kernel, len)
required_size = sizeof(ThreadLocalRNGState) * len

function generate_next_state(state::ThreadLocalRNGState)
new_val = xorshift(state.val)
return ThreadLocalRNGState(new_val)
if kernel.random_state === missing
kernel.random_state = Mem.alloc(Mem.Device, required_size; async=true)
elseif sizeof(kernel.random_state) < required_size
Mem.free(kernel.random_state, async=true)
kernel.random_state = Mem.alloc(Mem.Device, required_size; async=true)
end
@show kernel.random_state

random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state")
random_state_ptr[] = reinterpret(Ptr{Cvoid}, convert(CuPtr{Cvoid}, kernel.random_state))
end

@eval @inline global_random_state() =
Expand Down Expand Up @@ -54,6 +57,31 @@ function seed_rng!(seed=nothing)
unsafe_store!(random_state_ptr, ThreadLocalRNGState(seed), index, Val(align))
end


# generators

function get_global_thread_index()
block_id = (gridDim().x * (blockIdx().y - 1)) + (blockIdx().x - 1)

thread_id = (block_id * (blockDim().x * blockDim().y)) +
((threadIdx().y - 1) * blockDim().x) +
(threadIdx().x - 1)

return thread_id + 1
end

function xorshift(x::UInt32)::UInt32
x = xor(x, x << 13)
x = xor(x, x >> 17)
x = xor(x, x << 5)
return x
end

function generate_next_state(state::ThreadLocalRNGState)
new_val = xorshift(state.val)
return ThreadLocalRNGState(new_val)
end

function rand()
random_state_ptr = get_random_state_ptr()
index = get_global_thread_index()
Expand All @@ -66,17 +94,3 @@ function rand()
res = (new_state.val >> 9) | reinterpret(UInt32, 1f0)
return reinterpret(Float32, res) - 1.0f0
end

function init_random_state!(kernel, len)
required_size = sizeof(ThreadLocalRNGState) * len

if isnothing(kernel.random_state)
kernel.random_state = Mem.alloc(Mem.Device, required_size)
elseif sizeof(kernel.random_state) < required_size
Mem.free(kernel.random_state, async=true)
kernel.random_state = Mem.alloc(Mem.Device, required_size)
end

random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state")
random_state_ptr[] = reinterpret(Ptr{Cvoid}, convert(CuPtr{Cvoid}, kernel.random_state))
end

0 comments on commit 269a2dd

Please sign in to comment.