-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][TosaToLinalg] Exit after notifyMatchFailure
#132012
Conversation
This PR adds `return nullptr` when the shift value of `tosa.mul` is not constant to prevent a crash.
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-tosa Author: Longsheng Mou (CoTinker) ChangesThis PR adds Full diff: https://github.com/llvm/llvm-project/pull/132012.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c0a25a56dbe2a..6e1e3343ac169 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -136,14 +136,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MulOp
if (isa<tosa::MulOp>(op)) {
- auto shift_val = cast<tosa::MulOp>(op).getShift();
- DenseElementsAttr shift_elem;
- if (!shift_val.getImpl() ||
- !matchPattern(shift_val, m_Constant(&shift_elem))) {
+ auto shiftVal = cast<tosa::MulOp>(op).getShift();
+ DenseElementsAttr shiftElem;
+ if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
+ return nullptr;
}
- int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index d00846a4c3e02..69d8471df8032 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,3 +73,11 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// -----
+
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+ // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
+ %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+ return %0 : tensor<2x3xi32>
+}
|
@llvm/pr-subscribers-mlir Author: Longsheng Mou (CoTinker) ChangesThis PR adds Full diff: https://github.com/llvm/llvm-project/pull/132012.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index c0a25a56dbe2a..6e1e3343ac169 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -136,14 +136,14 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::MulOp
if (isa<tosa::MulOp>(op)) {
- auto shift_val = cast<tosa::MulOp>(op).getShift();
- DenseElementsAttr shift_elem;
- if (!shift_val.getImpl() ||
- !matchPattern(shift_val, m_Constant(&shift_elem))) {
+ auto shiftVal = cast<tosa::MulOp>(op).getShift();
+ DenseElementsAttr shiftElem;
+ if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
+ return nullptr;
}
- int32_t shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+ int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
if (isa<FloatType>(elementTy)) {
if (shift != 0) {
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index d00846a4c3e02..69d8471df8032 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -73,3 +73,11 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}
+
+// -----
+
+func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
+ // expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
+ %0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
+ return %0 : tensor<2x3xi32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @CoTinker!
This PR adds
return nullptr
when the shift value oftosa.mul
is not constant to prevent a crash. Fixes #131766.