Skip to content

Commit

Permalink
ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering (
Browse files Browse the repository at this point in the history
…llvm#66924)

The issue llvm#55208 noticed that std::rint is vectorized by the
SLPVectorizer, but a very similar function, std::lrint, is not.
std::lrint corresponds to ISD::LRINT in the SelectionDAG, and
std::llrint is a familiar cousin corresponding to ISD::LLRINT. Now,
neither ISD::LRINT nor ISD::LLRINT have a corresponding vector variant,
and the LangRef makes this clear in the documentation of llvm.lrint.*
and llvm.llrint.*.

This patch extends the LangRef to include vector variants of
llvm.lrint.* and llvm.llrint.*, and lays the necessary ground-work of
scalarizing it for all targets. However, this patch would be devoid of
motivation unless we show the utility of these new vector variants.
Hence, the RISCV target has been chosen to implement a custom lowering
to the vfcvt.x.f.v instruction. The patch also includes a CostModel for
RISCV, and a trivial follow-up can potentially enable the SLPVectorizer
to vectorize std::lrint and std::llrint, fixing llvm#55208.

The patch includes tests, obviously for the RISCV target, but also for
the X86, AArch64, and PowerPC targets to justify the addition of the
vector variants to the LangRef.
  • Loading branch information
artagnon committed Oct 19, 2023
1 parent 3d7802d commit 98c90a1
Show file tree
Hide file tree
Showing 21 changed files with 12,200 additions and 15 deletions.
6 changes: 4 additions & 2 deletions llvm/docs/LangRef.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15760,7 +15760,8 @@ Syntax:
"""""""

This is an overloaded intrinsic. You can use ``llvm.lrint`` on any
floating-point type. Not all targets support all types however.
floating-point type or vector of floating-point type. Not all targets
support all types however.

::

Expand Down Expand Up @@ -15804,7 +15805,8 @@ Syntax:
"""""""

This is an overloaded intrinsic. You can use ``llvm.llrint`` on any
floating-point type. Not all targets support all types however.
floating-point type or vector of floating-point type. Not all targets
support all types however.

::

Expand Down
6 changes: 6 additions & 0 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
case Intrinsic::rint:
ISD = ISD::FRINT;
break;
case Intrinsic::lrint:
ISD = ISD::LRINT;
break;
case Intrinsic::llrint:
ISD = ISD::LLRINT;
break;
case Intrinsic::round:
ISD = ISD::FROUND;
break;
Expand Down
20 changes: 20 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ namespace {
SDValue visitUINT_TO_FP(SDNode *N);
SDValue visitFP_TO_SINT(SDNode *N);
SDValue visitFP_TO_UINT(SDNode *N);
SDValue visitXRINT(SDNode *N);
SDValue visitFP_ROUND(SDNode *N);
SDValue visitFP_EXTEND(SDNode *N);
SDValue visitFNEG(SDNode *N);
Expand Down Expand Up @@ -1911,6 +1912,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) {
}

SDValue DAGCombiner::visit(SDNode *N) {
// clang-format off
switch (N->getOpcode()) {
default: break;
case ISD::TokenFactor: return visitTokenFactor(N);
Expand Down Expand Up @@ -2011,6 +2013,8 @@ SDValue DAGCombiner::visit(SDNode *N) {
case ISD::UINT_TO_FP: return visitUINT_TO_FP(N);
case ISD::FP_TO_SINT: return visitFP_TO_SINT(N);
case ISD::FP_TO_UINT: return visitFP_TO_UINT(N);
case ISD::LRINT:
case ISD::LLRINT: return visitXRINT(N);
case ISD::FP_ROUND: return visitFP_ROUND(N);
case ISD::FP_EXTEND: return visitFP_EXTEND(N);
case ISD::FNEG: return visitFNEG(N);
Expand Down Expand Up @@ -2065,6 +2069,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
#include "llvm/IR/VPIntrinsics.def"
return visitVPOp(N);
}
// clang-format on
return SDValue();
}

Expand Down Expand Up @@ -17480,6 +17485,21 @@ SDValue DAGCombiner::visitFP_TO_UINT(SDNode *N) {
return FoldIntToFPToInt(N, DAG);
}

SDValue DAGCombiner::visitXRINT(SDNode *N) {
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);

// fold (lrint|llrint undef) -> undef
if (N0.isUndef())
return DAG.getUNDEF(VT);

// fold (lrint|llrint c1fp) -> c1
if (DAG.isConstantFPBuildVectorOrConstantFP(N0))
return DAG.getNode(N->getOpcode(), SDLoc(N), VT, N0);

return SDValue();
}

SDValue DAGCombiner::visitFP_ROUND(SDNode *N) {
SDValue N0 = N->getOperand(0);
SDValue N1 = N->getOperand(1);
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2198,6 +2198,7 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
// to use the promoted float operand. Nodes that produce at least one
// promotion-requiring floating point result have their operands legalized as
// a part of PromoteFloatResult.
// clang-format off
switch (N->getOpcode()) {
default:
#ifndef NDEBUG
Expand All @@ -2209,7 +2210,9 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
case ISD::BITCAST: R = PromoteFloatOp_BITCAST(N, OpNo); break;
case ISD::FCOPYSIGN: R = PromoteFloatOp_FCOPYSIGN(N, OpNo); break;
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT: R = PromoteFloatOp_FP_TO_XINT(N, OpNo); break;
case ISD::FP_TO_UINT:
case ISD::LRINT:
case ISD::LLRINT: R = PromoteFloatOp_UnaryOp(N, OpNo); break;
case ISD::FP_TO_SINT_SAT:
case ISD::FP_TO_UINT_SAT:
R = PromoteFloatOp_FP_TO_XINT_SAT(N, OpNo); break;
Expand All @@ -2218,6 +2221,7 @@ bool DAGTypeLegalizer::PromoteFloatOperand(SDNode *N, unsigned OpNo) {
case ISD::SETCC: R = PromoteFloatOp_SETCC(N, OpNo); break;
case ISD::STORE: R = PromoteFloatOp_STORE(N, OpNo); break;
}
// clang-format on

if (R.getNode())
ReplaceValueWith(SDValue(N, 0), R);
Expand Down Expand Up @@ -2251,7 +2255,7 @@ SDValue DAGTypeLegalizer::PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo) {
}

// Convert the promoted float value to the desired integer type
SDValue DAGTypeLegalizer::PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo) {
SDValue DAGTypeLegalizer::PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo) {
SDValue Op = GetPromotedFloat(N->getOperand(0));
return DAG.getNode(N->getOpcode(), SDLoc(N), N->getValueType(0), Op);
}
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteFloatOp_BITCAST(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FCOPYSIGN(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FP_EXTEND(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FP_TO_XINT(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_UnaryOp(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_FP_TO_XINT_SAT(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_STORE(SDNode *N, unsigned OpNo);
SDValue PromoteFloatOp_SELECT_CC(SDNode *N, unsigned OpNo);
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,8 @@ SDValue VectorLegalizer::LegalizeOp(SDValue Op) {
case ISD::FCEIL:
case ISD::FTRUNC:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FNEARBYINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
Expand Down
16 changes: 15 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FP_TO_SINT:
case ISD::FP_TO_UINT:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::FSIN:
Expand Down Expand Up @@ -681,6 +683,8 @@ bool DAGTypeLegalizer::ScalarizeVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::FP_TO_UINT:
case ISD::SINT_TO_FP:
case ISD::UINT_TO_FP:
case ISD::LRINT:
case ISD::LLRINT:
Res = ScalarizeVecOp_UnaryOp(N);
break;
case ISD::STRICT_SINT_TO_FP:
Expand Down Expand Up @@ -1097,6 +1101,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::VP_FP_TO_UINT:
case ISD::FRINT:
case ISD::VP_FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FROUND:
case ISD::VP_FROUND:
case ISD::FROUNDEVEN:
Expand Down Expand Up @@ -2974,6 +2980,8 @@ bool DAGTypeLegalizer::SplitVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::ZERO_EXTEND:
case ISD::ANY_EXTEND:
case ISD::FTRUNC:
case ISD::LRINT:
case ISD::LLRINT:
Res = SplitVecOp_UnaryOp(N);
break;
case ISD::FLDEXP:
Expand Down Expand Up @@ -4209,6 +4217,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FLOG2:
case ISD::FNEARBYINT:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::FSIN:
Expand Down Expand Up @@ -5958,7 +5968,11 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
case ISD::STRICT_FSETCCS: Res = WidenVecOp_STRICT_FSETCC(N); break;
case ISD::VSELECT: Res = WidenVecOp_VSELECT(N); break;
case ISD::FLDEXP:
case ISD::FCOPYSIGN: Res = WidenVecOp_UnrollVectorOp(N); break;
case ISD::FCOPYSIGN:
case ISD::LRINT:
case ISD::LLRINT:
Res = WidenVecOp_UnrollVectorOp(N);
break;
case ISD::IS_FPCLASS: Res = WidenVecOp_IS_FPCLASS(N); break;

case ISD::ANY_EXTEND:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5135,6 +5135,8 @@ bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const
case ISD::FROUND:
case ISD::FROUNDEVEN:
case ISD::FRINT:
case ISD::LRINT:
case ISD::LLRINT:
case ISD::FNEARBYINT:
case ISD::FLDEXP: {
if (SNaN)
Expand Down
12 changes: 6 additions & 6 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,13 +873,13 @@ void TargetLoweringBase::initActions() {

// These operations default to expand for vector types.
if (VT.isVector())
setOperationAction({ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG,
ISD::ANY_EXTEND_VECTOR_INREG,
ISD::SIGN_EXTEND_VECTOR_INREG,
ISD::ZERO_EXTEND_VECTOR_INREG, ISD::SPLAT_VECTOR},
VT, Expand);
setOperationAction(
{ISD::FCOPYSIGN, ISD::SIGN_EXTEND_INREG, ISD::ANY_EXTEND_VECTOR_INREG,
ISD::SIGN_EXTEND_VECTOR_INREG, ISD::ZERO_EXTEND_VECTOR_INREG,
ISD::SPLAT_VECTOR, ISD::LRINT, ISD::LLRINT},
VT, Expand);

// Constrained floating-point operations default to expand.
// Constrained floating-point operations default to expand.
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
setOperationAction(ISD::STRICT_##DAGN, VT, Expand);
#include "llvm/IR/ConstrainedOps.def"
Expand Down
22 changes: 20 additions & 2 deletions llvm/lib/IR/Verifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5669,10 +5669,28 @@ void Verifier::visitIntrinsicCall(Intrinsic::ID ID, CallBase &Call) {
}
break;
}
case Intrinsic::lround:
case Intrinsic::llround:
case Intrinsic::lrint:
case Intrinsic::llrint: {
Type *ValTy = Call.getArgOperand(0)->getType();
Type *ResultTy = Call.getType();
Check(
ValTy->isFPOrFPVectorTy() && ResultTy->isIntOrIntVectorTy(),
"llvm.lrint, llvm.llrint: argument must be floating-point or vector "
"of floating-points, and result must be integer or vector of integers",
&Call);
Check(ValTy->isVectorTy() == ResultTy->isVectorTy(),
"llvm.lrint, llvm.llrint: argument and result disagree on vector use",
&Call);
if (ValTy->isVectorTy()) {
Check(cast<VectorType>(ValTy)->getElementCount() ==
cast<VectorType>(ResultTy)->getElementCount(),
"llvm.lrint, llvm.llrint: argument must be same length as result",
&Call);
}
break;
}
case Intrinsic::lround:
case Intrinsic::llround: {
Type *ValTy = Call.getArgOperand(0)->getType();
Type *ResultTy = Call.getType();
Check(!ValTy->isVectorTy() && !ResultTy->isVectorTy(),
Expand Down
30 changes: 29 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
VT, Custom);
setOperationAction({ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}, VT,
Custom);

setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
setOperationAction(
{ISD::SADDSAT, ISD::UADDSAT, ISD::SSUBSAT, ISD::USUBSAT}, VT, Legal);

Expand Down Expand Up @@ -2950,6 +2950,31 @@ lowerFTRUNC_FCEIL_FFLOOR_FROUND(SDValue Op, SelectionDAG &DAG,
DAG.getTargetConstant(FRM, DL, Subtarget.getXLenVT()));
}

// Expand vector LRINT and LLRINT by converting to the integer domain.
static SDValue lowerVectorXRINT(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
MVT VT = Op.getSimpleValueType();
assert(VT.isVector() && "Unexpected type");

SDLoc DL(Op);
SDValue Src = Op.getOperand(0);
MVT ContainerVT = VT;

if (VT.isFixedLengthVector()) {
ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
Src = convertToScalableVector(ContainerVT, Src, DAG, Subtarget);
}

auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
SDValue Truncated =
DAG.getNode(RISCVISD::VFCVT_X_F_VL, DL, ContainerVT, Src, Mask, VL);

if (!VT.isFixedLengthVector())
return Truncated;

return convertFromScalableVector(VT, Truncated, DAG, Subtarget);
}

static SDValue
getVSlidedown(SelectionDAG &DAG, const RISCVSubtarget &Subtarget,
const SDLoc &DL, EVT VT, SDValue Merge, SDValue Op,
Expand Down Expand Up @@ -5978,6 +6003,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
case ISD::FROUND:
case ISD::FROUNDEVEN:
return lowerFTRUNC_FCEIL_FFLOOR_FROUND(Op, DAG, Subtarget);
case ISD::LRINT:
case ISD::LLRINT:
return lowerVectorXRINT(Op, DAG, Subtarget);
case ISD::VECREDUCE_ADD:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_SMAX:
Expand Down
27 changes: 27 additions & 0 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,31 @@ static const CostTblEntry VectorIntrinsicCostTable[]{
{Intrinsic::rint, MVT::nxv2f64, 7},
{Intrinsic::rint, MVT::nxv4f64, 7},
{Intrinsic::rint, MVT::nxv8f64, 7},
{Intrinsic::lrint, MVT::v2i32, 1},
{Intrinsic::lrint, MVT::v4i32, 1},
{Intrinsic::lrint, MVT::v8i32, 1},
{Intrinsic::lrint, MVT::v16i32, 1},
{Intrinsic::lrint, MVT::nxv1i32, 1},
{Intrinsic::lrint, MVT::nxv2i32, 1},
{Intrinsic::lrint, MVT::nxv4i32, 1},
{Intrinsic::lrint, MVT::nxv8i32, 1},
{Intrinsic::lrint, MVT::nxv16i32, 1},
{Intrinsic::lrint, MVT::v2i64, 1},
{Intrinsic::lrint, MVT::v4i64, 1},
{Intrinsic::lrint, MVT::v8i64, 1},
{Intrinsic::lrint, MVT::v16i64, 1},
{Intrinsic::lrint, MVT::nxv1i64, 1},
{Intrinsic::lrint, MVT::nxv2i64, 1},
{Intrinsic::lrint, MVT::nxv4i64, 1},
{Intrinsic::lrint, MVT::nxv8i64, 1},
{Intrinsic::llrint, MVT::v2i64, 1},
{Intrinsic::llrint, MVT::v4i64, 1},
{Intrinsic::llrint, MVT::v8i64, 1},
{Intrinsic::llrint, MVT::v16i64, 1},
{Intrinsic::llrint, MVT::nxv1i64, 1},
{Intrinsic::llrint, MVT::nxv2i64, 1},
{Intrinsic::llrint, MVT::nxv4i64, 1},
{Intrinsic::llrint, MVT::nxv8i64, 1},
{Intrinsic::nearbyint, MVT::v2f32, 9},
{Intrinsic::nearbyint, MVT::v4f32, 9},
{Intrinsic::nearbyint, MVT::v8f32, 9},
Expand Down Expand Up @@ -1051,6 +1076,8 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
case Intrinsic::floor:
case Intrinsic::trunc:
case Intrinsic::rint:
case Intrinsic::lrint:
case Intrinsic::llrint:
case Intrinsic::round:
case Intrinsic::roundeven: {
// These all use the same code.
Expand Down

0 comments on commit 98c90a1

Please sign in to comment.