From 466943c69c08a1e0a33d18af29f54d53fdbe84e8 Mon Sep 17 00:00:00 2001 From: Stijn De Ridder Date: Fri, 19 Mar 2021 09:41:04 +0100 Subject: [PATCH] Add support for constant memory. --- src/compiler/execution.jl | 23 +++++++-- src/device/intrinsics.jl | 1 + src/device/intrinsics/random.jl | 85 +++++++++++++++++++++++++++++++++ test/device/intrinsics.jl | 46 ++++++++++++++++++ 4 files changed, 152 insertions(+), 3 deletions(-) create mode 100644 src/device/intrinsics/random.jl diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 0bca3a430f..4074aa77ac 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -200,10 +200,16 @@ end ## host-side kernels -struct HostKernel{F,TT} <: AbstractKernel{F,TT} +mutable struct HostKernel{F,TT} <: AbstractKernel{F,TT} ctx::CuContext mod::CuModule fun::CuFunction + + random_state::Union{Nothing,Missing,CuVector{UInt32}} + + function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, random_state) where {F,TT} + kernel = new{F,TT}(ctx, mod, fun, random_state) + end end @doc (@doc AbstractKernel) HostKernel @@ -352,10 +358,21 @@ function cufunction_link(@nospecialize(job::CompilerJob), compiled) filter!(!isequal("exception_flag"), compiled.external_gvars) end - return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun) + random_state = nothing + if "global_random_state" in compiled.external_gvars + random_state = missing + filter!(!isequal("global_random_state"), compiled.external_gvars) + end + + return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, random_state) end -(kernel::HostKernel)(args...; kwargs...) = call(kernel, map(cudaconvert, args)...; kwargs...) +function (kernel::HostKernel)(args...; threads::CuDim=1, blocks::CuDim=1, kwargs...) + if kernel.random_state !== nothing + init_random_state!(kernel, prod(threads) * prod(blocks)) + end + call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...) +end ## device-side kernels diff --git a/src/device/intrinsics.jl b/src/device/intrinsics.jl index 0e0814b571..211f9176cf 100644 --- a/src/device/intrinsics.jl +++ b/src/device/intrinsics.jl @@ -41,6 +41,7 @@ include("intrinsics/memory_dynamic.jl") include("intrinsics/atomics.jl") include("intrinsics/misc.jl") include("intrinsics/wmma.jl") +include("intrinsics/random.jl") # functionality from libdevice # diff --git a/src/device/intrinsics/random.jl b/src/device/intrinsics/random.jl new file mode 100644 index 0000000000..85b0901a14 --- /dev/null +++ b/src/device/intrinsics/random.jl @@ -0,0 +1,85 @@ +## random number generation + +using Random + + +# helpers + +global_index() = (threadIdx().x, threadIdx().y, threadIdx().z, + blockIdx().x, blockIdx().y, blockIdx().z) + + +# global state + +struct ThreadLocalRNG <: AbstractRNG + vals::CuDeviceArray{UInt32, 6, AS.Generic} +end + +function init_random_state!(kernel, len) + if kernel.random_state === missing || length(kernel.random_state) < len + kernel.random_state = CuVector{UInt32}(undef, len) + end + + random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state") + random_state_ptr[] = reinterpret(Ptr{Cvoid}, pointer(kernel.random_state)) +end + +@eval @inline function global_random_state() + ptr = reinterpret(LLVMPtr{UInt32, AS.Generic}, Base.llvmcall( + $("""@global_random_state = weak externally_initialized global i$(WORD_SIZE) 0 + define i$(WORD_SIZE) @entry() #0 { + %ptr = load i$(WORD_SIZE), i$(WORD_SIZE)* @global_random_state, align 8 + ret i$(WORD_SIZE) %ptr + } + attributes #0 = { alwaysinline } + """, "entry"), Ptr{Cvoid}, Tuple{})) + dims = (blockDim().x, blockDim().y, blockDim().z, gridDim().x, gridDim().y, gridDim().z) + CuDeviceArray(dims, ptr) +end + +@device_override Random.default_rng() = ThreadLocalRNG(global_random_state()) + +@device_override Random.make_seed() = clock(UInt32) + +function Random.seed!(rng::ThreadLocalRNG, seed::Integer) + index = global_index() + rng.vals[index...] = seed + return +end + + +# generators + +function xorshift(x::UInt32)::UInt32 + x = xor(x, x << 13) + x = xor(x, x >> 17) + x = xor(x, x << 5) + return x +end + +function get_thread_word(rng::ThreadLocalRNG) + # NOTE: we add the current linear index to the local state, to make sure threads get + # different random numbers when unseeded (initial state = 0 for all threads) + index = global_index() + offset = LinearIndices(rng.vals)[index...] + state = rng.vals[index...] + UInt32(offset) + + new_state = generate_next_state(state) + rng.vals[index...] = new_state + + return new_state # FIXME: return old state? +end + +function generate_next_state(state::UInt32) + new_val = xorshift(state) + return UInt32(new_val) +end + +# TODO: support for more types (can we reuse more of the Random standard library?) +# see RandomNumbers.jl + +function Random.rand(rng::ThreadLocalRNG, ::Type{Float32}) + word = get_thread_word(rng) + res = (word >> 9) | reinterpret(UInt32, 1f0) + return reinterpret(Float32, res) - 1.0f0 +end diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 4b4de26133..342a0dbe57 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -1180,3 +1180,49 @@ end end end + + + +############################################################################################ + +@testset "random numbers" begin + +n = 256 + +@testset "basic" begin + function kernel(A::CuDeviceArray{T}, B::CuDeviceArray{T}) where {T} + tid = threadIdx().x + A[tid] = rand(T) + B[tid] = rand(T) + return nothing + end + + a = CUDA.zeros(Float32, n) + b = CUDA.zeros(Float32, n) + + @cuda threads=n kernel(a, b) + + # FIXME(?): all these tests have a (very small) chance to fail, but that's somewhat inherent to rand() without a seed + @test allunique(Array(a)) + @test allunique(Array(b)) + @test Array(a) != Array(b) +end + +@testset "custom seed" begin + function kernel(A::CuDeviceArray{T}) where {T} + tid = threadIdx().x + Random.seed!(1234) + A[tid] = rand(T) + return nothing + end + + a = CUDA.zeros(Float32, n) + b = CUDA.zeros(Float32, n) + + @cuda threads=n kernel(a) + @cuda threads=n kernel(b) + + @test Array(a) == Array(b) +end + +end