diff --git a/doc/changelog.md b/doc/changelog.md index af5c2dcde7..25cab2909b 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -83,6 +83,10 @@

Improvements

+* `debug.callbacks` are marked as inactive. This means `debug.callbacks` will not be considered + as active for the computation of gradients. + [(#706)](https://github.com/PennyLaneAI/catalyst/pull/706) + * Added support for IsingZZ gate in Catalyst frontend. Previously, the IsingZZ gate would be decomposed into a CNOT and RZ gates. However, this is not needed as the PennyLane-Lightning simulator supports this gate. diff --git a/frontend/catalyst/compiler.py b/frontend/catalyst/compiler.py index 89792c0af7..d091729a3f 100644 --- a/frontend/catalyst/compiler.py +++ b/frontend/catalyst/compiler.py @@ -213,6 +213,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt "reconcile-unrealized-casts", "gep-inbounds", "add-exception-handling", + "register-inactive-callback", ], ) diff --git a/mlir/include/Catalyst/Transforms/Passes.h b/mlir/include/Catalyst/Transforms/Passes.h index 8b4bcd46da..3e51383a9a 100644 --- a/mlir/include/Catalyst/Transforms/Passes.h +++ b/mlir/include/Catalyst/Transforms/Passes.h @@ -28,6 +28,7 @@ std::unique_ptr createHloCustomCallLoweringPass(); std::unique_ptr createQnodeToAsyncLoweringPass(); std::unique_ptr createAddExceptionHandlingPass(); std::unique_ptr createGEPInboundsPass(); +std::unique_ptr createRegisterInactiveCallbackPass(); void registerAllCatalystPasses(); diff --git a/mlir/include/Catalyst/Transforms/Passes.td b/mlir/include/Catalyst/Transforms/Passes.td index 68865fcce6..7a81625fbb 100644 --- a/mlir/include/Catalyst/Transforms/Passes.td +++ b/mlir/include/Catalyst/Transforms/Passes.td @@ -130,4 +130,15 @@ def GEPInboundsPass : Pass<"gep-inbounds"> { let constructor = "catalyst::GEPInboundsPass()"; } +def RegisterInactiveCallbackPass : Pass<"register-inactive-callback", "ModuleOp"> { + let summary = "Register `inactive_callback` as inactive with Enzyme"; + + let dependentDialects = [ + "mlir::LLVM::LLVMDialect" + ]; + + let constructor = "catalyst::createRegisterInactiveCallbackPass()"; + +} + #endif // CATALYST_PASSES diff --git a/mlir/include/Gradient/Transforms/EnzymeConstants.h b/mlir/include/Gradient/Transforms/EnzymeConstants.h new file mode 100644 index 0000000000..2ee45cb41b --- /dev/null +++ b/mlir/include/Gradient/Transforms/EnzymeConstants.h @@ -0,0 +1,28 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +namespace catalyst { +namespace gradient { + +static constexpr const char *enzyme_autodiff_func_name = "__enzyme_autodiff"; +static constexpr const char *enzyme_allocation_key = "__enzyme_allocation_like"; +static constexpr const char *enzyme_custom_gradient_key = "__enzyme_register_gradient_"; +static constexpr const char *enzyme_like_free_key = "__enzyme_function_like_free"; +static constexpr const char *enzyme_const_key = "enzyme_const"; +static constexpr const char *enzyme_dupnoneed_key = "enzyme_dupnoneed"; +static constexpr const char *enzyme_inactivefn_key = "__enzyme_inactivefn"; + +} +} diff --git a/mlir/lib/Catalyst/Transforms/CMakeLists.txt b/mlir/lib/Catalyst/Transforms/CMakeLists.txt index f27c719a60..9481f23b10 100644 --- a/mlir/lib/Catalyst/Transforms/CMakeLists.txt +++ b/mlir/lib/Catalyst/Transforms/CMakeLists.txt @@ -16,6 +16,7 @@ file(GLOB SRC AsyncUtils.cpp GEPInboundsPatterns.cpp GEPInboundsPass.cpp + RegisterInactiveCallbackPass.cpp ) get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) diff --git a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp index 452f785e56..ac0fc40132 100644 --- a/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp +++ b/mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp @@ -40,4 +40,5 @@ void catalyst::registerAllCatalystPasses() mlir::registerPass(catalyst::createGEPInboundsPass); mlir::registerPass(catalyst::createRemoveChainedSelfInversePass); mlir::registerPass(catalyst::createAnnotateFunctionPass); + mlir::registerPass(catalyst::createRegisterInactiveCallbackPass); } diff --git a/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp b/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp new file mode 100644 index 0000000000..80267b85a2 --- /dev/null +++ b/mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp @@ -0,0 +1,67 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#include "Catalyst/Transforms/Passes.h" +#include "Catalyst/Transforms/Patterns.h" +#include "Gradient/Transforms/EnzymeConstants.h" + +using namespace mlir; + +namespace catalyst { + +#define GEN_PASS_DEF_REGISTERINACTIVECALLBACKPASS +#define GEN_PASS_DECL_REGISTERINACTIVECALLBACKPASS +#include "Catalyst/Transforms/Passes.h.inc" + +struct RegisterInactiveCallbackPass + : impl::RegisterInactiveCallbackPassBase { + using RegisterInactiveCallbackPassBase::RegisterInactiveCallbackPassBase; + void runOnOperation() final + { + auto mod = getOperation(); + StringRef inactive_callbackFnName = "inactive_callback"; + auto fnDecl = mod.lookupSymbol(inactive_callbackFnName); + if (!fnDecl) { + return; + } + MLIRContext *context = &getContext(); + auto builder = OpBuilder(context); + builder.setInsertionPointToStart(mod.getBody()); + auto ptrTy = LLVM::LLVMPointerType::get(context); + auto arrTy = LLVM::LLVMArrayType::get(ptrTy, 1); + auto loc = mod.getLoc(); + auto isConstant = false; + auto linkage = LLVM::Linkage::External; + auto key = catalyst::gradient::enzyme_inactivefn_key; + auto glb = builder.create(loc, arrTy, isConstant, linkage, key, nullptr); + // Create a block and push it to the global + Block *block = new Block(); + glb.getInitializerRegion().push_back(block); + builder.setInsertionPointToStart(block); + auto undef = builder.create(glb.getLoc(), arrTy); + auto fnSym = SymbolRefAttr::get(context, inactive_callbackFnName); + auto fnPtr = builder.create(glb.getLoc(), ptrTy, fnSym); + auto filledInArray = builder.create(glb.getLoc(), undef, fnPtr, 0); + builder.create(glb.getLoc(), filledInArray); + } +}; + +std::unique_ptr createRegisterInactiveCallbackPass() +{ + return std::make_unique(); +} +} // namespace catalyst diff --git a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp index 4166c3e04a..25359420c5 100644 --- a/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp +++ b/mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp @@ -444,7 +444,7 @@ struct PythonCallOpPattern : public OpConversionPattern { bool isVarArg = true; LLVM::LLVMFuncOp customCallFnOp = mlir::LLVM::lookupOrCreateFn( - mod, "pyregistry", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg); + mod, "inactive_callback", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg); customCallFnOp.setPrivate(); rewriter.restoreInsertionPoint(point); diff --git a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp index a5839be630..519bb2368e 100644 --- a/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp +++ b/mlir/lib/Gradient/Transforms/ConversionPatterns.cpp @@ -34,6 +34,7 @@ #include "Catalyst/Utils/CallGraph.h" #include "Gradient/IR/GradientOps.h" +#include "Gradient/Transforms/EnzymeConstants.h" #include "Gradient/Transforms/Patterns.h" #include "Gradient/Utils/DestinationPassingStyle.h" #include "Gradient/Utils/EinsumLinalgGeneric.h" @@ -149,13 +150,6 @@ struct EnzymeMemRefInterfaceOptions { bool dupNoNeed = false; }; -static constexpr const char *enzyme_autodiff_func_name = "__enzyme_autodiff"; -static constexpr const char *enzyme_allocation_key = "__enzyme_allocation_like"; -static constexpr const char *enzyme_custom_gradient_key = "__enzyme_register_gradient_"; -static constexpr const char *enzyme_like_free_key = "__enzyme_function_like_free"; -static constexpr const char *enzyme_const_key = "enzyme_const"; -static constexpr const char *enzyme_dupnoneed_key = "enzyme_dupnoneed"; - /// Enzyme custom gradients appear to exhibit better stability when they are registered for /// functions where MemRefs are passed via wrapped pointers (!llvm.ptr) /// rather than having their fields unpacked. This function automatically transforms MemRef diff --git a/mlir/test/Catalyst/ConversionTest.mlir b/mlir/test/Catalyst/ConversionTest.mlir index 52b6b52d8d..141ebae5e6 100644 --- a/mlir/test/Catalyst/ConversionTest.mlir +++ b/mlir/test/Catalyst/ConversionTest.mlir @@ -117,7 +117,7 @@ func.func @python_call () { // CHECK: [[identifier:%.+]] = llvm.mlir.constant(0 : i64) // CHECK: [[argcount:%.+]] = llvm.mlir.constant(0 : i64) // CHECK: [[rescount:%.+]] = llvm.mlir.constant(0 : i64) - // CHECK: llvm.call @pyregistry([[identifier]], [[argcount]], [[rescount]]) + // CHECK: llvm.call @inactive_callback([[identifier]], [[argcount]], [[rescount]]) catalyst.pycallback() { identifier = 0} : () -> () return } diff --git a/mlir/test/Catalyst/RegisterInactive.mlir b/mlir/test/Catalyst/RegisterInactive.mlir new file mode 100644 index 0000000000..ecfb0ddedd --- /dev/null +++ b/mlir/test/Catalyst/RegisterInactive.mlir @@ -0,0 +1,56 @@ +// Copyright 2024 Xanadu Quantum Technologies Inc. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// RUN: quantum-opt %s --register-inactive-callback --split-input-file --verify-diagnostics | FileCheck %s + +// This test just makes sure that we can +// run the compiler with the option +// +// --register-inactive-callback +// +// and that if there are no callbacks present +// it doesn't change anything + +// CHECK-LABEL: @test0 +module @test0 { + // CHECK-NOT: llvm.mlir.global external @__enzyme_inactivefn + // CHECK-LABEL: @foo + func.func @foo() { + return + } + // CHECK-NOT: llvm.mlir.global external @__enzyme_inactivefn +} + +// ----- + +// This test checks the invariant that after the transformation +// the attribute has been removed. + +// CHECK-LABEL: @test1 +module @test1 { + + // CHECK: llvm.mlir.global external @__enzyme_inactivefn + // CHECK: [[undef:%.+]] = llvm.mlir.undef + // CHECK: [[ptr:%.+]] = llvm.mlir.addressof @inactive_callback + // CHECK: [[retval:%.+]] = llvm.insertvalue [[ptr]], [[undef]][0] + // CHECK: llvm.return [[retval]] + + llvm.func @inactive_callback(i64, i64, i64, ...) + llvm.func @wrapper() { + %0 = llvm.mlir.constant(139935726668624 : i64) : i64 + %1 = llvm.mlir.constant(0 : i64) : i64 + llvm.call @inactive_callback(%0, %1, %1) vararg(!llvm.func) : (i64, i64, i64) -> () + llvm.return + } +} diff --git a/runtime/lib/capi/ExecutionContext.hpp b/runtime/lib/capi/ExecutionContext.hpp index 281e56bf48..8796b8d30a 100644 --- a/runtime/lib/capi/ExecutionContext.hpp +++ b/runtime/lib/capi/ExecutionContext.hpp @@ -33,7 +33,7 @@ extern void callbackCall(int64_t, int64_t, int64_t, va_list); namespace Catalyst::Runtime { -extern "C" void pyregistry(int64_t identifier, int64_t argc, int64_t retc, ...); +extern "C" void inactive_callback(int64_t identifier, int64_t argc, int64_t retc, ...); class MemoryManager final { private: diff --git a/runtime/lib/capi/RuntimeCAPI.cpp b/runtime/lib/capi/RuntimeCAPI.cpp index 928b75d4c7..86d24e5670 100644 --- a/runtime/lib/capi/RuntimeCAPI.cpp +++ b/runtime/lib/capi/RuntimeCAPI.cpp @@ -108,7 +108,7 @@ extern "C" { using namespace Catalyst::Runtime; using timer = catalyst::utils::Timer; -void pyregistry(int64_t identifier, int64_t argc, int64_t retc, ...) +void inactive_callback(int64_t identifier, int64_t argc, int64_t retc, ...) { // We need to guard calls to callback. // These are implemented in Python.