Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Decimal-128 Multiplication For Newer Spark Versions #1623

Merged
merged 10 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions src/main/cpp/src/DecimalUtilsJni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,13 @@

extern "C" {

JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_DecimalUtils_multiply128(
JNIEnv* env, jclass, jlong j_view_a, jlong j_view_b, jint j_product_scale)
JNIEXPORT jlongArray JNICALL
Java_com_nvidia_spark_rapids_jni_DecimalUtils_multiply128(JNIEnv* env,
jclass,
jlong j_view_a,
jlong j_view_b,
jint j_product_scale,
bool cast_interim_result)
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
{
JNI_NULL_CHECK(env, j_view_a, "column is null", 0);
JNI_NULL_CHECK(env, j_view_b, "column is null", 0);
Expand All @@ -30,7 +35,7 @@ JNIEXPORT jlongArray JNICALL Java_com_nvidia_spark_rapids_jni_DecimalUtils_multi
auto view_b = reinterpret_cast<cudf::column_view const*>(j_view_b);
auto scale = static_cast<int>(j_product_scale);
return cudf::jni::convert_table_for_return(
env, cudf::jni::multiply_decimal128(*view_a, *view_b, scale));
env, cudf::jni::multiply_decimal128(*view_a, *view_b, scale, cast_interim_result));
}
CATCH_STD(env, 0);
}
Expand Down
53 changes: 31 additions & 22 deletions src/main/cpp/src/decimal_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -657,14 +657,16 @@ struct dec128_multiplier {
dec128_multiplier(bool* overflows,
cudf::mutable_column_view const& product_view,
cudf::column_view const& a_col,
cudf::column_view const& b_col)
cudf::column_view const& b_col,
bool const cast_interim_result)
: overflows(overflows),
a_data(a_col.data<__int128_t>()),
b_data(b_col.data<__int128_t>()),
product_data(product_view.data<__int128_t>()),
a_scale(a_col.type().scale()),
b_scale(b_col.type().scale()),
prod_scale(product_view.type().scale())
prod_scale(product_view.type().scale()),
cast_interim_result(cast_interim_result)
{
}

Expand All @@ -675,22 +677,26 @@ struct dec128_multiplier {

chunked256 product = multiply(a, b);

// Spark does some really odd things that I personally think are a bug
// https://issues.apache.org/jira/browse/SPARK-40129
// But to match Spark we need to first round the result to a precision of 38
// and this is specific to the value in the result of the multiply.
// Then we need to round the result to the final scale that we care about.
int dec_precision = precision10(product);
int first_div_precision = dec_precision - 38;

int mult_scale = a_scale + b_scale;
if (first_div_precision > 0) {
auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits();
product = divide_and_round(product, first_div_scale_divisor);

// a_scale and b_scale are negative. first_div_precision is not
mult_scale = a_scale + b_scale + first_div_precision;
}
int dec_precision = precision10(product);
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

int const mult_scale = [&]() {
// According to https://issues.apache.org/jira/browse/SPARK-40129
// and https://issues.apache.org/jira/browse/SPARK-45786, Spark has a bug in
// versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 The bug is fixed for later versions but to
// match the legacy behavior we need to first round the result to a precision of 38 then we
// need to round the result to the final scale that we care about.
if (cast_interim_result) {
int first_div_precision = dec_precision - 38;
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
if (first_div_precision > 0) {
auto const first_div_scale_divisor = pow_ten(first_div_precision).as_128_bits();
product = divide_and_round(product, first_div_scale_divisor);

// a_scale and b_scale are negative. first_div_precision is not
return a_scale + b_scale + first_div_precision;
}
}
return a_scale + b_scale;
}();

int exponent = prod_scale - mult_scale;
if (exponent < 0) {
Expand Down Expand Up @@ -718,6 +724,7 @@ struct dec128_multiplier {
private:
// output column for overflow detected
bool* const overflows;
bool cast_interim_result;
ttnghia marked this conversation as resolved.
Show resolved Hide resolved

// input data for multiply
__int128_t const* const a_data;
Expand Down Expand Up @@ -968,6 +975,7 @@ namespace cudf::jni {
std::unique_ptr<cudf::table> multiply_decimal128(cudf::column_view const& a,
cudf::column_view const& b,
int32_t product_scale,
bool const cast_interim_result,
rmm::cuda_stream_view stream)
{
CUDF_EXPECTS(a.type().id() == cudf::type_id::DECIMAL128, "not a DECIMAL128 column");
Expand All @@ -992,10 +1000,11 @@ std::unique_ptr<cudf::table> multiply_decimal128(cudf::column_view const& a,
auto overflows_view = columns[0]->mutable_view();
auto product_view = columns[1]->mutable_view();
check_scale_divisor(a.type().scale() + b.type().scale(), product_scale);
thrust::for_each(rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(num_rows),
dec128_multiplier(overflows_view.begin<bool>(), product_view, a, b));
thrust::for_each(
rmm::exec_policy(stream),
thrust::make_counting_iterator<cudf::size_type>(0),
thrust::make_counting_iterator<cudf::size_type>(num_rows),
dec128_multiplier(overflows_view.begin<bool>(), product_view, a, b, cast_interim_result));
return std::make_unique<cudf::table>(std::move(columns));
}

Expand Down
1 change: 1 addition & 0 deletions src/main/cpp/src/decimal_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ std::unique_ptr<cudf::table> multiply_decimal128(
cudf::column_view const& a,
cudf::column_view const& b,
int32_t product_scale,
bool const cast_interim_result,
rmm::cuda_stream_view stream = cudf::get_default_stream());

std::unique_ptr<cudf::table> divide_decimal128(
Expand Down
35 changes: 32 additions & 3 deletions src/main/java/com/nvidia/spark/rapids/jni/DecimalUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,50 @@ public class DecimalUtils {
NativeDepsLoader.loadNativeDeps();
}

/**
* Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified
* scale with overflow detection. This method considers a precision greater than 38 as overflow
* even if the number still fits in a 128-bit representation.
*
* WARNING: This method has a bug which we match with Spark versions before 3.4.2,
* 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10:
* -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166
* while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165
*
* @param a factor input, must match row count of the other factor input
* @param b factor input, must match row count of the other factor input
* @param productScale scale to use for the product type
* @return table containing a boolean column and a DECIMAL128 product column of the specified
* scale. The boolean value will be true if an overflow was detected for that row's
* DECIMAL128 product value. A null input row will result in a corresponding null output
* row.
*/
public static Table multiply128(ColumnView a, ColumnView b, int productScale) {
return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, true));
}

/**
* Multiply two DECIMAL128 columns together into a DECIMAL128 product rounded to the specified
* scale with overflow detection. This method considers a precision greater than 38 as overflow
* even if the number still fits in a 128-bit representation.
*
* WARNING: With interimCast set to true, this method has a bug which we match with Spark versions before 3.4.2,
* 4.0.0, 3.5.1. Consider the following example using Decimal with a precision of 38 and scale of 10:
* -8533444864753048107770677711.1312637916 * -12.0000000000 = 102401338377036577293248132533.575166
* while the actual answer based on Java BigDecimal is 102401338377036577293248132533.575165
*
* @param a factor input, must match row count of the other factor input
* @param b factor input, must match row count of the other factor input
* @param productScale scale to use for the product type
* @param interimCast whether to cast the result of the division to 38 precision before casting it again to the final
* precision
Comment on lines +55 to +64
Copy link
Collaborator

@ttnghia ttnghia Dec 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Sorry, I think these lines are longer than usual thus we may need to rewrite them a little bit. As a convention, typically a line should not exceed 100 characters. The lines above are up to 120.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, good point. I see the file has other instances where we are doing this. To keep this PR focused, I don't want to make the changes in other places and since the CI has passed I would really appreciate if we can do that as a follow-on.

* @return table containing a boolean column and a DECIMAL128 product column of the specified
* scale. The boolean value will be true if an overflow was detected for that row's
* DECIMAL128 product value. A null input row will result in a corresponding null output
* row.
*/
public static Table multiply128(ColumnView a, ColumnView b, int productScale) {
return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale));
public static Table multiply128(ColumnView a, ColumnView b, int productScale, boolean interimCast) {
return new Table(multiply128(a.getNativeView(), b.getNativeView(), productScale, interimCast));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the interimCast is applied to Spark versions 3.2.4, 3.3.3, 3.4.1, 3.5.0 and 4.0.0 then please add such clarification into the docs of this function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought the docs were pretty clear, if there is anything else you want me to add please share and I will add.

}

/**
Expand Down Expand Up @@ -148,7 +177,7 @@ public static Table add128(ColumnView a, ColumnView b, int targetScale) {
return new Table(add128(a.getNativeView(), b.getNativeView(), targetScale));
}

private static native long[] multiply128(long viewA, long viewB, int productScale);
private static native long[] multiply128(long viewA, long viewB, int productScale, boolean interimCast);

private static native long[] divide128(long viewA, long viewB, int quotientScale, boolean isIntegerDivide);

Expand Down
16 changes: 16 additions & 0 deletions src/test/java/com/nvidia/spark/rapids/jni/DecimalUtilsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,22 @@ void simplePosMultiplyZeroByNegOne() {
}
}

@Test
void multiply128WithoutInterimCast() {
try (ColumnVector lhs =
makeDec128Column("-8533444864753048107770677711.1312637916");
ColumnVector rhs =
makeDec128Column("-12.0000000000");
ColumnVector expectedBasic =
makeDec128Column("102401338377036577293248132533.575165");
ColumnVector expectedValid =
ColumnVector.fromBooleans(false);
ttnghia marked this conversation as resolved.
Show resolved Hide resolved
Table found = DecimalUtils.multiply128(lhs, rhs, -6, false)) {
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved
assertColumnsAreEqual(expectedValid, found.getColumn(0));
assertColumnsAreEqual(expectedBasic, found.getColumn(1));
}
}

@Test
void largePosMultiplyTenByTen() {
try (ColumnVector lhs =
Expand Down