diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp index 98afeaa7972e..abacad8c8b1d 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/Canonicalization.cpp @@ -13,6 +13,8 @@ #include "iree/compiler/InputConversion/StableHLO/Preprocessing/Passes.h" #include "iree/compiler/InputConversion/StableHLO/Preprocessing/Rewriters.h" #include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Attributes.h" @@ -236,6 +238,151 @@ struct MulOpCanon final : OpRewritePattern { } }; +static mlir::stablehlo::ComparisonDirection invertDirection( + mlir::stablehlo::ComparisonDirection direction) { + using mlir::stablehlo::ComparisonDirection; + + switch (direction) { + case ComparisonDirection::EQ: + return ComparisonDirection::EQ; + case ComparisonDirection::GE: + return ComparisonDirection::LE; + case ComparisonDirection::LE: + return ComparisonDirection::GE; + case ComparisonDirection::GT: + return ComparisonDirection::LT; + case ComparisonDirection::LT: + return ComparisonDirection::GT; + case ComparisonDirection::NE: + return ComparisonDirection::NE; + } + + llvm_unreachable("Unhandled case"); +} + +static APInt calculateComp(mlir::stablehlo::ComparisonType kind, + mlir::stablehlo::ComparisonDirection direction, + const APInt &lhs, const APInt &rhs) { + using mlir::stablehlo::ComparisonDirection; + using mlir::stablehlo::ComparisonType; + assert(llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED}, + kind) && + "Not an integer comparison"); + + auto asBit = [](bool value) { + return value ? APInt::getAllOnes(1) : APInt::getZero(1); + }; + + // Signed comparison. + if (kind == ComparisonType::SIGNED) { + switch (direction) { + case ComparisonDirection::EQ: + return asBit(lhs == rhs); + case ComparisonDirection::GE: + return asBit(lhs.sge(rhs)); + case ComparisonDirection::GT: + return asBit(lhs.sgt(rhs)); + case ComparisonDirection::LE: + return asBit(lhs.sle(rhs)); + case ComparisonDirection::LT: + return asBit(lhs.slt(rhs)); + case ComparisonDirection::NE: + return asBit(lhs != rhs); + } + } + + // Unsigned comparison. + switch (direction) { + case ComparisonDirection::EQ: + return asBit(lhs == rhs); + case ComparisonDirection::GE: + return asBit(lhs.uge(rhs)); + case ComparisonDirection::GT: + return asBit(lhs.ugt(rhs)); + case ComparisonDirection::LE: + return asBit(lhs.ule(rhs)); + case ComparisonDirection::LT: + return asBit(lhs.ult(rhs)); + case ComparisonDirection::NE: + return asBit(lhs != rhs); + } + + llvm_unreachable("Unhandled case"); +} + +struct CompareOpCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::CompareOp op, + PatternRewriter &rewriter) const override { + auto type = dyn_cast(op.getType()); + if (!type) return failure(); + + // Bail out on non-integer comparison. + // TODO: Support more comparison types. + using mlir::stablehlo::ComparisonType; + std::optional compType = op.getCompareType(); + if (!compType || + !llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED}, + *compType)) { + return failure(); + } + + using mlir::stablehlo::ComparisonDirection; + ComparisonDirection direction = op.getComparisonDirection(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + if (lhs == rhs) { + switch (direction) { + case ComparisonDirection::EQ: + case ComparisonDirection::GE: + case ComparisonDirection::LE: { + rewriter.replaceOpWithNewOp( + op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true))); + return success(); + } + case ComparisonDirection::GT: + case ComparisonDirection::LT: + case ComparisonDirection::NE: { + rewriter.replaceOpWithNewOp( + op, rewriter.getZeroAttr(type)); + return success(); + } + } + llvm_unreachable("Unhandled case"); + } + + TypedAttr lhsAttr; + matchPattern(lhs, m_Constant(&lhsAttr)); + + TypedAttr rhsAttr; + matchPattern(rhs, m_Constant(&rhsAttr)); + + // The canonical form has the constant operand as the RHS. + if (lhsAttr && !rhsAttr) { + rewriter.updateRootInPlace(op, [&op, direction, lhs, rhs] { + op.setComparisonDirection(invertDirection(direction)); + op->setOperands(ValueRange{rhs, lhs}); + }); + return success(); + } + + if (lhsAttr && rhsAttr) { + if (Attribute res = constFoldBinaryOp( + ArrayRef({lhsAttr, rhsAttr}), op.getType(), + [direction, kind = *compType](const APInt &a, const APInt &b) { + return calculateComp(kind, direction, a, b); + })) { + rewriter.replaceOpWithNewOp(op, res); + return success(); + } + } + + return failure(); + } +}; + struct BroadcastInDimOpCanon final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -499,10 +646,16 @@ struct StableHLOCanonicalize final void populateCanonicalizationPatterns(MLIRContext *context, RewritePatternSet *patterns, PatternBenefit benefit) { - patterns->add( - context, benefit); + patterns->add< + // Arithmetic ops. + AddOpCanon, SubtractOpCanon, MulOpCanon, CompareOpCanon, + // Complex ops. + RealOpCanon, ImagOpCanon, + // Query ops. + GetDimensionSizeOpCanon, GetTupleElementOpCanon, + // Shape manipulation(-ish) ops. + BroadcastInDimOpCanon, ConcatenateOpCanon, ConvertOpCanon, + DynamicReshapeOpCanon, ReshapeOpCanon, TransposeOpCanon>(context, + benefit); } } // namespace mlir::iree_compiler::stablehlo diff --git a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir index be547321dd35..f9c55ffb89e7 100644 --- a/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir +++ b/compiler/src/iree/compiler/InputConversion/StableHLO/Preprocessing/test/canonicalization.mlir @@ -109,6 +109,102 @@ func.func @multiply(%arg0: tensor<2xi32>, %arg1: tensor) // ----- +// CHECK-LABEL: func.func @compare_signed_arg +// CHECK-SAME: ([[ARG0:%.+]]: tensor) +func.func @compare_signed_arg(%arg0: tensor) + -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) { + %c0 = stablehlo.constant dense<0> : tensor + %c4 = stablehlo.constant dense<4> : tensor + %c5 = stablehlo.constant dense<5> : tensor + + %0 = stablehlo.compare EQ, %arg0, %arg0, SIGNED : (tensor, tensor) -> tensor + %1 = stablehlo.compare GT, %arg0, %arg0, SIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.compare LE, %arg0, %arg0, SIGNED : (tensor, tensor) -> tensor + %3 = stablehlo.compare NE, %arg0, %arg0, SIGNED : (tensor, tensor) -> tensor + + %4 = stablehlo.compare EQ, %c5, %arg0, SIGNED : (tensor, tensor) -> tensor + %5 = stablehlo.compare LT, %c5, %arg0, SIGNED : (tensor, tensor) -> tensor + %6 = stablehlo.compare GE, %c5, %arg0, SIGNED : (tensor, tensor) -> tensor + %7 = stablehlo.compare NE, %c5, %arg0, SIGNED : (tensor, tensor) -> tensor + + // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense : tensor + // CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense : tensor + // CHECK-DAG: [[C5:%.+]] = stablehlo.constant dense<5> : tensor + + // CHECK-DAG: [[R0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[C5]], SIGNED + // CHECK-DAG: [[R1:%.+]] = stablehlo.compare GT, [[ARG0]], [[C5]], SIGNED + // CHECK-DAG: [[R2:%.+]] = stablehlo.compare LE, [[ARG0]], [[C5]], SIGNED + // CHECK-DAG: [[R3:%.+]] = stablehlo.compare NE, [[ARG0]], [[C5]], SIGNED + + // CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C0]], [[R0]], [[R1]], [[R2]], [[R3]] + return %0, %1, %2, %3, %4, %5, %6, %7 : + tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor +} + +// ----- + +// CHECK-LABEL: func.func @compare_unsigned_arg +// CHECK-SAME: ([[ARG0:%.+]]: tensor) +func.func @compare_unsigned_arg(%arg0: tensor) + -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) { + %c0 = stablehlo.constant dense<0> : tensor + %c4 = stablehlo.constant dense<4> : tensor + %c5 = stablehlo.constant dense<5> : tensor + + %0 = stablehlo.compare EQ, %arg0, %arg0, UNSIGNED : (tensor, tensor) -> tensor + %1 = stablehlo.compare GT, %arg0, %arg0, UNSIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.compare LE, %arg0, %arg0, UNSIGNED : (tensor, tensor) -> tensor + %3 = stablehlo.compare NE, %arg0, %arg0, UNSIGNED : (tensor, tensor) -> tensor + + %4 = stablehlo.compare EQ, %c5, %arg0, UNSIGNED : (tensor, tensor) -> tensor + %5 = stablehlo.compare LT, %c5, %arg0, UNSIGNED : (tensor, tensor) -> tensor + %6 = stablehlo.compare GE, %c5, %arg0, UNSIGNED : (tensor, tensor) -> tensor + %7 = stablehlo.compare NE, %c5, %arg0, UNSIGNED : (tensor, tensor) -> tensor + + // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense : tensor + // CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense : tensor + // CHECK-DAG: [[C5:%.+]] = stablehlo.constant dense<5> : tensor + + // CHECK-DAG: [[R0:%.+]] = stablehlo.compare EQ, [[ARG0]], [[C5]], UNSIGNED + // CHECK-DAG: [[R1:%.+]] = stablehlo.compare GT, [[ARG0]], [[C5]], UNSIGNED + // CHECK-DAG: [[R2:%.+]] = stablehlo.compare LE, [[ARG0]], [[C5]], UNSIGNED + // CHECK-DAG: [[R3:%.+]] = stablehlo.compare NE, [[ARG0]], [[C5]], UNSIGNED + + // CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C0]], [[R0]], [[R1]], [[R2]], [[R3]] + return %0, %1, %2, %3, %4, %5, %6, %7 : + tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor +} + +// ----- + +// CHECK-LABEL: func.func @compare_folds +func.func @compare_folds() + -> (tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) { + %cn1 = stablehlo.constant dense<-1> : tensor + %c0 = stablehlo.constant dense<0> : tensor + %c4 = stablehlo.constant dense<4> : tensor + %c5 = stablehlo.constant dense<5> : tensor + + %0 = stablehlo.compare EQ, %cn1, %cn1, SIGNED : (tensor, tensor) -> tensor + %1 = stablehlo.compare GT, %c5, %c5, SIGNED : (tensor, tensor) -> tensor + %2 = stablehlo.compare GE, %c4, %cn1, SIGNED : (tensor, tensor) -> tensor + %3 = stablehlo.compare LE, %c4, %c5, SIGNED : (tensor, tensor) -> tensor + + %4 = stablehlo.compare EQ, %cn1, %cn1, UNSIGNED : (tensor, tensor) -> tensor + %5 = stablehlo.compare GT, %c5, %cn1, UNSIGNED : (tensor, tensor) -> tensor + %6 = stablehlo.compare GE, %c5, %c4, UNSIGNED : (tensor, tensor) -> tensor + %7 = stablehlo.compare LE, %cn1, %c5, UNSIGNED : (tensor, tensor) -> tensor + + // CHECK-DAG: [[C0:%.+]] = stablehlo.constant dense : tensor + // CHECK-DAG: [[C1:%.+]] = stablehlo.constant dense : tensor + + // CHECK-NEXT: return [[C1]], [[C0]], [[C1]], [[C1]], [[C1]], [[C0]], [[C1]], [[C0]] + return %0, %1, %2, %3, %4, %5, %6, %7 : + tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor +} + +// ----- + // CHECK-LABEL: func.func @broadcast_in_dim // CHECK-SAME: ([[ARG0:%.+]]: tensor<3x3xi32>) func.func @broadcast_in_dim(%arg0: tensor<3x3xi32>)