@@ -50,26 +50,7 @@ static bool contractSupportsMMAMatrixType(vector::ContractionOp contract) {
50
50
if (contract.getIndexingMaps () != infer ({{m, k}, {k, n}, {m, n}}))
51
51
return false ;
52
52
53
- // Check that the size matches what is natively supported.
54
- VectorType lhsType = contract.lhs ().getType ().cast <VectorType>();
55
- VectorType rhsType = contract.rhs ().getType ().cast <VectorType>();
56
- VectorType accType = contract.acc ().getType ().cast <VectorType>();
57
-
58
- std::tuple<int , int , int > dim (lhsType.getDimSize (0 ), rhsType.getDimSize (1 ),
59
- lhsType.getDimSize (1 ));
60
- if (lhsType.getElementType ().isInteger (8 ) &&
61
- rhsType.getElementType ().isInteger (8 ) &&
62
- accType.getElementType ().isInteger (32 ) &&
63
- (dim == std::make_tuple (8 , 8 , 32 ) || dim == std::make_tuple (16 , 16 , 32 ) ||
64
- dim == std::make_tuple (16 , 8 , 32 )))
65
- return true ;
66
-
67
- if (lhsType.getElementType ().isF16 () && rhsType.getElementType ().isF16 () &&
68
- (accType.getElementType ().isF16 () || accType.getElementType ().isF32 ()) &&
69
- (dim == std::make_tuple (8 , 8 , 16 ) || dim == std::make_tuple (16 , 16 , 16 ) ||
70
- dim == std::make_tuple (16 , 8 , 16 )))
71
- return true ;
72
- return false ;
53
+ return true ;
73
54
}
74
55
75
56
// Return the stide for the dimension 0 of |type| if it is a memref and has a
@@ -95,8 +76,15 @@ static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
95
76
return false ;
96
77
if (!getMemrefConstantHorizontalStride (readOp.getShapedType ()))
97
78
return false ;
79
+ AffineMap map = readOp.permutation_map ();
80
+ OpBuilder b (readOp.getContext ());
81
+ AffineExpr innerDim = b.getAffineDimExpr (map.getNumDims () - 1 );
82
+ AffineExpr zero = b.getAffineConstantExpr (0 );
83
+ auto broadcastInnerDim = AffineMap::get (map.getNumDims (), 0 , {zero, innerDim},
84
+ readOp.getContext ());
98
85
// TODO: Support transpose once it is added to GPU dialect ops.
99
- if (!readOp.permutation_map ().isMinorIdentity ())
86
+ // For now we only support (d0, d1) -> (d0, d1) and (d0, d1) -> (0, d1).
87
+ if (!map.isMinorIdentity () && map != broadcastInnerDim)
100
88
return false ;
101
89
return true ;
102
90
}
@@ -142,6 +130,8 @@ convertElementwiseOpToMMA(Operation *op) {
142
130
return gpu::MMAElementwiseOp::MAXF;
143
131
if (isa<MinFOp>(op))
144
132
return gpu::MMAElementwiseOp::MINF;
133
+ if (isa<arith::DivFOp>(op))
134
+ return gpu::MMAElementwiseOp::DIVF;
145
135
return llvm::None;
146
136
}
147
137
@@ -166,6 +156,44 @@ static bool supportsMMaMatrixType(Operation *op) {
166
156
return elementwiseSupportsMMAMatrixType (op);
167
157
}
168
158
159
+ // / Return an unsorted slice handling scf.for region differently than
160
+ // / `getSlice`. In scf.for we only want to include as part of the slice elements
161
+ // / that are part of the use/def chain.
162
+ static SetVector<Operation *> getSliceContract (Operation *op,
163
+ TransitiveFilter backwardFilter,
164
+ TransitiveFilter forwardFilter) {
165
+ SetVector<Operation *> slice;
166
+ slice.insert (op);
167
+ unsigned currentIndex = 0 ;
168
+ SetVector<Operation *> backwardSlice;
169
+ SetVector<Operation *> forwardSlice;
170
+ while (currentIndex != slice.size ()) {
171
+ auto *currentOp = (slice)[currentIndex];
172
+ // Compute and insert the backwardSlice starting from currentOp.
173
+ backwardSlice.clear ();
174
+ getBackwardSlice (currentOp, &backwardSlice, backwardFilter);
175
+ slice.insert (backwardSlice.begin (), backwardSlice.end ());
176
+
177
+ // Compute and insert the forwardSlice starting from currentOp.
178
+ forwardSlice.clear ();
179
+ // Special case for ForOp, we don't want to include the whole region but
180
+ // only the value using the region arguments.
181
+ // TODO: We should refine this to only care about the region arguments being
182
+ // converted to matrix type.
183
+ if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
184
+ for (Value forOpResult : forOp.getResults ())
185
+ getForwardSlice (forOpResult, &forwardSlice, forwardFilter);
186
+ for (BlockArgument &arg : forOp.getRegionIterArgs ())
187
+ getForwardSlice (arg, &forwardSlice, forwardFilter);
188
+ } else {
189
+ getForwardSlice (currentOp, &forwardSlice, forwardFilter);
190
+ }
191
+ slice.insert (forwardSlice.begin (), forwardSlice.end ());
192
+ ++currentIndex;
193
+ }
194
+ return slice;
195
+ }
196
+
169
197
// Analyze slice of operations based on convert op to figure out if the whole
170
198
// slice can be converted to MMA operations.
171
199
static SetVector<Operation *> getOpToConvert (mlir::Operation *op) {
@@ -182,16 +210,17 @@ static SetVector<Operation *> getOpToConvert(mlir::Operation *op) {
182
210
if (opToConvert.contains (contract.getOperation ()))
183
211
return ;
184
212
SetVector<Operation *> dependentOps =
185
- getSlice (contract, hasVectorDest, hasVectorSrc);
213
+ getSliceContract (contract, hasVectorDest, hasVectorSrc);
186
214
// If any instruction cannot use MMA matrix type drop the whole
187
- // chaine . MMA matrix are stored in an opaque type so they cannot be used
215
+ // chain . MMA matrix are stored in an opaque type so they cannot be used
188
216
// by all operations.
189
217
if (llvm::any_of (dependentOps,
190
218
[](Operation *op) { return !supportsMMaMatrixType (op); }))
191
219
return ;
192
220
opToConvert.insert (dependentOps.begin (), dependentOps.end ());
193
221
});
194
- return opToConvert;
222
+ // Sort the operations so that we can convert them in topological order.
223
+ return topologicalSort (opToConvert);
195
224
}
196
225
197
226
namespace {
@@ -309,6 +338,12 @@ static void convertTransferReadOp(vector::TransferReadOp op,
309
338
assert (transferReadSupportsMMAMatrixType (op));
310
339
Optional<int64_t > stride =
311
340
getMemrefConstantHorizontalStride (op.getShapedType ());
341
+ AffineMap map = op.permutation_map ();
342
+ // Handle broadcast by setting the stride to 0.
343
+ if (map.getResult (0 ).isa <AffineConstantExpr>()) {
344
+ assert (map.getResult (0 ).cast <AffineConstantExpr>().getValue () == 0 );
345
+ stride = 0 ;
346
+ }
312
347
assert (stride);
313
348
const char *fragType = inferFragType (op);
314
349
gpu::MMAMatrixType type =
0 commit comments