From 2fb0255ab94d81f4cb5f361941bdb1d3bc22eb3d Mon Sep 17 00:00:00 2001 From: Erick Ochoa Lopez Date: Thu, 2 May 2024 16:44:39 -0400 Subject: [PATCH] Proof of concept [skip ci] --- frontend/catalyst/compiler.py | 2 +- mlir/include/Catalyst/Transforms/Passes.td | 2 +- ...AnnotateDebugCallbackAsEnzymeConstPass.cpp | 61 +++++++++---------- 3 files changed, 30 insertions(+), 35 deletions(-) diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 0ca97dbc8b..fbc5c4018d 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -139,7 +139,6 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt QUANTUM_COMPILATION_PASS = ( "QuantumCompilationPass", [ - "annotate-function", "lower-mitigation", "lower-gradients", "adjoint-lowering", @@ -210,6 +209,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt "reconcile-unrealized-casts", "gep-inbounds", "add-exception-handling", + "annotate-debug-callback-as-enzyme-const", ], ) diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index 39b3fc45bb..e8a437908f 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -130,7 +130,7 @@ def GEPInboundsPass : Pass<"gep-inbounds"> { let constructor = "catalyst::GEPInboundsPass()"; } -def AnnotateDebugCallbackAsEnzymeConstPass : Pass<"annotate-debug-callback-as-enzyme-const"> { +def AnnotateDebugCallbackAsEnzymeConstPass : Pass<"annotate-debug-callback-as-enzyme-const", "ModuleOp"> { let summary = "Annotates debug callbacks as enzyme_const"; let description = [{ diff --git a/mlir/lib/Catalyst/Transforms/AnnotateDebugCallbackAsEnzymeConstPass.cpp b/mlir/lib/Catalyst/Transforms/AnnotateDebugCallbackAsEnzymeConstPass.cpp index 67e6cf3b0a..eb731185cd 100644 --- a/mlir/lib/Catalyst/Transforms/AnnotateDebugCallbackAsEnzymeConstPass.cpp +++ b/mlir/lib/Catalyst/Transforms/AnnotateDebugCallbackAsEnzymeConstPass.cpp @@ -21,34 +21,6 @@ using namespace mlir; -namespace { -static constexpr llvm::StringRef debugCallback = "catalyst.debugCallback"; -struct AnnotateDebugCallbackPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult match(LLVM::CallOp op) const override; - void rewrite(LLVM::CallOp op, PatternRewriter &rewriter) const override; -}; - -LogicalResult AnnotateDebugCallbackPattern::match(LLVM::CallOp callOp) const { - bool isDebugCallback = callOp->hasAttr(debugCallback); - return isDebugCallback ? success() : failure(); -} - - -void AnnotateDebugCallbackPattern::rewrite(LLVM::CallOp callOp, PatternRewriter &rewriter) const -{ - auto llvmPtrType = LLVM::LLVMPointerType::get(rewriter.getContext()); - Value enzymeConst = rewriter.create(callOp->getLoc(), llvmPtrType, catalyst::gradient::enzyme_const_key); - // We need to place the an enzyme_autodiff call here. - // And the parameter needs to be: - // * callOp's FlatSymbolRefAttr - // * zip(enzyme_const, args) - rewriter.updateRootInPlace(callOp, [&] { callOp->removeAttr(debugCallback); }); -} - -} - namespace catalyst { #define GEN_PASS_DEF_ANNOTATEDEBUGCALLBACKASENZYMECONSTPASS @@ -61,12 +33,35 @@ struct AnnotateDebugCallbackAsEnzymeConstPass : impl::AnnotateDebugCallbackAsEnz void runOnOperation() final { + auto mod = getOperation(); + StringRef pyregistryFnName = "pyregistry"; + auto fnDecl = mod.lookupSymbol(pyregistryFnName); + if (!fnDecl) return; + MLIRContext *context = &getContext(); - RewritePatternSet patterns(context); - patterns.add(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { - signalPassFailure(); - } + auto builder = OpBuilder(context); + builder.setInsertionPointToStart(mod.getBody()); + auto ptrTy = LLVM::LLVMPointerType::get(context); + auto arrTy = LLVM::LLVMArrayType::get(ptrTy, 1); + auto loc = mod.getLoc(); + auto isConstant = false; + auto linkage = LLVM::Linkage::External; + auto key = catalyst::gradient::enzyme_inactivefn_key; + auto glb = builder.create(loc, arrTy, isConstant, linkage, key, nullptr); + + // Create a block and push it to the global + auto *contextGlb = glb.getContext(); + Block *block = new Block(); + glb.getInitializerRegion().push_back(block); + builder.setInsertionPointToStart(block); + + auto undef = builder.create(glb.getLoc(), arrTy); + auto fnSym = SymbolRefAttr::get(context, pyregistryFnName); + auto fnPtr = builder.create(glb.getLoc(), ptrTy, fnSym); + auto filledInArray = builder.create(glb.getLoc(), undef, fnPtr, 0); + builder.create(glb.getLoc(), filledInArray); + + } };