-
Notifications
You must be signed in to change notification settings - Fork 33
Pipeline for nested enzyme differentiation #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| ) | ||
|
|
||
| # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate | ||
| # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ftynse so we now have an option in enzyme which runs a pass pipeline on all funcops created by enzyme (e.g. the new forward and reverse functionops). this is crucial for higher order AD to be fast/work. However, transform dialect (for enzyme-hlo-opt) can't generate the corresponding IR on a funcop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are you trying to do specifically? This sounds like something that can be supported with a small fix, I just need to understand which one. Specifically, it is possible to put a module containing transform ops inside a func.func, does the pass refuse to do that? Does the interpreter fail?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yeah adding an extra module would work:
essentially we have https://github.com/EnzymeAD/Enzyme/blob/7bc73fa8291cfed08456ae93e6f460060a6c8344/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp#L206 and similar code for reverse.
which means we have a funcpassmanager. And yes it refuses to put the transform ops in the funcop itself (due to a symbol error iirc). But we could just take nf move it to a temporary inner module, run everything, then put it back (presuming no pass deletes)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
|
Does this need anything other than the latest jll (0.0.35)? I am getting julia> @code_hlo optimize=true ∂∂xloss_function(model, ps, st, x, δ, y)
ERROR: failed to add pipeline:<Pass-Options-Parser>: no such option postpasses
failed to add `enzyme` with options `postpasses="arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize"`
Stacktrace:
[1] add_pipeline!(op_pass::Reactant.MLIR.IR.OpPassManager, pipeline::String)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Pass.jl:195
[2] run_pass_pipeline!(mod::Reactant.MLIR.IR.Module, pass_pipeline::String; enable_verifier::Bool)
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:294
[3] run_pass_pipeline!
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:290 [inlined]
[4] compile_mlir!(mod::Reactant.MLIR.IR.Module, f::Function, args::Tuple{…}; optimize::Bool)
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:349
[5] compile_mlir!
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:325 [inlined]
[6] #6
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:316 [inlined]
[7] context!(f::Reactant.Compiler.var"#6#7"{@Kwargs{…}, typeof(∂∂xloss_function), Tuple{…}}, ctx::Reactant.MLIR.IR.Context)
@ Reactant.MLIR.IR /mnt/software/lux/Reactant.jl/src/mlir/IR/Context.jl:76
[8] compile_mlir(f::Function, args::Tuple{…}; kwargs::@Kwargs{…})
@ Reactant.Compiler /mnt/software/lux/Reactant.jl/src/Compiler.jl:314
[9] top-level scope
@ /mnt/software/lux/Reactant.jl/src/Compiler.jl:557
Some type information was truncated. Use `show(err)` to see complete types. |
|
nope, but the last jll didn't actually update the enzyme-jax commit (and JuliaPackaging/Yggdrasil#10192 will, and also should include your latest optimization too) Locally building the jll this successfully does nested AD |
|
@avik-pal with Reactant.jl/deps/ReactantExtra/WORKSPACE Line 12 in 327d252
[basically @mofeing 's refactor broke the jll build, so I reverted it to make it build, but didnt update the right place in the reversion] |
|
so in ~2.5 hours we will have end to end nested AD [hopefully] |
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
depends on JuliaPackaging/Yggdrasil#10190