Skip to content

Commit b310dd1

Browse files
committed
[AArch64][SVE] Lower index_vector to step_vector
As discussed in D100107, this patch first convert index_vector to step_vector, and convert step_vector back to index_vector after LegalizeDAG. Differential Revision: https://reviews.llvm.org/D100816
1 parent ba5b015 commit b310dd1

File tree

5 files changed

+187
-75
lines changed

5 files changed

+187
-75
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
902902
setTargetDAGCombine(ISD::INSERT_VECTOR_ELT);
903903
setTargetDAGCombine(ISD::EXTRACT_VECTOR_ELT);
904904
setTargetDAGCombine(ISD::VECREDUCE_ADD);
905+
setTargetDAGCombine(ISD::STEP_VECTOR);
905906

906907
setTargetDAGCombine(ISD::GlobalAddress);
907908

@@ -1151,7 +1152,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
11511152
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
11521153
setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
11531154
setOperationAction(ISD::VECREDUCE_SMAX, VT, Custom);
1154-
setOperationAction(ISD::STEP_VECTOR, VT, Custom);
11551155

11561156
setOperationAction(ISD::UMUL_LOHI, VT, Expand);
11571157
setOperationAction(ISD::SMUL_LOHI, VT, Expand);
@@ -4476,8 +4476,6 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
44764476
return LowerVECTOR_SHUFFLE(Op, DAG);
44774477
case ISD::SPLAT_VECTOR:
44784478
return LowerSPLAT_VECTOR(Op, DAG);
4479-
case ISD::STEP_VECTOR:
4480-
return LowerSTEP_VECTOR(Op, DAG);
44814479
case ISD::EXTRACT_SUBVECTOR:
44824480
return LowerEXTRACT_SUBVECTOR(Op, DAG);
44834481
case ISD::INSERT_SUBVECTOR:
@@ -9162,20 +9160,6 @@ SDValue AArch64TargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
91629160
return GenerateTBL(Op, ShuffleMask, DAG);
91639161
}
91649162

9165-
SDValue AArch64TargetLowering::LowerSTEP_VECTOR(SDValue Op,
9166-
SelectionDAG &DAG) const {
9167-
SDLoc dl(Op);
9168-
EVT VT = Op.getValueType();
9169-
assert(VT.isScalableVector() &&
9170-
"Only expect scalable vectors for STEP_VECTOR");
9171-
assert(VT.getScalarType() != MVT::i1 &&
9172-
"Vectors of i1 types not supported for STEP_VECTOR");
9173-
9174-
SDValue StepVal = Op.getOperand(0);
9175-
SDValue Zero = DAG.getConstant(0, dl, StepVal.getValueType());
9176-
return DAG.getNode(AArch64ISD::INDEX_VECTOR, dl, VT, Zero, StepVal);
9177-
}
9178-
91799163
SDValue AArch64TargetLowering::LowerSPLAT_VECTOR(SDValue Op,
91809164
SelectionDAG &DAG) const {
91819165
SDLoc dl(Op);
@@ -9261,9 +9245,7 @@ SDValue AArch64TargetLowering::LowerDUPQLane(SDValue Op,
92619245
SDValue SplatOne = DAG.getNode(ISD::SPLAT_VECTOR, DL, MVT::nxv2i64, One);
92629246

92639247
// create the vector 0,1,0,1,...
9264-
SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
9265-
SDValue SV = DAG.getNode(AArch64ISD::INDEX_VECTOR,
9266-
DL, MVT::nxv2i64, Zero, One);
9248+
SDValue SV = DAG.getNode(ISD::STEP_VECTOR, DL, MVT::nxv2i64, One);
92679249
SV = DAG.getNode(ISD::AND, DL, MVT::nxv2i64, SV, SplatOne);
92689250

92699251
// create the vector idx64,idx64+1,idx64,idx64+1,...
@@ -13665,15 +13647,18 @@ static SDValue LowerSVEIntrinsicIndex(SDNode *N, SelectionDAG &DAG) {
1366513647
SDLoc DL(N);
1366613648
SDValue Op1 = N->getOperand(1);
1366713649
SDValue Op2 = N->getOperand(2);
13668-
EVT ScalarTy = Op1.getValueType();
13669-
13670-
if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16)) {
13671-
Op1 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op1);
13672-
Op2 = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Op2);
13673-
}
13650+
EVT ScalarTy = Op2.getValueType();
13651+
if ((ScalarTy == MVT::i8) || (ScalarTy == MVT::i16))
13652+
ScalarTy = MVT::i32;
1367413653

13675-
return DAG.getNode(AArch64ISD::INDEX_VECTOR, DL, N->getValueType(0),
13676-
Op1, Op2);
13654+
// Lower index_vector(base, step) to mul(step step_vector(1)) + splat(base).
13655+
SDValue One = DAG.getConstant(1, DL, ScalarTy);
13656+
SDValue StepVector =
13657+
DAG.getNode(ISD::STEP_VECTOR, DL, N->getValueType(0), One);
13658+
SDValue Step = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op2);
13659+
SDValue Mul = DAG.getNode(ISD::MUL, DL, N->getValueType(0), StepVector, Step);
13660+
SDValue Base = DAG.getNode(ISD::SPLAT_VECTOR, DL, N->getValueType(0), Op1);
13661+
return DAG.getNode(ISD::ADD, DL, N->getValueType(0), Mul, Base);
1367713662
}
1367813663

1367913664
static SDValue LowerSVEIntrinsicDUP(SDNode *N, SelectionDAG &DAG) {
@@ -15463,6 +15448,19 @@ static SDValue performGlobalAddressCombine(SDNode *N, SelectionDAG &DAG,
1546315448
DAG.getConstant(MinOffset, DL, MVT::i64));
1546415449
}
1546515450

15451+
static SDValue performStepVectorCombine(SDNode *N,
15452+
TargetLowering::DAGCombinerInfo &DCI,
15453+
SelectionDAG &DAG) {
15454+
if (!DCI.isAfterLegalizeDAG())
15455+
return SDValue();
15456+
15457+
SDLoc DL(N);
15458+
EVT VT = N->getValueType(0);
15459+
SDValue StepVal = N->getOperand(0);
15460+
SDValue Zero = DAG.getConstant(0, DL, StepVal.getValueType());
15461+
return DAG.getNode(AArch64ISD::INDEX_VECTOR, DL, VT, Zero, StepVal);
15462+
}
15463+
1546615464
// Turns the vector of indices into a vector of byte offstes by scaling Offset
1546715465
// by (BitWidth / 8).
1546815466
static SDValue getScaledOffsetForBitWidth(SelectionDAG &DAG, SDValue Offset,
@@ -15977,6 +15975,8 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
1597715975
return performExtractVectorEltCombine(N, DAG);
1597815976
case ISD::VECREDUCE_ADD:
1597915977
return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
15978+
case ISD::STEP_VECTOR:
15979+
return performStepVectorCombine(N, DCI, DAG);
1598015980
case ISD::INTRINSIC_VOID:
1598115981
case ISD::INTRINSIC_W_CHAIN:
1598215982
switch (cast<ConstantSDNode>(N->getOperand(1))->getZExtValue()) {

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -938,7 +938,6 @@ class AArch64TargetLowering : public TargetLowering {
938938
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
939939
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;
940940
SDValue LowerSPLAT_VECTOR(SDValue Op, SelectionDAG &DAG) const;
941-
SDValue LowerSTEP_VECTOR(SDValue Op, SelectionDAG &DAG) const;
942941
SDValue LowerDUPQLane(SDValue Op, SelectionDAG &DAG) const;
943942
SDValue LowerToPredicatedOp(SDValue Op, SelectionDAG &DAG, unsigned NewOp,
944943
bool OverrideNEON = false) const;

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1370,8 +1370,8 @@ let Predicates = [HasSVE] in {
13701370
defm INCP_ZP : sve_int_count_v<0b10000, "incp">;
13711371
defm DECP_ZP : sve_int_count_v<0b10100, "decp">;
13721372

1373-
defm INDEX_RR : sve_int_index_rr<"index", index_vector, index_vector_oneuse>;
1374-
defm INDEX_IR : sve_int_index_ir<"index", index_vector, index_vector_oneuse>;
1373+
defm INDEX_RR : sve_int_index_rr<"index", index_vector, index_vector_oneuse, AArch64mul_p_oneuse>;
1374+
defm INDEX_IR : sve_int_index_ir<"index", index_vector, index_vector_oneuse, AArch64mul_p, AArch64mul_p_oneuse>;
13751375
defm INDEX_RI : sve_int_index_ri<"index", index_vector, index_vector_oneuse>;
13761376
defm INDEX_II : sve_int_index_ii<"index", index_vector, index_vector_oneuse>;
13771377

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4837,7 +4837,7 @@ class sve_int_index_ir<bits<2> sz8_64, string asm, ZPRRegOp zprty,
48374837
let Inst{4-0} = Zd;
48384838
}
48394839

4840-
multiclass sve_int_index_ir<string asm, SDPatternOperator op, SDPatternOperator oneuseop> {
4840+
multiclass sve_int_index_ir<string asm, SDPatternOperator op, SDPatternOperator oneuseop, SDPatternOperator mulop, SDPatternOperator muloneuseop> {
48414841
def _B : sve_int_index_ir<0b00, asm, ZPR8, GPR32, simm5_8b>;
48424842
def _H : sve_int_index_ir<0b01, asm, ZPR16, GPR32, simm5_16b>;
48434843
def _S : sve_int_index_ir<0b10, asm, ZPR32, GPR32, simm5_32b>;
@@ -4862,6 +4862,25 @@ multiclass sve_int_index_ir<string asm, SDPatternOperator op, SDPatternOperator
48624862
def : Pat<(add (nxv2i64 (oneuseop (i64 0), GPR64:$Rm)), (nxv2i64 (AArch64dup(simm5_64b:$imm5)))),
48634863
(!cast<Instruction>(NAME # "_D") simm5_64b:$imm5, GPR64:$Rm)>;
48644864

4865+
// mul(index_vector(0, 1), dup(Y)) -> index_vector(0, Y).
4866+
def : Pat<(mulop (nxv16i1 (AArch64ptrue 31)), (nxv16i8 (oneuseop (i32 0), (i32 1))), (nxv16i8 (AArch64dup(i32 GPR32:$Rm)))),
4867+
(!cast<Instruction>(NAME # "_B") (i32 0), GPR32:$Rm)>;
4868+
def : Pat<(mulop (nxv8i1 (AArch64ptrue 31)), (nxv8i16 (oneuseop (i32 0), (i32 1))), (nxv8i16 (AArch64dup(i32 GPR32:$Rm)))),
4869+
(!cast<Instruction>(NAME # "_H") (i32 0), GPR32:$Rm)>;
4870+
def : Pat<(mulop (nxv4i1 (AArch64ptrue 31)), (nxv4i32 (oneuseop (i32 0), (i32 1))), (nxv4i32 (AArch64dup(i32 GPR32:$Rm)))),
4871+
(!cast<Instruction>(NAME # "_S") (i32 0), GPR32:$Rm)>;
4872+
def : Pat<(mulop (nxv2i1 (AArch64ptrue 31)), (nxv2i64 (oneuseop (i64 0), (i64 1))), (nxv2i64 (AArch64dup(i64 GPR64:$Rm)))),
4873+
(!cast<Instruction>(NAME # "_D") (i64 0), GPR64:$Rm)>;
4874+
4875+
// add(mul(index_vector(0, 1), dup(Y), dup(X)) -> index_vector(X, Y).
4876+
def : Pat<(add (muloneuseop (nxv16i1 (AArch64ptrue 31)), (nxv16i8 (oneuseop (i32 0), (i32 1))), (nxv16i8 (AArch64dup(i32 GPR32:$Rm)))), (nxv16i8 (AArch64dup(simm5_8b:$imm5)))),
4877+
(!cast<Instruction>(NAME # "_B") simm5_8b:$imm5, GPR32:$Rm)>;
4878+
def : Pat<(add (muloneuseop (nxv8i1 (AArch64ptrue 31)), (nxv8i16 (oneuseop (i32 0), (i32 1))), (nxv8i16 (AArch64dup(i32 GPR32:$Rm)))), (nxv8i16 (AArch64dup(simm5_16b:$imm5)))),
4879+
(!cast<Instruction>(NAME # "_H") simm5_16b:$imm5, GPR32:$Rm)>;
4880+
def : Pat<(add (muloneuseop (nxv4i1 (AArch64ptrue 31)), (nxv4i32 (oneuseop (i32 0), (i32 1))), (nxv4i32 (AArch64dup(i32 GPR32:$Rm)))), (nxv4i32 (AArch64dup(simm5_32b:$imm5)))),
4881+
(!cast<Instruction>(NAME # "_S") simm5_32b:$imm5, GPR32:$Rm)>;
4882+
def : Pat<(add (muloneuseop (nxv2i1 (AArch64ptrue 31)), (nxv2i64 (oneuseop (i64 0), (i64 1))), (nxv2i64 (AArch64dup(i64 GPR64:$Rm)))), (nxv2i64 (AArch64dup(simm5_64b:$imm5)))),
4883+
(!cast<Instruction>(NAME # "_D") simm5_64b:$imm5, GPR64:$Rm)>;
48654884
}
48664885

48674886
class sve_int_index_ri<bits<2> sz8_64, string asm, ZPRRegOp zprty,
@@ -4924,7 +4943,7 @@ class sve_int_index_rr<bits<2> sz8_64, string asm, ZPRRegOp zprty,
49244943
let Inst{4-0} = Zd;
49254944
}
49264945

4927-
multiclass sve_int_index_rr<string asm, SDPatternOperator op, SDPatternOperator oneuseop> {
4946+
multiclass sve_int_index_rr<string asm, SDPatternOperator op, SDPatternOperator oneuseop, SDPatternOperator mulop> {
49284947
def _B : sve_int_index_rr<0b00, asm, ZPR8, GPR32>;
49294948
def _H : sve_int_index_rr<0b01, asm, ZPR16, GPR32>;
49304949
def _S : sve_int_index_rr<0b10, asm, ZPR32, GPR32>;
@@ -4944,6 +4963,16 @@ multiclass sve_int_index_rr<string asm, SDPatternOperator op, SDPatternOperator
49444963
(!cast<Instruction>(NAME # "_S") GPR32:$Rn, GPR32:$Rm)>;
49454964
def : Pat<(add (nxv2i64 (oneuseop (i64 0), GPR64:$Rm)), (nxv2i64 (AArch64dup(i64 GPR64:$Rn)))),
49464965
(!cast<Instruction>(NAME # "_D") GPR64:$Rn, GPR64:$Rm)>;
4966+
4967+
// add(mul(index_vector(0, 1), dup(Y), dup(X)) -> index_vector(X, Y).
4968+
def : Pat<(add (mulop (nxv16i1 (AArch64ptrue 31)), (nxv16i8 (oneuseop (i32 0), (i32 1))), (nxv16i8 (AArch64dup(i32 GPR32:$Rm)))), (nxv16i8 (AArch64dup(i32 GPR32:$Rn)))),
4969+
(!cast<Instruction>(NAME # "_B") GPR32:$Rn, GPR32:$Rm)>;
4970+
def : Pat<(add (mulop (nxv8i1 (AArch64ptrue 31)), (nxv8i16 (oneuseop (i32 0), (i32 1))), (nxv8i16 (AArch64dup(i32 GPR32:$Rm)))),(nxv8i16 (AArch64dup(i32 GPR32:$Rn)))),
4971+
(!cast<Instruction>(NAME # "_H") GPR32:$Rn, GPR32:$Rm)>;
4972+
def : Pat<(add (mulop (nxv4i1 (AArch64ptrue 31)), (nxv4i32 (oneuseop (i32 0), (i32 1))), (nxv4i32 (AArch64dup(i32 GPR32:$Rm)))),(nxv4i32 (AArch64dup(i32 GPR32:$Rn)))),
4973+
(!cast<Instruction>(NAME # "_S") GPR32:$Rn, GPR32:$Rm)>;
4974+
def : Pat<(add (mulop (nxv2i1 (AArch64ptrue 31)), (nxv2i64 (oneuseop (i64 0), (i64 1))), (nxv2i64 (AArch64dup(i64 GPR64:$Rm)))),(nxv2i64 (AArch64dup(i64 GPR64:$Rn)))),
4975+
(!cast<Instruction>(NAME # "_D") GPR64:$Rn, GPR64:$Rm)>;
49474976
}
49484977

49494978
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)