Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][TosaToLinalg] Only support ranked tensor for reduce and gather #131805

Merged
merged 1 commit into from
Mar 19, 2025

Conversation

CoTinker
Copy link
Contributor

This PR adds checks for ranked tensors in converter of reduce and gather to prevent crash. Fixes #131087.

This PR adds checks for ranked tensors in converter of reduce and gather
to prevent crash.
@llvmbot
Copy link
Member

llvmbot commented Mar 18, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Longsheng Mou (CoTinker)

Changes

This PR adds checks for ranked tensors in converter of reduce and gather to prevent crash. Fixes #131087.


Full diff: https://github.com/llvm/llvm-project/pull/131805.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+8-7)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir (+16)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index b5e0efa71b3cc..c0a25a56dbe2a 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1176,8 +1176,11 @@ template <typename OpTy>
 static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
                                                  PatternRewriter &rewriter) {
   auto loc = op->getLoc();
-  auto inputTy = cast<ShapedType>(op->getOperand(0).getType());
-  auto resultTy = cast<ShapedType>(op->getResult(0).getType());
+  auto inputTy = dyn_cast<RankedTensorType>(op->getOperand(0).getType());
+  auto resultTy = dyn_cast<RankedTensorType>(op->getResult(0).getType());
+  if (!inputTy || !resultTy)
+    return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
+
   auto elementTy = resultTy.getElementType();
   Value input = op->getOperand(0);
 
@@ -2380,11 +2383,9 @@ class GatherConverter : public OpConversionPattern<tosa::GatherOp> {
     auto input = adaptor.getOperands()[0];
     auto indices = adaptor.getOperands()[1];
 
-    auto valuesTy =
-        dyn_cast_or_null<RankedTensorType>(op.getValues().getType());
-    auto resultTy = cast<ShapedType>(op.getType());
-
-    if (!valuesTy)
+    auto valuesTy = dyn_cast<RankedTensorType>(op.getValues().getType());
+    auto resultTy = dyn_cast<RankedTensorType>(op.getType());
+    if (!valuesTy || !resultTy)
       return rewriter.notifyMatchFailure(op, "unranked tensors not supported");
 
     auto dynamicDims = inferDynamicDimsForGather(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
index d1a9671873de0..d00846a4c3e02 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir
@@ -57,3 +57,19 @@ func.func @cast_unsupported_type(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!
   %0 = tosa.cast %arg0 : (tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
   return %0 : tensor<13x21x3x!quant.uniform<i16:f32, 0.078431375324726104:128>>
 }
+
+// -----
+
+func.func @unranked_reduce(%arg0: tensor<*xf32>) -> tensor<*xf32> {
+  // expected-error@+1 {{failed to legalize operation 'tosa.reduce_sum'}}
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<*xf32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}
+
+// -----
+
+func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<*xf32> {
+  // expected-error@+1 {{failed to legalize operation 'tosa.gather'}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
+  return %0 : tensor<*xf32>
+}

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@CoTinker CoTinker merged commit fbc1038 into llvm:main Mar 19, 2025
15 checks passed
@CoTinker CoTinker deleted the ranked_tensor branch March 19, 2025 01:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
4 participants