Skip to content

Commit 0c09ace

Browse files
committed
[AArch64][GlobalISel] Add push_mul_through_s/zext
This extends the existing push_add_through_zext to handle mul, similar to performVectorExtCombine in SDAG. This allows muls to be pushed up the tree of extends, operating on smaller vector types whilst keeping the result the same (providing there are > 2x bits in the output). matchExtAddvToUdotAddv needs to be adjusted to make sure it keeps generating dot instructions from add(ext(mul(ext, ext))).
1 parent acd264d commit 0c09ace

File tree

6 files changed

+2971
-1903
lines changed

6 files changed

+2971
-1903
lines changed

llvm/lib/Target/AArch64/AArch64Combine.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,10 @@ class push_opcode_through_ext<Instruction opcode, Instruction extOpcode> : GICom
6262

6363
def push_sub_through_zext : push_opcode_through_ext<G_SUB, G_ZEXT>;
6464
def push_add_through_zext : push_opcode_through_ext<G_ADD, G_ZEXT>;
65+
def push_mul_through_zext : push_opcode_through_ext<G_MUL, G_ZEXT>;
6566
def push_sub_through_sext : push_opcode_through_ext<G_SUB, G_SEXT>;
6667
def push_add_through_sext : push_opcode_through_ext<G_ADD, G_SEXT>;
68+
def push_mul_through_sext : push_opcode_through_ext<G_MUL, G_SEXT>;
6769

6870
def AArch64PreLegalizerCombiner: GICombiner<
6971
"AArch64PreLegalizerCombinerImpl", [all_combines,
@@ -75,8 +77,10 @@ def AArch64PreLegalizerCombiner: GICombiner<
7577
ext_uaddv_to_uaddlv,
7678
push_sub_through_zext,
7779
push_add_through_zext,
80+
push_mul_through_zext,
7881
push_sub_through_sext,
79-
push_add_through_sext]> {
82+
push_add_through_sext,
83+
push_mul_through_sext]> {
8084
let CombineAllMethodName = "tryCombineAllImpl";
8185
}
8286

llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
229229
}
230230

231231
// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y))
232+
// Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add(udot(x, y))
232233
// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1))
233234
// Similar to performVecReduceAddCombine in SelectionDAG
234235
bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
@@ -246,31 +247,57 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
246247
if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
247248
return false;
248249

249-
LLT SrcTy;
250-
auto I1Opc = I1->getOpcode();
251-
if (I1Opc == TargetOpcode::G_MUL) {
250+
// Detect mul(ext, ext) with symetric ext's. If I1Opc is G_ZEXT or G_SEXT then
251+
// the ext's must match the same opcode. It is set to the ext opcode on
252+
// output.
253+
auto tryMatchingMulOfExt = [&MRI](MachineInstr *MI, Register &Out1,
254+
Register &Out2, unsigned &I1Opc) {
252255
// If result of this has more than 1 use, then there is no point in creating
253-
// udot instruction
254-
if (!MRI.hasOneNonDBGUse(MidReg))
256+
// a dot instruction
257+
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
255258
return false;
256259

257260
MachineInstr *ExtMI1 =
258-
getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
261+
getDefIgnoringCopies(MI->getOperand(1).getReg(), MRI);
259262
MachineInstr *ExtMI2 =
260-
getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
263+
getDefIgnoringCopies(MI->getOperand(2).getReg(), MRI);
261264
LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
262265
LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());
263266

264267
if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
265268
return false;
269+
if ((I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) &&
270+
I1Opc != ExtMI1->getOpcode())
271+
return false;
272+
Out1 = ExtMI1->getOperand(1).getReg();
273+
Out2 = ExtMI2->getOperand(1).getReg();
266274
I1Opc = ExtMI1->getOpcode();
267-
SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
268-
std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg();
269-
std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg();
275+
return true;
276+
};
277+
278+
LLT SrcTy;
279+
unsigned I1Opc = I1->getOpcode();
280+
if (I1Opc == TargetOpcode::G_MUL) {
281+
Register Out1, Out2;
282+
if (!tryMatchingMulOfExt(I1, Out1, Out2, I1Opc))
283+
return false;
284+
SrcTy = MRI.getType(Out1);
285+
std::get<0>(MatchInfo) = Out1;
286+
std::get<1>(MatchInfo) = Out2;
270287
} else if (I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) {
271-
SrcTy = MRI.getType(I1->getOperand(1).getReg());
272-
std::get<0>(MatchInfo) = I1->getOperand(1).getReg();
273-
std::get<1>(MatchInfo) = 0;
288+
Register I1Op = I1->getOperand(1).getReg();
289+
MachineInstr *M = getDefIgnoringCopies(I1Op, MRI);
290+
Register Out1, Out2;
291+
if (M->getOpcode() == TargetOpcode::G_MUL &&
292+
tryMatchingMulOfExt(M, Out1, Out2, I1Opc)) {
293+
SrcTy = MRI.getType(Out1);
294+
std::get<0>(MatchInfo) = Out1;
295+
std::get<1>(MatchInfo) = Out2;
296+
} else {
297+
SrcTy = MRI.getType(I1Op);
298+
std::get<0>(MatchInfo) = I1Op;
299+
std::get<1>(MatchInfo) = 0;
300+
}
274301
} else {
275302
return false;
276303
}
@@ -553,14 +580,14 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
553580
MI.eraseFromParent();
554581
}
555582

556-
// Pushes ADD/SUB through extend instructions to decrease the number of extend
557-
// instruction at the end by allowing selection of {s|u}addl sooner
558-
559-
// i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
583+
// Pushes ADD/SUB/MUL through extend instructions to decrease the number of
584+
// extend instruction at the end by allowing selection of {s|u}addl sooner i32
585+
// add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
560586
bool matchPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
561587
Register DstReg, Register SrcReg1, Register SrcReg2) {
562588
assert((MI.getOpcode() == TargetOpcode::G_ADD ||
563-
MI.getOpcode() == TargetOpcode::G_SUB) &&
589+
MI.getOpcode() == TargetOpcode::G_SUB ||
590+
MI.getOpcode() == TargetOpcode::G_MUL) &&
564591
"Expected a G_ADD or G_SUB instruction\n");
565592

566593
// Deal with vector types only
@@ -594,9 +621,9 @@ void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
594621
B.buildInstr(MI.getOpcode(), {MidTy}, {Ext1Reg, Ext2Reg}).getReg(0);
595622

596623
// G_SUB has to sign-extend the result.
597-
// G_ADD needs to sext from sext and can sext or zext from zext, so the
598-
// original opcode is used.
599-
if (MI.getOpcode() == TargetOpcode::G_ADD)
624+
// G_ADD needs to sext from sext and can sext or zext from zext, and G_MUL
625+
// needs to use the original opcode so the original opcode is used for both.
626+
if (MI.getOpcode() != TargetOpcode::G_SUB)
600627
B.buildInstr(Opc, {DstReg}, {AddReg});
601628
else
602629
B.buildSExt(DstReg, AddReg);

llvm/test/CodeGen/AArch64/aarch64-wide-mul.ll

Lines changed: 41 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,12 @@ define <16 x i32> @mul_i32(<16 x i8> %a, <16 x i8> %b) {
3838
;
3939
; CHECK-GI-LABEL: mul_i32:
4040
; CHECK-GI: // %bb.0: // %entry
41-
; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
42-
; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
43-
; CHECK-GI-NEXT: ushll2 v4.8h, v0.16b, #0
44-
; CHECK-GI-NEXT: ushll2 v5.8h, v1.16b, #0
45-
; CHECK-GI-NEXT: umull v0.4s, v2.4h, v3.4h
46-
; CHECK-GI-NEXT: umull2 v1.4s, v2.8h, v3.8h
47-
; CHECK-GI-NEXT: umull v2.4s, v4.4h, v5.4h
48-
; CHECK-GI-NEXT: umull2 v3.4s, v4.8h, v5.8h
41+
; CHECK-GI-NEXT: umull v2.8h, v0.8b, v1.8b
42+
; CHECK-GI-NEXT: umull2 v3.8h, v0.16b, v1.16b
43+
; CHECK-GI-NEXT: ushll v0.4s, v2.4h, #0
44+
; CHECK-GI-NEXT: ushll2 v1.4s, v2.8h, #0
45+
; CHECK-GI-NEXT: ushll v2.4s, v3.4h, #0
46+
; CHECK-GI-NEXT: ushll2 v3.4s, v3.8h, #0
4947
; CHECK-GI-NEXT: ret
5048
entry:
5149
%ea = zext <16 x i8> %a to <16 x i32>
@@ -75,26 +73,20 @@ define <16 x i64> @mul_i64(<16 x i8> %a, <16 x i8> %b) {
7573
;
7674
; CHECK-GI-LABEL: mul_i64:
7775
; CHECK-GI: // %bb.0: // %entry
78-
; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
79-
; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
80-
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
81-
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
82-
; CHECK-GI-NEXT: ushll v4.4s, v2.4h, #0
83-
; CHECK-GI-NEXT: ushll2 v5.4s, v2.8h, #0
84-
; CHECK-GI-NEXT: ushll v2.4s, v3.4h, #0
85-
; CHECK-GI-NEXT: ushll v6.4s, v0.4h, #0
86-
; CHECK-GI-NEXT: ushll2 v3.4s, v3.8h, #0
87-
; CHECK-GI-NEXT: ushll v7.4s, v1.4h, #0
88-
; CHECK-GI-NEXT: ushll2 v16.4s, v0.8h, #0
89-
; CHECK-GI-NEXT: ushll2 v17.4s, v1.8h, #0
90-
; CHECK-GI-NEXT: umull v0.2d, v4.2s, v2.2s
91-
; CHECK-GI-NEXT: umull2 v1.2d, v4.4s, v2.4s
92-
; CHECK-GI-NEXT: umull v2.2d, v5.2s, v3.2s
93-
; CHECK-GI-NEXT: umull2 v3.2d, v5.4s, v3.4s
94-
; CHECK-GI-NEXT: umull v4.2d, v6.2s, v7.2s
95-
; CHECK-GI-NEXT: umull2 v5.2d, v6.4s, v7.4s
96-
; CHECK-GI-NEXT: umull v6.2d, v16.2s, v17.2s
97-
; CHECK-GI-NEXT: umull2 v7.2d, v16.4s, v17.4s
76+
; CHECK-GI-NEXT: umull v2.8h, v0.8b, v1.8b
77+
; CHECK-GI-NEXT: umull2 v0.8h, v0.16b, v1.16b
78+
; CHECK-GI-NEXT: ushll v1.4s, v2.4h, #0
79+
; CHECK-GI-NEXT: ushll2 v3.4s, v2.8h, #0
80+
; CHECK-GI-NEXT: ushll v5.4s, v0.4h, #0
81+
; CHECK-GI-NEXT: ushll2 v7.4s, v0.8h, #0
82+
; CHECK-GI-NEXT: ushll v0.2d, v1.2s, #0
83+
; CHECK-GI-NEXT: ushll2 v1.2d, v1.4s, #0
84+
; CHECK-GI-NEXT: ushll v2.2d, v3.2s, #0
85+
; CHECK-GI-NEXT: ushll2 v3.2d, v3.4s, #0
86+
; CHECK-GI-NEXT: ushll v4.2d, v5.2s, #0
87+
; CHECK-GI-NEXT: ushll2 v5.2d, v5.4s, #0
88+
; CHECK-GI-NEXT: ushll v6.2d, v7.2s, #0
89+
; CHECK-GI-NEXT: ushll2 v7.2d, v7.4s, #0
9890
; CHECK-GI-NEXT: ret
9991
entry:
10092
%ea = zext <16 x i8> %a to <16 x i64>
@@ -142,18 +134,12 @@ define <16 x i32> @mla_i32(<16 x i8> %a, <16 x i8> %b, <16 x i32> %c) {
142134
;
143135
; CHECK-GI-LABEL: mla_i32:
144136
; CHECK-GI: // %bb.0: // %entry
145-
; CHECK-GI-NEXT: ushll v6.8h, v0.8b, #0
146-
; CHECK-GI-NEXT: ushll v7.8h, v1.8b, #0
147-
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
148-
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
149-
; CHECK-GI-NEXT: umlal v2.4s, v6.4h, v7.4h
150-
; CHECK-GI-NEXT: umlal2 v3.4s, v6.8h, v7.8h
151-
; CHECK-GI-NEXT: umlal v4.4s, v0.4h, v1.4h
152-
; CHECK-GI-NEXT: umlal2 v5.4s, v0.8h, v1.8h
153-
; CHECK-GI-NEXT: mov v0.16b, v2.16b
154-
; CHECK-GI-NEXT: mov v1.16b, v3.16b
155-
; CHECK-GI-NEXT: mov v2.16b, v4.16b
156-
; CHECK-GI-NEXT: mov v3.16b, v5.16b
137+
; CHECK-GI-NEXT: umull v6.8h, v0.8b, v1.8b
138+
; CHECK-GI-NEXT: umull2 v7.8h, v0.16b, v1.16b
139+
; CHECK-GI-NEXT: uaddw v0.4s, v2.4s, v6.4h
140+
; CHECK-GI-NEXT: uaddw2 v1.4s, v3.4s, v6.8h
141+
; CHECK-GI-NEXT: uaddw v2.4s, v4.4s, v7.4h
142+
; CHECK-GI-NEXT: uaddw2 v3.4s, v5.4s, v7.8h
157143
; CHECK-GI-NEXT: ret
158144
entry:
159145
%ea = zext <16 x i8> %a to <16 x i32>
@@ -186,35 +172,21 @@ define <16 x i64> @mla_i64(<16 x i8> %a, <16 x i8> %b, <16 x i64> %c) {
186172
;
187173
; CHECK-GI-LABEL: mla_i64:
188174
; CHECK-GI: // %bb.0: // %entry
189-
; CHECK-GI-NEXT: mov v16.16b, v2.16b
190-
; CHECK-GI-NEXT: mov v17.16b, v3.16b
191-
; CHECK-GI-NEXT: mov v2.16b, v4.16b
192-
; CHECK-GI-NEXT: mov v3.16b, v5.16b
193-
; CHECK-GI-NEXT: mov v4.16b, v6.16b
194-
; CHECK-GI-NEXT: mov v5.16b, v7.16b
195-
; CHECK-GI-NEXT: ushll v6.8h, v0.8b, #0
196-
; CHECK-GI-NEXT: ushll v7.8h, v1.8b, #0
197-
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
198-
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
199-
; CHECK-GI-NEXT: ushll v18.4s, v6.4h, #0
200-
; CHECK-GI-NEXT: ushll v20.4s, v7.4h, #0
201-
; CHECK-GI-NEXT: ushll2 v19.4s, v6.8h, #0
202-
; CHECK-GI-NEXT: ushll v21.4s, v0.4h, #0
203-
; CHECK-GI-NEXT: ushll2 v22.4s, v7.8h, #0
204-
; CHECK-GI-NEXT: ushll v23.4s, v1.4h, #0
205-
; CHECK-GI-NEXT: ldp q6, q7, [sp]
206-
; CHECK-GI-NEXT: ushll2 v0.4s, v0.8h, #0
207-
; CHECK-GI-NEXT: ushll2 v1.4s, v1.8h, #0
208-
; CHECK-GI-NEXT: umlal v16.2d, v18.2s, v20.2s
209-
; CHECK-GI-NEXT: umlal2 v17.2d, v18.4s, v20.4s
210-
; CHECK-GI-NEXT: umlal v2.2d, v19.2s, v22.2s
211-
; CHECK-GI-NEXT: umlal2 v3.2d, v19.4s, v22.4s
212-
; CHECK-GI-NEXT: umlal v4.2d, v21.2s, v23.2s
213-
; CHECK-GI-NEXT: umlal2 v5.2d, v21.4s, v23.4s
214-
; CHECK-GI-NEXT: umlal v6.2d, v0.2s, v1.2s
215-
; CHECK-GI-NEXT: umlal2 v7.2d, v0.4s, v1.4s
216-
; CHECK-GI-NEXT: mov v0.16b, v16.16b
217-
; CHECK-GI-NEXT: mov v1.16b, v17.16b
175+
; CHECK-GI-NEXT: umull v16.8h, v0.8b, v1.8b
176+
; CHECK-GI-NEXT: umull2 v0.8h, v0.16b, v1.16b
177+
; CHECK-GI-NEXT: ldp q19, q20, [sp]
178+
; CHECK-GI-NEXT: ushll v1.4s, v16.4h, #0
179+
; CHECK-GI-NEXT: ushll2 v16.4s, v16.8h, #0
180+
; CHECK-GI-NEXT: ushll v17.4s, v0.4h, #0
181+
; CHECK-GI-NEXT: ushll2 v18.4s, v0.8h, #0
182+
; CHECK-GI-NEXT: uaddw v0.2d, v2.2d, v1.2s
183+
; CHECK-GI-NEXT: uaddw2 v1.2d, v3.2d, v1.4s
184+
; CHECK-GI-NEXT: uaddw v2.2d, v4.2d, v16.2s
185+
; CHECK-GI-NEXT: uaddw2 v3.2d, v5.2d, v16.4s
186+
; CHECK-GI-NEXT: uaddw v4.2d, v6.2d, v17.2s
187+
; CHECK-GI-NEXT: uaddw2 v5.2d, v7.2d, v17.4s
188+
; CHECK-GI-NEXT: uaddw v6.2d, v19.2d, v18.2s
189+
; CHECK-GI-NEXT: uaddw2 v7.2d, v20.2d, v18.4s
218190
; CHECK-GI-NEXT: ret
219191
entry:
220192
%ea = zext <16 x i8> %a to <16 x i64>

0 commit comments

Comments
 (0)