-
Notifications
You must be signed in to change notification settings - Fork 13.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add MemRefElementTypeInterface to gpu.mma_matrix #132312
base: main
Are you sure you want to change the base?
[MLIR] Add MemRefElementTypeInterface to gpu.mma_matrix #132312
Conversation
@llvm/pr-subscribers-mlir-affine @llvm/pr-subscribers-mlir-gpu Author: Uday Bondhugula (bondhugula) ChangesAdd MemRefElementTypeInterface to gpu.mma_matrix and introduce an Full diff: https://github.com/llvm/llvm-project/pull/132312.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 7b53594a1c8e2..d9165006329df 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -128,7 +128,8 @@ struct MMAMatrixStorageType : public TypeStorage {
/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
// TODO: consider moving this to ODS.
class MMAMatrixType
- : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
+ : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType,
+ MemRefElementTypeInterface::Trait> {
public:
using Base::Base;
@@ -163,6 +164,8 @@ class MMAMatrixType
/// Get elementType of a single element.
Type getElementType() const;
+ /// Implementation for MemRefElementTypeInterface.
+ unsigned getAnalysisSizeInBytes() const;
/// The general form of operation this type supports is given by the equation
/// C += A*B. This function returns which operand in the given equation is
/// held by this type. String returned can be one of"AOp", "BOp" and "COp".
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index bc377dcc72e48..ecd55ef0494d9 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -62,6 +62,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
}]>
];
let skipDefaultBuilders = 1;
+ let extraClassDeclaration = [{
+ /// Best effort size for analysis purposes.
+ unsigned getAnalysisSizeInBytes() { return 8; }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..001d0d9f3e756 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -74,10 +74,20 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
For example, scalar values such as integers can implement this interface,
but indicator types such as `void` or `unit` should not.
- The interface currently has no methods and is used by types to opt into
- being memref elements. This may change in the future, in particular to
- require types to provide their size or alignment given a data layout.
+ The interface currently has one method and is mainly used by types to opt
+ into being memref elements. This may change in the future, in particular to
+ require types to provide actual size or alignment given a data layout.
}];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the size of the element type in bytes for purposes such as
+ analysis. Such a size is meant to be used in analysis costs models as a
+ best effort in the absence of data layout, as opposed to for
+ target-specific lowering which would require a data layout.
+ }],
+ "unsigned", "getAnalysisSizeInBytes">,
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 86aba7b187535..312eaedaa13c3 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -1341,6 +1341,9 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
else
return std::nullopt;
+ } else if (auto memrefEltType = dyn_cast<MemRefElementTypeInterface>(
+ memRefType.getElementType())) {
+ sizeInBits = memrefEltType.getAnalysisSizeInBytes() * 8;
} else {
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 976432ea37120..04b8c901b50da 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -149,6 +149,13 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
elementType.isInteger(32);
}
+unsigned MMAMatrixType::getAnalysisSizeInBytes() const {
+ // The underlying element type is expected to always be int or float and
+ // typically divisible by 8 bits.
+ return ShapedType::getNumElements(getShape()) *
+ llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
+}
+
LogicalResult
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 4b9eca45492fb..0c2cc5503d7c2 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -666,3 +666,32 @@ func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
// PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
return
}
+
+// Test for fusion of affine load/store on memrefs of MMA type.
+
+// PRODUCER-CONSUMER-LABEL: func @gpu_mma_cast
+func.func @gpu_mma_cast(%a: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %b: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %c: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>) {
+ affine.for %i = 0 to 8 {
+ affine.for %j = 0 to 4 {
+ %v = affine.load %a[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ affine.store %v, %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ }
+ }
+
+ affine.for %i = 0 to 8 {
+ affine.for %j = 0 to 4 {
+ %v = affine.load %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ affine.store %v, %c[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+
+ }
+ }
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 8 {
+ // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 4 {
+ // PRODUCER-CONSUMER-NEXT: affine.load
+ // PRODUCER-CONSUMER-NEXT: affine.store
+ // PRODUCER-CONSUMER-NEXT: affine.load
+ // PRODUCER-CONSUMER-NEXT: affine.store
+
+ return
+ // PRODUCER-CONSUMER: return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..c3aac18917ba7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -169,6 +169,10 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
def TestMemRefElementType : Test_Type<"TestMemRefElementType",
[MemRefElementTypeInterface]> {
let mnemonic = "memref_element";
+
+ let extraClassDeclaration = [{
+ unsigned getAnalysisSizeInBytes() const { return 1; }
+ }];
}
def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;
|
@llvm/pr-subscribers-mlir Author: Uday Bondhugula (bondhugula) ChangesAdd MemRefElementTypeInterface to gpu.mma_matrix and introduce an Full diff: https://github.com/llvm/llvm-project/pull/132312.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
index 7b53594a1c8e2..d9165006329df 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDialect.h
@@ -128,7 +128,8 @@ struct MMAMatrixStorageType : public TypeStorage {
/// : index}: !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
// TODO: consider moving this to ODS.
class MMAMatrixType
- : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType> {
+ : public Type::TypeBase<MMAMatrixType, Type, MMAMatrixStorageType,
+ MemRefElementTypeInterface::Trait> {
public:
using Base::Base;
@@ -163,6 +164,8 @@ class MMAMatrixType
/// Get elementType of a single element.
Type getElementType() const;
+ /// Implementation for MemRefElementTypeInterface.
+ unsigned getAnalysisSizeInBytes() const;
/// The general form of operation this type supports is given by the equation
/// C += A*B. This function returns which operand in the given equation is
/// held by this type. String returned can be one of"AOp", "BOp" and "COp".
diff --git a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
index bc377dcc72e48..ecd55ef0494d9 100644
--- a/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
+++ b/mlir/include/mlir/Dialect/Ptr/IR/PtrDialect.td
@@ -62,6 +62,10 @@ def Ptr_PtrType : Ptr_Type<"Ptr", "ptr", [
}]>
];
let skipDefaultBuilders = 1;
+ let extraClassDeclaration = [{
+ /// Best effort size for analysis purposes.
+ unsigned getAnalysisSizeInBytes() { return 8; }
+ }];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
index 8aa2c55570153..001d0d9f3e756 100644
--- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
+++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td
@@ -74,10 +74,20 @@ def MemRefElementTypeInterface : TypeInterface<"MemRefElementTypeInterface"> {
For example, scalar values such as integers can implement this interface,
but indicator types such as `void` or `unit` should not.
- The interface currently has no methods and is used by types to opt into
- being memref elements. This may change in the future, in particular to
- require types to provide their size or alignment given a data layout.
+ The interface currently has one method and is mainly used by types to opt
+ into being memref elements. This may change in the future, in particular to
+ require types to provide actual size or alignment given a data layout.
}];
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the size of the element type in bytes for purposes such as
+ analysis. Such a size is meant to be used in analysis costs models as a
+ best effort in the absence of data layout, as opposed to for
+ target-specific lowering which would require a data layout.
+ }],
+ "unsigned", "getAnalysisSizeInBytes">,
+ ];
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
index 86aba7b187535..312eaedaa13c3 100644
--- a/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
+++ b/mlir/lib/Dialect/Affine/Analysis/Utils.cpp
@@ -1341,6 +1341,9 @@ mlir::affine::getMemRefIntOrFloatEltSizeInBytes(MemRefType memRefType) {
vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
else
return std::nullopt;
+ } else if (auto memrefEltType = dyn_cast<MemRefElementTypeInterface>(
+ memRefType.getElementType())) {
+ sizeInBits = memrefEltType.getAnalysisSizeInBytes() * 8;
} else {
return std::nullopt;
}
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 976432ea37120..04b8c901b50da 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -149,6 +149,13 @@ bool MMAMatrixType::isValidElementType(Type elementType) {
elementType.isInteger(32);
}
+unsigned MMAMatrixType::getAnalysisSizeInBytes() const {
+ // The underlying element type is expected to always be int or float and
+ // typically divisible by 8 bits.
+ return ShapedType::getNumElements(getShape()) *
+ llvm::divideCeil(getElementType().getIntOrFloatBitWidth(), 8);
+}
+
LogicalResult
MMAMatrixType::verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
ArrayRef<int64_t> shape, Type elementType,
diff --git a/mlir/test/Dialect/Affine/loop-fusion-4.mlir b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
index 4b9eca45492fb..0c2cc5503d7c2 100644
--- a/mlir/test/Dialect/Affine/loop-fusion-4.mlir
+++ b/mlir/test/Dialect/Affine/loop-fusion-4.mlir
@@ -666,3 +666,32 @@ func.func @unrolled(%arg0: memref<2x4xf32>, %arg1: memref<1x2x4xf32>) {
// PRODUCER-CONSUMER-MAXIMAL: affine.load %{{.*}}[0, %{{.*}}, %{{.*}}]
return
}
+
+// Test for fusion of affine load/store on memrefs of MMA type.
+
+// PRODUCER-CONSUMER-LABEL: func @gpu_mma_cast
+func.func @gpu_mma_cast(%a: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %b: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>, %c: memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>) {
+ affine.for %i = 0 to 8 {
+ affine.for %j = 0 to 4 {
+ %v = affine.load %a[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ affine.store %v, %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ }
+ }
+
+ affine.for %i = 0 to 8 {
+ affine.for %j = 0 to 4 {
+ %v = affine.load %b[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+ affine.store %v, %c[%i, %j] : memref<8x4x!gpu.mma_matrix<16x16xf32, "AOp">, 3>
+
+ }
+ }
+ // PRODUCER-CONSUMER: affine.for %{{.*}} = 0 to 8 {
+ // PRODUCER-CONSUMER-NEXT: affine.for %{{.*}} = 0 to 4 {
+ // PRODUCER-CONSUMER-NEXT: affine.load
+ // PRODUCER-CONSUMER-NEXT: affine.store
+ // PRODUCER-CONSUMER-NEXT: affine.load
+ // PRODUCER-CONSUMER-NEXT: affine.store
+
+ return
+ // PRODUCER-CONSUMER: return
+}
diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
index f1c31658c13ac..c3aac18917ba7 100644
--- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td
+++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td
@@ -169,6 +169,10 @@ def TestTypeWithLayoutType : Test_Type<"TestTypeWithLayout", [
def TestMemRefElementType : Test_Type<"TestMemRefElementType",
[MemRefElementTypeInterface]> {
let mnemonic = "memref_element";
+
+ let extraClassDeclaration = [{
+ unsigned getAnalysisSizeInBytes() const { return 1; }
+ }];
}
def TestTypeTrait : NativeTypeTrait<"TestTypeTrait">;
|
19b2cf7
to
0837c1d
Compare
Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an interface method that would allow analyses and cost models to work with it. This enables creation of memrefs of mma_matrix type, which in turn enables seamless fusion in the presence affine load/stores on such mma memrefs or forwarding of stores to loads out of the box.
0837c1d
to
f89c332
Compare
Add MemRefElementTypeInterface to gpu.mma_matrix and introduce an
interface method that would allow analyses and cost models to work with
it. This enables creation of memrefs of mma_matrix type, which in turn
enables seamless fusion in the presence affine load/stores on such mma memrefs
or forwarding of stores to loads out of the box.