Skip to content

Commit

Permalink
Merge pull request #2035 from JuliaGPU/tb/rand_seed
Browse files Browse the repository at this point in the history
rand: seed kernels from the host.
  • Loading branch information
maleadt committed Aug 17, 2023
2 parents fade845 + 556b23e commit 9796d5a
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 31 deletions.
5 changes: 2 additions & 3 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ steps:
cuda: "*"
commands: |
julia --project -e '
# make sure the 1.6-era Manifest works on this Julia version
# make sure the 1.7-era Manifest works on this Julia version
using Pkg
Pkg.resolve()
Expand All @@ -32,7 +32,6 @@ steps:
matrix:
setup:
julia:
- "1.6"
- "1.7"
- "1.8"
- "1.9"
Expand Down Expand Up @@ -315,10 +314,10 @@ steps:
matrix:
setup:
julia:
- "1.6"
- "1.7"
- "1.8"
- "1.9"
- "1.10"
- "nightly"
adjustments:
- with:
Expand Down
62 changes: 62 additions & 0 deletions src/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,68 @@ GPUCompiler.method_table(@nospecialize(job::CUDACompilerJob)) = method_table

GPUCompiler.kernel_state_type(job::CUDACompilerJob) = KernelState

function GPUCompiler.finish_module!(@nospecialize(job::CUDACompilerJob),
mod::LLVM.Module, entry::LLVM.Function)
entry = invoke(GPUCompiler.finish_module!,
Tuple{CompilerJob{PTXCompilerTarget}, LLVM.Module, LLVM.Function},
job, mod, entry)

# if this kernel uses our RNG, we should prime the shared state.
# XXX: these transformations should really happen at the Julia IR level...
if haskey(globals(mod), "global_random_keys")
f = initialize_rng_state
ft = typeof(f)
tt = Tuple{}

# don't recurse into `initialize_rng_state()` itself
if job.source.specTypes.parameters[1] == ft
return entry
end

# create a deferred compilation job for `initialize_rng_state()`
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
job = CompilerJob(src, cfg, job.world)
id = length(GPUCompiler.deferred_codegen_jobs) + 1
GPUCompiler.deferred_codegen_jobs[id] = job

# generate IR for calls to `deferred_codegen` and the resulting function pointer
top_bb = first(blocks(entry))
bb = BasicBlock(top_bb, "initialize_rng")
LLVM.@dispose builder=IRBuilder() begin
position!(builder, bb)
subprogram = LLVM.get_subprogram(entry)
if subprogram !== nothing
loc = DILocation(0, 0, subprogram)
debuglocation!(builder, loc)
end
debuglocation!(builder, first(instructions(top_bb)))

# call the `deferred_codegen` marker function
T_ptr = LLVM.Int64Type()
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_ptr])
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
functions(mod)["deferred_codegen"]
else
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
end
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])

# call the `initialize_rng_state` function
rt = Core.Compiler.return_type(f, tt)
llvm_rt = convert(LLVMType, rt)
llvm_ft = LLVM.FunctionType(llvm_rt)
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
call!(builder, llvm_ft, fptr)
br!(builder, top_bb)
end

# XXX: put some of the above behind GPUCompiler abstractions
# (e.g., a compile-time version of `deferred_codegen`)
end
return entry
end


## compiler implementation (cache, configure, compile, and link)

Expand Down
11 changes: 8 additions & 3 deletions src/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,9 @@ end
end
end

# add the kernel state
# add the kernel state, passing an instance with a unique seed
pushfirst!(call_t, KernelState)
pushfirst!(call_args, :(kernel.state))
pushfirst!(call_args, :(KernelState(kernel.state.exception_flag, make_seed(kernel))))

# finalize types
call_tt = Base.to_tuple_type(call_t)
Expand Down Expand Up @@ -329,7 +329,7 @@ function cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
if kernel === nothing
# create the kernel state object
exception_ptr = create_exceptions!(fun.mod)
state = KernelState(exception_ptr)
state = KernelState(exception_ptr, UInt32(0))

kernel = HostKernel{F,tt}(f, fun, state)
_kernel_instances[key] = kernel
Expand All @@ -345,6 +345,8 @@ function (kernel::HostKernel)(args...; threads::CuDim=1, blocks::CuDim=1, kwargs
call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...)
end

make_seed(::HostKernel) = Random.rand(UInt32)


## device-side kernels

Expand Down Expand Up @@ -375,6 +377,9 @@ end

(kernel::DeviceKernel)(args...; kwargs...) = call(kernel, args...; kwargs...)

# re-use the parent kernel's seed to avoid need for the RNG
make_seed(::DeviceKernel) = kernel_state().random_seed


## other

Expand Down
39 changes: 21 additions & 18 deletions src/device/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@ import RandomNumbers

# global state

# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!`
# we cannot store RNG state in thread-local memory (i.e. in the `rng` object) because that
# inflate register usage. instead, we store it in shared memory, with one entry per warp.
#
# XXX: this implies that state is shared between `rng` objects, which can be surprising.

# array with seeds, per warp, initialized on kernel start or by calling `seed!`
@eval @inline function global_random_keys()
ptr = Base.llvmcall(
$("""@global_random_keys = weak addrspace($(AS.Shared)) global [32 x i32] zeroinitializer, align 32
Expand All @@ -20,7 +25,7 @@ import RandomNumbers
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
end

# shared memory with per-warp counters, incremented when generating numbers
# array with per-warp counters, incremented when generating numbers
@eval @inline function global_random_counters()
ptr = Base.llvmcall(
$("""@global_random_counters = weak addrspace($(AS.Shared)) global [32 x i32] zeroinitializer, align 32
Expand All @@ -34,6 +39,17 @@ end
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
end

# initialization function, called automatically at the start of each kernel because
# there's no reliable way to detect uninitialized shared memory (see JuliaGPU/CUDA.jl#2008)
function initialize_rng_state()
threadId = threadIdx().x + (threadIdx().y - 1i32) * blockDim().x +
(threadIdx().z - 1i32) * blockDim().x * blockDim().y
warpId = (threadId - 1i32) >> 0x5 + 1i32 # fld1

@inbounds global_random_keys()[warpId] = kernel_state().random_seed
@inbounds global_random_counters()[warpId] = 0
end

@device_override Random.make_seed() = clock(UInt32)


Expand All @@ -43,19 +59,7 @@ using Random123: philox2x_round, philox2x_bumpkey

# GPU-compatible/optimized version of the generator from Random123.jl
struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64}
@inline function Philox2x32{R}() where R
rng = new{R}()
if rng.key == 0
# initialize the key. this happens when first accessing the (0-initialized)
# shared memory key from each block. if we ever want to make the device seed
# controlable from the host, this would be the place to read a global seed.
#
# note however that it is undefined how shared memory persists across e.g.
# launches, so we may not be able to rely on the zero initalization then.
rng.key = Random.make_seed()
end
return rng
end
# NOTE: the state is stored globally; see comments at the top of this file.
end

# default to 7 rounds; enough to pass SmallCrush
Expand All @@ -66,9 +70,7 @@ end
(threadIdx().z - 1i32) * blockDim().x * blockDim().y
warpId = (threadId - 1i32) >> 0x5 + 1i32 # fld1

if field === :seed
@inbounds global_random_seed()[1]
elseif field === :key
if field === :key
@inbounds global_random_keys()[warpId]
elseif field === :ctr1
@inbounds global_random_counters()[warpId]
Expand Down Expand Up @@ -139,6 +141,7 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
# update the warp counter
# NOTE: this performs the same update on every thread in the warp, but each warp writes
# to a unique location so the duplicate writes are innocuous
# NOTE: this is not guaranteed to be visible in other kernels (JuliaGPU/CUDA.jl#2008)
# XXX: what if this overflows? we can't increment ctr2. bump the key?
rng.ctr1 += 1i32

Expand Down
5 changes: 2 additions & 3 deletions src/device/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,13 @@ end

struct KernelState
exception_flag::Ptr{Cvoid}
random_seed::UInt32
end

@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)

exception_flag() = kernel_state().exception_flag

function signal_exception()
ptr = exception_flag()
ptr = kernel_state().exception_flag
if ptr !== C_NULL
unsafe_store!(convert(Ptr{Int}, ptr), 1)
threadfence_system()
Expand Down
11 changes: 7 additions & 4 deletions src/sorting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -459,8 +459,11 @@ end

#function sort

function quicksort!(c::AbstractArray{T,N}; lt::F1, by::F2, dims::Int, partial_k=nothing, block_size_shift=0) where {T,N,F1,F2}
max_depth = CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH)
function quicksort!(c::AbstractArray{T,N}; lt::F1, by::F2, dims::Int, partial_k=nothing,
block_size_shift=0) where {T,N,F1,F2}
# XXX: after JuliaLang/CUDA.jl#2035, which changed the kernel state struct contents,
# the max depth needed to be reduced by 1 to avoid an illegal memory crash...
max_depth = CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH) - 1
len = size(c, dims)

1 <= dims <= N || throw(ArgumentError("dimension out of range"))
Expand Down Expand Up @@ -884,11 +887,11 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
# N_pseudo_blocks = how many pseudo-blocks are in this layer of the network
N_pseudo_blocks = nextpow(2, c_len) ÷ pseudo_block_length
pseudo_blocks_per_block = threads2 ÷ pseudo_block_length

# grid dimensions
N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block)
block_size = pseudo_block_length, threads2 ÷ pseudo_block_length

kernel1(args1...; blocks=N_blocks, threads=block_size,
shmem=bitonic_shmem(c, block_size))
break
Expand Down

0 comments on commit 9796d5a

Please sign in to comment.