From 64ba019fe30ea12c554296b756004d5b36c30154 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Wed, 6 Dec 2023 13:51:38 -0800 Subject: [PATCH 01/10] Added another multiplication method for decimal 128 --- src/main/cpp/src/DecimalUtilsJni.cpp | 11 +++-- src/main/cpp/src/decimal_utils.cu | 40 +++++++++++-------- src/main/cpp/src/decimal_utils.hpp | 1 + .../nvidia/spark/rapids/jni/DecimalUtils.java | 25 +++++++++++- .../spark/rapids/jni/DecimalUtilsTest.java | 16 ++++++++ 5 files changed, 72 insertions(+), 21 deletions(-) diff --git a/src/main/cpp/src/DecimalUtilsJni.cpp b/src/main/cpp/src/DecimalUtilsJni.cpp index f732276817..6c7c1cc781 100644 --- a/src/main/cpp/src/DecimalUtilsJni.cpp +++ b/src/main/cpp/src/DecimalUtilsJni.cpp @@ -19,8 +19,13 @@ extern "C" { -JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_DecimalUtils_multiply128( - JNIEnv* env, jclass, jlong j_view_a, jlong j_view_b, jint j_product_scale) +JNIEXPORT jlongArray JNICALL +Java_com_nvidia_spark_rapids_jni_DecimalUtils_multiply128(JNIEnv* env, + jclass, + jlong j_view_a, + jlong j_view_b, + jint j_product_scale, + bool cast_interim_result) { JNI_NULL_CHECK(env, j_view_a, "column is null", 0); JNI_NULL_CHECK(env, j_view_b, "column is null", 0); @@ -30,7 +35,7 @@ JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_DecimalUtils_multi auto view_b = reinterpret_cast(j_view_b); auto scale = static_cast(j_product_scale); return cudf::jni::convert_table_for_return( - env, cudf::jni::multiply_decimal128(*view_a, *view_b, scale)); + env, cudf::jni::multiply_decimal128(*view_a, *view_b, scale, cast_interim_result)); } CATCH_STD(env, 0); } diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 392fb495b4..563c378bff 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -657,14 +657,16 @@ struct dec128_multiplier { dec128_multiplier(bool* overflows, cudf::mutable_column_view const& product_view, cudf::column_view const& a_col, - cudf::column_view const& b_col) + cudf::column_view const& b_col, + bool const& cast_interim_result) : overflows(overflows), a_data(a_col.data<__int128_t>()), b_data(b_col.data<__int128_t>()), product_data(product_view.data<__int128_t>()), a_scale(a_col.type().scale()), b_scale(b_col.type().scale()), - prod_scale(product_view.type().scale()) + prod_scale(product_view.type().scale()), + cast_interim_result(cast_interim_result) { } @@ -675,21 +677,24 @@ struct dec128_multiplier { chunked256 product = multiply(a, b); + int dec_precision = precision10(product); + + int mult_scale = a_scale + b_scale; + // Spark does some really odd things that I personally think are a bug // https://issues.apache.org/jira/browse/SPARK-40129 // But to match Spark we need to first round the result to a precision of 38 // and this is specific to the value in the result of the multiply. // Then we need to round the result to the final scale that we care about. - int dec_precision = precision10(product); - int first_div_precision = dec_precision - 38; - - int mult_scale = a_scale + b_scale; - if (first_div_precision > 0) { - auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); - product = divide_and_round(product, first_div_scale_divisor); - - // a_scale and b_scale are negative. first_div_precision is not - mult_scale = a_scale + b_scale + first_div_precision; + if (cast_interim_result) { + int first_div_precision = dec_precision - 38; + if (first_div_precision > 0) { + auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); + product = divide_and_round(product, first_div_scale_divisor); + + // a_scale and b_scale are negative. first_div_precision is not + mult_scale = a_scale + b_scale + first_div_precision; + } } int exponent = prod_scale - mult_scale; @@ -718,6 +723,7 @@ struct dec128_multiplier { private: // output column for overflow detected bool* const overflows; + bool cast_interim_result; // input data for multiply __int128_t const* const a_data; @@ -968,6 +974,7 @@ namespace cudf::jni { std::unique_ptr multiply_decimal128(cudf::column_view const& a, cudf::column_view const& b, int32_t product_scale, + bool const& cast_interim_result, rmm::cuda_stream_view stream) { CUDF_EXPECTS(a.type().id() == cudf::type_id::DECIMAL128, "not a DECIMAL128 column"); @@ -992,10 +999,11 @@ std::unique_ptr multiply_decimal128(cudf::column_view const& a, auto overflows_view = columns[0]->mutable_view(); auto product_view = columns[1]->mutable_view(); check_scale_divisor(a.type().scale() + b.type().scale(), product_scale); - thrust::for_each(rmm::exec_policy(stream), - thrust::make_counting_iterator(0), - thrust::make_counting_iterator(num_rows), - dec128_multiplier(overflows_view.begin(), product_view, a, b)); + thrust::for_each( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(num_rows), + dec128_multiplier(overflows_view.begin(), product_view, a, b, cast_interim_result)); return std::make_unique(std::move(columns)); } diff --git a/src/main/cpp/src/decimal_utils.hpp b/src/main/cpp/src/decimal_utils.hpp index 95c6c56c3d..ed4ee30d95 100644 --- a/src/main/cpp/src/decimal_utils.hpp +++ b/src/main/cpp/src/decimal_utils.hpp @@ -30,6 +30,7 @@ std::unique_ptr multiply_decimal128( cudf::column_view const& a, cudf::column_view const& b, int32_t product_scale, + bool const& cast_interim_result, rmm::cuda_stream_view stream = cudf::get_default_stream()); std::unique_ptr divide_decimal128( diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java index 389679965a..adc559c096 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java @@ -25,11 +25,32 @@ public class DecimalUtils { NativeDepsLoader.loadNativeDeps(); } + /** + * Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified + * scale with overflow detection. This method considers a precision greater than 38 as overflow + * even if the number still fits in a 128-bit representation. + * @param a factor input, must match row count of the other factor input + * @param b factor input, must match row count of the other factor input + * @param productScale scale to use for the product type + * @return table containing a boolean column and a DECIMAL128 product column of the specified + * scale. The boolean value will be true if an overflow was detected for that row's + * DECIMAL128 product value. A null input row will result in a corresponding null output + * row. + */ + public static Table mul128(ColumnView a, ColumnView b, int productScale) { + return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, false)); + } /** * Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. + * + * WARNING: This method has a bug which we match with Spark versions before 3.4.2, 4.0.0, 3.5.1. Consider the + * following example using Decimal with a precision of 38 and scale of 10: + * -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166 + * while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165 + * * @param a factor input, must match row count of the other factor input * @param b factor input, must match row count of the other factor input * @param productScale scale to use for the product type @@ -39,7 +60,7 @@ public class DecimalUtils { * row. */ public static Table multiply128(ColumnView a, ColumnView b, int productScale) { - return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale)); + return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, true)); } /** @@ -148,7 +169,7 @@ public static Table add128(ColumnView a, ColumnView b, int targetScale) { return new Table(add128(a.getNativeView(), b.getNativeView(), targetScale)); } - private static native long[] multiply128(long viewA, long viewB, int productScale); + private static native long[] multiply128(long viewA, long viewB, int productScale, boolean interimCast); private static native long[] divide128(long viewA, long viewB, int quotientScale, boolean isIntegerDivide); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java index 4698855f31..2a5856275c 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java @@ -86,6 +86,22 @@ void simplePosMultiplyZeroByNegOne() { } } + @Test + void multiply128WithoutInterimCast() { + try (ColumnVector lhs = + makeDec128Column("-8533444864753048107770677711.1312637916"); + ColumnVector rhs = + makeDec128Column("-12.0000000000"); + ColumnVector expectedBasic = + makeDec128Column("102401338377036577293248132533.575165"); + ColumnVector expectedValid = + ColumnVector.fromBooleans(false); + Table found = DecimalUtils.mul128(lhs, rhs, -6)) { + assertColumnsAreEqual(expectedValid, found.getColumn(0)); + assertColumnsAreEqual(expectedBasic, found.getColumn(1)); + } + } + @Test void largePosMultiplyTenByTen() { try (ColumnVector lhs = From 041088d04f2bacc69431cf832c2c9b575fef5da5 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Wed, 6 Dec 2023 17:10:27 -0800 Subject: [PATCH 02/10] Signing off Signed-off-by: Raza Jafri From 81f82af5b68822c28fc285406022ee5bdae3be12 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 12 Dec 2023 12:08:53 -0800 Subject: [PATCH 03/10] addressed review comments --- src/main/cpp/src/decimal_utils.cu | 10 +++---- .../nvidia/spark/rapids/jni/DecimalUtils.java | 28 ++++++++++++------- .../spark/rapids/jni/DecimalUtilsTest.java | 2 +- 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 563c378bff..3f3fe23304 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -681,11 +681,11 @@ struct dec128_multiplier { int mult_scale = a_scale + b_scale; - // Spark does some really odd things that I personally think are a bug - // https://issues.apache.org/jira/browse/SPARK-40129 - // But to match Spark we need to first round the result to a precision of 38 - // and this is specific to the value in the result of the multiply. - // Then we need to round the result to the final scale that we care about. + // According to https://issues.apache.org/jira/browse/SPARK-40129 + // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 + // and 4.0.0 + // The bug is fixed for later versions but to match the legacy behavior we need to first round the result to a + // precision of 38 then we need to round the result to the final scale that we care about. if (cast_interim_result) { int first_div_precision = dec_precision - 38; if (first_div_precision > 0) { diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java index adc559c096..ae9d9b16c4 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java @@ -29,15 +29,21 @@ public class DecimalUtils { * Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. - * @param a factor input, must match row count of the other factor input - * @param b factor input, must match row count of the other factor input + * + * WARNING: This method has a bug which we match with Spark versions before 3.4.2, + * 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10: + * -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166 + * while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165 + * + * @param a factor input, must match row count of the other factor input + * @param b factor input, must match row count of the other factor input * @param productScale scale to use for the product type * @return table containing a boolean column and a DECIMAL128 product column of the specified - * scale. The boolean value will be true if an overflow was detected for that row's - * DECIMAL128 product value. A null input row will result in a corresponding null output - * row. + * scale. The boolean value will be true if an overflow was detected for that row's + * DECIMAL128 product value. A null input row will result in a corresponding null output + * row. */ - public static Table mul128(ColumnView a, ColumnView b, int productScale) { + public static Table multiply128(ColumnView a, ColumnView b, int productScale) { return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, false)); } @@ -46,21 +52,23 @@ public static Table mul128(ColumnView a, ColumnView b, int productScale) { * scale with overflow detection. This method considers a precision greater than 38 as overflow * even if the number still fits in a 128-bit representation. * - * WARNING: This method has a bug which we match with Spark versions before 3.4.2, 4.0.0, 3.5.1. Consider the - * following example using Decimal with a precision of 38 and scale of 10: + * WARNING: With interimCast set to true, this method has a bug which we match with Spark versions before 3.4.2, + * 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10: * -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166 * while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165 * * @param a factor input, must match row count of the other factor input * @param b factor input, must match row count of the other factor input * @param productScale scale to use for the product type + * @param interimCast whether to cast the result of the division to 38 precision before casting it again to the final + * precision * @return table containing a boolean column and a DECIMAL128 product column of the specified * scale. The boolean value will be true if an overflow was detected for that row's * DECIMAL128 product value. A null input row will result in a corresponding null output * row. */ - public static Table multiply128(ColumnView a, ColumnView b, int productScale) { - return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, true)); + public static Table multiply128(ColumnView a, ColumnView b, int productScale, boolean interimCast) { + return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, interimCast)); } /** diff --git a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java index 2a5856275c..e3ea7dca66 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java @@ -96,7 +96,7 @@ void multiply128WithoutInterimCast() { makeDec128Column("102401338377036577293248132533.575165"); ColumnVector expectedValid = ColumnVector.fromBooleans(false); - Table found = DecimalUtils.mul128(lhs, rhs, -6)) { + Table found = DecimalUtils.multiply128(lhs, rhs, -6, false)) { assertColumnsAreEqual(expectedValid, found.getColumn(0)); assertColumnsAreEqual(expectedBasic, found.getColumn(1)); } From 96a929787b9bc7661d3d9924733f8bcce0162424 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Tue, 12 Dec 2023 13:13:02 -0800 Subject: [PATCH 04/10] fixed clang --- src/main/cpp/src/decimal_utils.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 3f3fe23304..84e6baa7ac 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -682,10 +682,10 @@ struct dec128_multiplier { int mult_scale = a_scale + b_scale; // According to https://issues.apache.org/jira/browse/SPARK-40129 - // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 - // and 4.0.0 - // The bug is fixed for later versions but to match the legacy behavior we need to first round the result to a - // precision of 38 then we need to round the result to the final scale that we care about. + // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in + // versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to + // match the legacy behavior we need to first round the result to a precision of 38 then we need + // to round the result to the final scale that we care about. if (cast_interim_result) { int first_div_precision = dec_precision - 38; if (first_div_precision > 0) { From d6f0cfe3309d53a6e4fe51c7494b70f294c0c61b Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Wed, 13 Dec 2023 11:20:18 -0800 Subject: [PATCH 05/10] addressed review comments --- src/main/cpp/src/decimal_utils.cu | 37 ++++++++++--------- .../nvidia/spark/rapids/jni/DecimalUtils.java | 2 +- 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 84e6baa7ac..a9d9fbf890 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -658,7 +658,7 @@ struct dec128_multiplier { cudf::mutable_column_view const& product_view, cudf::column_view const& a_col, cudf::column_view const& b_col, - bool const& cast_interim_result) + bool const cast_interim_result) : overflows(overflows), a_data(a_col.data<__int128_t>()), b_data(b_col.data<__int128_t>()), @@ -679,23 +679,24 @@ struct dec128_multiplier { int dec_precision = precision10(product); - int mult_scale = a_scale + b_scale; - - // According to https://issues.apache.org/jira/browse/SPARK-40129 - // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in - // versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to - // match the legacy behavior we need to first round the result to a precision of 38 then we need - // to round the result to the final scale that we care about. - if (cast_interim_result) { - int first_div_precision = dec_precision - 38; - if (first_div_precision > 0) { - auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); - product = divide_and_round(product, first_div_scale_divisor); - - // a_scale and b_scale are negative. first_div_precision is not - mult_scale = a_scale + b_scale + first_div_precision; - } - } + int const mult_scale = [&]() { + // According to https://issues.apache.org/jira/browse/SPARK-40129 + // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in + // versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to + // match the legacy behavior we need to first round the result to a precision of 38 then we need + // to round the result to the final scale that we care about. + if (cast_interim_result) { + int first_div_precision = dec_precision - 38; + if (first_div_precision > 0) { + auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); + product = divide_and_round(product, first_div_scale_divisor); + + // a_scale and b_scale are negative. first_div_precision is not + return a_scale + b_scale + first_div_precision; + } + } + return a_scale + b_scale; + }(); int exponent = prod_scale - mult_scale; if (exponent < 0) { diff --git a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java index ae9d9b16c4..17337691c5 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java @@ -44,7 +44,7 @@ public class DecimalUtils { * row. */ public static Table multiply128(ColumnView a, ColumnView b, int productScale) { - return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, false)); + return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, true)); } /** From 88e9b7c2af65f1db5bc200a981f3301fde7d6669 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Wed, 13 Dec 2023 11:23:06 -0800 Subject: [PATCH 06/10] ran pre-commit --- src/main/cpp/src/decimal_utils.cu | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index a9d9fbf890..683c300207 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -680,22 +680,22 @@ struct dec128_multiplier { int dec_precision = precision10(product); int const mult_scale = [&]() { - // According to https://issues.apache.org/jira/browse/SPARK-40129 - // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in - // versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to - // match the legacy behavior we need to first round the result to a precision of 38 then we need - // to round the result to the final scale that we care about. - if (cast_interim_result) { - int first_div_precision = dec_precision - 38; - if (first_div_precision > 0) { - auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); - product = divide_and_round(product, first_div_scale_divisor); - - // a_scale and b_scale are negative. first_div_precision is not - return a_scale + b_scale + first_div_precision; - } + // According to https://issues.apache.org/jira/browse/SPARK-40129 + // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in + // versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to + // match the legacy behavior we need to first round the result to a precision of 38 then we + // need to round the result to the final scale that we care about. + if (cast_interim_result) { + int first_div_precision = dec_precision - 38; + if (first_div_precision > 0) { + auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); + product = divide_and_round(product, first_div_scale_divisor); + + // a_scale and b_scale are negative. first_div_precision is not + return a_scale + b_scale + first_div_precision; } - return a_scale + b_scale; + } + return a_scale + b_scale; }(); int exponent = prod_scale - mult_scale; From b1cb173ee2fb59feb2ea769c826ec0b429d0d4c5 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Wed, 13 Dec 2023 16:17:24 -0800 Subject: [PATCH 07/10] removed pass-by reference --- src/main/cpp/src/decimal_utils.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 683c300207..49b766493c 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -975,7 +975,7 @@ namespace cudf::jni { std::unique_ptr multiply_decimal128(cudf::column_view const& a, cudf::column_view const& b, int32_t product_scale, - bool const& cast_interim_result, + bool const cast_interim_result, rmm::cuda_stream_view stream) { CUDF_EXPECTS(a.type().id() == cudf::type_id::DECIMAL128, "not a DECIMAL128 column"); From 185a62bc249401f00012c524fad7b1921c4e345b Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 14 Dec 2023 09:21:15 -0800 Subject: [PATCH 08/10] possible reason for CI failure, as locally it still builds --- src/main/cpp/src/decimal_utils.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/cpp/src/decimal_utils.hpp b/src/main/cpp/src/decimal_utils.hpp index ed4ee30d95..9793e63445 100644 --- a/src/main/cpp/src/decimal_utils.hpp +++ b/src/main/cpp/src/decimal_utils.hpp @@ -30,7 +30,7 @@ std::unique_ptr multiply_decimal128( cudf::column_view const& a, cudf::column_view const& b, int32_t product_scale, - bool const& cast_interim_result, + bool const cast_interim_result, rmm::cuda_stream_view stream = cudf::get_default_stream()); std::unique_ptr divide_decimal128( From 67c99577064ccfe664abfc26487213489399e047 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 14 Dec 2023 22:38:35 -0800 Subject: [PATCH 09/10] addressed review comments --- src/main/cpp/src/decimal_utils.cu | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/main/cpp/src/decimal_utils.cu b/src/main/cpp/src/decimal_utils.cu index 49b766493c..92273ff545 100644 --- a/src/main/cpp/src/decimal_utils.cu +++ b/src/main/cpp/src/decimal_utils.cu @@ -677,8 +677,6 @@ struct dec128_multiplier { chunked256 product = multiply(a, b); - int dec_precision = precision10(product); - int const mult_scale = [&]() { // According to https://issues.apache.org/jira/browse/SPARK-40129 // and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in @@ -686,7 +684,7 @@ struct dec128_multiplier { // match the legacy behavior we need to first round the result to a precision of 38 then we // need to round the result to the final scale that we care about. if (cast_interim_result) { - int first_div_precision = dec_precision - 38; + auto const first_div_precision = precision10(product) - 38; if (first_div_precision > 0) { auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits(); product = divide_and_round(product, first_div_scale_divisor); @@ -724,7 +722,7 @@ struct dec128_multiplier { private: // output column for overflow detected bool* const overflows; - bool cast_interim_result; + bool const cast_interim_result; // input data for multiply __int128_t const* const a_data; From cddc36d9da31c32fb0e6263874c894c9ffd0e034 Mon Sep 17 00:00:00 2001 From: Raza Jafri Date: Thu, 14 Dec 2023 22:40:30 -0800 Subject: [PATCH 10/10] formatted Java code --- .../nvidia/spark/rapids/jni/DecimalUtilsTest.java | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java index e3ea7dca66..7f3079e825 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java @@ -88,14 +88,10 @@ void simplePosMultiplyZeroByNegOne() { @Test void multiply128WithoutInterimCast() { - try (ColumnVector lhs = - makeDec128Column("-8533444864753048107770677711.1312637916"); - ColumnVector rhs = - makeDec128Column("-12.0000000000"); - ColumnVector expectedBasic = - makeDec128Column("102401338377036577293248132533.575165"); - ColumnVector expectedValid = - ColumnVector.fromBooleans(false); + try (ColumnVector lhs = makeDec128Column("-8533444864753048107770677711.1312637916"); + ColumnVector rhs = makeDec128Column("-12.0000000000"); + ColumnVector expectedBasic = makeDec128Column("102401338377036577293248132533.575165"); + ColumnVector expectedValid = ColumnVector.fromBooleans(false); Table found = DecimalUtils.multiply128(lhs, rhs, -6, false)) { assertColumnsAreEqual(expectedValid, found.getColumn(0)); assertColumnsAreEqual(expectedBasic, found.getColumn(1));