Skip to content

[MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass #144417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

nbpatel
Copy link
Contributor

@nbpatel nbpatel commented Jun 16, 2025

This PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass

@nbpatel nbpatel requested a review from chencha3 June 16, 2025 19:41
VectorType::get(sgShape, resultType.getElementType());

SmallVector<Value> newBroadcastOps;
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about use range-based for loop?

xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
if (!layout || !layout.getSgLayout())
return failure();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of vector.broadcast. I believe it is partially because of the limitation of LayoutAttr. It would be better to add a check.

@nbpatel nbpatel marked this pull request as ready for review June 16, 2025 20:33
@llvmbot
Copy link
Member

llvmbot commented Jun 16, 2025

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Nishant Patel (nbpatel)

Changes

This PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass


Full diff: https://github.com/llvm/llvm-project/pull/144417.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp (+40-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir (+18-1)
  • (modified) mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir (+17-2)
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a26c6b52f0ddc..96c7032d6b812 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -328,6 +328,39 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
   }
 };
 
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+struct WgToSgVectorBroadcastOp
+    : public OpConversionPattern<vector::BroadcastOp> {
+  using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = op.getResult().getType();
+    ArrayRef<int64_t> wgShape = resultType.getShape();
+
+    xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+    if (!layout || !layout.getSgLayout())
+      return failure();
+
+    SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+    VectorType newResultType =
+        VectorType::get(sgShape, resultType.getElementType());
+
+    SmallVector<Value> newBroadcastOps;
+    for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
+      auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+          op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+      xegpu::setLayoutAttr(newBroadcast->getResult(0),
+                           layout.dropSgLayoutAndData());
+      newBroadcastOps.push_back(newBroadcast.getResult());
+    }
+
+    rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
+    return success();
+  }
+};
+
 // Handles UnrealizedConversionCastOp generated during
 // SCFStructuralTypeConversions (step 1). This op may appear as either a
 // target or source materialization for Vector values, e.g.:
@@ -411,7 +444,8 @@ namespace xegpu {
 void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
   patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
                WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
-               UnrealizedConversionCastOpPattern>(patterns.getContext());
+               WgToSgVectorBroadcastOp, UnrealizedConversionCastOpPattern>(
+      patterns.getContext());
 }
 } // namespace xegpu
 } // namespace mlir
@@ -518,6 +552,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
     return isLegal(layout);
   });
 
+  target.addDynamicallyLegalOp<vector::BroadcastOp>(
+      [=](vector::BroadcastOp op) -> bool {
+        return isLegal(xegpu::getLayoutAttr(op.getResult()));
+      });
+
   target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
       [=](UnrealizedConversionCastOp op) {
         return llvm::is_contained(existingCastOps, op.getOperation());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 35ad16d8cd9a9..60ac266b0f112 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
     gpu.return
   }
 
+  // CHECK-LABEL: test_broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    // CHECK-COUNT-3: vector.broadcast {{.*}}
+    // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+    // CHECK-NOT: vector.broadcast
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
+
   gpu.func @test_scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
     %c1 = arith.constant 1 : index
     %c10 = arith.constant 10 : index
@@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
     xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
     gpu.return
   }
-
 }
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 466842c968448..125bab349b4cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,6 +170,22 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     gpu.return
   }
 
+  // CHECK-LABEL: test_broadcast
+  // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+  gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+    %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+      -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+    %load =  xegpu.load_nd %tdesc
+      : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+      -> vector<24x1xf32>
+    // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+    // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
+    %broadcast = vector.broadcast %load 
+      {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+      : vector<24x1xf32> to vector<24x8xf32>
+    gpu.return
+  }
+
   gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
     //CHECK: [[c0:%.+]] = arith.constant 0 : index
     //CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -295,6 +311,5 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
     xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
     gpu.return
   }
-
-
 }
+

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants