-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[MLIR][NVVM] Update dot.accumulate.4way NVVM Op #141223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Update dot.accumulate.4way NVVM Op #141223
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis change refactors and updates the dot.accumulate.4way NVVM Op to be more descriptive and readable. Full diff: https://github.com/llvm/llvm-project/pull/141223.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0c5c87cfe002f..166c6821c0743 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3533,36 +3533,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
}
//===----------------------------------------------------------------------===//
-// NVVM dot.accumulate.4way Op
+// NVVM dot.accumulate Ops
//===----------------------------------------------------------------------===//
-def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
-def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED", 0, "unsigned">;
+def DotAccumulateSigned : I32EnumAttrCase<"SIGNED", 1, "signed">;
-def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
- "NVVM DotAccumulate4WayType",
- [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
+def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
+ "NVVM DotAccumulateType",
+ [DotAccumulateSigned, DotAccumulateUnsigned]> {
let cppNamespace = "::mlir::NVVM";
let genSpecializedAttr = 0;
}
-def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
+def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
let assemblyFormat = "`<` $value `>`";
}
def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
- let summary = "Four-way byte dot product-accumulate instruction.";
+ let summary = "Four-way byte dot product-accumulate instruction";
let description = [{
Performs a four-way byte dot-product which is accumulated in a 32-bit
result.
Operand `a` and `b` are vectors of 4 bytes between which the dot product is
computed.
+
The `a_type` and `b_type` attributes specify the type of the elements in `a`
and `b` respectively.
- If `a_type` or `b_type` is `s8`, then the elements in the corresponding
+ If `a_type` or `b_type` is `signed`, then the elements in the corresponding
vector are sign-extended to 32-bit before the dot product is computed.
- If `a_type` or `b_type` is `u8`, then the elements in the corresponding
- vector are zero-extended to 32-bit instead.
+ If `a_type` or `b_type` is `unsigned`, then the elements in the
+ corresponding vector are zero-extended to 32-bit instead.
+
Operand `c` is a 32-bit integer to which the result is accumulated. It is
treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
@@ -3571,9 +3573,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
let arguments = (ins
VectorOfLengthAndType<[4], [I8]>:$a,
- DotAccumulate4WayTypeAttr:$a_type,
+ DotAccumulateTypeAttr:$a_type,
VectorOfLengthAndType<[4], [I8]>:$b,
- DotAccumulate4WayTypeAttr:$b_type,
+ DotAccumulateTypeAttr:$b_type,
I32:$c
);
@@ -3582,17 +3584,15 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
let extraClassDeclaration = [{
- static llvm::Intrinsic::ID
- getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
- NVVM::DotAccumulate4WayType b_type);
- llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+ static mlir::NVVM::IDArgPair
+ getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder);
}];
string llvmBuilder = [{
- llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
- llvm::Value* argA = op.getPackedArg($a, builder);
- llvm::Value* argB = op.getPackedArg($b, builder);
- $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+ auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+ *op, moduleTranslation, builder);
+ $res = createIntrinsicCall(builder, id, args);
}];
}
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 9f55fe315106c..ef3067b4e351d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1205,13 +1205,6 @@ LogicalResult NVVM::VoteSyncOp::verify() {
return success();
}
-llvm::Value *
-NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
- llvm::IRBuilderBase &builder) {
- return builder.CreateBitCast(arg,
- llvm::Type::getInt32Ty(builder.getContext()));
-}
-
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
@@ -1627,24 +1620,31 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
-llvm::Intrinsic::ID
-DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
- NVVM::DotAccumulate4WayType b_type) {
- bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
- bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
+static llvm::Value *getAsPackedI32(llvm::Value *arg,
+ llvm::IRBuilderBase &builder) {
+ return builder.CreateBitCast(arg,
+ llvm::Type::getInt32Ty(builder.getContext()));
+}
+
+NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+ Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+ auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
+
+ llvm::SmallVector<llvm::Value *> args;
+ args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
+ args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
+ args.push_back(mt.lookupValue(curOp.getC()));
+
+ bool is_a_siext = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
+ bool is_b_siext = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
unsigned type = (is_a_siext << 1) | is_b_siext;
- switch (type) {
- case 0:
- return llvm::Intrinsic::nvvm_idp4a_u_u;
- case 1:
- return llvm::Intrinsic::nvvm_idp4a_u_s;
- case 2:
- return llvm::Intrinsic::nvvm_idp4a_s_u;
- case 3:
- return llvm::Intrinsic::nvvm_idp4a_s_s;
- default:
- llvm_unreachable("Invalid DP4a type");
- }
+ const llvm::Intrinsic::ID ids[] = {
+ llvm::Intrinsic::nvvm_idp4a_u_u,
+ llvm::Intrinsic::nvvm_idp4a_u_s,
+ llvm::Intrinsic::nvvm_idp4a_s_u,
+ llvm::Intrinsic::nvvm_idp4a_s_s,
+ };
+ return {ids[type], args};
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..77b302155cb12 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -579,11 +579,11 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
}
// CHECK-LABEL: @dot_accumulate_4way
-func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
- %1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dot.accumulate.4way %a_vec <unsigned>, %b_vec <unsigned>, %c: vector<4xi8>, vector<4xi8>
// CHECK: nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
- %3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dot.accumulate.4way %a_vec <signed>, %b_vec <signed>, %c: vector<4xi8>, vector<4xi8>
return
}
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 90519a9402621..0be9007dd53e4 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -851,18 +851,18 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+ %0 = nvvm.dot.accumulate.4way %a <unsigned>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dot.accumulate.4way %a <signed>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+ %2 = nvvm.dot.accumulate.4way %a <unsigned>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
- %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dot.accumulate.4way %a <signed>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
llvm.return
}
|
12cfdab
to
380fcb8
Compare
380fcb8
to
c82edd3
Compare
This change refactors and updates the dot.accumulate.4way NVVM Op to be more descriptive and readable.
c82edd3
to
807ad35
Compare
Merging as this PR is just a split from #140518 where the changes are already reviewed. |
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 | ||
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32 | ||
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}}) | ||
%2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8> | ||
%2 = nvvm.dot.accumulate.4way %a <unsigned>, %b <signed>, %c: vector<4xi8>, vector<4xi8> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I do understand this change. It is nice when you think this PR in isolation.
Thats said, we want to use mlir builtin types. When we do, we need to change IR again because they don't print 'unsigned'. Considering that, should we make this change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I am not sure if we want to use the builtin types for this Op specifically (the signed 8-bit integer SI8
) since from what I understand, the signed types can't be directly lowered to LLVM. But if that changes in the future and we don't want to change the existing IR, maybe one potential solution would be to make these attributes optional and also support lowering through the types themselves? Please let me know what you think.
I think like we discussed in #139043 (comment), the cleaner solution would be different Ops for signed/unsigned (like in the arith
dialect) but in this case since we have four different intrinsics to lower to instead, this seemed like a good compromise.
This change refactors and updates the `dot.accumulate.4way` NVVM Op to be more descriptive and readable.
This change refactors and updates the `dot.accumulate.4way` NVVM Op to be more descriptive and readable.
This change refactors and updates the `dot.accumulate.4way` NVVM Op to be more descriptive and readable.
This change refactors and updates the
dot.accumulate.4way
NVVM Op to be more descriptive and readable.