Skip to content

Commit

Permalink
Eliminate cassette
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored and vchuravy committed Jun 11, 2021
1 parent b008651 commit 607193c
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 51 deletions.
49 changes: 1 addition & 48 deletions src/Enzyme.jl
Expand Up @@ -57,13 +57,7 @@ prepare_cc(arg::DuplicatedNoNeed, args...) = (arg.val, arg.dval, prepare_cc(args
prepare_cc(arg::Annotation, args...) = (arg.val, prepare_cc(args...)...)

@inline function autodiff(f::F, args...) where F
args′ = annotate(args...)
tt′ = Tuple{map(Core.Typeof, args′)...}
ptr = Compiler.deferred_codegen(Val(f), Val(tt′), Val(true))
tt = Tuple{map(T->eltype(Core.Typeof(T)), args′)...}
rt = Core.Compiler.return_type(f, tt)
thunk = Compiler.CombinedAdjointThunk{F, rt, tt′}(f, ptr)
thunk(args′...)
autodiff_no_cassette(f, args...)
end

@inline function autodiff_no_cassette(f::F, args...) where F
Expand All @@ -77,47 +71,6 @@ end
end

import .Compiler: EnzymeCtx
# Ops that have intrinsics
for op in (sin, cos, tan, exp, log)
for (T, suffix) in ((Float32, "f32"), (Float64, "f64"))
llvmf = "llvm.$(nameof(op)).$suffix"
@eval begin
@inline function Cassette.overdub(::EnzymeCtx, ::typeof($op), x::$T)
ccall($llvmf, llvmcall, $T, ($T,), x)
end
end
end
end

for op in (copysign,)
for (T, suffix) in ((Float32, "f32"), (Float64, "f64"))
llvmf = "llvm.$(nameof(op)).$suffix"
@eval begin
@inline function Cassette.overdub(::EnzymeCtx, ::typeof($op), x::$T, y::$T)
ccall($llvmf, llvmcall, $T, ($T, $T), x, y)
end
end
end
end

for op in (asin,tanh)
for (T, llvm_t, suffix) in ((Float32, "float", "f"), (Float64, "double", ""))
mod = """
declare $llvm_t @$(nameof(op))$suffix($llvm_t)
define $llvm_t @entry($llvm_t) #0 {
%val = call $llvm_t @$op$suffix($llvm_t %0)
ret $llvm_t %val
}
attributes #0 = { alwaysinline }
"""
@eval begin
@inline function Cassette.overdub(::EnzymeCtx, ::typeof($op), x::$T)
Base.llvmcall(($mod, "entry"), $T, Tuple{$T}, x)
end
end
end
end

@inline function pack(args...)
ntuple(Val(length(args))) do i
Expand Down
63 changes: 60 additions & 3 deletions src/compiler.jl
Expand Up @@ -217,7 +217,7 @@ function enzyme!(job, mod, primalf, adjoint, split, parallel)
elseif GPUCompiler.isghosttype(rt) || Core.Compiler.isconstType(rt)
retType = API.DFT_CONSTANT
else
error("What even is $rt")
error("Unhandled return type $rt")
end

TA = TypeAnalysis(triple(mod))
Expand Down Expand Up @@ -406,8 +406,52 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
else
primal_job = similar(parent_job, job.source)
end
mod, primalf = GPUCompiler.codegen(:llvm, primal_job, optimize=false, validate=false, parent_job=parent_job)
mod, meta = GPUCompiler.codegen(:llvm, primal_job, optimize=false, validate=false, parent_job=parent_job)
primalf = meta.entry
check_ir(job, mod)

custom = []

for (v, k) in meta.compiled
for (op, name) in ((Base.sin, "sin"), (Base.cos, "cos"), (Base.tan, "tan"), (Base.exp, "exp"), (Base.log, "log"), (Base.asin, "asin"), (Base.tanh, "tanh"))
if length(v.sparam_vals) >=1 && all(x->(x in [Float32, Float64] && x==v.sparam_vals[1]), v.sparam_vals) && v.def in methods(op)
llvmfn = functions(mod)[k.specfunc]
push!(custom, llvmfn)
push!(function_attributes(llvmfn), EnumAttribute("noinline", 0, context(mod)))
if v.sparam_vals[1] == Float32
push!(function_attributes(llvmfn), StringAttribute("enzyme_math", name*"f", context(mod)))
else
push!(function_attributes(llvmfn), StringAttribute("enzyme_math", name, context(mod)))
end

# Need to wrap the code when outermost
if llvmfn == primalf

FT = eltype(llvmtype(llvmfn)::LLVM.PointerType)::LLVM.FunctionType

wrapper_f = LLVM.Function(mod, LLVM.name(llvmfn)*"wrap", FT)

ctx = context(mod)
let builder = Builder(ctx)
entry = BasicBlock(wrapper_f, "entry", ctx)
position!(builder, entry)

res = call!(builder, llvmfn, collect(parameters(wrapper_f)))

if return_type(FT) == LLVM.VoidType(ctx)
ret!(builder)
else
ret!(builder, res)
end

dispose(builder)
end
primalf = wrapper_f
end
end
end
end

if primal_job.target isa GPUCompiler.NativeCompilerTarget
target_machine = tm[]
else
Expand All @@ -426,7 +470,6 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
end
end


primalf = lower_convention(job, mod, primalf)
flush(stderr)
flush(stdout)
Expand All @@ -449,6 +492,20 @@ function GPUCompiler.codegen(output::Symbol, job::CompilerJob{<:EnzymeTarget};
augmented_primalf = nothing
end

for f in custom
iter = function_attributes(f)
elems = Vector{LLVM.API.LLVMAttributeRef}(undef, length(iter))
LLVM.API.LLVMGetAttributesAtIndex(iter.f, iter.idx, elems)
for eattr in elems
at = Attribute(eattr)
if isa(at, LLVM.EnumAttribute)
if kind(at) == "noinline"
delete!(iter, at)
end
end
end
end

linkage!(adjointf, LLVM.API.LLVMExternalLinkage)
adjointf_name = name(adjointf)

Expand Down

0 comments on commit 607193c

Please sign in to comment.