Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Add MemRefElementTypeInterface to gpu.mma_matrix #132312

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

bondhugula
Copy link
Contributor

Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an
interface method that would allow analyses and cost models to work with
it. This enables creation of memrefs of mma_matrix type, which in turn
enables seamless fusion in the presence affine load/stores on such mma memrefs
or forwarding of stores to loads out of the box.

@bondhugula bondhugula marked this pull request as draft March 21, 2025 00:03
@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir-affine
@llvm/pr-subscribers-mlir-ods
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir-gpu

Author: Uday Bondhugula (bondhugula)

Changes

Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an
interface method that would allow analyses and cost models to work with
it. This enables creation of memrefs of mma_matrix type, which in turn
enables seamless fusion in the presence affine load/stores on such mma memrefs
or forwarding of stores to loads out of the box.


Full diff: https://github.com/llvm/llvm-project/pull/132312.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+4-1)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td (+4)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+13-3)
  • (modified) mlir/lib/Dialect/Affine/Analysis/Utils.cpp (+3)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+7)
  • (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+29)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+4)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 7b53594a1c8e2..d9165006329df 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -128,7 +128,8 @@ struct MMAMatrixStorageType : public TypeStorage {
 ///           : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
 // TODO: consider moving this to ODS.
 class MMAMatrixType
-    : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
+    : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType,
+                            MemRefElementTypeInterface::Trait> {
 public:
   using Base::Base;
 
@@ -163,6 +164,8 @@ class MMAMatrixType
   /// Get elementType of a single element.
   Type getElementType() const;
 
+  /// Implementation for MemRefElementTypeInterface.
+  unsigned getAnalysisSizeInBytes() const;
   /// The general form of operation this type supports is given by the equation
   /// C += A*B. This function returns which operand in the given equation is
   /// held by this type. String returned can be one of"AOp", "BOp" and "COp".
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index bc377dcc72e48..ecd55ef0494d9 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -62,6 +62,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     }]>
   ];
   let skipDefaultBuilders = 1;
+  let extraClassDeclaration = [{
+    /// Best effort size for analysis purposes.
+    unsigned getAnalysisSizeInBytes() { return 8; }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..001d0d9f3e756 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -74,10 +74,20 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
     For example, scalar values such as integers can implement this interface,
     but indicator types such as `void` or `unit` should not.
 
-    The interface currently has no methods and is used by types to opt into
-    being memref elements. This may change in the future, in particular to
-    require types to provide their size or alignment given a data layout.
+    The interface currently has one method and is mainly used by types to opt
+    into being memref elements. This may change in the future, in particular to
+    require types to provide actual size or alignment given a data layout.
   }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Returns the size of the element type in bytes for purposes such as
+      analysis. Such a size is meant to be used in analysis costs models as a
+      best effort in the absence of data layout, as opposed to for
+      target-specific lowering which would require a data layout.
+    }],
+    "unsigned", "getAnalysisSizeInBytes">,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 86aba7b187535..312eaedaa13c3 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -1341,6 +1341,9 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
           vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
     else
       return std::nullopt;
+  } else if (auto memrefEltType = dyn_cast<MemRefElementTypeInterface>(
+                 memRefType.getElementType())) {
+    sizeInBits = memrefEltType.getAnalysisSizeInBytes() * 8;
   } else {
     return std::nullopt;
   }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 976432ea37120..04b8c901b50da 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -149,6 +149,13 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
          elementType.isInteger(32);
 }
 
+unsigned MMAMatrixType::getAnalysisSizeInBytes() const {
+  // The underlying element type is expected to always be int or float and
+  // typically divisible by 8 bits.
+  return ShapedType::getNumElements(getShape()) *
+         llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
+}
+
 LogicalResult
 MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
                                 ArrayRef<int64_t> shape, Type elementType,
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 4b9eca45492fb..0c2cc5503d7c2 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -666,3 +666,32 @@ func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
   // PRODUCER-CONSUMER-MAXIMAL:          affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
   return
 }
+
+// Test for fusion of affine load/store on memrefs of MMA type.
+
+// PRODUCER-CONSUMER-LABEL: func @gpu_mma_cast
+func.func @gpu_mma_cast(%a: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %b: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %c: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>) {
+  affine.for %i = 0 to 8 {
+    affine.for %j = 0 to 4 {
+      %v = affine.load %a[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+      affine.store %v, %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+    }
+  }
+
+  affine.for %i = 0 to 8 {
+    affine.for %j = 0 to 4 {
+      %v = affine.load %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+      affine.store %v, %c[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+
+    }
+  }
+  // PRODUCER-CONSUMER:      affine.for %{{.*}} = 0 to 8 {
+  // PRODUCER-CONSUMER-NEXT:   affine.for %{{.*}} = 0 to 4 {
+  // PRODUCER-CONSUMER-NEXT:     affine.load
+  // PRODUCER-CONSUMER-NEXT:     affine.store
+  // PRODUCER-CONSUMER-NEXT:     affine.load
+  // PRODUCER-CONSUMER-NEXT:     affine.store
+
+  return
+  // PRODUCER-CONSUMER: return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..c3aac18917ba7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -169,6 +169,10 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
 def TestMemRefElementType : Test_Type<"TestMemRefElementType",
                                       [MemRefElementTypeInterface]> {
   let mnemonic = "memref_element";
+
+  let extraClassDeclaration = [{
+    unsigned getAnalysisSizeInBytes() const { return 1; }
+  }];
 }
 
 def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-mlir

Author: Uday Bondhugula (bondhugula)

Changes

Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an
interface method that would allow analyses and cost models to work with
it. This enables creation of memrefs of mma_matrix type, which in turn
enables seamless fusion in the presence affine load/stores on such mma memrefs
or forwarding of stores to loads out of the box.


Full diff: https://github.com/llvm/llvm-project/pull/132312.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h (+4-1)
  • (modified) mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td (+4)
  • (modified) mlir/include/mlir/IR/BuiltinTypeInterfaces.td (+13-3)
  • (modified) mlir/lib/Dialect/Affine/Analysis/Utils.cpp (+3)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+7)
  • (modified) mlir/test/Dialect/Affine/loop-fusion-4.mlir (+29)
  • (modified) mlir/test/lib/Dialect/Test/TestTypeDefs.td (+4)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 7b53594a1c8e2..d9165006329df 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -128,7 +128,8 @@ struct MMAMatrixStorageType : public TypeStorage {
 ///           : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
 // TODO: consider moving this to ODS.
 class MMAMatrixType
-    : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
+    : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType,
+                            MemRefElementTypeInterface::Trait> {
 public:
   using Base::Base;
 
@@ -163,6 +164,8 @@ class MMAMatrixType
   /// Get elementType of a single element.
   Type getElementType() const;
 
+  /// Implementation for MemRefElementTypeInterface.
+  unsigned getAnalysisSizeInBytes() const;
   /// The general form of operation this type supports is given by the equation
   /// C += A*B. This function returns which operand in the given equation is
   /// held by this type. String returned can be one of"AOp", "BOp" and "COp".
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index bc377dcc72e48..ecd55ef0494d9 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -62,6 +62,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
     }]>
   ];
   let skipDefaultBuilders = 1;
+  let extraClassDeclaration = [{
+    /// Best effort size for analysis purposes.
+    unsigned getAnalysisSizeInBytes() { return 8; }
+  }];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..001d0d9f3e756 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -74,10 +74,20 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
     For example, scalar values such as integers can implement this interface,
     but indicator types such as `void` or `unit` should not.
 
-    The interface currently has no methods and is used by types to opt into
-    being memref elements. This may change in the future, in particular to
-    require types to provide their size or alignment given a data layout.
+    The interface currently has one method and is mainly used by types to opt
+    into being memref elements. This may change in the future, in particular to
+    require types to provide actual size or alignment given a data layout.
   }];
+
+  let methods = [
+    InterfaceMethod<[{
+      Returns the size of the element type in bytes for purposes such as
+      analysis. Such a size is meant to be used in analysis costs models as a
+      best effort in the absence of data layout, as opposed to for
+      target-specific lowering which would require a data layout.
+    }],
+    "unsigned", "getAnalysisSizeInBytes">,
+  ];
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 86aba7b187535..312eaedaa13c3 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -1341,6 +1341,9 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
           vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
     else
       return std::nullopt;
+  } else if (auto memrefEltType = dyn_cast<MemRefElementTypeInterface>(
+                 memRefType.getElementType())) {
+    sizeInBits = memrefEltType.getAnalysisSizeInBytes() * 8;
   } else {
     return std::nullopt;
   }
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 976432ea37120..04b8c901b50da 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -149,6 +149,13 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
          elementType.isInteger(32);
 }
 
+unsigned MMAMatrixType::getAnalysisSizeInBytes() const {
+  // The underlying element type is expected to always be int or float and
+  // typically divisible by 8 bits.
+  return ShapedType::getNumElements(getShape()) *
+         llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
+}
+
 LogicalResult
 MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
                                 ArrayRef<int64_t> shape, Type elementType,
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 4b9eca45492fb..0c2cc5503d7c2 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -666,3 +666,32 @@ func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
   // PRODUCER-CONSUMER-MAXIMAL:          affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
   return
 }
+
+// Test for fusion of affine load/store on memrefs of MMA type.
+
+// PRODUCER-CONSUMER-LABEL: func @gpu_mma_cast
+func.func @gpu_mma_cast(%a: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %b: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %c: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>) {
+  affine.for %i = 0 to 8 {
+    affine.for %j = 0 to 4 {
+      %v = affine.load %a[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+      affine.store %v, %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+    }
+  }
+
+  affine.for %i = 0 to 8 {
+    affine.for %j = 0 to 4 {
+      %v = affine.load %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+      affine.store %v, %c[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+
+    }
+  }
+  // PRODUCER-CONSUMER:      affine.for %{{.*}} = 0 to 8 {
+  // PRODUCER-CONSUMER-NEXT:   affine.for %{{.*}} = 0 to 4 {
+  // PRODUCER-CONSUMER-NEXT:     affine.load
+  // PRODUCER-CONSUMER-NEXT:     affine.store
+  // PRODUCER-CONSUMER-NEXT:     affine.load
+  // PRODUCER-CONSUMER-NEXT:     affine.store
+
+  return
+  // PRODUCER-CONSUMER: return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..c3aac18917ba7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -169,6 +169,10 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
 def TestMemRefElementType : Test_Type<"TestMemRefElementType",
                                       [MemRefElementTypeInterface]> {
   let mnemonic = "memref_element";
+
+  let extraClassDeclaration = [{
+    unsigned getAnalysisSizeInBytes() const { return 1; }
+  }];
 }
 
 def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;

@bondhugula bondhugula force-pushed the uday/memref_elt_interface_mma_matrix branch 2 times, most recently from 19b2cf7 to 0837c1d Compare March 21, 2025 00:14
Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an
interface method that would allow analyses and cost models to work with
it. This enables creation of memrefs of mma_matrix type, which in turn
enables seamless fusion in the presence affine load/stores on such mma memrefs
or forwarding of stores to loads out of the box.
@bondhugula bondhugula force-pushed the uday/memref_elt_interface_mma_matrix branch from 0837c1d to f89c332 Compare March 21, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants