Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Mark debug.callbacks as inactive functions. #706

Merged
merged 17 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@

<h3>Improvements</h3>

* `debug.callbacks` are marked as inactive. This means `debug.callbacks` will not be considered
as active for the computation of gradients.
[(#706)](https://github.com/PennyLaneAI/catalyst/pull/706)
dime10 marked this conversation as resolved.
Show resolved Hide resolved

* Added support for IsingZZ gate in Catalyst frontend. Previously, the IsingZZ gate would be
decomposed into a CNOT and RZ gates. However, this is not needed as the PennyLane-Lightning
simulator supports this gate.
Expand Down
1 change: 1 addition & 0 deletions frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def run_writing_command(command: List[str], compile_options: Optional[CompileOpt
"reconcile-unrealized-casts",
"gep-inbounds",
"add-exception-handling",
"register-inactive-callback",
],
)

Expand Down
1 change: 1 addition & 0 deletions mlir/include/Catalyst/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ std::unique_ptr<mlir::Pass> createHloCustomCallLoweringPass();
std::unique_ptr<mlir::Pass> createQnodeToAsyncLoweringPass();
std::unique_ptr<mlir::Pass> createAddExceptionHandlingPass();
std::unique_ptr<mlir::Pass> createGEPInboundsPass();
std::unique_ptr<mlir::Pass> createRegisterInactiveCallbackPass();

void registerAllCatalystPasses();

Expand Down
11 changes: 11 additions & 0 deletions mlir/include/Catalyst/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,15 @@ def GEPInboundsPass : Pass<"gep-inbounds"> {
let constructor = "catalyst::GEPInboundsPass()";
}

def RegisterInactiveCallbackPass : Pass<"register-inactive-callback", "ModuleOp"> {
let summary = "Register `inactive_callback` as inactive with Enzyme";

let dependentDialects = [
"mlir::LLVM::LLVMDialect"
];

let constructor = "catalyst::createRegisterInactiveCallbackPass()";

}

#endif // CATALYST_PASSES
28 changes: 28 additions & 0 deletions mlir/include/Gradient/Transforms/EnzymeConstants.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

namespace catalyst {
namespace gradient {

static constexpr const char *enzyme_autodiff_func_name = "__enzyme_autodiff";
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
static constexpr const char *enzyme_allocation_key = "__enzyme_allocation_like";
static constexpr const char *enzyme_custom_gradient_key = "__enzyme_register_gradient_";
static constexpr const char *enzyme_like_free_key = "__enzyme_function_like_free";
static constexpr const char *enzyme_const_key = "enzyme_const";
static constexpr const char *enzyme_dupnoneed_key = "enzyme_dupnoneed";
static constexpr const char *enzyme_inactivefn_key = "__enzyme_inactivefn";

}
}
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ file(GLOB SRC
AsyncUtils.cpp
GEPInboundsPatterns.cpp
GEPInboundsPass.cpp
RegisterInactiveCallbackPass.cpp
)

get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createGEPInboundsPass);
mlir::registerPass(catalyst::createRemoveChainedSelfInversePass);
mlir::registerPass(catalyst::createAnnotateFunctionPass);
mlir::registerPass(catalyst::createRegisterInactiveCallbackPass);
}
67 changes: 67 additions & 0 deletions mlir/lib/Catalyst/Transforms/RegisterInactiveCallbackPass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "Catalyst/Transforms/Passes.h"
#include "Catalyst/Transforms/Patterns.h"
#include "Gradient/Transforms/EnzymeConstants.h"

using namespace mlir;

namespace catalyst {

#define GEN_PASS_DEF_REGISTERINACTIVECALLBACKPASS
#define GEN_PASS_DECL_REGISTERINACTIVECALLBACKPASS
#include "Catalyst/Transforms/Passes.h.inc"

struct RegisterInactiveCallbackPass
: impl::RegisterInactiveCallbackPassBase<RegisterInactiveCallbackPass> {
using RegisterInactiveCallbackPassBase::RegisterInactiveCallbackPassBase;
void runOnOperation() final
{
auto mod = getOperation();
StringRef inactive_callbackFnName = "inactive_callback";
auto fnDecl = mod.lookupSymbol<LLVM::LLVMFuncOp>(inactive_callbackFnName);
if (!fnDecl) {
return;
}
MLIRContext *context = &getContext();
auto builder = OpBuilder(context);
builder.setInsertionPointToStart(mod.getBody());
auto ptrTy = LLVM::LLVMPointerType::get(context);
auto arrTy = LLVM::LLVMArrayType::get(ptrTy, 1);
auto loc = mod.getLoc();
auto isConstant = false;
auto linkage = LLVM::Linkage::External;
auto key = catalyst::gradient::enzyme_inactivefn_key;
auto glb = builder.create<LLVM::GlobalOp>(loc, arrTy, isConstant, linkage, key, nullptr);
// Create a block and push it to the global
Block *block = new Block();
glb.getInitializerRegion().push_back(block);
builder.setInsertionPointToStart(block);
auto undef = builder.create<LLVM::UndefOp>(glb.getLoc(), arrTy);
auto fnSym = SymbolRefAttr::get(context, inactive_callbackFnName);
auto fnPtr = builder.create<LLVM::AddressOfOp>(glb.getLoc(), ptrTy, fnSym);
auto filledInArray = builder.create<LLVM::InsertValueOp>(glb.getLoc(), undef, fnPtr, 0);
builder.create<LLVM::ReturnOp>(glb.getLoc(), filledInArray);
}
};

std::unique_ptr<Pass> createRegisterInactiveCallbackPass()
{
return std::make_unique<RegisterInactiveCallbackPass>();
}
} // namespace catalyst
2 changes: 1 addition & 1 deletion mlir/lib/Catalyst/Transforms/catalyst_to_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ struct PythonCallOpPattern : public OpConversionPattern<PythonCallOp> {

bool isVarArg = true;
LLVM::LLVMFuncOp customCallFnOp = mlir::LLVM::lookupOrCreateFn(
mod, "pyregistry", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg);
mod, "inactive_callback", {/*args=*/i64, i64, i64}, /*ret_type=*/voidType, isVarArg);
dime10 marked this conversation as resolved.
Show resolved Hide resolved
customCallFnOp.setPrivate();
rewriter.restoreInsertionPoint(point);

Expand Down
8 changes: 1 addition & 7 deletions mlir/lib/Gradient/Transforms/ConversionPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include "Catalyst/Utils/CallGraph.h"
#include "Gradient/IR/GradientOps.h"
#include "Gradient/Transforms/EnzymeConstants.h"
#include "Gradient/Transforms/Patterns.h"
#include "Gradient/Utils/DestinationPassingStyle.h"
#include "Gradient/Utils/EinsumLinalgGeneric.h"
Expand Down Expand Up @@ -149,13 +150,6 @@ struct EnzymeMemRefInterfaceOptions {
bool dupNoNeed = false;
};

static constexpr const char *enzyme_autodiff_func_name = "__enzyme_autodiff";
static constexpr const char *enzyme_allocation_key = "__enzyme_allocation_like";
static constexpr const char *enzyme_custom_gradient_key = "__enzyme_register_gradient_";
static constexpr const char *enzyme_like_free_key = "__enzyme_function_like_free";
static constexpr const char *enzyme_const_key = "enzyme_const";
static constexpr const char *enzyme_dupnoneed_key = "enzyme_dupnoneed";

/// Enzyme custom gradients appear to exhibit better stability when they are registered for
/// functions where MemRefs are passed via wrapped pointers (!llvm.ptr<struct(ptr, ptr, i64, ...)>)
/// rather than having their fields unpacked. This function automatically transforms MemRef
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Catalyst/ConversionTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func.func @python_call () {
// CHECK: [[identifier:%.+]] = llvm.mlir.constant(0 : i64)
// CHECK: [[argcount:%.+]] = llvm.mlir.constant(0 : i64)
// CHECK: [[rescount:%.+]] = llvm.mlir.constant(0 : i64)
// CHECK: llvm.call @pyregistry([[identifier]], [[argcount]], [[rescount]])
// CHECK: llvm.call @inactive_callback([[identifier]], [[argcount]], [[rescount]])
catalyst.pycallback() { identifier = 0} : () -> ()
return
}
56 changes: 56 additions & 0 deletions mlir/test/Catalyst/RegisterInactive.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Copyright 2024 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// RUN: quantum-opt %s --register-inactive-callback --split-input-file --verify-diagnostics | FileCheck %s

// This test just makes sure that we can
// run the compiler with the option
//
// --register-inactive-callback
//
// and that if there are no callbacks present
// it doesn't change anything

// CHECK-LABEL: @test0
module @test0 {
// CHECK-NOT: llvm.mlir.global external @__enzyme_inactivefn
// CHECK-LABEL: @foo
func.func @foo() {
return
}
// CHECK-NOT: llvm.mlir.global external @__enzyme_inactivefn
}

// -----

// This test checks the invariant that after the transformation
// the attribute has been removed.

// CHECK-LABEL: @test1
module @test1 {

// CHECK: llvm.mlir.global external @__enzyme_inactivefn
// CHECK: [[undef:%.+]] = llvm.mlir.undef
// CHECK: [[ptr:%.+]] = llvm.mlir.addressof @inactive_callback
// CHECK: [[retval:%.+]] = llvm.insertvalue [[ptr]], [[undef]][0]
// CHECK: llvm.return [[retval]]

llvm.func @inactive_callback(i64, i64, i64, ...)
llvm.func @wrapper() {
%0 = llvm.mlir.constant(139935726668624 : i64) : i64
%1 = llvm.mlir.constant(0 : i64) : i64
llvm.call @inactive_callback(%0, %1, %1) vararg(!llvm.func<void (i64, i64, i64, ...)>) : (i64, i64, i64) -> ()
llvm.return
}
}
2 changes: 1 addition & 1 deletion runtime/lib/capi/ExecutionContext.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ extern void callbackCall(int64_t, int64_t, int64_t, va_list);

namespace Catalyst::Runtime {

extern "C" void pyregistry(int64_t identifier, int64_t argc, int64_t retc, ...);
extern "C" void inactive_callback(int64_t identifier, int64_t argc, int64_t retc, ...);

class MemoryManager final {
private:
Expand Down
2 changes: 1 addition & 1 deletion runtime/lib/capi/RuntimeCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ extern "C" {
using namespace Catalyst::Runtime;
using timer = catalyst::utils::Timer;

void pyregistry(int64_t identifier, int64_t argc, int64_t retc, ...)
void inactive_callback(int64_t identifier, int64_t argc, int64_t retc, ...)
{
// We need to guard calls to callback.
// These are implemented in Python.
Expand Down
Loading