-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][TosaToLinalg] Fix bugs in PointwiseConverter #132526
Conversation
- TOSA ensures pointwise op inputs have equal ranks, so the redundant rank expansion function is removed. - Added `getBroadcastableOperands` to prevent crashes by handling non-broadcastable inputs in `tosa.mul` and `tosa.negate`.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) Changes
Full diff: https://github.com/llvm/llvm-project/pull/132526.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6e1e3343ac169..e18fa849e9f30 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return nullptr;
}
-static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
- int64_t rank) {
- // No need to expand if we are already at the desired rank
- auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
- assert(tensorType && "expected a ranked tensor type");
- int64_t tensorRank = tensorType.getRank();
- int64_t numExtraDims = rank - tensorRank;
- assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
- if (!numExtraDims)
- return tensor;
-
- // Compute reassociation indices
- SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
- int64_t index = 0;
- if (tensorRank != 0) {
- for (index = 0; index <= numExtraDims; index++)
- reassociationIndices[0].push_back(index);
- for (size_t position = 1; position < reassociationIndices.size();
- position++)
- reassociationIndices[position].push_back(index++);
- }
-
- // Compute result type
- SmallVector<int64_t> resultShape;
- for (index = 0; index < numExtraDims; index++)
- resultShape.push_back(1);
- for (auto size : tensorType.getShape())
- resultShape.push_back(size);
- auto resultType =
- RankedTensorType::get(resultShape, tensorType.getElementType());
-
- // Emit 'tensor.expand_shape' op
- return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
- reassociationIndices);
-}
-
-static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
- Location loc, ValueRange operands,
- int64_t rank) {
- return llvm::map_to_vector(operands, [&](Value operand) {
- return expandRank(rewriter, loc, operand, rank);
- });
-}
-
using IndexPool = DenseMap<int64_t, Value>;
// Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
return success();
}
+static ValueRange getBroadcastableOperands(Operation *operation,
+ ValueRange operands) {
+ // Shift cannot broadcast
+ if (isa<tosa::MulOp>(operation))
+ return operands.take_front(2);
+ // Input1_zp and output_zp cannot broadcast
+ if (isa<tosa::NegateOp>(operation))
+ return operands.take_front(1);
+ return operands;
+}
+
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
// Lower operation
IndexPool indexPool;
auto loc = operation->getLoc();
- auto rank =
- cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
- // For the mul op we need to avoid expanding the rank of the optional shift
- // input.
- auto operandsToExpand =
- isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
-
- auto expandedOperands =
- expandInputRanks(rewriter, loc, operandsToExpand, rank);
+ auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
auto [targetShape, masterOperands] =
- computeTargetShape(rewriter, loc, indexPool, expandedOperands);
- auto broadcastOperands = broadcastDynamicDimensions(
- rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
+ computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
+ auto broadcastOperands =
+ broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
+ targetShape, masterOperands);
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
targetShape, converter);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 18ce8571eeea0..9258442de5a45 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -664,7 +664,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
%40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
// CHECK: [[ZERO:%.+]] = arith.constant 0
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
%in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
@@ -856,7 +856,7 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
// CHECK-LABEL: @test_negate_quantized
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
// CHECK: [[CNST:%.+]] = arith.constant 7
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
// CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
@@ -871,7 +871,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
%0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
// CHECK: [[C_128:%.+]] = arith.constant -128
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -2317,3 +2317,23 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> (
return
}
+
+// -----
+
+// CHECK-LABEL: @test_0d_input
+func.func @test_0d_input(%arg0: tensor<i32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.muli
+ %shift1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg0, %shift1 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+ // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: arith.subi [[ZERO]], %[[ARG1]]
+ %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<i32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+
+ return
+}
|
@llvm/pr-subscribers-mlir-tosa Author: Longsheng Mou (CoTinker) Changes
Full diff: https://github.com/llvm/llvm-project/pull/132526.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6e1e3343ac169..e18fa849e9f30 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
return nullptr;
}
-static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
- int64_t rank) {
- // No need to expand if we are already at the desired rank
- auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
- assert(tensorType && "expected a ranked tensor type");
- int64_t tensorRank = tensorType.getRank();
- int64_t numExtraDims = rank - tensorRank;
- assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
- if (!numExtraDims)
- return tensor;
-
- // Compute reassociation indices
- SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
- int64_t index = 0;
- if (tensorRank != 0) {
- for (index = 0; index <= numExtraDims; index++)
- reassociationIndices[0].push_back(index);
- for (size_t position = 1; position < reassociationIndices.size();
- position++)
- reassociationIndices[position].push_back(index++);
- }
-
- // Compute result type
- SmallVector<int64_t> resultShape;
- for (index = 0; index < numExtraDims; index++)
- resultShape.push_back(1);
- for (auto size : tensorType.getShape())
- resultShape.push_back(size);
- auto resultType =
- RankedTensorType::get(resultShape, tensorType.getElementType());
-
- // Emit 'tensor.expand_shape' op
- return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
- reassociationIndices);
-}
-
-static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
- Location loc, ValueRange operands,
- int64_t rank) {
- return llvm::map_to_vector(operands, [&](Value operand) {
- return expandRank(rewriter, loc, operand, rank);
- });
-}
-
using IndexPool = DenseMap<int64_t, Value>;
// Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
return success();
}
+static ValueRange getBroadcastableOperands(Operation *operation,
+ ValueRange operands) {
+ // Shift cannot broadcast
+ if (isa<tosa::MulOp>(operation))
+ return operands.take_front(2);
+ // Input1_zp and output_zp cannot broadcast
+ if (isa<tosa::NegateOp>(operation))
+ return operands.take_front(1);
+ return operands;
+}
+
static LogicalResult
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
// Lower operation
IndexPool indexPool;
auto loc = operation->getLoc();
- auto rank =
- cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
- // For the mul op we need to avoid expanding the rank of the optional shift
- // input.
- auto operandsToExpand =
- isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
-
- auto expandedOperands =
- expandInputRanks(rewriter, loc, operandsToExpand, rank);
+ auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
auto [targetShape, masterOperands] =
- computeTargetShape(rewriter, loc, indexPool, expandedOperands);
- auto broadcastOperands = broadcastDynamicDimensions(
- rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
+ computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
+ auto broadcastOperands =
+ broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
+ targetShape, masterOperands);
return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
targetShape, converter);
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 18ce8571eeea0..9258442de5a45 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -664,7 +664,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
%40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
// CHECK: [[ZERO:%.+]] = arith.constant 0
// CHECK: arith.subi [[ZERO]], %[[ARG1]]
%in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
@@ -856,7 +856,7 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
// CHECK-LABEL: @test_negate_quantized
func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
// CHECK: [[CNST:%.+]] = arith.constant 7
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
// CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
@@ -871,7 +871,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
%0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
// CHECK: linalg.generic
- // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+ // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
// CHECK: [[C_128:%.+]] = arith.constant -128
// CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
// CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -2317,3 +2317,23 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> (
return
}
+
+// -----
+
+// CHECK-LABEL: @test_0d_input
+func.func @test_0d_input(%arg0: tensor<i32>) -> () {
+ // CHECK: linalg.generic
+ // CHECK: arith.muli
+ %shift1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg0, %shift1 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+
+ // CHECK: linalg.generic
+ // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+ // CHECK: [[ZERO:%.+]] = arith.constant 0
+ // CHECK: arith.subi [[ZERO]], %[[ARG1]]
+ %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<i32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+
+ return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me. Will let others to approve since I don't commit to TosaToLinalg actively.
Okay, thanks for your review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for looking at this @CoTinker.
getBroadcastableOperands
to prevent crashes by handling non-broadcastable inputs intosa.mul
andtosa.negate
. Fixes [MLIR]-tosa-to-linalg
triggers Assertion `numExtraDims >= 0 && "cannot expand tensor to a lower rank"' failed. #131294.