Skip to content

Commit 274ac9d

Browse files
committed
[AArch64][SVE] Lowering sve.dot to DOT node
Differential Revision: https://reviews.llvm.org/D99699
1 parent ab3c5fb commit 274ac9d

File tree

4 files changed

+71
-6
lines changed

4 files changed

+71
-6
lines changed

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ bool ISD::isConstantSplatVector(const SDNode *N, APInt &SplatVal) {
145145
if (auto *Op0 = dyn_cast<ConstantSDNode>(N->getOperand(0))) {
146146
SplatVal = Op0->getAPIntValue().truncOrSelf(EltSize);
147147
return true;
148+
} else if (auto *Op0 = dyn_cast<ConstantFPSDNode>(N->getOperand(0))) {
149+
SplatVal = Op0->getValueAPF().bitcastToAPInt().truncOrSelf(EltSize);
150+
return true;
148151
}
149152
}
150153

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2153,6 +2153,24 @@ MachineBasicBlock *AArch64TargetLowering::EmitInstrWithCustomInserter(
21532153
// Lowering Code
21542154
//===----------------------------------------------------------------------===//
21552155

2156+
/// isZerosVector - Check whether SDNode N is a zero-filled vector.
2157+
static bool isZerosVector(const SDNode *N) {
2158+
// Look through a bit convert.
2159+
while (N->getOpcode() == ISD::BITCAST)
2160+
N = N->getOperand(0).getNode();
2161+
2162+
if (ISD::isConstantSplatVectorAllZeros(N))
2163+
return true;
2164+
2165+
if (N->getOpcode() != AArch64ISD::DUP)
2166+
return false;
2167+
2168+
auto Opnd0 = N->getOperand(0);
2169+
auto *CINT = dyn_cast<ConstantSDNode>(Opnd0);
2170+
auto *CFP = dyn_cast<ConstantFPSDNode>(Opnd0);
2171+
return (CINT && CINT->isNullValue()) || (CFP && CFP->isZero());
2172+
}
2173+
21562174
/// changeIntCCToAArch64CC - Convert a DAG integer condition code to an AArch64
21572175
/// CC
21582176
static AArch64CC::CondCode changeIntCCToAArch64CC(ISD::CondCode CC) {
@@ -3924,9 +3942,13 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
39243942
Op.getOperand(2));
39253943
}
39263944
case Intrinsic::aarch64_neon_sdot:
3927-
case Intrinsic::aarch64_neon_udot: {
3928-
unsigned Opcode = IntNo == Intrinsic::aarch64_neon_udot ? AArch64ISD::UDOT
3929-
: AArch64ISD::SDOT;
3945+
case Intrinsic::aarch64_neon_udot:
3946+
case Intrinsic::aarch64_sve_sdot:
3947+
case Intrinsic::aarch64_sve_udot: {
3948+
unsigned Opcode = (IntNo == Intrinsic::aarch64_neon_udot ||
3949+
IntNo == Intrinsic::aarch64_sve_udot)
3950+
? AArch64ISD::UDOT
3951+
: AArch64ISD::SDOT;
39303952
return DAG.getNode(Opcode, dl, Op.getValueType(), Op.getOperand(1),
39313953
Op.getOperand(2), Op.getOperand(3));
39323954
}
@@ -13340,7 +13362,7 @@ static SDValue performAddDotCombine(SDNode *N, SelectionDAG &DAG) {
1334013362
auto isZeroDot = [](SDValue Dot) {
1334113363
return (Dot.getOpcode() == AArch64ISD::UDOT ||
1334213364
Dot.getOpcode() == AArch64ISD::SDOT) &&
13343-
ISD::isBuildVectorAllZeros(Dot.getOperand(0).getNode());
13365+
isZerosVector(Dot.getOperand(0).getNode());
1334413366
};
1334513367
if (!isZeroDot(Dot))
1334613368
std::swap(Dot, A);

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,8 @@ let Predicates = [HasSVE] in {
353353
defm SDIV_ZPZZ : sve_int_bin_pred_sd<AArch64sdiv_p>;
354354
defm UDIV_ZPZZ : sve_int_bin_pred_sd<AArch64udiv_p>;
355355

356-
defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", int_aarch64_sve_sdot>;
357-
defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", int_aarch64_sve_udot>;
356+
defm SDOT_ZZZ : sve_intx_dot<0b0, "sdot", AArch64sdot>;
357+
defm UDOT_ZZZ : sve_intx_dot<0b1, "udot", AArch64udot>;
358358

359359
defm SDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b0, "sdot", int_aarch64_sve_sdot_lane>;
360360
defm UDOT_ZZZI : sve_intx_dot_by_indexed_elem<0b1, "udot", int_aarch64_sve_udot_lane>;

llvm/test/CodeGen/AArch64/sve-intrinsics-int-arith.ll

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,26 @@ define <vscale x 2 x i64> @sdot_i64(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b
114114
ret <vscale x 2 x i64> %out
115115
}
116116

117+
define <vscale x 2 x i64> @test_sdot_i64_zero(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c) {
118+
; CHECK-LABEL: test_sdot_i64_zero:
119+
; CHECK: sdot z0.d, z1.h, z2.h
120+
; CHECK-NEXT: ret
121+
entry:
122+
%vdot1.i = call <vscale x 2 x i64> @llvm.aarch64.sve.sdot.nxv2i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c)
123+
%ret = add <vscale x 2 x i64> %vdot1.i, %a
124+
ret <vscale x 2 x i64> %ret
125+
}
126+
127+
define <vscale x 4 x i32> @test_sdot_i32_zero(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
128+
; CHECK-LABEL: test_sdot_i32_zero:
129+
; CHECK: sdot z0.s, z1.b, z2.b
130+
; CHECK-NEXT: ret
131+
entry:
132+
%vdot1.i = call <vscale x 4 x i32> @llvm.aarch64.sve.sdot.nxv4i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c)
133+
%ret = add <vscale x 4 x i32> %vdot1.i, %a
134+
ret <vscale x 4 x i32> %ret
135+
}
136+
117137
; SDOT (Indexed)
118138

119139
define <vscale x 4 x i32> @sdot_lane_i32(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
@@ -236,6 +256,26 @@ define <vscale x 2 x i64> @udot_i64(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b
236256
ret <vscale x 2 x i64> %out
237257
}
238258

259+
define <vscale x 2 x i64> @test_udot_i64_zero(<vscale x 2 x i64> %a, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c) {
260+
; CHECK-LABEL: test_udot_i64_zero:
261+
; CHECK: udot z0.d, z1.h, z2.h
262+
; CHECK-NEXT: ret
263+
entry:
264+
%vdot1.i = call <vscale x 2 x i64> @llvm.aarch64.sve.udot.nxv2i64(<vscale x 2 x i64> zeroinitializer, <vscale x 8 x i16> %b, <vscale x 8 x i16> %c)
265+
%ret = add <vscale x 2 x i64> %vdot1.i, %a
266+
ret <vscale x 2 x i64> %ret
267+
}
268+
269+
define <vscale x 4 x i32> @test_udot_i32_zero(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {
270+
; CHECK-LABEL: test_udot_i32_zero:
271+
; CHECK: udot z0.s, z1.b, z2.b
272+
; CHECK-NEXT: ret
273+
entry:
274+
%vdot1.i = call <vscale x 4 x i32> @llvm.aarch64.sve.udot.nxv4i32(<vscale x 4 x i32> zeroinitializer, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c)
275+
%ret = add <vscale x 4 x i32> %vdot1.i, %a
276+
ret <vscale x 4 x i32> %ret
277+
}
278+
239279
; UDOT (Indexed)
240280

241281
define <vscale x 4 x i32> @udot_lane_i32(<vscale x 4 x i32> %a, <vscale x 16 x i8> %b, <vscale x 16 x i8> %c) {

0 commit comments

Comments
 (0)