Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a basic implementation of rand() for use inside kernels #772

Merged
merged 3 commits into from Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions Manifest.toml
Expand Up @@ -190,6 +190,12 @@ uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
deps = ["Serialization"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[[RandomNumbers]]
deps = ["Random", "Requires"]
git-tree-sha1 = "441e6fc35597524ada7f85e13df1f4e10137d16f"
uuid = "e6cf234a-135c-5ec9-84dd-332b85af5143"
version = "1.4.0"

[[Reexport]]
git-tree-sha1 = "57d8440b0c7d98fc4f889e478e80f268d534c9d5"
uuid = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -21,6 +21,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Memoize = "c03570c3-d221-55d1-a50c-7939bbd78826"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand Down
5 changes: 3 additions & 2 deletions src/CUDA.jl
Expand Up @@ -58,7 +58,9 @@ include("device/llvm.jl")
include("device/runtime.jl")
include("device/texture.jl")

# array essentials
include("pool.jl")
include("array.jl")

# compiler libraries
include("../lib/cupti/CUPTI.jl")
Expand All @@ -71,8 +73,7 @@ include("compiler/execution.jl")
include("compiler/exceptions.jl")
include("compiler/reflection.jl")

# array abstraction
include("array.jl")
# array implementation
include("gpuarrays.jl")
include("utilities.jl")
include("texture.jl")
Expand Down
7 changes: 0 additions & 7 deletions src/array.jl
Expand Up @@ -241,13 +241,6 @@ function Base.unsafe_convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::DenseCuArr
CuDeviceArray{T,N,AS.Global}(size(a), reinterpret(LLVMPtr{T,AS.Global}, pointer(a)))
end

Adapt.adapt_storage(::Adaptor, xs::CuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)

# we materialize ReshapedArray/ReinterpretArray/SubArray/... directly as a device array
Adapt.adapt_structure(::Adaptor, xs::DenseCuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)


## interop with CPU arrays

Expand Down
30 changes: 27 additions & 3 deletions src/compiler/execution.jl
Expand Up @@ -126,6 +126,13 @@ end
Base.getindex(r::CuRefValue) = r.x
Adapt.adapt_structure(to::Adaptor, r::Base.RefValue) = CuRefValue(adapt(to, r[]))

Adapt.adapt_storage(::Adaptor, xs::CuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)

# we materialize ReshapedArray/ReinterpretArray/SubArray/... directly as a device array
Adapt.adapt_structure(::Adaptor, xs::DenseCuArray{T,N}) where {T,N} =
Base.unsafe_convert(CuDeviceArray{T,N,AS.Global}, xs)

"""
cudaconvert(x)

Expand Down Expand Up @@ -193,10 +200,16 @@ end

## host-side kernels

struct HostKernel{F,TT} <: AbstractKernel{F,TT}
mutable struct HostKernel{F,TT} <: AbstractKernel{F,TT}
ctx::CuContext
mod::CuModule
fun::CuFunction

random_state::Union{Nothing,Missing,CuVector{UInt32}}

function HostKernel{F,TT}(ctx::CuContext, mod::CuModule, fun::CuFunction, random_state) where {F,TT}
kernel = new{F,TT}(ctx, mod, fun, random_state)
end
end

@doc (@doc AbstractKernel) HostKernel
Expand Down Expand Up @@ -345,10 +358,21 @@ function cufunction_link(@nospecialize(job::CompilerJob), compiled)
filter!(!isequal("exception_flag"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun)
random_state = nothing
if "global_random_state" in compiled.external_gvars
random_state = missing
filter!(!isequal("global_random_state"), compiled.external_gvars)
end

return HostKernel{job.source.f,job.source.tt}(ctx, mod, fun, random_state)
end

(kernel::HostKernel)(args...; kwargs...) = call(kernel, map(cudaconvert, args)...; kwargs...)
function (kernel::HostKernel)(args...; threads::CuDim=1, blocks::CuDim=1, kwargs...)
if kernel.random_state !== nothing
init_random_state!(kernel, prod(threads) * prod(blocks))
end
call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...)
end


## device-side kernels
Expand Down
1 change: 1 addition & 0 deletions src/device/intrinsics.jl
Expand Up @@ -41,6 +41,7 @@ include("intrinsics/memory_dynamic.jl")
include("intrinsics/atomics.jl")
include("intrinsics/misc.jl")
include("intrinsics/wmma.jl")
include("intrinsics/random.jl")

# functionality from libdevice
#
Expand Down
71 changes: 71 additions & 0 deletions src/device/intrinsics/random.jl
@@ -0,0 +1,71 @@
## random number generation

using Random
import RandomNumbers


# helpers

global_index() = (threadIdx().x, threadIdx().y, threadIdx().z,
blockIdx().x, blockIdx().y, blockIdx().z)


# global state

struct ThreadLocalXorshift32 <: RandomNumbers.AbstractRNG{UInt32}
vals::CuDeviceArray{UInt32, 6, AS.Generic}
end

function init_random_state!(kernel, len)
if kernel.random_state === missing || length(kernel.random_state) < len
kernel.random_state = CuVector{UInt32}(undef, len)
end

random_state_ptr = CuGlobal{Ptr{Cvoid}}(kernel.mod, "global_random_state")
random_state_ptr[] = reinterpret(Ptr{Cvoid}, pointer(kernel.random_state))
end

@eval @inline function global_random_state()
ptr = reinterpret(LLVMPtr{UInt32, AS.Generic}, Base.llvmcall(
$("""@global_random_state = weak externally_initialized global i$(WORD_SIZE) 0
define i$(WORD_SIZE) @entry() #0 {
%ptr = load i$(WORD_SIZE), i$(WORD_SIZE)* @global_random_state, align 8
S-D-R marked this conversation as resolved.
Show resolved Hide resolved
ret i$(WORD_SIZE) %ptr
}
attributes #0 = { alwaysinline }
""", "entry"), Ptr{Cvoid}, Tuple{}))
dims = (blockDim().x, blockDim().y, blockDim().z, gridDim().x, gridDim().y, gridDim().z)
CuDeviceArray(dims, ptr)
end

@device_override Random.default_rng() = ThreadLocalXorshift32(global_random_state())

@device_override Random.make_seed() = clock(UInt32)

function Random.seed!(rng::ThreadLocalXorshift32, seed::Integer)
index = global_index()
rng.vals[index...] = seed % UInt32
return
end


# generators

function xorshift(x::UInt32)::UInt32
x = xor(x, x << 13)
x = xor(x, x >> 17)
x = xor(x, x << 5)
return x
end

function Random.rand(rng::ThreadLocalXorshift32, ::Type{UInt32})
# NOTE: we add the current linear index to the local state, to make sure threads get
# different random numbers when unseeded (initial state = 0 for all threads)
index = global_index()
offset = LinearIndices(rng.vals)[index...]
state = rng.vals[index...] + UInt32(offset)

new_state = xorshift(state)
rng.vals[index...] = new_state
return new_state
end
54 changes: 54 additions & 0 deletions test/device/intrinsics.jl
Expand Up @@ -1180,3 +1180,57 @@ end
end

end



############################################################################################

@testset "random numbers" begin

n = 256

@testset "basic" begin
function kernel(A::CuDeviceArray{T}, B::CuDeviceArray{T}) where {T}
tid = threadIdx().x
A[tid] = rand(T)
B[tid] = rand(T)
return nothing
end

@testset for T in (Int32, UInt32, Int64, UInt64, Int128, UInt128,
Float32, Float64)
a = CUDA.zeros(T, n)
b = CUDA.zeros(T, n)

@cuda threads=n kernel(a, b)

@test all(Array(a) .!= Array(b))

if T == Float64
@test allunique(Array(a))
@test allunique(Array(b))
end
end
end

@testset "custom seed" begin
function kernel(A::CuDeviceArray{T}) where {T}
tid = threadIdx().x
Random.seed!(1234)
A[tid] = rand(T)
return nothing
end

@testset for T in (Int32, UInt32, Int64, UInt64, Int128, UInt128,
Float32, Float64)
a = CUDA.zeros(T, n)
b = CUDA.zeros(T, n)

@cuda threads=n kernel(a)
@cuda threads=n kernel(b)

@test Array(a) == Array(b)
end
end

end