Skip to content

Commit ff88a68

Browse files
Hsiangkaisivan-shani
authored andcommitted
[mlir][gpu] Add GPU subgroup MMA extract and insert operations (llvm#139048)
- Introduced `gpu.subgroup_mma_extract` operation to extract values from `!gpu.mma_matrix` by invocation and indices. - Introduced `gpu.subgroup_mma_insert` operation to insert values into `!gpu.mma_matrix` by invocation and indices. - Updated the conversion patterns to SPIR-V for both extract and insert operations. - Added test cases to validate the new operations in the GPU to SPIR-V conversion. RFC: https://discourse.llvm.org/t/rfc-add-gpu-operations-to-permute-data-in-2-loaded-mma-matrix/86148?u=hsiangkai
1 parent d72e09b commit ff88a68

File tree

4 files changed

+193
-0
lines changed

4 files changed

+193
-0
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,6 +1921,95 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
19211921
}];
19221922
}
19231923

1924+
def GPU_SubgroupMmaExtractThreadLocalOp : GPU_Op<"subgroup_mma_extract_thread_local",
1925+
[Pure,
1926+
TypesMatchWith<"value type matches element type of mma_matrix",
1927+
"matrix", "res",
1928+
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{
1929+
1930+
let summary = "Extract a value from GPU warp by invocation and indices";
1931+
1932+
let description = [{
1933+
The `gpu.subgroup_mma_extract_thread_local` operation extracts a value from `!gpu.mma_matrix`
1934+
that is stored at subgroup level.
1935+
1936+
This operation takes `!gpu.mma_matrix` as its first operand. It is the source
1937+
matrix across a subgroup. The op returns a scalar value stored in the invocation
1938+
in the subgroup.
1939+
1940+
Since `matrix` is packed into the the threads within a subgroup, `indices` are
1941+
the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
1942+
does not necessarily refer to the first element of the matrix, but the first element
1943+
that a particular thread holds.
1944+
1945+
The mapping of matrix elements to threads is not defined by this operation and may
1946+
not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
1947+
size of the subgroup is S, then `subgroup_mma_extract_thread_local` at each index in
1948+
`[0, (M * N) / S)` will have the entire matrix extracted across the subgroup.
1949+
1950+
Example:
1951+
1952+
```mlir
1953+
%c0 = arith.constant 0 : index
1954+
%val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
1955+
```
1956+
}];
1957+
1958+
let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);
1959+
1960+
let results = (outs AnyIntegerOrFloat:$res);
1961+
1962+
let assemblyFormat = [{
1963+
$matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
1964+
}];
1965+
}
1966+
1967+
def GPU_SubgroupMmaInsertThreadLocalOp : GPU_Op<"subgroup_mma_insert_thread_local",
1968+
[Pure,
1969+
TypesMatchWith<"value type matches element type of mma_matrix",
1970+
"matrix", "value",
1971+
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{
1972+
1973+
let summary = "Insert a value into GPU warp by invocation and indices";
1974+
1975+
let description = [{
1976+
The `gpu.subgroup_mma_insert_thread_local` operation inserts a value to `!gpu.mma_matrix`
1977+
that is stored at subgroup level.
1978+
1979+
This operation takes scalar value as its first operand and `!gpu.mma_matrix`
1980+
as its second operand. The op inserts the scalar value to the matrix.
1981+
1982+
Since `matrix` is packed into the the threads within a subgroup, `indices` are
1983+
the indices into the values stored by each thread. That is, an index of 0 (or [0, 0])
1984+
does not necessarily refer to the first element of the matrix, but the first element
1985+
that a particular thread holds.
1986+
1987+
The mapping of matrix elements to threads is not defined by this operation and may
1988+
not be defined by some lowerings (such as the lowering to SPIR-V). However, if the
1989+
size of the subgroup is S, then `subgroup_mma_insert_thread_local` at each index in
1990+
`[0, (M * N) / S)` will have the entire matrix inserted across the subgroup.
1991+
1992+
The op returns `!gpu.mma_matrix` with the updated value.
1993+
1994+
Example:
1995+
1996+
```mlir
1997+
%c0 = arith.constant 0 : index
1998+
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
1999+
-> !gpu.mma_matrix<16x16xf16, "COp">
2000+
```
2001+
}];
2002+
2003+
let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
2004+
Variadic<Index>:$indices);
2005+
2006+
let results = (outs GPU_MMAMatrix:$res);
2007+
2008+
let assemblyFormat = [{
2009+
$value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
2010+
}];
2011+
}
2012+
19242013
def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
19252014
def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
19262015
def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
111111
}
112112
};
113113

114+
/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
115+
/// matrix ops.
116+
struct WmmaExtractOpToSPIRVLowering final
117+
: OpConversionPattern<gpu::SubgroupMmaExtractThreadLocalOp> {
118+
using OpConversionPattern::OpConversionPattern;
119+
120+
LogicalResult
121+
matchAndRewrite(gpu::SubgroupMmaExtractThreadLocalOp op, OpAdaptor adaptor,
122+
ConversionPatternRewriter &rewriter) const override {
123+
Value matrix = adaptor.getMatrix();
124+
auto coopType =
125+
getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
126+
matrix.getType());
127+
if (!coopType)
128+
return rewriter.notifyMatchFailure(op, "type conversion failed");
129+
130+
SmallVector<int32_t> intValues;
131+
for (Value val : op.getIndices()) {
132+
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
133+
intValues.push_back(static_cast<int32_t>(constOp.value()));
134+
} else {
135+
return rewriter.notifyMatchFailure(op, "indices must be constants");
136+
}
137+
}
138+
139+
Type elementType = coopType.getElementType();
140+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141+
op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
142+
return success();
143+
}
144+
};
145+
146+
/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
147+
/// matrix ops.
148+
struct WmmaInsertOpToSPIRVLowering final
149+
: OpConversionPattern<gpu::SubgroupMmaInsertThreadLocalOp> {
150+
using OpConversionPattern::OpConversionPattern;
151+
152+
LogicalResult
153+
matchAndRewrite(gpu::SubgroupMmaInsertThreadLocalOp op, OpAdaptor adaptor,
154+
ConversionPatternRewriter &rewriter) const override {
155+
Value value = adaptor.getValue();
156+
Value matrix = adaptor.getMatrix();
157+
auto coopType = getTypeConverter()->convertType(matrix.getType());
158+
if (!coopType)
159+
return rewriter.notifyMatchFailure(op, "type conversion failed");
160+
161+
SmallVector<int32_t> intValues;
162+
for (Value val : op.getIndices()) {
163+
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
164+
intValues.push_back(static_cast<int32_t>(constOp.value()));
165+
} else {
166+
return rewriter.notifyMatchFailure(op, "indices must be constants");
167+
}
168+
}
169+
170+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171+
op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
172+
return success();
173+
}
174+
};
175+
114176
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
115177
/// the default case.
116178
struct WmmaElementwiseOpToSPIRVDefaultLowering final
@@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
296358
MLIRContext *context = patterns.getContext();
297359
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
298360
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361+
WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
299362
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
300363
// Give the following patterns higher benefit to prevail over the default one.
301364
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,33 @@ module attributes {
9393
gpu.return
9494
}
9595

96+
// CHECK-LABEL: spirv.func @gpu_wmma_extract_thread_local_op
97+
// CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
98+
gpu.func @gpu_wmma_extract_thread_local_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
99+
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
100+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
101+
// CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
102+
%c0 = arith.constant 0 : index
103+
%val = gpu.subgroup_mma_extract_thread_local %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
104+
memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
105+
gpu.return
106+
}
107+
108+
// CHECK-LABEL: spirv.func @gpu_wmma_insert_thread_local_op
109+
// CHECK-SAME: %[[ARG0:.+]]: f16
110+
// CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
111+
gpu.func @gpu_wmma_insert_thread_local_op(%val: f16,
112+
%m: !gpu.mma_matrix<16x16xf16, "COp">,
113+
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
114+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
115+
// CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
116+
%c0 = arith.constant 0 : index
117+
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
118+
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
119+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
120+
gpu.return
121+
}
122+
96123
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
97124
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
98125
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>

mlir/test/Dialect/GPU/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,20 @@ module attributes {gpu.container_module} {
430430
gpu.wait [%token16]
431431
return
432432
}
433+
434+
// CHECK-LABEL: func @extract_insert_mma
435+
func.func @extract_insert_mma(%src : !gpu.mma_matrix<16x16xf32, "COp">,
436+
%ptr: memref<16x16xf32>) {
437+
%zero = arith.constant 0.0 : f32
438+
%c0 = arith.constant 0 : index
439+
// CHECK: gpu.subgroup_mma_extract_thread_local
440+
%val = gpu.subgroup_mma_extract_thread_local %src[%c0] : !gpu.mma_matrix<16x16xf32, "COp"> -> f32
441+
%m = gpu.subgroup_mma_constant_matrix %zero : !gpu.mma_matrix<16x16xf32, "COp">
442+
// CHECK: gpu.subgroup_mma_insert_thread_local
443+
%s0 = gpu.subgroup_mma_insert_thread_local %val, %m[%c0] : f32, !gpu.mma_matrix<16x16xf32, "COp"> -> !gpu.mma_matrix<16x16xf32, "COp">
444+
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0, %c0] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf32, "COp">, memref<16x16xf32>
445+
return
446+
}
433447
}
434448

435449
// Just check that this doesn't crash.

0 commit comments

Comments
 (0)