-
Notifications
You must be signed in to change notification settings - Fork 14k
AMDGPU: Custom lower fptrunc vectors for f32 -> f16 #141883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
85ecbf7
8de1b0c
4ecace6
833b904
eac6d6e
40beab1
471b3e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -919,8 +919,11 @@ SITargetLowering::SITargetLowering(const TargetMachine &TM, | |
setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Legal); | ||
} | ||
|
||
if (Subtarget->hasCvtPkF16F32Inst()) | ||
setOperationAction(ISD::FP_ROUND, MVT::v2f16, Custom); | ||
if (Subtarget->hasCvtPkF16F32Inst()) { | ||
setOperationAction(ISD::FP_ROUND, | ||
{MVT::v2f16, MVT::v4f16, MVT::v8f16, MVT::v16f16}, | ||
Custom); | ||
} | ||
|
||
setTargetDAGCombine({ISD::ADD, | ||
ISD::UADDO_CARRY, | ||
|
@@ -6900,14 +6903,35 @@ SDValue SITargetLowering::getFPExtOrFPRound(SelectionDAG &DAG, SDValue Op, | |
DAG.getTargetConstant(0, DL, MVT::i32)); | ||
} | ||
|
||
SDValue SITargetLowering::splitFP_ROUNDVectorOp(SDValue Op, | ||
SelectionDAG &DAG) const { | ||
EVT DstVT = Op.getValueType(); | ||
unsigned NumElts = DstVT.getVectorNumElements(); | ||
assert(NumElts > 2 && isPowerOf2_32(NumElts)); | ||
|
||
auto [Lo, Hi] = DAG.SplitVectorOperand(Op.getNode(), 0); | ||
|
||
SDLoc DL(Op); | ||
unsigned Opc = Op.getOpcode(); | ||
SDValue Flags = Op.getOperand(1); | ||
EVT HalfDstVT = | ||
EVT::getVectorVT(*DAG.getContext(), DstVT.getScalarType(), NumElts / 2); | ||
SDValue OpLo = DAG.getNode(Opc, DL, HalfDstVT, Lo, Flags); | ||
SDValue OpHi = DAG.getNode(Opc, DL, HalfDstVT, Hi, Flags); | ||
|
||
return DAG.getNode(ISD::CONCAT_VECTORS, DL, DstVT, OpLo, OpHi); | ||
} | ||
|
||
SDValue SITargetLowering::lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const { | ||
SDValue Src = Op.getOperand(0); | ||
EVT SrcVT = Src.getValueType(); | ||
EVT DstVT = Op.getValueType(); | ||
|
||
if (DstVT == MVT::v2f16) { | ||
if (DstVT.isVector() && DstVT.getScalarType() == MVT::f16) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a follow up can you look into extending this for v2bf16? I'm guessing in the ultimate expansion sequence, this will give a benefit even if the underlying v2 opcode isn't legal |
||
assert(Subtarget->hasCvtPkF16F32Inst() && "support v_cvt_pk_f16_f32"); | ||
return SrcVT == MVT::v2f32 ? Op : SDValue(); | ||
if (SrcVT.getScalarType() != MVT::f32) | ||
return SDValue(); | ||
return SrcVT == MVT::v2f32 ? Op : splitFP_ROUNDVectorOp(Op, DAG); | ||
} | ||
|
||
if (SrcVT.getScalarType() != MVT::f64) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -145,6 +145,7 @@ class SITargetLowering final : public AMDGPUTargetLowering { | |
|
||
/// Custom lowering for ISD::FP_ROUND for MVT::f16. | ||
SDValue lowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const; | ||
SDValue splitFP_ROUNDVectorOp(SDValue Op, SelectionDAG &DAG) const; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did I miss anything here? This function was added but not implemented? |
||
SDValue lowerFMINNUM_FMAXNUM(SDValue Op, SelectionDAG &DAG) const; | ||
SDValue lowerFMINIMUM_FMAXIMUM(SDValue Op, SelectionDAG &DAG) const; | ||
SDValue lowerFLDEXP(SDValue Op, SelectionDAG &DAG) const; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't think this is actually tested