@@ -1554,36 +1554,22 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
1554
1554
llvm::SmallSetVector<Value, 32 > escapingValues;
1555
1555
SmallVector<Type> inputTypes;
1556
1556
SmallVector<Type> distTypes;
1557
- auto collectEscapingValues = [&](Value value) {
1558
- if (!escapingValues.insert (value))
1559
- return ;
1560
- Type distType = value.getType ();
1561
- if (auto vecType = dyn_cast<VectorType>(distType)) {
1562
- AffineMap map = distributionMapFn (value);
1563
- distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1564
- }
1565
- inputTypes.push_back (value.getType ());
1566
- distTypes.push_back (distType);
1567
- };
1568
-
1569
1557
mlir::visitUsedValuesDefinedAbove (
1570
1558
forOp.getBodyRegion (), [&](OpOperand *operand) {
1571
1559
Operation *parent = operand->get ().getParentRegion ()->getParentOp ();
1572
1560
if (warpOp->isAncestor (parent)) {
1573
- collectEscapingValues (operand->get ());
1561
+ if (!escapingValues.insert (operand->get ()))
1562
+ return ;
1563
+ Type distType = operand->get ().getType ();
1564
+ if (auto vecType = dyn_cast<VectorType>(distType)) {
1565
+ AffineMap map = distributionMapFn (operand->get ());
1566
+ distType = getDistributedType (vecType, map, warpOp.getWarpSize ());
1567
+ }
1568
+ inputTypes.push_back (operand->get ().getType ());
1569
+ distTypes.push_back (distType);
1574
1570
}
1575
1571
});
1576
1572
1577
- // Any forOp result that is not already yielded by the warpOp
1578
- // region is also considered escaping and must be returned by the
1579
- // original warpOp.
1580
- for (OpResult forResult : forOp.getResults ()) {
1581
- // Check if this forResult is already yielded by the yield op.
1582
- if (llvm::is_contained (yield->getOperands (), forResult))
1583
- continue ;
1584
- collectEscapingValues (forResult);
1585
- }
1586
-
1587
1573
if (llvm::is_contained (distTypes, Type{}))
1588
1574
return failure ();
1589
1575
@@ -1623,12 +1609,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
1623
1609
forOp.getResultTypes ().end ());
1624
1610
llvm::SmallDenseMap<Value, int64_t > argIndexMapping;
1625
1611
for (auto [i, retIdx] : llvm::enumerate (newRetIndices)) {
1626
- auto newWarpResult = newWarpOp.getResult (retIdx);
1627
- // Unused forOp results yielded by the warpOp region are already included
1628
- // in the new ForOp.
1629
- if (llvm::is_contained (newOperands, newWarpResult))
1630
- continue ;
1631
- warpInput.push_back (newWarpResult);
1612
+ warpInput.push_back (newWarpOp.getResult (retIdx));
1632
1613
argIndexMapping[escapingValues[i]] = warpInputType.size ();
1633
1614
warpInputType.push_back (inputTypes[i]);
1634
1615
}
0 commit comments