Skip to content

Commit e796924

Browse files
committed
[mlir][VectorToGPU] Support more cases in conversion to MMA ops
Support load with broadcast, elementwise divf op and remove the hardcoded restriction on the vector size. Picking the right size should be enfored by user and will fail conversion to llvm/spirv if it is not supported. Differential Revision: https://reviews.llvm.org/D113618
1 parent 04cbfa9 commit e796924

File tree

4 files changed

+88
-25
lines changed

4 files changed

+88
-25
lines changed

mlir/include/mlir/Dialect/GPU/GPUOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1130,11 +1130,12 @@ def GPU_ELEMENTWISE_OP_ADD : StrEnumAttrCase<"ADDF">;
11301130
def GPU_ELEMENTWISE_OP_MUL : StrEnumAttrCase<"MULF">;
11311131
def GPU_ELEMENTWISE_OP_MAXF : StrEnumAttrCase<"MAXF">;
11321132
def GPU_ELEMENTWISE_OP_MINF : StrEnumAttrCase<"MINF">;
1133+
def GPU_ELEMENTWISE_OP_DIVF : StrEnumAttrCase<"DIVF">;
11331134

11341135
def MMAElementWiseAttr : StrEnumAttr<"MMAElementwiseOp",
11351136
"elementwise operation to apply to mma matrix",
11361137
[GPU_ELEMENTWISE_OP_ADD, GPU_ELEMENTWISE_OP_MUL,
1137-
GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF]> {
1138+
GPU_ELEMENTWISE_OP_MAXF, GPU_ELEMENTWISE_OP_MINF, GPU_ELEMENTWISE_OP_DIVF]> {
11381139
let cppNamespace = "::mlir::gpu";
11391140
let storageType = "::mlir::StringAttr";
11401141
let returnType = "::mlir::gpu::MMAElementwiseOp";

mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,8 @@ static Value createScalarOp(OpBuilder &builder, Location loc,
304304
return builder.create<LLVM::FAddOp>(loc, operands[0].getType(), operands);
305305
case gpu::MMAElementwiseOp::MULF:
306306
return builder.create<LLVM::FMulOp>(loc, operands[0].getType(), operands);
307+
case gpu::MMAElementwiseOp::DIVF:
308+
return builder.create<LLVM::FDivOp>(loc, operands[0].getType(), operands);
307309
case gpu::MMAElementwiseOp::MAXF:
308310
return createMinMaxF(builder, loc, operands[0], operands[1],
309311
/*isMin=*/false);

mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp

Lines changed: 59 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,26 +50,7 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
5050
if (contract.getIndexingMaps() != infer({{m, k}, {k, n}, {m, n}}))
5151
return false;
5252

53-
// Check that the size matches what is natively supported.
54-
VectorType lhsType = contract.lhs().getType().cast<VectorType>();
55-
VectorType rhsType = contract.rhs().getType().cast<VectorType>();
56-
VectorType accType = contract.acc().getType().cast<VectorType>();
57-
58-
std::tuple<int, int, int> dim(lhsType.getDimSize(0), rhsType.getDimSize(1),
59-
lhsType.getDimSize(1));
60-
if (lhsType.getElementType().isInteger(8) &&
61-
rhsType.getElementType().isInteger(8) &&
62-
accType.getElementType().isInteger(32) &&
63-
(dim == std::make_tuple(8, 8, 32) || dim == std::make_tuple(16, 16, 32) ||
64-
dim == std::make_tuple(16, 8, 32)))
65-
return true;
66-
67-
if (lhsType.getElementType().isF16() && rhsType.getElementType().isF16() &&
68-
(accType.getElementType().isF16() || accType.getElementType().isF32()) &&
69-
(dim == std::make_tuple(8, 8, 16) || dim == std::make_tuple(16, 16, 16) ||
70-
dim == std::make_tuple(16, 8, 16)))
71-
return true;
72-
return false;
53+
return true;
7354
}
7455

7556
// Return the stide for the dimension 0 of |type| if it is a memref and has a
@@ -95,8 +76,15 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
9576
return false;
9677
if (!getMemrefConstantHorizontalStride(readOp.getShapedType()))
9778
return false;
79+
AffineMap map = readOp.permutation_map();
80+
OpBuilder b(readOp.getContext());
81+
AffineExpr innerDim = b.getAffineDimExpr(map.getNumDims() - 1);
82+
AffineExpr zero = b.getAffineConstantExpr(0);
83+
auto broadcastInnerDim = AffineMap::get(map.getNumDims(), 0, {zero, innerDim},
84+
readOp.getContext());
9885
// TODO: Support transpose once it is added to GPU dialect ops.
99-
if (!readOp.permutation_map().isMinorIdentity())
86+
// For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
87+
if (!map.isMinorIdentity() && map != broadcastInnerDim)
10088
return false;
10189
return true;
10290
}
@@ -142,6 +130,8 @@ convertElementwiseOpToMMA(Operation *op) {
142130
return gpu::MMAElementwiseOp::MAXF;
143131
if (isa<MinFOp>(op))
144132
return gpu::MMAElementwiseOp::MINF;
133+
if (isa<arith::DivFOp>(op))
134+
return gpu::MMAElementwiseOp::DIVF;
145135
return llvm::None;
146136
}
147137

@@ -166,6 +156,44 @@ static bool supportsMMaMatrixType(Operation *op) {
166156
return elementwiseSupportsMMAMatrixType(op);
167157
}
168158

159+
/// Return an unsorted slice handling scf.for region differently than
160+
/// `getSlice`. In scf.for we only want to include as part of the slice elements
161+
/// that are part of the use/def chain.
162+
static SetVector<Operation *> getSliceContract(Operation *op,
163+
TransitiveFilter backwardFilter,
164+
TransitiveFilter forwardFilter) {
165+
SetVector<Operation *> slice;
166+
slice.insert(op);
167+
unsigned currentIndex = 0;
168+
SetVector<Operation *> backwardSlice;
169+
SetVector<Operation *> forwardSlice;
170+
while (currentIndex != slice.size()) {
171+
auto *currentOp = (slice)[currentIndex];
172+
// Compute and insert the backwardSlice starting from currentOp.
173+
backwardSlice.clear();
174+
getBackwardSlice(currentOp, &backwardSlice, backwardFilter);
175+
slice.insert(backwardSlice.begin(), backwardSlice.end());
176+
177+
// Compute and insert the forwardSlice starting from currentOp.
178+
forwardSlice.clear();
179+
// Special case for ForOp, we don't want to include the whole region but
180+
// only the value using the region arguments.
181+
// TODO: We should refine this to only care about the region arguments being
182+
// converted to matrix type.
183+
if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
184+
for (Value forOpResult : forOp.getResults())
185+
getForwardSlice(forOpResult, &forwardSlice, forwardFilter);
186+
for (BlockArgument &arg : forOp.getRegionIterArgs())
187+
getForwardSlice(arg, &forwardSlice, forwardFilter);
188+
} else {
189+
getForwardSlice(currentOp, &forwardSlice, forwardFilter);
190+
}
191+
slice.insert(forwardSlice.begin(), forwardSlice.end());
192+
++currentIndex;
193+
}
194+
return slice;
195+
}
196+
169197
// Analyze slice of operations based on convert op to figure out if the whole
170198
// slice can be converted to MMA operations.
171199
static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
@@ -182,16 +210,17 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
182210
if (opToConvert.contains(contract.getOperation()))
183211
return;
184212
SetVector<Operation *> dependentOps =
185-
getSlice(contract, hasVectorDest, hasVectorSrc);
213+
getSliceContract(contract, hasVectorDest, hasVectorSrc);
186214
// If any instruction cannot use MMA matrix type drop the whole
187-
// chaine. MMA matrix are stored in an opaque type so they cannot be used
215+
// chain. MMA matrix are stored in an opaque type so they cannot be used
188216
// by all operations.
189217
if (llvm::any_of(dependentOps,
190218
[](Operation *op) { return !supportsMMaMatrixType(op); }))
191219
return;
192220
opToConvert.insert(dependentOps.begin(), dependentOps.end());
193221
});
194-
return opToConvert;
222+
// Sort the operations so that we can convert them in topological order.
223+
return topologicalSort(opToConvert);
195224
}
196225

197226
namespace {
@@ -309,6 +338,12 @@ static void convertTransferReadOp(vector::TransferReadOp op,
309338
assert(transferReadSupportsMMAMatrixType(op));
310339
Optional<int64_t> stride =
311340
getMemrefConstantHorizontalStride(op.getShapedType());
341+
AffineMap map = op.permutation_map();
342+
// Handle broadcast by setting the stride to 0.
343+
if (map.getResult(0).isa<AffineConstantExpr>()) {
344+
assert(map.getResult(0).cast<AffineConstantExpr>().getValue() == 0);
345+
stride = 0;
346+
}
312347
assert(stride);
313348
const char *fragType = inferFragType(op);
314349
gpu::MMAMatrixType type =

mlir/test/Conversion/VectorToGPU/vector-to-mma-ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,28 @@ func @matmul_fused_elementwise(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16
106106
vector.transfer_write %E, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
107107
return
108108
}
109+
110+
// CHECK-LABEL: func @matmul_fused_broadcast
111+
// CHECK-DAG: %[[CST_0:.+]] = arith.constant 0.000000e+00 : f16
112+
// CHECK-DAG: %[[A:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "AOp">
113+
// CHECK-DAG: %[[B:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : memref<16x16xf16> -> !gpu.mma_matrix<16x16xf16, "BOp">
114+
// CHECK-DAG: %[[C0:.+]] = gpu.subgroup_mma_constant_matrix %[[CST_0]] : !gpu.mma_matrix<16x16xf16, "COp">
115+
// CHECK: %[[D:.+]] = gpu.subgroup_mma_compute %[[A]], %[[B]], %[[C0]] : !gpu.mma_matrix<16x16xf16, "AOp">, !gpu.mma_matrix<16x16xf16, "BOp"> -> !gpu.mma_matrix<16x16xf16, "COp">
116+
// CHECK: %[[E:.+]] = gpu.subgroup_mma_load_matrix %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] {leadDimension = 0 : index} : memref<16x16x16x16xf16> -> !gpu.mma_matrix<16x16xf16, "COp">
117+
// CHECK: %[[F:.+]] = gpu.subgroup_mma_elementwise %[[D]], %[[E]] {operation = "DIVF"} : (!gpu.mma_matrix<16x16xf16, "COp">, !gpu.mma_matrix<16x16xf16, "COp">) -> !gpu.mma_matrix<16x16xf16, "COp">
118+
// CHECK: gpu.subgroup_mma_store_matrix %[[F]], %{{.*}}[%{{.*}}, %{{.*}}] {leadDimension = 16 : index} : !gpu.mma_matrix<16x16xf16, "COp">, memref<16x16xf16>
119+
func @matmul_fused_broadcast(%arg0: memref<16x16xf16>, %arg1: memref<16x16xf16>,
120+
%arg2: memref<16x16xf16>, %arg3: memref<16x16x16x16xf16>) {
121+
%cst_0 = arith.constant dense<0.000000e+00> : vector<16x16xf16>
122+
%c0 = arith.constant 0 : index
123+
%cst = arith.constant 0.000000e+00 : f16
124+
%A = vector.transfer_read %arg0[%c0, %c0], %cst {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
125+
%B = vector.transfer_read %arg1[%c0, %c0], %cst {permutation_map = #map0, in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
126+
%D = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %A, %B, %cst_0 : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf16>
127+
%E = vector.transfer_read %arg3[%c0, %c0, %c0, %c0], %cst
128+
{in_bounds = [true, true], permutation_map = affine_map<(d0, d1, d2, d3)->(0, d3)>}
129+
: memref<16x16x16x16xf16>, vector<16x16xf16>
130+
%F = arith.divf %D, %E : vector<16x16xf16>
131+
vector.transfer_write %F, %arg2[%c0, %c0] {in_bounds = [true, true]} : vector<16x16xf16>, memref<16x16xf16>
132+
return
133+
}

0 commit comments

Comments
 (0)