Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"

[[LLVM]]
deps = ["CEnum", "Libdl", "Printf", "Unicode"]
git-tree-sha1 = "a662366a5d485dee882077e8da3e1a95a86d097f"
git-tree-sha1 = "cf5f5a54f381f290cca33a1ccdc12bcd1b453800"
repo-rev = "6ec68e6"
repo-url = "https://github.com/maleadt/LLVM.jl.git"
uuid = "929cbde3-209d-540e-8aea-75f648917ca0"
version = "2.0.0"

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ TimerOutputs = "0.5"
julia = "1.3"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Pkg"]
16 changes: 10 additions & 6 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,20 @@ function codegen(output::Symbol, job::CompilerJob;

## LLVM IR

@timeit_debug to "IR generation" begin
ir, kernel = irgen(job, method_instance, world)
ctx = context(ir)
kernel_fn = LLVM.name(kernel)
end

# always preload the runtime, and do so early; it cannot be part of any timing block
# because it recurses into the compiler
if libraries
runtime = load_runtime(job)
runtime = load_runtime(job, ctx)
runtime_fns = LLVM.name.(defs(runtime))
end

@timeit_debug to "LLVM middle-end" begin
ir, kernel = @timeit_debug to "IR generation" irgen(job, method_instance, world)
kernel_fn = LLVM.name(kernel)

# target-specific libraries
if libraries
undefined_fns = LLVM.name.(decls(ir))
Expand Down Expand Up @@ -198,16 +201,17 @@ function codegen(output::Symbol, job::CompilerJob;
strip=strip, validate=validate,
deferred_codegen=false)
dyn_kernel_fn = LLVM.name(dyn_kernel)
@assert context(dyn_ir) == ctx
link!(ir, dyn_ir)
changed = true
dyn_kernel_fn
end
dyn_kernel = functions(ir)[dyn_kernel_fn]

# insert a pointer to the function everywhere the kernel is used
T_ptr = convert(LLVMType, Ptr{Cvoid})
T_ptr = convert(LLVMType, Ptr{Cvoid}, ctx)
for call in worklist[dyn_job]
Builder(JuliaContext()) do builder
Builder(ctx) do builder
position!(builder, call)
fptr = ptrtoint!(builder, dyn_kernel, T_ptr)
replace_uses!(call, fptr)
Expand Down
25 changes: 14 additions & 11 deletions src/gcn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ end

function lower_throw_extra!(mod::LLVM.Module)
job = current_job::CompilerJob
ctx = context(mod)
changed = false
@timeit_debug to "lower throw (extra)" begin

Expand All @@ -67,7 +68,7 @@ function lower_throw_extra!(mod::LLVM.Module)
call = user(use)::LLVM.CallInst

# replace the throw with a trap
let builder = Builder(JuliaContext())
let builder = Builder(ctx)
position!(builder, call)
emit_exception!(builder, f_name, call)
dispose(builder)
Expand Down Expand Up @@ -104,14 +105,15 @@ function lower_throw_extra!(mod::LLVM.Module)
end

function emit_trap!(job::CompilerJob{GCNCompilerTarget}, builder, mod, inst)
ctx = context(mod)
trap = if haskey(functions(mod), "llvm.trap")
functions(mod)["llvm.trap"]
else
LLVM.Function(mod, "llvm.trap", LLVM.FunctionType(LLVM.VoidType(JuliaContext())))
LLVM.Function(mod, "llvm.trap", LLVM.FunctionType(LLVM.VoidType(ctx)))
end
if Base.libllvm_version < v"9"
rl_ft = LLVM.FunctionType(LLVM.Int32Type(JuliaContext()),
[LLVM.Int32Type(JuliaContext())])
rl_ft = LLVM.FunctionType(LLVM.Int32Type(ctx),
[LLVM.Int32Type(ctx)])
rl = if haskey(functions(mod), "llvm.amdgcn.readfirstlane")
functions(mod)["llvm.amdgcn.readfirstlane"]
else
Expand All @@ -124,8 +126,8 @@ function emit_trap!(job::CompilerJob{GCNCompilerTarget}, builder, mod, inst)
# this, the target will only attempt to do a "masked branch", which
# only works on vector instructions (trap is a scalar instruction, and
# therefore it is executed even when EXEC==0).
rl_val = call!(builder, rl, [ConstantInt(Int32(32), JuliaContext())])
rl_bc = inttoptr!(builder, rl_val, LLVM.PointerType(LLVM.Int32Type(JuliaContext())))
rl_val = call!(builder, rl, [ConstantInt(Int32(32), ctx)])
rl_bc = inttoptr!(builder, rl_val, LLVM.PointerType(LLVM.Int32Type(ctx)))
store!(builder, rl_val, rl_bc)
end
call!(builder, trap)
Expand All @@ -147,8 +149,9 @@ function wrapper_type(julia_t::Type, codegen_t::LLVMType)::LLVMType
end
# generate a kernel wrapper to fix & improve argument passing
function wrap_entry!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function)
ctx = context(mod)
entry_ft = eltype(llvmtype(entry_f)::LLVM.PointerType)::LLVM.FunctionType
@compiler_assert return_type(entry_ft) == LLVM.VoidType(JuliaContext()) job
@compiler_assert return_type(entry_ft) == LLVM.VoidType(ctx) job

# filter out types which don't occur in the LLVM function signatures
sig = Base.signature_type(job.source.f, job.source.tt)::Type
Expand All @@ -165,12 +168,12 @@ function wrap_entry!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function)
in zip(julia_types, parameters(entry_ft))]
wrapper_fn = LLVM.name(entry_f)
LLVM.name!(entry_f, wrapper_fn * ".inner")
wrapper_ft = LLVM.FunctionType(LLVM.VoidType(JuliaContext()), wrapper_types)
wrapper_ft = LLVM.FunctionType(LLVM.VoidType(ctx), wrapper_types)
wrapper_f = LLVM.Function(mod, wrapper_fn, wrapper_ft)

# emit IR performing the "conversions"
let builder = Builder(JuliaContext())
entry = BasicBlock(wrapper_f, "entry", JuliaContext())
let builder = Builder(ctx)
entry = BasicBlock(wrapper_f, "entry", ctx)
position!(builder, entry)

wrapper_args = Vector{LLVM.Value}()
Expand Down Expand Up @@ -211,7 +214,7 @@ function wrap_entry!(job::CompilerJob, mod::LLVM.Module, entry_f::LLVM.Function)
end

# early-inline the original entry function into the wrapper
push!(function_attributes(entry_f), EnumAttribute("alwaysinline", 0, JuliaContext()))
push!(function_attributes(entry_f), EnumAttribute("alwaysinline", 0, ctx))
linkage!(entry_f, LLVM.API.LLVMInternalLinkage)

fixup_metadata!(entry_f)
Expand Down
25 changes: 16 additions & 9 deletions src/irgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ end
else

function module_setup(job::CompilerJob, mod::LLVM.Module)
ctx = context(mod)

# configure the module
triple!(mod, llvm_triple(job.target))
datalayout!(mod, llvm_datalayout(job.target))
Expand All @@ -184,14 +186,14 @@ function module_setup(job::CompilerJob, mod::LLVM.Module)
# Set Dwarf Version to 2, the DI printer will downgrade to v2 automatically,
# but this is technically correct and the only version supported by NVPTX
LLVM.flags(mod)["Dwarf Version", LLVM.API.LLVMModuleFlagBehaviorWarning] =
Metadata(ConstantInt(Int32(2), JuliaContext()))
Metadata(ConstantInt(Int32(2), ctx))
LLVM.flags(mod)["Debug Info Version", LLVM.API.LLVMModuleFlagBehaviorError] =
Metadata(ConstantInt(DEBUG_METADATA_VERSION(), JuliaContext()))
Metadata(ConstantInt(DEBUG_METADATA_VERSION(), ctx))
else
push!(metadata(mod), "llvm.module.flags",
MDNode([ConstantInt(Int32(1), JuliaContext()), # llvm::Module::Error
MDNode([ConstantInt(Int32(1), ctx), # llvm::Module::Error
MDString("Debug Info Version"),
ConstantInt(DEBUG_METADATA_VERSION(), JuliaContext())]))
ConstantInt(DEBUG_METADATA_VERSION(), ctx)]))
end
end

Expand Down Expand Up @@ -330,12 +332,13 @@ end

function irgen(job::CompilerJob, method_instance::Core.MethodInstance, world)
entry, mod = @timeit_debug to "emission" compile_method_instance(job, method_instance, world)
ctx = context(mod)

# clean up incompatibilities
@timeit_debug to "clean-up" begin
for llvmf in functions(mod)
# only occurs in debug builds
delete!(function_attributes(llvmf), EnumAttribute("sspstrong", 0, JuliaContext()))
delete!(function_attributes(llvmf), EnumAttribute("sspstrong", 0, ctx))

if VERSION < v"1.5.0-DEV.393"
# make function names safe for ptxas
Expand Down Expand Up @@ -491,6 +494,7 @@ safe_name(x) = safe_name(repr(x))
# exception arguments) and proper debug info to unwind the stack, this pass can go.
function lower_throw!(mod::LLVM.Module)
job = current_job::CompilerJob
ctx = context(mod)
changed = false
@timeit_debug to "lower throw" begin

Expand Down Expand Up @@ -524,7 +528,7 @@ function lower_throw!(mod::LLVM.Module)
call = user(use)::LLVM.CallInst

# replace the throw with a PTX-compatible exception
let builder = Builder(JuliaContext())
let builder = Builder(ctx)
position!(builder, call)
emit_exception!(builder, name, call)
dispose(builder)
Expand Down Expand Up @@ -570,6 +574,7 @@ function emit_exception!(builder, name, inst)
bb = position(builder)
fun = LLVM.parent(bb)
mod = LLVM.parent(fun)
ctx = context(mod)

# report the exception
if Base.JLOptions().debug_level >= 1
Expand All @@ -584,12 +589,13 @@ function emit_exception!(builder, name, inst)
# report each frame
if Base.JLOptions().debug_level >= 2
rt = Runtime.get(:report_exception_frame)
ft = convert(LLVM.FunctionType, rt, ctx)
bt = backtrace(inst)
for (i,frame) in enumerate(bt)
idx = ConstantInt(rt.llvm_types[1], i)
idx = ConstantInt(parameters(ft)[1], i)
func = globalstring_ptr!(builder, String(frame.func), "di_func")
file = globalstring_ptr!(builder, String(frame.file), "di_file")
line = ConstantInt(rt.llvm_types[4], frame.line)
line = ConstantInt(parameters(ft)[4], frame.line)
call!(builder, rt, [idx, func, file, line])
end
end
Expand All @@ -601,10 +607,11 @@ function emit_exception!(builder, name, inst)
end

function emit_trap!(job::CompilerJob, builder, mod, inst)
ctx = context(mod)
trap = if haskey(functions(mod), "llvm.trap")
functions(mod)["llvm.trap"]
else
LLVM.Function(mod, "llvm.trap", LLVM.FunctionType(LLVM.VoidType(JuliaContext())))
LLVM.Function(mod, "llvm.trap", LLVM.FunctionType(LLVM.VoidType(ctx)))
end
call!(builder, trap)
end
Expand Down
3 changes: 2 additions & 1 deletion src/mcgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ end
# this pass performs that resolution at link time.
function resolve_cpu_references!(mod::LLVM.Module)
job = current_job::CompilerJob
ctx = context(mod)
changed = false

for f in functions(mod)
Expand All @@ -39,7 +40,7 @@ function resolve_cpu_references!(mod::LLVM.Module)
# eagerly resolve the address of the binding
address = ccall(:jl_cglobal, Any, (Any, Any), fn, UInt)
dereferenced = unsafe_load(address)
dereferenced = LLVM.ConstantInt(dereferenced, JuliaContext())
dereferenced = LLVM.ConstantInt(dereferenced, ctx)

function replace_bindings!(value)
changed = false
Expand Down
5 changes: 3 additions & 2 deletions src/optim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,15 @@ end
function lower_gc_frame!(fun::LLVM.Function)
job = current_job::CompilerJob
mod = LLVM.parent(fun)
ctx = context(fun)
changed = false

# plain alloc
if haskey(functions(mod), "julia.gc_alloc_obj")
alloc_obj = functions(mod)["julia.gc_alloc_obj"]
alloc_obj_ft = eltype(llvmtype(alloc_obj))
T_prjlvalue = return_type(alloc_obj_ft)
T_pjlvalue = convert(LLVMType, Any, true)
T_pjlvalue = convert(LLVMType, Any, ctx; allow_boxed=true)

for use in uses(alloc_obj)
call = user(use)::LLVM.CallInst
Expand All @@ -101,7 +102,7 @@ function lower_gc_frame!(fun::LLVM.Function)
sz = ops[2]

# replace with PTX alloc_obj
let builder = Builder(JuliaContext())
let builder = Builder(ctx)
position!(builder, call)
ptr = call!(builder, Runtime.get(:gc_pool_alloc), [sz])
replace_uses!(call, ptr)
Expand Down
26 changes: 15 additions & 11 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,36 +58,38 @@ isintrinsic(::CompilerJob{PTXCompilerTarget}, fn::String) = in(fn, ptx_intrinsic
runtime_slug(job::CompilerJob{PTXCompilerTarget}) = "ptx-sm_$(job.target.cap.major)$(job.target.cap.minor)"

function process_kernel!(job::CompilerJob{PTXCompilerTarget}, mod::LLVM.Module, kernel::LLVM.Function)
ctx = context(mod)

# property annotations
annotations = LLVM.Value[kernel]

## kernel metadata
append!(annotations, [MDString("kernel"), ConstantInt(Int32(1), JuliaContext())])
append!(annotations, [MDString("kernel"), ConstantInt(Int32(1), ctx)])

## expected CTA sizes
if job.target.minthreads != nothing
for (dim, name) in enumerate([:x, :y, :z])
bound = dim <= length(job.target.minthreads) ? job.target.minthreads[dim] : 1
append!(annotations, [MDString("reqntid$name"),
ConstantInt(Int32(bound), JuliaContext())])
ConstantInt(Int32(bound), ctx)])
end
end
if job.target.maxthreads != nothing
for (dim, name) in enumerate([:x, :y, :z])
bound = dim <= length(job.target.maxthreads) ? job.target.maxthreads[dim] : 1
append!(annotations, [MDString("maxntid$name"),
ConstantInt(Int32(bound), JuliaContext())])
ConstantInt(Int32(bound), ctx)])
end
end

if job.target.blocks_per_sm != nothing
append!(annotations, [MDString("minctasm"),
ConstantInt(Int32(job.target.blocks_per_sm), JuliaContext())])
ConstantInt(Int32(job.target.blocks_per_sm), ctx)])
end

if job.target.maxregs != nothing
append!(annotations, [MDString("maxnreg"),
ConstantInt(Int32(job.target.maxregs), JuliaContext())])
ConstantInt(Int32(job.target.maxregs), ctx)])
end

push!(metadata(mod), "nvvm.annotations", MDNode(annotations))
Expand Down Expand Up @@ -145,6 +147,7 @@ end
# is still responsible for generating structured control flow).
function hide_unreachable!(fun::LLVM.Function)
job = current_job::CompilerJob
ctx = context(fun)
changed = false
@timeit_debug to "hide unreachable" begin

Expand All @@ -153,7 +156,7 @@ function hide_unreachable!(fun::LLVM.Function)
# when calling a `noreturn` function, LLVM places an `unreachable` after the call.
# this leads to an early `ret` from the function.
attrs = function_attributes(fun)
delete!(attrs, EnumAttribute("noreturn", 0, JuliaContext()))
delete!(attrs, EnumAttribute("noreturn", 0, ctx))

# build a map of basic block predecessors
predecessors = Dict(bb => Set{LLVM.BasicBlock}() for bb in blocks(fun))
Expand Down Expand Up @@ -184,7 +187,7 @@ function hide_unreachable!(fun::LLVM.Function)
# TODO: `unreachable; unreachable`
catch ex
isa(ex, UndefRefError) || rethrow(ex)
let builder = Builder(JuliaContext())
let builder = Builder(ctx)
position!(builder, bb)

# find the strict predecessors to this block
Expand Down Expand Up @@ -220,7 +223,7 @@ function hide_unreachable!(fun::LLVM.Function)

# apply the pending terminator rewrites
@timeit_debug to "replace" if !isempty(worklist)
let builder = Builder(JuliaContext())
let builder = Builder(ctx)
for (bb, fallthrough) in worklist
position!(builder, bb)
if fallthrough !== nothing
Expand All @@ -229,7 +232,7 @@ function hide_unreachable!(fun::LLVM.Function)
# couldn't find any other successor. this happens with functions
# that only contain a single block, or when the block is dead.
ft = eltype(llvmtype(fun))
if return_type(ft) == LLVM.VoidType(JuliaContext())
if return_type(ft) == LLVM.VoidType(ctx)
# even though returning can lead to invalid control flow,
# it mostly happens with functions that just throw,
# and leaving the unreachable there would make the optimizer
Expand All @@ -252,11 +255,12 @@ end
# if LLVM knows we're trapping, code is marked `unreachable` (see `hide_unreachable!`).
function hide_trap!(mod::LLVM.Module)
job = current_job::CompilerJob
ctx = context(mod)
changed = false
@timeit_debug to "hide trap" begin

# inline assembly to exit a thread, hiding control flow from LLVM
exit_ft = LLVM.FunctionType(LLVM.VoidType(JuliaContext()))
exit_ft = LLVM.FunctionType(LLVM.VoidType(ctx))
exit = if job.target.cap < v"7"
# ptxas for old compute capabilities has a bug where it messes up the
# synchronization stack in the presence of shared memory and thread-divergend exit.
Expand All @@ -271,7 +275,7 @@ function hide_trap!(mod::LLVM.Module)
for use in uses(trap)
val = user(use)
if isa(val, LLVM.CallInst)
let builder = Builder(JuliaContext())
let builder = Builder(ctx)
position!(builder, val)
call!(builder, exit)
dispose(builder)
Expand Down
Loading