Skip to content

Commit 06c02d5

Browse files
committed
[mlir][linalg] Fix tiling interface implementation offset calculation
The tiling interface implementation was making assumption on the code generated by makeTiledShape which were wrong. The ExtractSliceOp create may be combined with other ExtractSliceOp. To solve that we compute directly the offset using the new utilities. Differential Revision: https://reviews.llvm.org/D132182
1 parent 89167e3 commit 06c02d5

File tree

4 files changed

+60
-44
lines changed

4 files changed

+60
-44
lines changed

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 15 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -177,18 +177,6 @@ mlir::linalg::computeMultiTileSizes(OpBuilder &builder, LinalgOp op,
177177
return spec;
178178
}
179179

180-
/// Given a `subsetExtractOp`, a `source` and a `dest`, create a new
181-
/// `ParallelInsertSlice` op of `source` into `dest` at the same subset location
182-
/// as `subsetExtractOp`.
183-
static void
184-
createMatchingParallelSubsetInsertOp(OpBuilder &b, Location loc,
185-
tensor::ExtractSliceOp subsetExtractOp,
186-
Value source, Value dest) {
187-
b.create<tensor::ParallelInsertSliceOp>(
188-
loc, source, dest, subsetExtractOp.getMixedOffsets(),
189-
subsetExtractOp.getMixedSizes(), subsetExtractOp.getMixedStrides());
190-
}
191-
192180
/// Returns true if the maximum tile offset `tileSize * numThreads-1` is less
193181
/// than `iterationSize`.
194182
static bool canOmitTileOffsetInBoundsCheck(OpFoldResult tileSize,
@@ -333,16 +321,21 @@ static FailureOr<ForeachThreadTilingResult> tileToForeachThreadOpImpl(
333321

334322
auto tilingInterfaceOp = dyn_cast<TilingInterface>(tiledOp);
335323
assert(tilingInterfaceOp && "Tiled op does not implement TilingInterface");
336-
337-
auto tiledDestOperands = tilingInterfaceOp.getDestinationOperands(b);
338-
339-
// Create terminator with parallel subset insert operations.
340-
b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody());
341-
for (auto it : llvm::zip(tiledDestOperands, tilingInterfaceOp->getResults(),
342-
destOperands)) {
343-
createMatchingParallelSubsetInsertOp(
344-
b, loc, cast<tensor::ExtractSliceOp>(std::get<0>(it).getDefiningOp()),
345-
std::get<1>(it), std::get<2>(it));
324+
OpBuilder::InsertPoint insertPt = b.saveInsertionPoint();
325+
for (auto it :
326+
llvm::zip(llvm::seq(unsigned(0), unsigned(destOperands.size())),
327+
tilingInterfaceOp->getResults(), destOperands)) {
328+
b.setInsertionPoint(insertPt.getBlock(), insertPt.getPoint());
329+
SmallVector<OpFoldResult> resultOffsets, resultSizes;
330+
if (failed(op.getResultTilePosition(b, std::get<0>(it), tiledOffsets,
331+
tiledSizes, resultOffsets,
332+
resultSizes)))
333+
return op->emitOpError("output offsets couldn't be calculated");
334+
SmallVector<OpFoldResult> strides(resultSizes.size(), b.getIndexAttr(1));
335+
b.setInsertionPointToStart(foreachThreadOp.getTerminator().getBody());
336+
b.create<tensor::ParallelInsertSliceOp>(loc, std::get<1>(it),
337+
std::get<2>(it), resultOffsets,
338+
resultSizes, strides);
346339
}
347340
return ForeachThreadTilingResult{foreachThreadOp, tiledOp};
348341
}

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -161,15 +161,12 @@ struct LinalgOpTilingInterface
161161
}));
162162

163163
OpOperand *outOperand = linalgOp.getOutputOperand(resultNumber);
164-
Value sliceOpResult =
165-
makeTiledShape(b, loc, outOperand->get(), sizes,
166-
linalgOp.getTiedIndexingMap(outOperand), offsets,
167-
/*ubs*/ {}, subShapeSizes, true);
168-
auto sliceOp = sliceOpResult.getDefiningOp<tensor::ExtractSliceOp>();
169-
if (!sliceOp)
170-
return failure();
171-
resultOffsets = sliceOp.getMixedOffsets();
172-
resultSizes = sliceOp.getMixedSizes();
164+
SliceParameters sliceParams =
165+
computeSliceParameters(b, loc, outOperand->get(), sizes,
166+
linalgOp.getTiedIndexingMap(outOperand), offsets,
167+
/*ubs*/ {}, subShapeSizes, true);
168+
resultOffsets = sliceParams.offsets;
169+
resultSizes = sliceParams.sizes;
173170
return success();
174171
}
175172

mlir/test/Dialect/Linalg/multisize-tiling-full.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ func.func @two_d(%arg0: tensor<10x34xf32>,
5959
// CHECK: %[[RESPARTIAL:.+]] = tensor.insert_slice %[[RESSLICE_1]] into %[[ITERARG_2]]
6060
// CHECK: scf.yield %[[RESPARTIAL]]
6161

62-
// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][%[[I1]], 0] [2, 16] [1, 1]
62+
// CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[LOOPRES]] into %[[OUTSLICE_1]][0, 0] [2, 16] [1, 1]
6363
// CHECK: %[[OUTSLICE_3:.+]] = tensor.extract_slice %[[INSERTED]][0, 16] [2, 18] [1, 1]
6464
// CHECK: scf.for %{{.*}} iter_args(%{{.*}} = %[[OUTSLICE_3]])
6565
// CHECK-COUNT-2: tensor.extract_slice

mlir/test/Dialect/Linalg/tile-to-foreach-thread.mlir

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -split-input-file | FileCheck %s
1+
// RUN: mlir-opt %s --test-transform-dialect-interpreter -canonicalize -cse -split-input-file | FileCheck %s
22

33
// Offset per thread:
44
// CHECK-DAG: affine_map<(d0)[s0] -> (d0 * (s0 ceildiv 10))>
@@ -22,7 +22,7 @@ module {
2222
// CHECK: %[[RES:.*]] = linalg.matmul
2323
// CHECK-SAME: ins(%[[tA]], %[[tB]] : tensor<?x?xf32>, tensor<?x?xf32>)
2424
// CHECK-SAME: outs(%[[tC]] : tensor<?x?xf32>) -> tensor<?x?xf32>
25-
// CHECK-NEXT: scf.foreach_thread.perform_concurrently {
25+
// CHECK: scf.foreach_thread.perform_concurrently {
2626
// CHECK-NEXT: tensor.parallel_insert_slice %[[RES]] into %[[C]]{{.*}} :
2727
// CHECK-SAME: tensor<?x?xf32> into tensor<?x?xf32>
2828
// CHECK-NEXT: }
@@ -65,11 +65,9 @@ func.func @matmul_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf32>, %C: t
6565
// CHECK-NOT: affine.max
6666
// CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
6767
// CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
68-
// CHECK: %[[LB0_1:.+]] = affine.apply #[[$map2]](%[[IV0]])
69-
// CHECK: %[[LB1_1:.+]] = affine.apply #[[$map3]](%[[IV1]])
7068
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
7169
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
72-
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0_1]], %[[LB1_1]]] [10, %[[TS]]] [1, 1] :
70+
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
7371
// CHECK: linalg.matmul
7472
// CHECK: scf.foreach_thread.perform_concurrently
7573
// CHECK-NEXT: tensor.parallel_insert_slice
@@ -106,17 +104,13 @@ func.func @matmul_tile_size_dynamic(%A: tensor<?x?xf32>, %B: tensor<?x?xf32>, %C
106104
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
107105
// CHECK: %[[NT0:.+]] = affine.apply #map0()[%[[M]]]
108106
// CHECK: %[[NT1:.+]] = affine.apply #map1()[%[[N]]]
109-
// CHECK: %[[M:.+]] = tensor.dim %[[A]], %c0 :
110-
// CHECK: %[[N:.+]] = tensor.dim %[[B]], %c1 :
111107
// CHECK: scf.foreach_thread (%[[IV0:.+]], %[[IV1:.+]]) in (%[[NT0]], %[[NT1]])
112108
// CHECK: %[[TS0:.+]] = affine.min #[[$map2]](%[[IV0]])[%[[M]]]
113109
// CHECK: %[[TS1:.+]] = affine.min #[[$map4]](%[[IV1]])[%[[N]]]
114110
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
115111
// CHECK tensor.extract_slice %[[A]]
116112
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
117113
// CHECK tensor.extract_slice %[[B]]
118-
// CHECK: %[[LB0:.+]] = affine.apply #[[$map5]](%[[IV0]])
119-
// CHECK: %[[LB1:.+]] = affine.apply #[[$map6]](%[[IV1]])
120114
// CHECK tensor.extract_slice %[[C]]
121115
// CHECK: linalg.matmul
122116
// CHECK: scf.foreach_thread.perform_concurrently
@@ -156,11 +150,9 @@ func.func @matmul_tile_size_static(%A: tensor<100x200xf32>, %B: tensor<200x300xf
156150
// CHECK-NOT: affine.min
157151
// CHECK: %[[LB0:.+]] = affine.apply #[[$map2]](%[[IV0]])
158152
// CHECK: %[[LB1:.+]] = affine.apply #[[$map3]](%[[IV1]])
159-
// CHECK: %[[LB0_1:.+]] = affine.apply #[[$map2]](%[[IV0]])
160-
// CHECK: %[[LB1_1:.+]] = affine.apply #[[$map3]](%[[IV1]])
161153
// CHECK: %[[tA:.+]] = tensor.extract_slice %[[A]][%[[LB0]], 0] [10, 200] [1, 1] :
162154
// CHECK: %[[tB:.+]] = tensor.extract_slice %[[B]][0, %[[LB1]]] [200, %[[TS]]] [1, 1] :
163-
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0_1]], %[[LB1_1]]] [10, %[[TS]]] [1, 1] :
155+
// CHECK: %[[tC:.+]] = tensor.extract_slice %[[C]][%[[LB0]], %[[LB1]]] [10, %[[TS]]] [1, 1] :
164156
// CHECK: linalg.matmul
165157
// CHECK: scf.foreach_thread.perform_concurrently
166158
// CHECK-NEXT: tensor.parallel_insert_slice
@@ -177,3 +169,37 @@ transform.with_pdl_patterns {
177169
%1:2 = transform.structured.tile_to_foreach_thread_op %0 tile_sizes [10, 21]
178170
}
179171
}
172+
173+
// -----
174+
175+
module {
176+
func.func @extract_source(%A: tensor<4xf32>, %B: tensor<16xf32>) -> tensor<4xf32> {
177+
%B1 = tensor.extract_slice %B[10] [4] [1] : tensor<16xf32> to tensor<4xf32>
178+
%result = linalg.generic {indexing_maps = [
179+
affine_map<(d0) -> (d0)>,affine_map<(d0) -> (d0)>],
180+
iterator_types = ["parallel"]}
181+
ins(%A : tensor<4xf32>) outs(%B1 : tensor<4xf32>) {
182+
^bb0(%arg3: f32, %arg4: f32): // no predecessors
183+
%2 = arith.addf %arg3, %arg3 : f32
184+
linalg.yield %2 : f32
185+
} -> tensor<4xf32>
186+
return %result : tensor<4xf32>
187+
}
188+
189+
transform.with_pdl_patterns {
190+
^bb0(%arg0: !pdl.operation):
191+
transform.sequence %arg0 failures(propagate) {
192+
^bb1(%arg1: !pdl.operation):
193+
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1
194+
%1:2 = transform.structured.tile_to_foreach_thread_op %0 num_threads [2] (mapped to dims [0])
195+
}
196+
}
197+
}
198+
// CHECK-DAG: #[[$map0:.+]] = affine_map<(d0) -> (d0 * 2)>
199+
200+
// CHECK-LABEL: extract_source(
201+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
202+
// CHECK: scf.foreach_thread (%[[ARG:.*]]) in (%[[C2]]) -> (tensor<4xf32>) {
203+
// CHECK: %[[OFF:.*]] = affine.apply #[[$map0]](%[[ARG]])
204+
// CHECK: scf.foreach_thread.perform_concurrently {
205+
// CHECK: tensor.parallel_insert_slice %{{.*}} into %{{.*}}[%[[OFF]]] [2] [1] : tensor<2xf32> into tensor<4xf32>

0 commit comments

Comments
 (0)