Skip to content

Commit

Permalink
[LSR][AArch64] Optimize chain generation based on legal addressing mo…
Browse files Browse the repository at this point in the history
…des (llvm#94453)

LSR will generate chains of related instructions with a known increment
between them. With SVE, in the case of the test case, this can include
increments like 'vscale * 16 + 8'. The idea of this patch is if we have
a '+8' increment already calculated in the chain, we can generate a
(legal) '+ vscale*16' addressing mode from it, allowing us to use the
'[x16, llvm#1, mul vl]' addressing mode instructions.

In order to do this we keep track of the known 'bases' when generating
chains in GenerateIVChain, checking for each if the accumulated
increment expression from the base neatly folds into a legal addressing
mode. If they do not we fall back to the existing LeftOverExpr, whether
it is legal or not.

This is mostly orthogonal to llvm#88124, dealing with the generation of
chains as opposed to rest of LSR. The existing vscale addressing mode
work has greatly helped compared to the last time I looked at this,
allowing us to check that the addressing modes are indeed legal.
  • Loading branch information
davemgreen authored and Lukacma committed Jun 12, 2024
1 parent f78692c commit 7c65605
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 60 deletions.
72 changes: 58 additions & 14 deletions llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1256,7 +1256,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
Instruction *Fixup = nullptr);
Instruction *Fixup = nullptr,
int64_t ScalableOffset = 0);

static unsigned getSetupCost(const SCEV *Reg, unsigned Depth) {
if (isa<SCEVUnknown>(Reg) || isa<SCEVConstant>(Reg))
Expand Down Expand Up @@ -1675,16 +1676,18 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg, int64_t Scale,
Instruction *Fixup/*= nullptr*/) {
Instruction *Fixup /* = nullptr */,
int64_t ScalableOffset) {
switch (Kind) {
case LSRUse::Address:
return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset,
HasBaseReg, Scale, AccessTy.AddrSpace, Fixup);
HasBaseReg, Scale, AccessTy.AddrSpace,
Fixup, ScalableOffset);

case LSRUse::ICmpZero:
// There's not even a target hook for querying whether it would be legal to
// fold a GV into an ICmp.
if (BaseGV)
if (BaseGV || ScalableOffset != 0)
return false;

// ICmp only has two operands; don't allow more than two non-trivial parts.
Expand Down Expand Up @@ -1715,11 +1718,12 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI,

case LSRUse::Basic:
// Only handle single-register values.
return !BaseGV && Scale == 0 && BaseOffset == 0;
return !BaseGV && Scale == 0 && BaseOffset == 0 && ScalableOffset == 0;

case LSRUse::Special:
// Special case Basic to handle -1 scales.
return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0;
return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0 &&
ScalableOffset == 0;
}

llvm_unreachable("Invalid LSRUse Kind!");
Expand Down Expand Up @@ -1843,7 +1847,7 @@ static InstructionCost getScalingFactorCost(const TargetTransformInfo &TTI,
static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
LSRUse::KindType Kind, MemAccessTy AccessTy,
GlobalValue *BaseGV, int64_t BaseOffset,
bool HasBaseReg) {
bool HasBaseReg, int64_t ScalableOffset = 0) {
// Fast-path: zero is always foldable.
if (BaseOffset == 0 && !BaseGV) return true;

Expand All @@ -1859,7 +1863,7 @@ static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
}

return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset,
HasBaseReg, Scale);
HasBaseReg, Scale, nullptr, ScalableOffset);
}

static bool isAlwaysFoldable(const TargetTransformInfo &TTI,
Expand Down Expand Up @@ -3165,16 +3169,30 @@ void LSRInstance::FinalizeChain(IVChain &Chain) {
static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
Value *Operand, const TargetTransformInfo &TTI) {
const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr);
if (!IncConst || !isAddressUse(TTI, UserInst, Operand))
return false;
int64_t IncOffset = 0;
int64_t ScalableOffset = 0;
if (IncConst) {
if (IncConst && IncConst->getAPInt().getSignificantBits() > 64)
return false;
IncOffset = IncConst->getValue()->getSExtValue();
} else {
// Look for mul(vscale, constant), to detect ScalableOffset.
auto *IncVScale = dyn_cast<SCEVMulExpr>(IncExpr);
if (!IncVScale || IncVScale->getNumOperands() != 2 ||
!isa<SCEVVScale>(IncVScale->getOperand(1)))
return false;
auto *Scale = dyn_cast<SCEVConstant>(IncVScale->getOperand(0));
if (!Scale || Scale->getType()->getScalarSizeInBits() > 64)
return false;
ScalableOffset = Scale->getValue()->getSExtValue();
}

if (IncConst->getAPInt().getSignificantBits() > 64)
if (!isAddressUse(TTI, UserInst, Operand))
return false;

MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand);
int64_t IncOffset = IncConst->getValue()->getSExtValue();
if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr,
IncOffset, /*HasBaseReg=*/false))
IncOffset, /*HasBaseReg=*/false, ScalableOffset))
return false;

return true;
Expand Down Expand Up @@ -3220,6 +3238,10 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
Type *IVTy = IVSrc->getType();
Type *IntTy = SE.getEffectiveSCEVType(IVTy);
const SCEV *LeftOverExpr = nullptr;
const SCEV *Accum = SE.getZero(IntTy);
SmallVector<std::pair<const SCEV *, Value *>> Bases;
Bases.emplace_back(Accum, IVSrc);

for (const IVInc &Inc : Chain) {
Instruction *InsertPt = Inc.UserInst;
if (isa<PHINode>(InsertPt))
Expand All @@ -3232,10 +3254,31 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// IncExpr was the result of subtraction of two narrow values, so must
// be signed.
const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy);
Accum = SE.getAddExpr(Accum, IncExpr);
LeftOverExpr = LeftOverExpr ?
SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr;
}
if (LeftOverExpr && !LeftOverExpr->isZero()) {

// Look through each base to see if any can produce a nice addressing mode.
bool FoundBase = false;
for (auto [MapScev, MapIVOper] : reverse(Bases)) {
const SCEV *Remainder = SE.getMinusSCEV(Accum, MapScev);
if (canFoldIVIncExpr(Remainder, Inc.UserInst, Inc.IVOperand, TTI)) {
if (!Remainder->isZero()) {
Rewriter.clearPostInc();
Value *IncV = Rewriter.expandCodeFor(Remainder, IntTy, InsertPt);
const SCEV *IVOperExpr =
SE.getAddExpr(SE.getUnknown(MapIVOper), SE.getUnknown(IncV));
IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt);
} else {
IVOper = MapIVOper;
}

FoundBase = true;
break;
}
}
if (!FoundBase && LeftOverExpr && !LeftOverExpr->isZero()) {
// Expand the IV increment.
Rewriter.clearPostInc();
Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt);
Expand All @@ -3246,6 +3289,7 @@ void LSRInstance::GenerateIVChain(const IVChain &Chain,
// If an IV increment can't be folded, use it as the next IV value.
if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) {
assert(IVTy == IVOper->getType() && "inconsistent IV increment type");
Bases.emplace_back(Accum, IVOper);
IVSrc = IVOper;
LeftOverExpr = nullptr;
}
Expand Down
86 changes: 40 additions & 46 deletions llvm/test/CodeGen/AArch64/sve-lsrchain.ll
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,22 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
; CHECK-NEXT: // %bb.2: // %for.body.us.preheader
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: add x11, x2, x11, lsl #1
; CHECK-NEXT: mov x12, #-16 // =0xfffffffffffffff0
; CHECK-NEXT: ptrue p1.b
; CHECK-NEXT: mov w8, wzr
; CHECK-NEXT: ptrue p1.b
; CHECK-NEXT: mov x9, xzr
; CHECK-NEXT: mov w10, wzr
; CHECK-NEXT: addvl x12, x12, #1
; CHECK-NEXT: mov x13, #4 // =0x4
; CHECK-NEXT: mov x14, #8 // =0x8
; CHECK-NEXT: mov x12, #4 // =0x4
; CHECK-NEXT: mov x13, #8 // =0x8
; CHECK-NEXT: .LBB0_3: // %for.body.us
; CHECK-NEXT: // =>This Loop Header: Depth=1
; CHECK-NEXT: // Child Loop BB0_4 Depth 2
; CHECK-NEXT: add x15, x0, x9, lsl #2
; CHECK-NEXT: sbfiz x16, x8, #1, #32
; CHECK-NEXT: mov x17, x2
; CHECK-NEXT: ldp s0, s1, [x15]
; CHECK-NEXT: add x16, x16, #8
; CHECK-NEXT: ldp s2, s3, [x15, #8]
; CHECK-NEXT: ubfiz x15, x8, #1, #32
; CHECK-NEXT: add x14, x0, x9, lsl #2
; CHECK-NEXT: sbfiz x15, x8, #1, #32
; CHECK-NEXT: mov x16, x2
; CHECK-NEXT: ldp s0, s1, [x14]
; CHECK-NEXT: add x15, x15, #8
; CHECK-NEXT: ldp s2, s3, [x14, #8]
; CHECK-NEXT: ubfiz x14, x8, #1, #32
; CHECK-NEXT: fcvt h0, s0
; CHECK-NEXT: fcvt h1, s1
; CHECK-NEXT: fcvt h2, s2
Expand All @@ -43,56 +41,52 @@ define void @test(ptr nocapture noundef readonly %kernel, i32 noundef %kw, float
; CHECK-NEXT: .LBB0_4: // %for.cond.i.preheader.us
; CHECK-NEXT: // Parent Loop BB0_3 Depth=1
; CHECK-NEXT: // => This Inner Loop Header: Depth=2
; CHECK-NEXT: ld1b { z4.b }, p1/z, [x17, x15]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17]
; CHECK-NEXT: add x18, x17, x16
; CHECK-NEXT: add x3, x17, x15
; CHECK-NEXT: ld1b { z4.b }, p1/z, [x16, x14]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16]
; CHECK-NEXT: add x17, x16, x15
; CHECK-NEXT: add x18, x16, x14
; CHECK-NEXT: add x3, x17, #8
; CHECK-NEXT: add x4, x17, #16
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x17, x16]
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x16, x15]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x12, lsl #1]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, x13, lsl #1]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #1, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #1, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #1, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #1, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #1, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #2, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #2, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #1, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #2, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: add x18, x18, #16
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #2, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #2, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x3, #3, mul vl]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x16, #3, mul vl]
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #2, mul vl]
; CHECK-NEXT: ld1h { z4.h }, p0/z, [x18, #3, mul vl]
; CHECK-NEXT: fmad z4.h, p0/m, z0.h, z5.h
; CHECK-NEXT: ld1b { z5.b }, p1/z, [x18, x12]
; CHECK-NEXT: add x18, x18, x12
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x17, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z1.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x13, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x3, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z2.h
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x18, x14, lsl #1]
; CHECK-NEXT: ld1h { z5.h }, p0/z, [x4, #3, mul vl]
; CHECK-NEXT: fmla z4.h, p0/m, z5.h, z3.h
; CHECK-NEXT: st1h { z4.h }, p0, [x17, #3, mul vl]
; CHECK-NEXT: addvl x17, x17, #4
; CHECK-NEXT: cmp x17, x11
; CHECK-NEXT: st1h { z4.h }, p0, [x16, #3, mul vl]
; CHECK-NEXT: addvl x16, x16, #4
; CHECK-NEXT: cmp x16, x11
; CHECK-NEXT: b.lo .LBB0_4
; CHECK-NEXT: // %bb.5: // %while.cond.i..exit_crit_edge.us
; CHECK-NEXT: // in Loop: Header=BB0_3 Depth=1
Expand Down

0 comments on commit 7c65605

Please sign in to comment.