Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Merge pull request #362 from JuliaGPU/tb/dynamic_parallelism
Browse files Browse the repository at this point in the history
Dynamic parallelism
  • Loading branch information
maleadt committed Mar 22, 2019
2 parents 854a230 + da51086 commit 68f687c
Show file tree
Hide file tree
Showing 15 changed files with 865 additions and 376 deletions.
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -5,6 +5,7 @@ uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand Down
2 changes: 2 additions & 0 deletions src/CUDAnative.jl
Expand Up @@ -6,6 +6,7 @@ using LLVM
using LLVM.Interop

using Adapt
using DataStructures

using Pkg
using Libdl
Expand All @@ -27,6 +28,7 @@ include(joinpath("device", "tools.jl"))
include(joinpath("device", "pointer.jl"))
include(joinpath("device", "array.jl"))
include(joinpath("device", "cuda.jl"))
include(joinpath("device", "llvm.jl"))
include(joinpath("device", "runtime.jl"))

include("compiler.jl")
Expand Down
38 changes: 19 additions & 19 deletions src/compiler/common.jl
@@ -1,6 +1,6 @@
# common functionality

struct CompilerContext
struct CompilerJob
# core invocation
f::Core.Function
tt::DataType
Expand All @@ -13,38 +13,38 @@ struct CompilerContext
blocks_per_sm::Union{Nothing,Integer}
maxregs::Union{Nothing,Integer}

CompilerContext(f, tt, cap, kernel;
CompilerJob(f, tt, cap, kernel;
minthreads=nothing, maxthreads=nothing,
blocks_per_sm=nothing, maxregs=nothing) =
new(f, tt, cap, kernel, minthreads, maxthreads, blocks_per_sm, maxregs)
end

# global context reference
# FIXME: thread through `ctx` everywhere (deadlocks the Julia compiler when doing so with
# global job reference
# FIXME: thread through `job` everywhere (deadlocks the Julia compiler when doing so with
# the LLVM passes in CUDAnative)
global_ctx = nothing
current_job = nothing


function signature(ctx::CompilerContext)
fn = typeof(ctx.f).name.mt.name
args = join(ctx.tt.parameters, ", ")
return "$fn($(join(ctx.tt.parameters, ", ")))"
function signature(job::CompilerJob)
fn = typeof(job.f).name.mt.name
args = join(job.tt.parameters, ", ")
return "$fn($(join(job.tt.parameters, ", ")))"
end


struct KernelError <: Exception
ctx::CompilerContext
job::CompilerJob
message::String
help::Union{Nothing,String}
bt::StackTraces.StackTrace

KernelError(ctx::CompilerContext, message::String, help=nothing;
KernelError(job::CompilerJob, message::String, help=nothing;
bt=StackTraces.StackTrace()) =
new(ctx, message, help, bt)
new(job, message, help, bt)
end

function Base.showerror(io::IO, err::KernelError)
println(io, "GPU compilation of $(signature(err.ctx)) failed")
println(io, "GPU compilation of $(signature(err.job)) failed")
println(io, "KernelError: $(err.message)")
println(io)
println(io, something(err.help, "Try inspecting the generated code with any of the @device_code_... macros."))
Expand All @@ -53,10 +53,10 @@ end


struct InternalCompilerError <: Exception
ctx::CompilerContext
job::CompilerJob
message::String
meta::Dict
InternalCompilerError(ctx, message; kwargs...) = new(ctx, message, kwargs)
InternalCompilerError(job, message; kwargs...) = new(job, message, kwargs)
end

function Base.showerror(io::IO, err::InternalCompilerError)
Expand All @@ -67,8 +67,8 @@ function Base.showerror(io::IO, err::InternalCompilerError)
println(io, "\nInternalCompilerError: $(err.message)")

println(io, "\nCompiler invocation:")
for field in fieldnames(CompilerContext)
println(io, " - $field = $(repr(getfield(err.ctx, field)))")
for field in fieldnames(CompilerJob)
println(io, " - $field = $(repr(getfield(err.job, field)))")
end

if !isempty(err.meta)
Expand All @@ -87,10 +87,10 @@ function Base.showerror(io::IO, err::InternalCompilerError)
versioninfo(io)
end

macro compiler_assert(ex, ctx, kwargs...)
macro compiler_assert(ex, job, kwargs...)
msg = "$ex, at $(__source__.file):$(__source__.line)"
return :($(esc(ex)) ? $(nothing)
: throw(InternalCompilerError($(esc(ctx)), $msg;
: throw(InternalCompilerError($(esc(job)), $msg;
$(map(esc, kwargs)...)))
)
end
Expand Down
225 changes: 159 additions & 66 deletions src/compiler/driver.jl
@@ -1,106 +1,199 @@
# compiler driver and main interface

# (::CompilerContext)
# (::CompilerJob)
const compile_hook = Ref{Union{Nothing,Function}}(nothing)

"""
compile(dev::CuDevice, f, tt; kwargs...)
compile(to::Symbol, cap::VersionNumber, f, tt, kernel=true;
kernel=true, optimize=true, strip=false, ...)
Compile a function `f` invoked with types `tt` for device `dev`, returning the compiled
function module respectively of type `CuFuction` and `CuModule`.
Compile a function `f` invoked with types `tt` for device capability `cap` to one of the
following formats as specified by the `to` argument: `:julia` for Julia IR, `:llvm` for LLVM
IR, `:ptx` for PTX assembly and `:cuda` for CUDA driver objects. If the `kernel` flag is
set, specialized code generation and optimization for kernel functions is enabled.
For a list of supported keyword arguments, refer to the documentation of
[`cufunction`](@ref).
The following keyword arguments are supported:
- `hooks`: enable compiler hooks that drive reflection functions (default: true)
- `libraries`: link auxiliary bitcode libraries that may be required (default: true)
- `optimize`: optimize the code (default: true)
- `strip`: strip non-functional metadata and debug information (default: false)
Other keyword arguments can be found in the documentation of [`cufunction`](@ref).
"""
function compile(dev::CuDevice, @nospecialize(f::Core.Function), @nospecialize(tt); kwargs...)
CUDAnative.configured || error("CUDAnative.jl has not been configured; cannot JIT code.")
compile(to::Symbol, cap::VersionNumber, @nospecialize(f::Core.Function), @nospecialize(tt),
kernel::Bool=true; hooks::Bool=true, libraries::Bool=true,
optimize::Bool=true, strip::Bool=false,
kwargs...) =
compile(to, CompilerJob(f, tt, cap, kernel; kwargs...);
hooks=hooks, libraries=libraries, optimize=optimize, strip=strip)

function compile(to::Symbol, job::CompilerJob;
hooks::Bool=true, libraries::Bool=true,
optimize::Bool=true, strip::Bool=false)
@debug "(Re)compiling function" job

if hooks && compile_hook[] != nothing
global globalUnique
previous_globalUnique = globalUnique

module_asm, module_entry = compile(supported_capability(dev), f, tt; kwargs...)
compile_hook[](job)

# enable debug options based on Julia's debug setting
jit_options = Dict{CUDAdrv.CUjit_option,Any}()
if Base.JLOptions().debug_level == 1
jit_options[CUDAdrv.GENERATE_LINE_INFO] = true
elseif Base.JLOptions().debug_level >= 2
jit_options[CUDAdrv.GENERATE_DEBUG_INFO] = true
globalUnique = previous_globalUnique
end

# Link libcudadevrt
linker = CUDAdrv.CuLink(jit_options)
CUDAdrv.add_file!(linker, libcudadevrt, CUDAdrv.LIBRARY)
CUDAdrv.add_data!(linker, module_entry, module_asm)
image = CUDAdrv.complete(linker)

cuda_mod = CuModule(image, jit_options)
cuda_fun = CuFunction(cuda_mod, module_entry)
## Julia IR

return cuda_fun, cuda_mod
end
check_method(job)

# same as above, but without an active device
function compile(cap::VersionNumber, @nospecialize(f), @nospecialize(tt);
kernel=true, kwargs...)
ctx = CompilerContext(f, tt, cap, kernel; kwargs...)
# get the method instance
world = typemax(UInt)
meth = which(job.f, job.tt)
sig = Base.signature_type(job.f, job.tt)::Type
(ti, env) = ccall(:jl_type_intersection_with_env, Any,
(Any, Any), sig, meth.sig)::Core.SimpleVector
if VERSION >= v"1.2.0-DEV.320"
meth = Base.func_for_method_checked(meth, ti, env)
else
meth = Base.func_for_method_checked(meth, ti)
end
linfo = ccall(:jl_specializations_get_linfo, Ref{Core.MethodInstance},
(Any, Any, Any, UInt), meth, ti, env, world)

return compile(ctx)
end
to == :julia && return linfo

function compile(ctx::CompilerContext)
if compile_hook[] != nothing
hook = compile_hook[]
compile_hook[] = nothing

global globalUnique
previous_globalUnique = globalUnique
## LLVM IR

hook(ctx)
ir, kernel = irgen(job, linfo, world)

globalUnique = previous_globalUnique
compile_hook[] = hook
need_library(lib) = any(f -> isdeclaration(f) &&
intrinsic_id(f) == 0 &&
haskey(functions(lib), LLVM.name(f)),
functions(ir))

if libraries
libdevice = load_libdevice(job.cap)
if need_library(libdevice)
link_libdevice!(job, ir, libdevice)
end
end

if optimize
kernel = optimize!(job, ir, kernel)
end

## high-level code generation (Julia AST)
if libraries
runtime = load_runtime(job.cap)
if need_library(runtime)
link_library!(job, ir, runtime)
end
end

@debug "(Re)compiling function" ctx
verify(ir)

check_method(ctx)
if strip
strip_debuginfo!(ir)
end

kernel_fn = LLVM.name(kernel)
kernel_ft = eltype(llvmtype(kernel))

to == :llvm && return ir, kernel


## dynamic parallelism

kernels = OrderedDict{CompilerJob, String}(job => kernel_fn)

if haskey(functions(ir), "cudanativeCompileKernel")
dyn_maker = functions(ir)["cudanativeCompileKernel"]

# iterative compilation (non-recursive)
changed = true
while changed
changed = false

# find dynamic kernel invocations
# TODO: recover this information earlier, from the Julia IR
worklist = MultiDict{CompilerJob, LLVM.CallInst}()
for use in uses(dyn_maker)
# decode the call
call = user(use)::LLVM.CallInst
id = convert(Int, first(operands(call)))

global delayed_cufunctions
dyn_f, dyn_tt = delayed_cufunctions[id]
dyn_job = CompilerJob(dyn_f, dyn_tt, job.cap, #=kernel=# true)
push!(worklist, dyn_job => call)
end

# compile and link
for dyn_job in keys(worklist)
# cached compilation
dyn_kernel_fn = get!(kernels, dyn_job) do
dyn_ir, dyn_kernel = compile(:llvm, dyn_job; hooks=false,
optimize=optimize, strip=strip)
dyn_kernel_fn = LLVM.name(dyn_kernel)
dyn_kernel_ft = eltype(llvmtype(dyn_kernel))
link!(ir, dyn_ir)
changed = true
dyn_kernel_fn
end
dyn_kernel = functions(ir)[dyn_kernel_fn]

# insert a pointer to the function everywhere the kernel is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_job]
Builder(JuliaContext()) do builder
position!(builder, call)
fptr = ptrtoint!(builder, dyn_kernel, T_ptr)
replace_uses!(call, fptr)
end
unsafe_delete!(LLVM.parent(call), call)
end
end
end

# all dynamic launches should have been resolved
@compiler_assert isempty(uses(dyn_maker)) job
unsafe_delete!(ir, dyn_maker)
end

## low-level code generation (LLVM IR)

mod, entry = irgen(ctx)
## PTX machine code

need_library(lib) = any(f -> isdeclaration(f) &&
intrinsic_id(f) == 0 &&
haskey(functions(lib), LLVM.name(f)),
functions(mod))
prepare_execution!(job, ir)

libdevice = load_libdevice(ctx.cap)
if need_library(libdevice)
link_libdevice!(ctx, mod, libdevice)
end
check_invocation(job, kernel)
check_ir(job, ir)

# optimize the IR
entry = optimize!(ctx, mod, entry)
asm = mcgen(job, ir, kernel)

to == :ptx && return asm, kernel_fn

runtime = load_runtime(ctx.cap)
if need_library(runtime)
link_library!(ctx, mod, runtime)
end

prepare_execution!(ctx, mod)
## CUDA objects

check_invocation(ctx, entry)
# enable debug options based on Julia's debug setting
jit_options = Dict{CUDAdrv.CUjit_option,Any}()
if Base.JLOptions().debug_level == 1
jit_options[CUDAdrv.GENERATE_LINE_INFO] = true
elseif Base.JLOptions().debug_level >= 2
jit_options[CUDAdrv.GENERATE_DEBUG_INFO] = true
end

# check generated IR
check_ir(ctx, mod)
verify(mod)
# link the CUDA device library
linker = CUDAdrv.CuLink(jit_options)
CUDAdrv.add_file!(linker, libcudadevrt, CUDAdrv.LIBRARY)
CUDAdrv.add_data!(linker, kernel_fn, asm)
image = CUDAdrv.complete(linker)

cuda_mod = CuModule(image, jit_options)
cuda_fun = CuFunction(cuda_mod, kernel_fn)

## machine code generation (PTX assembly)
to == :cuda && return cuda_fun, cuda_mod

module_asm = mcgen(ctx, mod, entry)

return module_asm, LLVM.name(entry)
error("Unknown compilation target $to")
end

0 comments on commit 68f687c

Please sign in to comment.