Skip to content

Commit

Permalink
[RISCV] Use vwadd.vx for splat vector with extension (llvm#87249)
Browse files Browse the repository at this point in the history
This patch allows `combineBinOp_VLToVWBinOp_VL` to handle patterns like
`(splat_vector (sext op))` or `(splat_vector (zext op))`. Then we can
use `vwadd.vx` and `vwadd.w` for such a case.

### Source code
```
define <vscale x 8 x i64> @vwadd_vx_splat_sext(<vscale x 8 x i32> %va, i32 %b) {
     %sb = sext i32 %b to i64
     %head = insertelement <vscale x 8 x i64> poison, i64 %sb, i32 0
     %splat = shufflevector <vscale x 8 x i64> %head, <vscale x 8 x i64> poison, <vscale x 8 x i32> zeroinitializer
     %vc = sext <vscale x 8 x i32> %va to <vscale x 8 x i64>
     %ve = add <vscale x 8 x i64> %vc, %splat
     ret <vscale x 8 x i64> %ve
}
```

### Before this patch
[Compiler Explorer](https://godbolt.org/z/sq191PsT4)
```
vwadd_vx_splat_sext:
  sext.w a0, a0
  vsetvli a1, zero, e64, m8, ta, ma
  vmv.v.x v16, a0
  vsetvli zero, zero, e32, m4, ta, ma
  vwadd.wv v16, v16, v8
  vmv8r.v v8, v16
  ret
```
### After this patch
```
vwadd_vx_splat_sext
  vsetvli a1, zero, e32, m4, ta, ma
  vwadd.vx v16, v8, a0
  vmv8r.v v8, v16
  ret
```
  • Loading branch information
sun-jacobi committed Apr 10, 2024
1 parent 313a33b commit 469caa3
Show file tree
Hide file tree
Showing 5 changed files with 569 additions and 235 deletions.
85 changes: 48 additions & 37 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13597,7 +13597,8 @@ struct NodeExtensionHelper {

/// Check if this instance represents a splat.
bool isSplat() const {
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL;
return OrigOperand.getOpcode() == RISCVISD::VMV_V_X_VL ||
OrigOperand.getOpcode() == ISD::SPLAT_VECTOR;
}

/// Get the extended opcode.
Expand Down Expand Up @@ -13641,6 +13642,8 @@ struct NodeExtensionHelper {
case RISCVISD::VZEXT_VL:
case RISCVISD::FP_EXTEND_VL:
return DAG.getNode(ExtOpc, DL, NarrowVT, Source, Mask, VL);
case ISD::SPLAT_VECTOR:
return DAG.getSplat(NarrowVT, DL, Source.getOperand(0));
case RISCVISD::VMV_V_X_VL:
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, NarrowVT,
DAG.getUNDEF(NarrowVT), Source.getOperand(1), VL);
Expand Down Expand Up @@ -13776,6 +13779,47 @@ struct NodeExtensionHelper {
/// Check if this node needs to be fully folded or extended for all users.
bool needToPromoteOtherUsers() const { return EnforceOneUse; }

void fillUpExtensionSupportForSplat(SDNode *Root, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
unsigned Opc = OrigOperand.getOpcode();
MVT VT = OrigOperand.getSimpleValueType();

assert((Opc == ISD::SPLAT_VECTOR || Opc == RISCVISD::VMV_V_X_VL) &&
"Unexpected Opcode");

// The pasthru must be undef for tail agnostic.
if (Opc == RISCVISD::VMV_V_X_VL && !OrigOperand.getOperand(0).isUndef())
return;

// Get the scalar value.
SDValue Op = Opc == ISD::SPLAT_VECTOR ? OrigOperand.getOperand(0)
: OrigOperand.getOperand(1);

// See if we have enough sign bits or zero bits in the scalar to use a
// widening opcode by splatting to smaller element size.
unsigned EltBits = VT.getScalarSizeInBits();
unsigned ScalarBits = Op.getValueSizeInBits();
// Make sure we're getting all element bits from the scalar register.
// FIXME: Support implicit sign extension of vmv.v.x?
if (ScalarBits < EltBits)
return;

unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
// If the narrow type cannot be expressed with a legal VMV,
// this is not a valid candidate.
if (NarrowSize < 8)
return;

if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
SupportsSExt = true;

if (DAG.MaskedValueIsZero(Op,
APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
SupportsZExt = true;

EnforceOneUse = false;
}

/// Helper method to set the various fields of this struct based on the
/// type of \p Root.
void fillUpExtensionSupport(SDNode *Root, SelectionDAG &DAG,
Expand Down Expand Up @@ -13814,43 +13858,10 @@ struct NodeExtensionHelper {
case RISCVISD::FP_EXTEND_VL:
SupportsFPExt = true;
break;
case RISCVISD::VMV_V_X_VL: {
// Historically, we didn't care about splat values not disappearing during
// combines.
EnforceOneUse = false;

// The operand is a splat of a scalar.

// The pasthru must be undef for tail agnostic.
if (!OrigOperand.getOperand(0).isUndef())
break;

// Get the scalar value.
SDValue Op = OrigOperand.getOperand(1);

// See if we have enough sign bits or zero bits in the scalar to use a
// widening opcode by splatting to smaller element size.
MVT VT = Root->getSimpleValueType(0);
unsigned EltBits = VT.getScalarSizeInBits();
unsigned ScalarBits = Op.getValueSizeInBits();
// Make sure we're getting all element bits from the scalar register.
// FIXME: Support implicit sign extension of vmv.v.x?
if (ScalarBits < EltBits)
break;

unsigned NarrowSize = VT.getScalarSizeInBits() / 2;
// If the narrow type cannot be expressed with a legal VMV,
// this is not a valid candidate.
if (NarrowSize < 8)
break;

if (DAG.ComputeMaxSignificantBits(Op) <= NarrowSize)
SupportsSExt = true;
if (DAG.MaskedValueIsZero(Op,
APInt::getBitsSetFrom(ScalarBits, NarrowSize)))
SupportsZExt = true;
case ISD::SPLAT_VECTOR:
case RISCVISD::VMV_V_X_VL:
fillUpExtensionSupportForSplat(Root, DAG, Subtarget);
break;
}
default:
break;
}
Expand Down
Loading

0 comments on commit 469caa3

Please sign in to comment.