Skip to content

Commit a93bfc4

Browse files
committed
[NVPTX] Add -nvptx-prec-divf32=3 to disable ftz for f32 fdiv
1 parent 656d9ba commit a93bfc4

File tree

4 files changed

+193
-91
lines changed

4 files changed

+193
-91
lines changed

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ enum class DivPrecisionLevel : unsigned {
258258
Approx = 0,
259259
Full = 1,
260260
IEEE754 = 2,
261+
IEEE754_NoFTZ = 3,
261262
};
262263

263264
} // namespace NVPTX

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,15 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8787

8888
static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
8989
"nvptx-prec-divf32", cl::Hidden,
90-
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
91-
" IEEE Compliant F32 div.rnd if available."),
92-
cl::values(clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0",
93-
"Use div.approx"),
94-
clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
95-
clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
96-
"Use IEEE Compliant F32 div.rnd if available")),
90+
cl::desc(
91+
"NVPTX Specifies: Override the precision of the lowering for f32 fdiv"),
92+
cl::values(
93+
clEnumValN(NVPTX::DivPrecisionLevel::Approx, "0", "Use div.approx"),
94+
clEnumValN(NVPTX::DivPrecisionLevel::Full, "1", "Use div.full"),
95+
clEnumValN(NVPTX::DivPrecisionLevel::IEEE754, "2",
96+
"Use IEEE Compliant F32 div.rnd if available (default)"),
97+
clEnumValN(NVPTX::DivPrecisionLevel::IEEE754_NoFTZ, "3",
98+
"Use IEEE Compliant F32 div.rnd if available, no FTZ")),
9799
cl::init(NVPTX::DivPrecisionLevel::IEEE754));
98100

99101
static cl::opt<bool> UsePrecSqrtF32(

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 90 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,20 +1222,20 @@ def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, Int32Regs, True>;
12221222
// F64 division
12231223
//
12241224
def FRCP64r :
1225-
NVPTXInst<(outs Float64Regs:$dst),
1226-
(ins Float64Regs:$b),
1227-
"rcp.rn.f64 \t$dst, $b;",
1228-
[(set f64:$dst, (fdiv f64imm_1, f64:$b))]>;
1225+
BasicNVPTXInst<(outs Float64Regs:$dst),
1226+
(ins Float64Regs:$b),
1227+
"rcp.rn.f64",
1228+
[(set f64:$dst, (fdiv f64imm_1, f64:$b))]>;
12291229
def FDIV64rr :
1230-
NVPTXInst<(outs Float64Regs:$dst),
1231-
(ins Float64Regs:$a, Float64Regs:$b),
1232-
"div.rn.f64 \t$dst, $a, $b;",
1233-
[(set f64:$dst, (fdiv f64:$a, f64:$b))]>;
1230+
BasicNVPTXInst<(outs Float64Regs:$dst),
1231+
(ins Float64Regs:$a, Float64Regs:$b),
1232+
"div.rn.f64",
1233+
[(set f64:$dst, (fdiv f64:$a, f64:$b))]>;
12341234
def FDIV64ri :
1235-
NVPTXInst<(outs Float64Regs:$dst),
1236-
(ins Float64Regs:$a, f64imm:$b),
1237-
"div.rn.f64 \t$dst, $a, $b;",
1238-
[(set f64:$dst, (fdiv f64:$a, fpimm:$b))]>;
1235+
BasicNVPTXInst<(outs Float64Regs:$dst),
1236+
(ins Float64Regs:$a, f64imm:$b),
1237+
"div.rn.f64",
1238+
[(set f64:$dst, (fdiv f64:$a, fpimm:$b))]>;
12391239

12401240
// fdiv will be converted to rcp
12411241
// fneg (fdiv 1.0, X) => fneg (rcp.rn X)
@@ -1253,42 +1253,42 @@ def fdiv_approx : PatFrag<(ops node:$a, node:$b),
12531253

12541254

12551255
def FRCP32_approx_r_ftz :
1256-
NVPTXInst<(outs Float32Regs:$dst),
1257-
(ins Float32Regs:$b),
1258-
"rcp.approx.ftz.f32 \t$dst, $b;",
1259-
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>,
1260-
Requires<[doF32FTZ]>;
1256+
BasicNVPTXInst<(outs Float32Regs:$dst),
1257+
(ins Float32Regs:$b),
1258+
"rcp.approx.ftz.f32",
1259+
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>,
1260+
Requires<[doF32FTZ]>;
12611261
def FRCP32_approx_r :
1262-
NVPTXInst<(outs Float32Regs:$dst),
1263-
(ins Float32Regs:$b),
1264-
"rcp.approx.f32 \t$dst, $b;",
1265-
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>;
1262+
BasicNVPTXInst<(outs Float32Regs:$dst),
1263+
(ins Float32Regs:$b),
1264+
"rcp.approx.f32",
1265+
[(set f32:$dst, (fdiv_approx f32imm_1, f32:$b))]>;
12661266

12671267
//
12681268
// F32 Approximate division
12691269
//
12701270
def FDIV32approxrr_ftz :
1271-
NVPTXInst<(outs Float32Regs:$dst),
1272-
(ins Float32Regs:$a, Float32Regs:$b),
1273-
"div.approx.ftz.f32 \t$dst, $a, $b;",
1274-
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>,
1275-
Requires<[doF32FTZ]>;
1271+
BasicNVPTXInst<(outs Float32Regs:$dst),
1272+
(ins Float32Regs:$a, Float32Regs:$b),
1273+
"div.approx.ftz.f32",
1274+
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>,
1275+
Requires<[doF32FTZ]>;
12761276
def FDIV32approxri_ftz :
1277-
NVPTXInst<(outs Float32Regs:$dst),
1278-
(ins Float32Regs:$a, f32imm:$b),
1279-
"div.approx.ftz.f32 \t$dst, $a, $b;",
1280-
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>,
1281-
Requires<[doF32FTZ]>;
1277+
BasicNVPTXInst<(outs Float32Regs:$dst),
1278+
(ins Float32Regs:$a, f32imm:$b),
1279+
"div.approx.ftz.f32",
1280+
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>,
1281+
Requires<[doF32FTZ]>;
12821282
def FDIV32approxrr :
1283-
NVPTXInst<(outs Float32Regs:$dst),
1284-
(ins Float32Regs:$a, Float32Regs:$b),
1285-
"div.approx.f32 \t$dst, $a, $b;",
1286-
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>;
1283+
BasicNVPTXInst<(outs Float32Regs:$dst),
1284+
(ins Float32Regs:$a, Float32Regs:$b),
1285+
"div.approx.f32",
1286+
[(set f32:$dst, (fdiv_approx f32:$a, f32:$b))]>;
12871287
def FDIV32approxri :
1288-
NVPTXInst<(outs Float32Regs:$dst),
1289-
(ins Float32Regs:$a, f32imm:$b),
1290-
"div.approx.f32 \t$dst, $a, $b;",
1291-
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>;
1288+
BasicNVPTXInst<(outs Float32Regs:$dst),
1289+
(ins Float32Regs:$a, f32imm:$b),
1290+
"div.approx.f32",
1291+
[(set f32:$dst, (fdiv_approx f32:$a, fpimm:$b))]>;
12921292
//
12931293
// F32 Semi-accurate reciprocal
12941294
//
@@ -1312,66 +1312,72 @@ def : Pat<(fdiv_full f32imm_1, f32:$b),
13121312
// F32 Semi-accurate division
13131313
//
13141314
def FDIV32rr_ftz :
1315-
NVPTXInst<(outs Float32Regs:$dst),
1316-
(ins Float32Regs:$a, Float32Regs:$b),
1317-
"div.full.ftz.f32 \t$dst, $a, $b;",
1318-
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>,
1319-
Requires<[doF32FTZ]>;
1315+
BasicNVPTXInst<(outs Float32Regs:$dst),
1316+
(ins Float32Regs:$a, Float32Regs:$b),
1317+
"div.full.ftz.f32",
1318+
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>,
1319+
Requires<[doF32FTZ]>;
13201320
def FDIV32ri_ftz :
1321-
NVPTXInst<(outs Float32Regs:$dst),
1322-
(ins Float32Regs:$a, f32imm:$b),
1323-
"div.full.ftz.f32 \t$dst, $a, $b;",
1324-
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>,
1325-
Requires<[doF32FTZ]>;
1321+
BasicNVPTXInst<(outs Float32Regs:$dst),
1322+
(ins Float32Regs:$a, f32imm:$b),
1323+
"div.full.ftz.f32",
1324+
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>,
1325+
Requires<[doF32FTZ]>;
13261326
def FDIV32rr :
1327-
NVPTXInst<(outs Float32Regs:$dst),
1328-
(ins Float32Regs:$a, Float32Regs:$b),
1329-
"div.full.f32 \t$dst, $a, $b;",
1330-
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>;
1327+
BasicNVPTXInst<(outs Float32Regs:$dst),
1328+
(ins Float32Regs:$a, Float32Regs:$b),
1329+
"div.full.f32",
1330+
[(set f32:$dst, (fdiv_full f32:$a, f32:$b))]>;
13311331
def FDIV32ri :
1332-
NVPTXInst<(outs Float32Regs:$dst),
1333-
(ins Float32Regs:$a, f32imm:$b),
1334-
"div.full.f32 \t$dst, $a, $b;",
1335-
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>;
1332+
BasicNVPTXInst<(outs Float32Regs:$dst),
1333+
(ins Float32Regs:$a, f32imm:$b),
1334+
"div.full.f32",
1335+
[(set f32:$dst, (fdiv_full f32:$a, fpimm:$b))]>;
13361336
//
13371337
// F32 Accurate reciprocal
13381338
//
1339+
1340+
def fdiv_ftz : PatFrag<(ops node:$a, node:$b),
1341+
(fdiv node:$a, node:$b), [{
1342+
return getDivF32Level(N) == NVPTX::DivPrecisionLevel::IEEE754;
1343+
}]>;
1344+
13391345
def FRCP32r_prec_ftz :
1340-
NVPTXInst<(outs Float32Regs:$dst),
1341-
(ins Float32Regs:$b),
1342-
"rcp.rn.ftz.f32 \t$dst, $b;",
1343-
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>,
1344-
Requires<[doF32FTZ]>;
1346+
BasicNVPTXInst<(outs Float32Regs:$dst),
1347+
(ins Float32Regs:$b),
1348+
"rcp.rn.ftz.f32",
1349+
[(set f32:$dst, (fdiv_ftz f32imm_1, f32:$b))]>,
1350+
Requires<[doF32FTZ]>;
13451351
def FRCP32r_prec :
1346-
NVPTXInst<(outs Float32Regs:$dst),
1347-
(ins Float32Regs:$b),
1348-
"rcp.rn.f32 \t$dst, $b;",
1349-
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>;
1352+
BasicNVPTXInst<(outs Float32Regs:$dst),
1353+
(ins Float32Regs:$b),
1354+
"rcp.rn.f32",
1355+
[(set f32:$dst, (fdiv f32imm_1, f32:$b))]>;
13501356
//
13511357
// F32 Accurate division
13521358
//
13531359
def FDIV32rr_prec_ftz :
1354-
NVPTXInst<(outs Float32Regs:$dst),
1355-
(ins Float32Regs:$a, Float32Regs:$b),
1356-
"div.rn.ftz.f32 \t$dst, $a, $b;",
1357-
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>,
1358-
Requires<[doF32FTZ]>;
1360+
BasicNVPTXInst<(outs Float32Regs:$dst),
1361+
(ins Float32Regs:$a, Float32Regs:$b),
1362+
"div.rn.ftz.f32",
1363+
[(set f32:$dst, (fdiv_ftz f32:$a, f32:$b))]>,
1364+
Requires<[doF32FTZ]>;
13591365
def FDIV32ri_prec_ftz :
1360-
NVPTXInst<(outs Float32Regs:$dst),
1361-
(ins Float32Regs:$a, f32imm:$b),
1362-
"div.rn.ftz.f32 \t$dst, $a, $b;",
1363-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>,
1364-
Requires<[doF32FTZ]>;
1366+
BasicNVPTXInst<(outs Float32Regs:$dst),
1367+
(ins Float32Regs:$a, f32imm:$b),
1368+
"div.rn.ftz.f32",
1369+
[(set f32:$dst, (fdiv_ftz f32:$a, fpimm:$b))]>,
1370+
Requires<[doF32FTZ]>;
13651371
def FDIV32rr_prec :
1366-
NVPTXInst<(outs Float32Regs:$dst),
1367-
(ins Float32Regs:$a, Float32Regs:$b),
1368-
"div.rn.f32 \t$dst, $a, $b;",
1369-
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>;
1372+
BasicNVPTXInst<(outs Float32Regs:$dst),
1373+
(ins Float32Regs:$a, Float32Regs:$b),
1374+
"div.rn.f32",
1375+
[(set f32:$dst, (fdiv f32:$a, f32:$b))]>;
13701376
def FDIV32ri_prec :
1371-
NVPTXInst<(outs Float32Regs:$dst),
1372-
(ins Float32Regs:$a, f32imm:$b),
1373-
"div.rn.f32 \t$dst, $a, $b;",
1374-
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>;
1377+
BasicNVPTXInst<(outs Float32Regs:$dst),
1378+
(ins Float32Regs:$a, f32imm:$b),
1379+
"div.rn.f32",
1380+
[(set f32:$dst, (fdiv f32:$a, fpimm:$b))]>;
13751381

13761382
//
13771383
// FMA
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -verify-machineinstrs -nvptx-prec-divf32=0 | FileCheck %s --check-prefix=APPROX
3+
; RUN: llc < %s -verify-machineinstrs -nvptx-prec-divf32=1 | FileCheck %s --check-prefix=FULL
4+
; RUN: llc < %s -verify-machineinstrs -nvptx-prec-divf32=2 | FileCheck %s --check-prefixes=IEEE,FTZ
5+
; RUN: llc < %s -verify-machineinstrs -nvptx-prec-divf32=3 | FileCheck %s --check-prefixes=IEEE,NOFTZ
6+
7+
target triple = "nvptx64-nvidia-cuda"
8+
9+
define float @div_ftz(float %a, float %b) "denormal-fp-math-f32" = "preserve-sign" {
10+
; APPROX-LABEL: div_ftz(
11+
; APPROX: {
12+
; APPROX-NEXT: .reg .b32 %r<4>;
13+
; APPROX-EMPTY:
14+
; APPROX-NEXT: // %bb.0:
15+
; APPROX-NEXT: ld.param.b32 %r1, [div_ftz_param_0];
16+
; APPROX-NEXT: ld.param.b32 %r2, [div_ftz_param_1];
17+
; APPROX-NEXT: div.approx.ftz.f32 %r3, %r1, %r2;
18+
; APPROX-NEXT: st.param.b32 [func_retval0], %r3;
19+
; APPROX-NEXT: ret;
20+
;
21+
; FULL-LABEL: div_ftz(
22+
; FULL: {
23+
; FULL-NEXT: .reg .b32 %r<4>;
24+
; FULL-EMPTY:
25+
; FULL-NEXT: // %bb.0:
26+
; FULL-NEXT: ld.param.b32 %r1, [div_ftz_param_0];
27+
; FULL-NEXT: ld.param.b32 %r2, [div_ftz_param_1];
28+
; FULL-NEXT: div.full.ftz.f32 %r3, %r1, %r2;
29+
; FULL-NEXT: st.param.b32 [func_retval0], %r3;
30+
; FULL-NEXT: ret;
31+
;
32+
; FTZ-LABEL: div_ftz(
33+
; FTZ: {
34+
; FTZ-NEXT: .reg .b32 %r<4>;
35+
; FTZ-EMPTY:
36+
; FTZ-NEXT: // %bb.0:
37+
; FTZ-NEXT: ld.param.b32 %r1, [div_ftz_param_0];
38+
; FTZ-NEXT: ld.param.b32 %r2, [div_ftz_param_1];
39+
; FTZ-NEXT: div.rn.ftz.f32 %r3, %r1, %r2;
40+
; FTZ-NEXT: st.param.b32 [func_retval0], %r3;
41+
; FTZ-NEXT: ret;
42+
;
43+
; NOFTZ-LABEL: div_ftz(
44+
; NOFTZ: {
45+
; NOFTZ-NEXT: .reg .b32 %r<4>;
46+
; NOFTZ-EMPTY:
47+
; NOFTZ-NEXT: // %bb.0:
48+
; NOFTZ-NEXT: ld.param.b32 %r1, [div_ftz_param_0];
49+
; NOFTZ-NEXT: ld.param.b32 %r2, [div_ftz_param_1];
50+
; NOFTZ-NEXT: div.rn.f32 %r3, %r1, %r2;
51+
; NOFTZ-NEXT: st.param.b32 [func_retval0], %r3;
52+
; NOFTZ-NEXT: ret;
53+
%val = fdiv float %a, %b
54+
ret float %val
55+
}
56+
57+
58+
define float @div(float %a, float %b) {
59+
; APPROX-LABEL: div(
60+
; APPROX: {
61+
; APPROX-NEXT: .reg .b32 %r<4>;
62+
; APPROX-EMPTY:
63+
; APPROX-NEXT: // %bb.0:
64+
; APPROX-NEXT: ld.param.b32 %r1, [div_param_0];
65+
; APPROX-NEXT: ld.param.b32 %r2, [div_param_1];
66+
; APPROX-NEXT: div.approx.f32 %r3, %r1, %r2;
67+
; APPROX-NEXT: st.param.b32 [func_retval0], %r3;
68+
; APPROX-NEXT: ret;
69+
;
70+
; FULL-LABEL: div(
71+
; FULL: {
72+
; FULL-NEXT: .reg .b32 %r<4>;
73+
; FULL-EMPTY:
74+
; FULL-NEXT: // %bb.0:
75+
; FULL-NEXT: ld.param.b32 %r1, [div_param_0];
76+
; FULL-NEXT: ld.param.b32 %r2, [div_param_1];
77+
; FULL-NEXT: div.full.f32 %r3, %r1, %r2;
78+
; FULL-NEXT: st.param.b32 [func_retval0], %r3;
79+
; FULL-NEXT: ret;
80+
;
81+
; IEEE-LABEL: div(
82+
; IEEE: {
83+
; IEEE-NEXT: .reg .b32 %r<4>;
84+
; IEEE-EMPTY:
85+
; IEEE-NEXT: // %bb.0:
86+
; IEEE-NEXT: ld.param.b32 %r1, [div_param_0];
87+
; IEEE-NEXT: ld.param.b32 %r2, [div_param_1];
88+
; IEEE-NEXT: div.rn.f32 %r3, %r1, %r2;
89+
; IEEE-NEXT: st.param.b32 [func_retval0], %r3;
90+
; IEEE-NEXT: ret;
91+
%val = fdiv float %a, %b
92+
ret float %val
93+
}

0 commit comments

Comments
 (0)