Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[libc][math][c23] Implement higher math function cbrtf16 in LLVM libc #132484

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

krishna2803
Copy link
Contributor

This PR implements cbrtf16 for half precision float (float16).

Closes #132199.

Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-backend-amdgpu

Author: Krishna Pandey (krishna2803)

Changes

This PR implements cbrtf16 for half precision float (float16).

Closes #132199.


Full diff: https://github.com/llvm/llvm-project/pull/132484.diff

13 Files Affected:

  • (modified) libc/config/gpu/amdgpu/entrypoints.txt (+1)
  • (modified) libc/config/gpu/nvptx/entrypoints.txt (+1)
  • (modified) libc/config/linux/aarch64/entrypoints.txt (+1)
  • (modified) libc/config/linux/x86_64/entrypoints.txt (+1)
  • (modified) libc/src/math/CMakeLists.txt (+1)
  • (added) libc/src/math/cbrtf16.h (+21)
  • (modified) libc/src/math/generic/CMakeLists.txt (+18)
  • (modified) libc/src/math/generic/cbrtf.cpp (+1-1)
  • (added) libc/src/math/generic/cbrtf16.cpp (+164)
  • (modified) libc/test/src/math/CMakeLists.txt (+12)
  • (added) libc/test/src/math/cbrtf16_test.cpp (+56)
  • (modified) libc/test/src/math/smoke/CMakeLists.txt (+10)
  • (added) libc/test/src/math/smoke/cbrtf16_test.cpp (+33)
diff --git a/libc/config/gpu/amdgpu/entrypoints.txt b/libc/config/gpu/amdgpu/entrypoints.txt
index 291d86b4dd587..0a155e68f307d 100644
--- a/libc/config/gpu/amdgpu/entrypoints.txt
+++ b/libc/config/gpu/amdgpu/entrypoints.txt
@@ -520,6 +520,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
   list(APPEND TARGET_LIBM_ENTRYPOINTS
     # math.h C23 _Float16 entrypoints
     libc.src.math.canonicalizef16
+    libc.src.math.cbrtf16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
     libc.src.math.coshf16
diff --git a/libc/config/gpu/nvptx/entrypoints.txt b/libc/config/gpu/nvptx/entrypoints.txt
index 1ea0d9b03b37e..010265bd915cb 100644
--- a/libc/config/gpu/nvptx/entrypoints.txt
+++ b/libc/config/gpu/nvptx/entrypoints.txt
@@ -522,6 +522,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
   list(APPEND TARGET_LIBM_ENTRYPOINTS
     # math.h C23 _Float16 entrypoints
     libc.src.math.canonicalizef16
+    libc.src.math.cbrtf16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
     libc.src.math.coshf16
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index 5c4913a658c2f..651bac5b9e79a 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -646,6 +646,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
   list(APPEND TARGET_LIBM_ENTRYPOINTS
     # math.h C23 _Float16 entrypoints
     libc.src.math.canonicalizef16
+    libc.src.math.cbrtf16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
     libc.src.math.cospif16
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 124b80d03d846..7f22b4a82ce14 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -660,6 +660,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.cosf16
     libc.src.math.coshf16
     libc.src.math.cospif16
+    libc.src.math.cbrtf16
     libc.src.math.exp10f16
     libc.src.math.exp10m1f16
     libc.src.math.exp2f16
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index f18a73d46f9aa..7bcf4ee294126 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -78,6 +78,7 @@ add_math_entrypoint_object(iscanonicalf128)
 
 add_math_entrypoint_object(cbrt)
 add_math_entrypoint_object(cbrtf)
+add_math_entrypoint_object(cbrtf16)
 
 add_math_entrypoint_object(ceil)
 add_math_entrypoint_object(ceilf)
diff --git a/libc/src/math/cbrtf16.h b/libc/src/math/cbrtf16.h
new file mode 100644
index 0000000000000..bddb1eda0a9fa
--- /dev/null
+++ b/libc/src/math/cbrtf16.h
@@ -0,0 +1,21 @@
+//===-- Implementation header for cbrtf16 -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_CBRTF16_H
+#define LLVM_LIBC_SRC_MATH_CBRTF16_H
+
+#include "src/__support/macros/config.h"           // LIBC_NAMESPACE_DECL
+#include "src/__support/macros/properties/types.h" // float16
+
+namespace LIBC_NAMESPACE_DECL {
+
+float16 cbrtf16(float16 x);
+
+} // namespace LIBC_NAMESPACE_DECL
+
+#endif // LLVM_LIBC_SRC_MATH_CBRTF16_H
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index db07bd1a098cc..f88f133d8efdd 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -4816,6 +4816,24 @@ add_entrypoint_object(
     libc.src.__support.integer_literals
 )
 
+add_entrypoint_object(
+  cbrtf16
+  SRCS
+    cbrtf16.cpp
+  HDRS
+    ../cbrtf16.h
+  DEPENDS
+    libc.hdr.fenv_macros
+    libc.src.__support.FPUtil.double_double
+    libc.src.__support.FPUtil.dyadic_float
+    libc.src.__support.FPUtil.fenv_impl
+    libc.src.__support.FPUtil.fp_bits
+    libc.src.__support.FPUtil.multiply_add
+    libc.src.__support.FPUtil.polyeval
+    libc.src.__support.macros.optimization
+    libc.src.__support.integer_literals
+)
+
 add_entrypoint_object(
   dmull
   SRCS
diff --git a/libc/src/math/generic/cbrtf.cpp b/libc/src/math/generic/cbrtf.cpp
index 71b23c4a8c742..868790ee7c7c0 100644
--- a/libc/src/math/generic/cbrtf.cpp
+++ b/libc/src/math/generic/cbrtf.cpp
@@ -22,7 +22,7 @@ namespace {
 // Look up table for 2^(i/3) for i = 0, 1, 2.
 constexpr double CBRT2[3] = {1.0, 0x1.428a2f98d728bp0, 0x1.965fea53d6e3dp0};
 
-// Degree-7 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
+// Degree-6 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
 // generated by Sollya with:
 // > for i from 0 to 15 do {
 //     P = fpminimax(((1 + x)^(1/3) - 1)/x, 6, [|D...|], [i/16, (i + 1)/16]);
diff --git a/libc/src/math/generic/cbrtf16.cpp b/libc/src/math/generic/cbrtf16.cpp
new file mode 100644
index 0000000000000..782f1bfd1b100
--- /dev/null
+++ b/libc/src/math/generic/cbrtf16.cpp
@@ -0,0 +1,164 @@
+//===-- Implementation of sqrtf16 function --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/cbrtf16.h"
+#include "hdr/fenv_macros.h"
+#include "src/__support/FPUtil/FEnvImpl.h"
+#include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/FPUtil/multiply_add.h"
+#include "src/__support/common.h"
+#include "src/__support/macros/config.h"
+#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
+
+namespace LIBC_NAMESPACE_DECL {
+
+namespace {
+
+// Look up table for 2^(i/3) for i = 0, 1, 2 in single precision
+constexpr float CBRT2[3] = {0x1p0f, 0x1.428a3p0f, 0x1.965feap0f};
+
+// Degree-4 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
+// generated by Sollya with:
+// > display=hexadecimal;
+// for i from 0 to 15 do {
+//   P = fpminimax(((1 + x)^(1/3) - 1)/x, 4, [|SG...|], [i/16, (i + 1)/16]);
+//   print("{", coeff(P, 0), ",", coeff(P, 1), ",", coeff(P, 2), ",",
+//         coeff(P, 3), coeff(P, 4),"},");
+// };
+// Then (1 + x)^(1/3) ~ 1 + x * P(x).
+// For example: for 0 <= x <= 1/8:
+// P(x) = 0x1.555556p-2 + x * (-0x1.c71d38p-4 + x * (0x1.f9b95ap-5 + x *
+// (-0x1.4ebe18p-5 + x * 0x1.9ca9d2p-6)))
+
+constexpr float COEFFS[16][5] = {
+    {0x1.555556p-2f, -0x1.c71ea4p-4f, 0x1.faa5f2p-5f, -0x1.64febep-5f,
+     0x1.733a46p-5f},
+    {0x1.55554ep-2f, -0x1.c715f6p-4f, 0x1.f88a9ep-5f, -0x1.4456e8p-5f,
+     0x1.5b5ef2p-6f},
+    {0x1.555508p-2f, -0x1.c6f404p-4f, 0x1.f56b7ap-5f, -0x1.33cff8p-5f,
+     0x1.18f146p-6f},
+    {0x1.5553fcp-2f, -0x1.c69bacp-4f, 0x1.efed98p-5f, -0x1.204706p-5f,
+     0x1.c90976p-7f},
+    {0x1.55517p-2f, -0x1.c5f996p-4f, 0x1.e85932p-5f, -0x1.0c0c0ep-5f,
+     0x1.77c766p-7f},
+    {0x1.554c96p-2f, -0x1.c501d2p-4f, 0x1.df0fc4p-5f, -0x1.f067f2p-6f,
+     0x1.380ab8p-7f},
+    {0x1.55448cp-2f, -0x1.c3ab1ep-4f, 0x1.d45876p-5f, -0x1.ca3988p-6f,
+     0x1.04f38ap-7f},
+    {0x1.5538aap-2f, -0x1.c1f886p-4f, 0x1.c8b11p-5f, -0x1.a6a16cp-6f,
+     0x1.b847c2p-8f},
+    {0x1.55278ap-2f, -0x1.bfd538p-4f, 0x1.bbde6p-5f, -0x1.846a8cp-6f,
+     0x1.73bfcp-8f},
+    {0x1.5511dp-2f, -0x1.bd6c88p-4f, 0x1.af0a3ap-5f, -0x1.660852p-6f,
+     0x1.3dbe34p-8f},
+    {0x1.54f82ap-2f, -0x1.bada56p-4f, 0x1.a2aa0ep-5f, -0x1.4b8c2ap-6f,
+     0x1.13379cp-8f},
+    {0x1.54d512p-2f, -0x1.b7a936p-4f, 0x1.94b91ep-5f, -0x1.30792cp-6f,
+     0x1.d7883cp-9f},
+    {0x1.54a8d8p-2f, -0x1.b3fde2p-4f, 0x1.861aeep-5f, -0x1.169484p-6f,
+     0x1.92b4cap-9f},
+    {0x1.548126p-2f, -0x1.b0f4a8p-4f, 0x1.7af574p-5f, -0x1.04644ep-6f,
+     0x1.662fb6p-9f},
+    {0x1.544b9p-2f, -0x1.ad2124p-4f, 0x1.6dd75p-5f, -0x1.e0cbecp-7f,
+     0x1.387692p-9f},
+    {0x1.5422c6p-2f, -0x1.aa61bp-4f, 0x1.64f4bap-5f, -0x1.c742b2p-7f,
+     0x1.1cf15ap-9f},
+};
+
+} // anonymous namespace
+
+LLVM_LIBC_FUNCTION(float16, cbrtf16, (float16 x)) {
+  using FPBits = fputil::FPBits<float16>;
+  using FloatBits = fputil::FPBits<float>;
+
+  FPBits x_bits(x);
+
+  uint16_t x_u = x_bits.uintval();
+  uint16_t x_abs = x_u & 0x7fff;
+  uint32_t sign_bit = (x_u >> 15) << FloatBits::EXP_LEN;
+
+  // cbrtf16(0) = 0, cbrtf16(NaN) = NaN
+  if (LIBC_UNLIKELY(x_abs == 0 || x_abs >= 0x7C00)) {
+    if (x_bits.is_signaling_nan()) {
+      fputil::raise_except(FE_INVALID);
+      return FPBits::quiet_nan().uintval();
+    }
+    return x;
+  }
+
+  float xf = static_cast<float>(x);
+  FloatBits xf_bits(xf);
+
+  unsigned x_e = static_cast<unsigned>(xf_bits.get_exponent());
+  unsigned out_e = (x_e / 3 + 127) | sign_bit;
+
+  unsigned shift_e = x_e % 3;
+
+  // Set x_m = 2^(x_e % 3) * (1 + mantissa)
+  uint32_t x_m = xf_bits.get_mantissa();
+
+  // Use the leading 4 bits for look up table
+  unsigned idx = static_cast<unsigned>(x_m >> (FloatBits::FRACTION_LEN - 4));
+
+  x_m |= static_cast<uint32_t>(FloatBits::EXP_BIAS) << FloatBits::FRACTION_LEN;
+
+  float x_reduced = FloatBits(x_m).get_val();
+  float dx = x_reduced - 1.0f;
+
+  float dx_sq = dx * dx;
+
+  // fputil::multiply_add(x, y, z) = x*y + z
+
+  // c0 =  1 + x * a0
+  float c0 = fputil::multiply_add(dx, COEFFS[idx][0], 1.0f);
+  // c1 = a1 + x * a2
+  float c1 = fputil::multiply_add(dx, COEFFS[idx][2], COEFFS[idx][1]);
+  // c2 = a3 + x * a4
+  float c2 = fputil::multiply_add(dx, COEFFS[idx][4], COEFFS[idx][3]);
+  // we save a multiply_add operation by decreasing the polynomial degree by 2
+  // i.e. using a degree-4 polynomial instead of degree 6.
+
+  float dx_4 = dx_sq * dx_sq;
+
+  // p0 = c0 + x^2 * c1
+  // p0 = (1 + x * a0) + x^2 * (a1 + x * a2)
+  // p0 = 1 + x * a0 + x^2 * a1 + x^3 * a2
+  float p0 = fputil::multiply_add(dx_sq, c1, c0);
+
+  // p1 = c2
+  // p1 = x * a4
+  float p1 = c2;
+
+  // r = p0 + x^4 * p1
+  // r = (1 + x * a0 + x^2 * a1 + x^3 * a2) + x^4 (x * a4)
+  // r = 1 + x * a0 + x^2 * a1 + x^3 * a2 + x^5 * a4
+  // r = 1 + x * (a0 + a1 * x + a2 * x^2 + a3 * x^3 + a4 * x^4)
+  // r = 1 + x * P(x)
+  float r = fputil::multiply_add(dx_4, p1, p0) * CBRT2[shift_e];
+
+  uint32_t r_m = FloatBits(r).get_mantissa();
+  // For float, mantissa is 23 bits (instead of 52 for double)
+  // Check if the output is exact. To be exact, the smallest 1-bit of the
+  // output has to be at least 2^-7 or higher. So we check the lowest 15 bits
+  // to see if they are within 2^(-23 + 3) errors from all zeros, then the
+  // result cube root is exact.
+  if (LIBC_UNLIKELY(((r_m + 4) & 0x7fff) <= 8)) {
+    if ((r_m & 0x7fff) <= 4)
+      r_m &= 0xffff'ffe0;
+    else
+      r_m = (r_m & 0xffff'ffe0) + 0x20; // Round up to next multiple of 0x20
+    fputil::clear_except_if_required(FE_INEXACT);
+  }
+
+  uint32_t r_bits =
+      r_m | (static_cast<uint32_t>(out_e) << FloatBits::FRACTION_LEN);
+
+  return static_cast<float16>(FloatBits(r_bits).get_val());
+}
+
+} // namespace LIBC_NAMESPACE_DECL
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index beafa87e03a77..1fb7f47b1d541 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -2655,6 +2655,18 @@ add_fp_unittest(
     libc.src.__support.FPUtil.fp_bits
 )
 
+add_fp_unittest(
+  cbrtf16_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    cbrtf16_test.cpp
+  DEPENDS
+    libc.src.math.cbrtf16
+    libc.src.__support.FPUtil.fp_bits
+)
+
 add_fp_unittest(
   dmull_test
   NEED_MPFR
diff --git a/libc/test/src/math/cbrtf16_test.cpp b/libc/test/src/math/cbrtf16_test.cpp
new file mode 100644
index 0000000000000..2e2cfc079aeb5
--- /dev/null
+++ b/libc/test/src/math/cbrtf16_test.cpp
@@ -0,0 +1,56 @@
+//===-- Unittests for cbrtf16 ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "hdr/math_macros.h"
+#include "src/__support/FPUtil/FPBits.h"
+#include "src/math/cbrtf16.h"
+#include "test/UnitTest/FPMatcher.h"
+#include "test/UnitTest/Test.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+using LlvmLibcCbrtf16Test = LIBC_NAMESPACE::testing::FPTest<float16>;
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+// Range: [0, Inf];
+static constexpr uint16_t POS_START = 0x0000U;
+static constexpr uint16_t POS_STOP = 0x7c00U;
+
+// Range: [-Inf, 0]
+static constexpr uint16_t NEG_START = 0x8000U;
+static constexpr uint16_t NEG_STOP = 0xfc00U;
+
+TEST_F(LlvmLibcCbrtf16Test, PositiveRange) {
+  for (uint16_t v = POS_START; v <= POS_STOP; ++v) {
+    float16 x = FPBits(v).get_val();
+    EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Cbrt, x,
+                                   LIBC_NAMESPACE::cbrtf16(x), 0.5);
+  }
+}
+
+TEST_F(LlvmLibcCbrtf16Test, NegativeRange) {
+  for (uint16_t v = NEG_START; v <= NEG_STOP; ++v) {
+    float16 x = FPBits(v).get_val();
+    EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Cbrt, x,
+                                   LIBC_NAMESPACE::cbrtf16(x), 0.5);
+  }
+}
+
+TEST_F(LlvmLibcCbrtf16Test, SpecialValues) {
+  constexpr uint16_t INPUTS[] = {
+      0x4a00, 0x4500, 0x4e00, 0x0c00, 0x4940,
+  };
+  for (uint16_t v : INPUTS) {
+    float16 x = FPBits(v).get_val();
+    mpfr::ForceRoundingMode r(mpfr::RoundingMode::Upward);
+    EXPECT_MPFR_MATCH(mpfr::Operation::Cbrt, x, LIBC_NAMESPACE::cbrtf16(x), 0.5,
+                      mpfr::RoundingMode::Upward);
+  }
+
+  ASSERT_EQ(1, 1);
+}
diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt
index 94ec099ddfcbc..f7311d4f2080c 100644
--- a/libc/test/src/math/smoke/CMakeLists.txt
+++ b/libc/test/src/math/smoke/CMakeLists.txt
@@ -5042,6 +5042,16 @@ add_fp_unittest(
     libc.src.math.cbrt
 )
 
+add_fp_unittest(
+  cbrtf16_test
+  SUITE
+    libc-math-smoke-tests
+  SRCS
+    cbrtf16_test.cpp
+  DEPENDS
+    libc.src.math.cbrtf16
+)
+
 add_fp_unittest(
   dmull_test
   SUITE
diff --git a/libc/test/src/math/smoke/cbrtf16_test.cpp b/libc/test/src/math/smoke/cbrtf16_test.cpp
new file mode 100644
index 0000000000000..7c8a7103273a0
--- /dev/null
+++ b/libc/test/src/math/smoke/cbrtf16_test.cpp
@@ -0,0 +1,33 @@
+//===-- Unittests for cbrtf16 ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/cbrtf16.h"
+#include "test/UnitTest/FPMatcher.h"
+#include "test/UnitTest/Test.h"
+
+using LlvmLibcCbrtfTest = LIBC_NAMESPACE::testing::FPTest<float16>;
+
+using LIBC_NAMESPACE::testing::tlog;
+
+TEST_F(LlvmLibcCbrtfTest, SpecialNumbers) {
+  EXPECT_FP_EQ_ALL_ROUNDING(aNaN, LIBC_NAMESPACE::cbrtf16(aNaN));
+  EXPECT_FP_EQ_ALL_ROUNDING(inf, LIBC_NAMESPACE::cbrtf16(inf));
+  EXPECT_FP_EQ_ALL_ROUNDING(neg_inf, LIBC_NAMESPACE::cbrtf16(neg_inf));
+  EXPECT_FP_EQ_ALL_ROUNDING(zero, LIBC_NAMESPACE::cbrtf16(zero));
+  EXPECT_FP_EQ_ALL_ROUNDING(neg_zero, LIBC_NAMESPACE::cbrtf16(neg_zero));
+  EXPECT_FP_EQ_ALL_ROUNDING(1.0f, LIBC_NAMESPACE::cbrtf16(1.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(-1.0f, LIBC_NAMESPACE::cbrtf16(-1.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(2.0f, LIBC_NAMESPACE::cbrtf16(8.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(-2.0f, LIBC_NAMESPACE::cbrtf16(-8.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(3.0f, LIBC_NAMESPACE::cbrtf16(27.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(-3.0f, LIBC_NAMESPACE::cbrtf16(-27.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(5.0f, LIBC_NAMESPACE::cbrtf16(125.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(-5.0f, LIBC_NAMESPACE::cbrtf16(-125.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(40.0f, LIBC_NAMESPACE::cbrtf16(0x1.f4p15));
+  EXPECT_FP_EQ_ALL_ROUNDING(-40.0f, LIBC_NAMESPACE::cbrtf16(-0x1.f4p15));
+}

@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-libc

Author: Krishna Pandey (krishna2803)

Changes

This PR implements cbrtf16 for half precision float (float16).

Closes #132199.


Full diff: https://github.com/llvm/llvm-project/pull/132484.diff

13 Files Affected:

  • (modified) libc/config/gpu/amdgpu/entrypoints.txt (+1)
  • (modified) libc/config/gpu/nvptx/entrypoints.txt (+1)
  • (modified) libc/config/linux/aarch64/entrypoints.txt (+1)
  • (modified) libc/config/linux/x86_64/entrypoints.txt (+1)
  • (modified) libc/src/math/CMakeLists.txt (+1)
  • (added) libc/src/math/cbrtf16.h (+21)
  • (modified) libc/src/math/generic/CMakeLists.txt (+18)
  • (modified) libc/src/math/generic/cbrtf.cpp (+1-1)
  • (added) libc/src/math/generic/cbrtf16.cpp (+164)
  • (modified) libc/test/src/math/CMakeLists.txt (+12)
  • (added) libc/test/src/math/cbrtf16_test.cpp (+56)
  • (modified) libc/test/src/math/smoke/CMakeLists.txt (+10)
  • (added) libc/test/src/math/smoke/cbrtf16_test.cpp (+33)
diff --git a/libc/config/gpu/amdgpu/entrypoints.txt b/libc/config/gpu/amdgpu/entrypoints.txt
index 291d86b4dd587..0a155e68f307d 100644
--- a/libc/config/gpu/amdgpu/entrypoints.txt
+++ b/libc/config/gpu/amdgpu/entrypoints.txt
@@ -520,6 +520,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
   list(APPEND TARGET_LIBM_ENTRYPOINTS
     # math.h C23 _Float16 entrypoints
     libc.src.math.canonicalizef16
+    libc.src.math.cbrtf16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
     libc.src.math.coshf16
diff --git a/libc/config/gpu/nvptx/entrypoints.txt b/libc/config/gpu/nvptx/entrypoints.txt
index 1ea0d9b03b37e..010265bd915cb 100644
--- a/libc/config/gpu/nvptx/entrypoints.txt
+++ b/libc/config/gpu/nvptx/entrypoints.txt
@@ -522,6 +522,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
   list(APPEND TARGET_LIBM_ENTRYPOINTS
     # math.h C23 _Float16 entrypoints
     libc.src.math.canonicalizef16
+    libc.src.math.cbrtf16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
     libc.src.math.coshf16
diff --git a/libc/config/linux/aarch64/entrypoints.txt b/libc/config/linux/aarch64/entrypoints.txt
index 5c4913a658c2f..651bac5b9e79a 100644
--- a/libc/config/linux/aarch64/entrypoints.txt
+++ b/libc/config/linux/aarch64/entrypoints.txt
@@ -646,6 +646,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
   list(APPEND TARGET_LIBM_ENTRYPOINTS
     # math.h C23 _Float16 entrypoints
     libc.src.math.canonicalizef16
+    libc.src.math.cbrtf16
     libc.src.math.ceilf16
     libc.src.math.copysignf16
     libc.src.math.cospif16
diff --git a/libc/config/linux/x86_64/entrypoints.txt b/libc/config/linux/x86_64/entrypoints.txt
index 124b80d03d846..7f22b4a82ce14 100644
--- a/libc/config/linux/x86_64/entrypoints.txt
+++ b/libc/config/linux/x86_64/entrypoints.txt
@@ -660,6 +660,7 @@ if(LIBC_TYPES_HAS_FLOAT16)
     libc.src.math.cosf16
     libc.src.math.coshf16
     libc.src.math.cospif16
+    libc.src.math.cbrtf16
     libc.src.math.exp10f16
     libc.src.math.exp10m1f16
     libc.src.math.exp2f16
diff --git a/libc/src/math/CMakeLists.txt b/libc/src/math/CMakeLists.txt
index f18a73d46f9aa..7bcf4ee294126 100644
--- a/libc/src/math/CMakeLists.txt
+++ b/libc/src/math/CMakeLists.txt
@@ -78,6 +78,7 @@ add_math_entrypoint_object(iscanonicalf128)
 
 add_math_entrypoint_object(cbrt)
 add_math_entrypoint_object(cbrtf)
+add_math_entrypoint_object(cbrtf16)
 
 add_math_entrypoint_object(ceil)
 add_math_entrypoint_object(ceilf)
diff --git a/libc/src/math/cbrtf16.h b/libc/src/math/cbrtf16.h
new file mode 100644
index 0000000000000..bddb1eda0a9fa
--- /dev/null
+++ b/libc/src/math/cbrtf16.h
@@ -0,0 +1,21 @@
+//===-- Implementation header for cbrtf16 -----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIBC_SRC_MATH_CBRTF16_H
+#define LLVM_LIBC_SRC_MATH_CBRTF16_H
+
+#include "src/__support/macros/config.h"           // LIBC_NAMESPACE_DECL
+#include "src/__support/macros/properties/types.h" // float16
+
+namespace LIBC_NAMESPACE_DECL {
+
+float16 cbrtf16(float16 x);
+
+} // namespace LIBC_NAMESPACE_DECL
+
+#endif // LLVM_LIBC_SRC_MATH_CBRTF16_H
diff --git a/libc/src/math/generic/CMakeLists.txt b/libc/src/math/generic/CMakeLists.txt
index db07bd1a098cc..f88f133d8efdd 100644
--- a/libc/src/math/generic/CMakeLists.txt
+++ b/libc/src/math/generic/CMakeLists.txt
@@ -4816,6 +4816,24 @@ add_entrypoint_object(
     libc.src.__support.integer_literals
 )
 
+add_entrypoint_object(
+  cbrtf16
+  SRCS
+    cbrtf16.cpp
+  HDRS
+    ../cbrtf16.h
+  DEPENDS
+    libc.hdr.fenv_macros
+    libc.src.__support.FPUtil.double_double
+    libc.src.__support.FPUtil.dyadic_float
+    libc.src.__support.FPUtil.fenv_impl
+    libc.src.__support.FPUtil.fp_bits
+    libc.src.__support.FPUtil.multiply_add
+    libc.src.__support.FPUtil.polyeval
+    libc.src.__support.macros.optimization
+    libc.src.__support.integer_literals
+)
+
 add_entrypoint_object(
   dmull
   SRCS
diff --git a/libc/src/math/generic/cbrtf.cpp b/libc/src/math/generic/cbrtf.cpp
index 71b23c4a8c742..868790ee7c7c0 100644
--- a/libc/src/math/generic/cbrtf.cpp
+++ b/libc/src/math/generic/cbrtf.cpp
@@ -22,7 +22,7 @@ namespace {
 // Look up table for 2^(i/3) for i = 0, 1, 2.
 constexpr double CBRT2[3] = {1.0, 0x1.428a2f98d728bp0, 0x1.965fea53d6e3dp0};
 
-// Degree-7 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
+// Degree-6 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
 // generated by Sollya with:
 // > for i from 0 to 15 do {
 //     P = fpminimax(((1 + x)^(1/3) - 1)/x, 6, [|D...|], [i/16, (i + 1)/16]);
diff --git a/libc/src/math/generic/cbrtf16.cpp b/libc/src/math/generic/cbrtf16.cpp
new file mode 100644
index 0000000000000..782f1bfd1b100
--- /dev/null
+++ b/libc/src/math/generic/cbrtf16.cpp
@@ -0,0 +1,164 @@
+//===-- Implementation of sqrtf16 function --------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/cbrtf16.h"
+#include "hdr/fenv_macros.h"
+#include "src/__support/FPUtil/FEnvImpl.h"
+#include "src/__support/FPUtil/FPBits.h"
+#include "src/__support/FPUtil/multiply_add.h"
+#include "src/__support/common.h"
+#include "src/__support/macros/config.h"
+#include "src/__support/macros/optimization.h" // LIBC_UNLIKELY
+
+namespace LIBC_NAMESPACE_DECL {
+
+namespace {
+
+// Look up table for 2^(i/3) for i = 0, 1, 2 in single precision
+constexpr float CBRT2[3] = {0x1p0f, 0x1.428a3p0f, 0x1.965feap0f};
+
+// Degree-4 polynomials approximation of ((1 + x)^(1/3) - 1)/x for 0 <= x <= 1
+// generated by Sollya with:
+// > display=hexadecimal;
+// for i from 0 to 15 do {
+//   P = fpminimax(((1 + x)^(1/3) - 1)/x, 4, [|SG...|], [i/16, (i + 1)/16]);
+//   print("{", coeff(P, 0), ",", coeff(P, 1), ",", coeff(P, 2), ",",
+//         coeff(P, 3), coeff(P, 4),"},");
+// };
+// Then (1 + x)^(1/3) ~ 1 + x * P(x).
+// For example: for 0 <= x <= 1/8:
+// P(x) = 0x1.555556p-2 + x * (-0x1.c71d38p-4 + x * (0x1.f9b95ap-5 + x *
+// (-0x1.4ebe18p-5 + x * 0x1.9ca9d2p-6)))
+
+constexpr float COEFFS[16][5] = {
+    {0x1.555556p-2f, -0x1.c71ea4p-4f, 0x1.faa5f2p-5f, -0x1.64febep-5f,
+     0x1.733a46p-5f},
+    {0x1.55554ep-2f, -0x1.c715f6p-4f, 0x1.f88a9ep-5f, -0x1.4456e8p-5f,
+     0x1.5b5ef2p-6f},
+    {0x1.555508p-2f, -0x1.c6f404p-4f, 0x1.f56b7ap-5f, -0x1.33cff8p-5f,
+     0x1.18f146p-6f},
+    {0x1.5553fcp-2f, -0x1.c69bacp-4f, 0x1.efed98p-5f, -0x1.204706p-5f,
+     0x1.c90976p-7f},
+    {0x1.55517p-2f, -0x1.c5f996p-4f, 0x1.e85932p-5f, -0x1.0c0c0ep-5f,
+     0x1.77c766p-7f},
+    {0x1.554c96p-2f, -0x1.c501d2p-4f, 0x1.df0fc4p-5f, -0x1.f067f2p-6f,
+     0x1.380ab8p-7f},
+    {0x1.55448cp-2f, -0x1.c3ab1ep-4f, 0x1.d45876p-5f, -0x1.ca3988p-6f,
+     0x1.04f38ap-7f},
+    {0x1.5538aap-2f, -0x1.c1f886p-4f, 0x1.c8b11p-5f, -0x1.a6a16cp-6f,
+     0x1.b847c2p-8f},
+    {0x1.55278ap-2f, -0x1.bfd538p-4f, 0x1.bbde6p-5f, -0x1.846a8cp-6f,
+     0x1.73bfcp-8f},
+    {0x1.5511dp-2f, -0x1.bd6c88p-4f, 0x1.af0a3ap-5f, -0x1.660852p-6f,
+     0x1.3dbe34p-8f},
+    {0x1.54f82ap-2f, -0x1.bada56p-4f, 0x1.a2aa0ep-5f, -0x1.4b8c2ap-6f,
+     0x1.13379cp-8f},
+    {0x1.54d512p-2f, -0x1.b7a936p-4f, 0x1.94b91ep-5f, -0x1.30792cp-6f,
+     0x1.d7883cp-9f},
+    {0x1.54a8d8p-2f, -0x1.b3fde2p-4f, 0x1.861aeep-5f, -0x1.169484p-6f,
+     0x1.92b4cap-9f},
+    {0x1.548126p-2f, -0x1.b0f4a8p-4f, 0x1.7af574p-5f, -0x1.04644ep-6f,
+     0x1.662fb6p-9f},
+    {0x1.544b9p-2f, -0x1.ad2124p-4f, 0x1.6dd75p-5f, -0x1.e0cbecp-7f,
+     0x1.387692p-9f},
+    {0x1.5422c6p-2f, -0x1.aa61bp-4f, 0x1.64f4bap-5f, -0x1.c742b2p-7f,
+     0x1.1cf15ap-9f},
+};
+
+} // anonymous namespace
+
+LLVM_LIBC_FUNCTION(float16, cbrtf16, (float16 x)) {
+  using FPBits = fputil::FPBits<float16>;
+  using FloatBits = fputil::FPBits<float>;
+
+  FPBits x_bits(x);
+
+  uint16_t x_u = x_bits.uintval();
+  uint16_t x_abs = x_u & 0x7fff;
+  uint32_t sign_bit = (x_u >> 15) << FloatBits::EXP_LEN;
+
+  // cbrtf16(0) = 0, cbrtf16(NaN) = NaN
+  if (LIBC_UNLIKELY(x_abs == 0 || x_abs >= 0x7C00)) {
+    if (x_bits.is_signaling_nan()) {
+      fputil::raise_except(FE_INVALID);
+      return FPBits::quiet_nan().uintval();
+    }
+    return x;
+  }
+
+  float xf = static_cast<float>(x);
+  FloatBits xf_bits(xf);
+
+  unsigned x_e = static_cast<unsigned>(xf_bits.get_exponent());
+  unsigned out_e = (x_e / 3 + 127) | sign_bit;
+
+  unsigned shift_e = x_e % 3;
+
+  // Set x_m = 2^(x_e % 3) * (1 + mantissa)
+  uint32_t x_m = xf_bits.get_mantissa();
+
+  // Use the leading 4 bits for look up table
+  unsigned idx = static_cast<unsigned>(x_m >> (FloatBits::FRACTION_LEN - 4));
+
+  x_m |= static_cast<uint32_t>(FloatBits::EXP_BIAS) << FloatBits::FRACTION_LEN;
+
+  float x_reduced = FloatBits(x_m).get_val();
+  float dx = x_reduced - 1.0f;
+
+  float dx_sq = dx * dx;
+
+  // fputil::multiply_add(x, y, z) = x*y + z
+
+  // c0 =  1 + x * a0
+  float c0 = fputil::multiply_add(dx, COEFFS[idx][0], 1.0f);
+  // c1 = a1 + x * a2
+  float c1 = fputil::multiply_add(dx, COEFFS[idx][2], COEFFS[idx][1]);
+  // c2 = a3 + x * a4
+  float c2 = fputil::multiply_add(dx, COEFFS[idx][4], COEFFS[idx][3]);
+  // we save a multiply_add operation by decreasing the polynomial degree by 2
+  // i.e. using a degree-4 polynomial instead of degree 6.
+
+  float dx_4 = dx_sq * dx_sq;
+
+  // p0 = c0 + x^2 * c1
+  // p0 = (1 + x * a0) + x^2 * (a1 + x * a2)
+  // p0 = 1 + x * a0 + x^2 * a1 + x^3 * a2
+  float p0 = fputil::multiply_add(dx_sq, c1, c0);
+
+  // p1 = c2
+  // p1 = x * a4
+  float p1 = c2;
+
+  // r = p0 + x^4 * p1
+  // r = (1 + x * a0 + x^2 * a1 + x^3 * a2) + x^4 (x * a4)
+  // r = 1 + x * a0 + x^2 * a1 + x^3 * a2 + x^5 * a4
+  // r = 1 + x * (a0 + a1 * x + a2 * x^2 + a3 * x^3 + a4 * x^4)
+  // r = 1 + x * P(x)
+  float r = fputil::multiply_add(dx_4, p1, p0) * CBRT2[shift_e];
+
+  uint32_t r_m = FloatBits(r).get_mantissa();
+  // For float, mantissa is 23 bits (instead of 52 for double)
+  // Check if the output is exact. To be exact, the smallest 1-bit of the
+  // output has to be at least 2^-7 or higher. So we check the lowest 15 bits
+  // to see if they are within 2^(-23 + 3) errors from all zeros, then the
+  // result cube root is exact.
+  if (LIBC_UNLIKELY(((r_m + 4) & 0x7fff) <= 8)) {
+    if ((r_m & 0x7fff) <= 4)
+      r_m &= 0xffff'ffe0;
+    else
+      r_m = (r_m & 0xffff'ffe0) + 0x20; // Round up to next multiple of 0x20
+    fputil::clear_except_if_required(FE_INEXACT);
+  }
+
+  uint32_t r_bits =
+      r_m | (static_cast<uint32_t>(out_e) << FloatBits::FRACTION_LEN);
+
+  return static_cast<float16>(FloatBits(r_bits).get_val());
+}
+
+} // namespace LIBC_NAMESPACE_DECL
diff --git a/libc/test/src/math/CMakeLists.txt b/libc/test/src/math/CMakeLists.txt
index beafa87e03a77..1fb7f47b1d541 100644
--- a/libc/test/src/math/CMakeLists.txt
+++ b/libc/test/src/math/CMakeLists.txt
@@ -2655,6 +2655,18 @@ add_fp_unittest(
     libc.src.__support.FPUtil.fp_bits
 )
 
+add_fp_unittest(
+  cbrtf16_test
+  NEED_MPFR
+  SUITE
+    libc-math-unittests
+  SRCS
+    cbrtf16_test.cpp
+  DEPENDS
+    libc.src.math.cbrtf16
+    libc.src.__support.FPUtil.fp_bits
+)
+
 add_fp_unittest(
   dmull_test
   NEED_MPFR
diff --git a/libc/test/src/math/cbrtf16_test.cpp b/libc/test/src/math/cbrtf16_test.cpp
new file mode 100644
index 0000000000000..2e2cfc079aeb5
--- /dev/null
+++ b/libc/test/src/math/cbrtf16_test.cpp
@@ -0,0 +1,56 @@
+//===-- Unittests for cbrtf16 ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "hdr/math_macros.h"
+#include "src/__support/FPUtil/FPBits.h"
+#include "src/math/cbrtf16.h"
+#include "test/UnitTest/FPMatcher.h"
+#include "test/UnitTest/Test.h"
+#include "utils/MPFRWrapper/MPFRUtils.h"
+
+using LlvmLibcCbrtf16Test = LIBC_NAMESPACE::testing::FPTest<float16>;
+
+namespace mpfr = LIBC_NAMESPACE::testing::mpfr;
+
+// Range: [0, Inf];
+static constexpr uint16_t POS_START = 0x0000U;
+static constexpr uint16_t POS_STOP = 0x7c00U;
+
+// Range: [-Inf, 0]
+static constexpr uint16_t NEG_START = 0x8000U;
+static constexpr uint16_t NEG_STOP = 0xfc00U;
+
+TEST_F(LlvmLibcCbrtf16Test, PositiveRange) {
+  for (uint16_t v = POS_START; v <= POS_STOP; ++v) {
+    float16 x = FPBits(v).get_val();
+    EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Cbrt, x,
+                                   LIBC_NAMESPACE::cbrtf16(x), 0.5);
+  }
+}
+
+TEST_F(LlvmLibcCbrtf16Test, NegativeRange) {
+  for (uint16_t v = NEG_START; v <= NEG_STOP; ++v) {
+    float16 x = FPBits(v).get_val();
+    EXPECT_MPFR_MATCH_ALL_ROUNDING(mpfr::Operation::Cbrt, x,
+                                   LIBC_NAMESPACE::cbrtf16(x), 0.5);
+  }
+}
+
+TEST_F(LlvmLibcCbrtf16Test, SpecialValues) {
+  constexpr uint16_t INPUTS[] = {
+      0x4a00, 0x4500, 0x4e00, 0x0c00, 0x4940,
+  };
+  for (uint16_t v : INPUTS) {
+    float16 x = FPBits(v).get_val();
+    mpfr::ForceRoundingMode r(mpfr::RoundingMode::Upward);
+    EXPECT_MPFR_MATCH(mpfr::Operation::Cbrt, x, LIBC_NAMESPACE::cbrtf16(x), 0.5,
+                      mpfr::RoundingMode::Upward);
+  }
+
+  ASSERT_EQ(1, 1);
+}
diff --git a/libc/test/src/math/smoke/CMakeLists.txt b/libc/test/src/math/smoke/CMakeLists.txt
index 94ec099ddfcbc..f7311d4f2080c 100644
--- a/libc/test/src/math/smoke/CMakeLists.txt
+++ b/libc/test/src/math/smoke/CMakeLists.txt
@@ -5042,6 +5042,16 @@ add_fp_unittest(
     libc.src.math.cbrt
 )
 
+add_fp_unittest(
+  cbrtf16_test
+  SUITE
+    libc-math-smoke-tests
+  SRCS
+    cbrtf16_test.cpp
+  DEPENDS
+    libc.src.math.cbrtf16
+)
+
 add_fp_unittest(
   dmull_test
   SUITE
diff --git a/libc/test/src/math/smoke/cbrtf16_test.cpp b/libc/test/src/math/smoke/cbrtf16_test.cpp
new file mode 100644
index 0000000000000..7c8a7103273a0
--- /dev/null
+++ b/libc/test/src/math/smoke/cbrtf16_test.cpp
@@ -0,0 +1,33 @@
+//===-- Unittests for cbrtf16 ---------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/math/cbrtf16.h"
+#include "test/UnitTest/FPMatcher.h"
+#include "test/UnitTest/Test.h"
+
+using LlvmLibcCbrtfTest = LIBC_NAMESPACE::testing::FPTest<float16>;
+
+using LIBC_NAMESPACE::testing::tlog;
+
+TEST_F(LlvmLibcCbrtfTest, SpecialNumbers) {
+  EXPECT_FP_EQ_ALL_ROUNDING(aNaN, LIBC_NAMESPACE::cbrtf16(aNaN));
+  EXPECT_FP_EQ_ALL_ROUNDING(inf, LIBC_NAMESPACE::cbrtf16(inf));
+  EXPECT_FP_EQ_ALL_ROUNDING(neg_inf, LIBC_NAMESPACE::cbrtf16(neg_inf));
+  EXPECT_FP_EQ_ALL_ROUNDING(zero, LIBC_NAMESPACE::cbrtf16(zero));
+  EXPECT_FP_EQ_ALL_ROUNDING(neg_zero, LIBC_NAMESPACE::cbrtf16(neg_zero));
+  EXPECT_FP_EQ_ALL_ROUNDING(1.0f, LIBC_NAMESPACE::cbrtf16(1.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(-1.0f, LIBC_NAMESPACE::cbrtf16(-1.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(2.0f, LIBC_NAMESPACE::cbrtf16(8.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(-2.0f, LIBC_NAMESPACE::cbrtf16(-8.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(3.0f, LIBC_NAMESPACE::cbrtf16(27.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(-3.0f, LIBC_NAMESPACE::cbrtf16(-27.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(5.0f, LIBC_NAMESPACE::cbrtf16(125.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(-5.0f, LIBC_NAMESPACE::cbrtf16(-125.0f));
+  EXPECT_FP_EQ_ALL_ROUNDING(40.0f, LIBC_NAMESPACE::cbrtf16(0x1.f4p15));
+  EXPECT_FP_EQ_ALL_ROUNDING(-40.0f, LIBC_NAMESPACE::cbrtf16(-0x1.f4p15));
+}

Signed-off-by: krishna2803 <kpandey81930@gmail.com>
@krishna2803
Copy link
Contributor Author

hmm... there seem to be some issues when calculating the cube roots of very small or very large numbers.

@krishna2803
Copy link
Contributor Author

oh no, this seems fine for | x | 1 but not for | x | < 1

@krishna2803
Copy link
Contributor Author

krishna2803 commented Mar 21, 2025

uh oh nvm we do have a failure for some x > 1.0 as well (though all those are 1  ulp

[3/3] Running unit test libc.test.src.math.cbrtf16_test.__unit__
FAILED: libc/test/src/math/CMakeFiles/libc.test.src.math.cbrtf16_test.__unit__ /home/krishna/OpenSource/llvm-project/build/libc/test/src/math/CMakeFiles/libc.test.src.math.cbrtf16_test.__unit__ 
cd /home/krishna/OpenSource/llvm-project/build/libc/test/src/math && /home/krishna/OpenSource/llvm-project/build/libc/test/src/math/libc.test.src.math.cbrtf16_test.__unit__.__build__
[==========] Running 1 test from 1 test suite.
[ RUN      ] LlvmLibcCbrtf16Test.PositiveRange
/home/krishna/OpenSource/llvm-project/libc/test/src/math/cbrtf16_test.cpp:33: FAILURE
Failed to match __llvm_libc_21_0_0_git::cbrtf16(x) against LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<mpfr::Operation::Cbrt>( x, __llvm_libc_21_0_0_git::cbrtf16(x), 0.5, mpfr::RoundingMode::Upward).
Match value not within tolerance value of MPFR result:
  Input decimal: 3.16210937500000000000000000000000000000000000000000
     Input bits: 0x4253 = (S: 0, E: 0x0010, M: 0x0253)

  Match decimal: 1.46875000000000000000000000000000000000000000000000
     Match bits: 0x3DE0 = (S: 0, E: 0x000F, M: 0x01E0)

    MPFR result: 1.46777343750000000000000000000000000000000000000000
   MPFR rounded: 0x3DDF = (S: 0, E: 0x000F, M: 0x01DF)

      ULP error: 1.00000000000000000000000000000000000000000000000000
/home/krishna/OpenSource/llvm-project/libc/test/src/math/cbrtf16_test.cpp:33: FAILURE
Failed to match __llvm_libc_21_0_0_git::cbrtf16(x) against LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<mpfr::Operation::Cbrt>( x, __llvm_libc_21_0_0_git::cbrtf16(x), 0.5, mpfr::RoundingMode::Upward).
Match value not within tolerance value of MPFR result:
  Input decimal: 25.29687500000000000000000000000000000000000000000000
     Input bits: 0x4E53 = (S: 0, E: 0x0013, M: 0x0253)

  Match decimal: 2.93750000000000000000000000000000000000000000000000
     Match bits: 0x41E0 = (S: 0, E: 0x0010, M: 0x01E0)

    MPFR result: 2.93554687500000000000000000000000000000000000000000
   MPFR rounded: 0x41DF = (S: 0, E: 0x0010, M: 0x01DF)

      ULP error: 1.00000000000000000000000000000000000000000000000000
/home/krishna/OpenSource/llvm-project/libc/test/src/math/cbrtf16_test.cpp:33: FAILURE
Failed to match __llvm_libc_21_0_0_git::cbrtf16(x) against LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<mpfr::Operation::Cbrt>( x, __llvm_libc_21_0_0_git::cbrtf16(x), 0.5, mpfr::RoundingMode::Upward).
Match value not within tolerance value of MPFR result:
  Input decimal: 202.37500000000000000000000000000000000000000000000000
     Input bits: 0x5A53 = (S: 0, E: 0x0016, M: 0x0253)

  Match decimal: 5.87500000000000000000000000000000000000000000000000
     Match bits: 0x45E0 = (S: 0, E: 0x0011, M: 0x01E0)

    MPFR result: 5.87109375000000000000000000000000000000000000000000
   MPFR rounded: 0x45DF = (S: 0, E: 0x0011, M: 0x01DF)

      ULP error: 1.00000000000000000000000000000000000000000000000000
/home/krishna/OpenSource/llvm-project/libc/test/src/math/cbrtf16_test.cpp:33: FAILURE
Failed to match __llvm_libc_21_0_0_git::cbrtf16(x) against LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<mpfr::Operation::Cbrt>( x, __llvm_libc_21_0_0_git::cbrtf16(x), 0.5, mpfr::RoundingMode::Upward).
Match value not within tolerance value of MPFR result:
  Input decimal: 1619.00000000000000000000000000000000000000000000000000
     Input bits: 0x6653 = (S: 0, E: 0x0019, M: 0x0253)

  Match decimal: 11.75000000000000000000000000000000000000000000000000
     Match bits: 0x49E0 = (S: 0, E: 0x0012, M: 0x01E0)

    MPFR result: 11.74218750000000000000000000000000000000000000000000
   MPFR rounded: 0x49DF = (S: 0, E: 0x0012, M: 0x01DF)

      ULP error: 1.00000000000000000000000000000000000000000000000000
/home/krishna/OpenSource/llvm-project/libc/test/src/math/cbrtf16_test.cpp:33: FAILURE
Failed to match __llvm_libc_21_0_0_git::cbrtf16(x) against LIBC_NAMESPACE::testing::mpfr::get_mpfr_matcher<mpfr::Operation::Cbrt>( x, __llvm_libc_21_0_0_git::cbrtf16(x), 0.5, mpfr::RoundingMode::Upward).
Match value not within tolerance value of MPFR result:
  Input decimal: 12952.00000000000000000000000000000000000000000000000000
     Input bits: 0x7253 = (S: 0, E: 0x001C, M: 0x0253)

  Match decimal: 23.50000000000000000000000000000000000000000000000000
     Match bits: 0x4DE0 = (S: 0, E: 0x0013, M: 0x01E0)

    MPFR result: 23.48437500000000000000000000000000000000000000000000
   MPFR rounded: 0x4DDF = (S: 0, E: 0x0013, M: 0x01DF)

      ULP error: 1.00000000000000000000000000000000000000000000000000
[  FAILED  ] LlvmLibcCbrtf16Test.PositiveRange
Ran 1 tests.  PASS: 0  FAIL: 1
ninja: build stopped: subcommand failed.

Comment on lines 97 to 100
unsigned x_e = static_cast<unsigned>(xf_bits.get_exponent());
unsigned out_e = (x_e / 3 + 127) | sign_bit;

unsigned shift_e = x_e % 3;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think the error for | x | < 1 is due to these

Signed-off-by: krishna2803 <kpandey81930@gmail.com>
@krishna2803
Copy link
Contributor Author

@overmighty @lntue can you please give some guidance?


uint16_t x_u = x_bits.uintval();
uint16_t x_abs = x_u & 0x7fff;
uint32_t sign_bit = (x_u >> 15) << FloatBits::EXP_LEN;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cast (x_u >> 15) to uint32_t before left-shifting.

Signed-off-by: krishna2803 <kpandey81930@gmail.com>
float xf = static_cast<float>(x);
FloatBits xf_bits(xf);

unsigned x_e = static_cast<unsigned>(xf_bits.get_exponent());
Copy link
Contributor

@lntue lntue Mar 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting x_e to unsigned before dividing by 3 will give you completely wrong results for negative x_e.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah sorry! i completely missed that it could be negative as wel

FloatBits xf_bits(xf);

unsigned x_e = static_cast<unsigned>(xf_bits.get_exponent());
unsigned out_e = x_e / 3 + 127;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x_e/3 is round-to-zero, so it won't be correct when x_e is negative and not divisible by 3.
If we analyze this a bit more careful: let x_e be the real exponent of x, the exponent field in xf will have value:

  x_e_biased = x_e + 127

and it is positive.
What we want is to have the final result as floor(x_e / 3) + 127. We can get that simply by subtracting 1 from the biased exponent of x before dividing by 3:

  floor( (x_e_biased - 1) / 3 ) = floor( (x_e + 126) / 3 ) = floor( x_e/3 + 42 ) = floor(x_e/3) + 42.

So you could have:

  uint16_t x_e = xf_bits.get_biased_exponent();
  uint16_t out_e = (x_e - 1) / 3 + (127 - 42);
  uint16_t shift_e = (x_e - 1) % 3;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in 4c9b493a

Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
Signed-off-by: krishna2803 <kpandey81930@gmail.com>
@krishna2803
Copy link
Contributor Author

there still seems to be 1 ULP errors for the some numbers where I'm underestimating the result. I'll investigate

unsigned out_e = x_e / 3 + 127;

unsigned shift_e = x_e % 3;
uint32_t x_e = xf_bits.get_biased_exponent();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the comments to explain how we derived the exponent computations.

Signed-off-by: krishna2803 <kpandey81930@gmail.com>
@krishna2803
Copy link
Contributor Author

krishna2803 commented Mar 25, 2025

x (dec) x (hex) cbrtf16(x) mpfr::Operation::Cbrt(x) Error (ULP) Rounding Mode
0.000096499919891357421875 0x0653 0.0458984375 0.045867919921875 1 Upward
0.000771999359130859375 0x1253 0.091796875 0.09173583984375 1 Upward
0.006175994873046875 0x1E53 0.18359375 0.1834716796875 1 Upward
0.049407958984375 0x2A53 0.3671875 0.366943359375 1 Upward
0.395263671875 0x3653 0.734375 0.73388671875 1 Upward
3.162109375 0x4253 1.46875 1.4677734375 1 Upward
25.296875 0x4E53 2.9375 2.935546875 1 Upward
202.375 0x5A53 5.875 5.87109375 1 Upward
1619 0x6653 11.75 11.7421875 1 Upward
12952 0x7253 23.5 23.484375 1 Upward
-0.000096499919891357421875 0x8653 -0.045867919921875 -0.04583740234375 1 Upward
-0.000771999359130859375 0x9253 -0.09173583984375 -0.0916748046875 1 Upward
-0.006175994873046875 0x9E53 -0.1834716796875 -0.183349609375 1 Upward
-0.049407958984375 0xAA53 -0.366943359375 -0.36669921875 1 Upward
-0.395263671875 0xB653 -0.73388671875 -0.7333984375 1 Upward
-3.162109375 0xC253 -1.4677734375 -1.466796875 1 Upward
-25.296875 0xCE53 -2.935546875 -2.93359375 1 Upward
-202.375 0xDA53 -5.87109375 -5.8671875 1 Upward
-1619 0xE653 -11.7421875 -11.734375 1 Upward
-12952 0xF253 -23.484375 -23.46875 1 Upward

after a bit of investigation, i found out that these errors occur when the rounding mode is upward. the result of cbrtf16 function doesn't match with mpfr. and interestingly this occurs when the last byte of the input is 0x53.

@lntue
Copy link
Contributor

lntue commented Mar 26, 2025

x (dec) x (hex) cbrtf16(x) mpfr::Operation::Cbrt(x) Error (ULP) Rounding Mode
0.000096499919891357421875 0x0653 0.0458984375 0.045867919921875 1 Upward
0.000771999359130859375 0x1253 0.091796875 0.09173583984375 1 Upward
0.006175994873046875 0x1E53 0.18359375 0.1834716796875 1 Upward
0.049407958984375 0x2A53 0.3671875 0.366943359375 1 Upward
0.395263671875 0x3653 0.734375 0.73388671875 1 Upward
3.162109375 0x4253 1.46875 1.4677734375 1 Upward
25.296875 0x4E53 2.9375 2.935546875 1 Upward
202.375 0x5A53 5.875 5.87109375 1 Upward
1619 0x6653 11.75 11.7421875 1 Upward
12952 0x7253 23.5 23.484375 1 Upward
-0.000096499919891357421875 0x8653 -0.045867919921875 -0.04583740234375 1 Upward
-0.000771999359130859375 0x9253 -0.09173583984375 -0.0916748046875 1 Upward
-0.006175994873046875 0x9E53 -0.1834716796875 -0.183349609375 1 Upward
-0.049407958984375 0xAA53 -0.366943359375 -0.36669921875 1 Upward
-0.395263671875 0xB653 -0.73388671875 -0.7333984375 1 Upward
-3.162109375 0xC253 -1.4677734375 -1.466796875 1 Upward
-25.296875 0xCE53 -2.935546875 -2.93359375 1 Upward
-202.375 0xDA53 -5.87109375 -5.8671875 1 Upward
-1619 0xE653 -11.7421875 -11.734375 1 Upward
-12952 0xF253 -23.484375 -23.46875 1 Upward
after a bit of investigation, i found out that these errors occur when the rounding mode is upward. the result of cbrtf16 function doesn't match with mpfr. and interestingly this occurs when the last byte of the input is 0x53.

I would add one of these failure to the smoke test, commenting out all other check in the smoke test, add a bunch of printing statements showing the bits / values of intermediate computations, especially around the accuracy test and see what happens there and after.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[libc][math][c23] Implement C23 math function cbrtf16
3 participants