Skip to content

Commit

Permalink
Proof of concept
Browse files Browse the repository at this point in the history
[skip ci]
  • Loading branch information
erick-xanadu committed May 2, 2024
1 parent 8101f0a commit 2fb0255
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 35 deletions.
2 changes: 1 addition & 1 deletion frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
QUANTUM_COMPILATION_PASS = (
"QuantumCompilationPass",
[
"annotate-function",
"lower-mitigation",
"lower-gradients",
"adjoint-lowering",
Expand Down Expand Up @@ -210,6 +209,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
"reconcile-unrealized-casts",
"gep-inbounds",
"add-exception-handling",
"annotate-debug-callback-as-enzyme-const",
],
)

Expand Down
2 changes: 1 addition & 1 deletion mlir/include/Catalyst/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def GEPInboundsPass : Pass<"gep-inbounds"> {
let constructor = "catalyst::GEPInboundsPass()";
}

def AnnotateDebugCallbackAsEnzymeConstPass : Pass<"annotate-debug-callback-as-enzyme-const"> {
def AnnotateDebugCallbackAsEnzymeConstPass : Pass<"annotate-debug-callback-as-enzyme-const", "ModuleOp"> {
let summary = "Annotates debug callbacks as enzyme_const";

let description = [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,34 +21,6 @@

using namespace mlir;

namespace {
static constexpr llvm::StringRef debugCallback = "catalyst.debugCallback";
struct AnnotateDebugCallbackPattern : public OpRewritePattern<LLVM::CallOp> {
using OpRewritePattern<LLVM::CallOp>::OpRewritePattern;

LogicalResult match(LLVM::CallOp op) const override;
void rewrite(LLVM::CallOp op, PatternRewriter &rewriter) const override;
};

LogicalResult AnnotateDebugCallbackPattern::match(LLVM::CallOp callOp) const {
bool isDebugCallback = callOp->hasAttr(debugCallback);
return isDebugCallback ? success() : failure();
}


void AnnotateDebugCallbackPattern::rewrite(LLVM::CallOp callOp, PatternRewriter &rewriter) const
{
auto llvmPtrType = LLVM::LLVMPointerType::get(rewriter.getContext());
Value enzymeConst = rewriter.create<LLVM::AddressOfOp>(callOp->getLoc(), llvmPtrType, catalyst::gradient::enzyme_const_key);
// We need to place the an enzyme_autodiff call here.
// And the parameter needs to be:
// * callOp's FlatSymbolRefAttr
// * zip(enzyme_const, args)
rewriter.updateRootInPlace(callOp, [&] { callOp->removeAttr(debugCallback); });
}

}

namespace catalyst {

#define GEN_PASS_DEF_ANNOTATEDEBUGCALLBACKASENZYMECONSTPASS
Expand All @@ -61,12 +33,35 @@ struct AnnotateDebugCallbackAsEnzymeConstPass : impl::AnnotateDebugCallbackAsEnz
void runOnOperation() final
{

auto mod = getOperation();
StringRef pyregistryFnName = "pyregistry";
auto fnDecl = mod.lookupSymbol<LLVM::LLVMFuncOp>(pyregistryFnName);
if (!fnDecl) return;

MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add<AnnotateDebugCallbackPattern>(context);
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
signalPassFailure();
}
auto builder = OpBuilder(context);
builder.setInsertionPointToStart(mod.getBody());
auto ptrTy = LLVM::LLVMPointerType::get(context);
auto arrTy = LLVM::LLVMArrayType::get(ptrTy, 1);
auto loc = mod.getLoc();
auto isConstant = false;
auto linkage = LLVM::Linkage::External;
auto key = catalyst::gradient::enzyme_inactivefn_key;
auto glb = builder.create<LLVM::GlobalOp>(loc, arrTy, isConstant, linkage, key, nullptr);

// Create a block and push it to the global
auto *contextGlb = glb.getContext();
Block *block = new Block();
glb.getInitializerRegion().push_back(block);
builder.setInsertionPointToStart(block);

auto undef = builder.create<LLVM::UndefOp>(glb.getLoc(), arrTy);
auto fnSym = SymbolRefAttr::get(context, pyregistryFnName);
auto fnPtr = builder.create<LLVM::AddressOfOp>(glb.getLoc(), ptrTy, fnSym);
auto filledInArray = builder.create<LLVM::InsertValueOp>(glb.getLoc(), undef, fnPtr, 0);
builder.create<LLVM::ReturnOp>(glb.getLoc(), filledInArray);

Check notice on line 63 in mlir/lib/Catalyst/Transforms/AnnotateDebugCallbackAsEnzymeConstPass.cpp

View check run for this annotation

codefactor.io / CodeFactor

mlir/lib/Catalyst/Transforms/AnnotateDebugCallbackAsEnzymeConstPass.cpp#L63

Redundant blank line at the start of a code block should be deleted. (whitespace/blank_line)

Check notice on line 64 in mlir/lib/Catalyst/Transforms/AnnotateDebugCallbackAsEnzymeConstPass.cpp

View check run for this annotation

codefactor.io / CodeFactor

mlir/lib/Catalyst/Transforms/AnnotateDebugCallbackAsEnzymeConstPass.cpp#L64

Redundant blank line at the end of a code block should be deleted. (whitespace/blank_line)
}
};

Expand Down

0 comments on commit 2fb0255

Please sign in to comment.