diff --git a/doc/changelog.md b/doc/changelog.md
index af5c2dcde7..25cab2909b 100644
--- a/doc/changelog.md
+++ b/doc/changelog.md
@@ -83,6 +83,10 @@
Improvements
+* `debug.callbacks` are marked as inactive. This means `debug.callbacks` will not be considered
+ as active for the computation of gradients.
+ [(#706)](https://github.com/PennyLaneAI/catalyst/pull/706)
+
* Added support for IsingZZ gate in Catalyst frontend. Previously, the IsingZZ gate would be
decomposed into a CNOT and RZ gates. However, this is not needed as the PennyLane-Lightning
simulator supports this gate.
diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py
index 89792c0af7..d091729a3f 100644
--- a/frontend/catalyst/compiler.py
+++ b/frontend/catalyst/compiler.py
@@ -213,6 +213,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
"reconcile-unrealized-casts",
"gep-inbounds",
"add-exception-handling",
+ "register-inactive-callback",
],
)
diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h
index 8b4bcd46da..3e51383a9a 100644
--- a/mlir/include/Catalyst/Transforms/Passes.h
+++ b/mlir/include/Catalyst/Transforms/Passes.h
@@ -28,6 +28,7 @@ std::unique_ptr createHloCustomCallLoweringPass();
std::unique_ptr createQnodeToAsyncLoweringPass();
std::unique_ptr createAddExceptionHandlingPass();
std::unique_ptr createGEPInboundsPass();
+std::unique_ptr createRegisterInactiveCallbackPass();
void registerAllCatalystPasses();
diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td
index 68865fcce6..7a81625fbb 100644
--- a/mlir/include/Catalyst/Transforms/Passes.td
+++ b/mlir/include/Catalyst/Transforms/Passes.td
@@ -130,4 +130,15 @@ def GEPInboundsPass : Pass<"gep-inbounds"> {
let constructor = "catalyst::GEPInboundsPass()";
}
+def RegisterInactiveCallbackPass : Pass<"register-inactive-callback", "ModuleOp"> {
+ let summary = "Register `inactive_callback` as inactive with Enzyme";
+
+ let dependentDialects = [
+ "mlir::LLVM::LLVMDialect"
+ ];
+
+ let constructor = "catalyst::createRegisterInactiveCallbackPass()";
+
+}
+
#endif // CATALYST_PASSES
diff --git a/mlir/include/Gradient/Transforms/EnzymeConstants.h b/mlir/include/Gradient/Transforms/EnzymeConstants.h
new file mode 100644
index 0000000000..2ee45cb41b
--- /dev/null
+++ b/mlir/include/Gradient/Transforms/EnzymeConstants.h
@@ -0,0 +1,28 @@
+// Copyright 2024 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+#pragma once
+
+namespace catalyst {
+namespace gradient {
+
+static constexpr const char *enzyme_autodiff_func_name = "__enzyme_autodiff";
+static constexpr const char *enzyme_allocation_key = "__enzyme_allocation_like";
+static constexpr const char *enzyme_custom_gradient_key = "__enzyme_register_gradient_";
+static constexpr const char *enzyme_like_free_key = "__enzyme_function_like_free";
+static constexpr const char *enzyme_const_key = "enzyme_const";
+static constexpr const char *enzyme_dupnoneed_key = "enzyme_dupnoneed";
+static constexpr const char *enzyme_inactivefn_key = "__enzyme_inactivefn";
+
+}
+}
diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt
index f27c719a60..9481f23b10 100644
--- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt
+++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt
@@ -16,6 +16,7 @@ file(GLOB SRC
AsyncUtils.cpp
GEPInboundsPatterns.cpp
GEPInboundsPass.cpp
+ RegisterInactiveCallbackPass.cpp
)
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
index 452f785e56..ac0fc40132 100644
--- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
+++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
@@ -40,4 +40,5 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createGEPInboundsPass);
mlir::registerPass(catalyst::createRemoveChainedSelfInversePass);
mlir::registerPass(catalyst::createAnnotateFunctionPass);
+ mlir::registerPass(catalyst::createRegisterInactiveCallbackPass);
}
diff --git a/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp b/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp
new file mode 100644
index 0000000000..80267b85a2
--- /dev/null
+++ b/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp
@@ -0,0 +1,67 @@
+// Copyright 2024 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "Catalyst/Transforms/Passes.h"
+#include "Catalyst/Transforms/Patterns.h"
+#include "Gradient/Transforms/EnzymeConstants.h"
+
+using namespace mlir;
+
+namespace catalyst {
+
+#define GEN_PASS_DEF_REGISTERINACTIVECALLBACKPASS
+#define GEN_PASS_DECL_REGISTERINACTIVECALLBACKPASS
+#include "Catalyst/Transforms/Passes.h.inc"
+
+struct RegisterInactiveCallbackPass
+ : impl::RegisterInactiveCallbackPassBase {
+ using RegisterInactiveCallbackPassBase::RegisterInactiveCallbackPassBase;
+ void runOnOperation() final
+ {
+ auto mod = getOperation();
+ StringRef inactive_callbackFnName = "inactive_callback";
+ auto fnDecl = mod.lookupSymbol(inactive_callbackFnName);
+ if (!fnDecl) {
+ return;
+ }
+ MLIRContext *context = &getContext();
+ auto builder = OpBuilder(context);
+ builder.setInsertionPointToStart(mod.getBody());
+ auto ptrTy = LLVM::LLVMPointerType::get(context);
+ auto arrTy = LLVM::LLVMArrayType::get(ptrTy, 1);
+ auto loc = mod.getLoc();
+ auto isConstant = false;
+ auto linkage = LLVM::Linkage::External;
+ auto key = catalyst::gradient::enzyme_inactivefn_key;
+ auto glb = builder.create(loc, arrTy, isConstant, linkage, key, nullptr);
+ // Create a block and push it to the global
+ Block *block = new Block();
+ glb.getInitializerRegion().push_back(block);
+ builder.setInsertionPointToStart(block);
+ auto undef = builder.create(glb.getLoc(), arrTy);
+ auto fnSym = SymbolRefAttr::get(context, inactive_callbackFnName);
+ auto fnPtr = builder.create(glb.getLoc(), ptrTy, fnSym);
+ auto filledInArray = builder.create(glb.getLoc(), undef, fnPtr, 0);
+ builder.create(glb.getLoc(), filledInArray);
+ }
+};
+
+std::unique_ptr createRegisterInactiveCallbackPass()
+{
+ return std::make_unique();
+}
+} // namespace catalyst
diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
index 4166c3e04a..25359420c5 100644
--- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
+++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
@@ -444,7 +444,7 @@ struct PythonCallOpPattern : public OpConversionPattern {
bool isVarArg = true;
LLVM::LLVMFuncOp customCallFnOp = mlir::LLVM::lookupOrCreateFn(
- mod, "pyregistry", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg);
+ mod, "inactive_callback", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg);
customCallFnOp.setPrivate();
rewriter.restoreInsertionPoint(point);
diff --git a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
index a5839be630..519bb2368e 100644
--- a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
+++ b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
@@ -34,6 +34,7 @@
#include "Catalyst/Utils/CallGraph.h"
#include "Gradient/IR/GradientOps.h"
+#include "Gradient/Transforms/EnzymeConstants.h"
#include "Gradient/Transforms/Patterns.h"
#include "Gradient/Utils/DestinationPassingStyle.h"
#include "Gradient/Utils/EinsumLinalgGeneric.h"
@@ -149,13 +150,6 @@ struct EnzymeMemRefInterfaceOptions {
bool dupNoNeed = false;
};
-static constexpr const char *enzyme_autodiff_func_name = "__enzyme_autodiff";
-static constexpr const char *enzyme_allocation_key = "__enzyme_allocation_like";
-static constexpr const char *enzyme_custom_gradient_key = "__enzyme_register_gradient_";
-static constexpr const char *enzyme_like_free_key = "__enzyme_function_like_free";
-static constexpr const char *enzyme_const_key = "enzyme_const";
-static constexpr const char *enzyme_dupnoneed_key = "enzyme_dupnoneed";
-
/// Enzyme custom gradients appear to exhibit better stability when they are registered for
/// functions where MemRefs are passed via wrapped pointers (!llvm.ptr)
/// rather than having their fields unpacked. This function automatically transforms MemRef
diff --git a/mlir/test/Catalyst/ConversionTest.mlir b/mlir/test/Catalyst/ConversionTest.mlir
index 52b6b52d8d..141ebae5e6 100644
--- a/mlir/test/Catalyst/ConversionTest.mlir
+++ b/mlir/test/Catalyst/ConversionTest.mlir
@@ -117,7 +117,7 @@ func.func @python_call () {
// CHECK: [[identifier:%.+]] = llvm.mlir.constant(0 : i64)
// CHECK: [[argcount:%.+]] = llvm.mlir.constant(0 : i64)
// CHECK: [[rescount:%.+]] = llvm.mlir.constant(0 : i64)
- // CHECK: llvm.call @pyregistry([[identifier]], [[argcount]], [[rescount]])
+ // CHECK: llvm.call @inactive_callback([[identifier]], [[argcount]], [[rescount]])
catalyst.pycallback() { identifier = 0} : () -> ()
return
}
diff --git a/mlir/test/Catalyst/RegisterInactive.mlir b/mlir/test/Catalyst/RegisterInactive.mlir
new file mode 100644
index 0000000000..ecfb0ddedd
--- /dev/null
+++ b/mlir/test/Catalyst/RegisterInactive.mlir
@@ -0,0 +1,56 @@
+// Copyright 2024 Xanadu Quantum Technologies Inc.
+
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+
+// http://www.apache.org/licenses/LICENSE-2.0
+
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// RUN: quantum-opt %s --register-inactive-callback --split-input-file --verify-diagnostics | FileCheck %s
+
+// This test just makes sure that we can
+// run the compiler with the option
+//
+// --register-inactive-callback
+//
+// and that if there are no callbacks present
+// it doesn't change anything
+
+// CHECK-LABEL: @test0
+module @test0 {
+ // CHECK-NOT: llvm.mlir.global external @__enzyme_inactivefn
+ // CHECK-LABEL: @foo
+ func.func @foo() {
+ return
+ }
+ // CHECK-NOT: llvm.mlir.global external @__enzyme_inactivefn
+}
+
+// -----
+
+// This test checks the invariant that after the transformation
+// the attribute has been removed.
+
+// CHECK-LABEL: @test1
+module @test1 {
+
+ // CHECK: llvm.mlir.global external @__enzyme_inactivefn
+ // CHECK: [[undef:%.+]] = llvm.mlir.undef
+ // CHECK: [[ptr:%.+]] = llvm.mlir.addressof @inactive_callback
+ // CHECK: [[retval:%.+]] = llvm.insertvalue [[ptr]], [[undef]][0]
+ // CHECK: llvm.return [[retval]]
+
+ llvm.func @inactive_callback(i64, i64, i64, ...)
+ llvm.func @wrapper() {
+ %0 = llvm.mlir.constant(139935726668624 : i64) : i64
+ %1 = llvm.mlir.constant(0 : i64) : i64
+ llvm.call @inactive_callback(%0, %1, %1) vararg(!llvm.func) : (i64, i64, i64) -> ()
+ llvm.return
+ }
+}
diff --git a/runtime/lib/capi/ExecutionContext.hpp b/runtime/lib/capi/ExecutionContext.hpp
index 281e56bf48..8796b8d30a 100644
--- a/runtime/lib/capi/ExecutionContext.hpp
+++ b/runtime/lib/capi/ExecutionContext.hpp
@@ -33,7 +33,7 @@ extern void callbackCall(int64_t, int64_t, int64_t, va_list);
namespace Catalyst::Runtime {
-extern "C" void pyregistry(int64_t identifier, int64_t argc, int64_t retc, ...);
+extern "C" void inactive_callback(int64_t identifier, int64_t argc, int64_t retc, ...);
class MemoryManager final {
private:
diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp
index 928b75d4c7..86d24e5670 100644
--- a/runtime/lib/capi/RuntimeCAPI.cpp
+++ b/runtime/lib/capi/RuntimeCAPI.cpp
@@ -108,7 +108,7 @@ extern "C" {
using namespace Catalyst::Runtime;
using timer = catalyst::utils::Timer;
-void pyregistry(int64_t identifier, int64_t argc, int64_t retc, ...)
+void inactive_callback(int64_t identifier, int64_t argc, int64_t retc, ...)
{
// We need to guard calls to callback.
// These are implemented in Python.