Skip to content

Commit 7e1422c

Browse files
committed
[DAGCombiner] Fold step_vector with add/mul/shl
This patch implements some DAG combines for STEP_VECTOR: add step_vector(C1), step_vector(C2) -> step_vector(C1+C2) add (add X step_vector(C1)), step_vector(C2) -> add X step_vector(C1+C2) mul step_vector(C1), C2 -> step_vector(C1*C2) shl step_vector(C1), C2 -> step_vector(C1<<C2) TestPlan: check-llvm Differential Revision: https://reviews.llvm.org/D100088
1 parent ea14df6 commit 7e1422c

File tree

2 files changed

+100
-0
lines changed

2 files changed

+100
-0
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2503,6 +2503,31 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
25032503
return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), VS);
25042504
}
25052505

2506+
// Fold (add step_vector(c1), step_vector(c2) to step_vector(c1+c2))
2507+
if (N0.getOpcode() == ISD::STEP_VECTOR &&
2508+
N1.getOpcode() == ISD::STEP_VECTOR) {
2509+
const APInt &C0 = N0->getConstantOperandAPInt(0);
2510+
const APInt &C1 = N1->getConstantOperandAPInt(0);
2511+
EVT SVT = N0.getOperand(0).getValueType();
2512+
SDValue NewStep = DAG.getConstant(C0 + C1, DL, SVT);
2513+
return DAG.getStepVector(DL, VT, NewStep);
2514+
}
2515+
2516+
// Fold a + step_vector(c1) + step_vector(c2) to a + step_vector(c1+c2)
2517+
if ((N0.getOpcode() == ISD::ADD) &&
2518+
(N0.getOperand(1).getOpcode() == ISD::STEP_VECTOR) &&
2519+
(N1.getOpcode() == ISD::STEP_VECTOR)) {
2520+
const APInt &SV0 = N0.getOperand(1)->getConstantOperandAPInt(0);
2521+
const APInt &SV1 = N1->getConstantOperandAPInt(0);
2522+
EVT SVT = N1.getOperand(0).getValueType();
2523+
assert(N1.getOperand(0).getValueType() ==
2524+
N0.getOperand(1)->getOperand(0).getValueType() &&
2525+
"Different operand types of STEP_VECTOR.");
2526+
SDValue NewStep = DAG.getConstant(SV0 + SV1, DL, SVT);
2527+
SDValue SV = DAG.getStepVector(DL, VT, NewStep);
2528+
return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
2529+
}
2530+
25062531
return SDValue();
25072532
}
25082533

@@ -3893,6 +3918,17 @@ SDValue DAGCombiner::visitMUL(SDNode *N) {
38933918
return DAG.getVScale(SDLoc(N), VT, C0 * C1);
38943919
}
38953920

3921+
// Fold (mul step_vector(C0), C1) to (step_vector(C0 * C1)).
3922+
APInt MulVal;
3923+
if (N0.getOpcode() == ISD::STEP_VECTOR)
3924+
if (ISD::isConstantSplatVector(N1.getNode(), MulVal)) {
3925+
const APInt &C0 = N0.getConstantOperandAPInt(0);
3926+
EVT SVT = N0.getOperand(0).getValueType();
3927+
SDValue NewStep = DAG.getConstant(
3928+
C0 * MulVal.sextOrTrunc(SVT.getSizeInBits()), SDLoc(N), SVT);
3929+
return DAG.getStepVector(SDLoc(N), VT, NewStep);
3930+
}
3931+
38963932
// Fold ((mul x, 0/undef) -> 0,
38973933
// (mul x, 1) -> x) -> x)
38983934
// -> and(x, mask)
@@ -8381,6 +8417,17 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
83818417
return DAG.getVScale(SDLoc(N), VT, C0 << C1);
83828418
}
83838419

8420+
// Fold (shl step_vector(C0), C1) to (step_vector(C0 << C1)).
8421+
APInt ShlVal;
8422+
if (N0.getOpcode() == ISD::STEP_VECTOR)
8423+
if (ISD::isConstantSplatVector(N1.getNode(), ShlVal)) {
8424+
const APInt &C0 = N0.getConstantOperandAPInt(0);
8425+
EVT SVT = N0.getOperand(0).getValueType();
8426+
SDValue NewStep = DAG.getConstant(
8427+
C0 << ShlVal.sextOrTrunc(SVT.getSizeInBits()), SDLoc(N), SVT);
8428+
return DAG.getStepVector(SDLoc(N), VT, NewStep);
8429+
}
8430+
83848431
return SDValue();
83858432
}
83868433

llvm/test/CodeGen/AArch64/sve-stepvector.ll

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,59 @@ entry:
105105
ret <vscale x 8 x i8> %0
106106
}
107107

108+
define <vscale x 8 x i8> @add_stepvector_nxv8i8() {
109+
; CHECK-LABEL: add_stepvector_nxv8i8:
110+
; CHECK: // %bb.0: // %entry
111+
; CHECK-NEXT: index z0.h, #0, #2
112+
; CHECK-NEXT: ret
113+
entry:
114+
%0 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
115+
%1 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
116+
%2 = add <vscale x 8 x i8> %0, %1
117+
ret <vscale x 8 x i8> %2
118+
}
119+
120+
define <vscale x 8 x i8> @add_stepvector_nxv8i8_1(<vscale x 8 x i8> %p) {
121+
; CHECK-LABEL: add_stepvector_nxv8i8_1:
122+
; CHECK: // %bb.0: // %entry
123+
; CHECK-NEXT: index z1.h, #0, #2
124+
; CHECK-NEXT: add z0.h, z0.h, z1.h
125+
; CHECK-NEXT: ret
126+
entry:
127+
%0 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
128+
%1 = add <vscale x 8 x i8> %p, %0
129+
%2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
130+
%3 = add <vscale x 8 x i8> %1, %2
131+
ret <vscale x 8 x i8> %3
132+
}
133+
134+
define <vscale x 8 x i8> @mul_stepvector_nxv8i8() {
135+
; CHECK-LABEL: mul_stepvector_nxv8i8:
136+
; CHECK: // %bb.0: // %entry
137+
; CHECK-NEXT: index z0.h, #0, #2
138+
; CHECK-NEXT: ret
139+
entry:
140+
%0 = insertelement <vscale x 8 x i8> poison, i8 2, i32 0
141+
%1 = shufflevector <vscale x 8 x i8> %0, <vscale x 8 x i8> poison, <vscale x 8 x i32> zeroinitializer
142+
%2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
143+
%3 = mul <vscale x 8 x i8> %2, %1
144+
ret <vscale x 8 x i8> %3
145+
}
146+
147+
define <vscale x 8 x i8> @shl_stepvector_nxv8i8() {
148+
; CHECK-LABEL: shl_stepvector_nxv8i8:
149+
; CHECK: // %bb.0: // %entry
150+
; CHECK-NEXT: index z0.h, #0, #4
151+
; CHECK-NEXT: ret
152+
entry:
153+
%0 = insertelement <vscale x 8 x i8> poison, i8 2, i32 0
154+
%1 = shufflevector <vscale x 8 x i8> %0, <vscale x 8 x i8> poison, <vscale x 8 x i32> zeroinitializer
155+
%2 = call <vscale x 8 x i8> @llvm.experimental.stepvector.nxv8i8()
156+
%3 = shl <vscale x 8 x i8> %2, %1
157+
ret <vscale x 8 x i8> %3
158+
}
159+
160+
108161
declare <vscale x 2 x i64> @llvm.experimental.stepvector.nxv2i64()
109162
declare <vscale x 4 x i32> @llvm.experimental.stepvector.nxv4i32()
110163
declare <vscale x 8 x i16> @llvm.experimental.stepvector.nxv8i16()

0 commit comments

Comments
 (0)