From 269a2dd495b774f9f0b96b33fe5a25201f18ea96 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Fri, 19 Mar 2021 10:11:52 +0100 Subject: [PATCH] Simplify implementation. --- src/compiler/execution.jl | 29 ++++++------- src/device/intrinsics/random.jl | 74 ++++++++++++++++++++------------- 2 files changed, 57 insertions(+), 46 deletions(-) diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 6a911faf3f..9dd4335c68 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -197,12 +197,13 @@ mutable struct HostKernel{F,TT} <: AbstractKernel{F,TT} ctx::CuContext mod::CuModule fun::CuFunction - uses_random_state::Bool - random_state::Union{Nothing,Mem.DeviceBuffer} - function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, uses_random_state::Bool) where {F,TT} - kernel = new{F,TT}(ctx, mod, fun, uses_random_state, nothing) + + random_state::Union{Nothing,Missing,Mem.DeviceBuffer} + + function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, random_state) where {F,TT} + kernel = new{F,TT}(ctx, mod, fun, random_state) finalizer(kernel) do k - if !isnothing(k.random_state) + if k.random_state isa Mem.DeviceBuffer Mem.free(k.random_state) end end @@ -356,24 +357,20 @@ function cufunction_link(@nospecialize(job::CompilerJob), compiled) filter!(!isequal("exception_flag"), compiled.external_gvars) end - uses_random_state = false - + random_state = nothing if "global_random_state" in compiled.external_gvars - uses_random_state = true + random_state = missing filter!(!isequal("global_random_state"), compiled.external_gvars) end - return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, uses_random_state) + return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, random_state) end -function (kernel::HostKernel)(args...; kwargs...) - if kernel.uses_random_state - kws = Dict(kwargs) - num_threads = get(kws, :threads, 1) * get(kws, :blocks, 1) - - init_random_state!(kernel, num_threads) +function (kernel::HostKernel)(args...; threads::Integer=1, blocks::Integer=1, kwargs...) + if kernel.random_state !== nothing + init_random_state!(kernel, threads * blocks) end - call(kernel, map(cudaconvert, args)...; kwargs...) + call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...) end diff --git a/src/device/intrinsics/random.jl b/src/device/intrinsics/random.jl index 6fa58c376f..f542e6ad3c 100644 --- a/src/device/intrinsics/random.jl +++ b/src/device/intrinsics/random.jl @@ -1,27 +1,30 @@ -function get_global_thread_index() - block_id = (gridDim().x * (blockIdx().y - 1)) + (blockIdx().x - 1) +## random number generation - thread_id = (block_id * (blockDim().x * blockDim().y)) + - ((threadIdx().y - 1) * blockDim().x) + - (threadIdx().x - 1) +using Random + +# TODO: use an RNG, overriding Random.default_rng()? +# or otherwise reuse more of the Random stdlib? - return thread_id + 1 -end + +# global state struct ThreadLocalRNGState val::UInt32 end -function xorshift(x::UInt32)::UInt32 - x = xor(x, x << 13) - x = xor(x, x >> 17) - x = xor(x, x << 5) - return x -end +function init_random_state!(kernel, len) + required_size = sizeof(ThreadLocalRNGState) * len -function generate_next_state(state::ThreadLocalRNGState) - new_val = xorshift(state.val) - return ThreadLocalRNGState(new_val) + if kernel.random_state === missing + kernel.random_state = Mem.alloc(Mem.Device, required_size; async=true) + elseif sizeof(kernel.random_state) < required_size + Mem.free(kernel.random_state, async=true) + kernel.random_state = Mem.alloc(Mem.Device, required_size; async=true) + end + @show kernel.random_state + + random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state") + random_state_ptr[] = reinterpret(Ptr{Cvoid}, convert(CuPtr{Cvoid}, kernel.random_state)) end @eval @inline global_random_state() = @@ -54,6 +57,31 @@ function seed_rng!(seed=nothing) unsafe_store!(random_state_ptr, ThreadLocalRNGState(seed), index, Val(align)) end + +# generators + +function get_global_thread_index() + block_id = (gridDim().x * (blockIdx().y - 1)) + (blockIdx().x - 1) + + thread_id = (block_id * (blockDim().x * blockDim().y)) + + ((threadIdx().y - 1) * blockDim().x) + + (threadIdx().x - 1) + + return thread_id + 1 +end + +function xorshift(x::UInt32)::UInt32 + x = xor(x, x << 13) + x = xor(x, x >> 17) + x = xor(x, x << 5) + return x +end + +function generate_next_state(state::ThreadLocalRNGState) + new_val = xorshift(state.val) + return ThreadLocalRNGState(new_val) +end + function rand() random_state_ptr = get_random_state_ptr() index = get_global_thread_index() @@ -66,17 +94,3 @@ function rand() res = (new_state.val >> 9) | reinterpret(UInt32, 1f0) return reinterpret(Float32, res) - 1.0f0 end - -function init_random_state!(kernel, len) - required_size = sizeof(ThreadLocalRNGState) * len - - if isnothing(kernel.random_state) - kernel.random_state = Mem.alloc(Mem.Device, required_size) - elseif sizeof(kernel.random_state) < required_size - Mem.free(kernel.random_state, async=true) - kernel.random_state = Mem.alloc(Mem.Device, required_size) - end - - random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state") - random_state_ptr[] = reinterpret(Ptr{Cvoid}, convert(CuPtr{Cvoid}, kernel.random_state)) -end