Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rand: seed kernels from the host. #2035

Merged
merged 7 commits into from Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 2 additions & 3 deletions .buildkite/pipeline.yml
Expand Up @@ -19,7 +19,7 @@ steps:
cuda: "*"
commands: |
julia --project -e '
# make sure the 1.6-era Manifest works on this Julia version
# make sure the 1.7-era Manifest works on this Julia version
using Pkg
Pkg.resolve()

Expand All @@ -32,7 +32,6 @@ steps:
matrix:
setup:
julia:
- "1.6"
- "1.7"
- "1.8"
- "1.9"
Expand Down Expand Up @@ -315,10 +314,10 @@ steps:
matrix:
setup:
julia:
- "1.6"
- "1.7"
- "1.8"
- "1.9"
- "1.10"
- "nightly"
adjustments:
- with:
Expand Down
62 changes: 62 additions & 0 deletions src/compiler/compilation.jl
Expand Up @@ -42,6 +42,68 @@ GPUCompiler.method_table(@nospecialize(job::CUDACompilerJob)) = method_table

GPUCompiler.kernel_state_type(job::CUDACompilerJob) = KernelState

function GPUCompiler.finish_module!(@nospecialize(job::CUDACompilerJob),
mod::LLVM.Module, entry::LLVM.Function)
entry = invoke(GPUCompiler.finish_module!,
Tuple{CompilerJob{PTXCompilerTarget}, LLVM.Module, LLVM.Function},
job, mod, entry)

# if this kernel uses our RNG, we should prime the shared state.
# XXX: these transformations should really happen at the Julia IR level...
if haskey(globals(mod), "global_random_keys")
f = initialize_rng_state
ft = typeof(f)
tt = Tuple{}

# don't recurse into `initialize_rng_state()` itself
if job.source.specTypes.parameters[1] == ft
return entry
end

# create a deferred compilation job for `initialize_rng_state()`
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
job = CompilerJob(src, cfg, job.world)
id = length(GPUCompiler.deferred_codegen_jobs) + 1
GPUCompiler.deferred_codegen_jobs[id] = job

# generate IR for calls to `deferred_codegen` and the resulting function pointer
top_bb = first(blocks(entry))
bb = BasicBlock(top_bb, "initialize_rng")
LLVM.@dispose builder=IRBuilder() begin
position!(builder, bb)
subprogram = LLVM.get_subprogram(entry)
if subprogram !== nothing
loc = DILocation(0, 0, subprogram)
debuglocation!(builder, loc)
end
debuglocation!(builder, first(instructions(top_bb)))

# call the `deferred_codegen` marker function
T_ptr = LLVM.Int64Type()
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_ptr])
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
functions(mod)["deferred_codegen"]
else
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
end
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])

# call the `initialize_rng_state` function
rt = Core.Compiler.return_type(f, tt)
llvm_rt = convert(LLVMType, rt)
llvm_ft = LLVM.FunctionType(llvm_rt)
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
call!(builder, llvm_ft, fptr)
br!(builder, top_bb)
end

# XXX: put some of the above behind GPUCompiler abstractions
# (e.g., a compile-time version of `deferred_codegen`)
end
return entry
end


## compiler implementation (cache, configure, compile, and link)

Expand Down
11 changes: 8 additions & 3 deletions src/compiler/execution.jl
Expand Up @@ -212,9 +212,9 @@ end
end
end

# add the kernel state
# add the kernel state, passing an instance with a unique seed
pushfirst!(call_t, KernelState)
pushfirst!(call_args, :(kernel.state))
pushfirst!(call_args, :(KernelState(kernel.state.exception_flag, make_seed(kernel))))

# finalize types
call_tt = Base.to_tuple_type(call_t)
Expand Down Expand Up @@ -328,7 +328,7 @@ function cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
if kernel === nothing
# create the kernel state object
exception_ptr = create_exceptions!(fun.mod)
state = KernelState(exception_ptr)
state = KernelState(exception_ptr, UInt32(0))

kernel = HostKernel{F,tt}(f, fun, state)
_kernel_instances[key] = kernel
Expand All @@ -344,6 +344,8 @@ function (kernel::HostKernel)(args...; threads::CuDim=1, blocks::CuDim=1, kwargs
call(kernel, map(cudaconvert, args)...; threads, blocks, kwargs...)
end

make_seed(::HostKernel) = Random.rand(UInt32)


## device-side kernels

Expand Down Expand Up @@ -374,6 +376,9 @@ end

(kernel::DeviceKernel)(args...; kwargs...) = call(kernel, args...; kwargs...)

# re-use the parent kernel's seed to avoid need for the RNG
make_seed(::DeviceKernel) = kernel_state().random_seed


## other

Expand Down
39 changes: 21 additions & 18 deletions src/device/random.jl
Expand Up @@ -6,7 +6,12 @@ import RandomNumbers

# global state

# shared memory with the actual seed, per warp, loaded lazily or overridden by calling `seed!`
# we cannot store RNG state in thread-local memory (i.e. in the `rng` object) because that
# inflate register usage. instead, we store it in shared memory, with one entry per warp.
#
# XXX: this implies that state is shared between `rng` objects, which can be surprising.

# array with seeds, per warp, initialized on kernel start or by calling `seed!`
@eval @inline function global_random_keys()
ptr = Base.llvmcall(
$("""@global_random_keys = weak addrspace($(AS.Shared)) global [32 x i32] zeroinitializer, align 32
Expand All @@ -20,7 +25,7 @@ import RandomNumbers
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
end

# shared memory with per-warp counters, incremented when generating numbers
# array with per-warp counters, incremented when generating numbers
@eval @inline function global_random_counters()
ptr = Base.llvmcall(
$("""@global_random_counters = weak addrspace($(AS.Shared)) global [32 x i32] zeroinitializer, align 32
Expand All @@ -34,6 +39,17 @@ end
CuDeviceArray{UInt32,1,AS.Shared}(ptr, (32,))
end

# initialization function, called automatically at the start of each kernel because
# there's no reliable way to detect uninitialized shared memory (see JuliaGPU/CUDA.jl#2008)
function initialize_rng_state()
threadId = threadIdx().x + (threadIdx().y - 1i32) * blockDim().x +
(threadIdx().z - 1i32) * blockDim().x * blockDim().y
warpId = (threadId - 1i32) >> 0x5 + 1i32 # fld1

@inbounds global_random_keys()[warpId] = kernel_state().random_seed
@inbounds global_random_counters()[warpId] = 0
end

@device_override Random.make_seed() = clock(UInt32)


Expand All @@ -43,19 +59,7 @@ using Random123: philox2x_round, philox2x_bumpkey

# GPU-compatible/optimized version of the generator from Random123.jl
struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64}
@inline function Philox2x32{R}() where R
rng = new{R}()
if rng.key == 0
# initialize the key. this happens when first accessing the (0-initialized)
# shared memory key from each block. if we ever want to make the device seed
# controlable from the host, this would be the place to read a global seed.
#
# note however that it is undefined how shared memory persists across e.g.
# launches, so we may not be able to rely on the zero initalization then.
rng.key = Random.make_seed()
end
return rng
end
# NOTE: the state is stored globally; see comments at the top of this file.
end

# default to 7 rounds; enough to pass SmallCrush
Expand All @@ -66,9 +70,7 @@ end
(threadIdx().z - 1i32) * blockDim().x * blockDim().y
warpId = (threadId - 1i32) >> 0x5 + 1i32 # fld1

if field === :seed
@inbounds global_random_seed()[1]
elseif field === :key
if field === :key
@inbounds global_random_keys()[warpId]
elseif field === :ctr1
@inbounds global_random_counters()[warpId]
Expand Down Expand Up @@ -139,6 +141,7 @@ function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R}
# update the warp counter
# NOTE: this performs the same update on every thread in the warp, but each warp writes
# to a unique location so the duplicate writes are innocuous
# NOTE: this is not guaranteed to be visible in other kernels (JuliaGPU/CUDA.jl#2008)
# XXX: what if this overflows? we can't increment ctr2. bump the key?
rng.ctr1 += 1i32

Expand Down
5 changes: 2 additions & 3 deletions src/device/runtime.jl
Expand Up @@ -28,14 +28,13 @@ end

struct KernelState
exception_flag::Ptr{Cvoid}
random_seed::UInt32
end

@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState)

exception_flag() = kernel_state().exception_flag

function signal_exception()
ptr = exception_flag()
ptr = kernel_state().exception_flag
if ptr !== C_NULL
unsafe_store!(convert(Ptr{Int}, ptr), 1)
threadfence_system()
Expand Down
11 changes: 7 additions & 4 deletions src/sorting.jl
Expand Up @@ -459,8 +459,11 @@ end

#function sort

function quicksort!(c::AbstractArray{T,N}; lt::F1, by::F2, dims::Int, partial_k=nothing, block_size_shift=0) where {T,N,F1,F2}
max_depth = CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH)
function quicksort!(c::AbstractArray{T,N}; lt::F1, by::F2, dims::Int, partial_k=nothing,
block_size_shift=0) where {T,N,F1,F2}
# XXX: after JuliaLang/CUDA.jl#2035, which changed the kernel state struct contents,
# the max depth needed to be reduced by 1 to avoid an illegal memory crash...
max_depth = CUDA.limit(CUDA.LIMIT_DEV_RUNTIME_SYNC_DEPTH) - 1
len = size(c, dims)

1 <= dims <= N || throw(ArgumentError("dimension out of range"))
Expand Down Expand Up @@ -884,11 +887,11 @@ function bitonic_sort!(c; by = identity, lt = isless, rev = false)
# N_pseudo_blocks = how many pseudo-blocks are in this layer of the network
N_pseudo_blocks = nextpow(2, c_len) ÷ pseudo_block_length
pseudo_blocks_per_block = threads2 ÷ pseudo_block_length

# grid dimensions
N_blocks = max(1, N_pseudo_blocks ÷ pseudo_blocks_per_block)
block_size = pseudo_block_length, threads2 ÷ pseudo_block_length

kernel1(args1...; blocks=N_blocks, threads=block_size,
shmem=bitonic_shmem(c, block_size))
break
Expand Down