Skip to content

Commit 5ca26d7

Browse files
[AArch64][SME2] Improve register allocation of multi-vector SME intrinsics (llvm#116399)
The FORM_TRANSPOSED_REG_TUPLE pseudos have been created to improve register allocation for intrinsics which use strided and contiguous multi-vector registers, avoiding unnecessary copies. If the operands of the pseudo are copies where the source register is in the StridedOrContiguous class, the pseudo is used by getRegAllocationHints to suggest a contigious multi-vector register which matches the subregister sequence used by the operands. If the operands do not match this pattern, the pseudos are expanded to a REG_SEQUENCE. Patch contains changes by Matthew Devereau.
1 parent 98470c0 commit 5ca26d7

File tree

7 files changed

+805
-161
lines changed

7 files changed

+805
-161
lines changed

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ class AArch64ExpandPseudo : public MachineFunctionPass {
6767
TargetRegisterClass ContiguousClass,
6868
TargetRegisterClass StridedClass,
6969
unsigned ContiguousOpc, unsigned StridedOpc);
70+
bool expandFormTuplePseudo(MachineBasicBlock &MBB,
71+
MachineBasicBlock::iterator MBBI,
72+
MachineBasicBlock::iterator &NextMBBI,
73+
unsigned Size);
7074
bool expandMOVImm(MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
7175
unsigned BitSize);
7276

@@ -1142,6 +1146,32 @@ bool AArch64ExpandPseudo::expandMultiVecPseudo(
11421146
return true;
11431147
}
11441148

1149+
bool AArch64ExpandPseudo::expandFormTuplePseudo(
1150+
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI,
1151+
MachineBasicBlock::iterator &NextMBBI, unsigned Size) {
1152+
assert(Size == 2 || Size == 4 && "Invalid Tuple Size");
1153+
MachineInstr &MI = *MBBI;
1154+
Register ReturnTuple = MI.getOperand(0).getReg();
1155+
1156+
const TargetRegisterInfo *TRI =
1157+
MBB.getParent()->getSubtarget().getRegisterInfo();
1158+
for (unsigned I = 0; I < Size; ++I) {
1159+
Register FormTupleOpReg = MI.getOperand(I + 1).getReg();
1160+
Register ReturnTupleSubReg =
1161+
TRI->getSubReg(ReturnTuple, AArch64::zsub0 + I);
1162+
// Add copies to ensure the subregisters remain in the correct order
1163+
// for any contigious operation they are used by.
1164+
if (FormTupleOpReg != ReturnTupleSubReg)
1165+
BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::ORR_ZZZ))
1166+
.addReg(ReturnTupleSubReg, RegState::Define)
1167+
.addReg(FormTupleOpReg)
1168+
.addReg(FormTupleOpReg);
1169+
}
1170+
1171+
MI.eraseFromParent();
1172+
return true;
1173+
}
1174+
11451175
/// If MBBI references a pseudo instruction that should be expanded here,
11461176
/// do the expansion and return true. Otherwise return false.
11471177
bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
@@ -1724,6 +1754,10 @@ bool AArch64ExpandPseudo::expandMI(MachineBasicBlock &MBB,
17241754
return expandMultiVecPseudo(
17251755
MBB, MBBI, AArch64::ZPR4RegClass, AArch64::ZPR4StridedRegClass,
17261756
AArch64::LDNT1D_4Z, AArch64::LDNT1D_4Z_STRIDED);
1757+
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
1758+
return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 2);
1759+
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
1760+
return expandFormTuplePseudo(MBB, MBBI, NextMBBI, 4);
17271761
}
17281762
return false;
17291763
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8581,6 +8581,56 @@ static bool checkZExtBool(SDValue Arg, const SelectionDAG &DAG) {
85818581
return ZExtBool;
85828582
}
85838583

8584+
// The FORM_TRANSPOSED_REG_TUPLE pseudo should only be used if the
8585+
// input operands are copy nodes where the source register is in a
8586+
// StridedOrContiguous class. For example:
8587+
//
8588+
// %3:zpr2stridedorcontiguous = LD1B_2Z_IMM_PSEUDO ..
8589+
// %4:zpr = COPY %3.zsub1:zpr2stridedorcontiguous
8590+
// %5:zpr = COPY %3.zsub0:zpr2stridedorcontiguous
8591+
// %6:zpr2stridedorcontiguous = LD1B_2Z_PSEUDO ..
8592+
// %7:zpr = COPY %6.zsub1:zpr2stridedorcontiguous
8593+
// %8:zpr = COPY %6.zsub0:zpr2stridedorcontiguous
8594+
// %9:zpr2mul2 = FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO %5:zpr, %8:zpr
8595+
//
8596+
bool shouldUseFormStridedPseudo(MachineInstr &MI) {
8597+
MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
8598+
8599+
const TargetRegisterClass *RegClass = nullptr;
8600+
switch (MI.getOpcode()) {
8601+
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO:
8602+
RegClass = &AArch64::ZPR2StridedOrContiguousRegClass;
8603+
break;
8604+
case AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO:
8605+
RegClass = &AArch64::ZPR4StridedOrContiguousRegClass;
8606+
break;
8607+
default:
8608+
llvm_unreachable("Unexpected opcode.");
8609+
}
8610+
8611+
MCRegister SubReg = MCRegister::NoRegister;
8612+
for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
8613+
MachineOperand &MO = MI.getOperand(I);
8614+
assert(MO.isReg() && "Unexpected operand to FORM_TRANSPOSED_REG_TUPLE");
8615+
8616+
MachineOperand *Def = MRI.getOneDef(MO.getReg());
8617+
if (!Def || !Def->getParent()->isCopy())
8618+
return false;
8619+
8620+
const MachineOperand &CopySrc = Def->getParent()->getOperand(1);
8621+
unsigned OpSubReg = CopySrc.getSubReg();
8622+
if (SubReg == MCRegister::NoRegister)
8623+
SubReg = OpSubReg;
8624+
8625+
MachineOperand *CopySrcOp = MRI.getOneDef(CopySrc.getReg());
8626+
if (!CopySrcOp || !CopySrcOp->isReg() || OpSubReg != SubReg ||
8627+
MRI.getRegClass(CopySrcOp->getReg()) != RegClass)
8628+
return false;
8629+
}
8630+
8631+
return true;
8632+
}
8633+
85848634
void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
85858635
SDNode *Node) const {
85868636
// Live-in physreg copies that are glued to SMSTART are applied as
@@ -8606,6 +8656,27 @@ void AArch64TargetLowering::AdjustInstrPostInstrSelection(MachineInstr &MI,
86068656
}
86078657
}
86088658

8659+
if (MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO ||
8660+
MI.getOpcode() == AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO) {
8661+
// If input values to the FORM_TRANSPOSED_REG_TUPLE pseudo aren't copies
8662+
// from a StridedOrContiguous class, fall back on REG_SEQUENCE node.
8663+
if (shouldUseFormStridedPseudo(MI))
8664+
return;
8665+
8666+
const TargetInstrInfo *TII = Subtarget->getInstrInfo();
8667+
MachineInstrBuilder MIB = BuildMI(*MI.getParent(), MI, MI.getDebugLoc(),
8668+
TII->get(TargetOpcode::REG_SEQUENCE),
8669+
MI.getOperand(0).getReg());
8670+
8671+
for (unsigned I = 1; I < MI.getNumOperands(); ++I) {
8672+
MIB.add(MI.getOperand(I));
8673+
MIB.addImm(AArch64::zsub0 + (I - 1));
8674+
}
8675+
8676+
MI.eraseFromParent();
8677+
return;
8678+
}
8679+
86098680
// Add an implicit use of 'VG' for ADDXri/SUBXri, which are instructions that
86108681
// have nothing to do with VG, were it not that they are used to materialise a
86118682
// frame-address. If they contain a frame-index to a scalable vector, this

llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,58 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC,
10811081
}
10821082
}
10831083

1084+
// FORM_TRANSPOSED_REG_TUPLE nodes are created to improve register allocation
1085+
// where a consecutive multi-vector tuple is constructed from the same indices
1086+
// of multiple strided loads. This may still result in unnecessary copies
1087+
// between the loads and the tuple. Here we try to return a hint to assign the
1088+
// contiguous ZPRMulReg starting at the same register as the first operand of
1089+
// the pseudo, which should be a subregister of the first strided load.
1090+
//
1091+
// For example, if the first strided load has been assigned $z16_z20_z24_z28
1092+
// and the operands of the pseudo are each accessing subregister zsub2, we
1093+
// should look through through Order to find a contiguous register which
1094+
// begins with $z24 (i.e. $z24_z25_z26_z27).
1095+
//
1096+
bool AArch64RegisterInfo::getRegAllocationHints(
1097+
Register VirtReg, ArrayRef<MCPhysReg> Order,
1098+
SmallVectorImpl<MCPhysReg> &Hints, const MachineFunction &MF,
1099+
const VirtRegMap *VRM, const LiveRegMatrix *Matrix) const {
1100+
const MachineRegisterInfo &MRI = MF.getRegInfo();
1101+
1102+
for (MachineInstr &MI : MRI.def_instructions(VirtReg)) {
1103+
if (MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO &&
1104+
MI.getOpcode() != AArch64::FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO)
1105+
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints,
1106+
MF, VRM);
1107+
1108+
unsigned FirstOpSubReg = MI.getOperand(1).getSubReg();
1109+
switch (FirstOpSubReg) {
1110+
case AArch64::zsub0:
1111+
case AArch64::zsub1:
1112+
case AArch64::zsub2:
1113+
case AArch64::zsub3:
1114+
break;
1115+
default:
1116+
continue;
1117+
}
1118+
1119+
// Look up the physical register mapped to the first operand of the pseudo.
1120+
Register FirstOpVirtReg = MI.getOperand(1).getReg();
1121+
if (!VRM->hasPhys(FirstOpVirtReg))
1122+
continue;
1123+
1124+
MCRegister TupleStartReg =
1125+
getSubReg(VRM->getPhys(FirstOpVirtReg), FirstOpSubReg);
1126+
for (unsigned I = 0; I < Order.size(); ++I)
1127+
if (MCRegister R = getSubReg(Order[I], AArch64::zsub0))
1128+
if (R == TupleStartReg)
1129+
Hints.push_back(Order[I]);
1130+
}
1131+
1132+
return TargetRegisterInfo::getRegAllocationHints(VirtReg, Order, Hints, MF,
1133+
VRM);
1134+
}
1135+
10841136
unsigned AArch64RegisterInfo::getLocalAddressRegister(
10851137
const MachineFunction &MF) const {
10861138
const auto &MFI = MF.getFrameInfo();

llvm/lib/Target/AArch64/AArch64RegisterInfo.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ class AArch64RegisterInfo final : public AArch64GenRegisterInfo {
134134
unsigned getRegPressureLimit(const TargetRegisterClass *RC,
135135
MachineFunction &MF) const override;
136136

137+
bool getRegAllocationHints(Register VirtReg, ArrayRef<MCPhysReg> Order,
138+
SmallVectorImpl<MCPhysReg> &Hints,
139+
const MachineFunction &MF, const VirtRegMap *VRM,
140+
const LiveRegMatrix *Matrix) const override;
141+
137142
unsigned getLocalAddressRegister(const MachineFunction &MF) const;
138143
bool regNeedsCFI(unsigned Reg, unsigned &RegToUseForCFI) const;
139144

llvm/lib/Target/AArch64/SMEInstrFormats.td

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,30 @@ def tileslicerange0s4 : ComplexPattern<i32, 2, "SelectSMETileSlice<0, 4>", []>;
3535
let WantsRoot = true in
3636
def am_sme_indexed_b4 : ComplexPattern<iPTR, 2, "SelectAddrModeIndexedSVE<0, 15>">;
3737

38+
// The FORM_TRANSPOSED_REG_TUPLE pseudos defined below are intended to
39+
// improve register allocation for intrinsics which use strided and contiguous
40+
// multi-vector registers, avoiding unnecessary copies.
41+
// If the operands of the pseudo are copies where the source register is in
42+
// the StridedOrContiguous class, the pseudo is used to provide a hint to the
43+
// register allocator suggesting a contigious multi-vector register which
44+
// matches the subregister sequence used by the operands.
45+
// If the operands do not match this pattern, the pseudos are expanded
46+
// to a REG_SEQUENCE using the post-isel hook.
47+
48+
def FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO :
49+
Pseudo<(outs ZPR2Mul2:$tup),
50+
(ins ZPR:$zn0, ZPR:$zn1), []>, Sched<[]>{
51+
let hasSideEffects = 0;
52+
let hasPostISelHook = 1;
53+
}
54+
55+
def FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO :
56+
Pseudo<(outs ZPR4Mul4:$tup),
57+
(ins ZPR:$zn0, ZPR:$zn1, ZPR:$zn2, ZPR:$zn3), []>, Sched<[]>{
58+
let hasSideEffects = 0;
59+
let hasPostISelHook = 1;
60+
}
61+
3862
def SDTZALoadStore : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisPtrTy<1>, SDTCisInt<2>]>;
3963
def AArch64SMELdr : SDNode<"AArch64ISD::SME_ZA_LDR", SDTZALoadStore,
4064
[SDNPHasChain, SDNPSideEffect, SDNPMayLoad]>;
@@ -173,14 +197,14 @@ class SME2_ZA_TwoOp_VG2_Multi_Index_Pat<string name, SDPatternOperator intrinsic
173197
Operand imm_ty, ComplexPattern tileslice>
174198
: Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)), vt:$Zn1, vt:$Zn2, vt:$Zm, (i32 imm_ty:$i)),
175199
(!cast<Instruction>(name # _PSEUDO) $base, $offset,
176-
(REG_SEQUENCE ZPR2Mul2, vt:$Zn1, zsub0, vt:$Zn2, zsub1), zpr_ty:$Zm, imm_ty:$i)>;
200+
(FORM_TRANSPOSED_REG_TUPLE_X2_PSEUDO vt:$Zn1,vt:$Zn2), zpr_ty:$Zm, imm_ty:$i)>;
177201

178202
class SME2_ZA_TwoOp_VG4_Multi_Index_Pat<string name, SDPatternOperator intrinsic, Operand index_ty, ZPRRegOp zpr_ty, ValueType vt,
179203
Operand imm_ty, ComplexPattern tileslice>
180204
: Pat<(intrinsic (i32 (tileslice MatrixIndexGPR32Op8_11:$base, index_ty:$offset)),
181205
vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4, vt:$Zm, (i32 imm_ty:$i)),
182206
(!cast<Instruction>(name # _PSEUDO) $base, $offset,
183-
(REG_SEQUENCE ZPR4Mul4, vt:$Zn1, zsub0, vt:$Zn2, zsub1, vt:$Zn3, zsub2, vt:$Zn4, zsub3),
207+
(FORM_TRANSPOSED_REG_TUPLE_X4_PSEUDO vt:$Zn1, vt:$Zn2, vt:$Zn3, vt:$Zn4),
184208
zpr_ty:$Zm, imm_ty:$i)>;
185209

186210
class SME2_Sat_Shift_VG2_Pat<string name, SDPatternOperator intrinsic, ValueType out_vt, ValueType in_vt, Operand imm_ty>

0 commit comments

Comments
 (0)