Skip to content

Commit

Permalink
Eliminate cassette
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 2, 2021
1 parent c2d04a2 commit 45fc001
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 62 deletions.
49 changes: 1 addition & 48 deletions src/Enzyme.jl
Expand Up @@ -57,13 +57,7 @@ prepare_cc(arg::DuplicatedNoNeed, args...) = (arg.val, arg.dval, prepare_cc(args
prepare_cc(arg::Annotation, args...) = (arg.val, prepare_cc(args...)...)

@inline function autodiff(f::F, args...) where F
args′ = annotate(args...)
tt′ = Tuple{map(Core.Typeof, args′)...}
ptr = Compiler.deferred_codegen(Val(f), Val(tt′), Val(true))
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(f, ptr)
thunk(args′...)
autodiff_no_cassette(f, args...)
end

@inline function autodiff_no_cassette(f::F, args...) where F
Expand All @@ -77,47 +71,6 @@ end
end

import .Compiler: EnzymeCtx
# Ops that have intrinsics
for op in (sin, cos, tan, exp, log)
for (T, suffix) in ((Float32, "f32"), (Float64, "f64"))
llvmf = "llvm.$(nameof(op)).$suffix"
@eval begin
@inline function Cassette.overdub(::EnzymeCtx, ::typeof($op), x::$T)
ccall($llvmf, llvmcall, $T, ($T,), x)
end
end
end
end

for op in (copysign,)
for (T, suffix) in ((Float32, "f32"), (Float64, "f64"))
llvmf = "llvm.$(nameof(op)).$suffix"
@eval begin
@inline function Cassette.overdub(::EnzymeCtx, ::typeof($op), x::$T, y::$T)
ccall($llvmf, llvmcall, $T, ($T, $T), x, y)
end
end
end
end

for op in (asin,tanh)
for (T, llvm_t, suffix) in ((Float32, "float", "f"), (Float64, "double", ""))
mod = """
declare $llvm_t @$(nameof(op))$suffix($llvm_t)
define $llvm_t @entry($llvm_t) #0 {
%val = call $llvm_t @$op$suffix($llvm_t %0)
ret $llvm_t %val
}
attributes #0 = { alwaysinline }
"""
@eval begin
@inline function Cassette.overdub(::EnzymeCtx, ::typeof($op), x::$T)
Base.llvmcall(($mod, "entry"), $T, Tuple{$T}, x)
end
end
end
end

@inline function pack(args...)
ntuple(Val(length(args))) do i
Expand Down
61 changes: 58 additions & 3 deletions src/compiler.jl
Expand Up @@ -208,7 +208,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
# If requested, the shadow return value of the function
# For each active (non duplicated) argument
# The adjoint of that argument
if rt <: Integer
if rt <: Integer || rt <: DataType
retType = API.DFT_CONSTANT
elseif rt <: AbstractFloat
retType = API.DFT_OUT_DIFF
Expand Down Expand Up @@ -406,8 +406,50 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
else
primal_job = similar(parent_job, job.source)
end
mod, primalf = GPUCompiler.codegen(:llvm, primal_job, optimize=false, validate=false, parent_job=parent_job)
mod, primalf, func_origs = GPUCompiler.codegen(:llvm, primal_job, optimize=false, validate=false, parent_job=parent_job)
check_ir(job, mod)

custom = []

for (k, v) in func_origs
for (op, name) in ((Base.sin, "sin"), (Base.cos, "cos"), (Base.tan, "tan"), (Base.exp, "exp"), (Base.log, "log"), (Base.asin, "asin"), (Base.tanh, "tanh"))
if length(v.sparam_vals) >=1 && all(x->(x in [Float32, Float64] && x==v.sparam_vals[1]), v.sparam_vals) && v.def in methods(op)
push!(custom, k)
push!(function_attributes(k), EnumAttribute("noinline", 0, context(mod)))
if v.sparam_vals[1] == Float32
push!(function_attributes(k), StringAttribute("enzyme_math", name*"f", context(mod)))
else
push!(function_attributes(k), StringAttribute("enzyme_math", name, context(mod)))
end

# Need to wrap the code when outermost
if k == primalf

FT = eltype(llvmtype(k)::LLVM.PointerType)::LLVM.FunctionType

wrapper_f = LLVM.Function(mod, LLVM.name(k)*"wrap", FT)

ctx = context(mod)
let builder = Builder(ctx)
entry = BasicBlock(wrapper_f, "entry", ctx)
position!(builder, entry)

res = call!(builder, k, collect(parameters(wrapper_f)))

if return_type(FT) == LLVM.VoidType(ctx)
ret!(builder)
else
ret!(builder, res)
end

dispose(builder)
end
primalf = wrapper_f
end
end
end
end

if primal_job.target isa GPUCompiler.NativeCompilerTarget
target_machine = tm[]
else
Expand All @@ -426,7 +468,6 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
end
end


primalf = lower_convention(job, mod, primalf)
flush(stderr)
flush(stdout)
Expand All @@ -449,6 +490,20 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
augmented_primalf = nothing
end

for f in custom
iter = function_attributes(f)
elems = Vector{LLVM.API.LLVMAttributeRef}(undef, length(iter))
LLVM.API.LLVMGetAttributesAtIndex(iter.f, iter.idx, elems)
for eattr in elems
at = Attribute(eattr)
if isa(at, LLVM.EnumAttribute)
if kind(at) == "noinline"
delete!(iter, at)
end
end
end
end

linkage!(adjointf, LLVM.API.LLVMExternalLinkage)
adjointf_name = name(adjointf)

Expand Down
11 changes: 7 additions & 4 deletions src/compiler/validation.jl
Expand Up @@ -201,6 +201,9 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
data = open(flib, "r") do io
lib = readmeta(io)
sections = Sections(lib)
if !(".llvmbc" in sections)
return nothing
end
llvmbc = read(findfirst(sections, ".llvmbc"))
return llvmbc
end
Expand Down Expand Up @@ -247,7 +250,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
end
end
end

b = Builder(ctx)

position!(b, inst)
Expand All @@ -271,7 +274,7 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
end
end
end

b = Builder(ctx)
position!(b, inst)
replace_uses!(inst, LLVM.inttoptr!(b, replaceWith, llvmtype(inst)))
Expand Down Expand Up @@ -327,8 +330,8 @@ function check_ir!(job, errors, imported, inst::LLVM.CallInst)
if ptr == cglobal(:malloc)
fn = "malloc"
end
if length(fn) > 1 && fromC

if length(fn) > 1 && fromC
mod = LLVM.parent(LLVM.parent(LLVM.parent(inst)))
lfn = LLVM.API.LLVMGetNamedFunction(mod, fn)
if lfn == C_NULL
Expand Down
18 changes: 11 additions & 7 deletions src/typetree.jl
Expand Up @@ -11,7 +11,7 @@ LLVM.dispose(tt::TypeTree) = API.EnzymeFreeTypeTree(tt)

TypeTree() = TypeTree(API.EnzymeNewTypeTree())
TypeTree(CT, ctx) = TypeTree(API.EnzymeNewTypeTreeCT(CT, ctx))
function TypeTree(CT, idx, ctx)
function TypeTree(CT, idx, ctx)
tt = TypeTree(CT, ctx)
only!(tt, idx)
return tt
Expand Down Expand Up @@ -73,6 +73,10 @@ function typetree(::Type{Float64}, ctx, dl)
return TypeTree(API.DT_Double, -1, ctx)
end

function typetree(::Type{<:DataType}, ctx, dl)
return TypeTree()
end

function typetree(::Type{<:Union{Ptr{T}, Core.LLVMPtr{T}}}, ctx, dl) where T
tt = typetree(T, ctx, dl)
merge!(tt, TypeTree(API.DT_Pointer, ctx))
Expand Down Expand Up @@ -123,9 +127,9 @@ function typetree(@nospecialize(T), ctx, dl)
if subT.isinlinealloc
shift!(subtree, dl, 0, sizeof(subT), offset)
else
merge!(subtree, TypeTree(API.DT_Pointer, ctx))
merge!(subtree, TypeTree(API.DT_Pointer, ctx))
only!(subtree, offset)
end
end

merge!(tt, subtree)
end
Expand All @@ -139,14 +143,14 @@ struct FnTypeInfo
end
Base.cconvert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo) = fnti
function Base.unsafe_convert(::Type{API.CFnTypeInfo}, fnti::FnTypeInfo)
args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values))
rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT))
args_kv = Base.unsafe_convert(Ptr{API.IntList}, Base.cconvert(Ptr{API.IntList}, fnti.known_values))
rTT = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, fnti.rTT))

tts = API.CTypeTreeRef[]
for tt in fnti.argTTs
raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt))
raw_tt = Base.unsafe_convert(API.CTypeTreeRef, Base.cconvert(API.CTypeTreeRef, tt))
push!(tts, raw_tt)
end
argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts))
argTTs = Base.unsafe_convert(Ptr{API.CTypeTreeRef}, Base.cconvert(Ptr{API.CTypeTreeRef}, tts))
return API.CFnTypeInfo(argTTs, rTT, args_kv)
end

0 comments on commit 45fc001

Please sign in to comment.