Skip to content

[RISCV] Add GetVTypeMinimalPredicates for the operation supported by zvfhmin. NFC. #143847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 16, 2025

Conversation

tclin914
Copy link
Contributor

@tclin914 tclin914 commented Jun 12, 2025

This patch adds a new GetVTypeMinimalPredicates for f16 operation supported by Zvfhmin. Split the type predicates for minimal support and full compute support. This is a refactor patch for implementing vector compute support for bf16 (Zvfbfa), that we can check bf16 type whether with Zvfbfa extension in GetVTypePredicates.

…zvfhmin. NFC.

This patch adds a new `GetVTypeMinimalPredicates` for `f16` operation
supported by `Zvfhmin`. This is a refactor patch for implementing vector
compute support for bf16 (Zvfbfa), so that we can add `bf16` type to
`GetVTypePredicates` for vector computation.
@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-backend-risc-v

Author: Jim Lin (tclin914)

Changes

This patch adds a new GetVTypeMinimalPredicates for f16 operation supported by Zvfhmin. Split the type predicates for minimal support and full compute support. This is a refactor patch for implementing vector compute support for bf16 (Zvfbfa), that we can add bf16 type to GetVTypePredicates for vector computation.


Full diff: https://github.com/llvm/llvm-project/pull/143847.diff

3 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td (+26-32)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td (+3-5)
  • (modified) llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td (+6-7)
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
index 281f8d55932b9..1e48d42ac8f38 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
@@ -783,6 +783,15 @@ class GetVTypePredicates<VTypeInfo vti> {
                                      true : [HasVInstructions]);
 }
 
+class GetVTypeMinimalPredicates<VTypeInfo vti> {
+  list<Predicate> Predicates = !cond(!eq(vti.Scalar, f16) : [HasVInstructionsF16Minimal],
+                                     !eq(vti.Scalar, bf16) : [HasVInstructionsBF16Minimal],
+                                     !eq(vti.Scalar, f32) : [HasVInstructionsAnyF],
+                                     !eq(vti.Scalar, f64) : [HasVInstructionsF64],
+                                     !eq(vti.SEW, 64) : [HasVInstructionsI64],
+                                     true : [HasVInstructions]);
+}
+
 class VPseudoUSLoadNoMask<VReg RetClass,
                           int EEW,
                           DAGOperand sewop = sew> :
@@ -4568,7 +4577,7 @@ multiclass VPatUnaryS_M<string intrinsic_name,
 multiclass VPatUnaryV_V_AnyMask<string intrinsic, string instruction,
                                 list<VTypeInfo> vtilist> {
   foreach vti = vtilist in {
-    let Predicates = GetVTypePredicates<vti>.Predicates in
+    let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in
     def : VPatUnaryAnyMask<intrinsic, instruction, "VM",
                            vti.Vector, vti.Vector, vti.Mask,
                            vti.Log2SEW, vti.LMul, vti.RegClass, vti.RegClass>;
@@ -4887,7 +4896,7 @@ multiclass VPatBinaryV_VV_INT<string intrinsic, string instruction,
                               list<VTypeInfo> vtilist> {
   foreach vti = vtilist in {
     defvar ivti = GetIntVTypeInfo<vti>.Vti;
-    let Predicates = GetVTypePredicates<vti>.Predicates in
+    let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in
     defm : VPatBinary<intrinsic,
                       instruction # "_VV_" # vti.LMul.MX # "_E" # vti.SEW,
                       vti.Vector, vti.Vector, ivti.Vector, vti.Mask,
@@ -4950,7 +4959,7 @@ multiclass VPatBinaryV_VX_RM<string intrinsic, string instruction,
 multiclass VPatBinaryV_VX_INT<string intrinsic, string instruction,
                           list<VTypeInfo> vtilist> {
   foreach vti = vtilist in
-    let Predicates = GetVTypePredicates<vti>.Predicates in
+    let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in
     defm : VPatBinary<intrinsic, instruction # "_VX_" # vti.LMul.MX,
                       vti.Vector, vti.Vector, XLenVT, vti.Mask,
                       vti.Log2SEW, vti.RegClass,
@@ -4960,7 +4969,7 @@ multiclass VPatBinaryV_VX_INT<string intrinsic, string instruction,
 multiclass VPatBinaryV_VI<string intrinsic, string instruction,
                           list<VTypeInfo> vtilist, Operand imm_type> {
   foreach vti = vtilist in
-    let Predicates = GetVTypePredicates<vti>.Predicates in
+    let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in
     defm : VPatBinary<intrinsic, instruction # "_VI_" # vti.LMul.MX,
                       vti.Vector, vti.Vector, XLenVT, vti.Mask,
                       vti.Log2SEW, vti.RegClass,
@@ -5887,12 +5896,11 @@ multiclass VPatConversionWF_VF<string intrinsic, string instruction,
     defvar fvti = fvtiToFWti.Vti;
     defvar fwti = fvtiToFWti.Wti;
     // Define vfwcvt.f.f.v for f16 when Zvfhmin is enable.
-    let Predicates = !if(!eq(fvti.Scalar, f16), [HasVInstructionsF16Minimal],
-                         !listconcat(GetVTypePredicates<fvti>.Predicates,
-                                     GetVTypePredicates<fwti>.Predicates)) in
-      defm : VPatConversion<intrinsic, instruction, "V",
-                            fwti.Vector, fvti.Vector, fwti.Mask, fvti.Log2SEW,
-                            fvti.LMul, fwti.RegClass, fvti.RegClass, isSEWAware>;
+    let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
+                                 GetVTypeMinimalPredicates<fwti>.Predicates) in
+    defm : VPatConversion<intrinsic, instruction, "V",
+                          fwti.Vector, fvti.Vector, fwti.Mask, fvti.Log2SEW,
+                          fvti.LMul, fwti.RegClass, fvti.RegClass, isSEWAware>;
   }
 }
 
@@ -5979,8 +5987,9 @@ multiclass VPatConversionVF_WF_RM<string intrinsic, string instruction,
   foreach fvtiToFWti = wlist in {
     defvar fvti = fvtiToFWti.Vti;
     defvar fwti = fvtiToFWti.Wti;
-    let Predicates = !listconcat(GetVTypePredicates<fvti>.Predicates,
-                                 GetVTypePredicates<fwti>.Predicates) in
+    // Define vfncvt.f.f.w for f16 when Zvfhmin is enable.
+    let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
+                                 GetVTypeMinimalPredicates<fwti>.Predicates) in
     defm : VPatConversionRoundingMode<intrinsic, instruction, "W",
                                       fvti.Vector, fwti.Vector, fvti.Mask, fvti.Log2SEW,
                                       fvti.LMul, fvti.RegClass, fwti.RegClass,
@@ -7005,8 +7014,7 @@ defm : VPatBinaryV_VM_XM_IM<"int_riscv_vmerge", "PseudoVMERGE">;
 // 11.16. Vector Integer Move Instructions
 //===----------------------------------------------------------------------===//
 foreach vti = AllVectors in {
-  let Predicates = !if(!eq(vti.Scalar, f16), [HasVInstructionsF16Minimal],
-                       GetVTypePredicates<vti>.Predicates) in {
+  let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in {
     def : Pat<(vti.Vector (int_riscv_vmv_v_v (vti.Vector vti.RegClass:$passthru),
                                              (vti.Vector vti.RegClass:$rs1),
                                              VLOpFrag)),
@@ -7201,8 +7209,7 @@ defm : VPatConversionVI_VF<"int_riscv_vfclass", "PseudoVFCLASS">;
 // NOTE: Clang previously used int_riscv_vfmerge for vector-vector, but now uses
 // int_riscv_vmerge. Support both for compatibility.
 foreach vti = AllFloatVectors in {
-  let Predicates = !if(!eq(vti.Scalar, f16), [HasVInstructionsF16Minimal],
-                       GetVTypePredicates<vti>.Predicates) in
+  let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in
     defm : VPatBinaryCarryInTAIL<"int_riscv_vmerge", "PseudoVMERGE", "VVM",
                                  vti.Vector,
                                  vti.Vector, vti.Vector, vti.Mask,
@@ -7281,16 +7288,8 @@ defm : VPatConversionVF_WI_RM<"int_riscv_vfncvt_f_xu_w", "PseudoVFNCVT_F_XU",
                               isSEWAware=1>;
 defm : VPatConversionVF_WI_RM<"int_riscv_vfncvt_f_x_w", "PseudoVFNCVT_F_X",
                               isSEWAware=1>;
-defvar WidenableFloatVectorsExceptF16 = !filter(fvtiToFWti, AllWidenableFloatVectors,
-                                                !ne(fvtiToFWti.Vti.Scalar, f16));
-defm : VPatConversionVF_WF_RM<"int_riscv_vfncvt_f_f_w", "PseudoVFNCVT_F_F",
-                           WidenableFloatVectorsExceptF16, isSEWAware=1>;
-// Define vfncvt.f.f.w for f16 when Zvfhmin is enable.
-defvar F16WidenableFloatVectors = !filter(fvtiToFWti, AllWidenableFloatVectors,
-                                          !eq(fvtiToFWti.Vti.Scalar, f16));
-let Predicates = [HasVInstructionsF16Minimal] in
 defm : VPatConversionVF_WF_RM<"int_riscv_vfncvt_f_f_w", "PseudoVFNCVT_F_F",
-                           F16WidenableFloatVectors, isSEWAware=1>;
+                              AllWidenableFloatVectors, isSEWAware=1>;
 defm : VPatConversionVF_WF_BF_RM<"int_riscv_vfncvtbf16_f_f_w", 
                                  "PseudoVFNCVTBF16_F_F", isSEWAware=1>;
 defm : VPatConversionVF_WF<"int_riscv_vfncvt_rod_f_f_w", "PseudoVFNCVT_ROD_F_F",
@@ -7425,10 +7424,7 @@ defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16",
                               eew=16, vtilist=AllIntegerVectors>;
 
 defm : VPatBinaryV_VV_VX_VI_INT<"int_riscv_vrgather", "PseudoVRGATHER",
-                                AllFloatVectorsExceptFP16, uimm5>;
-let Predicates = [HasVInstructionsF16Minimal] in
-  defm : VPatBinaryV_VV_VX_VI_INT<"int_riscv_vrgather", "PseudoVRGATHER",
-                                  AllFP16Vectors, uimm5>;
+                                AllFloatVectors, uimm5>;
 defm : VPatBinaryV_VV_VX_VI_INT<"int_riscv_vrgather", "PseudoVRGATHER",
                                 AllBFloatVectors, uimm5>;
 defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16",
@@ -7437,9 +7433,7 @@ defm : VPatBinaryV_VV_INT_EEW<"int_riscv_vrgatherei16_vv", "PseudoVRGATHEREI16",
 // 16.5. Vector Compress Instruction
 //===----------------------------------------------------------------------===//
 defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllIntegerVectors>;
-defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllFloatVectorsExceptFP16>;
-let Predicates = [HasVInstructionsF16Minimal] in
-  defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllFP16Vectors>;
+defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllFloatVectors>;
 defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllBFloatVectors>;
 
 // Include the non-intrinsic ISel patterns
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
index e318a78285a2e..520959b0896f7 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td
@@ -864,8 +864,7 @@ multiclass VPatAVGADD_VV_VX_RM<SDNode vop, int vxrm, string suffix = ""> {
 
 // 7.4. Vector Unit-Stride Instructions
 foreach vti = AllVectors in
-  let Predicates = !if(!eq(vti.Scalar, f16), [HasVInstructionsF16Minimal],
-                       GetVTypePredicates<vti>.Predicates) in 
+  let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in
   defm : VPatUSLoadStoreSDNode<vti.Vector, vti.RegClass, vti.Log2SEW, vti.LMul,
                                vti.AVL, vti.RegClass>;
 foreach mti = AllMasks in
@@ -1449,9 +1448,8 @@ defm : VPatNConvertI2FPSDNode_W_RM<any_uint_to_fp, "PseudoVFNCVT_F_XU_W">;
 foreach fvtiToFWti = AllWidenableFloatVectors in {
   defvar fvti = fvtiToFWti.Vti;
   defvar fwti = fvtiToFWti.Wti;
-  let Predicates = !if(!eq(fvti.Scalar, f16), [HasVInstructionsF16Minimal],
-                       !listconcat(GetVTypePredicates<fvti>.Predicates,
-                                   GetVTypePredicates<fwti>.Predicates)) in
+  let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
+                               GetVTypeMinimalPredicates<fwti>.Predicates) in
   def : Pat<(fvti.Vector (fpround (fwti.Vector fwti.RegClass:$rs1))),
             (!cast<Instruction>("PseudoVFNCVT_F_F_W_"#fvti.LMul.MX#"_E"#fvti.SEW)
                 (fvti.Vector (IMPLICIT_DEF)),
diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
index 1da4adc8c3125..52a1c2edd76f2 100644
--- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
+++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
@@ -2687,9 +2687,9 @@ defm : VPatWConvertI2FPVL_V<any_riscv_sint_to_fp_vl, "PseudoVFWCVT_F_X_V">;
 foreach fvtiToFWti = AllWidenableFloatVectors in {
   defvar fvti = fvtiToFWti.Vti;
   defvar fwti = fvtiToFWti.Wti;
-  let Predicates = !if(!eq(fvti.Scalar, f16), [HasVInstructionsF16Minimal],
-                       !listconcat(GetVTypePredicates<fvti>.Predicates,
-                                   GetVTypePredicates<fwti>.Predicates)) in
+  // Define vfwcvt.f.f.v for f16 when Zvfhmin is enable.
+  let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
+                               GetVTypeMinimalPredicates<fwti>.Predicates) in
   def : Pat<(fwti.Vector (any_riscv_fpextend_vl
                              (fvti.Vector fvti.RegClass:$rs1),
                              (fvti.Mask VMV0:$vm),
@@ -2730,10 +2730,9 @@ defm : VPatNConvertI2FP_RM_VL_W<riscv_vfcvt_rm_f_x_vl, "PseudoVFNCVT_F_X_W">;
 foreach fvtiToFWti = AllWidenableFloatVectors in {
   defvar fvti = fvtiToFWti.Vti;
   defvar fwti = fvtiToFWti.Wti;
-  // Define vfwcvt.f.f.v for f16 when Zvfhmin is enable.
-  let Predicates = !if(!eq(fvti.Scalar, f16), [HasVInstructionsF16Minimal],
-                       !listconcat(GetVTypePredicates<fvti>.Predicates,
-                                   GetVTypePredicates<fwti>.Predicates)) in {
+  // Define vfncvt.f.f.w for f16 when Zvfhmin is enable.
+  let Predicates = !listconcat(GetVTypeMinimalPredicates<fvti>.Predicates,
+                               GetVTypeMinimalPredicates<fwti>.Predicates) in {
     def : Pat<(fvti.Vector (any_riscv_fpround_vl
                                (fwti.Vector fwti.RegClass:$rs1),
                                (fwti.Mask VMV0:$vm), VLOpFrag)),

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice cleanup!

@tclin914 tclin914 requested a review from 4vtomat June 12, 2025 08:47
Copy link
Contributor

@jacquesguan jacquesguan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -4568,7 +4577,7 @@ multiclass VPatUnaryS_M<string intrinsic_name,
multiclass VPatUnaryV_V_AnyMask<string intrinsic, string instruction,
list<VTypeInfo> vtilist> {
foreach vti = vtilist in {
let Predicates = GetVTypePredicates<vti>.Predicates in
let Predicates = GetVTypeMinimalPredicates<vti>.Predicates in
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why weren't any tests failing if we using the wrong predicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why weren't any tests failing if we using the wrong predicate?

Currently, only int_riscv_vcompress uses the multiclass VPatUnaryV_V_AnyMask.
Originally, the predicate for the fp16 vector type was written as:

let Predicates = [HasVInstructionsF16Minimal] in
  defm : VPatUnaryV_V_AnyMask<"int_riscv_vcompress", "PseudoVCOMPRESS", AllFP16Vectors>;

I simply moved this predicate into the multiclass VPatUnaryV_V_AnyMask by replacing GetVTypePredicates with GetVTypeMinimalPredicates.

I hope I didn’t misunderstand your concern.

Copy link
Contributor

@lukel97 lukel97 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +786 to +794
class GetVTypeMinimalPredicates<VTypeInfo vti> {
list<Predicate> Predicates = !cond(!eq(vti.Scalar, f16) : [HasVInstructionsF16Minimal],
!eq(vti.Scalar, bf16) : [HasVInstructionsBF16Minimal],
!eq(vti.Scalar, f32) : [HasVInstructionsAnyF],
!eq(vti.Scalar, f64) : [HasVInstructionsF64],
!eq(vti.SEW, 64) : [HasVInstructionsI64],
true : [HasVInstructions]);
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe for a separate PR, I think for anything that's not a widening/converting instruction e.g. vmerge/vrgather/vcompress/vle/vse, we can just use the integer predicate, maybe something like:

class GetVTypeIntPredicates<VTypeInfo vti> {
   defvar ivti = GetIntVTypeInfo<vti>;
   list<Predicate> Predicates = GetVTypePredicates<ivti>;
 }

Since vmerge.vvm with a fp16 vector isn't really a zvfhmin instruction, it only requires zve32x.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is somewhat related to #143975

@topperc
Copy link
Collaborator

topperc commented Jun 13, 2025

FYI @tclin914 SiFive has patches for assembler and intrinsics for Zvfbfa internally. Should we provide those to upstream or have you already done the work yourself?

CC: @4vtomat

@tclin914
Copy link
Contributor Author

FYI @tclin914 SiFive has patches for assembler and intrinsics for Zvfbfa internally. Should we provide those to upstream or have you already done the work yourself?

CC: @4vtomat

I haven't finished work on Zvfbfa yet. Contributions from SiFive for the Zvfbfa support would be greatly appreciated.

@tclin914 tclin914 merged commit d64ee2c into llvm:main Jun 16, 2025
7 checks passed
@tclin914 tclin914 deleted the getvtypeminimalpredicates branch June 16, 2025 02:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants