Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.3"
Reactant_jll = "0.0.34"
Reactant_jll = "0.0.36"
Scratch = "1.2"
SpecialFunctions = "2"
Statistics = "1.10"
Expand Down
41 changes: 31 additions & 10 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,9 @@ function create_result(
return Meta.quot(tocopy)
end

const opt_passes::String = join(
# Optimization passes via transform dialect
const transform_passes::String = join(
[
"inline{default-pipeline=canonicalize max-iterations=4}",
"canonicalize,cse",
"canonicalize",
"enzyme-hlo-generate-td{" *
join(
[
Expand Down Expand Up @@ -273,9 +271,22 @@ const opt_passes::String = join(
"transform-interpreter",
"enzyme-hlo-remove-transform",
],
',',
",",
)

# Optimization passes which apply to an individual function
const func_passes::String = join(
["canonicalize,cse", "canonicalize", transform_passes], ","
)

const opt_passes::String = join(
["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ','
)

# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ftynse so we now have an option in enzyme which runs a pass pipeline on all funcops created by enzyme (e.g. the new forward and reverse functionops). this is crucial for higher order AD to be fast/work. However, transform dialect (for enzyme-hlo-opt) can't generate the corresponding IR on a funcop

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are you trying to do specifically? This sounds like something that can be supported with a small fix, I just need to understand which one. Specifically, it is possible to put a module containing transform ops inside a func.func, does the pass refuse to do that? Does the interpreter fail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah adding an extra module would work:

essentially we have https://github.com/EnzymeAD/Enzyme/blob/7bc73fa8291cfed08456ae93e6f460060a6c8344/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp#L206 and similar code for reverse.

which means we have a funcpassmanager. And yes it refuses to put the transform ops in the funcop itself (due to a symbol error iirc). But we could just take nf move it to a temporary inner module, run everything, then put it back (presuming no pass deletes)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"

function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
pm = MLIR.IR.PassManager()
MLIR.IR.enable_verifier!(pm, enable_verifier)
Expand Down Expand Up @@ -335,7 +346,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
if optimize === :all
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
Expand All @@ -351,7 +364,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
)
elseif optimize === :before_kernel
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
Expand Down Expand Up @@ -381,7 +396,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
)
elseif optimize === :only_enzyme
run_pass_pipeline!(mod, "enzyme-batch")
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
Expand All @@ -391,7 +408,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
)
elseif optimize === :after_enzyme
run_pass_pipeline!(mod, "enzyme-batch")
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod,
join(
Expand All @@ -407,7 +426,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
)
elseif optimize === :before_enzyme
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
run_pass_pipeline!(
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
)
run_pass_pipeline!(
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
)
Expand Down
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ include("Overlay.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
)::RT where {copy_if_inactive,RT<:RArray}
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
if haskey(seen, prev)
return seen[prev]
end
Expand Down
11 changes: 11 additions & 0 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,14 @@ end
@test stret.st2 ≈ x .+ 1
@test stret.st1 === stret.st2
end

@testset "Nested AD" begin
x = ConcreteRNumber(3.1)
f(x) = x * x * x * x
df(x) = Enzyme.gradient(Reverse, f, x)[1]
res1 = @jit df(x)
@test res1 ≈ 4 * 3.1^3
ddf(x) = Enzyme.gradient(Reverse, df, x)[1]
res2 = @jit ddf(x)
@test res2 ≈ 4 * 3 * 3.1^2
end
Loading