Skip to content

Commit

Permalink
[MLIR] Mark debug.callbacks as inactive functions. (#706)
Browse files Browse the repository at this point in the history
**Context:** `debug.callbacks` will always be inactive functions when
taking the gradient of them.

**Description of the Change:** Change `pyregistry` to
`inactive_callback` and mark all `inactive_callback`s as inactive.

**Benefits:** Inactive callbacks are inactive.

Notes:

* Future PRs will make active_callbacks through specialization and will
call inactive_callback as a primitive.
* Future PR will re-enable callbacks in gradients.

[sc-60496]
  • Loading branch information
erick-xanadu committed May 24, 2024
1 parent 715523e commit 2636257
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 11 deletions.
4 changes: 4 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@

<h3>Improvements</h3>

* `debug.callbacks` are marked as inactive. This means `debug.callbacks` will not be considered
as active for the computation of gradients.
[(#706)](https://github.com/PennyLaneAI/catalyst/pull/706)

* Added support for IsingZZ gate in Catalyst frontend. Previously, the IsingZZ gate would be
decomposed into a CNOT and RZ gates. However, this is not needed as the PennyLane-Lightning
simulator supports this gate.
Expand Down
1 change: 1 addition & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
"reconcile-unrealized-casts",
"gep-inbounds",
"add-exception-handling",
"register-inactive-callback",
],
)

Expand Down
1 change: 1 addition & 0 deletions mlir/include/Catalyst/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ std::unique_ptr<mlir::Pass> createHloCustomCallLoweringPass();
std::unique_ptr<mlir::Pass> createQnodeToAsyncLoweringPass();
std::unique_ptr<mlir::Pass> createAddExceptionHandlingPass();
std::unique_ptr<mlir::Pass> createGEPInboundsPass();
std::unique_ptr<mlir::Pass> createRegisterInactiveCallbackPass();

void registerAllCatalystPasses();

Expand Down
11 changes: 11 additions & 0 deletions mlir/include/Catalyst/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,15 @@ def GEPInboundsPass : Pass<"gep-inbounds"> {
let constructor = "catalyst::GEPInboundsPass()";
}

def RegisterInactiveCallbackPass : Pass<"register-inactive-callback", "ModuleOp"> {
let summary = "Register `inactive_callback` as inactive with Enzyme";

let dependentDialects = [
"mlir::LLVM::LLVMDialect"
];

let constructor = "catalyst::createRegisterInactiveCallbackPass()";

}

#endif // CATALYST_PASSES
28 changes: 28 additions & 0 deletions mlir/include/Gradient/Transforms/EnzymeConstants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

namespace catalyst {
namespace gradient {

static constexpr const char *enzyme_autodiff_func_name = "__enzyme_autodiff";
static constexpr const char *enzyme_allocation_key = "__enzyme_allocation_like";
static constexpr const char *enzyme_custom_gradient_key = "__enzyme_register_gradient_";
static constexpr const char *enzyme_like_free_key = "__enzyme_function_like_free";
static constexpr const char *enzyme_const_key = "enzyme_const";
static constexpr const char *enzyme_dupnoneed_key = "enzyme_dupnoneed";
static constexpr const char *enzyme_inactivefn_key = "__enzyme_inactivefn";

}
}
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ file(GLOB SRC
AsyncUtils.cpp
GEPInboundsPatterns.cpp
GEPInboundsPass.cpp
RegisterInactiveCallbackPass.cpp
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createGEPInboundsPass);
mlir::registerPass(catalyst::createRemoveChainedSelfInversePass);
mlir::registerPass(catalyst::createAnnotateFunctionPass);
mlir::registerPass(catalyst::createRegisterInactiveCallbackPass);
}
67 changes: 67 additions & 0 deletions mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "Catalyst/Transforms/Passes.h"
#include "Catalyst/Transforms/Patterns.h"
#include "Gradient/Transforms/EnzymeConstants.h"

using namespace mlir;

namespace catalyst {

#define GEN_PASS_DEF_REGISTERINACTIVECALLBACKPASS
#define GEN_PASS_DECL_REGISTERINACTIVECALLBACKPASS
#include "Catalyst/Transforms/Passes.h.inc"

struct RegisterInactiveCallbackPass
: impl::RegisterInactiveCallbackPassBase<RegisterInactiveCallbackPass> {
using RegisterInactiveCallbackPassBase::RegisterInactiveCallbackPassBase;
void runOnOperation() final
{
auto mod = getOperation();
StringRef inactive_callbackFnName = "inactive_callback";
auto fnDecl = mod.lookupSymbol<LLVM::LLVMFuncOp>(inactive_callbackFnName);
if (!fnDecl) {
return;
}
MLIRContext *context = &getContext();
auto builder = OpBuilder(context);
builder.setInsertionPointToStart(mod.getBody());
auto ptrTy = LLVM::LLVMPointerType::get(context);
auto arrTy = LLVM::LLVMArrayType::get(ptrTy, 1);
auto loc = mod.getLoc();
auto isConstant = false;
auto linkage = LLVM::Linkage::External;
auto key = catalyst::gradient::enzyme_inactivefn_key;
auto glb = builder.create<LLVM::GlobalOp>(loc, arrTy, isConstant, linkage, key, nullptr);
// Create a block and push it to the global
Block *block = new Block();
glb.getInitializerRegion().push_back(block);
builder.setInsertionPointToStart(block);
auto undef = builder.create<LLVM::UndefOp>(glb.getLoc(), arrTy);
auto fnSym = SymbolRefAttr::get(context, inactive_callbackFnName);
auto fnPtr = builder.create<LLVM::AddressOfOp>(glb.getLoc(), ptrTy, fnSym);
auto filledInArray = builder.create<LLVM::InsertValueOp>(glb.getLoc(), undef, fnPtr, 0);
builder.create<LLVM::ReturnOp>(glb.getLoc(), filledInArray);
}
};

std::unique_ptr<Pass> createRegisterInactiveCallbackPass()
{
return std::make_unique<RegisterInactiveCallbackPass>();
}
} // namespace catalyst
2 changes: 1 addition & 1 deletion mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ struct PythonCallOpPattern : public OpConversionPattern<PythonCallOp> {

bool isVarArg = true;
LLVM::LLVMFuncOp customCallFnOp = mlir::LLVM::lookupOrCreateFn(
mod, "pyregistry", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg);
mod, "inactive_callback", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg);
customCallFnOp.setPrivate();
rewriter.restoreInsertionPoint(point);

Expand Down
8 changes: 1 addition & 7 deletions mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "Catalyst/Utils/CallGraph.h"
#include "Gradient/IR/GradientOps.h"
#include "Gradient/Transforms/EnzymeConstants.h"
#include "Gradient/Transforms/Patterns.h"
#include "Gradient/Utils/DestinationPassingStyle.h"
#include "Gradient/Utils/EinsumLinalgGeneric.h"
Expand Down Expand Up @@ -149,13 +150,6 @@ struct EnzymeMemRefInterfaceOptions {
bool dupNoNeed = false;
};

static constexpr const char *enzyme_autodiff_func_name = "__enzyme_autodiff";
static constexpr const char *enzyme_allocation_key = "__enzyme_allocation_like";
static constexpr const char *enzyme_custom_gradient_key = "__enzyme_register_gradient_";
static constexpr const char *enzyme_like_free_key = "__enzyme_function_like_free";
static constexpr const char *enzyme_const_key = "enzyme_const";
static constexpr const char *enzyme_dupnoneed_key = "enzyme_dupnoneed";

/// Enzyme custom gradients appear to exhibit better stability when they are registered for
/// functions where MemRefs are passed via wrapped pointers (!llvm.ptr<struct(ptr, ptr, i64, ...)>)
/// rather than having their fields unpacked. This function automatically transforms MemRef
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Catalyst/ConversionTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func.func @python_call () {
// CHECK: [[identifier:%.+]] = llvm.mlir.constant(0 : i64)
// CHECK: [[argcount:%.+]] = llvm.mlir.constant(0 : i64)
// CHECK: [[rescount:%.+]] = llvm.mlir.constant(0 : i64)
// CHECK: llvm.call @pyregistry([[identifier]], [[argcount]], [[rescount]])
// CHECK: llvm.call @inactive_callback([[identifier]], [[argcount]], [[rescount]])
catalyst.pycallback() { identifier = 0} : () -> ()
return
}
56 changes: 56 additions & 0 deletions mlir/test/Catalyst/RegisterInactive.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// RUN: quantum-opt %s --register-inactive-callback --split-input-file --verify-diagnostics | FileCheck %s

// This test just makes sure that we can
// run the compiler with the option
//
// --register-inactive-callback
//
// and that if there are no callbacks present
// it doesn't change anything

// CHECK-LABEL: @test0
module @test0 {
// CHECK-NOT: llvm.mlir.global external @__enzyme_inactivefn
// CHECK-LABEL: @foo
func.func @foo() {
return
}
// CHECK-NOT: llvm.mlir.global external @__enzyme_inactivefn
}

// -----

// This test checks the invariant that after the transformation
// the attribute has been removed.

// CHECK-LABEL: @test1
module @test1 {

// CHECK: llvm.mlir.global external @__enzyme_inactivefn
// CHECK: [[undef:%.+]] = llvm.mlir.undef
// CHECK: [[ptr:%.+]] = llvm.mlir.addressof @inactive_callback
// CHECK: [[retval:%.+]] = llvm.insertvalue [[ptr]], [[undef]][0]
// CHECK: llvm.return [[retval]]

llvm.func @inactive_callback(i64, i64, i64, ...)
llvm.func @wrapper() {
%0 = llvm.mlir.constant(139935726668624 : i64) : i64
%1 = llvm.mlir.constant(0 : i64) : i64
llvm.call @inactive_callback(%0, %1, %1) vararg(!llvm.func<void (i64, i64, i64, ...)>) : (i64, i64, i64) -> ()
llvm.return
}
}
2 changes: 1 addition & 1 deletion runtime/lib/capi/ExecutionContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ extern void callbackCall(int64_t, int64_t, int64_t, va_list);

namespace Catalyst::Runtime {

extern "C" void pyregistry(int64_t identifier, int64_t argc, int64_t retc, ...);
extern "C" void inactive_callback(int64_t identifier, int64_t argc, int64_t retc, ...);

class MemoryManager final {
private:
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/capi/RuntimeCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ extern "C" {
using namespace Catalyst::Runtime;
using timer = catalyst::utils::Timer;

void pyregistry(int64_t identifier, int64_t argc, int64_t retc, ...)
void inactive_callback(int64_t identifier, int64_t argc, int64_t retc, ...)
{
// We need to guard calls to callback.
// These are implemented in Python.
Expand Down

0 comments on commit 2636257

Please sign in to comment.