Skip to content

Commit 7173372

Browse files
fairywreathtomtor
authored andcommitted
[mlir][spirv] Implement lowering gpu.subgroup_reduce with cluster size for SPIRV (llvm#141402)
Implement lowering of `gpu.subgroup_reduce` with a cluster size attribute to SPIRV by using the `ClusteredReduce` group operation.
1 parent c9e1c8b commit 7173372

File tree

2 files changed

+70
-16
lines changed

2 files changed

+70
-16
lines changed

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -464,27 +464,39 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
464464

465465
template <typename UniformOp, typename NonUniformOp>
466466
static Value createGroupReduceOpImpl(OpBuilder &builder, Location loc,
467-
Value arg, bool isGroup, bool isUniform) {
467+
Value arg, bool isGroup, bool isUniform,
468+
std::optional<uint32_t> clusterSize) {
468469
Type type = arg.getType();
469470
auto scope = mlir::spirv::ScopeAttr::get(builder.getContext(),
470471
isGroup ? spirv::Scope::Workgroup
471472
: spirv::Scope::Subgroup);
472-
auto groupOp = spirv::GroupOperationAttr::get(builder.getContext(),
473-
spirv::GroupOperation::Reduce);
473+
auto groupOp = spirv::GroupOperationAttr::get(
474+
builder.getContext(), clusterSize.has_value()
475+
? spirv::GroupOperation::ClusteredReduce
476+
: spirv::GroupOperation::Reduce);
474477
if (isUniform) {
475478
return builder.create<UniformOp>(loc, type, scope, groupOp, arg)
476479
.getResult();
477480
}
478-
return builder.create<NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
481+
482+
Value clusterSizeValue;
483+
if (clusterSize.has_value())
484+
clusterSizeValue = builder.create<spirv::ConstantOp>(
485+
loc, builder.getI32Type(),
486+
builder.getIntegerAttr(builder.getI32Type(), *clusterSize));
487+
488+
return builder
489+
.create<NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
479490
.getResult();
480491
}
481492

482-
static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
483-
Location loc, Value arg,
484-
gpu::AllReduceOperation opType,
485-
bool isGroup, bool isUniform) {
493+
static std::optional<Value>
494+
createGroupReduceOp(OpBuilder &builder, Location loc, Value arg,
495+
gpu::AllReduceOperation opType, bool isGroup,
496+
bool isUniform, std::optional<uint32_t> clusterSize) {
486497
enum class ElemType { Float, Boolean, Integer };
487-
using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool);
498+
using FuncT = Value (*)(OpBuilder &, Location, Value, bool, bool,
499+
std::optional<uint32_t>);
488500
struct OpHandler {
489501
gpu::AllReduceOperation kind;
490502
ElemType elemType;
@@ -548,7 +560,7 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
548560

549561
for (const OpHandler &handler : handlers)
550562
if (handler.kind == opType && elementType == handler.elemType)
551-
return handler.func(builder, loc, arg, isGroup, isUniform);
563+
return handler.func(builder, loc, arg, isGroup, isUniform, clusterSize);
552564

553565
return std::nullopt;
554566
}
@@ -571,7 +583,7 @@ class GPUAllReduceConversion final
571583

572584
auto result =
573585
createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(), *opType,
574-
/*isGroup*/ true, op.getUniform());
586+
/*isGroup*/ true, op.getUniform(), std::nullopt);
575587
if (!result)
576588
return failure();
577589

@@ -589,16 +601,17 @@ class GPUSubgroupReduceConversion final
589601
LogicalResult
590602
matchAndRewrite(gpu::SubgroupReduceOp op, OpAdaptor adaptor,
591603
ConversionPatternRewriter &rewriter) const override {
592-
if (op.getClusterSize())
604+
if (op.getClusterStride() > 1) {
593605
return rewriter.notifyMatchFailure(
594-
op, "lowering for clustered reduce not implemented");
606+
op, "lowering for cluster stride > 1 is not implemented");
607+
}
595608

596609
if (!isa<spirv::ScalarType>(adaptor.getValue().getType()))
597610
return rewriter.notifyMatchFailure(op, "reduction type is not a scalar");
598611

599-
auto result = createGroupReduceOp(rewriter, op.getLoc(), adaptor.getValue(),
600-
adaptor.getOp(),
601-
/*isGroup=*/false, adaptor.getUniform());
612+
auto result = createGroupReduceOp(
613+
rewriter, op.getLoc(), adaptor.getValue(), adaptor.getOp(),
614+
/*isGroup=*/false, adaptor.getUniform(), op.getClusterSize());
602615
if (!result)
603616
return failure();
604617

mlir/test/Conversion/GPUToSPIRV/reductions.mlir

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,3 +789,44 @@ gpu.module @kernels {
789789
}
790790
}
791791
}
792+
793+
// -----
794+
795+
module attributes {
796+
gpu.container_module,
797+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupUniformArithmeticKHR, GroupNonUniformClustered], []>, #spirv.resource_limits<>>
798+
} {
799+
800+
gpu.module @kernels {
801+
// CHECK-LABEL: spirv.func @test_subgroup_reduce_clustered
802+
// CHECK-SAME: (%[[ARG:.*]]: f32)
803+
// CHECK: %[[CLUSTER_SIZE:.*]] = spirv.Constant 8 : i32
804+
gpu.func @test_subgroup_reduce_clustered(%arg : f32) kernel
805+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
806+
// CHECK: %{{.*}} = spirv.GroupNonUniformFAdd <Subgroup> <ClusteredReduce> %[[ARG]] cluster_size(%[[CLUSTER_SIZE]]) : f32, i32 -> f32
807+
%reduced = gpu.subgroup_reduce add %arg cluster(size = 8) : (f32) -> (f32)
808+
gpu.return
809+
}
810+
}
811+
812+
}
813+
814+
// -----
815+
816+
// Subgrop reduce with cluster stride > 1 is not yet supported.
817+
818+
module attributes {
819+
gpu.container_module,
820+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.3, [Kernel, Addresses, Groups, GroupUniformArithmeticKHR, GroupNonUniformClustered], []>, #spirv.resource_limits<>>
821+
} {
822+
823+
gpu.module @kernels {
824+
gpu.func @test_invalid_subgroup_reduce_clustered_stride(%arg : f32) kernel
825+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
826+
// expected-error @+1 {{failed to legalize operation 'gpu.subgroup_reduce'}}
827+
%reduced = gpu.subgroup_reduce add %arg cluster(size = 8, stride = 2) : (f32) -> (f32)
828+
gpu.return
829+
}
830+
}
831+
832+
}

0 commit comments

Comments
 (0)