Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type constraints for AIEVec_MulElemOp #1487

Conversation

muradq-amd
Copy link
Collaborator

This PR includes adding constrains to aievec::mul_elem on the number of lanes and types of operands and results.

Changes:

  • Add constraints to the LHS and RHS operands and result (AIE2MulElemLHS, AIE2MulElemRHS, and AIE2MulElemACC)
  • Add PredOpTrait constrains to allow only supported type/number of lanes combinations.

Copy link
Collaborator

@jsetoain jsetoain left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand you've followed the constraint scheme for aievec.matmul, but that's is a very bad template for this use case because that one was a very irregular type combination, while this one is pretty regular.

You need to check that lhs/rhs match, and that either the element type of lhs/rhs matches acc, or acc is the "wide type" (i32 for integrals, f32 for floats).

Comment on lines 147 to 165
def AIE2MulElemLHS :
AnyTypeOf<[VectorOfShapeAndType<[32], I8>,
VectorOfShapeAndType<[32], I16>,
VectorOfShapeAndType<[16], I32>,
VectorOfShapeAndType<[16], BF16>,
VectorOfShapeAndType<[16], F32>],
"a vector compatible with a lhs operand of element-wise multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MulElemRHS :
AnyTypeOf<[VectorOfShapeAndType<[32], I8>,
VectorOfShapeAndType<[32], I16>,
VectorOfShapeAndType<[16], I32>,
VectorOfShapeAndType<[16], BF16>,
VectorOfShapeAndType<[16], F32>],
"a vector compatible with a rhs operand of element-wise multiply and "
# "accumulate",
"::mlir::VectorType">;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is overkill. Lhs & Rhs types have to match, you don't need to spell out every single valid type for each of the operands.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op verifier checks that the lanes and datatypes are the same for operands. See

// Additional checks for FMAElem op
// Get the width of the underlying scalars of all the vectors
Type ltype = lhsType.getElementType();
Type rtype = rhsType.getElementType();
Type atype = resultType.getElementType();
unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
unsigned atypeWidth = atype.getIntOrFloatBitWidth();
// Checks on the number of lanes
unsigned rhsLanes = getVectorLaneSize(rhsType);
unsigned lhsLanes = getVectorLaneSize(lhsType);
// lane size must match
if (lhsLanes != rhsLanes) {
return op.emitError("The number of lanes in lhs operand "
"must be the same as rhs operand");
}
// lhs and rhs vector's element type must match
if (ltype != rtype)
return op.emitError("The element type of lhs and rhs "
"operand vectors must match");

This can probably be replaced by SameTypeOperands and SameOperandsAndResultShape traits....

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I'm not sure how to better constraint the fact that the acc is i32 for i8/i16 operands, i64 for i32 operands, and f32 for bf16/f32 operands. Maybe just constraints it in the op verifier in cpp code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op verifier checks that the lanes and datatypes are the same for operands. See

// Additional checks for FMAElem op
// Get the width of the underlying scalars of all the vectors
Type ltype = lhsType.getElementType();
Type rtype = rhsType.getElementType();
Type atype = resultType.getElementType();
unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
unsigned atypeWidth = atype.getIntOrFloatBitWidth();
// Checks on the number of lanes
unsigned rhsLanes = getVectorLaneSize(rhsType);
unsigned lhsLanes = getVectorLaneSize(lhsType);
// lane size must match
if (lhsLanes != rhsLanes) {
return op.emitError("The number of lanes in lhs operand "
"must be the same as rhs operand");
}
// lhs and rhs vector's element type must match
if (ltype != rtype)
return op.emitError("The element type of lhs and rhs "
"operand vectors must match");

Notice that the C++ verifier will be replaced by the automatically generated one.

This can probably be replaced by SameTypeOperands and SameOperandsAndResultShape traits....

Indeed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I'm not sure how to better constraint the fact that the acc is i32 for i8/i16 operands, i64 for i32 operands, and f32 for bf16/f32 operands. Maybe just constraints it in the op verifier in cpp code?

For that, you can either define a new predicate ('IsValidAccumulatorTypeFor`) that verifies the accumulator type validity using simpler pre-existing predicates, or define a C++ function that makes that check and invoke it from the predicate.

Comment on lines 178 to 205
class IsValidAIE2MulElemShapeAndType<string lhs, string rhs, string acc> :
PredOpTrait<lhs # " x " # rhs # " = " # acc # " is a valid AIE2 " #
"element-wise multiply and accumulate op",
Or<[VectorTypesMatch<lhs, VectorOfShapeAndType<[32], I8>,
rhs, VectorOfShapeAndType<[32], I8>,
acc, VectorOfShapeAndType<[32], I8>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[32], I8>,
rhs, VectorOfShapeAndType<[32], I8>,
acc, VectorOfShapeAndType<[32], I32>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[32], I16>,
rhs, VectorOfShapeAndType<[32], I16>,
acc, VectorOfShapeAndType<[32], I16>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[32], I16>,
rhs, VectorOfShapeAndType<[32], I16>,
acc, VectorOfShapeAndType<[32], I32>>,

VectorTypesMatch<lhs, VectorOfShapeAndType<[16], I32>,
rhs, VectorOfShapeAndType<[16], I32>,
acc, VectorOfShapeAndType<[16], I32>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[16], BF16>,
rhs, VectorOfShapeAndType<[16], BF16>,
acc, VectorOfShapeAndType<[16], BF16>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[16], BF16>,
rhs, VectorOfShapeAndType<[16], BF16>,
acc, VectorOfShapeAndType<[16], F32>>,
VectorTypesMatch<lhs, VectorOfShapeAndType<[16], F32>,
rhs, VectorOfShapeAndType<[16], F32>,
acc, VectorOfShapeAndType<[16], F32>>]>>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This constraint alone already makes all the checks necessary for the supported types, but it's unnecessarily complex for such a regular operation.

Comment on lines 869 to 876
`vector<32xi8>` | `vector<32xi8>` | `vector<32xi8>`
`vector<32xi8>` | `vector<32xi8>` | `vector<32xi32>`
`vector<32xi16>` | `vector<32xi16>` | `vector<32xi16>`
`vector<32xi16>` | `vector<32xi16>` | `vector<32xi32>`
`vector<16xi32>` | `vector<16xi32>` | `vector<16xi32>`
`vector<16xbf16>` | `vector<16xbf16>` | `vector<16xbf16>`
`vector<16xbf16>` | `vector<16xbf16>` | `vector<16xf32>`
`vector<16xf32>` | `vector<16xf32>` | `vector<16xf32>`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For aievec.mul_elem the accumulator is acc32 (i32) for i8/i16 and acc64 (i64) for i32. Your Accumulator column here is more like a result type for a arith.mulf/i op from what you see in the e2e unit tests. To handle different result types, we introduce either aievec.cast or aievec.srs to the result of aievec.mul_elem. See

if (mulElemResultElWidth == resultElWidth) {
rewriter.replaceOpWithNewOp<aievec::CastOp>(
mulOp, resultType, mulElemOp.getResult(), /*isResAcc*/ false);
} else if (mulElemResultElWidth > resultElWidth) {
auto shiftParamOp = rewriter.create<arith::ConstantOp>(
mulOp.getLoc(), rewriter.getI32IntegerAttr(shiftParam));
rewriter.replaceOpWithNewOp<aievec::SRSOp>(
mulOp, resultType, mulElemOp.getResult(), shiftParamOp.getResult());
} else {
return failure();
}

Comment on lines 147 to 165
def AIE2MulElemLHS :
AnyTypeOf<[VectorOfShapeAndType<[32], I8>,
VectorOfShapeAndType<[32], I16>,
VectorOfShapeAndType<[16], I32>,
VectorOfShapeAndType<[16], BF16>,
VectorOfShapeAndType<[16], F32>],
"a vector compatible with a lhs operand of element-wise multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MulElemRHS :
AnyTypeOf<[VectorOfShapeAndType<[32], I8>,
VectorOfShapeAndType<[32], I16>,
VectorOfShapeAndType<[16], I32>,
VectorOfShapeAndType<[16], BF16>,
VectorOfShapeAndType<[16], F32>],
"a vector compatible with a rhs operand of element-wise multiply and "
# "accumulate",
"::mlir::VectorType">;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The op verifier checks that the lanes and datatypes are the same for operands. See

// Additional checks for FMAElem op
// Get the width of the underlying scalars of all the vectors
Type ltype = lhsType.getElementType();
Type rtype = rhsType.getElementType();
Type atype = resultType.getElementType();
unsigned ltypeWidth = ltype.getIntOrFloatBitWidth();
unsigned rtypeWidth = rtype.getIntOrFloatBitWidth();
unsigned atypeWidth = atype.getIntOrFloatBitWidth();
// Checks on the number of lanes
unsigned rhsLanes = getVectorLaneSize(rhsType);
unsigned lhsLanes = getVectorLaneSize(lhsType);
// lane size must match
if (lhsLanes != rhsLanes) {
return op.emitError("The number of lanes in lhs operand "
"must be the same as rhs operand");
}
// lhs and rhs vector's element type must match
if (ltype != rtype)
return op.emitError("The element type of lhs and rhs "
"operand vectors must match");

This can probably be replaced by SameTypeOperands and SameOperandsAndResultShape traits....

Comment on lines 147 to 165
def AIE2MulElemLHS :
AnyTypeOf<[VectorOfShapeAndType<[32], I8>,
VectorOfShapeAndType<[32], I16>,
VectorOfShapeAndType<[16], I32>,
VectorOfShapeAndType<[16], BF16>,
VectorOfShapeAndType<[16], F32>],
"a vector compatible with a lhs operand of element-wise multiply and "
# "accumulate",
"::mlir::VectorType">;

def AIE2MulElemRHS :
AnyTypeOf<[VectorOfShapeAndType<[32], I8>,
VectorOfShapeAndType<[32], I16>,
VectorOfShapeAndType<[16], I32>,
VectorOfShapeAndType<[16], BF16>,
VectorOfShapeAndType<[16], F32>],
"a vector compatible with a rhs operand of element-wise multiply and "
# "accumulate",
"::mlir::VectorType">;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, I'm not sure how to better constraint the fact that the acc is i32 for i8/i16 operands, i64 for i32 operands, and f32 for bf16/f32 operands. Maybe just constraints it in the op verifier in cpp code?

@muradq-amd
Copy link
Collaborator Author

Thanks @jamestcl-amd, @jsetoain for the comments.
I just revised the code changes and added simplified constraints on the types and shapes of aievec::mul_elem's operands/results. Here is the list of supported type/shape combinations:

lhs rhs acc
<32xi8> <32xi8> <32xi8>
<32xi8> <32xi8> <32xi32>
<64xi8> <64xi8> <32xi32>
<32xi16> <32xi16> <32xi16>
<32xi16> <32xi16> <32xi32>
<16xi32> <16xi32> <16xi32>
<16xi32> <16xi32> <16xi64>
<16xbf16> <16xbf16> <16xbf16>
<16xbf16> <16xbf16> <16xf32>
<16xf32> <16xf32> <16xf32>

Please let me know if you see any thing need to be fixed. thanks.

platforms/boards Outdated
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't be in the commit.

@@ -101,6 +101,10 @@ class ShapesCompatibleWithContraction<string lhs, string rhs, string acc> :
class VectorType<string name> : StrFunc<"cast<VectorType>($" # name #
".getType())">;

class VectorElementType<string name> :
Copy link
Collaborator

@jsetoain jsetoain May 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These already exist in upstream mlir. Check llvm-project/mlir/include/mlir/IR/OpBase.td. There's a lot of stuff there you'll find helpful.

Comment on lines 319 to 321
Arguments<(ins AnyVector:$lhs,
AnyVector:$rhs)>,
Results<(outs AnyVector:$acc)> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't support AnyVector, you should use an operand type constraint for these, and then your op constraint will be simpler.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants