From 2d95fc556d8bbc8fc1e0e4a15802d3f7240fc4bf Mon Sep 17 00:00:00 2001 From: Huang-Hsiang Cheng Date: Thu, 29 Jan 2026 22:57:52 -0800 Subject: [PATCH 1/2] fix: add scalar support for bit_count expression The bit_count function now handles scalar inputs in addition to arrays. Scalar inputs return scalar outputs, maintaining proper type semantics. Enable bit_count tests in bitwise.sql Co-Authored-By: Claude Sonnet 4.5 --- .../src/bitwise_funcs/bitwise_count.rs | 57 +++++++++++++++++-- .../sql-tests/expressions/bitwise/bitwise.sql | 2 +- .../comet/CometBitwiseExpressionSuite.scala | 4 +- 3 files changed, 55 insertions(+), 8 deletions(-) diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index 4ab63e532c..b65c507320 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -16,7 +16,7 @@ // under the License. use arrow::{array::*, datatypes::DataType}; -use datafusion::common::{exec_err, internal_datafusion_err, internal_err, Result}; +use datafusion::common::{exec_err, internal_datafusion_err, Result}; use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; use std::any::Any; @@ -99,15 +99,38 @@ pub fn spark_bit_count(args: [ColumnarValue; 1]) -> Result { DataType::Int16 => compute_op!(array, Int16Array), DataType::Int32 => compute_op!(array, Int32Array), DataType::Int64 => compute_op!(array, Int64Array), - _ => exec_err!("bit_count can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()), + _ => exec_err!("bit_count can't be evaluated because the array's type is {:?}, not signed int/boolean", array.data_type()), }; result.map(ColumnarValue::Array) } - [ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise count scalar path"), + [ColumnarValue::Scalar(scalar)] => { + use datafusion::common::ScalarValue; + let result = match scalar { + ScalarValue::Int8(Some(v)) => bit_count(v as i64), + ScalarValue::Int16(Some(v)) => bit_count(v as i64), + ScalarValue::Int32(Some(v)) => bit_count(v as i64), + ScalarValue::Int64(Some(v)) => bit_count(v), + ScalarValue::Boolean(Some(v)) => bit_count(if v { 1 } else { 0 }), + ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Boolean(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))) + } + _ => { + return exec_err!( + "bit_count can't be evaluated because the scalar's type is {:?}, not signed int/boolean", + scalar.data_type() + ) + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } } } -// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType) +// Here’s the equivalent Rust implementation of the bitCount function (similar to Java's bitCount for LongType) fn bit_count(i: i64) -> i32 { let mut u = i as u64; u = u - ((u >> 1) & 0x5555555555555555); @@ -121,7 +144,7 @@ fn bit_count(i: i64) -> i32 { #[cfg(test)] mod tests { - use datafusion::common::{cast::as_int32_array, Result}; + use datafusion::common::{cast::as_int32_array, Result, ScalarValue}; use super::*; @@ -133,8 +156,18 @@ mod tests { Some(12345), Some(89), Some(-3456), + Some(i32::MIN), + Some(i32::MAX), ]))); - let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]); + let expected = &Int32Array::from(vec![ + Some(1), + None, + Some(6), + Some(4), + Some(54), + Some(33), + Some(31), + ]); let ColumnarValue::Array(result) = spark_bit_count([args])? else { unreachable!() @@ -145,4 +178,16 @@ mod tests { Ok(()) } + + #[test] + fn bitwise_count_scalar() { + let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MAX))); + + match spark_bit_count([args]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(actual)))) => { + assert_eq!(actual, 63) + } + _ => unreachable!(), + } + } } diff --git a/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql b/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql index 640aa1e990..0cf6125082 100644 --- a/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql +++ b/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql @@ -73,7 +73,7 @@ SELECT bit_get(i, pos) FROM test_bit_get query SELECT 1111 & 2, 1111 | 2, 1111 ^ 2 -query ignore(https://github.com/apache/datafusion-comet/issues/3341) +query spark_answer_only SELECT bit_count(0), bit_count(7), bit_count(-1) query spark_answer_only diff --git a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala index 99a57b1575..d0adbe5c56 100644 --- a/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometBitwiseExpressionSuite.scala @@ -138,7 +138,9 @@ class CometBitwiseExpressionSuite extends CometTestBase with AdaptiveSparkPlanHe test("bitwise_count - min/max values") { Seq(false, true).foreach { dictionary => - withSQLConf("parquet.enable.dictionary" -> dictionary.toString) { + withSQLConf( + "parquet.enable.dictionary" -> dictionary.toString, + "spark.sql.optimizer.excludedRules" -> "org.apache.spark.sql.catalyst.optimizer.ConstantFolding") { val table = "bitwise_count_test" withTable(table) { sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet") From c6552ad7b460aadb5354aef6f22ac1dc22ae8f44 Mon Sep 17 00:00:00 2001 From: hsiang-c <137842490+hsiang-c@users.noreply.github.com> Date: Wed, 4 Feb 2026 17:31:31 -0800 Subject: [PATCH 2/2] Update spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql Co-authored-by: Andy Grove --- .../test/resources/sql-tests/expressions/bitwise/bitwise.sql | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql b/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql index 0cf6125082..74a971f368 100644 --- a/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql +++ b/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql @@ -73,7 +73,7 @@ SELECT bit_get(i, pos) FROM test_bit_get query SELECT 1111 & 2, 1111 | 2, 1111 ^ 2 -query spark_answer_only +query SELECT bit_count(0), bit_count(7), bit_count(-1) query spark_answer_only