Skip to content

Commit de1d5d3

Browse files
committed
[InstCombine] canonicalize funnel shift constant shift amount to be modulo bitwidth
The shift argument is defined to be modulo the bitwidth, so if that argument is a constant, we can always reduce the constant to its minimal form to allow better CSE and other follow-on transforms. We need to be careful to ignore constant expressions here, or we will likely infinite loop. I'm adding a general vector constant query for that case. Differential Revision: https://reviews.llvm.org/D59374 llvm-svn: 356192
1 parent 6e86216 commit de1d5d3

File tree

5 files changed

+56
-9
lines changed

5 files changed

+56
-9
lines changed

llvm/include/llvm/IR/Constant.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ class Constant : public User {
9090
/// elements.
9191
bool containsUndefElement() const;
9292

93+
/// Return true if this is a vector constant that includes any constant
94+
/// expressions.
95+
bool containsConstantExpression() const;
96+
9397
/// Return true if evaluation of this constant could trap. This is true for
9498
/// things like constant expressions that could divide by zero.
9599
bool canTrap() const;

llvm/lib/Analysis/InstructionSimplify.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4917,7 +4917,6 @@ static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd,
49174917
const APInt *ShAmtC;
49184918
if (match(ShAmtArg, m_APInt(ShAmtC))) {
49194919
// If there's effectively no shift, return the 1st arg or 2nd arg.
4920-
// TODO: For vectors, we could check each element of a non-splat constant.
49214920
APInt BitWidth = APInt(ShAmtC->getBitWidth(), ShAmtC->getBitWidth());
49224921
if (ShAmtC->urem(BitWidth).isNullValue())
49234922
return ArgBegin[IID == Intrinsic::fshl ? 0 : 1];

llvm/lib/IR/Constants.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,16 @@ bool Constant::containsUndefElement() const {
260260
return false;
261261
}
262262

263+
bool Constant::containsConstantExpression() const {
264+
if (!getType()->isVectorTy())
265+
return false;
266+
for (unsigned i = 0, e = getType()->getVectorNumElements(); i != e; ++i)
267+
if (isa<ConstantExpr>(getAggregateElement(i)))
268+
return true;
269+
270+
return false;
271+
}
272+
263273
/// Constructor to create a '0' constant of arbitrary type.
264274
Constant *Constant::getNullValue(Type *Ty) {
265275
switch (Ty->getTypeID()) {

llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,10 +1994,22 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
19941994

19951995
case Intrinsic::fshl:
19961996
case Intrinsic::fshr: {
1997+
// Canonicalize a shift amount constant operand to be modulo the bit-width.
1998+
unsigned BitWidth = II->getType()->getScalarSizeInBits();
1999+
Constant *ShAmtC;
2000+
if (match(II->getArgOperand(2), m_Constant(ShAmtC)) &&
2001+
!isa<ConstantExpr>(ShAmtC) && !ShAmtC->containsConstantExpression()) {
2002+
Constant *WidthC = ConstantInt::get(II->getType(), BitWidth);
2003+
Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC);
2004+
if (ModuloC != ShAmtC) {
2005+
II->setArgOperand(2, ModuloC);
2006+
return II;
2007+
}
2008+
}
2009+
19972010
const APInt *SA;
19982011
if (match(II->getArgOperand(2), m_APInt(SA))) {
19992012
Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1);
2000-
unsigned BitWidth = SA->getBitWidth();
20012013
uint64_t ShiftAmt = SA->urem(BitWidth);
20022014
assert(ShiftAmt != 0 && "SimplifyCall should have handled zero shift");
20032015
// Normalize to funnel shift left.
@@ -2020,7 +2032,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
20202032
// The shift amount (operand 2) of a funnel shift is modulo the bitwidth,
20212033
// so only the low bits of the shift amount are demanded if the bitwidth is
20222034
// a power-of-2.
2023-
unsigned BitWidth = II->getType()->getScalarSizeInBits();
20242035
if (!isPowerOf2_32(BitWidth))
20252036
break;
20262037
APInt Op2Demanded = APInt::getLowBitsSet(BitWidth, Log2_32_Ceil(BitWidth));

llvm/test/Transforms/InstCombine/fsh.ll

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ define <2 x i31> @fshl_only_op1_demanded_vec_splat(<2 x i31> %x, <2 x i31> %y) {
310310

311311
define i32 @fshl_constant_shift_amount_modulo_bitwidth(i32 %x, i32 %y) {
312312
; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth(
313-
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.fshl.i32(i32 [[X:%.*]], i32 [[Y:%.*]], i32 33)
313+
; CHECK-NEXT: [[R:%.*]] = call i32 @llvm.fshl.i32(i32 [[X:%.*]], i32 [[Y:%.*]], i32 1)
314314
; CHECK-NEXT: ret i32 [[R]]
315315
;
316316
%r = call i32 @llvm.fshl.i32(i32 %x, i32 %y, i32 33)
@@ -319,16 +319,28 @@ define i32 @fshl_constant_shift_amount_modulo_bitwidth(i32 %x, i32 %y) {
319319

320320
define i33 @fshr_constant_shift_amount_modulo_bitwidth(i33 %x, i33 %y) {
321321
; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth(
322-
; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 34)
322+
; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 1)
323323
; CHECK-NEXT: ret i33 [[R]]
324324
;
325325
%r = call i33 @llvm.fshr.i33(i33 %x, i33 %y, i33 34)
326326
ret i33 %r
327327
}
328328

329+
@external_global = external global i8
330+
331+
define i33 @fshr_constant_shift_amount_modulo_bitwidth_constexpr(i33 %x, i33 %y) {
332+
; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth_constexpr(
333+
; CHECK-NEXT: [[R:%.*]] = call i33 @llvm.fshr.i33(i33 [[X:%.*]], i33 [[Y:%.*]], i33 ptrtoint (i8* @external_global to i33))
334+
; CHECK-NEXT: ret i33 [[R]]
335+
;
336+
%shamt = ptrtoint i8* @external_global to i33
337+
%r = call i33 @llvm.fshr.i33(i33 %x, i33 %y, i33 %shamt)
338+
ret i33 %r
339+
}
340+
329341
define <2 x i32> @fshr_constant_shift_amount_modulo_bitwidth_vec(<2 x i32> %x, <2 x i32> %y) {
330342
; CHECK-LABEL: @fshr_constant_shift_amount_modulo_bitwidth_vec(
331-
; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> <i32 34, i32 -1>)
343+
; CHECK-NEXT: [[R:%.*]] = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> [[X:%.*]], <2 x i32> [[Y:%.*]], <2 x i32> <i32 2, i32 31>)
332344
; CHECK-NEXT: ret <2 x i32> [[R]]
333345
;
334346
%r = call <2 x i32> @llvm.fshr.v2i32(<2 x i32> %x, <2 x i32> %y, <2 x i32> <i32 34, i32 -1>)
@@ -373,17 +385,28 @@ define <2 x i32> @fshr_constant_shift_amount_modulo_bitwidth_vec(<2 x i32> %x, <
373385

374386
define <2 x i31> @fshl_constant_shift_amount_modulo_bitwidth_vec(<2 x i31> %x, <2 x i31> %y) {
375387
; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth_vec(
376-
; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> <i31 34, i31 -1>)
388+
; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> <i31 3, i31 1>)
377389
; CHECK-NEXT: ret <2 x i31> [[R]]
378390
;
379391
%r = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> %x, <2 x i31> %y, <2 x i31> <i31 34, i31 -1>)
380392
ret <2 x i31> %r
381393
}
382394

383-
; The shift modulo bitwidth is the same for all vector elements, but this is not simplified yet.
395+
define <2 x i31> @fshl_constant_shift_amount_modulo_bitwidth_vec_const_expr(<2 x i31> %x, <2 x i31> %y) {
396+
; CHECK-LABEL: @fshl_constant_shift_amount_modulo_bitwidth_vec_const_expr(
397+
; CHECK-NEXT: [[R:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> <i31 34, i31 ptrtoint (i8* @external_global to i31)>)
398+
; CHECK-NEXT: ret <2 x i31> [[R]]
399+
;
400+
%shamt = ptrtoint i8* @external_global to i31
401+
%r = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> %x, <2 x i31> %y, <2 x i31> <i31 34, i31 ptrtoint (i8* @external_global to i31)>)
402+
ret <2 x i31> %r
403+
}
404+
405+
; The shift modulo bitwidth is the same for all vector elements.
406+
384407
define <2 x i31> @fshl_only_op1_demanded_vec_nonsplat(<2 x i31> %x, <2 x i31> %y) {
385408
; CHECK-LABEL: @fshl_only_op1_demanded_vec_nonsplat(
386-
; CHECK-NEXT: [[Z:%.*]] = call <2 x i31> @llvm.fshl.v2i31(<2 x i31> [[X:%.*]], <2 x i31> [[Y:%.*]], <2 x i31> <i31 7, i31 38>)
409+
; CHECK-NEXT: [[Z:%.*]] = lshr <2 x i31> [[Y:%.*]], <i31 24, i31 24>
387410
; CHECK-NEXT: [[R:%.*]] = and <2 x i31> [[Z]], <i31 63, i31 31>
388411
; CHECK-NEXT: ret <2 x i31> [[R]]
389412
;

0 commit comments

Comments
 (0)