-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve type constraints for AIEVec_MulElemOp #1487
Improve type constraints for AIEVec_MulElemOp #1487
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand you've followed the constraint scheme for aievec.matmul
, but that's is a very bad template for this use case because that one was a very irregular type combination, while this one is pretty regular.
You need to check that lhs/rhs match, and that either the element type of lhs/rhs matches acc, or acc is the "wide type" (i32 for integrals, f32 for floats).
def AIE2MulElemLHS : | ||
AnyTypeOf<[VectorOfShapeAndType<[32], I8>, | ||
VectorOfShapeAndType<[32], I16>, | ||
VectorOfShapeAndType<[16], I32>, | ||
VectorOfShapeAndType<[16], BF16>, | ||
VectorOfShapeAndType<[16], F32>], | ||
"a vector compatible with a lhs operand of element-wise multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; | ||
|
||
def AIE2MulElemRHS : | ||
AnyTypeOf<[VectorOfShapeAndType<[32], I8>, | ||
VectorOfShapeAndType<[32], I16>, | ||
VectorOfShapeAndType<[16], I32>, | ||
VectorOfShapeAndType<[16], BF16>, | ||
VectorOfShapeAndType<[16], F32>], | ||
"a vector compatible with a rhs operand of element-wise multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is overkill. Lhs & Rhs types have to match, you don't need to spell out every single valid type for each of the operands.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The op verifier checks that the lanes and datatypes are the same for operands. See
mlir-aie/lib/Dialect/AIEVec/IR/AIEVecOps.cpp
Lines 872 to 894 in aba1887
// Additional checks for FMAElem op | |
// Get the width of the underlying scalars of all the vectors | |
Type ltype = lhsType.getElementType(); | |
Type rtype = rhsType.getElementType(); | |
Type atype = resultType.getElementType(); | |
unsigned ltypeWidth = ltype.getIntOrFloatBitWidth(); | |
unsigned rtypeWidth = rtype.getIntOrFloatBitWidth(); | |
unsigned atypeWidth = atype.getIntOrFloatBitWidth(); | |
// Checks on the number of lanes | |
unsigned rhsLanes = getVectorLaneSize(rhsType); | |
unsigned lhsLanes = getVectorLaneSize(lhsType); | |
// lane size must match | |
if (lhsLanes != rhsLanes) { | |
return op.emitError("The number of lanes in lhs operand " | |
"must be the same as rhs operand"); | |
} | |
// lhs and rhs vector's element type must match | |
if (ltype != rtype) | |
return op.emitError("The element type of lhs and rhs " | |
"operand vectors must match"); |
This can probably be replaced by SameTypeOperands
and SameOperandsAndResultShape
traits....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, I'm not sure how to better constraint the fact that the acc
is i32 for i8/i16 operands, i64
for i32 operands, and f32
for bf16/f32 operands. Maybe just constraints it in the op verifier in cpp code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The op verifier checks that the lanes and datatypes are the same for operands. See
mlir-aie/lib/Dialect/AIEVec/IR/AIEVecOps.cpp
Lines 872 to 894 in aba1887
// Additional checks for FMAElem op // Get the width of the underlying scalars of all the vectors Type ltype = lhsType.getElementType(); Type rtype = rhsType.getElementType(); Type atype = resultType.getElementType(); unsigned ltypeWidth = ltype.getIntOrFloatBitWidth(); unsigned rtypeWidth = rtype.getIntOrFloatBitWidth(); unsigned atypeWidth = atype.getIntOrFloatBitWidth(); // Checks on the number of lanes unsigned rhsLanes = getVectorLaneSize(rhsType); unsigned lhsLanes = getVectorLaneSize(lhsType); // lane size must match if (lhsLanes != rhsLanes) { return op.emitError("The number of lanes in lhs operand " "must be the same as rhs operand"); } // lhs and rhs vector's element type must match if (ltype != rtype) return op.emitError("The element type of lhs and rhs " "operand vectors must match");
Notice that the C++ verifier will be replaced by the automatically generated one.
This can probably be replaced by
SameTypeOperands
andSameOperandsAndResultShape
traits....
Indeed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, I'm not sure how to better constraint the fact that the
acc
is i32 for i8/i16 operands,i64
for i32 operands, andf32
for bf16/f32 operands. Maybe just constraints it in the op verifier in cpp code?
For that, you can either define a new predicate ('IsValidAccumulatorTypeFor`) that verifies the accumulator type validity using simpler pre-existing predicates, or define a C++ function that makes that check and invoke it from the predicate.
class IsValidAIE2MulElemShapeAndType<string lhs, string rhs, string acc> : | ||
PredOpTrait<lhs # " x " # rhs # " = " # acc # " is a valid AIE2 " # | ||
"element-wise multiply and accumulate op", | ||
Or<[VectorTypesMatch<lhs, VectorOfShapeAndType<[32], I8>, | ||
rhs, VectorOfShapeAndType<[32], I8>, | ||
acc, VectorOfShapeAndType<[32], I8>>, | ||
VectorTypesMatch<lhs, VectorOfShapeAndType<[32], I8>, | ||
rhs, VectorOfShapeAndType<[32], I8>, | ||
acc, VectorOfShapeAndType<[32], I32>>, | ||
VectorTypesMatch<lhs, VectorOfShapeAndType<[32], I16>, | ||
rhs, VectorOfShapeAndType<[32], I16>, | ||
acc, VectorOfShapeAndType<[32], I16>>, | ||
VectorTypesMatch<lhs, VectorOfShapeAndType<[32], I16>, | ||
rhs, VectorOfShapeAndType<[32], I16>, | ||
acc, VectorOfShapeAndType<[32], I32>>, | ||
|
||
VectorTypesMatch<lhs, VectorOfShapeAndType<[16], I32>, | ||
rhs, VectorOfShapeAndType<[16], I32>, | ||
acc, VectorOfShapeAndType<[16], I32>>, | ||
VectorTypesMatch<lhs, VectorOfShapeAndType<[16], BF16>, | ||
rhs, VectorOfShapeAndType<[16], BF16>, | ||
acc, VectorOfShapeAndType<[16], BF16>>, | ||
VectorTypesMatch<lhs, VectorOfShapeAndType<[16], BF16>, | ||
rhs, VectorOfShapeAndType<[16], BF16>, | ||
acc, VectorOfShapeAndType<[16], F32>>, | ||
VectorTypesMatch<lhs, VectorOfShapeAndType<[16], F32>, | ||
rhs, VectorOfShapeAndType<[16], F32>, | ||
acc, VectorOfShapeAndType<[16], F32>>]>>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This constraint alone already makes all the checks necessary for the supported types, but it's unnecessarily complex for such a regular operation.
`vector<32xi8>` | `vector<32xi8>` | `vector<32xi8>` | ||
`vector<32xi8>` | `vector<32xi8>` | `vector<32xi32>` | ||
`vector<32xi16>` | `vector<32xi16>` | `vector<32xi16>` | ||
`vector<32xi16>` | `vector<32xi16>` | `vector<32xi32>` | ||
`vector<16xi32>` | `vector<16xi32>` | `vector<16xi32>` | ||
`vector<16xbf16>` | `vector<16xbf16>` | `vector<16xbf16>` | ||
`vector<16xbf16>` | `vector<16xbf16>` | `vector<16xf32>` | ||
`vector<16xf32>` | `vector<16xf32>` | `vector<16xf32>` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For aievec.mul_elem
the accumulator is acc32
(i32) for i8/i16
and acc64
(i64) for i32
. Your Accumulator column here is more like a result type for a arith.mulf/i
op from what you see in the e2e unit tests. To handle different result types, we introduce either aievec.cast
or aievec.srs
to the result of aievec.mul_elem
. See
mlir-aie/lib/Dialect/AIEVec/Transforms/VectorToAIEVecConversions.cpp
Lines 702 to 712 in aba1887
if (mulElemResultElWidth == resultElWidth) { | |
rewriter.replaceOpWithNewOp<aievec::CastOp>( | |
mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false); | |
} else if (mulElemResultElWidth > resultElWidth) { | |
auto shiftParamOp = rewriter.create<arith::ConstantOp>( | |
mulOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam)); | |
rewriter.replaceOpWithNewOp<aievec::SRSOp>( | |
mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult()); | |
} else { | |
return failure(); | |
} |
def AIE2MulElemLHS : | ||
AnyTypeOf<[VectorOfShapeAndType<[32], I8>, | ||
VectorOfShapeAndType<[32], I16>, | ||
VectorOfShapeAndType<[16], I32>, | ||
VectorOfShapeAndType<[16], BF16>, | ||
VectorOfShapeAndType<[16], F32>], | ||
"a vector compatible with a lhs operand of element-wise multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; | ||
|
||
def AIE2MulElemRHS : | ||
AnyTypeOf<[VectorOfShapeAndType<[32], I8>, | ||
VectorOfShapeAndType<[32], I16>, | ||
VectorOfShapeAndType<[16], I32>, | ||
VectorOfShapeAndType<[16], BF16>, | ||
VectorOfShapeAndType<[16], F32>], | ||
"a vector compatible with a rhs operand of element-wise multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The op verifier checks that the lanes and datatypes are the same for operands. See
mlir-aie/lib/Dialect/AIEVec/IR/AIEVecOps.cpp
Lines 872 to 894 in aba1887
// Additional checks for FMAElem op | |
// Get the width of the underlying scalars of all the vectors | |
Type ltype = lhsType.getElementType(); | |
Type rtype = rhsType.getElementType(); | |
Type atype = resultType.getElementType(); | |
unsigned ltypeWidth = ltype.getIntOrFloatBitWidth(); | |
unsigned rtypeWidth = rtype.getIntOrFloatBitWidth(); | |
unsigned atypeWidth = atype.getIntOrFloatBitWidth(); | |
// Checks on the number of lanes | |
unsigned rhsLanes = getVectorLaneSize(rhsType); | |
unsigned lhsLanes = getVectorLaneSize(lhsType); | |
// lane size must match | |
if (lhsLanes != rhsLanes) { | |
return op.emitError("The number of lanes in lhs operand " | |
"must be the same as rhs operand"); | |
} | |
// lhs and rhs vector's element type must match | |
if (ltype != rtype) | |
return op.emitError("The element type of lhs and rhs " | |
"operand vectors must match"); |
This can probably be replaced by SameTypeOperands
and SameOperandsAndResultShape
traits....
def AIE2MulElemLHS : | ||
AnyTypeOf<[VectorOfShapeAndType<[32], I8>, | ||
VectorOfShapeAndType<[32], I16>, | ||
VectorOfShapeAndType<[16], I32>, | ||
VectorOfShapeAndType<[16], BF16>, | ||
VectorOfShapeAndType<[16], F32>], | ||
"a vector compatible with a lhs operand of element-wise multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; | ||
|
||
def AIE2MulElemRHS : | ||
AnyTypeOf<[VectorOfShapeAndType<[32], I8>, | ||
VectorOfShapeAndType<[32], I16>, | ||
VectorOfShapeAndType<[16], I32>, | ||
VectorOfShapeAndType<[16], BF16>, | ||
VectorOfShapeAndType<[16], F32>], | ||
"a vector compatible with a rhs operand of element-wise multiply and " | ||
# "accumulate", | ||
"::mlir::VectorType">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
However, I'm not sure how to better constraint the fact that the acc
is i32 for i8/i16 operands, i64
for i32 operands, and f32
for bf16/f32 operands. Maybe just constraints it in the op verifier in cpp code?
Thanks @jamestcl-amd, @jsetoain for the comments.
Please let me know if you see any thing need to be fixed. thanks. |
platforms/boards
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This shouldn't be in the commit.
@@ -101,6 +101,10 @@ class ShapesCompatibleWithContraction<string lhs, string rhs, string acc> : | |||
class VectorType<string name> : StrFunc<"cast<VectorType>($" # name # | |||
".getType())">; | |||
|
|||
class VectorElementType<string name> : |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These already exist in upstream mlir. Check llvm-project/mlir/include/mlir/IR/OpBase.td
. There's a lot of stuff there you'll find helpful.
Arguments<(ins AnyVector:$lhs, | ||
AnyVector:$rhs)>, | ||
Results<(outs AnyVector:$acc)> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't support AnyVector
, you should use an operand type constraint for these, and then your op constraint will be simpler.
a38d63d
to
37ef519
Compare
This PR includes adding constrains to aievec::mul_elem on the number of lanes and types of operands and results.
Changes: