Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PartialOrd for ScalarValue::List/FixSizeList/LargeList #8253

Merged
merged 3 commits into from
Dec 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 35 additions & 75 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,69 +358,47 @@ impl PartialOrd for ScalarValue {
(FixedSizeBinary(_, _), _) => None,
(LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
(LargeBinary(_), _) => None,
(List(arr1), List(arr2)) | (FixedSizeList(arr1), FixedSizeList(arr2)) => {
if arr1.data_type() == arr2.data_type() {
let list_arr1 = as_list_array(arr1);
let list_arr2 = as_list_array(arr2);
if list_arr1.len() != list_arr2.len() {
return None;
}
for i in 0..list_arr1.len() {
let arr1 = list_arr1.value(i);
let arr2 = list_arr2.value(i);

let lt_res =
arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
let eq_res =
arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;

for j in 0..lt_res.len() {
if lt_res.is_valid(j) && lt_res.value(j) {
return Some(Ordering::Less);
}
if eq_res.is_valid(j) && !eq_res.value(j) {
return Some(Ordering::Greater);
}
}
(List(arr1), List(arr2))
| (FixedSizeList(arr1), FixedSizeList(arr2))
| (LargeList(arr1), LargeList(arr2)) => {
// ScalarValue::List / ScalarValue::FixedSizeList / ScalarValue::LargeList are ensure to have length 1
assert_eq!(arr1.len(), 1);
assert_eq!(arr2.len(), 1);

if arr1.data_type() != arr2.data_type() {
return None;
}

fn first_array_for_list(arr: &ArrayRef) -> ArrayRef {
if let Some(arr) = arr.as_list_opt::<i32>() {
arr.value(0)
} else if let Some(arr) = arr.as_list_opt::<i64>() {
arr.value(0)
} else if let Some(arr) = arr.as_fixed_size_list_opt() {
arr.value(0)
} else {
unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer internal error here

Suggested change
unreachable!("Since only List / LargeList / FixedSizeList are supported, this should never happen")
internal_err!("Since only List / LargeList / FixedSizeList are supported, this should never happen")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I think unreachable is the better choice in this case, otherwise when should we use unreachable 😕

Copy link
Member

@Weijun-H Weijun-H Nov 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'This was likely caused by a bug in DataFusion's code and we would welcome that you file a bug report in our issue tracker'.

I rechecked the Internal Error definition, which is for an unobserved bug report. Because here is an if-else branch, it would be more proper for internal error 🤔 .

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should have internal_err for an unobserved bug report. If we can't ensure the value we will get, I think internal_err is appropriate, but in this case, we have type check already, so I think it is ok to just panic if we got to this point. The code should never reach that line unless rust compiler or arrow::DataType is broken.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree panic'ing is find at this case as the types are checked in the match arms

}
Some(Ordering::Equal)
} else {
None
}
}
(LargeList(arr1), LargeList(arr2)) => {
if arr1.data_type() == arr2.data_type() {
let list_arr1 = as_large_list_array(arr1);
let list_arr2 = as_large_list_array(arr2);
if list_arr1.len() != list_arr2.len() {
return None;

let arr1 = first_array_for_list(arr1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code would probably be faster and simpler if it used the single lt_eq kernel: https://docs.rs/arrow/latest/arrow/compute/kernels/cmp/fn.lt_eq.html

However, i see this just follows the existing logic

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With lt_eq, I think we still need to differentiate lt and eq, with either eq or lt.

let arr2 = first_array_for_list(arr2);

let lt_res = arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
let eq_res = arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;

for j in 0..lt_res.len() {
if lt_res.is_valid(j) && lt_res.value(j) {
return Some(Ordering::Less);
}
for i in 0..list_arr1.len() {
let arr1 = list_arr1.value(i);
let arr2 = list_arr2.value(i);

let lt_res =
arrow::compute::kernels::cmp::lt(&arr1, &arr2).ok()?;
let eq_res =
arrow::compute::kernels::cmp::eq(&arr1, &arr2).ok()?;

for j in 0..lt_res.len() {
if lt_res.is_valid(j) && lt_res.value(j) {
return Some(Ordering::Less);
}
if eq_res.is_valid(j) && !eq_res.value(j) {
return Some(Ordering::Greater);
}
}
if eq_res.is_valid(j) && !eq_res.value(j) {
return Some(Ordering::Greater);
}
Some(Ordering::Equal)
} else {
None
}

Some(Ordering::Equal)
}
(List(_), _) => None,
(LargeList(_), _) => None,
(FixedSizeList(_), _) => None,
(List(_), _) | (LargeList(_), _) | (FixedSizeList(_), _) => None,
(Date32(v1), Date32(v2)) => v1.partial_cmp(v2),
(Date32(_), _) => None,
(Date64(v1), Date64(v2)) => v1.partial_cmp(v2),
Expand Down Expand Up @@ -3644,24 +3622,6 @@ mod tests {
])]),
));
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));

let a =
ScalarValue::List(Arc::new(
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(vec![Some(10), Some(2), Some(3)]),
None,
Some(vec![Some(10), Some(2), Some(3)]),
]),
));
let b =
ScalarValue::List(Arc::new(
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![
Some(vec![Some(10), Some(2), Some(3)]),
None,
Some(vec![Some(10), Some(2), Some(3)]),
]),
));
assert_eq!(a.partial_cmp(&b), Some(Ordering::Equal));
}

#[test]
Expand Down