Skip to content

Commit bcd8009

Browse files
committed
[Attributor] Use the proper context instruction in genericValueTraversal
There was a TODO in genericValueTraversal to provide the context instruction and due to the lack of it users that wanted one just used something available. Unfortunately, using a fixed instruction is wrong in the presence of PHIs so we need to update the context instruction properly. Reviewed By: uenoku Differential Revision: https://reviews.llvm.org/D76870
1 parent ac96c8f commit bcd8009

File tree

2 files changed

+123
-43
lines changed

2 files changed

+123
-43
lines changed

llvm/lib/Transforms/IPO/Attributor.cpp

Lines changed: 53 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -398,8 +398,10 @@ static Value *constructPointer(Type *ResTy, Value *Ptr, int64_t Offset,
398398
template <typename AAType, typename StateTy>
399399
static bool genericValueTraversal(
400400
Attributor &A, IRPosition IRP, const AAType &QueryingAA, StateTy &State,
401-
function_ref<bool(Value &, StateTy &, bool)> VisitValueCB,
402-
int MaxValues = 8, function_ref<Value *(Value *)> StripCB = nullptr) {
401+
function_ref<bool(Value &, const Instruction *, StateTy &, bool)>
402+
VisitValueCB,
403+
const Instruction *CtxI, int MaxValues = 16,
404+
function_ref<Value *(Value *)> StripCB = nullptr) {
403405

404406
const AAIsDead *LivenessAA = nullptr;
405407
if (IRP.getAnchorScope())
@@ -408,20 +410,22 @@ static bool genericValueTraversal(
408410
/* TrackDependence */ false);
409411
bool AnyDead = false;
410412

411-
// TODO: Use Positions here to allow context sensitivity in VisitValueCB
412-
SmallPtrSet<Value *, 16> Visited;
413-
SmallVector<Value *, 16> Worklist;
414-
Worklist.push_back(&IRP.getAssociatedValue());
413+
using Item = std::pair<Value *, const Instruction *>;
414+
SmallSet<Item, 16> Visited;
415+
SmallVector<Item, 16> Worklist;
416+
Worklist.push_back({&IRP.getAssociatedValue(), CtxI});
415417

416418
int Iteration = 0;
417419
do {
418-
Value *V = Worklist.pop_back_val();
420+
Item I = Worklist.pop_back_val();
421+
Value *V = I.first;
422+
CtxI = I.second;
419423
if (StripCB)
420424
V = StripCB(V);
421425

422426
// Check if we should process the current value. To prevent endless
423427
// recursion keep a record of the values we followed!
424-
if (!Visited.insert(V).second)
428+
if (!Visited.insert(I).second)
425429
continue;
426430

427431
// Make sure we limit the compile time for complex expressions.
@@ -444,14 +448,14 @@ static bool genericValueTraversal(
444448
}
445449
}
446450
if (NewV && NewV != V) {
447-
Worklist.push_back(NewV);
451+
Worklist.push_back({NewV, CtxI});
448452
continue;
449453
}
450454

451455
// Look through select instructions, visit both potential values.
452456
if (auto *SI = dyn_cast<SelectInst>(V)) {
453-
Worklist.push_back(SI->getTrueValue());
454-
Worklist.push_back(SI->getFalseValue());
457+
Worklist.push_back({SI->getTrueValue(), CtxI});
458+
Worklist.push_back({SI->getFalseValue(), CtxI});
455459
continue;
456460
}
457461

@@ -460,20 +464,21 @@ static bool genericValueTraversal(
460464
assert(LivenessAA &&
461465
"Expected liveness in the presence of instructions!");
462466
for (unsigned u = 0, e = PHI->getNumIncomingValues(); u < e; u++) {
463-
const BasicBlock *IncomingBB = PHI->getIncomingBlock(u);
467+
BasicBlock *IncomingBB = PHI->getIncomingBlock(u);
464468
if (A.isAssumedDead(*IncomingBB->getTerminator(), &QueryingAA,
465469
LivenessAA,
466470
/* CheckBBLivenessOnly */ true)) {
467471
AnyDead = true;
468472
continue;
469473
}
470-
Worklist.push_back(PHI->getIncomingValue(u));
474+
Worklist.push_back(
475+
{PHI->getIncomingValue(u), IncomingBB->getTerminator()});
471476
}
472477
continue;
473478
}
474479

475480
// Once a leaf is reached we inform the user through the callback.
476-
if (!VisitValueCB(*V, State, Iteration > 1))
481+
if (!VisitValueCB(*V, CtxI, State, Iteration > 1))
477482
return false;
478483
} while (!Worklist.empty());
479484

@@ -710,7 +715,7 @@ void IRPosition::getAttrs(ArrayRef<Attribute::AttrKind> AKs,
710715
}
711716
if (A)
712717
for (Attribute::AttrKind AK : AKs)
713-
getAttrsFromAssumes(AK, Attrs, *A);
718+
getAttrsFromAssumes(AK, Attrs, *A);
714719
}
715720

716721
bool IRPosition::getAttrsFromIRAttr(Attribute::AttrKind AK,
@@ -1466,7 +1471,8 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
14661471
};
14671472

14681473
// Callback for a leaf value returned by the associated function.
1469-
auto VisitValueCB = [](Value &Val, RVState &RVS, bool) -> bool {
1474+
auto VisitValueCB = [](Value &Val, const Instruction *, RVState &RVS,
1475+
bool) -> bool {
14701476
auto Size = RVS.RetValsMap[&Val].size();
14711477
RVS.RetValsMap[&Val].insert(RVS.RetInsts.begin(), RVS.RetInsts.end());
14721478
bool Inserted = RVS.RetValsMap[&Val].size() != Size;
@@ -1480,18 +1486,19 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
14801486
};
14811487

14821488
// Helper method to invoke the generic value traversal.
1483-
auto VisitReturnedValue = [&](Value &RV, RVState &RVS) {
1489+
auto VisitReturnedValue = [&](Value &RV, RVState &RVS,
1490+
const Instruction *CtxI) {
14841491
IRPosition RetValPos = IRPosition::value(RV);
1485-
return genericValueTraversal<AAReturnedValues, RVState>(A, RetValPos, *this,
1486-
RVS, VisitValueCB);
1492+
return genericValueTraversal<AAReturnedValues, RVState>(
1493+
A, RetValPos, *this, RVS, VisitValueCB, CtxI);
14871494
};
14881495

14891496
// Callback for all "return intructions" live in the associated function.
14901497
auto CheckReturnInst = [this, &VisitReturnedValue, &Changed](Instruction &I) {
14911498
ReturnInst &Ret = cast<ReturnInst>(I);
14921499
RVState RVS({ReturnedValues, Changed, {}});
14931500
RVS.RetInsts.insert(&Ret);
1494-
return VisitReturnedValue(*Ret.getReturnValue(), RVS);
1501+
return VisitReturnedValue(*Ret.getReturnValue(), RVS, &I);
14951502
};
14961503

14971504
// Start by discovering returned values from all live returned instructions in
@@ -1576,7 +1583,7 @@ ChangeStatus AAReturnedValuesImpl::updateImpl(Attributor &A) {
15761583
// again.
15771584
bool Unused = false;
15781585
RVState RVS({NewRVsMap, Unused, RetValAAIt.second});
1579-
VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS);
1586+
VisitReturnedValue(*CB->getArgOperand(Arg->getArgNo()), RVS, CB);
15801587
continue;
15811588
} else if (isa<CallBase>(RetVal)) {
15821589
// Call sites are resolved by the callee attribute over time, no need to
@@ -2148,11 +2155,11 @@ struct AANonNullFloating
21482155
AC = InfoCache.getAnalysisResultForFunction<AssumptionAnalysis>(*Fn);
21492156
}
21502157

2151-
auto VisitValueCB = [&](Value &V, AANonNull::StateType &T,
2152-
bool Stripped) -> bool {
2158+
auto VisitValueCB = [&](Value &V, const Instruction *CtxI,
2159+
AANonNull::StateType &T, bool Stripped) -> bool {
21532160
const auto &AA = A.getAAFor<AANonNull>(*this, IRPosition::value(V));
21542161
if (!Stripped && this == &AA) {
2155-
if (!isKnownNonZero(&V, DL, 0, AC, getCtxI(), DT))
2162+
if (!isKnownNonZero(&V, DL, 0, AC, CtxI, DT))
21562163
T.indicatePessimisticFixpoint();
21572164
} else {
21582165
// Use abstract attribute information.
@@ -2164,8 +2171,8 @@ struct AANonNullFloating
21642171
};
21652172

21662173
StateType T;
2167-
if (!genericValueTraversal<AANonNull, StateType>(A, getIRPosition(), *this,
2168-
T, VisitValueCB))
2174+
if (!genericValueTraversal<AANonNull, StateType>(
2175+
A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
21692176
return indicatePessimisticFixpoint();
21702177

21712178
return clampStateAndIndicateChange(getState(), T);
@@ -3776,7 +3783,8 @@ struct AADereferenceableFloating
37763783

37773784
const DataLayout &DL = A.getDataLayout();
37783785

3779-
auto VisitValueCB = [&](Value &V, DerefState &T, bool Stripped) -> bool {
3786+
auto VisitValueCB = [&](Value &V, const Instruction *, DerefState &T,
3787+
bool Stripped) -> bool {
37803788
unsigned IdxWidth =
37813789
DL.getIndexSizeInBits(V.getType()->getPointerAddressSpace());
37823790
APInt Offset(IdxWidth, 0);
@@ -3831,7 +3839,7 @@ struct AADereferenceableFloating
38313839

38323840
DerefState T;
38333841
if (!genericValueTraversal<AADereferenceable, DerefState>(
3834-
A, getIRPosition(), *this, T, VisitValueCB))
3842+
A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
38353843
return indicatePessimisticFixpoint();
38363844

38373845
return Change | clampStateAndIndicateChange(getState(), T);
@@ -4073,8 +4081,8 @@ struct AAAlignFloating : AAFromMustBeExecutedContext<AAAlign, AAAlignImpl> {
40734081

40744082
const DataLayout &DL = A.getDataLayout();
40754083

4076-
auto VisitValueCB = [&](Value &V, AAAlign::StateType &T,
4077-
bool Stripped) -> bool {
4084+
auto VisitValueCB = [&](Value &V, const Instruction *,
4085+
AAAlign::StateType &T, bool Stripped) -> bool {
40784086
const auto &AA = A.getAAFor<AAAlign>(*this, IRPosition::value(V));
40794087
if (!Stripped && this == &AA) {
40804088
// Use only IR information if we did not strip anything.
@@ -4092,7 +4100,7 @@ struct AAAlignFloating : AAFromMustBeExecutedContext<AAAlign, AAAlignImpl> {
40924100

40934101
StateType T;
40944102
if (!genericValueTraversal<AAAlign, StateType>(A, getIRPosition(), *this, T,
4095-
VisitValueCB))
4103+
VisitValueCB, getCtxI()))
40964104
return indicatePessimisticFixpoint();
40974105

40984106
// TODO: If we know we visited all incoming values, thus no are assumed
@@ -4958,7 +4966,8 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl {
49584966
ChangeStatus updateImpl(Attributor &A) override {
49594967
bool HasValueBefore = SimplifiedAssociatedValue.hasValue();
49604968

4961-
auto VisitValueCB = [&](Value &V, bool &, bool Stripped) -> bool {
4969+
auto VisitValueCB = [&](Value &V, const Instruction *CtxI, bool &,
4970+
bool Stripped) -> bool {
49624971
auto &AA = A.getAAFor<AAValueSimplify>(*this, IRPosition::value(V));
49634972
if (!Stripped && this == &AA) {
49644973
// TODO: Look the instruction and check recursively.
@@ -4971,8 +4980,8 @@ struct AAValueSimplifyFloating : AAValueSimplifyImpl {
49714980
};
49724981

49734982
bool Dummy = false;
4974-
if (!genericValueTraversal<AAValueSimplify, bool>(A, getIRPosition(), *this,
4975-
Dummy, VisitValueCB))
4983+
if (!genericValueTraversal<AAValueSimplify, bool>(
4984+
A, getIRPosition(), *this, Dummy, VisitValueCB, getCtxI()))
49764985
if (!askSimplifiedValueForAAValueConstantRange(A))
49774986
return indicatePessimisticFixpoint();
49784987

@@ -6605,7 +6614,8 @@ void AAMemoryLocationImpl::categorizePtrValue(
66056614
return V;
66066615
};
66076616

6608-
auto VisitValueCB = [&](Value &V, AAMemoryLocation::StateType &T,
6617+
auto VisitValueCB = [&](Value &V, const Instruction *,
6618+
AAMemoryLocation::StateType &T,
66096619
bool Stripped) -> bool {
66106620
assert(!isa<GEPOperator>(V) && "GEPs should have been stripped.");
66116621
if (isa<UndefValue>(V))
@@ -6652,7 +6662,7 @@ void AAMemoryLocationImpl::categorizePtrValue(
66526662
};
66536663

66546664
if (!genericValueTraversal<AAMemoryLocation, AAMemoryLocation::StateType>(
6655-
A, IRPosition::value(Ptr), *this, State, VisitValueCB,
6665+
A, IRPosition::value(Ptr), *this, State, VisitValueCB, getCtxI(),
66566666
/* MaxValues */ 32, StripGEPCB)) {
66576667
LLVM_DEBUG(
66586668
dbgs() << "[AAMemoryLocation] Pointer locations not categorized\n");
@@ -7132,7 +7142,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
71327142

71337143
bool calculateBinaryOperator(
71347144
Attributor &A, BinaryOperator *BinOp, IntegerRangeState &T,
7135-
Instruction *CtxI,
7145+
const Instruction *CtxI,
71367146
SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
71377147
Value *LHS = BinOp->getOperand(0);
71387148
Value *RHS = BinOp->getOperand(1);
@@ -7160,7 +7170,8 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
71607170
}
71617171

71627172
bool calculateCastInst(
7163-
Attributor &A, CastInst *CastI, IntegerRangeState &T, Instruction *CtxI,
7173+
Attributor &A, CastInst *CastI, IntegerRangeState &T,
7174+
const Instruction *CtxI,
71647175
SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
71657176
assert(CastI->getNumOperands() == 1 && "Expected cast to be unary!");
71667177
// TODO: Allow non integers as well.
@@ -7178,7 +7189,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
71787189

71797190
bool
71807191
calculateCmpInst(Attributor &A, CmpInst *CmpI, IntegerRangeState &T,
7181-
Instruction *CtxI,
7192+
const Instruction *CtxI,
71827193
SmallVectorImpl<const AAValueConstantRange *> &QuerriedAAs) {
71837194
Value *LHS = CmpI->getOperand(0);
71847195
Value *RHS = CmpI->getOperand(1);
@@ -7233,9 +7244,8 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
72337244

72347245
/// See AbstractAttribute::updateImpl(...).
72357246
ChangeStatus updateImpl(Attributor &A) override {
7236-
Instruction *CtxI = getCtxI();
7237-
auto VisitValueCB = [&](Value &V, IntegerRangeState &T,
7238-
bool Stripped) -> bool {
7247+
auto VisitValueCB = [&](Value &V, const Instruction *CtxI,
7248+
IntegerRangeState &T, bool Stripped) -> bool {
72397249
Instruction *I = dyn_cast<Instruction>(&V);
72407250
if (!I || isa<CallBase>(I)) {
72417251

@@ -7285,7 +7295,7 @@ struct AAValueConstantRangeFloating : AAValueConstantRangeImpl {
72857295
IntegerRangeState T(getBitWidth());
72867296

72877297
if (!genericValueTraversal<AAValueConstantRange, IntegerRangeState>(
7288-
A, getIRPosition(), *this, T, VisitValueCB))
7298+
A, getIRPosition(), *this, T, VisitValueCB, getCtxI()))
72897299
return indicatePessimisticFixpoint();
72907300

72917301
return clampStateAndIndicateChange(getState(), T);

llvm/test/Transforms/Attributor/range.ll

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1212,6 +1212,76 @@ define i1 @callee_range_2(i1 %c1, i1 %c2) {
12121212
}
12131213

12141214

1215+
define i32 @ret100() {
1216+
; CHECK-LABEL: define {{[^@]+}}@ret100()
1217+
; CHECK-NEXT: ret i32 100
1218+
;
1219+
ret i32 100
1220+
}
1221+
1222+
define i1 @ctx_adjustment(i32 %V) {
1223+
; OLD_PM-LABEL: define {{[^@]+}}@ctx_adjustment
1224+
; OLD_PM-SAME: (i32 [[V:%.*]])
1225+
; OLD_PM-NEXT: [[C1:%.*]] = icmp sge i32 [[V]], 100
1226+
; OLD_PM-NEXT: br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
1227+
; OLD_PM: if.true:
1228+
; OLD_PM-NEXT: br label [[END:%.*]]
1229+
; OLD_PM: if.false:
1230+
; OLD_PM-NEXT: br label [[END]]
1231+
; OLD_PM: end:
1232+
; OLD_PM-NEXT: [[PHI:%.*]] = phi i32 [ [[V]], [[IF_TRUE]] ], [ 100, [[IF_FALSE]] ]
1233+
; OLD_PM-NEXT: [[C2:%.*]] = icmp sge i32 [[PHI]], 100
1234+
; OLD_PM-NEXT: ret i1 [[C2]]
1235+
;
1236+
; NEW_PM-LABEL: define {{[^@]+}}@ctx_adjustment
1237+
; NEW_PM-SAME: (i32 [[V:%.*]])
1238+
; NEW_PM-NEXT: [[C1:%.*]] = icmp sge i32 [[V]], 100
1239+
; NEW_PM-NEXT: br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
1240+
; NEW_PM: if.true:
1241+
; NEW_PM-NEXT: br label [[END:%.*]]
1242+
; NEW_PM: if.false:
1243+
; NEW_PM-NEXT: br label [[END]]
1244+
; NEW_PM: end:
1245+
; NEW_PM-NEXT: ret i1 true
1246+
;
1247+
; CGSCC_OLD_PM-LABEL: define {{[^@]+}}@ctx_adjustment
1248+
; CGSCC_OLD_PM-SAME: (i32 [[V:%.*]])
1249+
; CGSCC_OLD_PM-NEXT: [[C1:%.*]] = icmp sge i32 [[V]], 100
1250+
; CGSCC_OLD_PM-NEXT: br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
1251+
; CGSCC_OLD_PM: if.true:
1252+
; CGSCC_OLD_PM-NEXT: br label [[END:%.*]]
1253+
; CGSCC_OLD_PM: if.false:
1254+
; CGSCC_OLD_PM-NEXT: br label [[END]]
1255+
; CGSCC_OLD_PM: end:
1256+
; CGSCC_OLD_PM-NEXT: [[PHI:%.*]] = phi i32 [ [[V]], [[IF_TRUE]] ], [ 100, [[IF_FALSE]] ]
1257+
; CGSCC_OLD_PM-NEXT: [[C2:%.*]] = icmp sge i32 [[PHI]], 100
1258+
; CGSCC_OLD_PM-NEXT: ret i1 [[C2]]
1259+
;
1260+
; CGSCC_NEW_PM-LABEL: define {{[^@]+}}@ctx_adjustment
1261+
; CGSCC_NEW_PM-SAME: (i32 [[V:%.*]])
1262+
; CGSCC_NEW_PM-NEXT: [[C1:%.*]] = icmp sge i32 [[V]], 100
1263+
; CGSCC_NEW_PM-NEXT: br i1 [[C1]], label [[IF_TRUE:%.*]], label [[IF_FALSE:%.*]]
1264+
; CGSCC_NEW_PM: if.true:
1265+
; CGSCC_NEW_PM-NEXT: br label [[END:%.*]]
1266+
; CGSCC_NEW_PM: if.false:
1267+
; CGSCC_NEW_PM-NEXT: br label [[END]]
1268+
; CGSCC_NEW_PM: end:
1269+
; CGSCC_NEW_PM-NEXT: ret i1 true
1270+
;
1271+
%c1 = icmp sge i32 %V, 100
1272+
br i1 %c1, label %if.true, label %if.false
1273+
if.true:
1274+
br label %end
1275+
if.false:
1276+
%call = call i32 @ret100()
1277+
br label %end
1278+
end:
1279+
%phi = phi i32 [ %V, %if.true ], [ %call, %if.false ]
1280+
%c2 = icmp sge i32 %phi, 100
1281+
ret i1 %c2
1282+
}
1283+
1284+
12151285
!0 = !{i32 0, i32 10}
12161286
!1 = !{i32 10, i32 100}
12171287
; CHECK: !0 = !{i32 0, i32 10}

0 commit comments

Comments
 (0)