Skip to content

Commit 422647f

Browse files
lhutton1rorth
authored andcommitted
[mlir][tosa] Fix MulOp verifier handling for unranked operands (llvm#141980)
The previous verifier checks did not correctly handle unranked operands. For example, it could incorrectly assume the number of `rankedOperandTypes` would be >= 2, which isn't the case when both a and b are unranked. This change simplifies these checks such that they only operate over the intended a and b operands as opposed to the shift operand as well.
1 parent 75f29b1 commit 422647f

File tree

3 files changed

+77
-53
lines changed

3 files changed

+77
-53
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1857,7 +1857,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
18571857
}
18581858

18591859
LogicalResult tosa::MulOp::verify() {
1860-
auto resElemType = getElementTypeOrSelf(getOutput());
1860+
const Value output = getOutput();
1861+
auto resElemType = getElementTypeOrSelf(output);
18611862

18621863
// Verify if the element type among operands and result match tosa
18631864
// specification.
@@ -1897,59 +1898,39 @@ LogicalResult tosa::MulOp::verify() {
18971898
// Verify the op has same ranks for all main operands (excludes extra operands
18981899
// such as shift of mul op, so this is the only difference with the built-in
18991900
// `SameOperandsAndResultRank` trait) and results types, if known.
1900-
1901-
// delegate function that returns true if type is a shaped type with known
1902-
// rank
1903-
auto hasRank = [](const Type type) {
1904-
if (auto shaped_type = dyn_cast<ShapedType>(type))
1905-
return shaped_type.hasRank();
1906-
1907-
return false;
1908-
};
1909-
1910-
auto rankedOperandTypes =
1911-
llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
1912-
1913-
auto rankedResultTypes =
1914-
llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
1915-
1916-
// If all operands and results are unranked, then no further verification.
1917-
if (rankedOperandTypes.empty() && rankedResultTypes.empty())
1901+
TypeRange operandTypes = getOperandTypes();
1902+
ShapedType aType = cast<ShapedType>(operandTypes[0]);
1903+
ShapedType bType = cast<ShapedType>(operandTypes[1]);
1904+
1905+
const bool aHasRank = aType.hasRank();
1906+
const bool bHasRank = bType.hasRank();
1907+
if (aHasRank && bHasRank) {
1908+
const int64_t aRank = aType.getRank();
1909+
const int64_t bRank = bType.getRank();
1910+
if (aRank != bRank)
1911+
return emitOpError("a and b operands don't have matching ranks, got ")
1912+
<< aRank << " and " << bRank;
1913+
1914+
// check for broadcast compatible shapes
1915+
SmallVector<int64_t> resultShape;
1916+
if (!mlir::OpTrait::util::getBroadcastedShape(
1917+
aType.getShape(), bType.getShape(), resultShape))
1918+
return emitOpError("a and b operands don't have broadcast-compatible "
1919+
"shapes, got ")
1920+
<< aType << " and " << bType;
1921+
}
1922+
1923+
ShapedType resultType = cast<ShapedType>(output.getType());
1924+
if (!resultType.hasRank())
19181925
return success();
19191926

1920-
// delegate function that returns rank of shaped type with known rank
1921-
auto getRank = [](const Type type) {
1922-
return cast<ShapedType>(type).getRank();
1923-
};
1924-
1925-
auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
1926-
: getRank(*rankedResultTypes.begin());
1927-
1928-
for (size_t i = 0; i < 2; ++i) {
1929-
if (rank != getRank(rankedOperandTypes[i])) {
1930-
return emitOpError("operands don't have matching ranks");
1931-
}
1932-
}
1933-
1934-
for (const auto type : rankedResultTypes) {
1935-
if (rank != getRank(type)) {
1936-
return emitOpError("result type has different rank than operands");
1937-
}
1938-
}
1939-
1940-
// check for broadcast compatible shapes in first two operands (ignoring
1941-
// shift)
1942-
1943-
// delegate function that returns shape of shaped type
1944-
auto getShape = [](const Type type) {
1945-
return mlir::cast<ShapedType>(type).getShape();
1946-
};
1947-
SmallVector<int64_t> resultShape;
1948-
if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
1949-
getShape(rankedOperandTypes[1]),
1950-
resultShape)) {
1951-
return emitOpError("operands don't have broadcast-compatible shapes");
1952-
}
1927+
const int64_t resultRank = resultType.getRank();
1928+
if (aHasRank && resultRank != aType.getRank())
1929+
return emitOpError("result type has different rank than a, got ")
1930+
<< resultRank << " vs " << aType.getRank();
1931+
if (bHasRank && resultRank != bType.getRank())
1932+
return emitOpError("result type has different rank than b, got ")
1933+
<< resultRank << " vs " << bType.getRank();
19531934

19541935
return success();
19551936
}

mlir/test/Dialect/Tosa/invalid.mlir

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1134,11 +1134,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
11341134
// CHECK-LABEL: test_mul_non_broadcast
11351135
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
11361136
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1137-
// expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
1137+
// expected-error@+1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}}
11381138
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
11391139
return %0 : tensor<13x21x3xf32>
11401140
}
11411141

1142+
// -----
1143+
// CHECK-LABEL: test_mul_different_operand_ranks
1144+
func.func @test_mul_different_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
1145+
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1146+
// expected-error@+1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}}
1147+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
1148+
return %0 : tensor<13x21x3xf32>
1149+
}
1150+
1151+
// -----
1152+
// CHECK-LABEL: test_mul_different_a_and_result_ranks
1153+
func.func @test_mul_different_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
1154+
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1155+
// expected-error@+1 {{'tosa.mul' op result type has different rank than a, got 3 vs 2}}
1156+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
1157+
return %0 : tensor<13x21x3xf32>
1158+
}
1159+
1160+
// -----
1161+
// CHECK-LABEL: test_mul_different_b_and_result_ranks
1162+
func.func @test_mul_different_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> {
1163+
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1164+
// expected-error@+1 {{'tosa.mul' op result type has different rank than b, got 3 vs 2}}
1165+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
1166+
return %0 : tensor<13x21x3xf32>
1167+
}
1168+
11421169
// -----
11431170
// CHECK-LABEL: test_resize_invalid_scale_values
11441171
func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x?x?x?xf32> {

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso
424424
return %0 : tensor<13x21x3xi16>
425425
}
426426

427+
// -----
428+
// CHECK-LABEL: test_mul_unranked_b
429+
func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
430+
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
431+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
432+
return %0 : tensor<13x21x3xf32>
433+
}
434+
435+
// -----
436+
// CHECK-LABEL: test_mul_unranked_a_and_b
437+
func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
438+
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
439+
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
440+
return %0 : tensor<13x21x3xf32>
441+
}
442+
427443
// -----
428444
// CHECK-LABEL: pow
429445
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {

0 commit comments

Comments
 (0)