Skip to content

[InstCombine] Refactor fixed and scalable binop shuffle combine. NFCI #141287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 51 additions & 64 deletions llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,49 @@ static bool shouldMergeGEPs(GEPOperator &GEP, GEPOperator &Src) {
return true;
}

/// Find a constant NewC that has property:
/// shuffle(NewC, ShMask) = C
/// Returns nullptr if such a constant does not exist e.g. ShMask=<0,0> C=<1,2>
///
/// A 1-to-1 mapping is not required. Example:
/// ShMask = <1,1,2,2> and C = <5,5,6,6> --> NewC = <poison,5,6,poison>
static Constant *unshuffleConstant(ArrayRef<int> ShMask, Constant *C,
VectorType *NewCTy) {
if (isa<ScalableVectorType>(NewCTy)) {
Constant *Splat = C->getSplatValue();
if (!Splat)
return nullptr;
return ConstantVector::getSplat(NewCTy->getElementCount(), Splat);
}

if (cast<FixedVectorType>(NewCTy)->getNumElements() >
cast<FixedVectorType>(C->getType())->getNumElements())
return nullptr;

unsigned NewCNumElts = cast<FixedVectorType>(NewCTy)->getNumElements();
PoisonValue *PoisonScalar = PoisonValue::get(C->getType()->getScalarType());
SmallVector<Constant *, 16> NewVecC(NewCNumElts, PoisonScalar);
unsigned NumElts = cast<FixedVectorType>(C->getType())->getNumElements();
for (unsigned I = 0; I < NumElts; ++I) {
Constant *CElt = C->getAggregateElement(I);
if (ShMask[I] >= 0) {
assert(ShMask[I] < (int)NumElts && "Not expecting narrowing shuffle");
Constant *NewCElt = NewVecC[ShMask[I]];
// Bail out if:
// 1. The constant vector contains a constant expression.
// 2. The shuffle needs an element of the constant vector that can't
// be mapped to a new constant vector.
// 3. This is a widening shuffle that copies elements of V1 into the
// extended elements (extending with poison is allowed).
if (!CElt || (!isa<PoisonValue>(NewCElt) && NewCElt != CElt) ||
I >= NewCNumElts)
return nullptr;
NewVecC[ShMask[I]] = CElt;
}
}
return ConstantVector::get(NewVecC);
}

Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
if (!isa<VectorType>(Inst.getType()))
return nullptr;
Expand Down Expand Up @@ -2213,53 +2256,18 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
// other binops, so they can be folded. It may also enable demanded elements
// transforms.
Constant *C;
auto *InstVTy = dyn_cast<FixedVectorType>(Inst.getType());
if (InstVTy &&
match(&Inst, m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Poison(),
if (match(&Inst, m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Poison(),
m_Mask(Mask))),
m_ImmConstant(C))) &&
cast<FixedVectorType>(V1->getType())->getNumElements() <=
InstVTy->getNumElements()) {
assert(InstVTy->getScalarType() == V1->getType()->getScalarType() &&
m_ImmConstant(C)))) {
assert(Inst.getType()->getScalarType() == V1->getType()->getScalarType() &&
"Shuffle should not change scalar type");

// Find constant NewC that has property:
// shuffle(NewC, ShMask) = C
// If such constant does not exist (example: ShMask=<0,0> and C=<1,2>)
// reorder is not possible. A 1-to-1 mapping is not required. Example:
// ShMask = <1,1,2,2> and C = <5,5,6,6> --> NewC = <undef,5,6,undef>
bool ConstOp1 = isa<Constant>(RHS);
ArrayRef<int> ShMask = Mask;
unsigned SrcVecNumElts =
cast<FixedVectorType>(V1->getType())->getNumElements();
PoisonValue *PoisonScalar = PoisonValue::get(C->getType()->getScalarType());
SmallVector<Constant *, 16> NewVecC(SrcVecNumElts, PoisonScalar);
bool MayChange = true;
unsigned NumElts = InstVTy->getNumElements();
for (unsigned I = 0; I < NumElts; ++I) {
Constant *CElt = C->getAggregateElement(I);
if (ShMask[I] >= 0) {
assert(ShMask[I] < (int)NumElts && "Not expecting narrowing shuffle");
Constant *NewCElt = NewVecC[ShMask[I]];
// Bail out if:
// 1. The constant vector contains a constant expression.
// 2. The shuffle needs an element of the constant vector that can't
// be mapped to a new constant vector.
// 3. This is a widening shuffle that copies elements of V1 into the
// extended elements (extending with poison is allowed).
if (!CElt || (!isa<PoisonValue>(NewCElt) && NewCElt != CElt) ||
I >= SrcVecNumElts) {
MayChange = false;
break;
}
NewVecC[ShMask[I]] = CElt;
}
}
if (MayChange) {
Constant *NewC = ConstantVector::get(NewVecC);
// Lanes of NewC not used by the shuffle will be poison which will cause
// UB for div/rem. Mask them with a safe constant.
if (Inst.isIntDivRem())
if (Constant *NewC =
unshuffleConstant(Mask, C, cast<VectorType>(V1->getType()))) {
// For fixed vectors, lanes of NewC not used by the shuffle will be poison
// which will cause UB for div/rem. Mask them with a safe constant.
if (isa<FixedVectorType>(V1->getType()) && Inst.isIntDivRem())
NewC = getSafeVectorConstantForBinop(Opcode, NewC, ConstOp1);

// Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask)
Expand All @@ -2270,27 +2278,6 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
}
}

// Similar to the combine above, but handles the case for scalable vectors
// where both shuffle(V1, 0) and C are splats.
//
// Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
if (isa<ScalableVectorType>(Inst.getType()) &&
match(&Inst, m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Poison(),
m_ZeroMask())),
m_ImmConstant(C)))) {
if (Constant *Splat = C->getSplatValue()) {
bool ConstOp1 = isa<Constant>(RHS);
VectorType *V1Ty = cast<VectorType>(V1->getType());
Constant *NewC = ConstantVector::getSplat(V1Ty->getElementCount(), Splat);

Value *NewLHS = ConstOp1 ? V1 : NewC;
Value *NewRHS = ConstOp1 ? NewC : V1;
VectorType *VTy = cast<VectorType>(Inst.getType());
SmallVector<int> Mask(VTy->getElementCount().getKnownMinValue(), 0);
return createBinOpShuffle(NewLHS, NewRHS, Mask);
}
}

// Try to reassociate to sink a splat shuffle after a binary operation.
if (Inst.isAssociative() && Inst.isCommutative()) {
// Canonicalize shuffle operand as LHS.
Expand Down
55 changes: 55 additions & 0 deletions llvm/test/Transforms/InstCombine/vec_shuffle.ll
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,17 @@ define <4 x i8> @widening_shuffle_add_1(<2 x i8> %x) {
ret <4 x i8> %r
}

define <vscale x 4 x i8> @widening_shuffle_add_1_scalable(<vscale x 2 x i8> %x) {
; CHECK-LABEL: @widening_shuffle_add_1_scalable(
; CHECK-NEXT: [[TMP1:%.*]] = add <vscale x 2 x i8> [[X:%.*]], splat (i8 42)
; CHECK-NEXT: [[R:%.*]] = shufflevector <vscale x 2 x i8> [[TMP1]], <vscale x 2 x i8> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: ret <vscale x 4 x i8> [[R]]
;
%widex = shufflevector <vscale x 2 x i8> %x, <vscale x 2 x i8> poison, <vscale x 4 x i32> zeroinitializer
%r = add <vscale x 4 x i8> %widex, splat (i8 42)
ret <vscale x 4 x i8> %r
}

; Reduce the width of the binop by moving it ahead of a shuffle.

define <4 x i8> @widening_shuffle_add_2(<2 x i8> %x) {
Expand Down Expand Up @@ -938,6 +949,28 @@ define <2 x i32> @shl_splat_constant1(<2 x i32> %x) {
ret <2 x i32> %r
}

define <vscale x 2 x i32> @shl_splat_constant0_scalable(<vscale x 2 x i32> %x) {
; CHECK-LABEL: @shl_splat_constant0_scalable(
; CHECK-NEXT: [[TMP1:%.*]] = shl <vscale x 2 x i32> splat (i32 5), [[X:%.*]]
; CHECK-NEXT: [[R:%.*]] = shufflevector <vscale x 2 x i32> [[TMP1]], <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
; CHECK-NEXT: ret <vscale x 2 x i32> [[R]]
;
%splat = shufflevector <vscale x 2 x i32> %x, <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
%r = shl <vscale x 2 x i32> splat (i32 5), %splat
ret <vscale x 2 x i32> %r
}

define <vscale x 2 x i32> @shl_splat_constant1_scalable(<vscale x 2 x i32> %x) {
; CHECK-LABEL: @shl_splat_constant1_scalable(
; CHECK-NEXT: [[TMP1:%.*]] = shl <vscale x 2 x i32> [[X:%.*]], splat (i32 5)
; CHECK-NEXT: [[R:%.*]] = shufflevector <vscale x 2 x i32> [[TMP1]], <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
; CHECK-NEXT: ret <vscale x 2 x i32> [[R]]
;
%splat = shufflevector <vscale x 2 x i32> %x, <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
%r = shl <vscale x 2 x i32> %splat, splat (i32 5)
ret <vscale x 2 x i32> %r
}

define <2 x i32> @ashr_splat_constant0(<2 x i32> %x) {
; CHECK-LABEL: @ashr_splat_constant0(
; CHECK-NEXT: [[TMP1:%.*]] = lshr <2 x i32> <i32 5, i32 poison>, [[X:%.*]]
Expand Down Expand Up @@ -1048,6 +1081,28 @@ define <2 x i32> @udiv_splat_constant1(<2 x i32> %x) {
ret <2 x i32> %r
}

define <vscale x 2 x i32> @udiv_splat_constant0_scalable(<vscale x 2 x i32> %x) {
; CHECK-LABEL: @udiv_splat_constant0_scalable(
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 2 x i32> [[X:%.*]], <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
; CHECK-NEXT: [[R:%.*]] = udiv <vscale x 2 x i32> splat (i32 42), [[SPLAT]]
; CHECK-NEXT: ret <vscale x 2 x i32> [[R]]
;
%splat = shufflevector <vscale x 2 x i32> %x, <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
%r = udiv <vscale x 2 x i32> splat (i32 42), %splat
ret <vscale x 2 x i32> %r
}

define <vscale x 2 x i32> @udiv_splat_constant1_scalable(<vscale x 2 x i32> %x) {
; CHECK-LABEL: @udiv_splat_constant1_scalable(
; CHECK-NEXT: [[TMP1:%.*]] = udiv <vscale x 2 x i32> [[X:%.*]], splat (i32 42)
; CHECK-NEXT: [[R:%.*]] = shufflevector <vscale x 2 x i32> [[TMP1]], <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
; CHECK-NEXT: ret <vscale x 2 x i32> [[R]]
;
%splat = shufflevector <vscale x 2 x i32> %x, <vscale x 2 x i32> poison, <vscale x 2 x i32> zeroinitializer
%r = udiv <vscale x 2 x i32> %splat, splat (i32 42)
ret <vscale x 2 x i32> %r
}

define <2 x i32> @sdiv_splat_constant0(<2 x i32> %x) {
; CHECK-LABEL: @sdiv_splat_constant0(
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <2 x i32> [[X:%.*]], <2 x i32> poison, <2 x i32> zeroinitializer
Expand Down