Skip to content

[MLIR][NVVM] Add Op to create tcgen05-mma smem descriptor #141651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

durga4github
Copy link
Contributor

This patch adds an Op to create the shared-memory
descriptor for Tcgen05 MMA.

This patch adds an Op to create the shared-memory
descriptor for Tcgen05 MMA.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
@durga4github durga4github requested a review from grypp as a code owner May 27, 2025 18:23
@durga4github durga4github requested review from grypp and removed request for grypp May 27, 2025 18:23
@llvmbot
Copy link
Member

llvmbot commented May 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Durgadoss R (durga4github)

Changes

This patch adds an Op to create the shared-memory
descriptor for Tcgen05 MMA.


Full diff: https://github.com/llvm/llvm-project/pull/141651.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+64)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+44)
  • (added) mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir (+38)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 13f693872d890..408537be0a5e4 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3373,6 +3373,70 @@ def NVVM_Tcgen05CpOp : NVVM_Op<"tcgen05.cp", [NVVMRequiresSMa<[100, 101]>]> {
   }];
 }
 
+def NVVM_Tcgen05MmaSmemDescOp : NVVM_Op<"tcgen05.mma_smem_desc", []> {
+  let summary = "Constructs a Shared Memory descriptor for MMA Operands A or B";
+  let description = [{
+    The `nvvm.tcgen05_mma_smem_desc` constructs a Shared Memory descriptor
+    for tcgen05.mma. This descriptor is a 64-bit value which describes the
+    properties of multiplicand matrix in shared memory including its location
+    in the shared memory of the current CTA.
+
+    +-----------+------+------------------------------------------------------+
+    | Bit-field | Size | Description                                          |
+    +-----------+------+------------------------------------------------------+
+    | 0-13      | 14   | Matrix start address                                 |
+    | 14-15     | 2    | Reserved                                             |
+    | 16-29     | 14   | Leading dim relative-offset (or) absolute-address    |
+    | 30-31     | 2    | Reserved                                             |
+    | 32-45     | 14   | Stride dimension byte offset                         |
+    | 46-48     | 3    | Fixed constant value of 0b001                        |
+    | 49-51     | 3    | Matrix base offset                                   |
+    | 52        | 1    | Leading dimension stride mode:                       |
+    |           |      |   0: byte offset relative                            |
+    |           |      |   1: byte address absolute                           |
+    | 53-60     | 8    | Fixed constant value of 0xb00000000                  |
+    | 61-63     | 3    | Swizzling mode:                                      |
+    |           |      |   0: No swizzling                                    |
+    |           |      |   1: 128-Byte with 32B atomic swizzling              |
+    |           |      |   2: 128-Byte swizzling                              |
+    |           |      |   4: 64-Byte swizzling                               |
+    |           |      |   6: 32-Byte swizzling                               |
+    |           |      |   (Values 3, 5 and 7 are invalid)                    |
+    +-----------+------+------------------------------------------------------+    
+
+    Example:
+    ```mlir
+      %desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset,
+                                          %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
+    ```
+    [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-shared-memory-descriptor)
+  }];
+
+  let arguments = (ins
+      I32:$startAddr,        // Matrix A or B start address (bits 13-0)
+      I32:$leadingDimOffset, // Matrix A or B leading dim byte offset (bits 29-16)
+      I32:$strideDimOffset,  // Matrix A or B stride  dim byte offset (bits 45-32)
+      I8:$baseOffset,        // Matrix A or B base offset (bits 51-49)
+      I1:$leadingDimMode,    // Matrix A or B leading dim mode (bit 52)
+      I8:$swizzleMode        // Swizzle mode (bits 63-61)
+  );
+
+  let results = (outs I64:$res);
+
+  let assemblyFormat = [{
+    `(` operands `)` attr-dict `:` `(` type(operands) `)` `->` type($res)
+  }];
+
+  let extraClassDeclaration = [{
+    static void createSmemDescriptor(Operation &op, LLVM::ModuleTranslation &mt,
+                                     llvm::IRBuilderBase& builder);
+  }];
+
+  string llvmBuilder = [{
+    NVVM::Tcgen05MmaSmemDescOp::createSmemDescriptor(*op, moduleTranslation, builder);
+  }];
+}
+
 //===----------------------------------------------------------------------===//
 // NVVM tcgen05 LdSt Shape Attr
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 79d9d2f6255e7..8036ea27f524f 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1212,6 +1212,50 @@ NVVM::DotAccumulate4WayOp::getPackedArg(llvm::Value *arg,
                                llvm::Type::getInt32Ty(builder.getContext()));
 }
 
+/// Packs the given `field` into the `result`.
+/// The `result` is 64-bits and each `field` can be 32-bits or narrower.
+static llvm::Value *
+packValInto64Bits(llvm::IRBuilderBase &builder,
+                  llvm::Value *result, // the `result` (unset bits are zero)
+                  llvm::Value *field,  // `field` to pack into `result`
+                  unsigned sizeInBits, // Size of `field` in bits
+                  unsigned start) {    // Starting bit within `result`
+  field = builder.CreateZExtOrBitCast(field, builder.getInt32Ty());
+
+  unsigned mask = (sizeInBits < 32 ? ((1u << sizeInBits) - 1) : 0xffffffffu);
+  if (mask != 0xffffffffu)
+    field = builder.CreateAnd(field, builder.getInt32(mask));
+
+  field = builder.CreateZExtOrBitCast(field, builder.getInt64Ty());
+  field = builder.CreateShl(field, start);
+
+  return builder.CreateOr(result, field);
+}
+
+void Tcgen05MmaSmemDescOp::createSmemDescriptor(Operation &op,
+                                                LLVM::ModuleTranslation &mt,
+                                                llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::Tcgen05MmaSmemDescOp>(op);
+  llvm::Value *smemDesc = builder.getInt64(0);
+
+  smemDesc = packValInto64Bits(builder, smemDesc,
+                               mt.lookupValue(thisOp.getStartAddr()), 14, 0);
+  smemDesc = packValInto64Bits(
+      builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimOffset()), 14, 16);
+  smemDesc = packValInto64Bits(
+      builder, smemDesc, mt.lookupValue(thisOp.getStrideDimOffset()), 14, 32);
+
+  smemDesc = packValInto64Bits(builder, smemDesc, builder.getInt32(1), 3, 46);
+  smemDesc = packValInto64Bits(builder, smemDesc,
+                               mt.lookupValue(thisOp.getBaseOffset()), 3, 49);
+  smemDesc = packValInto64Bits(
+      builder, smemDesc, mt.lookupValue(thisOp.getLeadingDimMode()), 1, 52);
+  smemDesc = packValInto64Bits(builder, smemDesc,
+                               mt.lookupValue(thisOp.getSwizzleMode()), 3, 61);
+
+  mt.mapValue(thisOp.getRes()) = smemDesc;
+}
+
 //===----------------------------------------------------------------------===//
 // getIntrinsicID/getIntrinsicIDAndArgs methods
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir b/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir
new file mode 100644
index 0000000000000..5af79c6f1379b
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/nvvm/tcgen05-smem-desc.mlir
@@ -0,0 +1,38 @@
+// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
+
+// CHECK-LABEL: define i64 @tcgen05_mma_smem_desc_test(i32 %0, i32 %1, i32 %2, i8 %3, i1 %4, i8 %5) {
+llvm.func @tcgen05_mma_smem_desc_test(%startAddr: i32, %leadingDimOffset: i32, %strideDimOffset: i32,
+                                      %baseOffset: i8, %leadingDimMode: i1, %swizzleMode: i8) -> i64 {
+  // CHECK-NEXT: %7 = and i32 %0, 16383
+  // CHECK-NEXT: %8 = zext i32 %7 to i64
+  // CHECK-NEXT: %9 = shl i64 %8, 0
+  // CHECK-NEXT: %10 = or i64 0, %9
+  // CHECK-NEXT: %11 = and i32 %1, 16383
+  // CHECK-NEXT: %12 = zext i32 %11 to i64
+  // CHECK-NEXT: %13 = shl i64 %12, 16
+  // CHECK-NEXT: %14 = or i64 %10, %13
+  // CHECK-NEXT: %15 = and i32 %2, 16383
+  // CHECK-NEXT: %16 = zext i32 %15 to i64
+  // CHECK-NEXT: %17 = shl i64 %16, 32
+  // CHECK-NEXT: %18 = or i64 %14, %17
+  // CHECK-NEXT: %19 = or i64 %18, 70368744177664
+  // CHECK-NEXT: %20 = zext i8 %3 to i32
+  // CHECK-NEXT: %21 = and i32 %20, 7
+  // CHECK-NEXT: %22 = zext i32 %21 to i64
+  // CHECK-NEXT: %23 = shl i64 %22, 49
+  // CHECK-NEXT: %24 = or i64 %19, %23
+  // CHECK-NEXT: %25 = zext i1 %4 to i32
+  // CHECK-NEXT: %26 = and i32 %25, 1
+  // CHECK-NEXT: %27 = zext i32 %26 to i64
+  // CHECK-NEXT: %28 = shl i64 %27, 52
+  // CHECK-NEXT: %29 = or i64 %24, %28
+  // CHECK-NEXT: %30 = zext i8 %5 to i32
+  // CHECK-NEXT: %31 = and i32 %30, 7
+  // CHECK-NEXT: %32 = zext i32 %31 to i64
+  // CHECK-NEXT: %33 = shl i64 %32, 61
+  // CHECK-NEXT: %34 = or i64 %29, %33
+  // CHECK-NEXT: ret i64 %34
+  // CHECK-NEXT: }
+  %desc = nvvm.tcgen05.mma_smem_desc (%startAddr, %leadingDimOffset, %strideDimOffset, %baseOffset, %leadingDimMode, %swizzleMode) : (i32, i32, i32, i8, i1, i8) -> i64
+  llvm.return %desc : i64
+}

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Even though this is not part of the ptx spec, it's extremely used to have an op for that!

@durga4github durga4github merged commit a615975 into llvm:main May 28, 2025
14 checks passed
@durga4github durga4github deleted the durgadossr/mlir_tcgen05_smem_desc branch May 28, 2025 08:13
sivan-shani pushed a commit to sivan-shani/llvm-project that referenced this pull request Jun 3, 2025
This patch adds an Op to create the shared-memory
descriptor for Tcgen05 MMA.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants