Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 5, 2023
1 parent f7382e2 commit bdec1df
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 17 deletions.
4 changes: 2 additions & 2 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ const CustomReversePass = Ptr{Cvoid}
EnzymeRegisterCallHandler(name, fwdhandle, revhandle) = ccall((:EnzymeRegisterCallHandler, libEnzyme), Cvoid, (Cstring, CustomAugmentedForwardPass, CustomReversePass), name, fwdhandle, revhandle)
EnzymeRegisterFwdCallHandler(name, fwdhandle) = ccall((:EnzymeRegisterFwdCallHandler, libEnzyme), Cvoid, (Cstring, CustomForwardPass), name, fwdhandle)

EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef), ci, fn)
EnzymeCloneFunctionWithoutReturn(fn::LLVM.Function) = ccall((:EnzymeCloneFunctionWithoutReturn, libEnzyme), LLVMValueRef, (LLVMValueRef,), fn)
EnzymeSetCalledFunction(ci::LLVM.CallInst, fn::LLVM.Function, toremove) = ccall((:EnzymeSetCalledFunction, libEnzyme), Cvoid, (LLVMValueRef, LLVMValueRef, Ptr{Int64}, Int64), ci, fn, toremove, length(toremove))
EnzymeCloneFunctionWithoutReturnOrArgs(fn::LLVM.Function, keepret, args) = ccall((:EnzymeCloneFunctionWithoutReturnOrArgs, libEnzyme), LLVMValueRef, (LLVMValueRef,UInt8,Ptr{Int64}, Int64), fn, keepret, args, length(args))
EnzymeGetShadowType(width, T) = ccall((:EnzymeGetShadowType, libEnzyme), LLVMTypeRef, (UInt64,LLVMTypeRef), width, T)

EnzymeGradientUtilsGetMode(gutils) = ccall((:EnzymeGradientUtilsGetMode, libEnzyme), CDerivativeMode, (EnzymeGradientUtilsRef,), gutils)
Expand Down
49 changes: 44 additions & 5 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2907,6 +2907,44 @@ function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, f, tt, wor
nested_codegen!(mode, mod, funcspec, world)
end

function prepare_llvm(mod, job, meta)
ctx = LLVM.context(mod)
interp = GPUCompiler.get_interpreter(job)
for f in functions(mod)
attributes = function_attributes(f)
push!(attributes, StringAttribute("enzymejl_world", string(job.world); ctx))
end
for (mi, k) in meta.compiled
k_name = GPUCompiler.safe_name(k.specfunc)
if !haskey(functions(mod), k_name)
continue
end
llvmfn = functions(mod)[k_name]

RT = Core.Compiler.typeinf_ext_toplevel(interp, mi).rettype

sret = is_sret(RT, ctx)
returnRoots = false
if sret
lRT = eltype(value_type(parameters(llvmfn)[1]))
returnRoots = deserves_rooting(lRT)
end

attributes = function_attributes(llvmfn)
push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))); ctx))
push!(attributes, StringAttribute("enzymejl_rt", string(convert(UInt, unsafe_to_pointer(RT))); ctx))
if returnRoots
attr = StringAttribute("enzymejl_returnRoots", ""; ctx)
push!(parameter_attributes(llvmfn, 2), attr)
for u in LLVM.uses(llvmfn)
u = LLVM.user(u)
@assert isa(u, LLVM.CallInst)
LLVM.API.LLVMAddCallSiteAttribute(u, LLVM.API.LLVMAttributeIndex(2), attr)
end
end
end
end

function nested_codegen!(mode::API.CDerivativeMode, mod::LLVM.Module, funcspec::Core.MethodInstance, world)
# TODO: Put a cache here index on `mod` and f->tt

Expand All @@ -2928,6 +2966,8 @@ end
# TODO
parent_job = nothing
otherMod, meta = GPUCompiler.codegen(:llvm, job; optimize=false, cleanup=false, validate=false, parent_job=parent_job, ctx)
prepare_llvm(otherMod, job, meta)

entry = name(meta.entry)

for f in functions(otherMod)
Expand Down Expand Up @@ -8155,6 +8195,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
end

mod, meta = GPUCompiler.codegen(:llvm, primal_job; optimize=false, cleanup=false, validate=false, parent_job=parent_job, ctx)
prepare_llvm(mod, primal_job, meta)
inserted_ts = false
if ctx !== nothing && ctx isa LLVM.Context
@assert ctx == context(mod)
Expand Down Expand Up @@ -8292,7 +8333,6 @@ end
for a in attrs
push!(attributes, a)
end
push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))); ctx))
push!(attributes, StringAttribute("enzymejl_job", string(convert(UInt, pointer_from_objref(jobref))); ctx))
push!(attributes, StringAttribute("enzyme_math", name; ctx))
push!(attributes, EnumAttribute("noinline", 0; ctx))
Expand Down Expand Up @@ -8446,11 +8486,14 @@ end

dispose(builder)
end
attributes = function_attributes(wrapper_f)
push!(attributes, StringAttribute("enzymejl_world", string(job.world); ctx))
primalf = wrapper_f
end

source_sig = job.source.specTypes
primalf, returnRoots = lower_convention(source_sig, mod, primalf, actualRetType)
push!(function_attributes(primalf), StringAttribute("enzymejl_world", string(job.world); ctx))

if primal_job.config.target isa GPUCompiler.NativeCompilerTarget
target_machine = JIT.get_tm()
Expand Down Expand Up @@ -8494,13 +8537,9 @@ end
continue
end
attributes = function_attributes(f)
push!(attributes, StringAttribute("enzymejl_mi", string(convert(UInt, pointer_from_objref(mi))); ctx))
push!(attributes, StringAttribute("enzymejl_job", string(convert(UInt, pointer_from_objref(jobref))); ctx))
push!(jlrules, fname)
end
for f in functions(mod)
push!(function_attributes(f), StringAttribute("enzymejl_world", string(job.world); ctx))
end

GC.@preserve job jobref begin
adjointf, augmented_primalf, TapeType = enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, abiwrap, modifiedBetween, returnPrimal, jlrules, expectedTapeType)
Expand Down
44 changes: 34 additions & 10 deletions src/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ end
function propagate_returned!(mod::LLVM.Module)
ctx = LLVM.context(mod)

tofinalize = LLVM.Function[]
tofinalize = Tuple{LLVM.Function,Bool,Vector{Int64}}[]
for fn in functions(mod)
if isempty(blocks(fn))
continue
Expand All @@ -253,12 +253,16 @@ function propagate_returned!(mod::LLVM.Module)
continue
end
argn = nothing
for i in 1:length(parameters(fn))
toremove = Int64[]
for (i, arg) in enumerate(parameters(fn))
if any(kind(attr) == kind(EnumAttribute("returned"; ctx)) for attr in collect(parameter_attributes(fn, i)))
argn = i
end
if length(collect(LLVM.uses(arg))) == 0
push!(toremove, i-1)
end
end
if argn === nothing
if argn === nothing && length(toremove) == 0
continue
end
illegalUse = linkage(fn) != LLVM.API.LLVMInternalLinkage
Expand All @@ -280,26 +284,28 @@ function propagate_returned!(mod::LLVM.Module)
illegalUse = true
continue
end
LLVM.replace_uses!(un, ops[argn])
if argn !== nothing
LLVM.replace_uses!(un, ops[argn])
end
end
if !illegalUse
push!(tofinalize, fn)
push!(tofinalize, (fn, argn === nothing, toremove))
end
end
for fn in tofinalize
for (fn, keepret, toremove) in tofinalize
try
nfn = LLVM.Function(API.EnzymeCloneFunctionWithoutReturn(fn))
todo = LLVM.CallInst[]
for u in LLVM.uses(fn)
un = LLVM.user(u)
push!(todo, un)
end
nfn = LLVM.Function(API.EnzymeCloneFunctionWithoutReturnOrArgs(fn, keepret, toremove))
for un in todo
API.EnzymeSetCalledFunction(un, nfn)
API.EnzymeSetCalledFunction(un, nfn, toremove)
end
unsafe_delete!(mod, fn)
catch
break
break
end
end
end
Expand Down Expand Up @@ -452,11 +458,23 @@ end
ctx = LLVM.context(mod)
funcT = LLVM.FunctionType(LLVM.VoidType(ctx), LLVMType[], vararg=true)
func, _ = get_function!(mod, "llvm.enzymefakeuse", funcT, [EnumAttribute("readnone"; ctx)])

rfunc, _ = get_function!(mod, "llvm.enzymefakeread", funcT, [EnumAttribute("readonly"; ctx), EnumAttribute("argmemonly"; ctx), EnumAttribute("nocapture"; ctx)])

for fn in functions(mod)
if isempty(blocks(fn))
continue
end
# Ensure that interprocedural optimizations do not delete the use of returnRoots
if length(collect(parameters(fn))) >= 2 && any(kind(attr) == kind(StringAttribute("enzymejl_returnRoots"; ctx)) for attr in collect(parameter_attributes(fn, 2)))
for u in LLVM.uses(fn)
u = LLVM.user(u)
@assert isa(u, LLVM.CallInst)
B = IRBuilder(ctx)
nextInst = LLVM.Instruction(LLVM.API.LLVMGetNextInstruction(u))
position!(B, nextInst)
cl = call!(B, funcT, rfunc, LLVM.Value[operands(u)[2]])
end
end
attrs = collect(function_attributes(fn))
prevent = any(kind(attr) == kind(EnumAttribute("noinline"; ctx)) for attr in attrs) && any(kind(attr) == kind(StringAttribute("enzyme_math"; ctx)) for attr in attrs)
if prevent
Expand All @@ -478,6 +496,11 @@ end
run!(pm, mod)
end

for u in LLVM.uses(rfunc)
u = LLVM.user(u)
unsafe_delete!(LLVM.parent(u), u)
end
unsafe_delete!(mod, rfunc)
for fn in functions(mod)
for b in blocks(fn)
inst = first(LLVM.instructions(b))
Expand All @@ -489,6 +512,7 @@ end
end
end
end
unsafe_delete!(mod, func)
detect_writeonly!(mod)
nodecayed_phis!(mod)
end
Expand Down

0 comments on commit bdec1df

Please sign in to comment.