Skip to content

[SelectionDAG][AArch64] Add dot product lowering in NEON for PARTIAL_REDUCE_*MLA ISD nodes #140075

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 50 additions & 13 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1451,8 +1451,16 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
for (MVT VT : { MVT::v32i8, MVT::v16i16, MVT::v8i32, MVT::v4i64 })
setOperationAction(ISD::ADD, VT, Custom);
// FADDP custom lowering
for (MVT VT : { MVT::v16f16, MVT::v8f32, MVT::v4f64 })
for (MVT VT : {MVT::v16f16, MVT::v8f32, MVT::v4f64})
setOperationAction(ISD::FADD, VT, Custom);

if (EnablePartialReduceNodes && Subtarget->hasDotProd()) {
setPartialReduceMLAAction(MVT::v2i64, MVT::v8i16, Legal);
setPartialReduceMLAAction(MVT::v4i32, MVT::v16i8, Legal);
setPartialReduceMLAAction(MVT::v2i32, MVT::v8i8, Legal);
setPartialReduceMLAAction(MVT::v2i64, MVT::v16i8, Custom);
}

} else /* !isNeonAvailable */ {
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
for (unsigned Op = 0; Op < ISD::BUILTIN_OP_END; ++Op)
Expand Down Expand Up @@ -27569,6 +27577,12 @@ void AArch64TargetLowering::ReplaceNodeResults(
if (SDValue Res = LowerVECTOR_COMPRESS(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
case ISD::PARTIAL_REDUCE_UMLA:
case ISD::PARTIAL_REDUCE_SMLA: {
if (SDValue Res = LowerPARTIAL_REDUCE_MLA(SDValue(N, 0), DAG))
Results.push_back(Res);
return;
}
case ISD::ADD:
case ISD::FADD:
ReplaceAddWithADDP(N, Results, DAG, Subtarget);
Expand Down Expand Up @@ -29518,37 +29532,60 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
}

/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
/// of (nx)v2i64/(nx)v16i8, we cannot directly lower it to a (u|s)dot. We can
/// however still make use of the dot product instruction by instead
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
/// accumulating over two steps: (nx)v16i8 -> (nx)v4i32 -> (nx)v2i64.
/// If available, make use of the (U|S)ADDW(B|T) instructions, otherwise
/// the following pattern is emitted:
/// add(add(Acc, ext(EXTRACT_SUBVECTOR(N, 0)), ext(EXTRACT_SUBVECTOR(N,
/// NTy/2))))
SDValue
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
SelectionDAG &DAG) const {
bool Scalable = Op.getValueType().isScalableVector();

assert((!Scalable || Subtarget->isSVEorStreamingSVEAvailable()) &&
"SVE or StreamingSVE must be available when using scalable vectors.");
assert(
(Scalable || (Subtarget->isNeonAvailable() || Subtarget->hasDotProd())) &&
"Neon or dotprod must be available when using fixed-width vectors.");

SDLoc DL(Op);

SDValue Acc = Op.getOperand(0);
SDValue LHS = Op.getOperand(1);
SDValue RHS = Op.getOperand(2);
EVT ResultVT = Op.getValueType();
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);

SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
assert((Scalable && ResultVT == MVT::nxv2i64 &&
LHS.getValueType() == MVT::nxv16i8) ||
(!Scalable && ResultVT == MVT::v2i64 &&
LHS.getValueType() == MVT::v16i8));

EVT DotVT = Scalable ? MVT::nxv4i32 : MVT::v4i32;
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, DotVT,
DAG.getConstant(0, DL, DotVT), LHS, RHS);

bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
if (Scalable &&
(Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable())) {
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
}

unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
// Fold (nx)v4i32 into (nx)v2i64
auto [DotNodeLo, DotNodeHi] = DAG.SplitVector(DotNode, DL);
if (IsUnsigned) {
DotNodeLo = DAG.getZExtOrTrunc(DotNodeLo, DL, ResultVT);
DotNodeHi = DAG.getZExtOrTrunc(DotNodeHi, DL, ResultVT);
} else {
DotNodeLo = DAG.getSExtOrTrunc(DotNodeLo, DL, ResultVT);
DotNodeHi = DAG.getSExtOrTrunc(DotNodeHi, DL, ResultVT);
}
auto Lo = DAG.getNode(ISD::ADD, DL, ResultVT, Acc, DotNodeLo);
return DAG.getNode(ISD::ADD, DL, ResultVT, Lo, DotNodeHi);
}

SDValue
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -1474,6 +1474,17 @@ defm SDOTlane : SIMDThreeSameVectorDotIndex<0, 0, 0b10, "sdot", AArch64sdot>;
defm UDOTlane : SIMDThreeSameVectorDotIndex<1, 0, 0b10, "udot", AArch64udot>;
}

let Predicates = [HasNEON, HasDotProd] in {
def : Pat<(v4i32 (partial_reduce_umla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
(v4i32 (UDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
def : Pat<(v4i32 (partial_reduce_smla (v4i32 V128:$Acc), (v16i8 V128:$MulLHS), (v16i8 V128:$MulRHS))),
(v4i32 (SDOTv16i8 V128:$Acc, V128:$MulLHS, V128:$MulRHS))>;
def : Pat<(v2i32 (partial_reduce_umla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
(v2i32 (UDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
def : Pat<(v2i32 (partial_reduce_smla (v2i32 V64:$Acc), (v8i8 V64:$MulLHS), (v8i8 V64:$MulRHS))),
(v2i32 (SDOTv8i8 V64:$Acc, V64:$MulLHS, V64:$MulRHS))>;
} // End HasNEON, HasDotProd

// ARMv8.6-A BFloat
let Predicates = [HasNEON, HasBF16] in {
defm BFDOT : SIMDThreeSameVectorBFDot<1, "bfdot">;
Expand Down
Loading