Skip to content

Commit 8a2754e

Browse files
charithaintctomtor
authored andcommitted
Revert "[mlir][vector] Fix for WarpOpScfForOp failure when scf.for has results that are unused." (llvm#144124)
Reverts llvm#141853 Reverting the bug fix because it does not handle all cases correctly.
1 parent b23ab6c commit 8a2754e

File tree

2 files changed

+10
-65
lines changed

2 files changed

+10
-65
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,36 +1554,22 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
15541554
llvm::SmallSetVector<Value, 32> escapingValues;
15551555
SmallVector<Type> inputTypes;
15561556
SmallVector<Type> distTypes;
1557-
auto collectEscapingValues = [&](Value value) {
1558-
if (!escapingValues.insert(value))
1559-
return;
1560-
Type distType = value.getType();
1561-
if (auto vecType = dyn_cast<VectorType>(distType)) {
1562-
AffineMap map = distributionMapFn(value);
1563-
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1564-
}
1565-
inputTypes.push_back(value.getType());
1566-
distTypes.push_back(distType);
1567-
};
1568-
15691557
mlir::visitUsedValuesDefinedAbove(
15701558
forOp.getBodyRegion(), [&](OpOperand *operand) {
15711559
Operation *parent = operand->get().getParentRegion()->getParentOp();
15721560
if (warpOp->isAncestor(parent)) {
1573-
collectEscapingValues(operand->get());
1561+
if (!escapingValues.insert(operand->get()))
1562+
return;
1563+
Type distType = operand->get().getType();
1564+
if (auto vecType = dyn_cast<VectorType>(distType)) {
1565+
AffineMap map = distributionMapFn(operand->get());
1566+
distType = getDistributedType(vecType, map, warpOp.getWarpSize());
1567+
}
1568+
inputTypes.push_back(operand->get().getType());
1569+
distTypes.push_back(distType);
15741570
}
15751571
});
15761572

1577-
// Any forOp result that is not already yielded by the warpOp
1578-
// region is also considered escaping and must be returned by the
1579-
// original warpOp.
1580-
for (OpResult forResult : forOp.getResults()) {
1581-
// Check if this forResult is already yielded by the yield op.
1582-
if (llvm::is_contained(yield->getOperands(), forResult))
1583-
continue;
1584-
collectEscapingValues(forResult);
1585-
}
1586-
15871573
if (llvm::is_contained(distTypes, Type{}))
15881574
return failure();
15891575

@@ -1623,12 +1609,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
16231609
forOp.getResultTypes().end());
16241610
llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
16251611
for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
1626-
auto newWarpResult = newWarpOp.getResult(retIdx);
1627-
// Unused forOp results yielded by the warpOp region are already included
1628-
// in the new ForOp.
1629-
if (llvm::is_contained(newOperands, newWarpResult))
1630-
continue;
1631-
warpInput.push_back(newWarpResult);
1612+
warpInput.push_back(newWarpOp.getResult(retIdx));
16321613
argIndexMapping[escapingValues[i]] = warpInputType.size();
16331614
warpInputType.push_back(inputTypes[i]);
16341615
}

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -584,42 +584,6 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
584584
return
585585
}
586586

587-
// -----
588-
// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
589-
// CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
590-
// CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
591-
// CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
592-
// CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
593-
// CHECK-PROP: }
594-
// CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
595-
// CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
596-
// CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
597-
// CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
598-
// CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
599-
// CHECK-PROP: }
600-
// CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
601-
// CHECK-PROP: }
602-
// CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
603-
func.func @warp_scf_for_unused_yield(%arg0: index) {
604-
%c128 = arith.constant 128 : index
605-
%c1 = arith.constant 1 : index
606-
%c0 = arith.constant 0 : index
607-
%0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
608-
%ini = "some_def"() : () -> (vector<128xf32>)
609-
%ini1 = "some_def"() : () -> (vector<128xf32>)
610-
%3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
611-
%add = arith.addi %arg3, %c1 : index
612-
%1 = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
613-
%acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
614-
scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
615-
}
616-
gpu.yield %3#0 : vector<128xf32>
617-
}
618-
"some_use"(%0) : (vector<4xf32>) -> ()
619-
return
620-
}
621-
622-
623587
// -----
624588

625589
// CHECK-PROP-LABEL: func @vector_reduction(

0 commit comments

Comments
 (0)