Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions native/spark-expr/src/bitwise_funcs/bitwise_count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use arrow::{array::*, datatypes::DataType};
use datafusion::common::{exec_err, internal_datafusion_err, internal_err, Result};
use datafusion::common::{exec_err, internal_datafusion_err, Result};
use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
use std::any::Any;
Expand Down Expand Up @@ -99,15 +99,38 @@ pub fn spark_bit_count(args: [ColumnarValue; 1]) -> Result<ColumnarValue> {
DataType::Int16 => compute_op!(array, Int16Array),
DataType::Int32 => compute_op!(array, Int32Array),
DataType::Int64 => compute_op!(array, Int64Array),
_ => exec_err!("bit_count can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()),
_ => exec_err!("bit_count can't be evaluated because the array's type is {:?}, not signed int/boolean", array.data_type()),
};
result.map(ColumnarValue::Array)
}
[ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise count scalar path"),
[ColumnarValue::Scalar(scalar)] => {
use datafusion::common::ScalarValue;
let result = match scalar {
ScalarValue::Int8(Some(v)) => bit_count(v as i64),
ScalarValue::Int16(Some(v)) => bit_count(v as i64),
ScalarValue::Int32(Some(v)) => bit_count(v as i64),
ScalarValue::Int64(Some(v)) => bit_count(v),
ScalarValue::Boolean(Some(v)) => bit_count(if v { 1 } else { 0 }),
ScalarValue::Int8(None)
| ScalarValue::Int16(None)
| ScalarValue::Int32(None)
| ScalarValue::Int64(None)
| ScalarValue::Boolean(None) => {
return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None)))
}
_ => {
return exec_err!(
"bit_count can't be evaluated because the scalar's type is {:?}, not signed int/boolean",
scalar.data_type()
)
}
};
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result))))
}
}
}

// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType)
// Here’s the equivalent Rust implementation of the bitCount function (similar to Java's bitCount for LongType)
fn bit_count(i: i64) -> i32 {
let mut u = i as u64;
u = u - ((u >> 1) & 0x5555555555555555);
Expand All @@ -121,7 +144,7 @@ fn bit_count(i: i64) -> i32 {

#[cfg(test)]
mod tests {
use datafusion::common::{cast::as_int32_array, Result};
use datafusion::common::{cast::as_int32_array, Result, ScalarValue};

use super::*;

Expand All @@ -133,8 +156,18 @@ mod tests {
Some(12345),
Some(89),
Some(-3456),
Some(i32::MIN),
Some(i32::MAX),
])));
let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]);
let expected = &Int32Array::from(vec![
Some(1),
None,
Some(6),
Some(4),
Some(54),
Some(33),
Some(31),
]);

let ColumnarValue::Array(result) = spark_bit_count([args])? else {
unreachable!()
Expand All @@ -145,4 +178,16 @@ mod tests {

Ok(())
}

#[test]
fn bitwise_count_scalar() {
let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MAX)));

match spark_bit_count([args]) {
Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(actual)))) => {
assert_eq!(actual, 63)
}
_ => unreachable!(),
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ SELECT bit_get(i, pos) FROM test_bit_get
query
SELECT 1111 & 2, 1111 | 2, 1111 ^ 2

query ignore(https://github.com/apache/datafusion-comet/issues/3341)
query
SELECT bit_count(0), bit_count(7), bit_count(-1)

query spark_answer_only
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHe

test("bitwise_count - min/max values") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
withSQLConf(
"parquet.enable.dictionary" -> dictionary.toString,
"spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") {
val table = "bitwise_count_test"
withTable(table) {
sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet")
Expand Down