Skip to content

AMDGPU: Custom lower fptrunc vectors for f32 -> f16 #141883

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 6, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1061,10 +1061,12 @@ AMDGPULegalizerInfo::AMDGPULegalizerInfo(const GCNSubtarget &ST_,
}

auto &FPTruncActions = getActionDefinitionsBuilder(G_FPTRUNC);
if (ST.hasCvtPkF16F32Inst())
FPTruncActions.legalFor({{S32, S64}, {S16, S32}, {V2S16, V2S32}});
else
if (ST.hasCvtPkF16F32Inst()) {
FPTruncActions.legalFor({{S32, S64}, {S16, S32}, {V2S16, V2S32}})
.customFor({{V4S16, V4S32}, {V8S16, V8S32}});
} else {
FPTruncActions.legalFor({{S32, S64}, {S16, S32}});
}
FPTruncActions.scalarize(0).lower();

getActionDefinitionsBuilder(G_FPEXT)
Expand Down Expand Up @@ -2163,6 +2165,8 @@ bool AMDGPULegalizerInfo::legalizeCustom(
case TargetOpcode::G_FMINNUM_IEEE:
case TargetOpcode::G_FMAXNUM_IEEE:
return legalizeMinNumMaxNum(Helper, MI);
case TargetOpcode::G_FPTRUNC:
return legalizeFPTrunc(Helper, MI, MRI);
case TargetOpcode::G_EXTRACT_VECTOR_ELT:
return legalizeExtractVectorElt(MI, MRI, B);
case TargetOpcode::G_INSERT_VECTOR_ELT:
Expand Down Expand Up @@ -2749,6 +2753,21 @@ bool AMDGPULegalizerInfo::legalizeMinNumMaxNum(LegalizerHelper &Helper,
return Helper.lowerFMinNumMaxNum(MI) == LegalizerHelper::Legalized;
}

bool AMDGPULegalizerInfo::legalizeFPTrunc(LegalizerHelper &Helper,
MachineInstr &MI,
MachineRegisterInfo &MRI) const {
Register DstReg = MI.getOperand(0).getReg();
LLT DstTy = MRI.getType(DstReg);
assert(DstTy.isVector() && DstTy.getNumElements() > 2);
LLT EltTy = DstTy.getElementType();
assert(EltTy == S16 && "Only handle vectors of half");

// Split vector to packs.
LLT PkTy = LLT::fixed_vector(2, EltTy);
return Helper.fewerElementsVector(MI, /*TypeIdx=*/0, PkTy) ==
LegalizerHelper::Legalized;
}

bool AMDGPULegalizerInfo::legalizeExtractVectorElt(
MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B) const {
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AMDGPU/AMDGPULegalizerInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class AMDGPULegalizerInfo final : public LegalizerInfo {
bool legalizeFPTOI(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B, bool Signed) const;
bool legalizeMinNumMaxNum(LegalizerHelper &Helper, MachineInstr &MI) const;
bool legalizeFPTrunc(LegalizerHelper &Helper, MachineInstr &MI,
MachineRegisterInfo &MRI) const;
bool legalizeExtractVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,
MachineIRBuilder &B) const;
bool legalizeInsertVectorElt(MachineInstr &MI, MachineRegisterInfo &MRI,
Expand Down
40 changes: 36 additions & 4 deletions llvm/lib/Target/AMDGPU/SIISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -919,8 +919,10 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM,
setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Legal);
}

if (Subtarget->hasCvtPkF16F32Inst())
setOperationAction(ISD::FP_ROUND, MVT::v2f16, Custom);
if (Subtarget->hasCvtPkF16F32Inst()) {
setOperationAction(ISD::FP_ROUND, {MVT::v2f16, MVT::v4f16, MVT::v8f16},
Custom);
}

setTargetDAGCombine({ISD::ADD,
ISD::UADDO_CARRY,
Expand Down Expand Up @@ -6900,14 +6902,44 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op,
DAG.getTargetConstant(0, DL, MVT::i32));
}

SDValue SITargetLowering::SplitFP_ROUNDVectorToPacks(SDValue Op,
SelectionDAG &DAG) const {
unsigned Opc = Op.getOpcode();
EVT DstVT = Op.getValueType();
unsigned NumElts = DstVT.getVectorNumElements();
assert(NumElts % 2 == 0 && "Only handle vectors of even number of elements");
if (NumElts == 2) // already packed.
return Op;

SDValue Src = Op.getOperand(0);
EVT SrcVT = Src.getValueType();
LLVMContext &Context = *DAG.getContext();
EVT SrcPkVT = EVT::getVectorVT(Context, SrcVT.getScalarType(), 2);
EVT DstPkVT = EVT::getVectorVT(Context, DstVT.getScalarType(), 2);

SDLoc DL(Op);
SmallVector<SDValue, 16> Packs;
for (unsigned Index = 0; Index < NumElts; Index += 2) {
SDValue PkSrc = DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, SrcPkVT, Src,
DAG.getConstant(Index, DL, MVT::i32));
SDValue PkDst = DAG.getNode(Opc, DL, DstPkVT, PkSrc,
DAG.getTargetConstant(0, DL, MVT::i32));
Packs.push_back(PkDst);
}

return DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, Packs);
}

SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const {
SDValue Src = Op.getOperand(0);
EVT SrcVT = Src.getValueType();
EVT DstVT = Op.getValueType();

if (DstVT == MVT::v2f16) {
if (DstVT.isVector() && DstVT.getScalarType() == MVT::f16) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a follow up can you look into extending this for v2bf16? I'm guessing in the ultimate expansion sequence, this will give a benefit even if the underlying v2 opcode isn't legal

assert(Subtarget->hasCvtPkF16F32Inst() && "support v_cvt_pk_f16_f32");
return SrcVT == MVT::v2f32 ? Op : SDValue();
if (SrcVT.getScalarType() != MVT::f32)
return SDValue();
return SplitFP_ROUNDVectorToPacks(Op, DAG);
}

if (SrcVT.getScalarType() != MVT::f64)
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/AMDGPU/SIISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class SITargetLowering final : public AMDGPUTargetLowering {

/// Custom lowering for ISD::FP_ROUND for MVT::f16.
SDValue lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
SDValue SplitFP_ROUNDVectorToPacks(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFMINNUM_FMAXNUM(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFMINIMUM_FMAXIMUM(SDValue Op, SelectionDAG &DAG) const;
SDValue lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const;
Expand Down
120 changes: 113 additions & 7 deletions llvm/test/CodeGen/AMDGPU/fptrunc.v2f16.no.fast.math.ll
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,124 @@ define <2 x half> @v_test_cvt_v2f32_v2f16(<2 x float> %src) {
ret <2 x half> %res
}

define half @fptrunc_v2f32_v2f16_then_extract(<2 x float> %src) {
; GFX950-LABEL: fptrunc_v2f32_v2f16_then_extract:
define <4 x half> @v_test_cvt_v4f32_v4f16(<4 x float> %src) {
; GFX950-LABEL: v_test_cvt_v4f32_v4f16:
; GFX950: ; %bb.0:
; GFX950-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX950-NEXT: v_cvt_pk_f16_f32 v0, v0, v1
; GFX950-NEXT: v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:WORD_1 src1_sel:DWORD
; GFX950-NEXT: v_cvt_pk_f16_f32 v1, v2, v3
; GFX950-NEXT: s_setpc_b64 s[30:31]
%res = fptrunc <4 x float> %src to <4 x half>
ret <4 x half> %res
}

define <8 x half> @v_test_cvt_v8f32_v2f16(<8 x float> %src) {
; GFX950-LABEL: v_test_cvt_v8f32_v2f16:
; GFX950: ; %bb.0:
; GFX950-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX950-NEXT: v_cvt_pk_f16_f32 v0, v0, v1
; GFX950-NEXT: v_cvt_pk_f16_f32 v1, v2, v3
; GFX950-NEXT: v_cvt_pk_f16_f32 v2, v4, v5
; GFX950-NEXT: v_cvt_pk_f16_f32 v3, v6, v7
; GFX950-NEXT: s_setpc_b64 s[30:31]
%res = fptrunc <8 x float> %src to <8 x half>
ret <8 x half> %res
}

define half @fptrunc_v2f32_v2f16_extract_uses(<2 x float> %src) {
; GFX950-LABEL: fptrunc_v2f32_v2f16_extract_uses:
; GFX950: ; %bb.0:
; GFX950-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX950-NEXT: v_cvt_pk_f16_f32 v0, v0, v1
; GFX950-NEXT: v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-NEXT: s_setpc_b64 s[30:31]
%vec_half = fptrunc <2 x float> %src to <2 x half>
%first = extractelement <2 x half> %vec_half, i64 1
%second = extractelement <2 x half> %vec_half, i64 0
%res = fadd half %first, %second
ret half %res
%f0 = extractelement <2 x half> %vec_half, i64 0
%f1 = extractelement <2 x half> %vec_half, i64 1
%rslt = fadd half %f0, %f1
ret half %rslt
}

define half @fptrunc_v4f32_v4f16_extract_uses(<4 x float> %vec_float) {
; GFX950-SDAG-LABEL: fptrunc_v4f32_v4f16_extract_uses:
; GFX950-SDAG: ; %bb.0:
; GFX950-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX950-SDAG-NEXT: v_cvt_pk_f16_f32 v2, v2, v3
; GFX950-SDAG-NEXT: v_cvt_pk_f16_f32 v0, v0, v1
; GFX950-SDAG-NEXT: v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-SDAG-NEXT: v_add_f16_sdwa v1, v2, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-SDAG-NEXT: v_add_f16_e32 v0, v0, v1
; GFX950-SDAG-NEXT: s_setpc_b64 s[30:31]
;
; GFX950-GISEL-LABEL: fptrunc_v4f32_v4f16_extract_uses:
; GFX950-GISEL: ; %bb.0:
; GFX950-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX950-GISEL-NEXT: v_cvt_pk_f16_f32 v0, v0, v1
; GFX950-GISEL-NEXT: v_cvt_pk_f16_f32 v1, v2, v3
; GFX950-GISEL-NEXT: v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-GISEL-NEXT: v_add_f16_sdwa v1, v1, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-GISEL-NEXT: v_add_f16_e32 v0, v0, v1
; GFX950-GISEL-NEXT: s_setpc_b64 s[30:31]
%vec_half = fptrunc <4 x float> %vec_float to <4 x half>
%f0 = extractelement <4 x half> %vec_half, i64 0
%f1 = extractelement <4 x half> %vec_half, i64 1
%f2 = extractelement <4 x half> %vec_half, i64 2
%f3 = extractelement <4 x half> %vec_half, i64 3
%sum0 = fadd half %f0, %f1
%sum1 = fadd half %f2, %f3
%rslt = fadd half %sum0, %sum1
ret half %rslt
}

define half @fptrunc_v8f32_v8f16_extract_uses(<8 x float> %vec_float) {
; GFX950-SDAG-LABEL: fptrunc_v8f32_v8f16_extract_uses:
; GFX950-SDAG: ; %bb.0:
; GFX950-SDAG-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX950-SDAG-NEXT: v_cvt_pk_f16_f32 v6, v6, v7
; GFX950-SDAG-NEXT: v_cvt_pk_f16_f32 v4, v4, v5
; GFX950-SDAG-NEXT: v_cvt_pk_f16_f32 v2, v2, v3
; GFX950-SDAG-NEXT: v_cvt_pk_f16_f32 v0, v0, v1
; GFX950-SDAG-NEXT: v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-SDAG-NEXT: v_add_f16_sdwa v1, v2, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-SDAG-NEXT: v_add_f16_sdwa v2, v4, v4 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-SDAG-NEXT: v_add_f16_sdwa v3, v6, v6 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-SDAG-NEXT: v_add_f16_e32 v0, v0, v1
; GFX950-SDAG-NEXT: v_add_f16_e32 v1, v2, v3
; GFX950-SDAG-NEXT: v_add_f16_e32 v0, v0, v1
; GFX950-SDAG-NEXT: s_setpc_b64 s[30:31]
;
; GFX950-GISEL-LABEL: fptrunc_v8f32_v8f16_extract_uses:
; GFX950-GISEL: ; %bb.0:
; GFX950-GISEL-NEXT: s_waitcnt vmcnt(0) expcnt(0) lgkmcnt(0)
; GFX950-GISEL-NEXT: v_cvt_pk_f16_f32 v0, v0, v1
; GFX950-GISEL-NEXT: v_cvt_pk_f16_f32 v1, v2, v3
; GFX950-GISEL-NEXT: v_cvt_pk_f16_f32 v2, v4, v5
; GFX950-GISEL-NEXT: v_cvt_pk_f16_f32 v3, v6, v7
; GFX950-GISEL-NEXT: v_add_f16_sdwa v0, v0, v0 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-GISEL-NEXT: v_add_f16_sdwa v1, v1, v1 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-GISEL-NEXT: v_add_f16_sdwa v2, v2, v2 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-GISEL-NEXT: v_add_f16_sdwa v3, v3, v3 dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:DWORD src1_sel:WORD_1
; GFX950-GISEL-NEXT: v_add_f16_e32 v0, v0, v1
; GFX950-GISEL-NEXT: v_add_f16_e32 v1, v2, v3
; GFX950-GISEL-NEXT: v_add_f16_e32 v0, v0, v1
; GFX950-GISEL-NEXT: s_setpc_b64 s[30:31]
%vec_half = fptrunc <8 x float> %vec_float to <8 x half>
%f0 = extractelement <8 x half> %vec_half, i64 0
%f1 = extractelement <8 x half> %vec_half, i64 1
%f2 = extractelement <8 x half> %vec_half, i64 2
%f3 = extractelement <8 x half> %vec_half, i64 3
%f4 = extractelement <8 x half> %vec_half, i64 4
%f5 = extractelement <8 x half> %vec_half, i64 5
%f6 = extractelement <8 x half> %vec_half, i64 6
%f7 = extractelement <8 x half> %vec_half, i64 7
%sum0 = fadd half %f0, %f1
%sum1 = fadd half %f2, %f3
%sum2 = fadd half %f4, %f5
%sum3 = fadd half %f6, %f7
%sum4 = fadd half %sum0, %sum1
%sum5 = fadd half %sum2, %sum3
%rslt = fadd half %sum4, %sum5
ret half %rslt
}

define <2 x half> @v_test_cvt_v2f64_v2f16(<2 x double> %src) {
Expand Down