Skip to content

[AArch64][GlobalISel] Add push_mul_through_s/zext #141551

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion llvm/lib/Target/AArch64/AArch64Combine.td
Original file line number Diff line number Diff line change
@@ -62,8 +62,10 @@ class push_opcode_through_ext<Instruction opcode, Instruction extOpcode> : GICom

def push_sub_through_zext : push_opcode_through_ext<G_SUB, G_ZEXT>;
def push_add_through_zext : push_opcode_through_ext<G_ADD, G_ZEXT>;
def push_mul_through_zext : push_opcode_through_ext<G_MUL, G_ZEXT>;
def push_sub_through_sext : push_opcode_through_ext<G_SUB, G_SEXT>;
def push_add_through_sext : push_opcode_through_ext<G_ADD, G_SEXT>;
def push_mul_through_sext : push_opcode_through_ext<G_MUL, G_SEXT>;

def AArch64PreLegalizerCombiner: GICombiner<
"AArch64PreLegalizerCombinerImpl", [all_combines,
@@ -75,8 +77,10 @@ def AArch64PreLegalizerCombiner: GICombiner<
ext_uaddv_to_uaddlv,
push_sub_through_zext,
push_add_through_zext,
push_mul_through_zext,
push_sub_through_sext,
push_add_through_sext]> {
push_add_through_sext,
push_mul_through_sext]> {
let CombineAllMethodName = "tryCombineAllImpl";
}

69 changes: 48 additions & 21 deletions llvm/lib/Target/AArch64/GISel/AArch64PreLegalizerCombiner.cpp
Original file line number Diff line number Diff line change
@@ -229,6 +229,7 @@ void applyFoldGlobalOffset(MachineInstr &MI, MachineRegisterInfo &MRI,
}

// Combines vecreduce_add(mul(ext(x), ext(y))) -> vecreduce_add(udot(x, y))
// Or vecreduce_add(ext(mul(ext(x), ext(y)))) -> vecreduce_add(udot(x, y))
// Or vecreduce_add(ext(x)) -> vecreduce_add(udot(x, 1))
// Similar to performVecReduceAddCombine in SelectionDAG
bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
@@ -246,31 +247,57 @@ bool matchExtAddvToUdotAddv(MachineInstr &MI, MachineRegisterInfo &MRI,
if (DstTy.getScalarSizeInBits() != 32 || MidTy.getScalarSizeInBits() != 32)
return false;

LLT SrcTy;
auto I1Opc = I1->getOpcode();
if (I1Opc == TargetOpcode::G_MUL) {
// Detect mul(ext, ext) with symetric ext's. If I1Opc is G_ZEXT or G_SEXT then
// the ext's must match the same opcode. It is set to the ext opcode on
// output.
auto tryMatchingMulOfExt = [&MRI](MachineInstr *MI, Register &Out1,
Register &Out2, unsigned &I1Opc) {
// If result of this has more than 1 use, then there is no point in creating
// udot instruction
if (!MRI.hasOneNonDBGUse(MidReg))
// a dot instruction
if (!MRI.hasOneNonDBGUse(MI->getOperand(0).getReg()))
return false;

MachineInstr *ExtMI1 =
getDefIgnoringCopies(I1->getOperand(1).getReg(), MRI);
getDefIgnoringCopies(MI->getOperand(1).getReg(), MRI);
MachineInstr *ExtMI2 =
getDefIgnoringCopies(I1->getOperand(2).getReg(), MRI);
getDefIgnoringCopies(MI->getOperand(2).getReg(), MRI);
LLT Ext1DstTy = MRI.getType(ExtMI1->getOperand(0).getReg());
LLT Ext2DstTy = MRI.getType(ExtMI2->getOperand(0).getReg());

if (ExtMI1->getOpcode() != ExtMI2->getOpcode() || Ext1DstTy != Ext2DstTy)
return false;
if ((I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) &&
I1Opc != ExtMI1->getOpcode())
return false;
Out1 = ExtMI1->getOperand(1).getReg();
Out2 = ExtMI2->getOperand(1).getReg();
I1Opc = ExtMI1->getOpcode();
SrcTy = MRI.getType(ExtMI1->getOperand(1).getReg());
std::get<0>(MatchInfo) = ExtMI1->getOperand(1).getReg();
std::get<1>(MatchInfo) = ExtMI2->getOperand(1).getReg();
return true;
};

LLT SrcTy;
unsigned I1Opc = I1->getOpcode();
if (I1Opc == TargetOpcode::G_MUL) {
Register Out1, Out2;
if (!tryMatchingMulOfExt(I1, Out1, Out2, I1Opc))
return false;
SrcTy = MRI.getType(Out1);
std::get<0>(MatchInfo) = Out1;
std::get<1>(MatchInfo) = Out2;
} else if (I1Opc == TargetOpcode::G_ZEXT || I1Opc == TargetOpcode::G_SEXT) {
SrcTy = MRI.getType(I1->getOperand(1).getReg());
std::get<0>(MatchInfo) = I1->getOperand(1).getReg();
std::get<1>(MatchInfo) = 0;
Register I1Op = I1->getOperand(1).getReg();
MachineInstr *M = getDefIgnoringCopies(I1Op, MRI);
Register Out1, Out2;
if (M->getOpcode() == TargetOpcode::G_MUL &&
tryMatchingMulOfExt(M, Out1, Out2, I1Opc)) {
SrcTy = MRI.getType(Out1);
std::get<0>(MatchInfo) = Out1;
std::get<1>(MatchInfo) = Out2;
} else {
SrcTy = MRI.getType(I1Op);
std::get<0>(MatchInfo) = I1Op;
std::get<1>(MatchInfo) = 0;
}
} else {
return false;
}
@@ -553,14 +580,14 @@ void applyExtUaddvToUaddlv(MachineInstr &MI, MachineRegisterInfo &MRI,
MI.eraseFromParent();
}

// Pushes ADD/SUB through extend instructions to decrease the number of extend
// instruction at the end by allowing selection of {s|u}addl sooner

// i32 add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
// Pushes ADD/SUB/MUL through extend instructions to decrease the number of
// extend instruction at the end by allowing selection of {s|u}addl sooner i32
// add(i32 ext i8, i32 ext i8) => i32 ext(i16 add(i16 ext i8, i16 ext i8))
bool matchPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
Register DstReg, Register SrcReg1, Register SrcReg2) {
assert((MI.getOpcode() == TargetOpcode::G_ADD ||
MI.getOpcode() == TargetOpcode::G_SUB) &&
MI.getOpcode() == TargetOpcode::G_SUB ||
MI.getOpcode() == TargetOpcode::G_MUL) &&
"Expected a G_ADD or G_SUB instruction\n");

// Deal with vector types only
@@ -594,9 +621,9 @@ void applyPushAddSubExt(MachineInstr &MI, MachineRegisterInfo &MRI,
B.buildInstr(MI.getOpcode(), {MidTy}, {Ext1Reg, Ext2Reg}).getReg(0);

// G_SUB has to sign-extend the result.
// G_ADD needs to sext from sext and can sext or zext from zext, so the
// original opcode is used.
if (MI.getOpcode() == TargetOpcode::G_ADD)
// G_ADD needs to sext from sext and can sext or zext from zext, and G_MUL
// needs to use the original opcode so the original opcode is used for both.
if (MI.getOpcode() != TargetOpcode::G_SUB)
B.buildInstr(Opc, {DstReg}, {AddReg});
else
B.buildSExt(DstReg, AddReg);
110 changes: 41 additions & 69 deletions llvm/test/CodeGen/AArch64/aarch64-wide-mul.ll
Original file line number Diff line number Diff line change
@@ -38,14 +38,12 @@ define <16 x i32> @mul_i32(<16 x i8> %a, <16 x i8> %b) {
;
; CHECK-GI-LABEL: mul_i32:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-GI-NEXT: ushll2 v4.8h, v0.16b, #0
; CHECK-GI-NEXT: ushll2 v5.8h, v1.16b, #0
; CHECK-GI-NEXT: umull v0.4s, v2.4h, v3.4h
; CHECK-GI-NEXT: umull2 v1.4s, v2.8h, v3.8h
; CHECK-GI-NEXT: umull v2.4s, v4.4h, v5.4h
; CHECK-GI-NEXT: umull2 v3.4s, v4.8h, v5.8h
; CHECK-GI-NEXT: umull v2.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: umull2 v3.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: ushll v0.4s, v2.4h, #0
; CHECK-GI-NEXT: ushll2 v1.4s, v2.8h, #0
; CHECK-GI-NEXT: ushll v2.4s, v3.4h, #0
; CHECK-GI-NEXT: ushll2 v3.4s, v3.8h, #0
; CHECK-GI-NEXT: ret
entry:
%ea = zext <16 x i8> %a to <16 x i32>
@@ -75,26 +73,20 @@ define <16 x i64> @mul_i64(<16 x i8> %a, <16 x i8> %b) {
;
; CHECK-GI-LABEL: mul_i64:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-GI-NEXT: ushll v4.4s, v2.4h, #0
; CHECK-GI-NEXT: ushll2 v5.4s, v2.8h, #0
; CHECK-GI-NEXT: ushll v2.4s, v3.4h, #0
; CHECK-GI-NEXT: ushll v6.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll2 v3.4s, v3.8h, #0
; CHECK-GI-NEXT: ushll v7.4s, v1.4h, #0
; CHECK-GI-NEXT: ushll2 v16.4s, v0.8h, #0
; CHECK-GI-NEXT: ushll2 v17.4s, v1.8h, #0
; CHECK-GI-NEXT: umull v0.2d, v4.2s, v2.2s
; CHECK-GI-NEXT: umull2 v1.2d, v4.4s, v2.4s
; CHECK-GI-NEXT: umull v2.2d, v5.2s, v3.2s
; CHECK-GI-NEXT: umull2 v3.2d, v5.4s, v3.4s
; CHECK-GI-NEXT: umull v4.2d, v6.2s, v7.2s
; CHECK-GI-NEXT: umull2 v5.2d, v6.4s, v7.4s
; CHECK-GI-NEXT: umull v6.2d, v16.2s, v17.2s
; CHECK-GI-NEXT: umull2 v7.2d, v16.4s, v17.4s
; CHECK-GI-NEXT: umull v2.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: umull2 v0.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: ushll v1.4s, v2.4h, #0
; CHECK-GI-NEXT: ushll2 v3.4s, v2.8h, #0
; CHECK-GI-NEXT: ushll v5.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll2 v7.4s, v0.8h, #0
; CHECK-GI-NEXT: ushll v0.2d, v1.2s, #0
; CHECK-GI-NEXT: ushll2 v1.2d, v1.4s, #0
; CHECK-GI-NEXT: ushll v2.2d, v3.2s, #0
; CHECK-GI-NEXT: ushll2 v3.2d, v3.4s, #0
; CHECK-GI-NEXT: ushll v4.2d, v5.2s, #0
; CHECK-GI-NEXT: ushll2 v5.2d, v5.4s, #0
; CHECK-GI-NEXT: ushll v6.2d, v7.2s, #0
; CHECK-GI-NEXT: ushll2 v7.2d, v7.4s, #0
; CHECK-GI-NEXT: ret
entry:
%ea = zext <16 x i8> %a to <16 x i64>
@@ -142,18 +134,12 @@ define <16 x i32> @mla_i32(<16 x i8> %a, <16 x i8> %b, <16 x i32> %c) {
;
; CHECK-GI-LABEL: mla_i32:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: ushll v6.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v7.8h, v1.8b, #0
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-GI-NEXT: umlal v2.4s, v6.4h, v7.4h
; CHECK-GI-NEXT: umlal2 v3.4s, v6.8h, v7.8h
; CHECK-GI-NEXT: umlal v4.4s, v0.4h, v1.4h
; CHECK-GI-NEXT: umlal2 v5.4s, v0.8h, v1.8h
; CHECK-GI-NEXT: mov v0.16b, v2.16b
; CHECK-GI-NEXT: mov v1.16b, v3.16b
; CHECK-GI-NEXT: mov v2.16b, v4.16b
; CHECK-GI-NEXT: mov v3.16b, v5.16b
; CHECK-GI-NEXT: umull v6.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: umull2 v7.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: uaddw v0.4s, v2.4s, v6.4h
; CHECK-GI-NEXT: uaddw2 v1.4s, v3.4s, v6.8h
; CHECK-GI-NEXT: uaddw v2.4s, v4.4s, v7.4h
; CHECK-GI-NEXT: uaddw2 v3.4s, v5.4s, v7.8h
; CHECK-GI-NEXT: ret
entry:
%ea = zext <16 x i8> %a to <16 x i32>
@@ -186,35 +172,21 @@ define <16 x i64> @mla_i64(<16 x i8> %a, <16 x i8> %b, <16 x i64> %c) {
;
; CHECK-GI-LABEL: mla_i64:
; CHECK-GI: // %bb.0: // %entry
; CHECK-GI-NEXT: mov v16.16b, v2.16b
; CHECK-GI-NEXT: mov v17.16b, v3.16b
; CHECK-GI-NEXT: mov v2.16b, v4.16b
; CHECK-GI-NEXT: mov v3.16b, v5.16b
; CHECK-GI-NEXT: mov v4.16b, v6.16b
; CHECK-GI-NEXT: mov v5.16b, v7.16b
; CHECK-GI-NEXT: ushll v6.8h, v0.8b, #0
; CHECK-GI-NEXT: ushll v7.8h, v1.8b, #0
; CHECK-GI-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-GI-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-GI-NEXT: ushll v18.4s, v6.4h, #0
; CHECK-GI-NEXT: ushll v20.4s, v7.4h, #0
; CHECK-GI-NEXT: ushll2 v19.4s, v6.8h, #0
; CHECK-GI-NEXT: ushll v21.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll2 v22.4s, v7.8h, #0
; CHECK-GI-NEXT: ushll v23.4s, v1.4h, #0
; CHECK-GI-NEXT: ldp q6, q7, [sp]
; CHECK-GI-NEXT: ushll2 v0.4s, v0.8h, #0
; CHECK-GI-NEXT: ushll2 v1.4s, v1.8h, #0
; CHECK-GI-NEXT: umlal v16.2d, v18.2s, v20.2s
; CHECK-GI-NEXT: umlal2 v17.2d, v18.4s, v20.4s
; CHECK-GI-NEXT: umlal v2.2d, v19.2s, v22.2s
; CHECK-GI-NEXT: umlal2 v3.2d, v19.4s, v22.4s
; CHECK-GI-NEXT: umlal v4.2d, v21.2s, v23.2s
; CHECK-GI-NEXT: umlal2 v5.2d, v21.4s, v23.4s
; CHECK-GI-NEXT: umlal v6.2d, v0.2s, v1.2s
; CHECK-GI-NEXT: umlal2 v7.2d, v0.4s, v1.4s
; CHECK-GI-NEXT: mov v0.16b, v16.16b
; CHECK-GI-NEXT: mov v1.16b, v17.16b
; CHECK-GI-NEXT: umull v16.8h, v0.8b, v1.8b
; CHECK-GI-NEXT: umull2 v0.8h, v0.16b, v1.16b
; CHECK-GI-NEXT: ldp q19, q20, [sp]
; CHECK-GI-NEXT: ushll v1.4s, v16.4h, #0
; CHECK-GI-NEXT: ushll2 v16.4s, v16.8h, #0
; CHECK-GI-NEXT: ushll v17.4s, v0.4h, #0
; CHECK-GI-NEXT: ushll2 v18.4s, v0.8h, #0
; CHECK-GI-NEXT: uaddw v0.2d, v2.2d, v1.2s
; CHECK-GI-NEXT: uaddw2 v1.2d, v3.2d, v1.4s
; CHECK-GI-NEXT: uaddw v2.2d, v4.2d, v16.2s
; CHECK-GI-NEXT: uaddw2 v3.2d, v5.2d, v16.4s
; CHECK-GI-NEXT: uaddw v4.2d, v6.2d, v17.2s
; CHECK-GI-NEXT: uaddw2 v5.2d, v7.2d, v17.4s
; CHECK-GI-NEXT: uaddw v6.2d, v19.2d, v18.2s
; CHECK-GI-NEXT: uaddw2 v7.2d, v20.2d, v18.4s
; CHECK-GI-NEXT: ret
entry:
%ea = zext <16 x i8> %a to <16 x i64>
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.