Skip to content

Commit 9b814e2

Browse files
committed
[mlir][gpu] Add GPU subgroup MMA extract and insert operations
- Introduced `gpu.subgroup_mma_extract` operation to extract values from `!gpu.mma_matrix` by invocation and indices. - Introduced `gpu.subgroup_mma_insert` operation to insert values into `!gpu.mma_matrix` by invocation and indices. - Updated the conversion patterns to SPIR-V for both extract and insert operations. - Added test cases to validate the new operations in the GPU to SPIR-V conversion.
1 parent c27e10f commit 9b814e2

File tree

3 files changed

+163
-0
lines changed

3 files changed

+163
-0
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1919,6 +1919,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
19191919
}];
19201920
}
19211921

1922+
def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
1923+
[Pure,
1924+
TypesMatchWith<"value type matches element type of mma_matrix",
1925+
"matrix", "res",
1926+
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{
1927+
1928+
let summary = "Extract a value from GPU warp by invocation and indices";
1929+
1930+
let description = [{
1931+
The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
1932+
by the invocation in a subgroup.
1933+
1934+
This operation takes `!gpu.mma_matrix` as its first operand. It is the source
1935+
matrix across a subgroup. The op returns a scalar value stored in the invocation
1936+
in the subgroup. If there are multiple values packed in an invocation, use
1937+
`indices` to specify the element to extract.
1938+
1939+
Example:
1940+
1941+
```mlir
1942+
%c0 = arith.constant 0 : index
1943+
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
1944+
```
1945+
}];
1946+
1947+
let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);
1948+
1949+
let results = (outs AnyIntegerOrFloat:$res);
1950+
1951+
let assemblyFormat = [{
1952+
$matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
1953+
}];
1954+
}
1955+
1956+
def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
1957+
[Pure,
1958+
TypesMatchWith<"value type matches element type of mma_matrix",
1959+
"matrix", "value",
1960+
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{
1961+
1962+
let summary = "Insert a value into GPU warp by invocation and indices";
1963+
1964+
let description = [{
1965+
The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
1966+
by the invocation in a subgroup.
1967+
1968+
This operation takes scalar value as its first operand and `!gpu.mma_matrix`
1969+
as its second operand. It is the matrix across a subgroup. The op inserts the
1970+
scalar value stored in the invocation in the subgroup to the matrix. If there
1971+
are multiple values packed in an invocation, use `indices` to specify the
1972+
location to insert in the packing.
1973+
1974+
The op returns `!gpu.mma_matrix` with the updated value.
1975+
1976+
Example:
1977+
1978+
```mlir
1979+
%c0 = arith.constant 0 : index
1980+
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
1981+
-> !gpu.mma_matrix<16x16xf16, "COp">
1982+
```
1983+
}];
1984+
1985+
let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
1986+
Variadic<Index>:$indices);
1987+
1988+
let results = (outs GPU_MMAMatrix:$res);
1989+
1990+
let assemblyFormat = [{
1991+
$value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
1992+
}];
1993+
}
1994+
19221995
def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
19231996
def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
19241997
def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
111111
}
112112
};
113113

114+
/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
115+
/// matrix ops.
116+
struct WmmaExtractOpToSPIRVLowering final
117+
: OpConversionPattern<gpu::SubgroupMmaExtractOp> {
118+
using OpConversionPattern::OpConversionPattern;
119+
120+
LogicalResult
121+
matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
122+
ConversionPatternRewriter &rewriter) const override {
123+
Value matrix = adaptor.getMatrix();
124+
auto coopType =
125+
getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
126+
matrix.getType());
127+
if (!coopType)
128+
return rewriter.notifyMatchFailure(op, "type conversion failed");
129+
130+
SmallVector<int32_t> intValues;
131+
for (Value val : op.getIndices()) {
132+
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
133+
intValues.push_back(static_cast<int32_t>(constOp.value()));
134+
} else {
135+
return rewriter.notifyMatchFailure(op, "indices must be constants");
136+
}
137+
}
138+
139+
Type elementType = coopType.getElementType();
140+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141+
op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
142+
return success();
143+
}
144+
};
145+
146+
/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
147+
/// matrix ops.
148+
struct WmmaInsertOpToSPIRVLowering final
149+
: OpConversionPattern<gpu::SubgroupMmaInsertOp> {
150+
using OpConversionPattern::OpConversionPattern;
151+
152+
LogicalResult
153+
matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
154+
ConversionPatternRewriter &rewriter) const override {
155+
Value value = adaptor.getValue();
156+
Value matrix = adaptor.getMatrix();
157+
auto coopType = getTypeConverter()->convertType(matrix.getType());
158+
if (!coopType)
159+
return rewriter.notifyMatchFailure(op, "type conversion failed");
160+
161+
SmallVector<int32_t> intValues;
162+
for (Value val : op.getIndices()) {
163+
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
164+
intValues.push_back(static_cast<int32_t>(constOp.value()));
165+
} else {
166+
return rewriter.notifyMatchFailure(op, "indices must be constants");
167+
}
168+
}
169+
170+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171+
op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
172+
return success();
173+
}
174+
};
175+
114176
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
115177
/// the default case.
116178
struct WmmaElementwiseOpToSPIRVDefaultLowering final
@@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
296358
MLIRContext *context = patterns.getContext();
297359
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
298360
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361+
WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
299362
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
300363
// Give the following patterns higher benefit to prevail over the default one.
301364
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,33 @@ module attributes {
9393
gpu.return
9494
}
9595

96+
// CHECK-LABEL: spirv.func @gpu_wmma_extract_op
97+
// CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
98+
gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
99+
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
100+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
101+
// CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
102+
%c0 = arith.constant 0 : index
103+
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
104+
memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
105+
gpu.return
106+
}
107+
108+
// CHECK-LABEL: spirv.func @gpu_wmma_insert_op
109+
// CHECK-SAME: %[[ARG0:.+]]: f16
110+
// CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
111+
gpu.func @gpu_wmma_insert_op(%val: f16,
112+
%m: !gpu.mma_matrix<16x16xf16, "COp">,
113+
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
114+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
115+
// CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
116+
%c0 = arith.constant 0 : index
117+
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
118+
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
119+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
120+
gpu.return
121+
}
122+
96123
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
97124
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
98125
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>

0 commit comments

Comments
 (0)