Skip to content

Commit

Permalink
Add bgv-add-client-interface pass
Browse files Browse the repository at this point in the history
This PR bundles a few changes related to enabling and end-to-end lowering
through BGV to OpenFHE.

The main obstacle here is that an input function like the `simple_sum`
test added in this PR go through type changes when being compiled to
FHE. In particular, a `tensor<32xi16> -> i16` function becomes a
`rlwe_ciphertext -> rlwe_ciphertext` function without any indication
of what the original type was that generated the ciphertext.

google#614 prepared a few changes for this, and the remaining changes are:

- Adds a `bgv-add-client-interface` pass, which inserts two functions
  that encrypt and decrypt the data types input and output to each compiled
  function.
- Adds `bgv.encrypt/decrypt` ops to support the above.

The `bgv-add-client-interface` pass can likely be generalized somehow, but I
will leave that to future work.

Part of google#494

PiperOrigin-RevId: 625480692
  • Loading branch information
j2kun authored and Copybara-Service committed Apr 16, 2024
1 parent 9b17b76 commit 0895910
Show file tree
Hide file tree
Showing 12 changed files with 438 additions and 20 deletions.
28 changes: 28 additions & 0 deletions include/Dialect/BGV/IR/BGVOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ def SameOperandsAndResultRings: NativeOpTrait<"SameOperandsAndResultRings"> {
let cppNamespace = "::mlir::heir::lwe";
}

class RlweParametersMatch<
string v1, string ty1, string v2, string ty2,
string comparator = "std::equal_to<>()"
> : PredOpTrait<
"the first value's RLWE parameters matches the second value's",
CPred<
comparator # "(" #
"::llvm::cast<lwe::" # ty1 # ">($" # v1 # ".getType()).getRlweParams(), " #
"::llvm::cast<lwe::" # ty2 # ">($" # v2 # ".getType()).getRlweParams())"
>
>;

class BGV_Op<string mnemonic, list<Trait> traits = []> :
Op<BGV_Dialect, mnemonic, traits> {

Expand Down Expand Up @@ -200,4 +212,20 @@ def BGV_ModulusSwitch : BGV_Op<"modulus_switch"> {
let assemblyFormat = "operands attr-dict `:` qualified(type($input)) `->` qualified(type($output))" ;
}

def BGV_EncryptOp : BGV_Op<"encrypt", [
RlweParametersMatch<"secret_key", "RLWESecretKeyType", "output", "RLWECiphertextType">
]> {
let summary = "Encrypt an encoded plaintext into a BGV ciphertext.";
let arguments = (ins RLWEPlaintext:$input, RLWESecretKey:$secret_key);
let results = (outs RLWECiphertext:$output);
}

def BGV_DecryptOp : BGV_Op<"decrypt", [
RlweParametersMatch<"secret_key", "RLWESecretKeyType", "input", "RLWECiphertextType">
]> {
let summary = "Decrypt an encoded plaintext into a BGV ciphertext.";
let arguments = (ins RLWECiphertext:$input, RLWESecretKey:$secret_key);
let results = (outs RLWEPlaintext:$output);
}

#endif // HEIR_INCLUDE_DIALECT_BGV_IR_BGVOPS_TD_
17 changes: 17 additions & 0 deletions include/Dialect/BGV/Transforms/AddClientInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef INCLUDE_DIALECT_BGV_TRANSFORMS_ADDCLIENTINTERFACE_H_
#define INCLUDE_DIALECT_BGV_TRANSFORMS_ADDCLIENTINTERFACE_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace bgv {

#define GEN_PASS_DECL_ADDCLIENTINTERFACE
#include "include/Dialect/BGV/Transforms/Passes.h.inc"

} // namespace bgv
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_BGV_TRANSFORMS_ADDCLIENTINTERFACE_H_
36 changes: 36 additions & 0 deletions include/Dialect/BGV/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# AddClientInterface tablegen and headers.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=BGV",
],
"Passes.h.inc",
),
(
["-gen-pass-doc"],
"BGVPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "Passes.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)

exports_files([
"Passes.h",
"AddClientInterface.h",
])
18 changes: 18 additions & 0 deletions include/Dialect/BGV/Transforms/Passes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef INCLUDE_DIALECT_BGV_TRANSFORMS_PASSES_H_
#define INCLUDE_DIALECT_BGV_TRANSFORMS_PASSES_H_

#include "include/Dialect/BGV/IR/BGVDialect.h"
#include "include/Dialect/BGV/Transforms/AddClientInterface.h"

namespace mlir {
namespace heir {
namespace bgv {

#define GEN_PASS_REGISTRATION
#include "include/Dialect/BGV/Transforms/Passes.h.inc"

} // namespace bgv
} // namespace heir
} // namespace mlir

#endif // INCLUDE_DIALECT_BGV_TRANSFORMS_PASSES_H_
18 changes: 18 additions & 0 deletions include/Dialect/BGV/Transforms/Passes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef INCLUDE_DIALECT_BGV_TRANSFORMS_PASSES_TD_
#define INCLUDE_DIALECT_BGV_TRANSFORMS_PASSES_TD_

include "mlir/Pass/PassBase.td"

def AddClientInterface : Pass<"bgv-add-client-interface"> {
let summary = "Add client interfaces to BGV encrypted functions";
let description = [{
This pass adds encrypt and decrypt functions for each compiled function in the
IR. These functions maintain the same interface as the original function,
while the compiled function may lose some of this information by the lowerings
to ciphertext types (e.g., a scalar ciphertext, when lowered through BGV, must
be encoded as a tensor).
}];
let dependentDialects = ["mlir::heir::bgv::BGVDialect"];
}

#endif // INCLUDE_DIALECT_BGV_TRANSFORMS_PASSES_TD_
2 changes: 1 addition & 1 deletion include/Dialect/LWE/IR/LWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def LWE_RLWEEncodeOp : LWE_Op<"rlwe_encode", [Pure, HasEncoding<"output", "encod
}

def LWE_RLWEDecodeOp : LWE_Op<"rlwe_decode", [
Pure, HasEncoding<"output", "encoding", "RLWEPlaintextType">]> {
Pure, HasEncoding<"input", "encoding", "RLWEPlaintextType">]> {
let summary = "Decode an RLWE plaintext to an underlying type";

let arguments = (ins
Expand Down
191 changes: 191 additions & 0 deletions lib/Dialect/BGV/Transforms/AddClientInterface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
#include "include/Dialect/BGV/Transforms/AddClientInterface.h"

#include <cstddef>
#include <string>

#include "include/Dialect/BGV/IR/BGVOps.h"
#include "include/Dialect/LWE/IR/LWEAttributes.h"
#include "include/Dialect/LWE/IR/LWEOps.h"
#include "include/Dialect/LWE/IR/LWETypes.h"
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace bgv {

#define GEN_PASS_DEF_ADDCLIENTINTERFACE
#include "include/Dialect/BGV/Transforms/Passes.h.inc"

/// Adds the client interface for a single func. This should only be used on the
/// "entry" func for the IR being compiled, but there may be multiple.
LogicalResult convertFunc(func::FuncOp op) {
auto module = op->getParentOfType<ModuleOp>();
auto argTypes = op.getArgumentTypes();
auto returnTypes = op.getResultTypes();

ImplicitLocOpBuilder builder =
ImplicitLocOpBuilder::atBlockEnd(module.getLoc(), module.getBody());

std::string encFuncName("");
llvm::raw_string_ostream encNameOs(encFuncName);
encNameOs << op.getSymName() << "__encrypt";

std::string decFuncName("");
llvm::raw_string_ostream decNameOs(decFuncName);
decNameOs << op.getSymName() << "__decrypt";

// The enryption function converts each plaintext operand to its encrypted
// form. We also have to add a secret key arg, and we put it at the end to
// maintain zippability of the non-secret-key args.
SmallVector<Type> encFuncArgTypes;
SmallVector<Type> encFuncResultTypes;
lwe::RLWEParamsAttr rlweParams = nullptr;
for (auto argTy : argTypes) {
if (auto argCtTy = dyn_cast<lwe::RLWECiphertextType>(argTy)) {
encFuncArgTypes.push_back(argCtTy.getUnderlyingType());
encFuncResultTypes.push_back(argCtTy);
if (rlweParams && rlweParams != argCtTy.getRlweParams()) {
op.emitError() << "Func op has multiple distinct RLWE params"
<< " but only 1 is currently supported per func.";
return failure();
}
rlweParams = argCtTy.getRlweParams();
continue;
}

// For plaintext arguments, the function is a no-op
encFuncArgTypes.push_back(argTy);
encFuncResultTypes.push_back(argTy);
}

if (!rlweParams) {
op.emitError("Func op has no RLWE ciphertext inputs");
return failure();
}

auto skType = lwe::RLWESecretKeyType::get(op.getContext(), rlweParams);
encFuncArgTypes.push_back(skType);

// The decryption function is the opposite.
SmallVector<Type> decFuncArgTypes;
SmallVector<Type> decFuncResultTypes;
for (auto returnTy : returnTypes) {
if (auto returnCtTy = dyn_cast<lwe::RLWECiphertextType>(returnTy)) {
decFuncArgTypes.push_back(returnCtTy);
decFuncResultTypes.push_back(returnCtTy.getUnderlyingType());
continue;
}

// For plaintext results, the function is a no-op
decFuncArgTypes.push_back(returnTy);
decFuncResultTypes.push_back(returnTy);
}
decFuncArgTypes.push_back(skType);

// Build the encryption function first
FunctionType encFuncType = FunctionType::get(
builder.getContext(), encFuncArgTypes, encFuncResultTypes);
auto encFuncOp = builder.create<func::FuncOp>(encFuncName, encFuncType);
Block *entryBlock = encFuncOp.addEntryBlock();
builder.setInsertionPointToEnd(entryBlock);
Value secretKey = encFuncOp.getArgument(encFuncOp.getNumArguments() - 1);

SmallVector<Value> encValuesToReturn;
// Use result types because arg types has the secret key at the end, but
// result types does not
// TODO(#615): encode/decode should convert scalar types to tensors.
for (size_t i = 0; i < encFuncResultTypes.size(); ++i) {
auto argTy = encFuncArgTypes[i];
auto resultTy = encFuncResultTypes[i];

// If the output is encrypted, we need to encode and encrypt
if (auto resultCtTy = dyn_cast<lwe::RLWECiphertextType>(resultTy)) {
auto plaintextTy = lwe::RLWEPlaintextType::get(
op.getContext(), resultCtTy.getEncoding(),
resultCtTy.getRlweParams().getRing(), argTy);
auto encoded = builder.create<lwe::RLWEEncodeOp>(
plaintextTy, encFuncOp.getArgument(i), resultCtTy.getEncoding(),
resultCtTy.getRlweParams().getRing());
auto encrypted = builder.create<bgv::EncryptOp>(
resultCtTy, encoded.getResult(), secretKey);
encValuesToReturn.push_back(encrypted.getResult());
continue;
}

// Otherwise, return the input unchanged.
encValuesToReturn.push_back(encFuncOp.getArgument(i));
}

builder.create<func::ReturnOp>(encValuesToReturn);

// Then the decryption function
FunctionType decFuncType = FunctionType::get(
builder.getContext(), decFuncArgTypes, decFuncResultTypes);
// Insertion point is inside the encryption func, have to move it
// back out to the module
builder.setInsertionPointToEnd(module.getBody());
auto decFuncOp = builder.create<func::FuncOp>(decFuncName, decFuncType);
builder.setInsertionPointToEnd(decFuncOp.addEntryBlock());
secretKey = decFuncOp.getArgument(decFuncOp.getNumArguments() - 1);

SmallVector<Value> decValuesToReturn;
// Use result types because arg types has the secret key at the end, but
// result types does not
for (size_t i = 0; i < decFuncResultTypes.size(); ++i) {
auto argTy = decFuncArgTypes[i];
auto resultTy = decFuncResultTypes[i];

// If the input is ciphertext, we need to decode and decrypt
if (auto argCtTy = dyn_cast<lwe::RLWECiphertextType>(argTy)) {
auto plaintextTy = lwe::RLWEPlaintextType::get(
op.getContext(), argCtTy.getEncoding(),
argCtTy.getRlweParams().getRing(), resultTy);
auto decrypted = builder.create<bgv::DecryptOp>(
plaintextTy, decFuncOp.getArgument(i), secretKey);
auto decoded = builder.create<lwe::RLWEDecodeOp>(
resultTy, decrypted.getResult(), argCtTy.getEncoding(),
argCtTy.getRlweParams().getRing());
// FIXME: if the input is a scalar type, we must add a tensor.extract op.
// The decode op's tablegen should also support having tensor types as
// outputs if it doesn't already.
decValuesToReturn.push_back(decoded.getResult());
continue;
}

// Otherwise, return the input unchanged.
decValuesToReturn.push_back(decFuncOp.getArgument(i));
}

builder.create<func::ReturnOp>(decValuesToReturn);

return success();
}

struct AddClientInterface : impl::AddClientInterfaceBase<AddClientInterface> {
using AddClientInterfaceBase::AddClientInterfaceBase;

void runOnOperation() override {
auto result =
getOperation()->walk<WalkOrder::PreOrder>([&](func::FuncOp op) {
if (failed(convertFunc(op))) {
op->emitError("Failed to add client interface for func");
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (result.wasInterrupted()) signalPassFailure();
}
};
} // namespace bgv
} // namespace heir
} // namespace mlir
35 changes: 35 additions & 0 deletions lib/Dialect/BGV/Transforms/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "Transforms",
hdrs = [
"@heir//include/Dialect/BGV/Transforms:Passes.h",
],
deps = [
":AddClientInterface",
"@heir//include/Dialect/BGV/Transforms:pass_inc_gen",
"@heir//lib/Dialect/BGV/IR:Dialect",
"@llvm-project//mlir:IR",
],
)

cc_library(
name = "AddClientInterface",
srcs = ["AddClientInterface.cpp"],
hdrs = [
"@heir//include/Dialect/BGV/Transforms:AddClientInterface.h",
],
deps = [
"@heir//include/Dialect/BGV/Transforms:pass_inc_gen",
"@heir//lib/Dialect/BGV/IR:Dialect",
"@heir//lib/Dialect/LWE/IR:Dialect",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
],
)
Loading

0 comments on commit 0895910

Please sign in to comment.