Skip to content

Commit 10dc8bc

Browse files
authored
[mlir][vector] Fix for WarpOpScfForOp failure when scf.for has results that are unused. (#141853)
Currently, only the values defined outside ForOp but inside the original WarpOp are considered "escaping values". However this is not true if the ForOp has some unused results. In this case, corresponding IterArgs must also be yielded by the original WarpOp. This PR adds the required code changes to achieve this.
1 parent 5213c57 commit 10dc8bc

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,22 +1554,36 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
15541554
llvm::SmallSetVector<Value, 32> escapingValues;
15551555
SmallVector<Type> inputTypes;
15561556
SmallVector<Type> distTypes;
1557+
auto collectEscapingValues = [&](Value value) {
1558+
if (!escapingValues.insert(value))
1559+
return;
1560+
Type distType = value.getType();
1561+
if (auto vecType = dyn_cast<VectorType>(distType)) {
1562+
AffineMap map = distributionMapFn(value);
1563+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1564+
}
1565+
inputTypes.push_back(value.getType());
1566+
distTypes.push_back(distType);
1567+
};
1568+
15571569
mlir::visitUsedValuesDefinedAbove(
15581570
forOp.getBodyRegion(), [&](OpOperand *operand) {
15591571
Operation *parent = operand->get().getParentRegion()->getParentOp();
15601572
if (warpOp->isAncestor(parent)) {
1561-
if (!escapingValues.insert(operand->get()))
1562-
return;
1563-
Type distType = operand->get().getType();
1564-
if (auto vecType = dyn_cast<VectorType>(distType)) {
1565-
AffineMap map = distributionMapFn(operand->get());
1566-
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1567-
}
1568-
inputTypes.push_back(operand->get().getType());
1569-
distTypes.push_back(distType);
1573+
collectEscapingValues(operand->get());
15701574
}
15711575
});
15721576

1577+
// Any forOp result that is not already yielded by the warpOp
1578+
// region is also considered escaping and must be returned by the
1579+
// original warpOp.
1580+
for (OpResult forResult : forOp.getResults()) {
1581+
// Check if this forResult is already yielded by the yield op.
1582+
if (llvm::is_contained(yield->getOperands(), forResult))
1583+
continue;
1584+
collectEscapingValues(forResult);
1585+
}
1586+
15731587
if (llvm::is_contained(distTypes, Type{}))
15741588
return failure();
15751589

@@ -1609,7 +1623,12 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
16091623
forOp.getResultTypes().end());
16101624
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
16111625
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1612-
warpInput.push_back(newWarpOp.getResult(retIdx));
1626+
auto newWarpResult = newWarpOp.getResult(retIdx);
1627+
// Unused forOp results yielded by the warpOp region are already included
1628+
// in the new ForOp.
1629+
if (llvm::is_contained(newOperands, newWarpResult))
1630+
continue;
1631+
warpInput.push_back(newWarpResult);
16131632
argIndexMapping[escapingValues[i]] = warpInputType.size();
16141633
warpInputType.push_back(inputTypes[i]);
16151634
}

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,42 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
584584
return
585585
}
586586

587+
// -----
588+
// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
589+
// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
590+
// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
591+
// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
592+
// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
593+
// CHECK-PROP: }
594+
// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
595+
// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
596+
// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
597+
// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
598+
// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
599+
// CHECK-PROP: }
600+
// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
601+
// CHECK-PROP: }
602+
// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
603+
func.func @warp_scf_for_unused_yield(%arg0: index) {
604+
%c128 = arith.constant 128 : index
605+
%c1 = arith.constant 1 : index
606+
%c0 = arith.constant 0 : index
607+
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
608+
%ini = "some_def"() : () -> (vector<128xf32>)
609+
%ini1 = "some_def"() : () -> (vector<128xf32>)
610+
%3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
611+
%add = arith.addi %arg3, %c1 : index
612+
%1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
613+
%acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
614+
scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
615+
}
616+
gpu.yield %3#0 : vector<128xf32>
617+
}
618+
"some_use"(%0) : (vector<4xf32>) -> ()
619+
return
620+
}
621+
622+
587623
// -----
588624

589625
// CHECK-PROP-LABEL: func @vector_reduction(

0 commit comments

Comments
 (0)