Skip to content

Revert "[mlir][vector] Fix for WarpOpScfForOp failure when scf.for has results that are unused." #144124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 13, 2025

Conversation

charithaintc
Copy link
Contributor

Reverts #141853

Reverting the bug fix because it does not handle all cases correctly.

Copy link

⚠️ We detected that you are using a GitHub private e-mail address to contribute to the repo.
Please turn off Keep my email addresses private setting in your account.
See LLVM Developer Policy and LLVM Discourse for more information.

@llvmbot
Copy link
Member

llvmbot commented Jun 13, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Charitha Saumya (charithaintc)

Changes

Reverts llvm/llvm-project#141853

Reverting the bug fix because it does not handle all cases correctly.


Full diff: https://github.com/llvm/llvm-project/pull/144124.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp (+10-29)
  • (modified) mlir/test/Dialect/Vector/vector-warp-distribute.mlir (-36)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 52a9cedb43cc0..045c192787f10 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -1554,36 +1554,22 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
     llvm::SmallSetVector<Value, 32> escapingValues;
     SmallVector<Type> inputTypes;
     SmallVector<Type> distTypes;
-    auto collectEscapingValues = [&](Value value) {
-      if (!escapingValues.insert(value))
-        return;
-      Type distType = value.getType();
-      if (auto vecType = dyn_cast<VectorType>(distType)) {
-        AffineMap map = distributionMapFn(value);
-        distType = getDistributedType(vecType, map, warpOp.getWarpSize());
-      }
-      inputTypes.push_back(value.getType());
-      distTypes.push_back(distType);
-    };
-
     mlir::visitUsedValuesDefinedAbove(
         forOp.getBodyRegion(), [&](OpOperand *operand) {
           Operation *parent = operand->get().getParentRegion()->getParentOp();
           if (warpOp->isAncestor(parent)) {
-            collectEscapingValues(operand->get());
+            if (!escapingValues.insert(operand->get()))
+              return;
+            Type distType = operand->get().getType();
+            if (auto vecType = dyn_cast<VectorType>(distType)) {
+              AffineMap map = distributionMapFn(operand->get());
+              distType = getDistributedType(vecType, map, warpOp.getWarpSize());
+            }
+            inputTypes.push_back(operand->get().getType());
+            distTypes.push_back(distType);
           }
         });
 
-    // Any forOp result that is not already yielded by the warpOp
-    // region is also considered escaping and must be returned by the
-    // original warpOp.
-    for (OpResult forResult : forOp.getResults()) {
-      // Check if this forResult is already yielded by the yield op.
-      if (llvm::is_contained(yield->getOperands(), forResult))
-        continue;
-      collectEscapingValues(forResult);
-    }
-
     if (llvm::is_contained(distTypes, Type{}))
       return failure();
 
@@ -1623,12 +1609,7 @@ struct WarpOpScfForOp : public WarpDistributionPattern {
                                     forOp.getResultTypes().end());
     llvm::SmallDenseMap<Value, int64_t> argIndexMapping;
     for (auto [i, retIdx] : llvm::enumerate(newRetIndices)) {
-      auto newWarpResult = newWarpOp.getResult(retIdx);
-      // Unused forOp results yielded by the warpOp region are already included
-      // in the new ForOp.
-      if (llvm::is_contained(newOperands, newWarpResult))
-        continue;
-      warpInput.push_back(newWarpResult);
+      warpInput.push_back(newWarpOp.getResult(retIdx));
       argIndexMapping[escapingValues[i]] = warpInputType.size();
       warpInputType.push_back(inputTypes[i]);
     }
diff --git a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
index 6c7ac7a5196a7..38771f2593449 100644
--- a/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
+++ b/mlir/test/Dialect/Vector/vector-warp-distribute.mlir
@@ -584,42 +584,6 @@ func.func @warp_scf_for_multiple_yield(%arg0: index, %arg1: memref<?xf32>, %arg2
   return
 }
 
-// -----
-// CHECK-PROP-LABEL: func.func @warp_scf_for_unused_yield(
-//       CHECK-PROP: %[[W0:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[INI0:.*]] = "some_def"() : () -> vector<128xf32>
-//       CHECK-PROP: %[[INI1:.*]] = "some_def"() : () -> vector<128xf32>
-//       CHECK-PROP: gpu.yield %[[INI0]], %[[INI1]] : vector<128xf32>, vector<128xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: %[[F:.*]]:2 = scf.for %{{.*}} iter_args(%{{.*}} = %[[W0]]#0, %{{.*}} = %[[W0]]#1) -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[W1:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[32] args(%{{.*}} : vector<4xf32>, vector<4xf32>) -> (vector<4xf32>, vector<4xf32>) {
-//       CHECK-PROP: %[[ACC0:.*]] = "some_def"(%{{.*}}) : (vector<128xf32>, index) -> vector<128xf32>
-//       CHECK-PROP: %[[ACC1:.*]] = "some_def"(%{{.*}}) : (index, vector<128xf32>, vector<128xf32>) -> vector<128xf32>
-//       CHECK-PROP: gpu.yield %[[ACC1]], %[[ACC0]] : vector<128xf32>, vector<128xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: scf.yield %[[W1]]#0, %[[W1]]#1 : vector<4xf32>, vector<4xf32>
-//       CHECK-PROP: }
-//       CHECK-PROP: "some_use"(%[[F]]#0) : (vector<4xf32>) -> ()
-func.func @warp_scf_for_unused_yield(%arg0: index) {
-  %c128 = arith.constant 128 : index
-  %c1 = arith.constant 1 : index
-  %c0 = arith.constant 0 : index
-  %0 = gpu.warp_execute_on_lane_0(%arg0)[32] -> (vector<4xf32>) {
-    %ini = "some_def"() : () -> (vector<128xf32>)
-    %ini1 = "some_def"() : () -> (vector<128xf32>)
-    %3:2 = scf.for %arg3 = %c0 to %c128 step %c1 iter_args(%arg4 = %ini, %arg5 = %ini1) -> (vector<128xf32>, vector<128xf32>) {
-      %add = arith.addi %arg3, %c1 : index
-      %1  = "some_def"(%arg5, %add) : (vector<128xf32>, index) -> (vector<128xf32>)
-      %acc = "some_def"(%add, %arg4, %1) : (index, vector<128xf32>, vector<128xf32>) -> (vector<128xf32>)
-      scf.yield %acc, %1 : vector<128xf32>, vector<128xf32>
-    }
-    gpu.yield %3#0 : vector<128xf32>
-  }
-  "some_use"(%0) : (vector<4xf32>) -> ()
-  return
-}
-
-
 // -----
 
 // CHECK-PROP-LABEL: func @vector_reduction(

@charithaintc charithaintc merged commit 0c7ce68 into main Jun 13, 2025
5 of 7 checks passed
@charithaintc charithaintc deleted the revert-141853-scf_for_bug branch June 13, 2025 18:02
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
…s results that are unused." (llvm#144124)

Reverts llvm#141853

Reverting the bug fix because it does not handle all cases correctly.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants