This repository has been archived by the owner on May 27, 2021. It is now read-only.
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #362 from JuliaGPU/tb/dynamic_parallelism
Dynamic parallelism
- Loading branch information
Showing
15 changed files
with
865 additions
and
376 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,106 +1,199 @@ | ||
# compiler driver and main interface | ||
|
||
# (::CompilerContext) | ||
# (::CompilerJob) | ||
const compile_hook = Ref{Union{Nothing,Function}}(nothing) | ||
|
||
""" | ||
compile(dev::CuDevice, f, tt; kwargs...) | ||
compile(to::Symbol, cap::VersionNumber, f, tt, kernel=true; | ||
kernel=true, optimize=true, strip=false, ...) | ||
Compile a function `f` invoked with types `tt` for device `dev`, returning the compiled | ||
function module respectively of type `CuFuction` and `CuModule`. | ||
Compile a function `f` invoked with types `tt` for device capability `cap` to one of the | ||
following formats as specified by the `to` argument: `:julia` for Julia IR, `:llvm` for LLVM | ||
IR, `:ptx` for PTX assembly and `:cuda` for CUDA driver objects. If the `kernel` flag is | ||
set, specialized code generation and optimization for kernel functions is enabled. | ||
For a list of supported keyword arguments, refer to the documentation of | ||
[`cufunction`](@ref). | ||
The following keyword arguments are supported: | ||
- `hooks`: enable compiler hooks that drive reflection functions (default: true) | ||
- `libraries`: link auxiliary bitcode libraries that may be required (default: true) | ||
- `optimize`: optimize the code (default: true) | ||
- `strip`: strip non-functional metadata and debug information (default: false) | ||
Other keyword arguments can be found in the documentation of [`cufunction`](@ref). | ||
""" | ||
function compile(dev::CuDevice, @nospecialize(f::Core.Function), @nospecialize(tt); kwargs...) | ||
CUDAnative.configured || error("CUDAnative.jl has not been configured; cannot JIT code.") | ||
compile(to::Symbol, cap::VersionNumber, @nospecialize(f::Core.Function), @nospecialize(tt), | ||
kernel::Bool=true; hooks::Bool=true, libraries::Bool=true, | ||
optimize::Bool=true, strip::Bool=false, | ||
kwargs...) = | ||
compile(to, CompilerJob(f, tt, cap, kernel; kwargs...); | ||
hooks=hooks, libraries=libraries, optimize=optimize, strip=strip) | ||
|
||
function compile(to::Symbol, job::CompilerJob; | ||
hooks::Bool=true, libraries::Bool=true, | ||
optimize::Bool=true, strip::Bool=false) | ||
@debug "(Re)compiling function" job | ||
|
||
if hooks && compile_hook[] != nothing | ||
global globalUnique | ||
previous_globalUnique = globalUnique | ||
|
||
module_asm, module_entry = compile(supported_capability(dev), f, tt; kwargs...) | ||
compile_hook[](job) | ||
|
||
# enable debug options based on Julia's debug setting | ||
jit_options = Dict{CUDAdrv.CUjit_option,Any}() | ||
if Base.JLOptions().debug_level == 1 | ||
jit_options[CUDAdrv.GENERATE_LINE_INFO] = true | ||
elseif Base.JLOptions().debug_level >= 2 | ||
jit_options[CUDAdrv.GENERATE_DEBUG_INFO] = true | ||
globalUnique = previous_globalUnique | ||
end | ||
|
||
# Link libcudadevrt | ||
linker = CUDAdrv.CuLink(jit_options) | ||
CUDAdrv.add_file!(linker, libcudadevrt, CUDAdrv.LIBRARY) | ||
CUDAdrv.add_data!(linker, module_entry, module_asm) | ||
image = CUDAdrv.complete(linker) | ||
|
||
cuda_mod = CuModule(image, jit_options) | ||
cuda_fun = CuFunction(cuda_mod, module_entry) | ||
## Julia IR | ||
|
||
return cuda_fun, cuda_mod | ||
end | ||
check_method(job) | ||
|
||
# same as above, but without an active device | ||
function compile(cap::VersionNumber, @nospecialize(f), @nospecialize(tt); | ||
kernel=true, kwargs...) | ||
ctx = CompilerContext(f, tt, cap, kernel; kwargs...) | ||
# get the method instance | ||
world = typemax(UInt) | ||
meth = which(job.f, job.tt) | ||
sig = Base.signature_type(job.f, job.tt)::Type | ||
(ti, env) = ccall(:jl_type_intersection_with_env, Any, | ||
(Any, Any), sig, meth.sig)::Core.SimpleVector | ||
if VERSION >= v"1.2.0-DEV.320" | ||
meth = Base.func_for_method_checked(meth, ti, env) | ||
else | ||
meth = Base.func_for_method_checked(meth, ti) | ||
end | ||
linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance}, | ||
(Any, Any, Any, UInt), meth, ti, env, world) | ||
|
||
return compile(ctx) | ||
end | ||
to == :julia && return linfo | ||
|
||
function compile(ctx::CompilerContext) | ||
if compile_hook[] != nothing | ||
hook = compile_hook[] | ||
compile_hook[] = nothing | ||
|
||
global globalUnique | ||
previous_globalUnique = globalUnique | ||
## LLVM IR | ||
|
||
hook(ctx) | ||
ir, kernel = irgen(job, linfo, world) | ||
|
||
globalUnique = previous_globalUnique | ||
compile_hook[] = hook | ||
need_library(lib) = any(f -> isdeclaration(f) && | ||
intrinsic_id(f) == 0 && | ||
haskey(functions(lib), LLVM.name(f)), | ||
functions(ir)) | ||
|
||
if libraries | ||
libdevice = load_libdevice(job.cap) | ||
if need_library(libdevice) | ||
link_libdevice!(job, ir, libdevice) | ||
end | ||
end | ||
|
||
if optimize | ||
kernel = optimize!(job, ir, kernel) | ||
end | ||
|
||
## high-level code generation (Julia AST) | ||
if libraries | ||
runtime = load_runtime(job.cap) | ||
if need_library(runtime) | ||
link_library!(job, ir, runtime) | ||
end | ||
end | ||
|
||
@debug "(Re)compiling function" ctx | ||
verify(ir) | ||
|
||
check_method(ctx) | ||
if strip | ||
strip_debuginfo!(ir) | ||
end | ||
|
||
kernel_fn = LLVM.name(kernel) | ||
kernel_ft = eltype(llvmtype(kernel)) | ||
|
||
to == :llvm && return ir, kernel | ||
|
||
|
||
## dynamic parallelism | ||
|
||
kernels = OrderedDict{CompilerJob, String}(job => kernel_fn) | ||
|
||
if haskey(functions(ir), "cudanativeCompileKernel") | ||
dyn_maker = functions(ir)["cudanativeCompileKernel"] | ||
|
||
# iterative compilation (non-recursive) | ||
changed = true | ||
while changed | ||
changed = false | ||
|
||
# find dynamic kernel invocations | ||
# TODO: recover this information earlier, from the Julia IR | ||
worklist = MultiDict{CompilerJob, LLVM.CallInst}() | ||
for use in uses(dyn_maker) | ||
# decode the call | ||
call = user(use)::LLVM.CallInst | ||
id = convert(Int, first(operands(call))) | ||
|
||
global delayed_cufunctions | ||
dyn_f, dyn_tt = delayed_cufunctions[id] | ||
dyn_job = CompilerJob(dyn_f, dyn_tt, job.cap, #=kernel=# true) | ||
push!(worklist, dyn_job => call) | ||
end | ||
|
||
# compile and link | ||
for dyn_job in keys(worklist) | ||
# cached compilation | ||
dyn_kernel_fn = get!(kernels, dyn_job) do | ||
dyn_ir, dyn_kernel = compile(:llvm, dyn_job; hooks=false, | ||
optimize=optimize, strip=strip) | ||
dyn_kernel_fn = LLVM.name(dyn_kernel) | ||
dyn_kernel_ft = eltype(llvmtype(dyn_kernel)) | ||
link!(ir, dyn_ir) | ||
changed = true | ||
dyn_kernel_fn | ||
end | ||
dyn_kernel = functions(ir)[dyn_kernel_fn] | ||
|
||
# insert a pointer to the function everywhere the kernel is used | ||
T_ptr = convert(LLVMType, Ptr{Cvoid}) | ||
for call in worklist[dyn_job] | ||
Builder(JuliaContext()) do builder | ||
position!(builder, call) | ||
fptr = ptrtoint!(builder, dyn_kernel, T_ptr) | ||
replace_uses!(call, fptr) | ||
end | ||
unsafe_delete!(LLVM.parent(call), call) | ||
end | ||
end | ||
end | ||
|
||
# all dynamic launches should have been resolved | ||
@compiler_assert isempty(uses(dyn_maker)) job | ||
unsafe_delete!(ir, dyn_maker) | ||
end | ||
|
||
## low-level code generation (LLVM IR) | ||
|
||
mod, entry = irgen(ctx) | ||
## PTX machine code | ||
|
||
need_library(lib) = any(f -> isdeclaration(f) && | ||
intrinsic_id(f) == 0 && | ||
haskey(functions(lib), LLVM.name(f)), | ||
functions(mod)) | ||
prepare_execution!(job, ir) | ||
|
||
libdevice = load_libdevice(ctx.cap) | ||
if need_library(libdevice) | ||
link_libdevice!(ctx, mod, libdevice) | ||
end | ||
check_invocation(job, kernel) | ||
check_ir(job, ir) | ||
|
||
# optimize the IR | ||
entry = optimize!(ctx, mod, entry) | ||
asm = mcgen(job, ir, kernel) | ||
|
||
to == :ptx && return asm, kernel_fn | ||
|
||
runtime = load_runtime(ctx.cap) | ||
if need_library(runtime) | ||
link_library!(ctx, mod, runtime) | ||
end | ||
|
||
prepare_execution!(ctx, mod) | ||
## CUDA objects | ||
|
||
check_invocation(ctx, entry) | ||
# enable debug options based on Julia's debug setting | ||
jit_options = Dict{CUDAdrv.CUjit_option,Any}() | ||
if Base.JLOptions().debug_level == 1 | ||
jit_options[CUDAdrv.GENERATE_LINE_INFO] = true | ||
elseif Base.JLOptions().debug_level >= 2 | ||
jit_options[CUDAdrv.GENERATE_DEBUG_INFO] = true | ||
end | ||
|
||
# check generated IR | ||
check_ir(ctx, mod) | ||
verify(mod) | ||
# link the CUDA device library | ||
linker = CUDAdrv.CuLink(jit_options) | ||
CUDAdrv.add_file!(linker, libcudadevrt, CUDAdrv.LIBRARY) | ||
CUDAdrv.add_data!(linker, kernel_fn, asm) | ||
image = CUDAdrv.complete(linker) | ||
|
||
cuda_mod = CuModule(image, jit_options) | ||
cuda_fun = CuFunction(cuda_mod, kernel_fn) | ||
|
||
## machine code generation (PTX assembly) | ||
to == :cuda && return cuda_fun, cuda_mod | ||
|
||
module_asm = mcgen(ctx, mod, entry) | ||
|
||
return module_asm, LLVM.name(entry) | ||
error("Unknown compilation target $to") | ||
end |
Oops, something went wrong.