Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement Spark-compatible CAST from string to integral types #307

Merged
merged 49 commits into from
May 1, 2024

Conversation

andygrove
Copy link
Member

@andygrove andygrove commented Apr 23, 2024

Which issue does this PR close?

Part of #286
Closes #15

Rationale for this change

Improve compatibility with Apache Spark

What changes are included in this PR?

Add custom implementation of CAST from string to integral rather than delegate to DataFusion

How are these changes tested?

@andygrove andygrove marked this pull request as draft April 23, 2024 15:17
@andygrove
Copy link
Member Author

I am now working on refactoring to reduce code duplication by leveraging macros/generics.

@andygrove andygrove marked this pull request as ready for review April 23, 2024 19:23
(
DataType::Dictionary(key_type, value_type),
DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64,
) if key_type.as_ref() == &DataType::Int32
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya do you know if dictionary keys will always be i32?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been assuming it to be so, though @viirya can give us the definitive answer

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember in many places in native code, we assume that dictionary keys are always Int32 type.

But I forgot that where we make such assumption. 😅

cc @sunchao Do you remember that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. I think the assumption comes from native scan side where the Parquet dictionary indices is always of integer type so dictionary keys read from native scan is always Int32 type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check the DictDecoder in native scan implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Except for that any operator or expression during execution produce a dictionary with keys other than Int32 type. But for that I think it should be considered a bug for us to fix because I don't think it makes sense to change dictionary key type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see. I think the assumption comes from native scan side where the Parquet dictionary indices is always of integer type so dictionary keys read from native scan is always Int32 type.

Yes that is exactly right.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @sunchao for confirming it.

@@ -64,6 +68,25 @@ pub struct Cast {
pub timezone: String,
}

macro_rules! spark_cast_utf8_to_integral {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe utf8_to_integer?

spark not involved in native exec, not sure why spark is needed.
Integral type also includes booleans and this scope limited by integers afaik

macro_rules! spark_cast_utf8_to_integral {
($string_array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{
let mut cast_array = PrimitiveArray::<$array_type>::builder($string_array.len());
for i in 0..$string_array.len() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably can use iterator instead of for loop?

and lets calc $string_array.len() once

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that $string_array.len() is already only computed once?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see them on lines 73,74 🤔

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed that! Thanks

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Ok(spark_cast(cast_result, from_type, to_type))
}

fn spark_cast_string_to_integral(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string_to_int?

@comphead
Copy link
Contributor

Thanks @andygrove btw I'm wondering if this PR should cover scope with formatting https://spark.apache.org/docs/latest/sql-ref-number-pattern.html#the-to_number-function

andygrove and others added 3 commits April 23, 2024 13:47
Co-authored-by: comphead <comphead@users.noreply.github.com>
…datafusion-comet into cast-string-to-integral
}

ignore("cast string to short") {
test("cast string to short") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably have some negative tests with invalid strings.
Also, curious, what does cast(".") yield?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fuzz testing does generate many invalid inputs. I can add some more explicit ones to these tests, though.

cast(".") will yield different results depending on the eval mode:

  • LEGACY -> 0
  • TRY -> null
  • ANSI -> error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for these test btw

@andygrove
Copy link
Member Author

Thanks @andygrove btw I'm wondering if this PR should cover scope with formatting https://spark.apache.org/docs/latest/sql-ref-number-pattern.html#the-to_number-function

Sorry, I'm not sure I understand. You are referring to the error message formatting?

@comphead
Copy link
Contributor

Thanks @andygrove btw I'm wondering if this PR should cover scope with formatting https://spark.apache.org/docs/latest/sql-ref-number-pattern.html#the-to_number-function

Sorry, I'm not sure I understand. You are referring to the error message formatting?

Oh it covers just cast string to integers, I thought to_number() is also covered as it has casting behind the scenes

Comment on lines 443 to 444
let negative = chars[0] == '-';
if negative || chars[0] == '+' {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems wrong.
It should be chars[i] == '-' instead? Otherwise, this cast doesn't work for -124

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! The code was originally trimming the string before this point and I missed updating this when I removed the trim. I have now fixed this.

use super::{cast_string_to_i8, EvalMode};

#[test]
fn test_cast_string_as_i8() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about add more tests about i32 and i64 with its min/max and zero input?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to focus on improving the tests in this PR today

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now added tests for all min/max boundary values in the Scala tests

@@ -103,10 +125,72 @@ impl Cast {
(DataType::LargeUtf8, DataType::Boolean) => {
Self::spark_cast_utf8_to_boolean::<i64>(&array, self.eval_mode)?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not part of this pr. But if we are going to name the added method as cast_string_to_int.

This method should be renamed to cast_utf8_to_boolean as well in a follow-up PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. I didn't want to start making unrelated changes in this PR, but we should rename this.

// Note that we are unpacking a dictionary-encoded array and then performing
// the cast. We could potentially improve performance here by casting the
// dictionary values directly without unpacking the array first, although this
// would add more complexity to the code
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can leave a TODO to cast dictionary directly?


ignore("cast string to long") {
castTest(generateStrings(numericPattern, 8).toDF("a"), DataTypes.LongType)
private val castStringToIntegralInputs: Seq[String] = Seq(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Since the cast code handles leading and trailing white spaces, I think we can add more input with white spaces.

For example:

castStringToIntegeralnputs.flatMap { x => Seq("  " + x, x + "   ", "   " + x + "  ") }

Copy link
Contributor

@advancedxy advancedxy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your effort @andygrove, the new code is well crafted.

@andygrove
Copy link
Member Author

@viirya @sunchao @parthchandra @comphead I did quite a bit of refactoring and performance tuning over the weekend. Please take another look when you can.

@andygrove
Copy link
Member Author

Thanks for your effort @andygrove, the new code is well crafted.

Thank you for the thorough review @advancedxy!

let len = $array.len();
let mut cast_array = PrimitiveArray::<$array_type>::builder(len);
for i in 0..len {
if $array.is_null(i) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it can be simplified to

if let Some(cast_value) = $cast_method($array.value(i).trim(), $eval_mode)? {
                 cast_array.append_value(cast_value);
             } else {
                 cast_array.append_null()
             }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a null input then we will always want a null output and we don't want to add the overhead of calling the cast logic in this case.

Copy link
Contributor

@comphead comphead left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm thanks @andygrove couple of minors

Comment on lines 404 to 418
/// Either return Ok(None) or Err(CometError::CastInvalidValue) depending on the evaluation mode
fn none_or_err<T>(eval_mode: EvalMode, type_name: &str, str: &str) -> CometResult<Option<T>> {
match eval_mode {
EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)),
_ => Ok(None),
}
}

fn invalid_value(value: &str, from_type: &str, to_type: &str) -> CometError {
CometError::CastInvalidValue {
value: value.to_string(),
from_type: from_type.to_string(),
to_type: to_type.to_string(),
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these can be inline function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I have updated this.

Comment on lines 255 to 261
fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> CometResult<Option<i32>> {
do_cast_string_to_int::<i32>(str, eval_mode, "INT", i32::MIN)
}

fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> CometResult<Option<i64>> {
do_cast_string_to_int::<i64>(str, eval_mode, "BIGINT", i64::MIN)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only i8 and i16 have range check?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is ported directly from Spark. This is the approach that is used there.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spark has IntWrapper and LongWrapper which are equivalent to do_cast_string_to_int::<i32> and do_cast_string_to_int::<i64> in this PR.

This is the logic for casting to byte in Spark. It uses IntWrapper then casts to byte.

  public boolean toByte(IntWrapper intWrapper) {
    if (toInt(intWrapper)) {
      int intValue = intWrapper.value;
      byte result = (byte) intValue;

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed a commit to add some comments referencing the Spark code that this code is based on

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@viirya Let me know if there is anything else to address. I have upmerged with latest from main branch so this PR is a little smaller now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will go to look at this again tonight.

Comment on lines 356 to 358
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE /
// radix), we can just use `result > 0` to check overflow. If result
// overflows, we should stop
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean "more than or equal to"? I think the above condition (L352) is already for result < stop_value?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment was copied from the Spark code in org/apache/spark/unsafe/types/UTF8String.java, but I agree that it seems incorrect. I have updated it.

@viirya
Copy link
Member

viirya commented May 1, 2024

Looks good to me. Thanks @andygrove

@andygrove andygrove merged commit cbf4730 into apache:main May 1, 2024
28 checks passed
Copy link
Contributor

@parthchandra parthchandra left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

return none_or_err(eval_mode, type_name, str);
};

// We are going to process the new digit and accumulate the result. However, before
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A comment to explain why we're using subtraction instead of addition would make it easier to understand this part of the code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cast string to integral type not compatible with Spark
6 participants