Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][TosaToLinalg] Exit after notifyMatchFailure #132012

Merged
merged 1 commit into from
Mar 20, 2025

Conversation

CoTinker
Copy link
Contributor

This PR adds return nullptr when the shift value of tosa.mul is not constant to prevent a crash. Fixes #131766.

This PR adds `return nullptr` when the shift value of
`tosa.mul` is not constant to prevent a crash.
@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Longsheng Mou (CoTinker)

Changes

This PR adds return nullptr when the shift value of tosa.mul is not constant to prevent a crash. Fixes #131766.


Full diff: https://github.com/llvm/llvm-project/pull/132012.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-5)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+8)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c0a25a56dbe2a..6e1e3343ac169 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -136,14 +136,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::MulOp
   if (isa<tosa::MulOp>(op)) {
-    auto shift_val = cast<tosa::MulOp>(op).getShift();
-    DenseElementsAttr shift_elem;
-    if (!shift_val.getImpl() ||
-        !matchPattern(shift_val, m_Constant(&shift_elem))) {
+    auto shiftVal = cast<tosa::MulOp>(op).getShift();
+    DenseElementsAttr shiftElem;
+    if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
       (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
+      return nullptr;
     }
 
-    int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+    int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
 
     if (isa<FloatType>(elementTy)) {
       if (shift != 0) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index d00846a4c3e02..69d8471df8032 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,3 +73,11 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
   %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
+
+// -----
+
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+  // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
+  %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+  return %0 : tensor<2x3xi32>
+}

@llvmbot
Copy link
Member

llvmbot commented Mar 19, 2025

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR adds return nullptr when the shift value of tosa.mul is not constant to prevent a crash. Fixes #131766.


Full diff: https://github.com/llvm/llvm-project/pull/132012.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+5-5)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+8)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c0a25a56dbe2a..6e1e3343ac169 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -136,14 +136,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
 
   // tosa::MulOp
   if (isa<tosa::MulOp>(op)) {
-    auto shift_val = cast<tosa::MulOp>(op).getShift();
-    DenseElementsAttr shift_elem;
-    if (!shift_val.getImpl() ||
-        !matchPattern(shift_val, m_Constant(&shift_elem))) {
+    auto shiftVal = cast<tosa::MulOp>(op).getShift();
+    DenseElementsAttr shiftElem;
+    if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
       (void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
+      return nullptr;
     }
 
-    int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+    int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
 
     if (isa<FloatType>(elementTy)) {
       if (shift != 0) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index d00846a4c3e02..69d8471df8032 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,3 +73,11 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
   %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
   return %0 : tensor<*xf32>
 }
+
+// -----
+
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+  // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
+  %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+  return %0 : tensor<2x3xi32>
+}

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @CoTinker!

@lhutton1 lhutton1 merged commit 7c11d05 into llvm:main Mar 20, 2025
15 checks passed
@CoTinker CoTinker deleted the mul_shift branch March 21, 2025 01:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[MLIR] Segmentation fault when running tosa-to-linalg pass
3 participants