-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[MLIR][XeGPU] Add transformation pattern for vector.broadcast in Wg to Sg pass #144417
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
VectorType::get(sgShape, resultType.getElementType()); | ||
|
||
SmallVector<Value> newBroadcastOps; | ||
for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about use range-based for loop?
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); | ||
if (!layout || !layout.getSgLayout()) | ||
return failure(); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks to me that the current implementation is assuming the rank of source is the same as the rank of the result, which is a subset of the supported semantics of vector.broadcast
. I believe it is partially because of the limitation of LayoutAttr
. It would be better to add a check.
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Nishant Patel (nbpatel) ChangesThis PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass Full diff: https://github.com/llvm/llvm-project/pull/144417.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
index a26c6b52f0ddc..96c7032d6b812 100644
--- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
+++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
@@ -328,6 +328,39 @@ struct WgToSgPrefetchNdOp : public OpConversionPattern<xegpu::PrefetchNdOp> {
}
};
+/// This pattern transforms vector.broadcast ops to work at subgroup level.
+struct WgToSgVectorBroadcastOp
+ : public OpConversionPattern<vector::BroadcastOp> {
+ using OpConversionPattern<vector::BroadcastOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::BroadcastOp op, OneToNOpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = op.getResult().getType();
+ ArrayRef<int64_t> wgShape = resultType.getShape();
+
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult());
+ if (!layout || !layout.getSgLayout())
+ return failure();
+
+ SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
+ VectorType newResultType =
+ VectorType::get(sgShape, resultType.getElementType());
+
+ SmallVector<Value> newBroadcastOps;
+ for (size_t i = 0; i < adaptor.getOperands().front().size(); ++i) {
+ auto newBroadcast = rewriter.create<vector::BroadcastOp>(
+ op.getLoc(), newResultType, adaptor.getOperands().front()[i]);
+ xegpu::setLayoutAttr(newBroadcast->getResult(0),
+ layout.dropSgLayoutAndData());
+ newBroadcastOps.push_back(newBroadcast.getResult());
+ }
+
+ rewriter.replaceOpWithMultiple(op, {newBroadcastOps});
+ return success();
+ }
+};
+
// Handles UnrealizedConversionCastOp generated during
// SCFStructuralTypeConversions (step 1). This op may appear as either a
// target or source materialization for Vector values, e.g.:
@@ -411,7 +444,8 @@ namespace xegpu {
void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
patterns.add<WgToSgCreateNdOp, WgToSgLoadNdOp, WgToSgStoreNdOp,
WgToSgUpdateNdOffsetOp, WgToSgDpasOp, WgToSgPrefetchNdOp,
- UnrealizedConversionCastOpPattern>(patterns.getContext());
+ WgToSgVectorBroadcastOp, UnrealizedConversionCastOpPattern>(
+ patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
@@ -518,6 +552,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});
+ target.addDynamicallyLegalOp<vector::BroadcastOp>(
+ [=](vector::BroadcastOp op) -> bool {
+ return isLegal(xegpu::getLayoutAttr(op.getResult()));
+ });
+
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[=](UnrealizedConversionCastOp op) {
return llvm::is_contained(existingCastOps, op.getOperation());
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
index 35ad16d8cd9a9..60ac266b0f112 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-rr.mlir
@@ -103,6 +103,24 @@ gpu.module @test_round_robin_assignment {
gpu.return
}
+ // CHECK-LABEL: test_broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ // CHECK-COUNT-3: vector.broadcast {{.*}}
+ // CHECK-SAME-COUNT-3: {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME-COUNT-3: : vector<2x1xf32> to vector<2x4xf32>
+ // CHECK-NOT: vector.broadcast
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [4, 1], sg_data = [2, 4], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
+
gpu.func @test_scf_for(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
%c1 = arith.constant 1 : index
%c10 = arith.constant 10 : index
@@ -197,5 +215,4 @@ gpu.module @test_round_robin_assignment {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [8], sg_data = [16]>>
gpu.return
}
-
}
diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
index 466842c968448..125bab349b4cb 100644
--- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
+++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir
@@ -170,6 +170,22 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
gpu.return
}
+ // CHECK-LABEL: test_broadcast
+ // CHECK-SAME: %[[ARG_0:.*]]: memref<24x1xf32>
+ gpu.func @test_broadcast(%src: memref<24x1xf32>) {
+ %tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x1xf32>
+ -> !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ %load = xegpu.load_nd %tdesc
+ : !xegpu.tensor_desc<24x1xf32, #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 1], lane_layout = [2, 1], lane_data = [1, 1]>>
+ -> vector<24x1xf32>
+ // CHECK: vector.broadcast {{.*}} {layout_result_0 = #xegpu.layout<lane_layout = [2, 1], lane_data = [1, 1]>}
+ // CHECK-SAME: : vector<12x1xf32> to vector<12x8xf32>
+ %broadcast = vector.broadcast %load
+ {layout_result_0 = #xegpu.layout<sg_layout = [2, 1], sg_data = [12, 8], lane_layout = [2, 1], lane_data = [1, 1]>}
+ : vector<24x1xf32> to vector<24x8xf32>
+ gpu.return
+ }
+
gpu.func @test_scf_for(%arg0: memref<1024x1024xf16>, %arg1: memref<1024x1024xf16>, %arg2: memref<1024x1024xf32>) {
//CHECK: [[c0:%.+]] = arith.constant 0 : index
//CHECK: [[c128:%.+]] = arith.constant 128 : index
@@ -295,6 +311,5 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
gpu.return
}
-
-
}
+
|
This PR adds transformation pattern for vector.broadcast op in xegpu-wg-to-sg-distribute pass