-
Notifications
You must be signed in to change notification settings - Fork 14k
[mlir][tosa] Fix MulOp verifier handling for unranked operands #141980
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of `rankedOperandTypes` would be >= 2, which isn't the case when both a and b are unranked. This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well. Change-Id: I0d0b7f7e8058f9a25dcb6c051aa0375cf780b80c
@llvm/pr-subscribers-mlir Author: Luke Hutton (lhutton1) ChangesThe previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well. Full diff: https://github.com/llvm/llvm-project/pull/141980.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..298802fc7fa6c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1779,7 +1779,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
}
LogicalResult tosa::MulOp::verify() {
- auto resElemType = getElementTypeOrSelf(getOutput());
+ const Value output = getOutput();
+ auto resElemType = getElementTypeOrSelf(output);
// Verify if the element type among operands and result match tosa
// specification.
@@ -1819,59 +1820,39 @@ LogicalResult tosa::MulOp::verify() {
// Verify the op has same ranks for all main operands (excludes extra operands
// such as shift of mul op, so this is the only difference with the built-in
// `SameOperandsAndResultRank` trait) and results types, if known.
-
- // delegate function that returns true if type is a shaped type with known
- // rank
- auto hasRank = [](const Type type) {
- if (auto shaped_type = dyn_cast<ShapedType>(type))
- return shaped_type.hasRank();
-
- return false;
- };
-
- auto rankedOperandTypes =
- llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
-
- auto rankedResultTypes =
- llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
-
- // If all operands and results are unranked, then no further verification.
- if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+ TypeRange operandTypes = getOperandTypes();
+ ShapedType aType = cast<ShapedType>(operandTypes[0]);
+ ShapedType bType = cast<ShapedType>(operandTypes[1]);
+
+ const bool aHasRank = aType.hasRank();
+ const bool bHasRank = bType.hasRank();
+ if (aHasRank && bHasRank) {
+ const int64_t aRank = aType.getRank();
+ const int64_t bRank = bType.getRank();
+ if (aRank != bRank)
+ return emitOpError("a and b operands don't have matching ranks, got ")
+ << aRank << " and " << bRank;
+
+ // check for broadcast compatible shapes
+ SmallVector<int64_t> resultShape;
+ if (!mlir::OpTrait::util::getBroadcastedShape(
+ aType.getShape(), bType.getShape(), resultShape))
+ return emitOpError("a and b operands don't have broadcast-compatible "
+ "shapes, got ")
+ << aType << " and " << bType;
+ }
+
+ ShapedType resultType = cast<ShapedType>(output.getType());
+ if (!resultType.hasRank())
return success();
- // delegate function that returns rank of shaped type with known rank
- auto getRank = [](const Type type) {
- return cast<ShapedType>(type).getRank();
- };
-
- auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
- : getRank(*rankedResultTypes.begin());
-
- for (size_t i = 0; i < 2; ++i) {
- if (rank != getRank(rankedOperandTypes[i])) {
- return emitOpError("operands don't have matching ranks");
- }
- }
-
- for (const auto type : rankedResultTypes) {
- if (rank != getRank(type)) {
- return emitOpError("result type has different rank than operands");
- }
- }
-
- // check for broadcast compatible shapes in first two operands (ignoring
- // shift)
-
- // delegate function that returns shape of shaped type
- auto getShape = [](const Type type) {
- return mlir::cast<ShapedType>(type).getShape();
- };
- SmallVector<int64_t> resultShape;
- if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
- getShape(rankedOperandTypes[1]),
- resultShape)) {
- return emitOpError("operands don't have broadcast-compatible shapes");
- }
+ const int64_t resultRank = resultType.getRank();
+ if (aHasRank && resultRank != aType.getRank())
+ return emitOpError("result type has different rank than a, got ")
+ << resultRank << " vs " << aType.getRank();
+ if (bHasRank && resultRank != bType.getRank())
+ return emitOpError("result type has different rank than b, got ")
+ << resultRank << " vs " << bType.getRank();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7b589fa839b44..3298e518de2f5 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1107,11 +1107,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// CHECK-LABEL: test_mul_non_broadcast
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
- // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+ // expected-error@+1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: test_mul_different_operand_ranks
+func.func @test_mul_different_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_a_and_result_ranks
+func.func @test_mul_different_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op result type has different rank than a, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_b_and_result_ranks
+func.func @test_mul_different_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op result type has different rank than b, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: test_resize_invalid_scale_values
func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x?x?x?xf32> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 5ec506a45b3ad..882b59d029a4a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso
return %0 : tensor<13x21x3xi16>
}
+// -----
+// CHECK-LABEL: test_mul_unranked_b
+func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_unranked_a_and_b
+func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: pow
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
|
@llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesThe previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well. Full diff: https://github.com/llvm/llvm-project/pull/141980.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 3ee5a85a21dca..298802fc7fa6c 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -1779,7 +1779,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
}
LogicalResult tosa::MulOp::verify() {
- auto resElemType = getElementTypeOrSelf(getOutput());
+ const Value output = getOutput();
+ auto resElemType = getElementTypeOrSelf(output);
// Verify if the element type among operands and result match tosa
// specification.
@@ -1819,59 +1820,39 @@ LogicalResult tosa::MulOp::verify() {
// Verify the op has same ranks for all main operands (excludes extra operands
// such as shift of mul op, so this is the only difference with the built-in
// `SameOperandsAndResultRank` trait) and results types, if known.
-
- // delegate function that returns true if type is a shaped type with known
- // rank
- auto hasRank = [](const Type type) {
- if (auto shaped_type = dyn_cast<ShapedType>(type))
- return shaped_type.hasRank();
-
- return false;
- };
-
- auto rankedOperandTypes =
- llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
-
- auto rankedResultTypes =
- llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
-
- // If all operands and results are unranked, then no further verification.
- if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+ TypeRange operandTypes = getOperandTypes();
+ ShapedType aType = cast<ShapedType>(operandTypes[0]);
+ ShapedType bType = cast<ShapedType>(operandTypes[1]);
+
+ const bool aHasRank = aType.hasRank();
+ const bool bHasRank = bType.hasRank();
+ if (aHasRank && bHasRank) {
+ const int64_t aRank = aType.getRank();
+ const int64_t bRank = bType.getRank();
+ if (aRank != bRank)
+ return emitOpError("a and b operands don't have matching ranks, got ")
+ << aRank << " and " << bRank;
+
+ // check for broadcast compatible shapes
+ SmallVector<int64_t> resultShape;
+ if (!mlir::OpTrait::util::getBroadcastedShape(
+ aType.getShape(), bType.getShape(), resultShape))
+ return emitOpError("a and b operands don't have broadcast-compatible "
+ "shapes, got ")
+ << aType << " and " << bType;
+ }
+
+ ShapedType resultType = cast<ShapedType>(output.getType());
+ if (!resultType.hasRank())
return success();
- // delegate function that returns rank of shaped type with known rank
- auto getRank = [](const Type type) {
- return cast<ShapedType>(type).getRank();
- };
-
- auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
- : getRank(*rankedResultTypes.begin());
-
- for (size_t i = 0; i < 2; ++i) {
- if (rank != getRank(rankedOperandTypes[i])) {
- return emitOpError("operands don't have matching ranks");
- }
- }
-
- for (const auto type : rankedResultTypes) {
- if (rank != getRank(type)) {
- return emitOpError("result type has different rank than operands");
- }
- }
-
- // check for broadcast compatible shapes in first two operands (ignoring
- // shift)
-
- // delegate function that returns shape of shaped type
- auto getShape = [](const Type type) {
- return mlir::cast<ShapedType>(type).getShape();
- };
- SmallVector<int64_t> resultShape;
- if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
- getShape(rankedOperandTypes[1]),
- resultShape)) {
- return emitOpError("operands don't have broadcast-compatible shapes");
- }
+ const int64_t resultRank = resultType.getRank();
+ if (aHasRank && resultRank != aType.getRank())
+ return emitOpError("result type has different rank than a, got ")
+ << resultRank << " vs " << aType.getRank();
+ if (bHasRank && resultRank != bType.getRank())
+ return emitOpError("result type has different rank than b, got ")
+ << resultRank << " vs " << bType.getRank();
return success();
}
diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir
index 7b589fa839b44..3298e518de2f5 100644
--- a/mlir/test/Dialect/Tosa/invalid.mlir
+++ b/mlir/test/Dialect/Tosa/invalid.mlir
@@ -1107,11 +1107,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// CHECK-LABEL: test_mul_non_broadcast
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
- // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
+ // expected-error@+1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}
+// -----
+// CHECK-LABEL: test_mul_different_operand_ranks
+func.func @test_mul_different_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_a_and_result_ranks
+func.func @test_mul_different_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op result type has different rank than a, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_different_b_and_result_ranks
+func.func @test_mul_different_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ // expected-error@+1 {{'tosa.mul' op result type has different rank than b, got 3 vs 2}}
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: test_resize_invalid_scale_values
func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x?x?x?xf32> {
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index 5ec506a45b3ad..882b59d029a4a 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso
return %0 : tensor<13x21x3xi16>
}
+// -----
+// CHECK-LABEL: test_mul_unranked_b
+func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_mul_unranked_a_and_b
+func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
+ %shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
+ return %0 : tensor<13x21x3xf32>
+}
+
// -----
// CHECK-LABEL: pow
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Patch looks good, need an approver to sign-off and merge
LGTM |
…141980) The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of `rankedOperandTypes` would be >= 2, which isn't the case when both a and b are unranked. This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well.
…141980) The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of `rankedOperandTypes` would be >= 2, which isn't the case when both a and b are unranked. This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well.
The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of
rankedOperandTypes
would be >= 2, which isn't the case when both a and b are unranked.This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well.