Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][TosaToLinalg] Fix bugs in PointwiseConverter #132526

Merged
merged 1 commit into from
Mar 26, 2025

Conversation

CoTinker
Copy link
Contributor

- TOSA ensures pointwise op inputs have equal ranks, so the redundant
  rank expansion function is removed.
- Added `getBroadcastableOperands` to prevent crashes by handling
  non-broadcastable inputs in `tosa.mul` and `tosa.negate`.
@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes
  • TOSA ensures pointwise op inputs have equal ranks, so the redundant rank expansion function is removed.
  • Added getBroadcastableOperands to prevent crashes by handling non-broadcastable inputs in tosa.mul and tosa.negate. Fixes #131294.

Full diff: https://github.com/llvm/llvm-project/pull/132526.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+16-56)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+23-3)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6e1e3343ac169..e18fa849e9f30 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   return nullptr;
 }
 
-static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
-                        int64_t rank) {
-  // No need to expand if we are already at the desired rank
-  auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
-  assert(tensorType && "expected a ranked tensor type");
-  int64_t tensorRank = tensorType.getRank();
-  int64_t numExtraDims = rank - tensorRank;
-  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
-  if (!numExtraDims)
-    return tensor;
-
-  // Compute reassociation indices
-  SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
-  int64_t index = 0;
-  if (tensorRank != 0) {
-    for (index = 0; index <= numExtraDims; index++)
-      reassociationIndices[0].push_back(index);
-    for (size_t position = 1; position < reassociationIndices.size();
-         position++)
-      reassociationIndices[position].push_back(index++);
-  }
-
-  // Compute result type
-  SmallVector<int64_t> resultShape;
-  for (index = 0; index < numExtraDims; index++)
-    resultShape.push_back(1);
-  for (auto size : tensorType.getShape())
-    resultShape.push_back(size);
-  auto resultType =
-      RankedTensorType::get(resultShape, tensorType.getElementType());
-
-  // Emit 'tensor.expand_shape' op
-  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
-                                                reassociationIndices);
-}
-
-static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
-                                           Location loc, ValueRange operands,
-                                           int64_t rank) {
-  return llvm::map_to_vector(operands, [&](Value operand) {
-    return expandRank(rewriter, loc, operand, rank);
-  });
-}
-
 using IndexPool = DenseMap<int64_t, Value>;
 
 // Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
   return success();
 }
 
+static ValueRange getBroadcastableOperands(Operation *operation,
+                                           ValueRange operands) {
+  // Shift cannot broadcast
+  if (isa<tosa::MulOp>(operation))
+    return operands.take_front(2);
+  // Input1_zp and output_zp cannot broadcast
+  if (isa<tosa::NegateOp>(operation))
+    return operands.take_front(1);
+  return operands;
+}
+
 static LogicalResult
 elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
                                  ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
   // Lower operation
   IndexPool indexPool;
   auto loc = operation->getLoc();
-  auto rank =
-      cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
-  // For the mul op we need to avoid expanding the rank of the optional shift
-  // input.
-  auto operandsToExpand =
-      isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
-
-  auto expandedOperands =
-      expandInputRanks(rewriter, loc, operandsToExpand, rank);
+  auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
   auto [targetShape, masterOperands] =
-      computeTargetShape(rewriter, loc, indexPool, expandedOperands);
-  auto broadcastOperands = broadcastDynamicDimensions(
-      rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
+      computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
+  auto broadcastOperands =
+      broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
+                                 targetShape, masterOperands);
   return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
                                     targetShape, converter);
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 18ce8571eeea0..9258442de5a45 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -664,7 +664,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   %40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
   // CHECK: [[ZERO:%.+]] = arith.constant 0
   // CHECK: arith.subi [[ZERO]], %[[ARG1]]
   %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
@@ -856,7 +856,7 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 // CHECK-LABEL: @test_negate_quantized
 func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
   // CHECK: [[CNST:%.+]] = arith.constant 7
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
   // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
@@ -871,7 +871,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
   // CHECK: [[C_128:%.+]] = arith.constant -128
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
   // CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -2317,3 +2317,23 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> (
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: @test_0d_input
+func.func @test_0d_input(%arg0: tensor<i32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.muli
+  %shift1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg0, %shift1 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+  // CHECK: [[ZERO:%.+]] = arith.constant 0
+  // CHECK: arith.subi [[ZERO]], %[[ARG1]]
+  %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<i32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Mar 22, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Longsheng Mou (CoTinker)

Changes
  • TOSA ensures pointwise op inputs have equal ranks, so the redundant rank expansion function is removed.
  • Added getBroadcastableOperands to prevent crashes by handling non-broadcastable inputs in tosa.mul and tosa.negate. Fixes #131294.

Full diff: https://github.com/llvm/llvm-project/pull/132526.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+16-56)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+23-3)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 6e1e3343ac169..e18fa849e9f30 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -711,50 +711,6 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   return nullptr;
 }
 
-static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
-                        int64_t rank) {
-  // No need to expand if we are already at the desired rank
-  auto tensorType = dyn_cast<RankedTensorType>(tensor.getType());
-  assert(tensorType && "expected a ranked tensor type");
-  int64_t tensorRank = tensorType.getRank();
-  int64_t numExtraDims = rank - tensorRank;
-  assert(numExtraDims >= 0 && "cannot expand tensor to a lower rank");
-  if (!numExtraDims)
-    return tensor;
-
-  // Compute reassociation indices
-  SmallVector<ReassociationIndices> reassociationIndices(tensorRank);
-  int64_t index = 0;
-  if (tensorRank != 0) {
-    for (index = 0; index <= numExtraDims; index++)
-      reassociationIndices[0].push_back(index);
-    for (size_t position = 1; position < reassociationIndices.size();
-         position++)
-      reassociationIndices[position].push_back(index++);
-  }
-
-  // Compute result type
-  SmallVector<int64_t> resultShape;
-  for (index = 0; index < numExtraDims; index++)
-    resultShape.push_back(1);
-  for (auto size : tensorType.getShape())
-    resultShape.push_back(size);
-  auto resultType =
-      RankedTensorType::get(resultShape, tensorType.getElementType());
-
-  // Emit 'tensor.expand_shape' op
-  return rewriter.create<tensor::ExpandShapeOp>(loc, resultType, tensor,
-                                                reassociationIndices);
-}
-
-static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
-                                           Location loc, ValueRange operands,
-                                           int64_t rank) {
-  return llvm::map_to_vector(operands, [&](Value operand) {
-    return expandRank(rewriter, loc, operand, rank);
-  });
-}
-
 using IndexPool = DenseMap<int64_t, Value>;
 
 // Emit an 'arith.constant' op for the given index if it has not been created
@@ -1036,6 +992,17 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
   return success();
 }
 
+static ValueRange getBroadcastableOperands(Operation *operation,
+                                           ValueRange operands) {
+  // Shift cannot broadcast
+  if (isa<tosa::MulOp>(operation))
+    return operands.take_front(2);
+  // Input1_zp and output_zp cannot broadcast
+  if (isa<tosa::NegateOp>(operation))
+    return operands.take_front(1);
+  return operands;
+}
+
 static LogicalResult
 elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
                                  ConversionPatternRewriter &rewriter,
@@ -1052,19 +1019,12 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
   // Lower operation
   IndexPool indexPool;
   auto loc = operation->getLoc();
-  auto rank =
-      cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
-  // For the mul op we need to avoid expanding the rank of the optional shift
-  // input.
-  auto operandsToExpand =
-      isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
-
-  auto expandedOperands =
-      expandInputRanks(rewriter, loc, operandsToExpand, rank);
+  auto operandsToBroadcast = getBroadcastableOperands(operation, operands);
   auto [targetShape, masterOperands] =
-      computeTargetShape(rewriter, loc, indexPool, expandedOperands);
-  auto broadcastOperands = broadcastDynamicDimensions(
-      rewriter, loc, indexPool, expandedOperands, targetShape, masterOperands);
+      computeTargetShape(rewriter, loc, indexPool, operandsToBroadcast);
+  auto broadcastOperands =
+      broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast,
+                                 targetShape, masterOperands);
   return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands,
                                     targetShape, converter);
 }
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 18ce8571eeea0..9258442de5a45 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -664,7 +664,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
   %40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32):
+  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
   // CHECK: [[ZERO:%.+]] = arith.constant 0
   // CHECK: arith.subi [[ZERO]], %[[ARG1]]
   %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
@@ -856,7 +856,7 @@ func.func @test_bool(%arg0: tensor<1xi1>, %arg1: tensor<1xi1>) -> () {
 // CHECK-LABEL: @test_negate_quantized
 func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
   // CHECK: [[CNST:%.+]] = arith.constant 7
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
   // CHECK: [[SUB:%.+]] = arith.subi [[CNST]], [[EXT]]
@@ -871,7 +871,7 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
   %0 = tosa.negate %arg0, %in_zp0, %out_zp0 : (tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1xi8>
 
   // CHECK: linalg.generic
-  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8, %[[BBARG2:.+]]: i8, %[[BBARG3:.+]]: i8
+  // CHECK: ^bb0(%[[BBARG0:.+]]: i8, %[[BBARG1:.+]]: i8
   // CHECK: [[C_128:%.+]] = arith.constant -128
   // CHECK: [[EXT:%.+]] = arith.extsi %[[BBARG0]] : i8 to i16
   // CHECK: [[SUB:%.+]] = arith.subi [[C_128]], [[EXT]]
@@ -2317,3 +2317,23 @@ func.func @clamp_nan_ignore(%arg0: tensor<5x4xf32>, %arg1: tensor<5x4xf32>) -> (
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: @test_0d_input
+func.func @test_0d_input(%arg0: tensor<i32>) -> () {
+  // CHECK: linalg.generic
+  // CHECK: arith.muli
+  %shift1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+  %0 = tosa.mul %arg0, %arg0, %shift1 : (tensor<i32>, tensor<i32>, tensor<1xi8>) -> tensor<i32>
+
+  // CHECK: linalg.generic
+  // CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
+  // CHECK: [[ZERO:%.+]] = arith.constant 0
+  // CHECK: arith.subi [[ZERO]], %[[ARG1]]
+  %in_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %out_zp = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32>
+  %5 = tosa.negate %arg0, %in_zp, %out_zp : (tensor<i32>, tensor<1xi32>, tensor<1xi32>) -> tensor<i32>
+
+  return
+}

@CoTinker CoTinker requested a review from Jerry-Ge March 25, 2025 01:18
@lhutton1
Copy link
Contributor

@CoTinker CoTinker requested a review from RoboTux March 25, 2025 12:27
Copy link
Member

@Jerry-Ge Jerry-Ge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me. Will let others to approve since I don't commit to TosaToLinalg actively.

@CoTinker
Copy link
Contributor Author

Looks good to me. Will let others to approve since I don't commit to TosaToLinalg actively.

Okay, thanks for your review.

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking at this @CoTinker.

@GeorgeARM GeorgeARM merged commit 73f487d into llvm:main Mar 26, 2025
15 checks passed
@CoTinker CoTinker deleted the no_expand_rank branch March 26, 2025 08:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
5 participants