Skip to content

[MLIR][NVVM] Update dot.accumulate.4way NVVM Op #141223

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 29, 2025

Conversation

Wolfram70
Copy link
Contributor

@Wolfram70 Wolfram70 commented May 23, 2025

This change refactors and updates the dot.accumulate.4way NVVM Op to be more descriptive and readable.

@llvmbot
Copy link
Member

llvmbot commented May 23, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Srinivasa Ravi (Wolfram70)

Changes

This change refactors and updates the dot.accumulate.4way NVVM Op to be more descriptive and readable.


Full diff: https://github.com/llvm/llvm-project/pull/141223.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+21-21)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+24-24)
  • (modified) mlir/test/Dialect/LLVMIR/nvvm.mlir (+3-3)
  • (modified) mlir/test/Target/LLVMIR/nvvmir.mlir (+4-4)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 0c5c87cfe002f..166c6821c0743 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3533,36 +3533,38 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st", [NVVMRequiresSMa<[100, 101]>]> {
 }
 
 //===----------------------------------------------------------------------===//
-// NVVM dot.accumulate.4way Op
+// NVVM dot.accumulate Ops
 //===----------------------------------------------------------------------===//
 
-def DotAccumulate4WayS8 : I32EnumAttrCase<"S8", 1, "s8">;
-def DotAccumulate4WayU8 : I32EnumAttrCase<"U8", 0, "u8">;
+def DotAccumulateUnsigned : I32EnumAttrCase<"UNSIGNED", 0, "unsigned">;
+def DotAccumulateSigned : I32EnumAttrCase<"SIGNED", 1, "signed">;
 
-def DotAccumulate4WayType : I32EnumAttr<"DotAccumulate4WayType",
-                              "NVVM DotAccumulate4WayType",
-                              [DotAccumulate4WayS8, DotAccumulate4WayU8]> {
+def DotAccumulateType : I32EnumAttr<"DotAccumulateType",
+                              "NVVM DotAccumulateType",
+                              [DotAccumulateSigned, DotAccumulateUnsigned]> {
   let cppNamespace = "::mlir::NVVM";
   let genSpecializedAttr = 0;
 }
 
-def DotAccumulate4WayTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulate4WayType, "dot_accumulate_4way_type"> {
+def DotAccumulateTypeAttr : EnumAttr<NVVM_Dialect, DotAccumulateType, "dot_accumulate_type"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
 def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
-  let summary = "Four-way byte dot product-accumulate instruction.";
+  let summary = "Four-way byte dot product-accumulate instruction";
   let description = [{
     Performs a four-way byte dot-product which is accumulated in a 32-bit
     result.
     Operand `a` and `b` are vectors of 4 bytes between which the dot product is 
     computed.
+
     The `a_type` and `b_type` attributes specify the type of the elements in `a`
     and `b` respectively.
-    If `a_type` or `b_type` is `s8`, then the elements in the corresponding 
+    If `a_type` or `b_type` is `signed`, then the elements in the corresponding 
     vector are sign-extended to 32-bit before the dot product is computed.
-    If `a_type` or `b_type` is `u8`, then the elements in the corresponding 
-    vector are zero-extended to 32-bit instead.
+    If `a_type` or `b_type` is `unsigned`, then the elements in the 
+    corresponding vector are zero-extended to 32-bit instead.
+
     Operand `c` is a 32-bit integer to which the result is accumulated. It is
     treated as holding a signed integer if any of `a_type` or `b_type` is `s8`.
     
@@ -3571,9 +3573,9 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   
   let arguments = (ins
     VectorOfLengthAndType<[4], [I8]>:$a,
-    DotAccumulate4WayTypeAttr:$a_type,
+    DotAccumulateTypeAttr:$a_type,
     VectorOfLengthAndType<[4], [I8]>:$b,
-    DotAccumulate4WayTypeAttr:$b_type,
+    DotAccumulateTypeAttr:$b_type,
     I32:$c
   );
 
@@ -3582,17 +3584,15 @@ def NVVM_DotAccumulate4WayOp : NVVM_Op<"dot.accumulate.4way"> {
   let assemblyFormat = "$a $a_type `,` $b $b_type `,` $c attr-dict `:` type($a) `,` type($b)";
   
   let extraClassDeclaration = [{
-    static llvm::Intrinsic::ID
-    getIntrinsicID(NVVM::DotAccumulate4WayType a_type, 
-                   NVVM::DotAccumulate4WayType b_type);
-    llvm::Value* getPackedArg(llvm::Value* arg, llvm::IRBuilderBase& builder);
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase &builder);
   }];
 
   string llvmBuilder = [{
-    llvm::Intrinsic::ID id = NVVM::DotAccumulate4WayOp::getIntrinsicID($a_type, $b_type);
-    llvm::Value* argA = op.getPackedArg($a, builder);
-    llvm::Value* argB = op.getPackedArg($b, builder);
-    $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+    auto [id, args] = NVVM::DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+                        *op, moduleTranslation, builder);
+    $res = createIntrinsicCall(builder, id, args);
   }];
 }
 
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 9f55fe315106c..ef3067b4e351d 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1205,13 +1205,6 @@ LogicalResult NVVM::VoteSyncOp::verify() {
   return success();
 }
 
-llvm::Value *
-NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
-                                        llvm::IRBuilderBase &builder) {
-  return builder.CreateBitCast(arg,
-                               llvm::Type::getInt32Ty(builder.getContext()));
-}
-
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//
@@ -1627,24 +1620,31 @@ static void nvvmInferResultRanges(Operation *op, Value result,
   }
 }
 
-llvm::Intrinsic::ID
-DotAccumulate4WayOp::getIntrinsicID(NVVM::DotAccumulate4WayType a_type,
-                                    NVVM::DotAccumulate4WayType b_type) {
-  bool is_a_siext = a_type == NVVM::DotAccumulate4WayType::S8;
-  bool is_b_siext = b_type == NVVM::DotAccumulate4WayType::S8;
+static llvm::Value *getAsPackedI32(llvm::Value *arg,
+                                   llvm::IRBuilderBase &builder) {
+  return builder.CreateBitCast(arg,
+                               llvm::Type::getInt32Ty(builder.getContext()));
+}
+
+NVVM::IDArgPair DotAccumulate4WayOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto curOp = cast<NVVM::DotAccumulate4WayOp>(op);
+
+  llvm::SmallVector<llvm::Value *> args;
+  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getA()), builder));
+  args.push_back(getAsPackedI32(mt.lookupValue(curOp.getB()), builder));
+  args.push_back(mt.lookupValue(curOp.getC()));
+
+  bool is_a_siext = curOp.getAType() == NVVM::DotAccumulateType::SIGNED;
+  bool is_b_siext = curOp.getBType() == NVVM::DotAccumulateType::SIGNED;
   unsigned type = (is_a_siext << 1) | is_b_siext;
-  switch (type) {
-  case 0:
-    return llvm::Intrinsic::nvvm_idp4a_u_u;
-  case 1:
-    return llvm::Intrinsic::nvvm_idp4a_u_s;
-  case 2:
-    return llvm::Intrinsic::nvvm_idp4a_s_u;
-  case 3:
-    return llvm::Intrinsic::nvvm_idp4a_s_s;
-  default:
-    llvm_unreachable("Invalid DP4a type");
-  }
+  const llvm::Intrinsic::ID ids[] = {
+      llvm::Intrinsic::nvvm_idp4a_u_u,
+      llvm::Intrinsic::nvvm_idp4a_u_s,
+      llvm::Intrinsic::nvvm_idp4a_s_u,
+      llvm::Intrinsic::nvvm_idp4a_s_s,
+  };
+  return {ids[type], args};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index e8425638cc9be..77b302155cb12 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -579,11 +579,11 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
 }
 
 // CHECK-LABEL: @dot_accumulate_4way
-func.func @dot_accumulate_4way(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+func.func @dot_accumulate_4way(%a_vec: vector<4xi8>, %b_vec: vector<4xi8>, %c: i32) {
   // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
-  %1 = nvvm.dot.accumulate.4way %a_vec <u8>, %b_vec <u8>, %c: vector<4xi8>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.4way %a_vec <unsigned>, %b_vec <unsigned>, %c: vector<4xi8>, vector<4xi8>
   // CHECK:   nvvm.dot.accumulate.4way %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
-  %3 = nvvm.dot.accumulate.4way %a_vec <s8>, %b_vec <s8>, %c: vector<4xi8>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.4way %a_vec <signed>, %b_vec <signed>, %c: vector<4xi8>, vector<4xi8>
   return
 }
 
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 90519a9402621..0be9007dd53e4 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -851,18 +851,18 @@ llvm.func @nvvm_dot_accumulate_4way(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32)
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %0 = nvvm.dot.accumulate.4way %a <u8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+  %0 = nvvm.dot.accumulate.4way %a <unsigned>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %1 = nvvm.dot.accumulate.4way %a <s8>, %b <u8>, %c: vector<4xi8>, vector<4xi8>
+  %1 = nvvm.dot.accumulate.4way %a <signed>, %b <unsigned>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+  %2 = nvvm.dot.accumulate.4way %a <unsigned>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
   // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
   // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
-  %3 = nvvm.dot.accumulate.4way %a <s8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
+  %3 = nvvm.dot.accumulate.4way %a <signed>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
   llvm.return
 }

@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-update-dp4a branch from 12cfdab to 380fcb8 Compare May 26, 2025 07:50
@Wolfram70 Wolfram70 self-assigned this May 26, 2025
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-update-dp4a branch from 380fcb8 to c82edd3 Compare May 28, 2025 08:32
This change refactors and updates the dot.accumulate.4way NVVM Op to
be more descriptive and readable.
@Wolfram70 Wolfram70 force-pushed the dev/Wolfram70/mlir-nvvm-update-dp4a branch from c82edd3 to 807ad35 Compare May 28, 2025 10:15
@Wolfram70
Copy link
Contributor Author

Merging as this PR is just a split from #140518 where the changes are already reviewed.

@Wolfram70 Wolfram70 merged commit aca088d into llvm:main May 29, 2025
11 checks passed
// CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
// CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
%2 = nvvm.dot.accumulate.4way %a <u8>, %b <s8>, %c: vector<4xi8>, vector<4xi8>
%2 = nvvm.dot.accumulate.4way %a <unsigned>, %b <signed>, %c: vector<4xi8>, vector<4xi8>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do understand this change. It is nice when you think this PR in isolation.

Thats said, we want to use mlir builtin types. When we do, we need to change IR again because they don't print 'unsigned'. Considering that, should we make this change?

Copy link
Contributor Author

@Wolfram70 Wolfram70 May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I am not sure if we want to use the builtin types for this Op specifically (the signed 8-bit integer SI8) since from what I understand, the signed types can't be directly lowered to LLVM. But if that changes in the future and we don't want to change the existing IR, maybe one potential solution would be to make these attributes optional and also support lowering through the types themselves? Please let me know what you think.

I think like we discussed in #139043 (comment), the cleaner solution would be different Ops for signed/unsigned (like in the arith dialect) but in this case since we have four different intrinsics to lower to instead, this seemed like a good compromise.

svkeerthy pushed a commit that referenced this pull request May 29, 2025
This change refactors and updates the `dot.accumulate.4way` NVVM Op to
be more descriptive and readable.
google-yfyang pushed a commit to google-yfyang/llvm-project that referenced this pull request May 29, 2025
This change refactors and updates the `dot.accumulate.4way` NVVM Op to
be more descriptive and readable.
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
This change refactors and updates the `dot.accumulate.4way` NVVM Op to
be more descriptive and readable.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants