Skip to content

Commit 613f72a

Browse files
committed
[rfc][mlir][gpu] Add operations to extract/insert/rotate within subgroup
Add gpu.rotate, gpu.subgroup_mma_extract, and gpu.subgroup_mma_insert operations.
1 parent 316a6ff commit 613f72a

File tree

5 files changed

+257
-1
lines changed

5 files changed

+257
-1
lines changed

mlir/include/mlir/Dialect/GPU/IR/GPUOps.td

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1364,6 +1364,35 @@ def GPU_ShuffleOp : GPU_Op<
13641364
];
13651365
}
13661366

1367+
def GPU_RotateOp : GPU_Op<
1368+
"rotate", [Pure, AllTypesMatch<["value", "rotateResult"]>]>,
1369+
Arguments<(ins AnyIntegerOrFloatOr1DVector:$value, I32:$offset, I32:$width)>,
1370+
Results<(outs AnyIntegerOrFloatOr1DVector:$rotateResult)> {
1371+
let summary = "Rotate values within a subgroup.";
1372+
let description = [{
1373+
The "rotate" op moves values to a across lanes circularly (a.k.a.,
1374+
invocations, work items) within the same subgroup. The `width` argument
1375+
specifies the number of lanes that participate in the rotation, and must
1376+
be uniform across all lanes. Further, the first `width` lanes of the
1377+
subgroup must be active.
1378+
1379+
Example:
1380+
1381+
```mlir
1382+
%cst1 = arith.constant 1 : i32
1383+
%width = arith.constant 16 : i32
1384+
%1 = gpu.rotate %0, %cst1, %width : f32
1385+
```
1386+
1387+
For lane 0 < `k` < 16, return the value from lane `(k - 1) % width`.
1388+
For lane k == 0, return the value from lane 15.
1389+
}];
1390+
1391+
let assemblyFormat = [{
1392+
$value `,` $offset `,` $width attr-dict `:` type($value)
1393+
}];
1394+
}
1395+
13671396
def GPU_BarrierOp : GPU_Op<"barrier"> {
13681397
let summary = "Synchronizes all work items of a workgroup.";
13691398
let description = [{
@@ -1919,6 +1948,79 @@ def GPU_SubgroupMmaConstantMatrixOp : GPU_Op<"subgroup_mma_constant_matrix",
19191948
}];
19201949
}
19211950

1951+
def GPU_SubgroupMmaExtractOp : GPU_Op<"subgroup_mma_extract",
1952+
[Pure,
1953+
TypesMatchWith<"value type matches element type of mma_matrix",
1954+
"matrix", "res",
1955+
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()">]>{
1956+
1957+
let summary = "Extract a value from GPU warp by invocation and indices";
1958+
1959+
let description = [{
1960+
The `gpu.subgroup_mma_extract` operation extracts a value from `!gpu.mma_matrix`
1961+
by the invocation in a subgroup.
1962+
1963+
This operation takes `!gpu.mma_matrix` as its first operand. It is the source
1964+
matrix across a subgroup. The op returns a scalar value stored in the invocation
1965+
in the subgroup. If there are multiple values packed in an invocation, use
1966+
`indices` to specify the element to extract.
1967+
1968+
Example:
1969+
1970+
```mlir
1971+
%c0 = arith.constant 0 : index
1972+
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
1973+
```
1974+
}];
1975+
1976+
let arguments = (ins GPU_MMAMatrix:$matrix, Variadic<Index>:$indices);
1977+
1978+
let results = (outs AnyIntegerOrFloat:$res);
1979+
1980+
let assemblyFormat = [{
1981+
$matrix`[`$indices`]` attr-dict `:` type($matrix) `->` type($res)
1982+
}];
1983+
}
1984+
1985+
def GPU_SubgroupMmaInsertOp : GPU_Op<"subgroup_mma_insert",
1986+
[Pure,
1987+
TypesMatchWith<"value type matches element type of mma_matrix",
1988+
"matrix", "value",
1989+
"::llvm::cast<gpu::MMAMatrixType>($_self).getElementType()"> ]>{
1990+
1991+
let summary = "Insert a value into GPU warp by invocation and indices";
1992+
1993+
let description = [{
1994+
The `gpu.subgroup_mma_insert` operation inserts a value to `!gpu.mma_matrix`
1995+
by the invocation in a subgroup.
1996+
1997+
This operation takes scalar value as its first operand and `!gpu.mma_matrix`
1998+
as its second operand. It is the matrix across a subgroup. The op inserts the
1999+
scalar value stored in the invocation in the subgroup to the matrix. If there
2000+
are multiple values packed in an invocation, use `indices` to specify the
2001+
location to insert in the packing.
2002+
2003+
The op returns `!gpu.mma_matrix` with the updated value.
2004+
2005+
Example:
2006+
2007+
```mlir
2008+
%c0 = arith.constant 0 : index
2009+
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp">
2010+
-> !gpu.mma_matrix<16x16xf16, "COp">
2011+
```
2012+
}];
2013+
2014+
let arguments = (ins AnyIntegerOrFloat:$value, GPU_MMAMatrix:$matrix,
2015+
Variadic<Index>:$indices);
2016+
2017+
let results = (outs GPU_MMAMatrix:$res);
2018+
2019+
let assemblyFormat = [{
2020+
$value`,` $matrix`[`$indices`]` attr-dict `:` type($value)`,` type($matrix) `->` type($res)
2021+
}];
2022+
}
2023+
19222024
def GPU_ElementwiseOpAddF : I32EnumAttrCase<"ADDF", 0, "addf">;
19232025
def GPU_ElementwiseOpMulF : I32EnumAttrCase<"MULF", 1, "mulf">;
19242026
def GPU_ElementwiseOpSUBF : I32EnumAttrCase<"SUBF", 2, "subf">;

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,16 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
122122
ConversionPatternRewriter &rewriter) const override;
123123
};
124124

125+
/// Pattern to convert a gpu.rotate op into a spirv.GroupNonUniformRotateKHR op.
126+
class GPURotateConversion final : public OpConversionPattern<gpu::RotateOp> {
127+
public:
128+
using OpConversionPattern::OpConversionPattern;
129+
130+
LogicalResult
131+
matchAndRewrite(gpu::RotateOp rotateOp, OpAdaptor adaptor,
132+
ConversionPatternRewriter &rewriter) const override;
133+
};
134+
125135
class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
126136
public:
127137
using OpConversionPattern::OpConversionPattern;
@@ -458,6 +468,35 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
458468
return success();
459469
}
460470

471+
//===----------------------------------------------------------------------===//
472+
// Rotate
473+
//===----------------------------------------------------------------------===//
474+
475+
LogicalResult GPURotateConversion::matchAndRewrite(
476+
gpu::RotateOp rotateOp, OpAdaptor adaptor,
477+
ConversionPatternRewriter &rewriter) const {
478+
// Require the rotate width to be the same as the target's subgroup size,
479+
// given that for SPIR-V non-uniform subgroup ops, we cannot select
480+
// participating invocations.
481+
auto targetEnv = getTypeConverter<SPIRVTypeConverter>()->getTargetEnv();
482+
unsigned subgroupSize =
483+
targetEnv.getAttr().getResourceLimits().getSubgroupSize();
484+
IntegerAttr widthAttr;
485+
if (!matchPattern(rotateOp.getWidth(), m_Constant(&widthAttr)) ||
486+
widthAttr.getValue().getZExtValue() != subgroupSize)
487+
return rewriter.notifyMatchFailure(
488+
rotateOp, "rotate width and target subgroup size mismatch");
489+
490+
Location loc = rotateOp.getLoc();
491+
auto scope = rewriter.getAttr<spirv::ScopeAttr>(spirv::Scope::Subgroup);
492+
493+
Value result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
494+
loc, scope, adaptor.getValue(), adaptor.getOffset(), rotateOp.getWidth());
495+
496+
rewriter.replaceOp(rotateOp, result);
497+
return success();
498+
}
499+
461500
//===----------------------------------------------------------------------===//
462501
// Group ops
463502
//===----------------------------------------------------------------------===//
@@ -733,7 +772,7 @@ void mlir::populateGPUToSPIRVPatterns(const SPIRVTypeConverter &typeConverter,
733772
RewritePatternSet &patterns) {
734773
patterns.add<
735774
GPUBarrierConversion, GPUFuncOpConversion, GPUModuleConversion,
736-
GPUReturnOpConversion, GPUShuffleConversion,
775+
GPUReturnOpConversion, GPUShuffleConversion, GPURotateConversion,
737776
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
738777
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
739778
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,

mlir/lib/Conversion/GPUToSPIRV/WmmaOpsToSPIRV.cpp

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,68 @@ struct WmmaConstantOpToSPIRVLowering final
111111
}
112112
};
113113

114+
/// Converts GPU MMA ExtractOp to CompositeExtract SPIR-V KHR/NV cooperative
115+
/// matrix ops.
116+
struct WmmaExtractOpToSPIRVLowering final
117+
: OpConversionPattern<gpu::SubgroupMmaExtractOp> {
118+
using OpConversionPattern::OpConversionPattern;
119+
120+
LogicalResult
121+
matchAndRewrite(gpu::SubgroupMmaExtractOp op, OpAdaptor adaptor,
122+
ConversionPatternRewriter &rewriter) const override {
123+
Value matrix = adaptor.getMatrix();
124+
auto coopType =
125+
getTypeConverter()->convertType<spirv::CooperativeMatrixType>(
126+
matrix.getType());
127+
if (!coopType)
128+
return rewriter.notifyMatchFailure(op, "type conversion failed");
129+
130+
SmallVector<int32_t> intValues;
131+
for (Value val : op.getIndices()) {
132+
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
133+
intValues.push_back(static_cast<int32_t>(constOp.value()));
134+
} else {
135+
return rewriter.notifyMatchFailure(op, "Indices must be constants");
136+
}
137+
}
138+
139+
Type elementType = coopType.getElementType();
140+
rewriter.replaceOpWithNewOp<spirv::CompositeExtractOp>(
141+
op, elementType, matrix, rewriter.getI32ArrayAttr(intValues));
142+
return success();
143+
}
144+
};
145+
146+
/// Converts GPU MMA InsertOp to CompositeInsert SPIR-V KHR/NV cooperative
147+
/// matrix ops.
148+
struct WmmaInsertOpToSPIRVLowering final
149+
: OpConversionPattern<gpu::SubgroupMmaInsertOp> {
150+
using OpConversionPattern::OpConversionPattern;
151+
152+
LogicalResult
153+
matchAndRewrite(gpu::SubgroupMmaInsertOp op, OpAdaptor adaptor,
154+
ConversionPatternRewriter &rewriter) const override {
155+
Value value = adaptor.getValue();
156+
Value matrix = adaptor.getMatrix();
157+
auto coopType = getTypeConverter()->convertType(matrix.getType());
158+
if (!coopType)
159+
return rewriter.notifyMatchFailure(op, "type conversion failed");
160+
161+
SmallVector<int32_t> intValues;
162+
for (Value val : op.getIndices()) {
163+
if (auto constOp = val.getDefiningOp<arith::ConstantIndexOp>()) {
164+
intValues.push_back(static_cast<int32_t>(constOp.value()));
165+
} else {
166+
return rewriter.notifyMatchFailure(op, "Indices must be constants");
167+
}
168+
}
169+
170+
rewriter.replaceOpWithNewOp<spirv::CompositeInsertOp>(
171+
op, coopType, value, matrix, rewriter.getI32ArrayAttr(intValues));
172+
return success();
173+
}
174+
};
175+
114176
/// Converts elementwise ops to SPIR-V cooperative matrix elementwise ops for
115177
/// the default case.
116178
struct WmmaElementwiseOpToSPIRVDefaultLowering final
@@ -296,6 +358,7 @@ void mlir::populateGpuWMMAToSPIRVCoopMatrixKHRConversionPatterns(
296358
MLIRContext *context = patterns.getContext();
297359
patterns.add<khr::WmmaLoadOpToSPIRVLowering, khr::WmmaMmaOpToSPIRVLowering,
298360
khr::WmmaStoreOpToSPIRVLowering, WmmaConstantOpToSPIRVLowering,
361+
WmmaExtractOpToSPIRVLowering, WmmaInsertOpToSPIRVLowering,
299362
WmmaElementwiseOpToSPIRVDefaultLowering>(converter, context);
300363
// Give the following patterns higher benefit to prevail over the default one.
301364
patterns.add<WmmaElementwiseOpToSPIRVScalarMulLowering>(converter, context,
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: mlir-opt -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
2+
3+
module attributes {
4+
gpu.container_module,
5+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformRotateKHR], []>, #spirv.resource_limits<subgroup_size = 16>>
6+
} {
7+
8+
gpu.module @kernels {
9+
// CHECK-LABEL: spirv.func @rotate()
10+
gpu.func @rotate() kernel
11+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [4, 4, 1]>} {
12+
// CHECK: %[[CST8_I32:.*]] = spirv.Constant 8 : i32
13+
// CHECK: %[[CST16_I32:.*]] = spirv.Constant 16 : i32
14+
// CHECK: %[[CST_F32:.*]] = spirv.Constant 4.200000e+01 : f32
15+
%offset = arith.constant 8 : i32
16+
%width = arith.constant 16 : i32
17+
%val = arith.constant 42.0 : f32
18+
19+
// CHECK: spirv.GroupNonUniformRotateKHR <Subgroup>, %[[CST_F32]], %[[CST8_I32]], cluster_size(%[[CST16_I32]])
20+
%result = gpu.rotate %val, %offset, %width : f32
21+
gpu.return
22+
}
23+
}
24+
25+
}

mlir/test/Conversion/GPUToSPIRV/wmma-ops-to-spirv-khr-coop-matrix.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,33 @@ module attributes {
9393
gpu.return
9494
}
9595

96+
// CHECK-LABEL: spirv.func @gpu_wmma_extract_op
97+
// CHECK-SAME: %[[ARG0:.+]]: !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
98+
gpu.func @gpu_wmma_extract_op(%m: !gpu.mma_matrix<16x16xf32, "AOp">,
99+
%ptr: memref<16x16xf32, #spirv.storage_class<StorageBuffer>>) kernel
100+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
101+
// CHECK: spirv.CompositeExtract %[[ARG0]][0 : i32] : !spirv.coopmatrix<16x16xf32, Subgroup, MatrixA>
102+
%c0 = arith.constant 0 : index
103+
%val = gpu.subgroup_mma_extract %m[%c0] : !gpu.mma_matrix<16x16xf32, "AOp"> -> f32
104+
memref.store %val, %ptr[%c0, %c0] : memref<16x16xf32, #spirv.storage_class<StorageBuffer>>
105+
gpu.return
106+
}
107+
108+
// CHECK-LABEL: spirv.func @gpu_wmma_insert_op
109+
// CHECK-SAME: %[[ARG0:.+]]: f16
110+
// CHECK-SAME: %[[ARG1:.+]]: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
111+
gpu.func @gpu_wmma_insert_op(%val: f16,
112+
%m: !gpu.mma_matrix<16x16xf16, "COp">,
113+
%ptr: memref<16x16xf16, #spirv.storage_class<StorageBuffer>>) kernel
114+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [32, 4, 1]>} {
115+
// CHECK: spirv.CompositeInsert %[[ARG0]], %[[ARG1]][0 : i32] : f16 into !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
116+
%c0 = arith.constant 0 : index
117+
%s0 = gpu.subgroup_mma_insert %val, %m[%c0] : f16, !gpu.mma_matrix<16x16xf16, "COp"> -> !gpu.mma_matrix<16x16xf16, "COp">
118+
gpu.subgroup_mma_store_matrix %s0, %ptr[%c0,%c0] {leadDimension = 16 : index} :
119+
!gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16, #spirv.storage_class<StorageBuffer>>
120+
gpu.return
121+
}
122+
96123
// CHECK-LABEL: spirv.func @gpu_wmma_elementwise_op_default
97124
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>
98125
// CHECK-SAME: !spirv.coopmatrix<16x16xf16, Subgroup, MatrixAcc>

0 commit comments

Comments
 (0)