Skip to content

[mlir][tosa] Fix MulOp verifier handling for unranked operands #141980

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 33 additions & 52 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1779,7 +1779,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
}

LogicalResult tosa::MulOp::verify() {
auto resElemType = getElementTypeOrSelf(getOutput());
const Value output = getOutput();
auto resElemType = getElementTypeOrSelf(output);

// Verify if the element type among operands and result match tosa
// specification.
Expand Down Expand Up @@ -1819,59 +1820,39 @@ LogicalResult tosa::MulOp::verify() {
// Verify the op has same ranks for all main operands (excludes extra operands
// such as shift of mul op, so this is the only difference with the built-in
// `SameOperandsAndResultRank` trait) and results types, if known.

// delegate function that returns true if type is a shaped type with known
// rank
auto hasRank = [](const Type type) {
if (auto shaped_type = dyn_cast<ShapedType>(type))
return shaped_type.hasRank();

return false;
};

auto rankedOperandTypes =
llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));

auto rankedResultTypes =
llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);

// If all operands and results are unranked, then no further verification.
if (rankedOperandTypes.empty() && rankedResultTypes.empty())
TypeRange operandTypes = getOperandTypes();
ShapedType aType = cast<ShapedType>(operandTypes[0]);
ShapedType bType = cast<ShapedType>(operandTypes[1]);

const bool aHasRank = aType.hasRank();
const bool bHasRank = bType.hasRank();
if (aHasRank && bHasRank) {
const int64_t aRank = aType.getRank();
const int64_t bRank = bType.getRank();
if (aRank != bRank)
return emitOpError("a and b operands don't have matching ranks, got ")
<< aRank << " and " << bRank;

// check for broadcast compatible shapes
SmallVector<int64_t> resultShape;
if (!mlir::OpTrait::util::getBroadcastedShape(
aType.getShape(), bType.getShape(), resultShape))
return emitOpError("a and b operands don't have broadcast-compatible "
"shapes, got ")
<< aType << " and " << bType;
}

ShapedType resultType = cast<ShapedType>(output.getType());
if (!resultType.hasRank())
return success();

// delegate function that returns rank of shaped type with known rank
auto getRank = [](const Type type) {
return cast<ShapedType>(type).getRank();
};

auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
: getRank(*rankedResultTypes.begin());

for (size_t i = 0; i < 2; ++i) {
if (rank != getRank(rankedOperandTypes[i])) {
return emitOpError("operands don't have matching ranks");
}
}

for (const auto type : rankedResultTypes) {
if (rank != getRank(type)) {
return emitOpError("result type has different rank than operands");
}
}

// check for broadcast compatible shapes in first two operands (ignoring
// shift)

// delegate function that returns shape of shaped type
auto getShape = [](const Type type) {
return mlir::cast<ShapedType>(type).getShape();
};
SmallVector<int64_t> resultShape;
if (!mlir::OpTrait::util::getBroadcastedShape(getShape(rankedOperandTypes[0]),
getShape(rankedOperandTypes[1]),
resultShape)) {
return emitOpError("operands don't have broadcast-compatible shapes");
}
const int64_t resultRank = resultType.getRank();
if (aHasRank && resultRank != aType.getRank())
return emitOpError("result type has different rank than a, got ")
<< resultRank << " vs " << aType.getRank();
if (bHasRank && resultRank != bType.getRank())
return emitOpError("result type has different rank than b, got ")
<< resultRank << " vs " << bType.getRank();

return success();
}
Expand Down
29 changes: 28 additions & 1 deletion mlir/test/Dialect/Tosa/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1107,11 +1107,38 @@ func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tenso
// CHECK-LABEL: test_mul_non_broadcast
func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
// expected-error@+1 {{'tosa.mul' op a and b operands don't have broadcast-compatible shapes, got 'tensor<13x21x2xf32>' and 'tensor<3x1x3xf32>'}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: test_mul_different_operand_ranks
func.func @test_mul_different_operand_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.mul' op a and b operands don't have matching ranks, got 2 and 3}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: test_mul_different_a_and_result_ranks
func.func @test_mul_different_a_and_result_ranks(%arg0: tensor<13x21xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.mul' op result type has different rank than a, got 3 vs 2}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: test_mul_different_b_and_result_ranks
func.func @test_mul_different_b_and_result_ranks(%arg0: tensor<*xf32>, %arg1: tensor<13x12xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// expected-error@+1 {{'tosa.mul' op result type has different rank than b, got 3 vs 2}}
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<13x12xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: test_resize_invalid_scale_values
func.func @test_resize_invalid_scale_values(%arg0: tensor<1x8x8x8xf32>) -> tensor<?x?x?x?xf32> {
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,22 @@ func.func @test_mul_relaxed_result_type(%arg0: tensor<13x21x3xi16>, %arg1: tenso
return %0 : tensor<13x21x3xi16>
}

// -----
// CHECK-LABEL: test_mul_unranked_b
func.func @test_mul_unranked_b(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: test_mul_unranked_a_and_b
func.func @test_mul_unranked_a_and_b(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<13x21x3xf32> {
%shift = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
%0 = tosa.mul %arg0, %arg1, %shift : (tensor<*xf32>, tensor<*xf32>, tensor<1xi8>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

// -----
// CHECK-LABEL: pow
func.func @test_pow(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x1xf32>) -> tensor<13x21x3xf32> {
Expand Down
Loading