Skip to content

Commit 0554871

Browse files
committed
[LV] Vectorize select min/max index.
Add support for vectorizing loops that select the index of the minimum or maximum element. The patch implements vectorizing those patterns by combining Min/Max and FindFirstIV reductions. It extends matching Min/Max reductions to allow in-loop users that are FindLastIV reductions. It records a flag indicating that the Min/Max reduction is used by another reduction. When creating reduction recipes, we process any reduction that has other reduction users. The reduction using the min/max reduction needs adjusting to compute the correct result: 1. We need to find the first IV for which the condition based on the min/max reduction is true, 2. Compare the partial min/max reduction result to its final value and, 3. Select the lanes of the partial FindLastIV reductions which correspond to the lanes matching the min/max reduction result.
1 parent 6fe32b9 commit 0554871

File tree

5 files changed

+581
-86
lines changed

5 files changed

+581
-86
lines changed

llvm/include/llvm/Analysis/IVDescriptors.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ class RecurrenceDescriptor {
296296
}
297297
}
298298

299+
void setKind(RecurKind NewKind) { Kind = NewKind; }
300+
299301
/// Returns a reference to the instructions used for type-promoting the
300302
/// recurrence.
301303
const SmallPtrSet<Instruction *, 8> &getCastInsts() const { return CastInsts; }
@@ -327,6 +329,10 @@ class RecurrenceDescriptor {
327329
/// AddReductionVar method, this field will be assigned the last met store.
328330
StoreInst *IntermediateStore = nullptr;
329331

332+
/// True if this recurrence is used by another recurrence in the loop. Users
333+
/// need to ensure that the final code-gen accounts for the use in the loop.
334+
bool IsUsedByOtherRecurrence = false;
335+
330336
private:
331337
// The starting value of the recurrence.
332338
// It does not have to be zero!

llvm/lib/Analysis/IVDescriptors.cpp

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ bool RecurrenceDescriptor::AddReductionVar(
255255
// Data used for determining if the recurrence has been type-promoted.
256256
Type *RecurrenceType = Phi->getType();
257257
SmallPtrSet<Instruction *, 4> CastInsts;
258-
unsigned MinWidthCastToRecurrenceType;
258+
unsigned MinWidthCastToRecurrenceType = -1ull;
259259
Instruction *Start = Phi;
260260
bool IsSigned = false;
261261

@@ -310,6 +310,7 @@ bool RecurrenceDescriptor::AddReductionVar(
310310
// This is either:
311311
// * An instruction type other than PHI or the reduction operation.
312312
// * A PHI in the header other than the initial PHI.
313+
bool IsUsedByOtherRecurrence = false;
313314
while (!Worklist.empty()) {
314315
Instruction *Cur = Worklist.pop_back_val();
315316

@@ -371,15 +372,37 @@ bool RecurrenceDescriptor::AddReductionVar(
371372

372373
// Any reduction instruction must be of one of the allowed kinds. We ignore
373374
// the starting value (the Phi or an AND instruction if the Phi has been
374-
// type-promoted).
375+
// type-promoted) and other in-loop users, if they form a FindLastIV
376+
// reduction. In the latter case, the user of the IVDescriptors must account
377+
// for that during codegen.
375378
if (Cur != Start) {
376379
ReduxDesc =
377380
isRecurrenceInstr(TheLoop, Phi, Cur, Kind, ReduxDesc, FuncFMF, SE);
378381
ExactFPMathInst = ExactFPMathInst == nullptr
379382
? ReduxDesc.getExactFPMathInst()
380383
: ExactFPMathInst;
381-
if (!ReduxDesc.isRecurrence())
384+
if (!ReduxDesc.isRecurrence()) {
385+
if (isMinMaxRecurrenceKind(Kind)) {
386+
// If the current recurrence is Min/Max, check if the current user is
387+
// a select that is a FindLastIV reduction. During codegen, this
388+
// recurrence needs to be turned into one that finds the first IV, as
389+
// the value to compare against is a Min/Max recurrence.
390+
auto *Sel = dyn_cast<SelectInst>(Cur);
391+
if (!Sel || !Sel->getType()->isIntegerTy())
392+
return false;
393+
auto *OtherPhi = dyn_cast<PHINode>(Sel->getOperand(2));
394+
if (!OtherPhi)
395+
return false;
396+
auto NewReduxDesc =
397+
isRecurrenceInstr(TheLoop, OtherPhi, Cur, RecurKind::FindLastIV,
398+
ReduxDesc, FuncFMF, SE);
399+
if (NewReduxDesc.isRecurrence()) {
400+
IsUsedByOtherRecurrence = true;
401+
continue;
402+
}
403+
}
382404
return false;
405+
}
383406
// FIXME: FMF is allowed on phi, but propagation is not handled correctly.
384407
if (isa<FPMathOperator>(ReduxDesc.getPatternInst()) && !IsAPhi) {
385408
FastMathFlags CurFMF = ReduxDesc.getPatternInst()->getFastMathFlags();
@@ -503,7 +526,7 @@ bool RecurrenceDescriptor::AddReductionVar(
503526
// pattern or more than just a select and cmp. Zero implies that we saw a
504527
// llvm.min/max intrinsic, which is always OK.
505528
if (isMinMaxRecurrenceKind(Kind) && NumCmpSelectPatternInst != 2 &&
506-
NumCmpSelectPatternInst != 0)
529+
NumCmpSelectPatternInst != 0 && !IsUsedByOtherRecurrence)
507530
return false;
508531

509532
if (isAnyOfRecurrenceKind(Kind) && NumCmpSelectPatternInst != 1)
@@ -535,7 +558,13 @@ bool RecurrenceDescriptor::AddReductionVar(
535558
ExitInstruction = cast<Instruction>(IntermediateStore->getValueOperand());
536559
}
537560

538-
if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction)
561+
if (!FoundStartPHI || !FoundReduxOp)
562+
return false;
563+
564+
if (IsUsedByOtherRecurrence) {
565+
if (ExitInstruction)
566+
return false;
567+
} else if (!ExitInstruction)
539568
return false;
540569

541570
const bool IsOrdered =
@@ -586,8 +615,9 @@ bool RecurrenceDescriptor::AddReductionVar(
586615
// without needing a white list of instructions to ignore.
587616
// This may also be useful for the inloop reductions, if it can be
588617
// kept simple enough.
589-
collectCastInstrs(TheLoop, ExitInstruction, RecurrenceType, CastInsts,
590-
MinWidthCastToRecurrenceType);
618+
if (ExitInstruction)
619+
collectCastInstrs(TheLoop, ExitInstruction, RecurrenceType, CastInsts,
620+
MinWidthCastToRecurrenceType);
591621

592622
// We found a reduction var if we have reached the original phi node and we
593623
// only have a single instruction with out-of-loop users.
@@ -600,7 +630,7 @@ bool RecurrenceDescriptor::AddReductionVar(
600630
FMF, ExactFPMathInst, RecurrenceType, IsSigned,
601631
IsOrdered, CastInsts, MinWidthCastToRecurrenceType);
602632
RedDes = RD;
603-
633+
RedDes.IsUsedByOtherRecurrence = IsUsedByOtherRecurrence;
604634
return true;
605635
}
606636

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8499,7 +8499,7 @@ bool VPRecipeBuilder::getScaledReductions(
84998499
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
85008500
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
85018501

8502-
if (!CM.TheLoop->contains(RdxExitInstr))
8502+
if (!RdxExitInstr || !CM.TheLoop->contains(RdxExitInstr))
85038503
return false;
85048504

85058505
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
@@ -9621,6 +9621,75 @@ void LoopVectorizationPlanner::adjustRecipesForReductions(
96219621
PhiR->setOperand(0, Plan->getOrAddLiveIn(RdxDesc.getSentinelValue()));
96229622
}
96239623
}
9624+
9625+
// Check if any reduction is used by another reduction, by starting from
9626+
// ComputeReductionResults and checking if the recurrence descriptor is marked
9627+
// has having another recurrence user.
9628+
for (auto &R : *MiddleVPBB) {
9629+
auto *MinMaxResult = dyn_cast<VPInstruction>(&R);
9630+
if (!MinMaxResult ||
9631+
MinMaxResult->getOpcode() != VPInstruction::ComputeReductionResult)
9632+
continue;
9633+
auto *MinMaxPhiR = cast<VPReductionPHIRecipe>(MinMaxResult->getOperand(0));
9634+
auto &MinMaxRdxDesc = MinMaxPhiR->getRecurrenceDescriptor();
9635+
if (!MinMaxRdxDesc.IsUsedByOtherRecurrence)
9636+
continue;
9637+
9638+
assert(RecurrenceDescriptor::isMinMaxRecurrenceKind(
9639+
MinMaxRdxDesc.getRecurrenceKind()) &&
9640+
"only min/max reductions can be used by other recurrences");
9641+
9642+
SmallVector<VPUser *> Worklist;
9643+
append_range(Worklist, MinMaxPhiR->users());
9644+
VPReductionPHIRecipe *FindIVPhiR = nullptr;
9645+
// Starting from MinMaxPhiR's users, find the other reduction phi using
9646+
// MinMaxPhiR.
9647+
while (!Worklist.empty()) {
9648+
VPUser *Cur = Worklist.pop_back_val();
9649+
if (isa<VPHeaderPHIRecipe>(Cur)) {
9650+
if (Cur != MinMaxPhiR) {
9651+
assert(!FindIVPhiR &&
9652+
"Only the starting MinMaxPhiR or another reduction "
9653+
"phi must be reachable");
9654+
FindIVPhiR = cast<VPReductionPHIRecipe>(Cur);
9655+
}
9656+
continue;
9657+
}
9658+
// Skip recipes outside any region.
9659+
if (!cast<VPSingleDefRecipe>(Cur)->getParent()->getParent())
9660+
continue;
9661+
append_range(Worklist, cast<VPSingleDefRecipe>(Cur)->users());
9662+
}
9663+
9664+
// Find the recipe computing the result of the other reduction.
9665+
VPInstruction *FindIVResult = nullptr;
9666+
for (auto *U : FindIVPhiR->users()) {
9667+
auto *VPI = dyn_cast<VPInstruction>(U);
9668+
if (!VPI || VPI->getOpcode() != VPInstruction::ComputeFindIVResult)
9669+
continue;
9670+
FindIVResult = VPI;
9671+
}
9672+
assert(FindIVResult && "must find a matching ComputeFindIVResult");
9673+
9674+
// The reduction using MinMaxPhiR needs adjusting to compute the correct
9675+
// result:
9676+
// 1. We need to find the first IV for which the condition based on the
9677+
// min/max recurrence is true,
9678+
// 2. Compare the partial min/max reduction result to its final value and,
9679+
// 3. Select the lanes of the partial FindLastIV reductions which
9680+
// correspond to the lanes matching the min/max reduction result.
9681+
FindIVResult->moveAfter(MinMaxResult);
9682+
VPBuilder B(FindIVResult);
9683+
const_cast<RecurrenceDescriptor &>(FindIVPhiR->getRecurrenceDescriptor())
9684+
.setKind(RecurKind::FindFirstIVUMax);
9685+
auto *Cmp = B.createICmp(CmpInst::ICMP_EQ, MinMaxResult->getOperand(1),
9686+
MinMaxResult);
9687+
auto *Sel = B.createSelect(
9688+
Cmp, FindIVResult->getOperand(2),
9689+
Plan->getOrAddLiveIn(
9690+
FindIVPhiR->getRecurrenceDescriptor().getSentinelValue()));
9691+
FindIVResult->setOperand(2, Sel);
9692+
}
96249693
for (VPRecipeBase *R : ToDelete)
96259694
R->eraseFromParent();
96269695

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1804,7 +1804,8 @@ class VPHeaderPHIRecipe : public VPSingleDefRecipe, public VPPhiAccessors {
18041804
~VPHeaderPHIRecipe() override = default;
18051805

18061806
/// Method to support type inquiry through isa, cast, and dyn_cast.
1807-
static inline bool classof(const VPRecipeBase *B) {
1807+
static inline bool classof(const VPUser *U) {
1808+
auto *B = cast<VPRecipeBase>(U);
18081809
return B->getVPDefID() >= VPDef::VPFirstHeaderPHISC &&
18091810
B->getVPDefID() <= VPDef::VPLastHeaderPHISC;
18101811
}
@@ -1813,6 +1814,10 @@ class VPHeaderPHIRecipe : public VPSingleDefRecipe, public VPPhiAccessors {
18131814
return B && B->getVPDefID() >= VPRecipeBase::VPFirstHeaderPHISC &&
18141815
B->getVPDefID() <= VPRecipeBase::VPLastHeaderPHISC;
18151816
}
1817+
static inline bool classof(const VPSingleDefRecipe *B) {
1818+
return B->getVPDefID() >= VPDef::VPFirstHeaderPHISC &&
1819+
B->getVPDefID() <= VPDef::VPLastHeaderPHISC;
1820+
}
18161821

18171822
/// Generate the phi nodes.
18181823
void execute(VPTransformState &State) override = 0;

0 commit comments

Comments
 (0)