Skip to content

Commit

Permalink
[StableHLO] Add canonicalization patterns for compare (iree-org#13760)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar authored and NatashaKnk committed Jul 6, 2023
1 parent 25e0eeb commit 0aaa30b
Show file tree
Hide file tree
Showing 2 changed files with 254 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.h"
#include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/CommonFolders.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -236,6 +238,151 @@ struct MulOpCanon final : OpRewritePattern<mlir::stablehlo::MulOp> {
}
};

static mlir::stablehlo::ComparisonDirection invertDirection(
mlir::stablehlo::ComparisonDirection direction) {
using mlir::stablehlo::ComparisonDirection;

switch (direction) {
case ComparisonDirection::EQ:
return ComparisonDirection::EQ;
case ComparisonDirection::GE:
return ComparisonDirection::LE;
case ComparisonDirection::LE:
return ComparisonDirection::GE;
case ComparisonDirection::GT:
return ComparisonDirection::LT;
case ComparisonDirection::LT:
return ComparisonDirection::GT;
case ComparisonDirection::NE:
return ComparisonDirection::NE;
}

llvm_unreachable("Unhandled case");
}

static APInt calculateComp(mlir::stablehlo::ComparisonType kind,
mlir::stablehlo::ComparisonDirection direction,
const APInt &lhs, const APInt &rhs) {
using mlir::stablehlo::ComparisonDirection;
using mlir::stablehlo::ComparisonType;
assert(llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED},
kind) &&
"Not an integer comparison");

auto asBit = [](bool value) {
return value ? APInt::getAllOnes(1) : APInt::getZero(1);
};

// Signed comparison.
if (kind == ComparisonType::SIGNED) {
switch (direction) {
case ComparisonDirection::EQ:
return asBit(lhs == rhs);
case ComparisonDirection::GE:
return asBit(lhs.sge(rhs));
case ComparisonDirection::GT:
return asBit(lhs.sgt(rhs));
case ComparisonDirection::LE:
return asBit(lhs.sle(rhs));
case ComparisonDirection::LT:
return asBit(lhs.slt(rhs));
case ComparisonDirection::NE:
return asBit(lhs != rhs);
}
}

// Unsigned comparison.
switch (direction) {
case ComparisonDirection::EQ:
return asBit(lhs == rhs);
case ComparisonDirection::GE:
return asBit(lhs.uge(rhs));
case ComparisonDirection::GT:
return asBit(lhs.ugt(rhs));
case ComparisonDirection::LE:
return asBit(lhs.ule(rhs));
case ComparisonDirection::LT:
return asBit(lhs.ult(rhs));
case ComparisonDirection::NE:
return asBit(lhs != rhs);
}

llvm_unreachable("Unhandled case");
}

struct CompareOpCanon final : OpRewritePattern<mlir::stablehlo::CompareOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(mlir::stablehlo::CompareOp op,
PatternRewriter &rewriter) const override {
auto type = dyn_cast<RankedTensorType>(op.getType());
if (!type) return failure();

// Bail out on non-integer comparison.
// TODO: Support more comparison types.
using mlir::stablehlo::ComparisonType;
std::optional<ComparisonType> compType = op.getCompareType();
if (!compType ||
!llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED},
*compType)) {
return failure();
}

using mlir::stablehlo::ComparisonDirection;
ComparisonDirection direction = op.getComparisonDirection();
Value lhs = op.getLhs();
Value rhs = op.getRhs();

if (lhs == rhs) {
switch (direction) {
case ComparisonDirection::EQ:
case ComparisonDirection::GE:
case ComparisonDirection::LE: {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true)));
return success();
}
case ComparisonDirection::GT:
case ComparisonDirection::LT:
case ComparisonDirection::NE: {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(
op, rewriter.getZeroAttr(type));
return success();
}
}
llvm_unreachable("Unhandled case");
}

TypedAttr lhsAttr;
matchPattern(lhs, m_Constant(&lhsAttr));

TypedAttr rhsAttr;
matchPattern(rhs, m_Constant(&rhsAttr));

// The canonical form has the constant operand as the RHS.
if (lhsAttr && !rhsAttr) {
rewriter.updateRootInPlace(op, [&op, direction, lhs, rhs] {
op.setComparisonDirection(invertDirection(direction));
op->setOperands(ValueRange{rhs, lhs});
});
return success();
}

if (lhsAttr && rhsAttr) {
if (Attribute res = constFoldBinaryOp<IntegerAttr>(
ArrayRef<Attribute>({lhsAttr, rhsAttr}), op.getType(),
[direction, kind = *compType](const APInt &a, const APInt &b) {
return calculateComp(kind, direction, a, b);
})) {
rewriter.replaceOpWithNewOp<mlir::stablehlo::ConstantOp>(op, res);
return success();
}
}

return failure();
}
};

struct BroadcastInDimOpCanon final
: OpRewritePattern<mlir::stablehlo::BroadcastInDimOp> {
using OpRewritePattern::OpRewritePattern;
Expand Down Expand Up @@ -499,10 +646,16 @@ struct StableHLOCanonicalize final
void populateCanonicalizationPatterns(MLIRContext *context,
RewritePatternSet *patterns,
PatternBenefit benefit) {
patterns->add<AddOpCanon, SubtractOpCanon, MulOpCanon, BroadcastInDimOpCanon,
ConcatenateOpCanon, ConvertOpCanon, DynamicReshapeOpCanon,
GetTupleElementOpCanon, RealOpCanon, ImagOpCanon,
GetDimensionSizeOpCanon, ReshapeOpCanon, TransposeOpCanon>(
context, benefit);
patterns->add<
// Arithmetic ops.
AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon,
// Complex ops.
RealOpCanon, ImagOpCanon,
// Query ops.
GetDimensionSizeOpCanon, GetTupleElementOpCanon,
// Shape manipulation(-ish) ops.
BroadcastInDimOpCanon, ConcatenateOpCanon, ConvertOpCanon,
DynamicReshapeOpCanon, ReshapeOpCanon, TransposeOpCanon>(context,
benefit);
}
} // namespace mlir::iree_compiler::stablehlo
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,102 @@ func.func @multiply(%arg0: tensor<2xi32>, %arg1: tensor<f32>)

// -----

// CHECK-LABEL: func.func @compare_signed_arg
// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>)
func.func @compare_signed_arg(%arg0: tensor<i32>)
-> (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) {
%c0 = stablehlo.constant dense<0> : tensor<i32>
%c4 = stablehlo.constant dense<4> : tensor<i32>
%c5 = stablehlo.constant dense<5> : tensor<i32>

%0 = stablehlo.compare EQ, %arg0, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%1 = stablehlo.compare GT, %arg0, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%2 = stablehlo.compare LE, %arg0, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = stablehlo.compare NE, %arg0, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

%4 = stablehlo.compare EQ, %c5, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%5 = stablehlo.compare LT, %c5, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%6 = stablehlo.compare GE, %c5, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%7 = stablehlo.compare NE, %c5, %arg0, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

// CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<false> : tensor<i1>
// CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<true> : tensor<i1>
// CHECK-DAG: [[C5:%.+]] = stablehlo.constant dense<5> : tensor<i32>

// CHECK-DAG: [[R0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[C5]], SIGNED
// CHECK-DAG: [[R1:%.+]] = stablehlo.compare GT, [[ARG0]], [[C5]], SIGNED
// CHECK-DAG: [[R2:%.+]] = stablehlo.compare LE, [[ARG0]], [[C5]], SIGNED
// CHECK-DAG: [[R3:%.+]] = stablehlo.compare NE, [[ARG0]], [[C5]], SIGNED

// CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C0]], [[R0]], [[R1]], [[R2]], [[R3]]
return %0, %1, %2, %3, %4, %5, %6, %7 :
tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
}

// -----

// CHECK-LABEL: func.func @compare_unsigned_arg
// CHECK-SAME: ([[ARG0:%.+]]: tensor<i32>)
func.func @compare_unsigned_arg(%arg0: tensor<i32>)
-> (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) {
%c0 = stablehlo.constant dense<0> : tensor<i32>
%c4 = stablehlo.constant dense<4> : tensor<i32>
%c5 = stablehlo.constant dense<5> : tensor<i32>

%0 = stablehlo.compare EQ, %arg0, %arg0, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%1 = stablehlo.compare GT, %arg0, %arg0, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%2 = stablehlo.compare LE, %arg0, %arg0, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = stablehlo.compare NE, %arg0, %arg0, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

%4 = stablehlo.compare EQ, %c5, %arg0, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%5 = stablehlo.compare LT, %c5, %arg0, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%6 = stablehlo.compare GE, %c5, %arg0, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%7 = stablehlo.compare NE, %c5, %arg0, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

// CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<false> : tensor<i1>
// CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<true> : tensor<i1>
// CHECK-DAG: [[C5:%.+]] = stablehlo.constant dense<5> : tensor<i32>

// CHECK-DAG: [[R0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[C5]], UNSIGNED
// CHECK-DAG: [[R1:%.+]] = stablehlo.compare GT, [[ARG0]], [[C5]], UNSIGNED
// CHECK-DAG: [[R2:%.+]] = stablehlo.compare LE, [[ARG0]], [[C5]], UNSIGNED
// CHECK-DAG: [[R3:%.+]] = stablehlo.compare NE, [[ARG0]], [[C5]], UNSIGNED

// CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C0]], [[R0]], [[R1]], [[R2]], [[R3]]
return %0, %1, %2, %3, %4, %5, %6, %7 :
tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
}

// -----

// CHECK-LABEL: func.func @compare_folds
func.func @compare_folds()
-> (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) {
%cn1 = stablehlo.constant dense<-1> : tensor<i32>
%c0 = stablehlo.constant dense<0> : tensor<i32>
%c4 = stablehlo.constant dense<4> : tensor<i32>
%c5 = stablehlo.constant dense<5> : tensor<i32>

%0 = stablehlo.compare EQ, %cn1, %cn1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%1 = stablehlo.compare GT, %c5, %c5, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%2 = stablehlo.compare GE, %c4, %cn1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = stablehlo.compare LE, %c4, %c5, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

%4 = stablehlo.compare EQ, %cn1, %cn1, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%5 = stablehlo.compare GT, %c5, %cn1, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%6 = stablehlo.compare GE, %c5, %c4, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%7 = stablehlo.compare LE, %cn1, %c5, UNSIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>

// CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense<false> : tensor<i1>
// CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense<true> : tensor<i1>

// CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C1]], [[C1]], [[C0]], [[C1]], [[C0]]
return %0, %1, %2, %3, %4, %5, %6, %7 :
tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>
}

// -----

// CHECK-LABEL: func.func @broadcast_in_dim
// CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>)
func.func @broadcast_in_dim(%arg0: tensor<3x3xi32>)
Expand Down

0 comments on commit 0aaa30b

Please sign in to comment.