diff --git a/Project.toml b/Project.toml index 7e22dd20..59d9373f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KernelAbstractions" uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c" -authors = ["Valentin Churavy and contributors"] version = "0.10.0-dev" +authors = ["Valentin Churavy and contributors"] [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -13,6 +13,9 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" OpenCL_jll = "6cb37087-e8b6-5417-8430-1f242f1e46e4" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Random123 = "74087812-796a-5b5d-8853-05524746bad3" +RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143" SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8" SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c" SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4" @@ -34,12 +37,15 @@ SparseArraysExt = "SparseArrays" Adapt = "0.4, 1.0, 2.0, 3.0, 4" Atomix = "0.1, 1" EnzymeCore = "0.7, 0.8.1" -GPUCompiler = "1.6" +GPUCompiler = "1.7.1" InteractiveUtils = "1.6" LLVM = "9.4.1" LinearAlgebra = "1.6" MacroTools = "0.5" PrecompileTools = "1" +Random = "1" +Random123 = "1.7.1" +RandomNumbers = "1.6.0" SPIRVIntrinsics = "0.5" SPIRV_LLVM_Backend_jll = "20" SPIRV_Tools_jll = "2024.4, 2025.1" diff --git a/src/pocl/compiler/compilation.jl b/src/pocl/compiler/compilation.jl index 3d5930d4..fb9f9585 100644 --- a/src/pocl/compiler/compilation.jl +++ b/src/pocl/compiler/compilation.jl @@ -19,6 +19,87 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) = in(fn, known_intrinsics) || contains(fn, "__spirv_") +GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState + +function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob), + mod::LLVM.Module, entry::LLVM.Function) + entry = invoke(GPUCompiler.finish_module!, + Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function}, + job, mod, entry) + + # if this kernel uses our RNG, we should prime the shared state. + # XXX: these transformations should really happen at the Julia IR level... + if haskey(functions(mod), "julia.opencl.random_keys") && job.config.kernel + # insert call to `initialize_rng_state` + f = initialize_rng_state + ft = typeof(f) + tt = Tuple{} + + # create a deferred compilation job for `initialize_rng_state` + src = methodinstance(ft, tt, GPUCompiler.tls_world_age()) + cfg = CompilerConfig(job.config; kernel=false, name=nothing) + job = CompilerJob(src, cfg, job.world) + id = length(GPUCompiler.deferred_codegen_jobs) + 1 + GPUCompiler.deferred_codegen_jobs[id] = job + + # generate IR for calls to `deferred_codegen` and the resulting function pointer + top_bb = first(blocks(entry)) + bb = BasicBlock(top_bb, "initialize_rng") + @dispose builder=IRBuilder() begin + position!(builder, bb) + subprogram = LLVM.subprogram(entry) + if subprogram !== nothing + loc = DILocation(0, 0, subprogram) + debuglocation!(builder, loc) + end + debuglocation!(builder, first(instructions(top_bb))) + + # call the `deferred_codegen` marker function + T_ptr = if LLVM.version() >= v"17" + LLVM.PointerType() + elseif VERSION >= v"1.12.0-DEV.225" + LLVM.PointerType(LLVM.Int8Type()) + else + LLVM.Int64Type() + end + T_id = convert(LLVMType, Int) + deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id]) + deferred_codegen = if haskey(functions(mod), "deferred_codegen") + functions(mod)["deferred_codegen"] + else + LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft) + end + fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)]) + + # call the `initialize_rng_state` function + rt = Core.Compiler.return_type(f, tt) + llvm_rt = convert(LLVMType, rt) + llvm_ft = LLVM.FunctionType(llvm_rt) + fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft)) + call!(builder, llvm_ft, fptr) + br!(builder, top_bb) + + # note the use of the device-side RNG in this kernel + push!(function_attributes(entry), StringAttribute("julia.opencl.rng", "")) + end + + # XXX: put some of the above behind GPUCompiler abstractions + # (e.g., a compile-time version of `deferred_codegen`) + end + return entry +end + +function GPUCompiler.finish_linked_module!(@nospecialize(job::OpenCLCompilerJob), mod::LLVM.Module) + for f in GPUCompiler.kernels(mod) + kernel_intrinsics = Dict( + "julia.opencl.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}), + "julia.opencl.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}), + ) + GPUCompiler.add_input_arguments!(job, mod, f, kernel_intrinsics) + end + return +end + ## compiler implementation (cache, configure, compile, and link) @@ -60,10 +141,13 @@ end function compile(@nospecialize(job::CompilerJob)) # TODO: this creates a context; cache those. obj, meta = JuliaContext() do ctx - GPUCompiler.compile(:obj, job) - end + obj, meta = GPUCompiler.compile(:obj, job) - return (; obj, entry = LLVM.name(meta.entry)) + entry = LLVM.name(meta.entry) + device_rng = StringAttribute("julia.opencl.rng", "") in collect(function_attributes(meta.entry)) + + (; obj, entry, device_rng) + end end # link into an executable kernel @@ -74,5 +158,5 @@ function link(@nospecialize(job::CompilerJob), compiled) error("Your device does not support SPIR-V, which is currently required for native execution.") end cl.build!(prog) - return cl.Kernel(prog, compiled.entry) + (; kernel=cl.Kernel(prog, compiled.entry), compiled.device_rng) end diff --git a/src/pocl/compiler/execution.jl b/src/pocl/compiler/execution.jl index dc47cb30..8fe66260 100644 --- a/src/pocl/compiler/execution.jl +++ b/src/pocl/compiler/execution.jl @@ -146,6 +146,8 @@ end abstract type AbstractKernel{F, TT} end +pass_arg(@nospecialize dt) = !(GPUCompiler.isghosttype(dt) || Core.Compiler.isconstType(dt)) + @inline @generated function (kernel::AbstractKernel{F, TT})( args...; call_kwargs... @@ -154,8 +156,7 @@ abstract type AbstractKernel{F, TT} end args = (:(kernel.f), (:(clconvert(args[$i], svm_pointers)) for i in 1:length(args))...) # filter out ghost arguments that shouldn't be passed - predicate = dt -> GPUCompiler.isghosttype(dt) || Core.Compiler.isconstType(dt) - to_pass = map(!predicate, sig.parameters) + to_pass = map(pass_arg, sig.parameters) call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]] call_args = Union{Expr, Symbol}[x[1] for x in zip(args, to_pass) if x[2]] @@ -167,12 +168,15 @@ abstract type AbstractKernel{F, TT} end end end + pushfirst!(call_t, KernelState) + pushfirst!(call_args, :(KernelState(kernel.rng_state ? Base.rand(UInt32) : UInt32(0)))) + # finalize types call_tt = Base.to_tuple_type(call_t) return quote svm_pointers = Ptr{Cvoid}[] - $cl.clcall(kernel.fun, $call_tt, $(call_args...); svm_pointers, call_kwargs...) + $cl.clcall(kernel.fun, $call_tt, $(call_args...); svm_pointers, kernel.rng_state, call_kwargs...) end end @@ -182,6 +186,7 @@ end struct HostKernel{F, TT} <: AbstractKernel{F, TT} f::F fun::cl.Kernel + rng_state::Bool end @@ -198,15 +203,15 @@ function clfunction(f::F, tt::TT = Tuple{}; kwargs...) where {F, TT} cache = compiler_cache(ctx) source = methodinstance(F, tt) config = compiler_config(dev; kwargs...)::OpenCLCompilerConfig - fun = GPUCompiler.cached_compilation(cache, source, config, compile, link) + linked = GPUCompiler.cached_compilation(cache, source, config, compile, link) # create a callable object that captures the function instance. we don't need to think # about world age here, as GPUCompiler already does and will return a different object - h = hash(fun, hash(f, hash(tt))) + h = hash(linked.kernel, hash(f, hash(tt))) kernel = get(_kernel_instances, h, nothing) if kernel === nothing # create the kernel state object - kernel = HostKernel{F, tt}(f, fun) + kernel = HostKernel{F, tt}(f, linked.kernel, linked.device_rng) _kernel_instances[h] = kernel end return kernel::HostKernel{F, tt} diff --git a/src/pocl/device/random.jl b/src/pocl/device/random.jl new file mode 100644 index 00000000..b70ce781 --- /dev/null +++ b/src/pocl/device/random.jl @@ -0,0 +1,234 @@ +## random number generation + +using Random +import RandomNumbers + +# local memory with the actual seed, per subgroup, set by `initialize_rng_state`` or overridden by calling `seed!` +@inline function global_random_keys() + n = get_num_sub_groups() + ptr = random_keys()::LLVMPtr{UInt32, AS.Workgroup} + return CLDeviceArray{UInt32, 1, AS.Workgroup}((n,), ptr) +end + +# local memory with per-subgroup counters, incremented when generating numbers +@inline function global_random_counters() + n = get_num_sub_groups() + ptr = random_counters()::LLVMPtr{UInt32, AS.Workgroup} + return CLDeviceArray{UInt32, 1, AS.Workgroup}((n,), ptr) +end + +# initialization function, called automatically at the start of each kernel +function initialize_rng_state() + subgroup_id = get_sub_group_id() + @inbounds global_random_keys()[subgroup_id] = kernel_state().random_seed + @inbounds global_random_counters()[subgroup_id] = 0 +end + +# generators + +using Random123: philox2x_round, philox2x_bumpkey + +# GPU-compatible/optimized version of the generator from Random123.jl +struct Philox2x32{R} <: RandomNumbers.AbstractRNG{UInt64} end + +# default to 7 rounds; enough to pass SmallCrush +@inline Philox2x32() = Philox2x32{7}() + +@inline function Base.getproperty(rng::Philox2x32, field::Symbol) + subgroup_id = get_sub_group_local_id() + + if field === :key + @inbounds global_random_keys()[subgroup_id] + elseif field === :ctr1 + @inbounds global_random_counters()[subgroup_id] + elseif field === :ctr2 + unsafe_trunc(UInt32, get_global_linear_id()) + end +end + +@inline function Base.setproperty!(rng::Philox2x32, field::Symbol, x) + subgroup_id = get_sub_group_local_id() + + if field === :key + @inbounds global_random_keys()[subgroup_id] = x + elseif field === :ctr1 + @inbounds global_random_counters()[subgroup_id] = x + end + return rng +end + +@device_override @inline Random.default_rng() = Philox2x32() + +""" + Random.seed!(rng::Philox2x32, seed::Integer, [counter::Integer=0]) + +Seed the on-device Philox2x32 generator with an UInt32 number. +Should be called by at least one thread per warp. +""" +function Random.seed!(rng::Philox2x32, seed::Integer, counter::Integer=UInt32(0)) + rng.key = seed % UInt32 + rng.ctr1 = counter + return +end + +# seeding the implicit default RNG +@static if VERSION >= v"1.11-" + @device_override Random.seed!(seed) = + Random.seed!(Random.default_rng(), seed) +else + @device_override Random.seed!(::Random._GLOBAL_RNG, seed) = + Random.seed!(Random.default_rng(), seed) +end + +@static if VERSION >= v"1.11-" + # `Random.seed!(::AbstractRNG)` now passes a `nothing` seed value + # TODO: I don't think there is any way in OpenCL to make this nondeterministic + Random.seed!(rng::Philox2x32, seed::Nothing) = + Random.seed!(rng, kernel_state().random_seed) +else + # ... where it used to call `Random_make_seed()` + @device_override Random.make_seed() = kernel_state().random_seed +end + +""" + Random.rand(rng::Philox2x32, UInt32) + +Generate a byte of random data using the on-device Tausworthe generator. +""" +function Random.rand(rng::Philox2x32{R},::Type{UInt64}) where {R} + ctr1, ctr2, key = rng.ctr1, rng.ctr2, rng.key + + if R > 0 ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 1 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 2 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 3 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 4 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 5 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 6 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 7 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 8 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 9 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 10 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 11 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 12 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 13 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 14 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + if R > 15 key = philox2x_bumpkey(key); ctr1, ctr2 = philox2x_round(ctr1, ctr2, key); end + + # update the warp counter + # NOTE: this performs the same update on every thread in the warp, but each warp writes + # to a unique location so the duplicate writes are innocuous + # XXX: what if this overflows? we can't increment ctr2. bump the key? + rng.ctr1 += Int32(1) + + # NOTE: it's too expensive to keep both numbers around in case the user only wanted one, + # so just make our 2x32 generator return 64-bit numbers by default. + return (ctr1 % UInt64) << 32 | (ctr2 % UInt64) +end + + + +# a hacky method of exposing constant tables as constant GPU memory + +function emit_constant_array(name::Symbol, data::AbstractArray{T}) where {T} + @dispose ctx=Context() begin + T_val = convert(LLVMType, T) + T_ptr = convert(LLVMType, LLVMPtr{T,AS.UniformConstant}) + + # define function and get LLVM module + llvm_f, _ = create_function(T_ptr) + mod = LLVM.parent(llvm_f) + + # create a global memory global variable + # TODO: global_var alignment? + T_global = LLVM.ArrayType(T_val, length(data)) + # XXX: why can't we use a single name like emit_shmem + gv = GlobalVariable(mod, T_global, "gpu_$(name)_data", AS.UniformConstant) + linkage!(gv, LLVM.API.LLVMInternalLinkage) + initializer!(gv, ConstantArray(data)) + alignment!(gv, 16) + + # generate IR + @dispose builder=IRBuilder() begin + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) + + ptr = gep!(builder, T_global, gv, [ConstantInt(0), ConstantInt(0)]) + + untyped_ptr = bitcast!(builder, ptr, T_ptr) + + ret!(builder, untyped_ptr) + end + + call_function(llvm_f, LLVMPtr{T,AS.UniformConstant}) + end +end + +for var in [:ki, :wi, :fi, :ke, :we, :fe] + val = getfield(Random, var) + gpu_var = Symbol("gpu_$var") + arr_typ = :(CLDeviceArray{$(eltype(val)),$(ndims(val)),AS.UniformConstant}) + @eval @inline @generated function $gpu_var() + ptr = emit_constant_array($(QuoteNode(var)), $val) + Expr(:call, $arr_typ, $(size(val)), ptr) + end +end + +## randn + +@device_override @inline function Random.randn(rng::AbstractRNG) + @label retry + r = Random.rand(rng, Random.UInt52Raw()) + @inbounds begin + r &= 0x000fffffffffffff + rabs = Int64(r >> 1) # One bit for the sign + idx = rabs & 0xFF + x = ifelse(r % Bool, -rabs, rabs)*gpu_wi()[idx+1] + rabs < gpu_ki()[idx+1] && return x # 99.3% of the time we return here 1st try + # TODO: This code could be outlined once LLVM supports LDS access in recursively-called functions + @inbounds if idx == 0 + while true + xx = -Random.ziggurat_nor_inv_r*log(Random.rand(rng)) + yy = -log(Random.rand(rng)) + yy+yy > xx*xx && + return (rabs >> 8) % Bool ? -Random.ziggurat_nor_r-xx : Random.ziggurat_nor_r+xx + end + elseif (gpu_fi()[idx] - gpu_fi()[idx+1])*Random.rand(rng) + gpu_fi()[idx+1] < exp(-0.5*x*x) + return x # return from the triangular area + else + @goto retry + end + end +end + +@device_override @inline function Random.randn(rng::AbstractRNG, ::Type{T}) where {T <: Union{Float16, Float32}} + @invoke Random.randn(rng::AbstractRNG, T::Type{<:AbstractFloat}) +end + +## randexp + +@device_override @inline function Random.randexp(rng::AbstractRNG) + @label retry + ri = Random.rand(rng, Random.UInt52Raw()) + @inbounds begin + ri &= 0x000fffffffffffff + idx = ri & 0xFF + x = ri*gpu_we()[idx+1] + ri < gpu_ke()[idx+1] && return x # 98.9% of the time we return here 1st try + # TODO: This code could be outlined once LLVM supports LDS access in recursively-called functions + @inbounds if idx == 0 + return Random.ziggurat_exp_r - log(Random.rand(rng)) + elseif (gpu_fe()[idx] - gpu_fe()[idx+1])*Random.rand(rng) + gpu_fe()[idx+1] < exp(-x) + return x # return from the triangular area + else + @goto retry + end + end +end + +@device_override @inline function Random.randexp(rng::AbstractRNG, ::Type{T}) where {T <: Union{Float16, Float32}} + @invoke Random.randexp(rng::AbstractRNG, T::Type{<:AbstractFloat}) +end + +@device_override Random.Sampler(::Type{<:AbstractRNG}, r::AbstractUnitRange{T}, + ::Random.Repetition) where {T<:Union{Int64, UInt64}} = Random.SamplerRangeFast(r) diff --git a/src/pocl/device/runtime.jl b/src/pocl/device/runtime.jl index f980f683..b6a1aa45 100644 --- a/src/pocl/device/runtime.jl +++ b/src/pocl/device/runtime.jl @@ -30,3 +30,62 @@ function report_exception_frame(idx, func, file, line) SPIRVIntrinsics.@printf(" [%d] %s at %s:%d\n", idx, func, file, line) return end + +## kernel state + +struct KernelState + random_seed::UInt32 +end + +@inline @generated kernel_state() = GPUCompiler.kernel_state_value(KernelState) + +## intrinsics for adding and accessing additional kernel arguments + +# The amount of local shared memory we need for storing RNG state is determined +# dynamically at kernel launch time, so needs to be passed as additional arguments +# to the kernel. +# We define intrinsics that get transformed into additional kernel arguments which +# then get propagated across function calls to the caller. + +function additional_arg_intr(mod::LLVM.Module, T_state, name) + state_intr = if haskey(functions(mod), "julia.opencl.$name") + functions(mod)["julia.opencl.$name"] + else + LLVM.Function(mod, "julia.opencl.$name", LLVM.FunctionType(T_state)) + end + push!(function_attributes(state_intr), EnumAttribute("readnone", 0)) + + return state_intr +end + +# run-time equivalent +function additional_arg_value(state, name) + @dispose ctx=Context() begin + T_state = convert(LLVMType, state) + + # create function + llvm_f, _ = create_function(T_state) + mod = LLVM.parent(llvm_f) + + # get intrinsic + state_intr = additional_arg_intr(mod, T_state, name) + state_intr_ft = function_type(state_intr) + + # generate IR + @dispose builder=IRBuilder() begin + entry = BasicBlock(llvm_f, "entry") + position!(builder, entry) + + val = call!(builder, state_intr_ft, state_intr, Value[], name) + + ret!(builder, val) + end + + call_function(llvm_f, state) + end +end + +for name in [:random_keys, :random_counters] + @eval @inline @generated $name() = + additional_arg_value(LLVMPtr{UInt32, AS.Workgroup}, $(String(name))) +end diff --git a/src/pocl/nanoOpenCL.jl b/src/pocl/nanoOpenCL.jl index a706710d..8aeb08be 100644 --- a/src/pocl/nanoOpenCL.jl +++ b/src/pocl/nanoOpenCL.jl @@ -629,6 +629,22 @@ end )::cl_int end +@checked function clGetKernelSubGroupInfo( + kernel, device, param_name, input_value_size, + input_value, param_value_size, param_value, + param_value_size_ret + ) + @ccall libopencl.clGetKernelSubGroupInfo( + kernel::cl_kernel, device::cl_device_id, + param_name::cl_kernel_sub_group_info, + input_value_size::Csize_t, + input_value::Ptr{Cvoid}, + param_value_size::Csize_t, + param_value::Ptr{Cvoid}, + param_value_size_ret::Ptr{Csize_t} + )::cl_int +end + @checked function clEnqueueNDRangeKernel( command_queue, kernel, work_dim, global_work_offset, global_work_size, @@ -1227,7 +1243,7 @@ end function enqueue_kernel( k::Kernel, global_work_size, local_work_size = nothing; - global_work_offset = nothing + global_work_offset = nothing, rng_state = false, nargs = nothing ) max_work_dim = device().max_work_item_dims work_dim = length(global_work_size) @@ -1271,6 +1287,20 @@ function enqueue_kernel( # null local size means OpenCL decides end + if rng_state + if local_work_size !== nothing + num_sub_groups = KernelSubGroupInfo(k, device(), lsize).sub_group_count + else + num_sub_groups = KernelSubGroupInfo(k, device(), Csize_t[]).max_num_sub_groups + end + if nargs === nothing + nargs = k.num_args - 2 + end + rng_state_size = sizeof(UInt32) * num_sub_groups + set_arg!(k, nargs + 1, LocalMem(UInt32, rng_state_size)) + set_arg!(k, nargs + 2, LocalMem(UInt32, rng_state_size)) + end + n_events = cl_uint(0) wait_event_ids = C_NULL ret_event = Ref{cl_event}() @@ -1285,7 +1315,8 @@ end function call( k::Kernel, args...; global_size = (1,), local_size = nothing, global_work_offset = nothing, - svm_pointers::Vector{Ptr{Cvoid}} = Ptr{Cvoid}[] + svm_pointers::Vector{Ptr{Cvoid}} = Ptr{Cvoid}[], + rng_state = false ) set_args!(k, args...) if !isempty(svm_pointers) @@ -1294,7 +1325,7 @@ function call( sizeof(svm_pointers), svm_pointers ) end - return enqueue_kernel(k, global_size, local_size; global_work_offset) + return enqueue_kernel(k, global_size, local_size; global_work_offset, rng_state, nargs=length(args)) end # convert the argument values to match the kernel's signature (specified by the user) @@ -1367,6 +1398,37 @@ function Base.getproperty(ki::KernelWorkGroupInfo, s::Symbol) end end +struct KernelSubGroupInfo + kernel::Kernel + device::Device + local_work_size::Vector{Csize_t} +end +sub_group_info(k::Kernel, d::Device, lsize::Vector{Csize_t}) = KernelSubGroupInfo(k, d, lsize) + +function Base.getproperty(ki::KernelSubGroupInfo, s::Symbol) + k = getfield(ki, :kernel) + d = getfield(ki, :device) + lsize = getfield(ki, :local_work_size) + + function get(val, typ) + result = Ref{typ}() + clGetKernelSubGroupInfo(k, d, val, sizeof(lsize), lsize, sizeof(typ), result, C_NULL) + return result[] + end + + return if s == :max_sub_group_size + Int(get(CL_KERNEL_MAX_SUB_GROUP_SIZE_FOR_NDRANGE, Csize_t)) + elseif s == :sub_group_count + Int(get(CL_KERNEL_SUB_GROUP_COUNT_FOR_NDRANGE, Csize_t)) + elseif s == :max_num_sub_groups + Int(get(CL_KERNEL_MAX_NUM_SUB_GROUPS, Csize_t)) + elseif s == :compile_num_sub_groups + Int(get(CL_KERNEL_COMPILE_NUM_SUB_GROUPS, Csize_t)) + else + getfield(ki, s) + end +end + mutable struct CmdQueue const id::cl_command_queue diff --git a/src/pocl/pocl.jl b/src/pocl/pocl.jl index 1cc693c8..482547cf 100644 --- a/src/pocl/pocl.jl +++ b/src/pocl/pocl.jl @@ -41,7 +41,7 @@ function queue() end using GPUCompiler -import LLVM +using LLVM using Adapt ## device overrides @@ -60,6 +60,7 @@ import Core: LLVMPtr include("device/array.jl") include("device/quirks.jl") include("device/runtime.jl") +include("device/random.jl") function Adapt.adapt_storage(to::KernelAdaptor, xs::Array{T, N}) where {T, N} return CLDeviceArray{T, N, AS.CrossWorkgroup}(size(xs), reinterpret(LLVMPtr{T, AS.CrossWorkgroup}, pointer(xs))) diff --git a/test/random.jl b/test/random.jl new file mode 100644 index 00000000..b098de63 --- /dev/null +++ b/test/random.jl @@ -0,0 +1,166 @@ +using Random + +const n = 256 + +function apply_seed(seed) + if seed === missing + # should result in different numbers across launches + Random.seed!() + # XXX: this currently doesn't work, because of the definition in Base, + # `seed!(r::MersenneTwister=default_rng())`, which breaks overriding + # `default_rng` with a non-MersenneTwister RNG. + elseif seed !== nothing + # should result in the same numbers + Random.seed!(seed) + elseif seed === nothing + # should result in different numbers across launches, + # as determined by the seed set during module loading. + end +end + +function random_testsuite(backend) + eltypes = [Float16, Float32, Float64, Int32, UInt32, Int64, UInt64, Bool, UInt16] + + @testset "rand($T), seed $seed" for T in eltypes, seed in (nothing, #=missing,=# 1234) + # different kernel invocations should get different numbers + @testset "across launches" begin + @kernel function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = @index(Global, Linear) + @inbounds A[tid] = rand(T) + end + + a = KernelAbstractions.zeros(backend(), T, n) + b = KernelAbstractions.zeros(backend(), T, n) + + kernel(backend())(a, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + kernel(backend())(b, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + + if seed === nothing || seed === missing + @test Array(a) != Array(b) + else + @test Array(a) == Array(b) + end + end + + # multiple calls to rand should get different numbers + @testset "across calls" begin + @kernel function kernel(A::AbstractArray{T}, B::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = @index(Global, Linear) + @inbounds A[tid] = rand(T) + @inbounds B[tid] = rand(T) + end + + a = KernelAbstractions.zeros(backend(), T, n) + b = KernelAbstractions.zeros(backend(), T, n) + + kernel(backend())(a, b, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + + @test Array(a) != Array(b) + end + + if T != Bool + # different threads should get different numbers + @testset "across threads, dim $active_dim" for active_dim in 1:6 + @kernel function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + lid = @index(Local, NTuple) + gid = @index(Group, NTuple) + id = lid[1] * lid[2] * lid[3] * gid[1] * gid[2] * gid[3] + if 1 <= id <= length(A) + @inbounds A[id] = rand(T) + end + end + + tx, ty, tz, bx, by, bz = [dim == active_dim ? 3 : 1 for dim in 1:6] + gx, gy, gz = tx*bx, ty*by, tz*bz + a = KernelAbstractions.zeros(backend(), T, 3) + + kernel(backend())(a, seed, ndrange=(gx, gy, gz), workgroupsize=(tx, ty, tz)) + KernelAbstractions.synchronize(backend()) + + # NOTE: we don't just generate two numbers and compare them, instead generating a + # couple more and checking they're not all the same, in order to avoid + # occasional collisions with lower-precision types (i.e., Float16). + @test length(unique(Array(a))) > 1 + end + end + end + + @testset "basic randn($T), seed $seed" for T in filter(x -> x <: Base.IEEEFloat, eltypes), seed in (nothing, #=missing,=# 1234) + @kernel function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = @index(Global, Linear) + @inbounds A[tid] = randn(T) + end + + a = KernelAbstractions.zeros(backend(), T, n) + b = KernelAbstractions.zeros(backend(), T, n) + + kernel(backend())(a, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + kernel(backend())(b, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + + if seed === nothing || seed === missing + @test Array(a) != Array(b) + else + @test Array(a) == Array(b) + end + end + + randexp_eltypes = filter(x -> x <: Base.IEEEFloat, eltypes) + # Check if we're using POCL and if it supports Float16 for randexp + # POCL doesn't support log1p for Float16 + if backend == CPU + filter!(x -> x != Float16, randexp_eltypes) + end + + @testset "basic randexp($T), seed $seed" for T in randexp_eltypes, seed in (nothing, #=missing,=# 1234) + @kernel function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = @index(Global, Linear) + @inbounds A[tid] = randexp(T) + end + + a = KernelAbstractions.zeros(backend(), T, n) + b = KernelAbstractions.zeros(backend(), T, n) + + kernel(backend())(a, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + kernel(backend())(b, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + + if seed === nothing || seed === missing + @test Array(a) != Array(b) + else + @test Array(a) == Array(b) + end + end + + @testset "rand(::AbstractRange{$T}), seed $seed" for T in (Int32, Int64, UInt32, UInt64), seed in (nothing, #=missing,=# 1234) + @kernel function kernel(A::AbstractArray{T}, seed) where {T} + apply_seed(seed) + tid = @index(Global, Linear) + @inbounds A[tid] = rand(T(10):T(20)) + end + + a = KernelAbstractions.zeros(backend(), T, n) + b = KernelAbstractions.zeros(backend(), T, n) + + kernel(backend())(a, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + kernel(backend())(b, seed, ndrange=n, workgroupsize=n) + KernelAbstractions.synchronize(backend()) + + if seed === nothing || seed === missing + @test Array(a) != Array(b) + else + @test Array(a) == Array(b) + end + end +end diff --git a/test/testsuite.jl b/test/testsuite.jl index 31b801b3..4c0c09af 100644 --- a/test/testsuite.jl +++ b/test/testsuite.jl @@ -39,6 +39,7 @@ include("reflection.jl") include("examples.jl") include("convert.jl") include("specialfunctions.jl") +include("random.jl") function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{String}()) @conditional_testset "Unittests" skip_tests begin @@ -93,6 +94,10 @@ function testsuite(backend, backend_str, backend_mod, AT, DAT; skip_tests = Set{ examples_testsuite(backend, backend_str) end + @conditional_testset "Random" skip_tests begin + random_testsuite(backend) + end + return end