Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Iterative compilation to deal with nested recursion.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Mar 21, 2019
1 parent e433eee commit 00225a1
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 48 deletions.
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -5,6 +5,7 @@ uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
Expand Down
1 change: 1 addition & 0 deletions src/CUDAnative.jl
Expand Up @@ -6,6 +6,7 @@ using LLVM
using LLVM.Interop

using Adapt
using DataStructures

using Pkg
using Libdl
Expand Down
116 changes: 68 additions & 48 deletions src/compiler/driver.jl
Expand Up @@ -62,7 +62,7 @@ function compile(to::Symbol, job::CompilerJob;

## LLVM IR

ir, entry = irgen(job, linfo, world)
ir, kernel = irgen(job, linfo, world)

need_library(lib) = any(f -> isdeclaration(f) &&
intrinsic_id(f) == 0 &&
Expand All @@ -76,7 +76,7 @@ function compile(to::Symbol, job::CompilerJob;

# optimize the IR
if optimize
entry = optimize!(job, ir, entry)
kernel = optimize!(job, ir, kernel)
end

runtime = load_runtime(job.cap)
Expand All @@ -90,68 +90,88 @@ function compile(to::Symbol, job::CompilerJob;
strip_debuginfo!(ir)
end

kernel_fn = LLVM.name(kernel)
kernel_ft = eltype(llvmtype(kernel))

to == :llvm && return ir, kernel


## dynamic parallelism

kernels = OrderedDict{CompilerJob, String}(job => kernel_fn)

if haskey(functions(ir), "cudanativeCompileKernel")
f = functions(ir)["cudanativeCompileKernel"]

# find dynamic kernel invocations
# TODO: recover this information earlier, from the Julia IR
worklist = Dict{Tuple{Core.Function,Type}, Vector{LLVM.CallInst}}()
for use in uses(f)
# decode the call
call = user(use)::LLVM.CallInst
ops = collect(operands(call))[1:2]
## addrspacecast
ops = LLVM.Value[first(operands(val)) for val in ops]
## inttoptr
ops = ConstantInt[first(operands(val)) for val in ops]
## integer constants
ops = convert.(Int, ops)
## actual pointer values
ops = Ptr{Any}.(ops)

dyn_f, dyn_tt = unsafe_pointer_to_objref.(ops)
calls = get!(worklist, (dyn_f, dyn_tt), LLVM.CallInst[])
push!(calls, call)
end
dyn_maker = functions(ir)["cudanativeCompileKernel"]

# iterative compilation (non-recursive)
changed = true
while changed
changed = false

# find dynamic kernel invocations
# TODO: recover this information earlier, from the Julia IR
worklist = MultiDict{CompilerJob, LLVM.CallInst}()
for use in uses(dyn_maker)
# decode the call
call = user(use)::LLVM.CallInst
ops = collect(operands(call))[1:2]
## addrspacecast
ops = LLVM.Value[first(operands(val)) for val in ops]
## inttoptr
ops = ConstantInt[first(operands(val)) for val in ops]
## integer constants
ops = convert.(Int, ops)
## actual pointer values
ops = Ptr{Any}.(ops)

dyn_f, dyn_tt = unsafe_pointer_to_objref.(ops)
dyn_job = CompilerJob(dyn_f, dyn_tt, job.cap, #=kernel=# true)
push!(worklist, dyn_job => call)
end

# compile and link
for (dyn_f, dyn_tt) in keys(worklist)
dyn_ctx = CompilerJob(dyn_f, dyn_tt, job.cap, true)
dyn_ir, dyn_entry =
compile(:llvm, dyn_ctx; hooks=false, optimize=optimize, strip=strip)

dyn_fn = LLVM.name(dyn_entry)
link!(ir, dyn_ir)
dyn_ir = nothing
dyn_entry = functions(ir)[dyn_fn]

# insert a call everywhere the kernel is used
for call in worklist[(dyn_f,dyn_tt)]
replace_uses!(call, dyn_entry)
unsafe_delete!(LLVM.parent(call), call)
# compile and link
for dyn_job in keys(worklist)
# cached compilation
dyn_kernel_fn = get!(kernels, dyn_job) do
dyn_ir, dyn_kernel = compile(:llvm, dyn_job; hooks=false,
optimize=optimize, strip=strip)
dyn_kernel_fn = LLVM.name(dyn_kernel)
dyn_kernel_ft = eltype(llvmtype(dyn_kernel))
link!(ir, dyn_ir)
changed = true
dyn_kernel_fn
end
dyn_kernel = functions(ir)[dyn_kernel_fn]

# insert a pointer to the function everywhere the kernel is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
for call in worklist[dyn_job]
Builder(JuliaContext()) do builder
position!(builder, call)
fptr = ptrtoint!(builder, dyn_kernel, T_ptr)
replace_uses!(call, fptr)
end
unsafe_delete!(LLVM.parent(call), call)
end
end
end

@compiler_assert isempty(uses(f)) job
unsafe_delete!(ir, f)
# all dynamic launches should have been resolved
@compiler_assert isempty(uses(dyn_maker)) job
unsafe_delete!(ir, dyn_maker)
end

to == :llvm && return ir, entry


## PTX machine code

prepare_execution!(job, ir)

check_invocation(job, entry)
check_invocation(job, kernel)
check_ir(job, ir)

asm = mcgen(job, ir, entry)
asm = mcgen(job, ir, kernel)

to == :ptx && return asm, LLVM.name(entry)
to == :ptx && return asm, kernel_fn


## CUDA objects
Expand All @@ -167,11 +187,11 @@ function compile(to::Symbol, job::CompilerJob;
# link the CUDA device library
linker = CUDAdrv.CuLink(jit_options)
CUDAdrv.add_file!(linker, libcudadevrt, CUDAdrv.LIBRARY)
CUDAdrv.add_data!(linker, LLVM.name(entry), asm)
CUDAdrv.add_data!(linker, kernel_fn, asm)
image = CUDAdrv.complete(linker)

cuda_mod = CuModule(image, jit_options)
cuda_fun = CuFunction(cuda_mod, LLVM.name(entry))
cuda_fun = CuFunction(cuda_mod, kernel_fn)

to == :cuda && return cuda_fun, cuda_mod

Expand Down
9 changes: 9 additions & 0 deletions src/compiler/optim.jl
Expand Up @@ -51,9 +51,18 @@ function optimize!(job::CompilerJob, mod::LLVM.Module, entry::LLVM.Function)

ModulePassManager() do pm
initialize!(pm)

# lower intrinsics
add!(pm, FunctionPass("LowerGCFrame", lower_gc_frame!))
aggressive_dce!(pm) # remove dead uses of ptls
add!(pm, ModulePass("LowerPTLS", lower_ptls!))

# the Julia GC lowering pass also has some clean-up that is required
function LLVMAddLateLowerGCFramePass(PM::LLVM.API.LLVMPassManagerRef)
LLVM.@apicall(:LLVMExtraAddLateLowerGCFramePass,Cvoid,(LLVM.API.LLVMPassManagerRef,), PM)
end
LLVMAddLateLowerGCFramePass(LLVM.ref(pm))

run!(pm, mod)
end
end
Expand Down

0 comments on commit 00225a1

Please sign in to comment.