Skip to content

Commit

Permalink
Merge pull request #42 from wsmoses/vc/bump
Browse files Browse the repository at this point in the history
Bump versions
  • Loading branch information
vchuravy committed Feb 2, 2021
2 parents d89a4e8 + 476948f commit fbb4c20
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 16 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Enzyme"
uuid = "7da242da-08ed-463a-9acd-ee780be4f1d9"
authors = ["William Moses <wmoses@mit.edu>", "Valentin Churavy <vchuravy@mit.edu>"]
version = "0.3.0"
version = "0.3.1"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -14,7 +14,7 @@ Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
[compat]
CEnum = "0.4"
Cassette = "0.3"
Enzyme_jll = "0.0.5"
Enzyme_jll = "0.0.6"
GPUCompiler = "0.8, 0.9, 0.10"
LLVM = "3.2"
julia = "1.5"
14 changes: 13 additions & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,23 @@ Create the `FunctionSpec` pair, and lookup the primal return type.

# can't return array since that's complicated.
rt = Core.Compiler.return_type(Cassette.overdub, overdub_tt)
@assert rt<:Union{AbstractFloat, Nothing}
if !(rt<:Union{AbstractFloat, Nothing})
@error "Return type should be <:Union{Nothing, AbstractFloat}" rt adjoint primal
error("Internal Enzyme Error")
end
return primal, adjoint, rt
end


function annotate!(mod)
inactive = LLVM.StringAttribute("enzyme_inactive", "", context(mod))
for inactivefn in ["jl_gc_queue_root"]
fn = functions(mod)[inactivefn]
push!(function_attributes(fn), inactive)
end
end


function enzyme!(mod, primalf, adjoint, rt, split)
ctx = context(mod)
rettype = convert(LLVMType, rt, ctx)
Expand Down
22 changes: 16 additions & 6 deletions src/compiler/reflection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function reflect(@nospecialize(func), @nospecialize(types);
optimize::Bool=true, run_enzyme::Bool=true, second_stage::Bool=true,)
optimize::Bool=true, run_enzyme::Bool=true, second_stage::Bool=true, split::Bool=false)
primal, adjoint, rt = fspec(func, types)

target = Compiler.EnzymeTarget()
Expand All @@ -9,14 +9,24 @@ function reflect(@nospecialize(func), @nospecialize(types);
# Codegen the primal function and all its dependency in one module
mod, primalf = Compiler.codegen(:llvm, job, optimize=false, #= validate=false =#)

# Generate the wrapper, named `enzyme_entry`
llvmf = wrapper!(mod, primalf, adjoint, rt)

LLVM.strip_debuginfo!(mod)
# Run pipeline and Enzyme pass
if optimize
optimize!(mod, llvmf, run_enzyme=run_enzyme)
optimize!(mod)
end

if run_enzyme
annotate!(mod)
adjointf, augmented_primalf = enzyme!(mod, primalf, adjoint, rt, split)

if second_stage
post_optimze!(mod)
end
llvmf = adjointf
else
llvmf = primalf
end


return llvmf, mod
end

Expand Down
9 changes: 2 additions & 7 deletions src/compiler/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,16 +106,11 @@ function _thunk(@nospecialize(primal::FunctionSpec); adjoint, rt, split)
# Codegen the primal function and all its dependency in one module
mod, primalf = Compiler.codegen(:llvm, job, optimize=false, #= validate=false =#)

# LLVM.strip_debuginfo!(mod)
# Run Julia pipeline
optimize!(mod)

# Annotate
inactive = LLVM.StringAttribute("enzyme_inactive", "", context(mod))
for inactivefn in ["jl_gc_queue_root"]
fn = functions(mod)[inactivefn]
push!(function_attributes(fn), inactive)
end
# annotate
annotate!(mod)

# Generate the adjoint
adjointf, augmented_primalf = enzyme!(mod, primalf, adjoint, rt, split)
Expand Down

2 comments on commit fbb4c20

@vchuravy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/29234

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.1 -m "<description of version>" fbb4c20b6631a12ff93c808c83db39445b048bbb
git push origin v0.3.1

Please sign in to comment.