-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[MLIR][NVVM] Add Op to create tcgen05-mma smem descriptor #141651
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][NVVM] Add Op to create tcgen05-mma smem descriptor #141651
Conversation
This patch adds an Op to create the shared-memory descriptor for Tcgen05 MMA. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Durgadoss R (durga4github) ChangesThis patch adds an Op to create the shared-memory Full diff: https://github.com/llvm/llvm-project/pull/141651.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 13f693872d890..408537be0a5e4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3373,6 +3373,70 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
}];
}
+def NVVM_Tcgen05MmaSmemDescOp : NVVM_Op<"tcgen05.mma_smem_desc", []> {
+ let summary = "Constructs a Shared Memory descriptor for MMA Operands A or B";
+ let description = [{
+ The `nvvm.tcgen05_mma_smem_desc` constructs a Shared Memory descriptor
+ for tcgen05.mma. This descriptor is a 64-bit value which describes the
+ properties of multiplicand matrix in shared memory including its location
+ in the shared memory of the current CTA.
+
+ +-----------+------+------------------------------------------------------+
+ | Bit-field | Size | Description |
+ +-----------+------+------------------------------------------------------+
+ | 0-13 | 14 | Matrix start address |
+ | 14-15 | 2 | Reserved |
+ | 16-29 | 14 | Leading dim relative-offset (or) absolute-address |
+ | 30-31 | 2 | Reserved |
+ | 32-45 | 14 | Stride dimension byte offset |
+ | 46-48 | 3 | Fixed constant value of 0b001 |
+ | 49-51 | 3 | Matrix base offset |
+ | 52 | 1 | Leading dimension stride mode: |
+ | | | 0: byte offset relative |
+ | | | 1: byte address absolute |
+ | 53-60 | 8 | Fixed constant value of 0xb00000000 |
+ | 61-63 | 3 | Swizzling mode: |
+ | | | 0: No swizzling |
+ | | | 1: 128-Byte with 32B atomic swizzling |
+ | | | 2: 128-Byte swizzling |
+ | | | 4: 64-Byte swizzling |
+ | | | 6: 32-Byte swizzling |
+ | | | (Values 3, 5 and 7 are invalid) |
+ +-----------+------+------------------------------------------------------+
+
+ Example:
+ ```mlir
+ %desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset,
+ %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
+ ```
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
+ }];
+
+ let arguments = (ins
+ I32:$startAddr, // Matrix A or B start address (bits 13-0)
+ I32:$leadingDimOffset, // Matrix A or B leading dim byte offset (bits 29-16)
+ I32:$strideDimOffset, // Matrix A or B stride dim byte offset (bits 45-32)
+ I8:$baseOffset, // Matrix A or B base offset (bits 51-49)
+ I1:$leadingDimMode, // Matrix A or B leading dim mode (bit 52)
+ I8:$swizzleMode // Swizzle mode (bits 63-61)
+ );
+
+ let results = (outs I64:$res);
+
+ let assemblyFormat = [{
+ `(` operands `)` attr-dict `:` `(` type(operands) `)` `->` type($res)
+ }];
+
+ let extraClassDeclaration = [{
+ static void createSmemDescriptor(Operation &op, LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase& builder);
+ }];
+
+ string llvmBuilder = [{
+ NVVM::Tcgen05MmaSmemDescOp::createSmemDescriptor(*op, moduleTranslation, builder);
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM tcgen05 LdSt Shape Attr
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 79d9d2f6255e7..8036ea27f524f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1212,6 +1212,50 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
llvm::Type::getInt32Ty(builder.getContext()));
}
+/// Packs the given `field` into the `result`.
+/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
+static llvm::Value *
+packValInto64Bits(llvm::IRBuilderBase &builder,
+ llvm::Value *result, // the `result` (unset bits are zero)
+ llvm::Value *field, // `field` to pack into `result`
+ unsigned sizeInBits, // Size of `field` in bits
+ unsigned start) { // Starting bit within `result`
+ field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
+
+ unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
+ if (mask != 0xffffffffu)
+ field = builder.CreateAnd(field, builder.getInt32(mask));
+
+ field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
+ field = builder.CreateShl(field, start);
+
+ return builder.CreateOr(result, field);
+}
+
+void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
+ LLVM::ModuleTranslation &mt,
+ llvm::IRBuilderBase &builder) {
+ auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
+ llvm::Value *smemDesc = builder.getInt64(0);
+
+ smemDesc = packValInto64Bits(builder, smemDesc,
+ mt.lookupValue(thisOp.getStartAddr()), 14, 0);
+ smemDesc = packValInto64Bits(
+ builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
+ smemDesc = packValInto64Bits(
+ builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
+
+ smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
+ smemDesc = packValInto64Bits(builder, smemDesc,
+ mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
+ smemDesc = packValInto64Bits(
+ builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
+ smemDesc = packValInto64Bits(builder, smemDesc,
+ mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
+
+ mt.mapValue(thisOp.getRes()) = smemDesc;
+}
+
//===----------------------------------------------------------------------===//
// getIntrinsicID/getIntrinsicIDAndArgs methods
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir
new file mode 100644
index 0000000000000..5af79c6f1379b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define i64 @tcgen05_mma_smem_desc_test(i32 %0, i32 %1, i32 %2, i8 %3, i1 %4, i8 %5) {
+llvm.func @tcgen05_mma_smem_desc_test(%startAddr: i32, %leadingDimOffset: i32, %strideDimOffset: i32,
+ %baseOffset: i8, %leadingDimMode: i1, %swizzleMode: i8) -> i64 {
+ // CHECK-NEXT: %7 = and i32 %0, 16383
+ // CHECK-NEXT: %8 = zext i32 %7 to i64
+ // CHECK-NEXT: %9 = shl i64 %8, 0
+ // CHECK-NEXT: %10 = or i64 0, %9
+ // CHECK-NEXT: %11 = and i32 %1, 16383
+ // CHECK-NEXT: %12 = zext i32 %11 to i64
+ // CHECK-NEXT: %13 = shl i64 %12, 16
+ // CHECK-NEXT: %14 = or i64 %10, %13
+ // CHECK-NEXT: %15 = and i32 %2, 16383
+ // CHECK-NEXT: %16 = zext i32 %15 to i64
+ // CHECK-NEXT: %17 = shl i64 %16, 32
+ // CHECK-NEXT: %18 = or i64 %14, %17
+ // CHECK-NEXT: %19 = or i64 %18, 70368744177664
+ // CHECK-NEXT: %20 = zext i8 %3 to i32
+ // CHECK-NEXT: %21 = and i32 %20, 7
+ // CHECK-NEXT: %22 = zext i32 %21 to i64
+ // CHECK-NEXT: %23 = shl i64 %22, 49
+ // CHECK-NEXT: %24 = or i64 %19, %23
+ // CHECK-NEXT: %25 = zext i1 %4 to i32
+ // CHECK-NEXT: %26 = and i32 %25, 1
+ // CHECK-NEXT: %27 = zext i32 %26 to i64
+ // CHECK-NEXT: %28 = shl i64 %27, 52
+ // CHECK-NEXT: %29 = or i64 %24, %28
+ // CHECK-NEXT: %30 = zext i8 %5 to i32
+ // CHECK-NEXT: %31 = and i32 %30, 7
+ // CHECK-NEXT: %32 = zext i32 %31 to i64
+ // CHECK-NEXT: %33 = shl i64 %32, 61
+ // CHECK-NEXT: %34 = or i64 %29, %33
+ // CHECK-NEXT: ret i64 %34
+ // CHECK-NEXT: }
+ %desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset, %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
+ llvm.return %desc : i64
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Even though this is not part of the ptx spec, it's extremely used to have an op for that!
This patch adds an Op to create the shared-memory descriptor for Tcgen05 MMA. Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
This patch adds an Op to create the shared-memory
descriptor for Tcgen05 MMA.