Skip to content

[AMDGPU] Fold fmed3 when inputs include infinity #144824

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

fairywreath
Copy link
Contributor

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Jun 19, 2025

@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-llvm-transforms

Author: Darren Wihandi (fairywreath)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/144824.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp (+39-1)
  • (modified) llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll (+90)
diff --git a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
index 5477c5eae9392..7554c6953d76f 100644
--- a/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
+++ b/llvm/lib/Target/AMDGPU/AMDGPUInstCombineIntrinsic.cpp
@@ -1039,7 +1039,6 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
     const APFloat *ConstSrc1 = nullptr;
     const APFloat *ConstSrc2 = nullptr;
 
-    // TODO: Also can fold to 2 operands with infinities.
     if ((match(Src0, m_APFloat(ConstSrc0)) && ConstSrc0->isNaN()) ||
         isa<UndefValue>(Src0)) {
       switch (fpenvIEEEMode(II)) {
@@ -1088,6 +1087,45 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
       case KnownIEEEMode::Unknown:
         break;
       }
+    } else if (match(Src0, m_APFloat(ConstSrc0)) && ConstSrc0->isInfinity()) {
+      switch (fpenvIEEEMode(II)) {
+      case KnownIEEEMode::On:
+        V = ConstSrc0->isNegative() ? IC.Builder.CreateMinNum(Src1, Src2)
+                                    : IC.Builder.CreateMaxNum(Src1, Src2);
+        break;
+      case KnownIEEEMode::Off:
+        V = ConstSrc0->isNegative() ? IC.Builder.CreateMinimumNum(Src1, Src2)
+                                    : IC.Builder.CreateMaximumNum(Src1, Src2);
+        break;
+      case KnownIEEEMode::Unknown:
+        break;
+      }
+    } else if (match(Src1, m_APFloat(ConstSrc1)) && ConstSrc1->isInfinity()) {
+      switch (fpenvIEEEMode(II)) {
+      case KnownIEEEMode::On:
+        V = ConstSrc1->isNegative() ? IC.Builder.CreateMinNum(Src0, Src2)
+                                    : IC.Builder.CreateMaxNum(Src0, Src2);
+        break;
+      case KnownIEEEMode::Off:
+        V = ConstSrc1->isNegative() ? IC.Builder.CreateMinimumNum(Src0, Src2)
+                                    : IC.Builder.CreateMaximumNum(Src0, Src2);
+        break;
+      case KnownIEEEMode::Unknown:
+        break;
+      }
+    } else if (match(Src2, m_APFloat(ConstSrc2)) && ConstSrc2->isInfinity()) {
+      switch (fpenvIEEEMode(II)) {
+      case KnownIEEEMode::On:
+        V = ConstSrc2->isNegative() ? IC.Builder.CreateMinNum(Src0, Src1)
+                                    : IC.Builder.CreateMaxNum(Src0, Src1);
+        break;
+      case KnownIEEEMode::Off:
+        V = ConstSrc2->isNegative() ? IC.Builder.CreateMinimumNum(Src0, Src1)
+                                    : IC.Builder.CreateMaximumNum(Src0, Src1);
+        break;
+      case KnownIEEEMode::Unknown:
+        break;
+      }
     }
 
     if (V) {
diff --git a/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll b/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
index d9311008bd680..361a2b8280910 100644
--- a/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
+++ b/llvm/test/Transforms/InstCombine/AMDGPU/fmed3.ll
@@ -521,6 +521,96 @@ define float @fmed3_neg2_3_snan1_f32(float %x, float %y) #1 {
   ret float %med3
 }
 
+define float @fmed3_inf_x_y_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_inf_x_y_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT:    [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT:    ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_inf_x_y_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT:    [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT:    ret float [[MED3]]
+;
+  %med3 = call float @llvm.amdgcn.fmed3.f32(float 0x7FF0000000000000, float %x, float %y)
+  ret float %med3
+}
+
+define float @fmed3_x_inf_y_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_x_inf_y_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT:    [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT:    ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_x_inf_y_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT:    [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT:    ret float [[MED3]]
+;
+  %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0x7FF0000000000000, float %y)
+  ret float %med3
+}
+
+define float @fmed3_x_y_inf_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_x_y_inf_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT:    [[MED3:%.*]] = call float @llvm.maxnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT:    ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_x_y_inf_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT:    [[MED3:%.*]] = call float @llvm.maximumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT:    ret float [[MED3]]
+;
+  %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 0x7FF0000000000000)
+  ret float %med3
+}
+
+define float @fmed3_ninf_x_y_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_ninf_x_y_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT:    [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT:    ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_ninf_x_y_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT:    [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT:    ret float [[MED3]]
+;
+  %med3 = call float @llvm.amdgcn.fmed3.f32(float 0xFFF0000000000000, float %x, float %y)
+  ret float %med3
+}
+
+define float @fmed3_x_ninf_y_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_x_ninf_y_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT:    [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT:    ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_x_ninf_y_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT:    [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT:    ret float [[MED3]]
+;
+  %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float 0xFFF0000000000000, float %y)
+  ret float %med3
+}
+
+define float @fmed3_x_y_ninf_f32(float %x, float %y) #1 {
+; IEEE1-LABEL: define float @fmed3_x_y_ninf_f32(
+; IEEE1-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE1-NEXT:    [[MED3:%.*]] = call float @llvm.minnum.f32(float [[X]], float [[Y]])
+; IEEE1-NEXT:    ret float [[MED3]]
+;
+; IEEE0-LABEL: define float @fmed3_x_y_ninf_f32(
+; IEEE0-SAME: float [[X:%.*]], float [[Y:%.*]]) #[[ATTR1]] {
+; IEEE0-NEXT:    [[MED3:%.*]] = call float @llvm.minimumnum.f32(float [[X]], float [[Y]])
+; IEEE0-NEXT:    ret float [[MED3]]
+;
+  %med3 = call float @llvm.amdgcn.fmed3.f32(float %x, float %y, float 0xFFF0000000000000)
+  ret float %med3
+}
+
 ; --------------------------------------------------------------------
 ; llvm.amdgcn.fmed3 with default mode implied by shader CC
 ; --------------------------------------------------------------------

@@ -1088,6 +1087,45 @@ GCNTTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
case KnownIEEEMode::Unknown:
break;
}
} else if (match(Src0, m_APFloat(ConstSrc0)) && ConstSrc0->isInfinity()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you merge this case with the nan handling above? The m_APFloat matcher already matched. If this is treated as a separate case, this should use the nicer m_Inf matcher

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants