@@ -464,27 +464,39 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
464
464
465
465
template <typename UniformOp, typename NonUniformOp>
466
466
static Value createGroupReduceOpImpl (OpBuilder &builder, Location loc,
467
- Value arg, bool isGroup, bool isUniform) {
467
+ Value arg, bool isGroup, bool isUniform,
468
+ std::optional<uint32_t > clusterSize) {
468
469
Type type = arg.getType ();
469
470
auto scope = mlir::spirv::ScopeAttr::get (builder.getContext (),
470
471
isGroup ? spirv::Scope::Workgroup
471
472
: spirv::Scope::Subgroup);
472
- auto groupOp = spirv::GroupOperationAttr::get (builder.getContext (),
473
- spirv::GroupOperation::Reduce);
473
+ auto groupOp = spirv::GroupOperationAttr::get (
474
+ builder.getContext (), clusterSize.has_value ()
475
+ ? spirv::GroupOperation::ClusteredReduce
476
+ : spirv::GroupOperation::Reduce);
474
477
if (isUniform) {
475
478
return builder.create <UniformOp>(loc, type, scope, groupOp, arg)
476
479
.getResult ();
477
480
}
478
- return builder.create <NonUniformOp>(loc, type, scope, groupOp, arg, Value{})
481
+
482
+ Value clusterSizeValue;
483
+ if (clusterSize.has_value ())
484
+ clusterSizeValue = builder.create <spirv::ConstantOp>(
485
+ loc, builder.getI32Type (),
486
+ builder.getIntegerAttr (builder.getI32Type (), *clusterSize));
487
+
488
+ return builder
489
+ .create <NonUniformOp>(loc, type, scope, groupOp, arg, clusterSizeValue)
479
490
.getResult ();
480
491
}
481
492
482
- static std::optional<Value> createGroupReduceOp (OpBuilder &builder,
483
- Location loc, Value arg,
484
- gpu::AllReduceOperation opType,
485
- bool isGroup, bool isUniform ) {
493
+ static std::optional<Value>
494
+ createGroupReduceOp (OpBuilder &builder, Location loc, Value arg,
495
+ gpu::AllReduceOperation opType, bool isGroup ,
496
+ bool isUniform, std::optional< uint32_t > clusterSize ) {
486
497
enum class ElemType { Float, Boolean, Integer };
487
- using FuncT = Value (*)(OpBuilder &, Location, Value, bool , bool );
498
+ using FuncT = Value (*)(OpBuilder &, Location, Value, bool , bool ,
499
+ std::optional<uint32_t >);
488
500
struct OpHandler {
489
501
gpu::AllReduceOperation kind;
490
502
ElemType elemType;
@@ -548,7 +560,7 @@ static std::optional<Value> createGroupReduceOp(OpBuilder &builder,
548
560
549
561
for (const OpHandler &handler : handlers)
550
562
if (handler.kind == opType && elementType == handler.elemType )
551
- return handler.func (builder, loc, arg, isGroup, isUniform);
563
+ return handler.func (builder, loc, arg, isGroup, isUniform, clusterSize );
552
564
553
565
return std::nullopt;
554
566
}
@@ -571,7 +583,7 @@ class GPUAllReduceConversion final
571
583
572
584
auto result =
573
585
createGroupReduceOp (rewriter, op.getLoc (), adaptor.getValue (), *opType,
574
- /* isGroup*/ true , op.getUniform ());
586
+ /* isGroup*/ true , op.getUniform (), std::nullopt );
575
587
if (!result)
576
588
return failure ();
577
589
@@ -589,16 +601,17 @@ class GPUSubgroupReduceConversion final
589
601
LogicalResult
590
602
matchAndRewrite (gpu::SubgroupReduceOp op, OpAdaptor adaptor,
591
603
ConversionPatternRewriter &rewriter) const override {
592
- if (op.getClusterSize ())
604
+ if (op.getClusterStride () > 1 ) {
593
605
return rewriter.notifyMatchFailure (
594
- op, " lowering for clustered reduce not implemented" );
606
+ op, " lowering for cluster stride > 1 is not implemented" );
607
+ }
595
608
596
609
if (!isa<spirv::ScalarType>(adaptor.getValue ().getType ()))
597
610
return rewriter.notifyMatchFailure (op, " reduction type is not a scalar" );
598
611
599
- auto result = createGroupReduceOp (rewriter, op. getLoc (), adaptor. getValue (),
600
- adaptor.getOp (),
601
- /* isGroup=*/ false , adaptor.getUniform ());
612
+ auto result = createGroupReduceOp (
613
+ rewriter, op. getLoc (), adaptor. getValue (), adaptor.getOp (),
614
+ /* isGroup=*/ false , adaptor.getUniform (), op. getClusterSize ());
602
615
if (!result)
603
616
return failure ();
604
617
0 commit comments