Skip to content

Commit

Permalink
Merge pull request google#413 from j2kun:more-secret-patterns
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604369311
  • Loading branch information
Copybara-Service committed Feb 5, 2024
2 parents 8093076 + 0770a17 commit b0d08fe
Show file tree
Hide file tree
Showing 8 changed files with 420 additions and 58 deletions.
25 changes: 25 additions & 0 deletions include/Dialect/Secret/IR/SecretOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,29 @@
#define GET_OP_CLASSES
#include "include/Dialect/Secret/IR/SecretOps.h.inc"

namespace mlir {
namespace heir {
namespace secret {

// Extracts the given op from inside the generic body and lifting to a new
// single-op generic after the context generic op. This function assumes as a
// precondition that the opToExtract's results do not have any uses besides in
// the yield of the genericOp. The HoistOpAfterGeneric pattern tests for this
// precondition.
//
// Replaces `genericOp` with a new genericOp using `rewriter`, and returns
// the two newly created generic ops, with the first one being the replacement
// for the input `genericOp`, and the second one being the extracted genericOp.
//
// Handles adding the operands of opToExtract to the yielded values of the
// generic. The new yields may not be needed, and this can be cleaned up by
// canonicalize, or a manual application of DedupeYieldedValues and
// RemoveUnusedYieldedValues.
std::pair<GenericOp, GenericOp> extractOpAfterGeneric(
GenericOp genericOp, Operation *opToExtract, PatternRewriter &rewriter);

} // namespace secret
} // namespace heir
} // namespace mlir

#endif // HEIR_INCLUDE_DIALECT_SECRET_IR_SECRETOPS_H_
20 changes: 19 additions & 1 deletion include/Dialect/Secret/IR/SecretOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,8 @@ def Secret_GenericOp : Secret_Op<"generic", [
// Clones a generic op and adds new yielded values. Returns the new op and
// the value range corresponding to the new result values of the generic.
// Callers can follow this method with something like the following to
// replace the current generic op with the result of this method.
// replace the current generic op with the result of this method. Always
// adds the new yielded values to the end of the list of yielded values.
//
// auto [modifiedGeneric, newResults] =
// genericOp.addNewYieldedValues(newResults, rewriter);
Expand Down Expand Up @@ -203,6 +204,23 @@ def Secret_GenericOp : Secret_Op<"generic", [
ArrayRef<int> yieldedIndicesToRemove,
PatternRewriter &rewriter,
SmallVector<Value> &remainingResults);

// Modifies a GenericOp in place by taking the given op inside the generic
// body and lifting it into a new single-op generic before the context
// generic op. Returns the newly created GenericOp.
//
// For extractOpAfterGeneric, see SecretOps.h (it's a non-member function).
GenericOp extractOpBeforeGeneric(
Operation *opToExtract, PatternRewriter &rewriter);

// Inlines the GenericOp in place, dropping any secret types involved.
// Extra `operands` argument allows a conversion pattern to pass
// adaptor.getOperands().
void inlineInPlaceDroppingSecrets(PatternRewriter &rewriter, ValueRange operands);

void inlineInPlaceDroppingSecrets(PatternRewriter &rewriter) {
inlineInPlaceDroppingSecrets(rewriter, getOperands());
}
}];

let hasCanonicalizer = 1;
Expand Down
45 changes: 45 additions & 0 deletions include/Dialect/Secret/IR/SecretPatterns.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#ifndef INCLUDE_DIALECT_SECRET_IR_SECRETPATTERNS_H_
#define INCLUDE_DIALECT_SECRET_IR_SECRETPATTERNS_H_

#include <utility>

#include "include/Dialect/Secret/IR/SecretOps.h"
#include "include/Dialect/Secret/IR/SecretTypes.h"
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
Expand Down Expand Up @@ -156,6 +158,49 @@ struct DedupeYieldedValues : public OpRewritePattern<GenericOp> {
PatternRewriter &rewriter) const override;
};

// Hoist an op out of a generic and place it before the generic (in a new
// generic block), if possible. This will be impossible if one of the op's
// operands depends on another SSA value defined by an op inside the
// generic.
//
// Accepts a list of op names to hoist.
struct HoistOpBeforeGeneric : public OpRewritePattern<GenericOp> {
HoistOpBeforeGeneric(mlir::MLIRContext *context,
std::vector<std::string> opTypes)
: OpRewritePattern<GenericOp>(context, /*benefit=*/1),
opTypes(std::move(opTypes)) {}

public:
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override;

bool canHoist(Operation &op) const;

private:
std::vector<std::string> opTypes;
};

// Hoist an op out of a generic and place it after the generic (in a new
// generic block), if possible. This will be impossible if one of the op's
// results is used by another op inside the generic before the yield.
//
// Accepts a list of op names to hoist.
struct HoistOpAfterGeneric : public OpRewritePattern<GenericOp> {
HoistOpAfterGeneric(mlir::MLIRContext *context,
std::vector<std::string> opTypes)
: OpRewritePattern<GenericOp>(context, /*benefit=*/1),
opTypes(std::move(opTypes)) {}

public:
LogicalResult matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const override;

bool canHoist(Operation &op) const;

private:
std::vector<std::string> opTypes;
};

} // namespace secret
} // namespace heir
} // namespace mlir
Expand Down
203 changes: 203 additions & 0 deletions lib/Dialect/Secret/IR/SecretOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project
#include "llvm/include/llvm/Support/Casting.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/Block.h" // from @llvm-project
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
Expand All @@ -27,6 +28,8 @@
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

#define DEBUG_TYPE "secret-ops"

namespace mlir {
namespace heir {
namespace secret {
Expand Down Expand Up @@ -251,6 +254,7 @@ OpFoldResult CastOp::fold(CastOp::FoldAdaptor adaptor) {
}

OpOperand *GenericOp::getOpOperandForBlockArgument(Value value) {
// FIXME: why can't I just dyn_cast the Value to a BlockArgument?
auto *body = getBody();
int index = std::find(body->getArguments().begin(),
body->getArguments().end(), value) -
Expand Down Expand Up @@ -368,6 +372,205 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef<int> yieldedIndicesToRemove,
return cloneWithNewTypes(*this, newResultTypes, rewriter);
}

GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract,
PatternRewriter &rewriter) {
assert(opToExtract->getParentOp() == *this);

// Result types are secret versions of the results of the op, since the
// secret will yield all of this op's results immediately.
SmallVector<Type> newResultTypes;
newResultTypes.reserve(opToExtract->getNumResults());
for (Type ty : opToExtract->getResultTypes()) {
newResultTypes.push_back(SecretType::get(ty));
}

auto newGeneric = rewriter.create<GenericOp>(
getLoc(), getInputs(), newResultTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
IRMapping mp;
for (BlockArgument blockArg : getBody()->getArguments()) {
mp.map(blockArg, blockArguments[blockArg.getArgNumber()]);
}
auto *newOp = b.clone(*opToExtract, mp);
b.create<YieldOp>(loc, newOp->getResults());
});

// Once the op is split off into a new generic op, we need to add new
// operands to the old generic op, add new corresponding block arguments, and
// replace all uses of the opToExtract's results with the created block
// arguments.
SmallVector<Value> oldGenericNewBlockArgs;
rewriter.modifyOpInPlace(*this, [&]() {
getInputsMutable().append(newGeneric.getResults());
for (auto ty : opToExtract->getResultTypes()) {
BlockArgument arg = getBody()->addArgument(ty, opToExtract->getLoc());
oldGenericNewBlockArgs.push_back(arg);
}
});
rewriter.replaceOp(opToExtract, oldGenericNewBlockArgs);

return newGeneric;
}

// When replacing a generic op with a new one, and given an op in the original
// generic op, find the corresponding op in the new generic op.
//
// Note, this is brittle and depends on the two generic ops having identical
// copies of the same ops in the same order.
Operation *findCorrespondingOp(GenericOp oldGenericOp, GenericOp newGenericOp,
Operation *op) {
assert(oldGenericOp.getBody()->getOperations().size() ==
newGenericOp.getBody()->getOperations().size() &&
"findCorrespondingOp requires both oldGenericOp and newGenericOp have "
"the same size");
for (auto [oldOp, newOp] :
llvm::zip(oldGenericOp.getBody()->getOperations(),
newGenericOp.getBody()->getOperations())) {
if (&oldOp == op) {
assert(oldOp.getName() == newOp.getName() &&
"Expected corresponding op to be the same type in old and new "
"generic");
return &newOp;
}
}
llvm_unreachable(
"findCorrespondingOp used but no corresponding op was found");
return nullptr;
}

std::pair<GenericOp, GenericOp> extractOpAfterGeneric(
GenericOp genericOp, Operation *opToExtract, PatternRewriter &rewriter) {
assert(opToExtract->getParentOp() == genericOp);
[[maybe_unused]] auto *parent = genericOp->getParentOp();

LLVM_DEBUG({
llvm::dbgs() << "At start of extracting op after generic:\n";
parent->dump();
});
// The new yields may not always be needed, and this can be cleaned up by
// canonicalize, or a manual application of DedupeYieldedValues and
// RemoveUnusedYieldedValues.
auto result =
genericOp.addNewYieldedValues(opToExtract->getOperands(), rewriter);
// Can't do structured assignment of pair above, because clang fails to
// compile the usage of these values in the closure below.
// (https://stackoverflow.com/a/46115028/438830).
GenericOp genericOpWithNewYields = result.first;
ValueRange newResults = result.second;
// Keep track of the opToExtract in the new generic.
opToExtract =
findCorrespondingOp(genericOp, genericOpWithNewYields, opToExtract);
rewriter.replaceOp(genericOp,
ValueRange(genericOpWithNewYields.getResults().drop_back(
newResults.size())));
LLVM_DEBUG({
llvm::dbgs() << "After adding new yielded values:\n";
parent->dump();
llvm::dbgs() << "opToExtract is now in:\n";
opToExtract->getParentOp()->dump();
});

SmallVector<Value> newGenericOperands;
newGenericOperands.reserve(opToExtract->getNumOperands());
for (auto operand : opToExtract->getOperands()) {
// If the yielded value is a block argument or ambient, we can just use the
// original SSA value.
auto blockArg = operand.dyn_cast<BlockArgument>();
bool isBlockArgOfGeneric =
blockArg && blockArg.getOwner() == genericOpWithNewYields.getBody();
bool isAmbient =
(blockArg && blockArg.getOwner() != genericOpWithNewYields.getBody()) ||
(!blockArg && operand.getDefiningOp()->getBlock() !=
genericOpWithNewYields.getBody());
if (isBlockArgOfGeneric) {
newGenericOperands.push_back(
genericOpWithNewYields.getOperand(blockArg.getArgNumber()));
continue;
}
if (isAmbient) {
newGenericOperands.push_back(operand);
continue;
}

// Otherwise, find the corresponding result of the generic op
auto yieldOperands = genericOpWithNewYields.getYieldOp().getOperands();
int resultIndex =
std::find(yieldOperands.begin(), yieldOperands.end(), operand) -
yieldOperands.begin();
newGenericOperands.push_back(genericOpWithNewYields.getResult(resultIndex));
}

// Result types are secret versions of the results of the op, since the
// secret will yield all of this op's results immediately.
SmallVector<Type> newResultTypes;
newResultTypes.reserve(opToExtract->getNumResults());
for (Type ty : opToExtract->getResultTypes()) {
newResultTypes.push_back(SecretType::get(ty));
}

rewriter.setInsertionPointAfter(genericOpWithNewYields);
auto newGeneric = rewriter.create<GenericOp>(
genericOpWithNewYields.getLoc(), newGenericOperands, newResultTypes,
[&](OpBuilder &b, Location loc, ValueRange blockArguments) {
IRMapping mp;
int i = 0;
for (Value operand : opToExtract->getOperands()) {
mp.map(operand, blockArguments[i]);
++i;
}
auto *newOp = b.clone(*opToExtract, mp);
b.create<YieldOp>(loc, newOp->getResults());
});
LLVM_DEBUG({
llvm::dbgs() << "After adding new single-op generic:\n";
parent->dump();
});

// Once the op is split off into a new generic op, we need to erase
// the old op and remove its results from the yield op.
rewriter.setInsertionPointAfter(genericOpWithNewYields);
SmallVector<Value> remainingResults;
auto replacedGeneric = genericOpWithNewYields.removeYieldedValues(
opToExtract->getResults(), rewriter, remainingResults);
// Keep track of the opToExtract in the new generic.
opToExtract =
findCorrespondingOp(genericOpWithNewYields, replacedGeneric, opToExtract);
rewriter.replaceAllUsesWith(remainingResults, replacedGeneric.getResults());
rewriter.eraseOp(genericOpWithNewYields);
rewriter.eraseOp(opToExtract);
LLVM_DEBUG({
llvm::dbgs() << "After removing opToExtract from old generic:\n";
parent->dump();
});

return std::pair{replacedGeneric, newGeneric};
}

void GenericOp::inlineInPlaceDroppingSecrets(PatternRewriter &rewriter,
ValueRange operands) {
GenericOp &op = *this;
Block *originalBlock = op->getBlock();
Block &opEntryBlock = op.getRegion().front();
YieldOp yieldOp = dyn_cast<YieldOp>(op.getRegion().back().getTerminator());

// Inline the op's (unique) block, including the yield op. This also
// requires splitting the parent block of the generic op, so that we have a
// clear insertion point for inlining.
Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op));
rewriter.inlineRegionBefore(op.getRegion(), newBlock);

// Now that op's region is inlined, the operands of its YieldOp are mapped
// to the materialized target values. Therefore, we can replace the op's
// uses with those of its YieldOp's operands.
rewriter.replaceOp(op, yieldOp->getOperands());

// No need for these intermediate blocks, merge them into 1.
rewriter.mergeBlocks(&opEntryBlock, originalBlock, operands);
rewriter.mergeBlocks(newBlock, originalBlock, {});

rewriter.eraseOp(yieldOp);
}

void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CollapseSecretlessGeneric, RemoveUnusedYieldedValues,
Expand Down
Loading

0 comments on commit b0d08fe

Please sign in to comment.