Skip to content

Commit

Permalink
Add support for constant memory.
Browse files Browse the repository at this point in the history
  • Loading branch information
S-D-R authored and maleadt committed Mar 24, 2021
1 parent 9e2e0bb commit 466943c
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 3 deletions.
23 changes: 20 additions & 3 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,10 +200,16 @@ end

## host-side kernels

struct HostKernel{F,TT} <: AbstractKernel{F,TT}
mutable struct HostKernel{F,TT} <: AbstractKernel{F,TT}
ctx::CuContext
mod::CuModule
fun::CuFunction

random_state::Union{Nothing,Missing,CuVector{UInt32}}

function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, random_state) where {F,TT}
kernel = new{F,TT}(ctx, mod, fun, random_state)
end
end

@doc (@doc AbstractKernel) HostKernel
Expand Down Expand Up @@ -352,10 +358,21 @@ function cufunction_link(@nospecialize(job::CompilerJob), compiled)
filter!(!isequal("exception_flag"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun)
random_state = nothing
if "global_random_state" in compiled.external_gvars
random_state = missing
filter!(!isequal("global_random_state"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, random_state)
end

(kernel::HostKernel)(args...; kwargs...) = call(kernel, map(cudaconvert, args)...; kwargs...)
function (kernel::HostKernel)(args...; threads::CuDim=1, blocks::CuDim=1, kwargs...)
if kernel.random_state !== nothing
init_random_state!(kernel, prod(threads) * prod(blocks))
end
call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...)
end


## device-side kernels
Expand Down
1 change: 1 addition & 0 deletions src/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ include("intrinsics/memory_dynamic.jl")
include("intrinsics/atomics.jl")
include("intrinsics/misc.jl")
include("intrinsics/wmma.jl")
include("intrinsics/random.jl")

# functionality from libdevice
#
Expand Down
85 changes: 85 additions & 0 deletions src/device/intrinsics/random.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
## random number generation

using Random


# helpers

global_index() = (threadIdx().x, threadIdx().y, threadIdx().z,
blockIdx().x, blockIdx().y, blockIdx().z)


# global state

struct ThreadLocalRNG <: AbstractRNG
vals::CuDeviceArray{UInt32, 6, AS.Generic}
end

function init_random_state!(kernel, len)
if kernel.random_state === missing || length(kernel.random_state) < len
kernel.random_state = CuVector{UInt32}(undef, len)
end

random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state")
random_state_ptr[] = reinterpret(Ptr{Cvoid}, pointer(kernel.random_state))
end

@eval @inline function global_random_state()
ptr = reinterpret(LLVMPtr{UInt32, AS.Generic}, Base.llvmcall(
$("""@global_random_state = weak externally_initialized global i$(WORD_SIZE) 0
define i$(WORD_SIZE) @entry() #0 {
%ptr = load i$(WORD_SIZE), i$(WORD_SIZE)* @global_random_state, align 8
ret i$(WORD_SIZE) %ptr
}
attributes #0 = { alwaysinline }
""", "entry"), Ptr{Cvoid}, Tuple{}))
dims = (blockDim().x, blockDim().y, blockDim().z, gridDim().x, gridDim().y, gridDim().z)
CuDeviceArray(dims, ptr)
end

@device_override Random.default_rng() = ThreadLocalRNG(global_random_state())

@device_override Random.make_seed() = clock(UInt32)

function Random.seed!(rng::ThreadLocalRNG, seed::Integer)
index = global_index()
rng.vals[index...] = seed
return
end


# generators

function xorshift(x::UInt32)::UInt32
x = xor(x, x << 13)
x = xor(x, x >> 17)
x = xor(x, x << 5)
return x
end

function get_thread_word(rng::ThreadLocalRNG)
# NOTE: we add the current linear index to the local state, to make sure threads get
# different random numbers when unseeded (initial state = 0 for all threads)
index = global_index()
offset = LinearIndices(rng.vals)[index...]
state = rng.vals[index...] + UInt32(offset)

new_state = generate_next_state(state)
rng.vals[index...] = new_state

return new_state # FIXME: return old state?
end

function generate_next_state(state::UInt32)
new_val = xorshift(state)
return UInt32(new_val)
end

# TODO: support for more types (can we reuse more of the Random standard library?)
# see RandomNumbers.jl

function Random.rand(rng::ThreadLocalRNG, ::Type{Float32})
word = get_thread_word(rng)
res = (word >> 9) | reinterpret(UInt32, 1f0)
return reinterpret(Float32, res) - 1.0f0
end
46 changes: 46 additions & 0 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1180,3 +1180,49 @@ end
end

end



############################################################################################

@testset "random numbers" begin

n = 256

@testset "basic" begin
function kernel(A::CuDeviceArray{T}, B::CuDeviceArray{T}) where {T}
tid = threadIdx().x
A[tid] = rand(T)
B[tid] = rand(T)
return nothing
end

a = CUDA.zeros(Float32, n)
b = CUDA.zeros(Float32, n)

@cuda threads=n kernel(a, b)

# FIXME(?): all these tests have a (very small) chance to fail, but that's somewhat inherent to rand() without a seed
@test allunique(Array(a))
@test allunique(Array(b))
@test Array(a) != Array(b)
end

@testset "custom seed" begin
function kernel(A::CuDeviceArray{T}) where {T}
tid = threadIdx().x
Random.seed!(1234)
A[tid] = rand(T)
return nothing
end

a = CUDA.zeros(Float32, n)
b = CUDA.zeros(Float32, n)

@cuda threads=n kernel(a)
@cuda threads=n kernel(b)

@test Array(a) == Array(b)
end

end

0 comments on commit 466943c

Please sign in to comment.