From 00225a1079a7ea4359b38310ef40b88601581985 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 21 Mar 2019 14:31:04 +0100 Subject: [PATCH] Iterative compilation to deal with nested recursion. --- Project.toml | 1 + src/CUDAnative.jl | 1 + src/compiler/driver.jl | 116 ++++++++++++++++++++++++----------------- src/compiler/optim.jl | 9 ++++ 4 files changed, 79 insertions(+), 48 deletions(-) diff --git a/Project.toml b/Project.toml index a62a301c..fd44b603 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" CUDAapi = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3" CUDAdrv = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb" diff --git a/src/CUDAnative.jl b/src/CUDAnative.jl index 018729ae..fab16f49 100644 --- a/src/CUDAnative.jl +++ b/src/CUDAnative.jl @@ -6,6 +6,7 @@ using LLVM using LLVM.Interop using Adapt +using DataStructures using Pkg using Libdl diff --git a/src/compiler/driver.jl b/src/compiler/driver.jl index 9687b050..d6e6544f 100644 --- a/src/compiler/driver.jl +++ b/src/compiler/driver.jl @@ -62,7 +62,7 @@ function compile(to::Symbol, job::CompilerJob; ## LLVM IR - ir, entry = irgen(job, linfo, world) + ir, kernel = irgen(job, linfo, world) need_library(lib) = any(f -> isdeclaration(f) && intrinsic_id(f) == 0 && @@ -76,7 +76,7 @@ function compile(to::Symbol, job::CompilerJob; # optimize the IR if optimize - entry = optimize!(job, ir, entry) + kernel = optimize!(job, ir, kernel) end runtime = load_runtime(job.cap) @@ -90,68 +90,88 @@ function compile(to::Symbol, job::CompilerJob; strip_debuginfo!(ir) end + kernel_fn = LLVM.name(kernel) + kernel_ft = eltype(llvmtype(kernel)) + + to == :llvm && return ir, kernel + ## dynamic parallelism + kernels = OrderedDict{CompilerJob, String}(job => kernel_fn) + if haskey(functions(ir), "cudanativeCompileKernel") - f = functions(ir)["cudanativeCompileKernel"] - - # find dynamic kernel invocations - # TODO: recover this information earlier, from the Julia IR - worklist = Dict{Tuple{Core.Function,Type}, Vector{LLVM.CallInst}}() - for use in uses(f) - # decode the call - call = user(use)::LLVM.CallInst - ops = collect(operands(call))[1:2] - ## addrspacecast - ops = LLVM.Value[first(operands(val)) for val in ops] - ## inttoptr - ops = ConstantInt[first(operands(val)) for val in ops] - ## integer constants - ops = convert.(Int, ops) - ## actual pointer values - ops = Ptr{Any}.(ops) - - dyn_f, dyn_tt = unsafe_pointer_to_objref.(ops) - calls = get!(worklist, (dyn_f, dyn_tt), LLVM.CallInst[]) - push!(calls, call) - end + dyn_maker = functions(ir)["cudanativeCompileKernel"] + + # iterative compilation (non-recursive) + changed = true + while changed + changed = false + + # find dynamic kernel invocations + # TODO: recover this information earlier, from the Julia IR + worklist = MultiDict{CompilerJob, LLVM.CallInst}() + for use in uses(dyn_maker) + # decode the call + call = user(use)::LLVM.CallInst + ops = collect(operands(call))[1:2] + ## addrspacecast + ops = LLVM.Value[first(operands(val)) for val in ops] + ## inttoptr + ops = ConstantInt[first(operands(val)) for val in ops] + ## integer constants + ops = convert.(Int, ops) + ## actual pointer values + ops = Ptr{Any}.(ops) + + dyn_f, dyn_tt = unsafe_pointer_to_objref.(ops) + dyn_job = CompilerJob(dyn_f, dyn_tt, job.cap, #=kernel=# true) + push!(worklist, dyn_job => call) + end - # compile and link - for (dyn_f, dyn_tt) in keys(worklist) - dyn_ctx = CompilerJob(dyn_f, dyn_tt, job.cap, true) - dyn_ir, dyn_entry = - compile(:llvm, dyn_ctx; hooks=false, optimize=optimize, strip=strip) - - dyn_fn = LLVM.name(dyn_entry) - link!(ir, dyn_ir) - dyn_ir = nothing - dyn_entry = functions(ir)[dyn_fn] - - # insert a call everywhere the kernel is used - for call in worklist[(dyn_f,dyn_tt)] - replace_uses!(call, dyn_entry) - unsafe_delete!(LLVM.parent(call), call) + # compile and link + for dyn_job in keys(worklist) + # cached compilation + dyn_kernel_fn = get!(kernels, dyn_job) do + dyn_ir, dyn_kernel = compile(:llvm, dyn_job; hooks=false, + optimize=optimize, strip=strip) + dyn_kernel_fn = LLVM.name(dyn_kernel) + dyn_kernel_ft = eltype(llvmtype(dyn_kernel)) + link!(ir, dyn_ir) + changed = true + dyn_kernel_fn + end + dyn_kernel = functions(ir)[dyn_kernel_fn] + + # insert a pointer to the function everywhere the kernel is used + T_ptr = convert(LLVMType, Ptr{Cvoid}) + for call in worklist[dyn_job] + Builder(JuliaContext()) do builder + position!(builder, call) + fptr = ptrtoint!(builder, dyn_kernel, T_ptr) + replace_uses!(call, fptr) + end + unsafe_delete!(LLVM.parent(call), call) + end end end - @compiler_assert isempty(uses(f)) job - unsafe_delete!(ir, f) + # all dynamic launches should have been resolved + @compiler_assert isempty(uses(dyn_maker)) job + unsafe_delete!(ir, dyn_maker) end - to == :llvm && return ir, entry - ## PTX machine code prepare_execution!(job, ir) - check_invocation(job, entry) + check_invocation(job, kernel) check_ir(job, ir) - asm = mcgen(job, ir, entry) + asm = mcgen(job, ir, kernel) - to == :ptx && return asm, LLVM.name(entry) + to == :ptx && return asm, kernel_fn ## CUDA objects @@ -167,11 +187,11 @@ function compile(to::Symbol, job::CompilerJob; # link the CUDA device library linker = CUDAdrv.CuLink(jit_options) CUDAdrv.add_file!(linker, libcudadevrt, CUDAdrv.LIBRARY) - CUDAdrv.add_data!(linker, LLVM.name(entry), asm) + CUDAdrv.add_data!(linker, kernel_fn, asm) image = CUDAdrv.complete(linker) cuda_mod = CuModule(image, jit_options) - cuda_fun = CuFunction(cuda_mod, LLVM.name(entry)) + cuda_fun = CuFunction(cuda_mod, kernel_fn) to == :cuda && return cuda_fun, cuda_mod diff --git a/src/compiler/optim.jl b/src/compiler/optim.jl index f4f3302b..eae9ae3f 100644 --- a/src/compiler/optim.jl +++ b/src/compiler/optim.jl @@ -51,9 +51,18 @@ function optimize!(job::CompilerJob, mod::LLVM.Module, entry::LLVM.Function) ModulePassManager() do pm initialize!(pm) + + # lower intrinsics add!(pm, FunctionPass("LowerGCFrame", lower_gc_frame!)) aggressive_dce!(pm) # remove dead uses of ptls add!(pm, ModulePass("LowerPTLS", lower_ptls!)) + + # the Julia GC lowering pass also has some clean-up that is required + function LLVMAddLateLowerGCFramePass(PM::LLVM.API.LLVMPassManagerRef) + LLVM.@apicall(:LLVMExtraAddLateLowerGCFramePass,Cvoid,(LLVM.API.LLVMPassManagerRef,), PM) + end + LLVMAddLateLowerGCFramePass(LLVM.ref(pm)) + run!(pm, mod) end end