@@ -1857,7 +1857,8 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents(
1857
1857
}
1858
1858
1859
1859
LogicalResult tosa::MulOp::verify () {
1860
- auto resElemType = getElementTypeOrSelf (getOutput ());
1860
+ const Value output = getOutput ();
1861
+ auto resElemType = getElementTypeOrSelf (output);
1861
1862
1862
1863
// Verify if the element type among operands and result match tosa
1863
1864
// specification.
@@ -1897,59 +1898,39 @@ LogicalResult tosa::MulOp::verify() {
1897
1898
// Verify the op has same ranks for all main operands (excludes extra operands
1898
1899
// such as shift of mul op, so this is the only difference with the built-in
1899
1900
// `SameOperandsAndResultRank` trait) and results types, if known.
1900
-
1901
- // delegate function that returns true if type is a shaped type with known
1902
- // rank
1903
- auto hasRank = [](const Type type) {
1904
- if (auto shaped_type = dyn_cast<ShapedType>(type))
1905
- return shaped_type.hasRank ();
1906
-
1907
- return false ;
1908
- };
1909
-
1910
- auto rankedOperandTypes =
1911
- llvm::to_vector (llvm::make_filter_range (getOperandTypes (), hasRank));
1912
-
1913
- auto rankedResultTypes =
1914
- llvm::make_filter_range (getOperation ()->getResultTypes (), hasRank);
1915
-
1916
- // If all operands and results are unranked, then no further verification.
1917
- if (rankedOperandTypes.empty () && rankedResultTypes.empty ())
1901
+ TypeRange operandTypes = getOperandTypes ();
1902
+ ShapedType aType = cast<ShapedType>(operandTypes[0 ]);
1903
+ ShapedType bType = cast<ShapedType>(operandTypes[1 ]);
1904
+
1905
+ const bool aHasRank = aType.hasRank ();
1906
+ const bool bHasRank = bType.hasRank ();
1907
+ if (aHasRank && bHasRank) {
1908
+ const int64_t aRank = aType.getRank ();
1909
+ const int64_t bRank = bType.getRank ();
1910
+ if (aRank != bRank)
1911
+ return emitOpError (" a and b operands don't have matching ranks, got " )
1912
+ << aRank << " and " << bRank;
1913
+
1914
+ // check for broadcast compatible shapes
1915
+ SmallVector<int64_t > resultShape;
1916
+ if (!mlir::OpTrait::util::getBroadcastedShape (
1917
+ aType.getShape (), bType.getShape (), resultShape))
1918
+ return emitOpError (" a and b operands don't have broadcast-compatible "
1919
+ " shapes, got " )
1920
+ << aType << " and " << bType;
1921
+ }
1922
+
1923
+ ShapedType resultType = cast<ShapedType>(output.getType ());
1924
+ if (!resultType.hasRank ())
1918
1925
return success ();
1919
1926
1920
- // delegate function that returns rank of shaped type with known rank
1921
- auto getRank = [](const Type type) {
1922
- return cast<ShapedType>(type).getRank ();
1923
- };
1924
-
1925
- auto rank = !rankedOperandTypes.empty () ? getRank (*rankedOperandTypes.begin ())
1926
- : getRank (*rankedResultTypes.begin ());
1927
-
1928
- for (size_t i = 0 ; i < 2 ; ++i) {
1929
- if (rank != getRank (rankedOperandTypes[i])) {
1930
- return emitOpError (" operands don't have matching ranks" );
1931
- }
1932
- }
1933
-
1934
- for (const auto type : rankedResultTypes) {
1935
- if (rank != getRank (type)) {
1936
- return emitOpError (" result type has different rank than operands" );
1937
- }
1938
- }
1939
-
1940
- // check for broadcast compatible shapes in first two operands (ignoring
1941
- // shift)
1942
-
1943
- // delegate function that returns shape of shaped type
1944
- auto getShape = [](const Type type) {
1945
- return mlir::cast<ShapedType>(type).getShape ();
1946
- };
1947
- SmallVector<int64_t > resultShape;
1948
- if (!mlir::OpTrait::util::getBroadcastedShape (getShape (rankedOperandTypes[0 ]),
1949
- getShape (rankedOperandTypes[1 ]),
1950
- resultShape)) {
1951
- return emitOpError (" operands don't have broadcast-compatible shapes" );
1952
- }
1927
+ const int64_t resultRank = resultType.getRank ();
1928
+ if (aHasRank && resultRank != aType.getRank ())
1929
+ return emitOpError (" result type has different rank than a, got " )
1930
+ << resultRank << " vs " << aType.getRank ();
1931
+ if (bHasRank && resultRank != bType.getRank ())
1932
+ return emitOpError (" result type has different rank than b, got " )
1933
+ << resultRank << " vs " << bType.getRank ();
1953
1934
1954
1935
return success ();
1955
1936
}
0 commit comments