Skip to content

Commit

Permalink
Handle closures
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 7, 2021
1 parent 3a08c1b commit 0f500eb
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 41 deletions.
6 changes: 3 additions & 3 deletions src/Enzyme.jl
Expand Up @@ -62,7 +62,7 @@ prepare_cc(arg::Annotation, args...) = (arg.val, prepare_cc(args...)...)
ptr = Compiler.deferred_codegen(Val(f), Val(tt′), Val(true))
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(ptr)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(f, ptr)
thunk(args′...)
end

Expand All @@ -72,7 +72,7 @@ end
ptr = Compiler.deferred_codegen(Val(f), Val(tt′), Val(false))
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(ptr)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(f, ptr)
thunk(args′...)
end

Expand Down Expand Up @@ -104,7 +104,7 @@ for op in (asin,tanh)
for (T, llvm_t, suffix) in ((Float32, "float", "f"), (Float64, "double", ""))
mod = """
declare $llvm_t @$(nameof(op))$suffix($llvm_t)
define $llvm_t @entry($llvm_t) #0 {
%val = call $llvm_t @$op$suffix($llvm_t %0)
ret $llvm_t %val
Expand Down
83 changes: 57 additions & 26 deletions src/compiler.jl
Expand Up @@ -149,6 +149,17 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
args_known_values = API.IntList[]

ctx = LLVM.context(mod)
if !GPUCompiler.isghosttype(typeof(adjoint.f)) && !Core.Compiler.isconstType(typeof(adjoint.f))
push!(args_activity, API.DFT_CONSTANT)
typeTree = typetree(typeof(adjoint.f), ctx, dl)
push!(args_typeInfo, typeTree)
if split
push!(uncacheable_args, true)
else
push!(uncacheable_args, false)
end
push!(args_known_values, API.IntList())
end

for T in tt
source_typ = eltype(T)
Expand All @@ -159,7 +170,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
continue
end
isboxed = GPUCompiler.deserves_argbox(source_typ)

if T <: Const
push!(args_activity, API.DFT_CONSTANT)
elseif T <: Active
Expand All @@ -173,11 +184,11 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
push!(args_activity, API.DFT_DUP_ARG)
elseif T <: DuplicatedNoNeed
push!(args_activity, API.DFT_DUP_NONEED)
else
else
@assert("illegal annotation type")
end
T = source_typ
if isboxed
if isboxed
T = Ptr{T}
end
typeTree = typetree(T, ctx, dl)
Expand All @@ -197,7 +208,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
# If requested, the shadow return value of the function
# For each active (non duplicated) argument
# The adjoint of that argument
if rt <: Integer
if rt <: Integer || rt <: DataType
retType = API.DFT_CONSTANT
elseif rt <: AbstractFloat
retType = API.DFT_OUT_DIFF
Expand All @@ -209,7 +220,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
error("What even is $rt")
end

TA = TypeAnalysis(triple(mod))
TA = TypeAnalysis(triple(mod))
logic = Logic()

if GPUCompiler.isghosttype(rt)|| Core.Compiler.isconstType(rt)
Expand Down Expand Up @@ -402,7 +413,7 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
else
target_machine = GPUCompiler.llvm_machine(primal_job.target)
end

parallel = false
process_module = false
if parent_job !== nothing
Expand Down Expand Up @@ -469,17 +480,18 @@ end

##
# Thunk
##
##

struct CombinedAdjointThunk{f, RT, TT}
struct CombinedAdjointThunk{F, RT, TT}
fn::F
# primal::Ptr{Cvoid}
adjoint::Ptr{Cvoid}
end

@inline (thunk::CombinedAdjointThunk{F, RT, TT})(args...) where {F, RT, TT} =
enzyme_call(thunk.adjoint, TT, RT, args...)
enzyme_call(thunk.adjoint, thunk.fn, TT, RT, args...)

@generated function enzyme_call(f::Ptr{Cvoid}, tt::Type{T}, rt::Type{RT}, args::Vararg{Any, N}) where {T, RT, N}
@generated function enzyme_call(fptr::Ptr{Cvoid}, f::F, tt::Type{T}, rt::Type{RT}, args::Vararg{Any, N}) where {F, T, RT, N}
argtt = tt.parameters[1]
rettype = rt.parameters[1]
argtypes = DataType[argtt.parameters...]
Expand Down Expand Up @@ -508,6 +520,20 @@ end
# By ref values we create and need to preserve
ccexprs = Union{Expr, Symbol}[] # The expressions passed to the `llvmcall`

if !GPUCompiler.isghosttype(F) && !Core.Compiler.isconstType(F)
isboxed = GPUCompiler.deserves_argbox(F)
llvmT = isboxed ? T_prjlvalue : convert(LLVMType, F, ctx)
argexpr = :(f)
if isboxed
push!(types, Any)
else
push!(types, F)
end

push!(ccexprs, argexpr)
push!(T_wrapperargs, llvmT)
end

for (i, T) in enumerate(argtypes)
source_typ = eltype(T)
if GPUCompiler.isghosttype(source_typ) || Core.Compiler.isconstType(source_typ)
Expand All @@ -528,7 +554,7 @@ end

push!(ccexprs, argexpr)
push!(T_wrapperargs, llvmT)

T <: Const && continue

if T <: Active
Expand Down Expand Up @@ -591,29 +617,34 @@ end
realparms = LLVM.Value[]
i = target+1

if !isempty(T_JuliaSRet)
if !isempty(T_JuliaSRet)
sret = inttoptr!(builder, params[1], LLVM.PointerType(LLVM.StructType(T_JuliaSRet)))
end

activeNum = 0

if !GPUCompiler.isghosttype(F) && !Core.Compiler.isconstType(F)
push!(realparms, params[i])
i+=1
end

for T in argtypes
T′ = eltype(T)

if GPUCompiler.isghosttype(T′) || Core.Compiler.isconstType(T′)
continue
end
isboxed = GPUCompiler.deserves_argbox(T′)
push!(realparms, params[i])
i+=1
if T <: Const
elseif T <: Active
isboxed = GPUCompiler.deserves_argbox(T′)
if isboxed
ptr = gep!(builder, sret, [LLVM.ConstantInt(LLVM.IntType(64, ctx), 0), LLVM.ConstantInt(LLVM.IntType(32, ctx), activeNum)])
cst = pointercast!(builder, ptr, ptr8)
push!(realparms, ptr)

cparms = LLVM.Value[cst,
cparms = LLVM.Value[cst,
LLVM.ConstantInt(LLVM.IntType(8, ctx), 0),
LLVM.ConstantInt(LLVM.IntType(64, ctx), LLVM.storage_size(dl, Base.eltype(LLVM.llvmtype(ptr)) )),
LLVM.ConstantInt(LLVM.IntType(1, ctx), 0)]
Expand All @@ -626,8 +657,8 @@ end
end
end

# Primal Return type
if i <= size(params, 1)
# Primal Differential Return type
if rettype <: AbstractFloat || rettype <: Complex{<:AbstractFloat}
push!(realparms, params[i])
end

Expand All @@ -640,7 +671,7 @@ end

ptr = inttoptr!(builder, params[target], LLVM.PointerType(ft))
val = call!(builder, ptr, realparms)
if !isempty(T_JuliaSRet)
if !isempty(T_JuliaSRet)
activeNum = 0
returnNum = 0
for T in argtypes
Expand All @@ -662,25 +693,25 @@ end
ir = string(mod)
fn = LLVM.name(llvm_f)

if !isempty(T_JuliaSRet)
if !isempty(T_JuliaSRet)
quote
Base.@_inline_meta
sret = Ref{$(Tuple{sret_types...})}()
GC.@preserve sret begin
ptr = Base.unsafe_convert(Ptr{$(Tuple{sret_types...})}, sret)
ptr = Base.unsafe_convert(Ptr{Cvoid}, ptr)
tptr = Base.unsafe_convert(Ptr{$(Tuple{sret_types...})}, sret)
tptr = Base.unsafe_convert(Ptr{Cvoid}, tptr)
Base.llvmcall(($ir,$fn), Cvoid,
$(Tuple{Ptr{Cvoid}, Ptr{Cvoid}, types...}),
ptr, f, $(ccexprs...))
tptr, fptr, $(ccexprs...))
end
sret[]
end
else
else
quote
Base.@_inline_meta
Base.llvmcall(($ir,$fn), Cvoid,
$(Tuple{Ptr{Cvoid}, types...}),
f, $(ccexprs...))
fptr, $(ccexprs...))
end
end
end
Expand Down Expand Up @@ -728,7 +759,7 @@ function _link(job, (mod, adjoint_name, primal_name))
adjoint = params.adjoint
split = params.split

primal = job.source
primal = job.source
rt = Core.Compiler.return_type(primal.f, primal.tt)

# Now invoke the JIT
Expand All @@ -751,7 +782,7 @@ function _link(job, (mod, adjoint_name, primal_name))
end

@assert primal_name === nothing
return CombinedAdjointThunk{typeof(adjoint.f), rt, adjoint.tt}(#=primal_ptr,=# adjoint_ptr)
return CombinedAdjointThunk{typeof(adjoint.f), rt, adjoint.tt}(adjoint.f, #=primal_ptr,=# adjoint_ptr)
end

# actual compilation
Expand Down Expand Up @@ -859,4 +890,4 @@ end
include("compiler/reflection.jl")
include("compiler/validation.jl")

end
end
11 changes: 7 additions & 4 deletions src/compiler/validation.jl
Expand Up @@ -201,6 +201,9 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
data = open(flib, "r") do io
lib = readmeta(io)
sections = Sections(lib)
if !(".llvmbc" in sections)
return nothing
end
llvmbc = read(findfirst(sections, ".llvmbc"))
return llvmbc
end
Expand Down Expand Up @@ -247,7 +250,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
end
end
end

b = Builder(ctx)

position!(b, inst)
Expand All @@ -271,7 +274,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
end
end
end

b = Builder(ctx)
position!(b, inst)
replace_uses!(inst, LLVM.inttoptr!(b, replaceWith, llvmtype(inst)))
Expand Down Expand Up @@ -327,8 +330,8 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
if ptr == cglobal(:malloc)
fn = "malloc"
end
if length(fn) > 1 && fromC

if length(fn) > 1 && fromC
mod = LLVM.parent(LLVM.parent(LLVM.parent(inst)))
lfn = LLVM.API.LLVMGetNamedFunction(mod, fn)
if lfn == C_NULL
Expand Down
18 changes: 11 additions & 7 deletions src/typetree.jl
Expand Up @@ -11,7 +11,7 @@ LLVM.dispose(tt::TypeTree) = API.EnzymeFreeTypeTree(tt)

TypeTree() = TypeTree(API.EnzymeNewTypeTree())
TypeTree(CT, ctx) = TypeTree(API.EnzymeNewTypeTreeCT(CT, ctx))
function TypeTree(CT, idx, ctx)
function TypeTree(CT, idx, ctx)
tt = TypeTree(CT, ctx)
only!(tt, idx)
return tt
Expand Down Expand Up @@ -73,6 +73,10 @@ function typetree(::Type{Float64}, ctx, dl)
return TypeTree(API.DT_Double, -1, ctx)
end

function typetree(::Type{<:DataType}, ctx, dl)
return TypeTree()
end

function typetree(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl) where T
tt = typetree(T, ctx, dl)
merge!(tt, TypeTree(API.DT_Pointer, ctx))
Expand Down Expand Up @@ -123,9 +127,9 @@ function typetree(@nospecialize(T), ctx, dl)
if subT.isinlinealloc
shift!(subtree, dl, 0, sizeof(subT), offset)
else
merge!(subtree, TypeTree(API.DT_Pointer, ctx))
merge!(subtree, TypeTree(API.DT_Pointer, ctx))
only!(subtree, offset)
end
end

merge!(tt, subtree)
end
Expand All @@ -139,14 +143,14 @@ struct FnTypeInfo
end
Base.cconvert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) = fnti
function Base.unsafe_convert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo)
args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values))
rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT))
args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values))
rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT))

tts = API.CTypeTreeRef[]
for tt in fnti.argTTs
raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt))
raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt))
push!(tts, raw_tt)
end
argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts))
argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts))
return API.CFnTypeInfo(argTTs, rTT, args_kv)
end

0 comments on commit 0f500eb

Please sign in to comment.