Skip to content

Commit 702c4ad

Browse files
[ISD::IndexType] Helper functions for common queries.
Add helper functions to query the signed and scaled properties of ISD::IndexType along with functions to change them. Remove setIndexType from MaskedGatherSDNode because it only has one usage and typically should only be changed alongside its index operand. Minimise the direct use of the enum values to lay the groundwork for more refactoring. Differential Revision: https://reviews.llvm.org/D123347
1 parent 1c5e85b commit 702c4ad

File tree

5 files changed

+56
-59
lines changed

5 files changed

+56
-59
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,6 +1358,26 @@ enum MemIndexType {
13581358

13591359
static const int LAST_MEM_INDEX_TYPE = UNSIGNED_UNSCALED + 1;
13601360

1361+
inline bool isIndexTypeScaled(MemIndexType IndexType) {
1362+
return IndexType == SIGNED_SCALED || IndexType == UNSIGNED_SCALED;
1363+
}
1364+
1365+
inline bool isIndexTypeSigned(MemIndexType IndexType) {
1366+
return IndexType == SIGNED_SCALED || IndexType == SIGNED_UNSCALED;
1367+
}
1368+
1369+
inline MemIndexType getSignedIndexType(MemIndexType IndexType) {
1370+
return isIndexTypeScaled(IndexType) ? SIGNED_SCALED : SIGNED_UNSCALED;
1371+
}
1372+
1373+
inline MemIndexType getUnsignedIndexType(MemIndexType IndexType) {
1374+
return isIndexTypeScaled(IndexType) ? UNSIGNED_SCALED : UNSIGNED_UNSCALED;
1375+
}
1376+
1377+
inline MemIndexType getUnscaledIndexType(MemIndexType IndexType) {
1378+
return isIndexTypeSigned(IndexType) ? SIGNED_UNSCALED : UNSIGNED_UNSCALED;
1379+
}
1380+
13611381
//===--------------------------------------------------------------------===//
13621382
/// LoadExtType enum - This enum defines the three variants of LOADEXT
13631383
/// (load with extension).

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,14 +2702,8 @@ class VPGatherScatterSDNode : public MemSDNode {
27022702
ISD::MemIndexType getIndexType() const {
27032703
return static_cast<ISD::MemIndexType>(LSBaseSDNodeBits.AddressingMode);
27042704
}
2705-
bool isIndexScaled() const {
2706-
return (getIndexType() == ISD::SIGNED_SCALED) ||
2707-
(getIndexType() == ISD::UNSIGNED_SCALED);
2708-
}
2709-
bool isIndexSigned() const {
2710-
return (getIndexType() == ISD::SIGNED_SCALED) ||
2711-
(getIndexType() == ISD::SIGNED_UNSCALED);
2712-
}
2705+
bool isIndexScaled() const { return isIndexTypeScaled(getIndexType()); }
2706+
bool isIndexSigned() const { return isIndexTypeSigned(getIndexType()); }
27132707

27142708
// In the both nodes address is Op1, mask is Op2:
27152709
// VPGatherSDNode (Chain, base, index, scale, mask, vlen)
@@ -2790,17 +2784,8 @@ class MaskedGatherScatterSDNode : public MemSDNode {
27902784
ISD::MemIndexType getIndexType() const {
27912785
return static_cast<ISD::MemIndexType>(LSBaseSDNodeBits.AddressingMode);
27922786
}
2793-
void setIndexType(ISD::MemIndexType IndexType) {
2794-
LSBaseSDNodeBits.AddressingMode = IndexType;
2795-
}
2796-
bool isIndexScaled() const {
2797-
return (getIndexType() == ISD::SIGNED_SCALED) ||
2798-
(getIndexType() == ISD::UNSIGNED_SCALED);
2799-
}
2800-
bool isIndexSigned() const {
2801-
return (getIndexType() == ISD::SIGNED_SCALED) ||
2802-
(getIndexType() == ISD::SIGNED_UNSCALED);
2803-
}
2787+
bool isIndexScaled() const { return isIndexTypeScaled(getIndexType()); }
2788+
bool isIndexSigned() const { return isIndexTypeSigned(getIndexType()); }
28042789

28052790
// In the both nodes address is Op1, mask is Op2:
28062791
// MaskedGatherSDNode (Chain, passthru, mask, base, index, scale)

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10471,24 +10471,27 @@ bool refineUniformBase(SDValue &BasePtr, SDValue &Index, bool IndexIsScaled,
1047110471
}
1047210472

1047310473
// Fold sext/zext of index into index type.
10474-
bool refineIndexType(MaskedGatherScatterSDNode *MGS, SDValue &Index,
10475-
bool Scaled, bool Signed, SelectionDAG &DAG) {
10474+
bool refineIndexType(SDValue &Index, ISD::MemIndexType &IndexType,
10475+
SelectionDAG &DAG) {
1047610476
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
1047710477

1047810478
// It's always safe to look through zero extends.
1047910479
if (Index.getOpcode() == ISD::ZERO_EXTEND) {
1048010480
SDValue Op = Index.getOperand(0);
10481-
MGS->setIndexType(Scaled ? ISD::UNSIGNED_SCALED : ISD::UNSIGNED_UNSCALED);
1048210481
if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
10482+
IndexType = ISD::getUnsignedIndexType(IndexType);
1048310483
Index = Op;
1048410484
return true;
10485+
} else if (ISD::isIndexTypeSigned(IndexType)) {
10486+
IndexType = ISD::getUnsignedIndexType(IndexType);
10487+
return true;
1048510488
}
1048610489
}
1048710490

1048810491
// It's only safe to look through sign extends when Index is signed.
10489-
if (Index.getOpcode() == ISD::SIGN_EXTEND && Signed) {
10492+
if (Index.getOpcode() == ISD::SIGN_EXTEND &&
10493+
ISD::isIndexTypeSigned(IndexType)) {
1049010494
SDValue Op = Index.getOperand(0);
10491-
MGS->setIndexType(Scaled ? ISD::SIGNED_SCALED : ISD::SIGNED_UNSCALED);
1049210495
if (TLI.shouldRemoveExtendFromGSIndex(Op.getValueType())) {
1049310496
Index = Op;
1049410497
return true;
@@ -10506,6 +10509,7 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
1050610509
SDValue Scale = MSC->getScale();
1050710510
SDValue StoreVal = MSC->getValue();
1050810511
SDValue BasePtr = MSC->getBasePtr();
10512+
ISD::MemIndexType IndexType = MSC->getIndexType();
1050910513
SDLoc DL(N);
1051010514

1051110515
// Zap scatters with a zero mask.
@@ -10514,17 +10518,16 @@ SDValue DAGCombiner::visitMSCATTER(SDNode *N) {
1051410518

1051510519
if (refineUniformBase(BasePtr, Index, MSC->isIndexScaled(), DAG)) {
1051610520
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
10517-
return DAG.getMaskedScatter(
10518-
DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
10519-
MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
10521+
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
10522+
DL, Ops, MSC->getMemOperand(), IndexType,
10523+
MSC->isTruncatingStore());
1052010524
}
1052110525

10522-
if (refineIndexType(MSC, Index, MSC->isIndexScaled(), MSC->isIndexSigned(),
10523-
DAG)) {
10526+
if (refineIndexType(Index, IndexType, DAG)) {
1052410527
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
10525-
return DAG.getMaskedScatter(
10526-
DAG.getVTList(MVT::Other), MSC->getMemoryVT(), DL, Ops,
10527-
MSC->getMemOperand(), MSC->getIndexType(), MSC->isTruncatingStore());
10528+
return DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MSC->getMemoryVT(),
10529+
DL, Ops, MSC->getMemOperand(), IndexType,
10530+
MSC->isTruncatingStore());
1052810531
}
1052910532

1053010533
return SDValue();
@@ -10602,6 +10605,7 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
1060210605
SDValue Scale = MGT->getScale();
1060310606
SDValue PassThru = MGT->getPassThru();
1060410607
SDValue BasePtr = MGT->getBasePtr();
10608+
ISD::MemIndexType IndexType = MGT->getIndexType();
1060510609
SDLoc DL(N);
1060610610

1060710611
// Zap gathers with a zero mask.
@@ -10610,19 +10614,16 @@ SDValue DAGCombiner::visitMGATHER(SDNode *N) {
1061010614

1061110615
if (refineUniformBase(BasePtr, Index, MGT->isIndexScaled(), DAG)) {
1061210616
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
10613-
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
10614-
MGT->getMemoryVT(), DL, Ops,
10615-
MGT->getMemOperand(), MGT->getIndexType(),
10616-
MGT->getExtensionType());
10617+
return DAG.getMaskedGather(
10618+
DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
10619+
Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
1061710620
}
1061810621

10619-
if (refineIndexType(MGT, Index, MGT->isIndexScaled(), MGT->isIndexSigned(),
10620-
DAG)) {
10622+
if (refineIndexType(Index, IndexType, DAG)) {
1062110623
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
10622-
return DAG.getMaskedGather(DAG.getVTList(N->getValueType(0), MVT::Other),
10623-
MGT->getMemoryVT(), DL, Ops,
10624-
MGT->getMemOperand(), MGT->getIndexType(),
10625-
MGT->getExtensionType());
10624+
return DAG.getMaskedGather(
10625+
DAG.getVTList(N->getValueType(0), MVT::Other), MGT->getMemoryVT(), DL,
10626+
Ops, MGT->getMemOperand(), IndexType, MGT->getExtensionType());
1062610627
}
1062710628

1062810629
return SDValue();

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8613,14 +8613,9 @@ SDValue TargetLowering::lowerCmpEqZeroToCtlzSrl(SDValue Op,
86138613
ISD::MemIndexType
86148614
TargetLowering::getCanonicalIndexType(ISD::MemIndexType IndexType, EVT MemVT,
86158615
SDValue Offsets) const {
8616-
bool IsScaledIndex =
8617-
(IndexType == ISD::SIGNED_SCALED) || (IndexType == ISD::UNSIGNED_SCALED);
8618-
bool IsSignedIndex =
8619-
(IndexType == ISD::SIGNED_SCALED) || (IndexType == ISD::SIGNED_UNSCALED);
8620-
86218616
// Scaling is unimportant for bytes, canonicalize to unscaled.
8622-
if (IsScaledIndex && MemVT.getScalarType() == MVT::i8)
8623-
return IsSignedIndex ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED;
8617+
if (ISD::isIndexTypeScaled(IndexType) && MemVT.getScalarType() == MVT::i8)
8618+
return ISD::getUnscaledIndexType(IndexType);
86248619

86258620
return IndexType;
86268621
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4711,10 +4711,8 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
47114711
return DAG.getMergeValues({Select, Load.getValue(1)}, DL);
47124712
}
47134713

4714-
bool IsScaled =
4715-
IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
4716-
bool IsSigned =
4717-
IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
4714+
bool IsScaled = MGT->isIndexScaled();
4715+
bool IsSigned = MGT->isIndexSigned();
47184716

47194717
// SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
47204718
// must be calculated before hand.
@@ -4727,7 +4725,7 @@ SDValue AArch64TargetLowering::LowerMGATHER(SDValue Op,
47274725
Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
47284726

47294727
SDValue Ops[] = {Chain, PassThru, Mask, BasePtr, Index, Scale};
4730-
IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED;
4728+
IndexType = getUnscaledIndexType(IndexType);
47314729
return DAG.getMaskedGather(MGT->getVTList(), MemVT, DL, Ops,
47324730
MGT->getMemOperand(), IndexType, ExtType);
47334731
}
@@ -4812,10 +4810,8 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
48124810
EVT MemVT = MSC->getMemoryVT();
48134811
ISD::MemIndexType IndexType = MSC->getIndexType();
48144812

4815-
bool IsScaled =
4816-
IndexType == ISD::SIGNED_SCALED || IndexType == ISD::UNSIGNED_SCALED;
4817-
bool IsSigned =
4818-
IndexType == ISD::SIGNED_SCALED || IndexType == ISD::SIGNED_UNSCALED;
4813+
bool IsScaled = MSC->isIndexScaled();
4814+
bool IsSigned = MSC->isIndexSigned();
48194815

48204816
// SVE supports an index scaled by sizeof(MemVT.elt) only, everything else
48214817
// must be calculated before hand.
@@ -4828,7 +4824,7 @@ SDValue AArch64TargetLowering::LowerMSCATTER(SDValue Op,
48284824
Scale = DAG.getTargetConstant(1, DL, Scale.getValueType());
48294825

48304826
SDValue Ops[] = {Chain, StoreVal, Mask, BasePtr, Index, Scale};
4831-
IndexType = IsSigned ? ISD::SIGNED_UNSCALED : ISD::UNSIGNED_UNSCALED;
4827+
IndexType = getUnscaledIndexType(IndexType);
48324828
return DAG.getMaskedScatter(MSC->getVTList(), MemVT, DL, Ops,
48334829
MSC->getMemOperand(), IndexType,
48344830
MSC->isTruncatingStore());

0 commit comments

Comments
 (0)