Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "KernelAbstractions"
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
authors = ["Valentin Churavy <v.churavy@gmail.com> and contributors"]
version = "0.10.0-dev"
authors = ["Valentin Churavy <v.churavy@gmail.com> and contributors"]

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -13,6 +13,9 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
OpenCL_jll = "6cb37087-e8b6-5417-8430-1f242f1e46e4"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
Expand All @@ -34,12 +37,15 @@ SparseArraysExt = "SparseArrays"
Adapt = "0.4, 1.0, 2.0, 3.0, 4"
Atomix = "0.1, 1"
EnzymeCore = "0.7, 0.8.1"
GPUCompiler = "1.6"
GPUCompiler = "1.7.1"
InteractiveUtils = "1.6"
LLVM = "9.4.1"
LinearAlgebra = "1.6"
MacroTools = "0.5"
PrecompileTools = "1"
Random = "1"
Random123 = "1.7.1"
RandomNumbers = "1.6.0"
SPIRVIntrinsics = "0.5"
SPIRV_LLVM_Backend_jll = "20"
SPIRV_Tools_jll = "2024.4, 2025.1"
Expand Down
92 changes: 88 additions & 4 deletions src/pocl/compiler/compilation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,87 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
in(fn, known_intrinsics) ||
contains(fn, "__spirv_")

GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState

function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
mod::LLVM.Module, entry::LLVM.Function)
entry = invoke(GPUCompiler.finish_module!,
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
job, mod, entry)

# if this kernel uses our RNG, we should prime the shared state.
# XXX: these transformations should really happen at the Julia IR level...
if haskey(functions(mod), "julia.opencl.random_keys") && job.config.kernel
# insert call to `initialize_rng_state`
f = initialize_rng_state
ft = typeof(f)
tt = Tuple{}

# create a deferred compilation job for `initialize_rng_state`
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
job = CompilerJob(src, cfg, job.world)
id = length(GPUCompiler.deferred_codegen_jobs) + 1
GPUCompiler.deferred_codegen_jobs[id] = job

# generate IR for calls to `deferred_codegen` and the resulting function pointer
top_bb = first(blocks(entry))
bb = BasicBlock(top_bb, "initialize_rng")
@dispose builder=IRBuilder() begin
position!(builder, bb)
subprogram = LLVM.subprogram(entry)
if subprogram !== nothing
loc = DILocation(0, 0, subprogram)
debuglocation!(builder, loc)
end
debuglocation!(builder, first(instructions(top_bb)))

# call the `deferred_codegen` marker function
T_ptr = if LLVM.version() >= v"17"
LLVM.PointerType()
elseif VERSION >= v"1.12.0-DEV.225"
LLVM.PointerType(LLVM.Int8Type())
else
LLVM.Int64Type()
end
T_id = convert(LLVMType, Int)
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
functions(mod)["deferred_codegen"]
else
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
end
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])

# call the `initialize_rng_state` function
rt = Core.Compiler.return_type(f, tt)
llvm_rt = convert(LLVMType, rt)
llvm_ft = LLVM.FunctionType(llvm_rt)
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
call!(builder, llvm_ft, fptr)
br!(builder, top_bb)

# note the use of the device-side RNG in this kernel
push!(function_attributes(entry), StringAttribute("julia.opencl.rng", ""))
end

# XXX: put some of the above behind GPUCompiler abstractions
# (e.g., a compile-time version of `deferred_codegen`)
end
return entry
end

function GPUCompiler.finish_linked_module!(@nospecialize(job::OpenCLCompilerJob), mod::LLVM.Module)
for f in GPUCompiler.kernels(mod)
kernel_intrinsics = Dict(
"julia.opencl.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
"julia.opencl.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
)
GPUCompiler.add_input_arguments!(job, mod, f, kernel_intrinsics)
end
return
end


## compiler implementation (cache, configure, compile, and link)

Expand Down Expand Up @@ -60,10 +141,13 @@ end
function compile(@nospecialize(job::CompilerJob))
# TODO: this creates a context; cache those.
obj, meta = JuliaContext() do ctx
GPUCompiler.compile(:obj, job)
end
obj, meta = GPUCompiler.compile(:obj, job)

return (; obj, entry = LLVM.name(meta.entry))
entry = LLVM.name(meta.entry)
device_rng = StringAttribute("julia.opencl.rng", "") in collect(function_attributes(meta.entry))

(; obj, entry, device_rng)
end
end

# link into an executable kernel
Expand All @@ -74,5 +158,5 @@ function link(@nospecialize(job::CompilerJob), compiled)
error("Your device does not support SPIR-V, which is currently required for native execution.")
end
cl.build!(prog)
return cl.Kernel(prog, compiled.entry)
(; kernel=cl.Kernel(prog, compiled.entry), compiled.device_rng)
end
17 changes: 11 additions & 6 deletions src/pocl/compiler/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ end

abstract type AbstractKernel{F, TT} end

pass_arg(@nospecialize dt) = !(GPUCompiler.isghosttype(dt) || Core.Compiler.isconstType(dt))

@inline @generated function (kernel::AbstractKernel{F, TT})(
args...;
call_kwargs...
Expand All @@ -154,8 +156,7 @@ abstract type AbstractKernel{F, TT} end
args = (:(kernel.f), (:(clconvert(args[$i], svm_pointers)) for i in 1:length(args))...)

# filter out ghost arguments that shouldn't be passed
predicate = dt -> GPUCompiler.isghosttype(dt) || Core.Compiler.isconstType(dt)
to_pass = map(!predicate, sig.parameters)
to_pass = map(pass_arg, sig.parameters)
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
call_args = Union{Expr, Symbol}[x[1] for x in zip(args, to_pass) if x[2]]

Expand All @@ -167,12 +168,15 @@ abstract type AbstractKernel{F, TT} end
end
end

pushfirst!(call_t, KernelState)
pushfirst!(call_args, :(KernelState(kernel.rng_state ? Base.rand(UInt32) : UInt32(0))))

# finalize types
call_tt = Base.to_tuple_type(call_t)

return quote
svm_pointers = Ptr{Cvoid}[]
$cl.clcall(kernel.fun, $call_tt, $(call_args...); svm_pointers, call_kwargs...)
$cl.clcall(kernel.fun, $call_tt, $(call_args...); svm_pointers, kernel.rng_state, call_kwargs...)
end
end

Expand All @@ -182,6 +186,7 @@ end
struct HostKernel{F, TT} <: AbstractKernel{F, TT}
f::F
fun::cl.Kernel
rng_state::Bool
end


Expand All @@ -198,15 +203,15 @@ function clfunction(f::F, tt::TT = Tuple{}; kwargs...) where {F, TT}
cache = compiler_cache(ctx)
source = methodinstance(F, tt)
config = compiler_config(dev; kwargs...)::OpenCLCompilerConfig
fun = GPUCompiler.cached_compilation(cache, source, config, compile, link)
linked = GPUCompiler.cached_compilation(cache, source, config, compile, link)

# create a callable object that captures the function instance. we don't need to think
# about world age here, as GPUCompiler already does and will return a different object
h = hash(fun, hash(f, hash(tt)))
h = hash(linked.kernel, hash(f, hash(tt)))
kernel = get(_kernel_instances, h, nothing)
if kernel === nothing
# create the kernel state object
kernel = HostKernel{F, tt}(f, fun)
kernel = HostKernel{F, tt}(f, linked.kernel, linked.device_rng)
_kernel_instances[h] = kernel
end
return kernel::HostKernel{F, tt}
Expand Down
Loading
Loading