From 0770a17035e492408786692e012f0fcc96707a05 Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Wed, 31 Jan 2024 07:44:20 -0800 Subject: [PATCH] Add more generic ops, helpers, and patterns --- include/Dialect/Secret/IR/SecretOps.h | 25 +++ include/Dialect/Secret/IR/SecretOps.td | 20 +- include/Dialect/Secret/IR/SecretPatterns.h | 45 ++++ lib/Dialect/Secret/IR/SecretOps.cpp | 203 ++++++++++++++++++ lib/Dialect/Secret/IR/SecretPatterns.cpp | 107 +++++++++ .../Secret/Transforms/DistributeGeneric.cpp | 37 +--- .../Secret/Transforms/ForgetSecrets.cpp | 21 +- tests/secret/canonicalize.mlir | 20 +- 8 files changed, 420 insertions(+), 58 deletions(-) diff --git a/include/Dialect/Secret/IR/SecretOps.h b/include/Dialect/Secret/IR/SecretOps.h index b3a960d7e..78e8311dc 100644 --- a/include/Dialect/Secret/IR/SecretOps.h +++ b/include/Dialect/Secret/IR/SecretOps.h @@ -16,4 +16,29 @@ #define GET_OP_CLASSES #include "include/Dialect/Secret/IR/SecretOps.h.inc" +namespace mlir { +namespace heir { +namespace secret { + +// Extracts the given op from inside the generic body and lifting to a new +// single-op generic after the context generic op. This function assumes as a +// precondition that the opToExtract's results do not have any uses besides in +// the yield of the genericOp. The HoistOpAfterGeneric pattern tests for this +// precondition. +// +// Replaces `genericOp` with a new genericOp using `rewriter`, and returns +// the two newly created generic ops, with the first one being the replacement +// for the input `genericOp`, and the second one being the extracted genericOp. +// +// Handles adding the operands of opToExtract to the yielded values of the +// generic. The new yields may not be needed, and this can be cleaned up by +// canonicalize, or a manual application of DedupeYieldedValues and +// RemoveUnusedYieldedValues. +std::pair extractOpAfterGeneric( + GenericOp genericOp, Operation *opToExtract, PatternRewriter &rewriter); + +} // namespace secret +} // namespace heir +} // namespace mlir + #endif // HEIR_INCLUDE_DIALECT_SECRET_IR_SECRETOPS_H_ diff --git a/include/Dialect/Secret/IR/SecretOps.td b/include/Dialect/Secret/IR/SecretOps.td index 93e602cc5..eedf63a39 100644 --- a/include/Dialect/Secret/IR/SecretOps.td +++ b/include/Dialect/Secret/IR/SecretOps.td @@ -171,7 +171,8 @@ def Secret_GenericOp : Secret_Op<"generic", [ // Clones a generic op and adds new yielded values. Returns the new op and // the value range corresponding to the new result values of the generic. // Callers can follow this method with something like the following to - // replace the current generic op with the result of this method. + // replace the current generic op with the result of this method. Always + // adds the new yielded values to the end of the list of yielded values. // // auto [modifiedGeneric, newResults] = // genericOp.addNewYieldedValues(newResults, rewriter); @@ -203,6 +204,23 @@ def Secret_GenericOp : Secret_Op<"generic", [ ArrayRef yieldedIndicesToRemove, PatternRewriter &rewriter, SmallVector &remainingResults); + + // Modifies a GenericOp in place by taking the given op inside the generic + // body and lifting it into a new single-op generic before the context + // generic op. Returns the newly created GenericOp. + // + // For extractOpAfterGeneric, see SecretOps.h (it's a non-member function). + GenericOp extractOpBeforeGeneric( + Operation *opToExtract, PatternRewriter &rewriter); + + // Inlines the GenericOp in place, dropping any secret types involved. + // Extra `operands` argument allows a conversion pattern to pass + // adaptor.getOperands(). + void inlineInPlaceDroppingSecrets(PatternRewriter &rewriter, ValueRange operands); + + void inlineInPlaceDroppingSecrets(PatternRewriter &rewriter) { + inlineInPlaceDroppingSecrets(rewriter, getOperands()); + } }]; let hasCanonicalizer = 1; diff --git a/include/Dialect/Secret/IR/SecretPatterns.h b/include/Dialect/Secret/IR/SecretPatterns.h index 2d397dc4a..9ce242e7c 100644 --- a/include/Dialect/Secret/IR/SecretPatterns.h +++ b/include/Dialect/Secret/IR/SecretPatterns.h @@ -1,6 +1,8 @@ #ifndef INCLUDE_DIALECT_SECRET_IR_SECRETPATTERNS_H_ #define INCLUDE_DIALECT_SECRET_IR_SECRETPATTERNS_H_ +#include + #include "include/Dialect/Secret/IR/SecretOps.h" #include "include/Dialect/Secret/IR/SecretTypes.h" #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project @@ -156,6 +158,49 @@ struct DedupeYieldedValues : public OpRewritePattern { PatternRewriter &rewriter) const override; }; +// Hoist an op out of a generic and place it before the generic (in a new +// generic block), if possible. This will be impossible if one of the op's +// operands depends on another SSA value defined by an op inside the +// generic. +// +// Accepts a list of op names to hoist. +struct HoistOpBeforeGeneric : public OpRewritePattern { + HoistOpBeforeGeneric(mlir::MLIRContext *context, + std::vector opTypes) + : OpRewritePattern(context, /*benefit=*/1), + opTypes(std::move(opTypes)) {} + + public: + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override; + + bool canHoist(Operation &op) const; + + private: + std::vector opTypes; +}; + +// Hoist an op out of a generic and place it after the generic (in a new +// generic block), if possible. This will be impossible if one of the op's +// results is used by another op inside the generic before the yield. +// +// Accepts a list of op names to hoist. +struct HoistOpAfterGeneric : public OpRewritePattern { + HoistOpAfterGeneric(mlir::MLIRContext *context, + std::vector opTypes) + : OpRewritePattern(context, /*benefit=*/1), + opTypes(std::move(opTypes)) {} + + public: + LogicalResult matchAndRewrite(GenericOp op, + PatternRewriter &rewriter) const override; + + bool canHoist(Operation &op) const; + + private: + std::vector opTypes; +}; + } // namespace secret } // namespace heir } // namespace mlir diff --git a/lib/Dialect/Secret/IR/SecretOps.cpp b/lib/Dialect/Secret/IR/SecretOps.cpp index 712aaa7de..a04e48da0 100644 --- a/lib/Dialect/Secret/IR/SecretOps.cpp +++ b/lib/Dialect/Secret/IR/SecretOps.cpp @@ -11,6 +11,7 @@ #include "llvm/include/llvm/ADT/STLExtras.h" // from @llvm-project #include "llvm/include/llvm/ADT/SmallVector.h" // from @llvm-project #include "llvm/include/llvm/Support/Casting.h" // from @llvm-project +#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project #include "mlir/include/mlir/IR/Attributes.h" // from @llvm-project #include "mlir/include/mlir/IR/Block.h" // from @llvm-project #include "mlir/include/mlir/IR/Builders.h" // from @llvm-project @@ -27,6 +28,8 @@ #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#define DEBUG_TYPE "secret-ops" + namespace mlir { namespace heir { namespace secret { @@ -251,6 +254,7 @@ OpFoldResult CastOp::fold(CastOp::FoldAdaptor adaptor) { } OpOperand *GenericOp::getOpOperandForBlockArgument(Value value) { + // FIXME: why can't I just dyn_cast the Value to a BlockArgument? auto *body = getBody(); int index = std::find(body->getArguments().begin(), body->getArguments().end(), value) - @@ -368,6 +372,205 @@ GenericOp GenericOp::removeYieldedValues(ArrayRef yieldedIndicesToRemove, return cloneWithNewTypes(*this, newResultTypes, rewriter); } +GenericOp GenericOp::extractOpBeforeGeneric(Operation *opToExtract, + PatternRewriter &rewriter) { + assert(opToExtract->getParentOp() == *this); + + // Result types are secret versions of the results of the op, since the + // secret will yield all of this op's results immediately. + SmallVector newResultTypes; + newResultTypes.reserve(opToExtract->getNumResults()); + for (Type ty : opToExtract->getResultTypes()) { + newResultTypes.push_back(SecretType::get(ty)); + } + + auto newGeneric = rewriter.create( + getLoc(), getInputs(), newResultTypes, + [&](OpBuilder &b, Location loc, ValueRange blockArguments) { + IRMapping mp; + for (BlockArgument blockArg : getBody()->getArguments()) { + mp.map(blockArg, blockArguments[blockArg.getArgNumber()]); + } + auto *newOp = b.clone(*opToExtract, mp); + b.create(loc, newOp->getResults()); + }); + + // Once the op is split off into a new generic op, we need to add new + // operands to the old generic op, add new corresponding block arguments, and + // replace all uses of the opToExtract's results with the created block + // arguments. + SmallVector oldGenericNewBlockArgs; + rewriter.modifyOpInPlace(*this, [&]() { + getInputsMutable().append(newGeneric.getResults()); + for (auto ty : opToExtract->getResultTypes()) { + BlockArgument arg = getBody()->addArgument(ty, opToExtract->getLoc()); + oldGenericNewBlockArgs.push_back(arg); + } + }); + rewriter.replaceOp(opToExtract, oldGenericNewBlockArgs); + + return newGeneric; +} + +// When replacing a generic op with a new one, and given an op in the original +// generic op, find the corresponding op in the new generic op. +// +// Note, this is brittle and depends on the two generic ops having identical +// copies of the same ops in the same order. +Operation *findCorrespondingOp(GenericOp oldGenericOp, GenericOp newGenericOp, + Operation *op) { + assert(oldGenericOp.getBody()->getOperations().size() == + newGenericOp.getBody()->getOperations().size() && + "findCorrespondingOp requires both oldGenericOp and newGenericOp have " + "the same size"); + for (auto [oldOp, newOp] : + llvm::zip(oldGenericOp.getBody()->getOperations(), + newGenericOp.getBody()->getOperations())) { + if (&oldOp == op) { + assert(oldOp.getName() == newOp.getName() && + "Expected corresponding op to be the same type in old and new " + "generic"); + return &newOp; + } + } + llvm_unreachable( + "findCorrespondingOp used but no corresponding op was found"); + return nullptr; +} + +std::pair extractOpAfterGeneric( + GenericOp genericOp, Operation *opToExtract, PatternRewriter &rewriter) { + assert(opToExtract->getParentOp() == genericOp); + [[maybe_unused]] auto *parent = genericOp->getParentOp(); + + LLVM_DEBUG({ + llvm::dbgs() << "At start of extracting op after generic:\n"; + parent->dump(); + }); + // The new yields may not always be needed, and this can be cleaned up by + // canonicalize, or a manual application of DedupeYieldedValues and + // RemoveUnusedYieldedValues. + auto result = + genericOp.addNewYieldedValues(opToExtract->getOperands(), rewriter); + // Can't do structured assignment of pair above, because clang fails to + // compile the usage of these values in the closure below. + // (https://stackoverflow.com/a/46115028/438830). + GenericOp genericOpWithNewYields = result.first; + ValueRange newResults = result.second; + // Keep track of the opToExtract in the new generic. + opToExtract = + findCorrespondingOp(genericOp, genericOpWithNewYields, opToExtract); + rewriter.replaceOp(genericOp, + ValueRange(genericOpWithNewYields.getResults().drop_back( + newResults.size()))); + LLVM_DEBUG({ + llvm::dbgs() << "After adding new yielded values:\n"; + parent->dump(); + llvm::dbgs() << "opToExtract is now in:\n"; + opToExtract->getParentOp()->dump(); + }); + + SmallVector newGenericOperands; + newGenericOperands.reserve(opToExtract->getNumOperands()); + for (auto operand : opToExtract->getOperands()) { + // If the yielded value is a block argument or ambient, we can just use the + // original SSA value. + auto blockArg = operand.dyn_cast(); + bool isBlockArgOfGeneric = + blockArg && blockArg.getOwner() == genericOpWithNewYields.getBody(); + bool isAmbient = + (blockArg && blockArg.getOwner() != genericOpWithNewYields.getBody()) || + (!blockArg && operand.getDefiningOp()->getBlock() != + genericOpWithNewYields.getBody()); + if (isBlockArgOfGeneric) { + newGenericOperands.push_back( + genericOpWithNewYields.getOperand(blockArg.getArgNumber())); + continue; + } + if (isAmbient) { + newGenericOperands.push_back(operand); + continue; + } + + // Otherwise, find the corresponding result of the generic op + auto yieldOperands = genericOpWithNewYields.getYieldOp().getOperands(); + int resultIndex = + std::find(yieldOperands.begin(), yieldOperands.end(), operand) - + yieldOperands.begin(); + newGenericOperands.push_back(genericOpWithNewYields.getResult(resultIndex)); + } + + // Result types are secret versions of the results of the op, since the + // secret will yield all of this op's results immediately. + SmallVector newResultTypes; + newResultTypes.reserve(opToExtract->getNumResults()); + for (Type ty : opToExtract->getResultTypes()) { + newResultTypes.push_back(SecretType::get(ty)); + } + + rewriter.setInsertionPointAfter(genericOpWithNewYields); + auto newGeneric = rewriter.create( + genericOpWithNewYields.getLoc(), newGenericOperands, newResultTypes, + [&](OpBuilder &b, Location loc, ValueRange blockArguments) { + IRMapping mp; + int i = 0; + for (Value operand : opToExtract->getOperands()) { + mp.map(operand, blockArguments[i]); + ++i; + } + auto *newOp = b.clone(*opToExtract, mp); + b.create(loc, newOp->getResults()); + }); + LLVM_DEBUG({ + llvm::dbgs() << "After adding new single-op generic:\n"; + parent->dump(); + }); + + // Once the op is split off into a new generic op, we need to erase + // the old op and remove its results from the yield op. + rewriter.setInsertionPointAfter(genericOpWithNewYields); + SmallVector remainingResults; + auto replacedGeneric = genericOpWithNewYields.removeYieldedValues( + opToExtract->getResults(), rewriter, remainingResults); + // Keep track of the opToExtract in the new generic. + opToExtract = + findCorrespondingOp(genericOpWithNewYields, replacedGeneric, opToExtract); + rewriter.replaceAllUsesWith(remainingResults, replacedGeneric.getResults()); + rewriter.eraseOp(genericOpWithNewYields); + rewriter.eraseOp(opToExtract); + LLVM_DEBUG({ + llvm::dbgs() << "After removing opToExtract from old generic:\n"; + parent->dump(); + }); + + return std::pair{replacedGeneric, newGeneric}; +} + +void GenericOp::inlineInPlaceDroppingSecrets(PatternRewriter &rewriter, + ValueRange operands) { + GenericOp &op = *this; + Block *originalBlock = op->getBlock(); + Block &opEntryBlock = op.getRegion().front(); + YieldOp yieldOp = dyn_cast(op.getRegion().back().getTerminator()); + + // Inline the op's (unique) block, including the yield op. This also + // requires splitting the parent block of the generic op, so that we have a + // clear insertion point for inlining. + Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op)); + rewriter.inlineRegionBefore(op.getRegion(), newBlock); + + // Now that op's region is inlined, the operands of its YieldOp are mapped + // to the materialized target values. Therefore, we can replace the op's + // uses with those of its YieldOp's operands. + rewriter.replaceOp(op, yieldOp->getOperands()); + + // No need for these intermediate blocks, merge them into 1. + rewriter.mergeBlocks(&opEntryBlock, originalBlock, operands); + rewriter.mergeBlocks(newBlock, originalBlock, {}); + + rewriter.eraseOp(yieldOp); +} + void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { results.addgetArguments().size(); ++i) { BlockArgument arg = body->getArguments()[i]; if (arg.use_empty()) { + LLVM_DEBUG(llvm::dbgs() << arg << " has no uses; removing\n"); hasUnusedOps = true; rewriter.modifyOpInPlace(op, [&]() { body->eraseArgument(i); @@ -99,6 +100,36 @@ LogicalResult RemoveUnusedGenericArgs::matchAndRewrite( }); // Ensure the next iteration uses the right arg number --i; + } else if (llvm::any_of(arg.getUsers(), [&](Operation *user) { + return llvm::isa(user); + })) { + LLVM_DEBUG(llvm::dbgs() << arg << " is passed through to yield\n"); + // In this case, the arg is passed through to the yield, and the yield + // can be removed and replaced with the operand. Note we don't need to + // remove the block argument itself since a subsequent iteration of this + // pattern will detect if that is possible (if it has no other uses). + Value replacementValue = op.getOperand(arg.getArgNumber()); + SmallVector yieldedValuesToRemove; + SmallVector resultsToReplace; + SmallVector replacementValues; + + for (auto &opOperand : op.getYieldOp()->getOpOperands()) { + if (opOperand.get() == arg) { + yieldedValuesToRemove.push_back(opOperand.get()); + resultsToReplace.push_back( + op.getResult(opOperand.getOperandNumber())); + replacementValues.push_back(replacementValue); + } + } + rewriter.replaceAllUsesWith(resultsToReplace, replacementValues); + + SmallVector remainingResults; + auto modifiedGeneric = op.removeYieldedValues(yieldedValuesToRemove, + rewriter, remainingResults); + rewriter.replaceAllUsesWith(remainingResults, + modifiedGeneric.getResults()); + rewriter.eraseOp(op); + return success(); } } @@ -407,6 +438,82 @@ LogicalResult DedupeYieldedValues::matchAndRewrite( return success(); } +bool HoistOpBeforeGeneric::canHoist(Operation &op) const { + bool inConfiguredList = + std::find(opTypes.begin(), opTypes.end(), op.getName().getStringRef()) != + opTypes.end(); + bool allOperandsAreBlockArgsOrAmbient = + llvm::all_of(op.getOperands(), [&](Value operand) { + return isa(operand) || + operand.getDefiningOp()->getBlock() != op.getBlock(); + }); + return inConfiguredList && allOperandsAreBlockArgsOrAmbient; +} + +LogicalResult HoistOpBeforeGeneric::matchAndRewrite( + GenericOp genericOp, PatternRewriter &rewriter) const { + auto &opRange = genericOp.getBody()->getOperations(); + if (opRange.size() <= 2) { + // This corresponds to a fixed point of the pattern: if an op is hoisted, + // it will be in a single-op generic, (yield is the second op), and if + // that triggers the pattern, it will be an infinite loop. + return failure(); + } + + auto it = std::find_if(opRange.begin(), opRange.end(), + [&](Operation &op) { return canHoist(op); }); + if (it == opRange.end()) { + return failure(); + } + + Operation *opToHoist = &*it; + LLVM_DEBUG(llvm::dbgs() << "Hoisting " << *opToHoist << "\n"); + genericOp.extractOpBeforeGeneric(opToHoist, rewriter); + LLVM_DEBUG({ + Operation *parent = genericOp->getParentOp(); + llvm::dbgs() << "After hoisting op\n"; + parent->dump(); + }); + return success(); +} + +bool HoistOpAfterGeneric::canHoist(Operation &op) const { + bool inConfiguredList = + std::find(opTypes.begin(), opTypes.end(), op.getName().getStringRef()) != + opTypes.end(); + bool allUsesAreYields = llvm::all_of( + op.getUsers(), [&](Operation *user) { return isa(user); }); + return inConfiguredList && allUsesAreYields; +} + +LogicalResult HoistOpAfterGeneric::matchAndRewrite( + GenericOp genericOp, PatternRewriter &rewriter) const { + auto &opRange = genericOp.getBody()->getOperations(); + if (opRange.size() <= 2) { + // This corresponds to a fixed point of the pattern: if an op is hoisted, + // it will be in a single-op generic, (yield is the second op), and if + // that triggers the pattern, it will be an infinite loop. + return failure(); + } + + auto it = std::find_if(opRange.begin(), opRange.end(), + [&](Operation &op) { return canHoist(op); }); + if (it == opRange.end()) { + return failure(); + } + + Operation *opToHoist = &*it; + LLVM_DEBUG(llvm::dbgs() << "Hoisting " << *opToHoist << "\n"); + + extractOpAfterGeneric(genericOp, opToHoist, rewriter); + LLVM_DEBUG({ + Operation *parent = genericOp->getParentOp(); + llvm::dbgs() << "After hoisting op\n"; + parent->dump(); + }); + return success(); +} + } // namespace secret } // namespace heir } // namespace mlir diff --git a/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp b/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp index c7cc5b6ae..0cee2550c 100644 --- a/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp +++ b/lib/Dialect/Secret/Transforms/DistributeGeneric.cpp @@ -429,43 +429,8 @@ struct SplitGeneric : public OpRewritePattern { PatternRewriter &rewriter) const { Operation &firstOp = genericOp.getBody()->front(); LLVM_DEBUG(firstOp.emitRemark() << " splitting generic after this op\n"); - - // Result types are secret versions of the results of the op, since the - // secret will yield all of this op's results immediately. - SmallVector newResultTypes; - newResultTypes.reserve(firstOp.getNumResults()); - for (Type ty : firstOp.getResultTypes()) { - newResultTypes.push_back(SecretType::get(ty)); - } - - auto newGeneric = rewriter.create( - genericOp.getLoc(), genericOp.getInputs(), newResultTypes, - [&](OpBuilder &b, Location loc, ValueRange blockArguments) { - IRMapping mp; - for (BlockArgument blockArg : genericOp.getBody()->getArguments()) { - mp.map(blockArg, blockArguments[blockArg.getArgNumber()]); - } - auto *newOp = b.clone(firstOp, mp); - b.create(loc, newOp->getResults()); - }); - + auto newGeneric = genericOp.extractOpBeforeGeneric(&firstOp, rewriter); LLVM_DEBUG(newGeneric.emitRemark() << " created new generic op\n"); - - // Once the op is split off into a new generic op, we need to add new - // operands to the old generic op, add new corresponding block arguments, - // and replace all uses of the opToDistribute's results with the created - // block arguments. - SmallVector oldGenericNewBlockArgs; - rewriter.modifyOpInPlace(genericOp, [&]() { - genericOp.getInputsMutable().append(newGeneric.getResults()); - for (auto ty : firstOp.getResultTypes()) { - BlockArgument arg = - genericOp.getBody()->addArgument(ty, firstOp.getLoc()); - oldGenericNewBlockArgs.push_back(arg); - } - }); - rewriter.replaceOp(&firstOp, oldGenericNewBlockArgs); - return newGeneric; } diff --git a/lib/Dialect/Secret/Transforms/ForgetSecrets.cpp b/lib/Dialect/Secret/Transforms/ForgetSecrets.cpp index e67d63c42..f0d0b7663 100644 --- a/lib/Dialect/Secret/Transforms/ForgetSecrets.cpp +++ b/lib/Dialect/Secret/Transforms/ForgetSecrets.cpp @@ -75,26 +75,7 @@ struct ConvertGeneric : public OpConversionPattern { LogicalResult matchAndRewrite( GenericOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Block *originalBlock = op->getBlock(); - Block &opEntryBlock = op.getRegion().front(); - YieldOp yieldOp = dyn_cast(op.getRegion().back().getTerminator()); - - // Inline the op's (unique) block, including the yield op. This also - // requires splitting the parent block of the generic op, so that we have a - // clear insertion point for inlining. - Block *newBlock = rewriter.splitBlock(originalBlock, Block::iterator(op)); - rewriter.inlineRegionBefore(op.getRegion(), newBlock); - - // Now that op's region is inlined, the operands of its YieldOp are mapped - // to the materialized target values. Therefore, we can replace the op's - // uses with those of its YieldOp's operands. - rewriter.replaceOp(op, yieldOp->getOperands()); - - // No need for these intermediate blocks, merge them into 1. - rewriter.mergeBlocks(&opEntryBlock, originalBlock, adaptor.getOperands()); - rewriter.mergeBlocks(newBlock, originalBlock, {}); - - rewriter.eraseOp(yieldOp); + op.inlineInPlaceDroppingSecrets(rewriter, adaptor.getOperands()); return success(); } }; diff --git a/tests/secret/canonicalize.mlir b/tests/secret/canonicalize.mlir index 5af109363..5d49a0a23 100644 --- a/tests/secret/canonicalize.mlir +++ b/tests/secret/canonicalize.mlir @@ -12,5 +12,23 @@ func.func @remove_unused_yielded_values(%arg0: !secret.secret) -> !secret.s // CHECK: secret.yield %[[value:.*]] : i32 secret.yield %d, %unused : i32, i32 } -> (!secret.secret, !secret.secret) - func.return %Z : !secret.secret + return %Z : !secret.secret +} + +// CHECK-LABEL: func @remove_pass_through_args +func.func @remove_pass_through_args( +// CHECK: %[[arg1:.*]]: !secret.secret, %[[arg2:.*]]: !secret.secret + %arg1 : !secret.secret, %arg2 : !secret.secret) -> (!secret.secret, !secret.secret) { + // CHECK: %[[out1:.*]] = secret.generic + %out1, %out2 = secret.generic + ins(%arg1, %arg2 : !secret.secret, !secret.secret) { + ^bb0(%x: i32, %y: i32) : + // CHECK: %[[value:.*]] = arith.addi + %z = arith.addi %x, %y : i32 + // Only yield one value + // CHECK: secret.yield %[[value]] : i32 + secret.yield %z, %y : i32, i32 + } -> (!secret.secret, !secret.secret) + // CHECK: return %[[out1]], %[[arg2]] + return %out1, %out2 : !secret.secret, !secret.secret }