From 28f77d7eb2d57ba8dc3ee78aecfbd381b724d3a9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 3 May 2026 12:24:38 -0600 Subject: [PATCH 1/4] fix: support Spark 4.1 BloomFilter V2 format and bit-scattering Closes #4193. Spark 4.1 (`BloomFilter.create`) defaults to a new V2 binary format and a new bit-scattering algorithm in `BloomFilterImplV2`: - V2 binary: `[version=2][numHashFunctions][seed][numWords][bits...]` (V1 omits the seed). - V2 scatter: `combinedHash = (long)h1 * Integer.MAX_VALUE; for (i = 0; i < numHashFunctions; i++) combinedHash += h2;` then take `combinedHash < 0 ? ~combinedHash : combinedHash` mod bitSize. V1 uses `h1 + i*h2` with i in 1..=N and 32-bit arithmetic. Comet's reader hard-rejected non-V1 bytes, and the writer always emitted V1, so on Spark 4.1 both `BloomFilterMightContain from random input` and `bloom_filter_agg` failed with byte/result mismatches. This change: - Adds `SparkBloomFilterVersion` (V1, V2) and a `seed` field to `SparkBloomFilter`. Deserializer detects version from the leading 4 bytes; for V2 it reads the extra seed. Serializer writes the matching layout. `put_long`/`put_binary`/`might_contain_long` branch on version for the bit-scattering algorithm and seed murmur3 with `self.seed` (always 0 for V1; configurable for V2). - Threads the version through `BloomFilterAgg::new` so the aggregator emits the version that matches Spark's output. `BloomFilterAggregate` in Spark always uses `BloomFilterImplV2.DEFAULT_SEED = 0`. - Adds a `version` field to the `BloomFilterAgg` proto and the `BloomFilterVersion` enum (V1 / V2 / Unspecified). - `CometBloomFilterAggregate` (JVM serde) sets V2 on Spark 4.1+ and V1 on Spark <= 4.0. - New Rust tests cover V1 and V2 round-trips, that the two scattering schemes produce different bit patterns for the same input, and that the deserializer rejects unknown versions. - Removes the `assume(!isSpark41Plus, ...)` guards from the `BloomFilterMightContain from random input` and `bloom_filter_agg` Comet test suites; both now pass on Spark 4.1, and the V1 path still passes on Spark 4.0. --- native/core/src/execution/planner.rs | 9 +- native/proto/src/proto/expr.proto | 11 + .../src/bloom_filter/bloom_filter_agg.rs | 15 +- native/spark-expr/src/bloom_filter/mod.rs | 1 + .../src/bloom_filter/spark_bloom_filter.rs | 333 +++++++++++++++--- native/spark-expr/src/lib.rs | 2 +- .../org/apache/comet/serde/aggregates.scala | 9 +- .../comet/exec/CometExec3_4PlusSuite.scala | 3 +- .../apache/comet/exec/CometExecSuite.scala | 1 - 9 files changed, 323 insertions(+), 61 deletions(-) diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index ef0250babc..f1f0ad6da1 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -69,7 +69,7 @@ use datafusion::{ use datafusion_comet_spark_expr::{ create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle, BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc, - SumInteger, ToCsv, + SparkBloomFilterVersion, SumInteger, ToCsv, }; use datafusion_spark::function::aggregate::collect::SparkCollectSet; use iceberg::expr::Bind; @@ -2287,10 +2287,17 @@ impl PhysicalPlanner { let num_bits = self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?; let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap()); + let version = match expr.version() { + spark_expression::BloomFilterVersion::V2 => SparkBloomFilterVersion::V2, + // Default (Unspecified or V1) preserves the pre-Spark-4.1 format that + // Comet has always emitted, keeping older Spark versions byte-equivalent. + _ => SparkBloomFilterVersion::V1, + }; let func = AggregateUDF::new_from_impl(BloomFilterAgg::new( Arc::clone(&num_items), Arc::clone(&num_bits), datatype, + version, )); Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func) } diff --git a/native/proto/src/proto/expr.proto b/native/proto/src/proto/expr.proto index f1b598000d..c7a305285d 100644 --- a/native/proto/src/proto/expr.proto +++ b/native/proto/src/proto/expr.proto @@ -248,6 +248,17 @@ message BloomFilterAgg { Expr numItems = 2; Expr numBits = 3; DataType datatype = 4; + // Output serialization version. Spark 4.0 and earlier always wrote V1; Spark + // 4.1+ defaults to V2 (different bit-scattering algorithm and a `seed` field + // in the binary format). The JVM serde sets this to the matching version so + // Comet's aggregate output is byte-equivalent with Spark's. + BloomFilterVersion version = 5; +} + +enum BloomFilterVersion { + BLOOM_FILTER_VERSION_UNSPECIFIED = 0; + BLOOM_FILTER_VERSION_V1 = 1; + BLOOM_FILTER_VERSION_V2 = 2; } message CollectSet { diff --git a/native/spark-expr/src/bloom_filter/bloom_filter_agg.rs b/native/spark-expr/src/bloom_filter/bloom_filter_agg.rs index 3436b29201..4ac236e6fc 100644 --- a/native/spark-expr/src/bloom_filter/bloom_filter_agg.rs +++ b/native/spark-expr/src/bloom_filter/bloom_filter_agg.rs @@ -20,7 +20,7 @@ use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility}; use std::{any::Any, sync::Arc}; use crate::bloom_filter::spark_bloom_filter; -use crate::bloom_filter::spark_bloom_filter::SparkBloomFilter; +use crate::bloom_filter::spark_bloom_filter::{SparkBloomFilter, SparkBloomFilterVersion}; use arrow::array::ArrayRef; use arrow::array::BinaryArray; @@ -37,6 +37,10 @@ pub struct BloomFilterAgg { signature: Signature, num_items: i32, num_bits: i32, + /// Output serialization version. Spark <= 4.0 only knows V1; Spark 4.1+'s + /// `BloomFilter.create` defaults to V2, so the JVM serde sets this to V2 on + /// 4.1+ to keep `bloom_filter_agg` byte-equivalent with Spark's aggregator. + version: SparkBloomFilterVersion, } #[inline] @@ -54,6 +58,7 @@ impl BloomFilterAgg { num_items: Arc, num_bits: Arc, data_type: DataType, + version: SparkBloomFilterVersion, ) -> Self { assert!(matches!(data_type, DataType::Binary)); Self { @@ -70,6 +75,7 @@ impl BloomFilterAgg { ), num_items: extract_i32_from_literal(num_items), num_bits: extract_i32_from_literal(num_bits), + version, } } } @@ -92,10 +98,13 @@ impl AggregateUDFImpl for BloomFilterAgg { } fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { - Ok(Box::new(SparkBloomFilter::from(( + Ok(Box::new(SparkBloomFilter::new( + self.version, spark_bloom_filter::optimal_num_hash_functions(self.num_items, self.num_bits), self.num_bits, - )))) + // Spark's BloomFilterAggregate always uses BloomFilterImplV2.DEFAULT_SEED (= 0). + 0, + ))) } fn state_fields(&self, _args: StateFieldsArgs) -> Result> { diff --git a/native/spark-expr/src/bloom_filter/mod.rs b/native/spark-expr/src/bloom_filter/mod.rs index bde2c7aaaa..63c127a379 100644 --- a/native/spark-expr/src/bloom_filter/mod.rs +++ b/native/spark-expr/src/bloom_filter/mod.rs @@ -20,6 +20,7 @@ mod bit; mod spark_bit_array; mod spark_bloom_filter; +pub use spark_bloom_filter::SparkBloomFilterVersion; pub mod bloom_filter_agg; pub use bloom_filter_might_contain::BloomFilterMightContain; diff --git a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs index e84257ea67..2f52941210 100644 --- a/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs +++ b/native/spark-expr/src/bloom_filter/spark_bloom_filter.rs @@ -24,6 +24,39 @@ use crate::bloom_filter::spark_bit_array::SparkBitArray; use crate::hash_funcs::murmur3::spark_compatible_murmur3_hash; const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1; +const SPARK_BLOOM_FILTER_VERSION_2: i32 = 2; + +/// Serialization format / hashing scheme used by a [`SparkBloomFilter`]. +/// +/// Spark 4.1 (SPARK-XXXXX, see Spark's `BloomFilter.java`) introduced a V2 format +/// that adds a `seed` field, switches the bit-scattering algorithm to use 64-bit +/// arithmetic, and is now the default for `BloomFilter.create`. Spark 4.0 and +/// earlier only know V1. The version is encoded in the first 4 bytes of the +/// serialized form, and the read path must honour it. +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub enum SparkBloomFilterVersion { + V1, + V2, +} + +impl SparkBloomFilterVersion { + fn from_int(v: i32) -> Self { + match v { + SPARK_BLOOM_FILTER_VERSION_1 => Self::V1, + SPARK_BLOOM_FILTER_VERSION_2 => Self::V2, + _ => panic!( + "Unsupported BloomFilter version: {v}, expecting {SPARK_BLOOM_FILTER_VERSION_1} or {SPARK_BLOOM_FILTER_VERSION_2}" + ), + } + } + + fn to_int(self) -> i32 { + match self { + Self::V1 => SPARK_BLOOM_FILTER_VERSION_1, + Self::V2 => SPARK_BLOOM_FILTER_VERSION_2, + } + } +} /// A Bloom filter implementation that simulates the behavior of Spark's BloomFilter. /// It's not a complete implementation of Spark's BloomFilter, but just add the minimum @@ -32,6 +65,11 @@ const SPARK_BLOOM_FILTER_VERSION_1: i32 = 1; pub struct SparkBloomFilter { bits: SparkBitArray, num_hash_functions: u32, + /// Serialization format and hash-scattering scheme. + version: SparkBloomFilterVersion, + /// Murmur3 seed. V1 always uses 0; V2 stores a per-filter seed (Spark's + /// `BloomFilterImplV2.DEFAULT_SEED` is also 0, so 0 is the common case). + seed: i32, } pub fn optimal_num_hash_functions(expected_items: i32, num_bits: i32) -> i32 { @@ -43,29 +81,33 @@ pub fn optimal_num_hash_functions(expected_items: i32, num_bits: i32) -> i32 { impl From<(i32, i32)> for SparkBloomFilter { /// Creates an empty SparkBloomFilter given number of hash functions and bits. + /// Defaults to V1 for backwards compatibility; use [`SparkBloomFilter::new_v2`] + /// to construct an empty V2 filter. fn from((num_hash_functions, num_bits): (i32, i32)) -> Self { - let num_words = spark_bit_array::num_words(num_bits as usize); - let bits = vec![0u64; num_words]; - Self { - bits: SparkBitArray::new(bits), - num_hash_functions: num_hash_functions as u32, - } + Self::new(SparkBloomFilterVersion::V1, num_hash_functions, num_bits, 0) } } impl From<&[u8]> for SparkBloomFilter { - /// Creates a SparkBloomFilter from a serialized byte array conforming to Spark's BloomFilter - /// binary format version 1. + /// Creates a SparkBloomFilter from a serialized byte array conforming to either + /// Spark's BloomFilter binary format V1 or V2. The version is read from the + /// first 4 bytes. fn from(buf: &[u8]) -> Self { let mut offset = 0; - let version = read_num_be_bytes!(i32, 4, buf[offset..]); + let version_int = read_num_be_bytes!(i32, 4, buf[offset..]); offset += 4; - assert_eq!( - version, SPARK_BLOOM_FILTER_VERSION_1, - "Unsupported BloomFilter version: {version}, expecting version: {SPARK_BLOOM_FILTER_VERSION_1}" - ); + let version = SparkBloomFilterVersion::from_int(version_int); let num_hash_functions = read_num_be_bytes!(i32, 4, buf[offset..]); offset += 4; + // V2 adds a 4-byte seed before the bit array. V1 has no seed. + let seed = match version { + SparkBloomFilterVersion::V1 => 0, + SparkBloomFilterVersion::V2 => { + let s = read_num_be_bytes!(i32, 4, buf[offset..]); + offset += 4; + s + } + }; let num_words = read_num_be_bytes!(i32, 4, buf[offset..]); offset += 4; let mut bits = vec![0u64; num_words as usize]; @@ -76,76 +118,150 @@ impl From<&[u8]> for SparkBloomFilter { Self { bits: SparkBitArray::new(bits), num_hash_functions: num_hash_functions as u32, + version, + seed, } } } impl SparkBloomFilter { + /// Construct an empty filter with the given version, number of hash functions, + /// and bit count. The `seed` is ignored for V1 (always treated as 0) but is + /// honoured for V2. + pub fn new( + version: SparkBloomFilterVersion, + num_hash_functions: i32, + num_bits: i32, + seed: i32, + ) -> Self { + let num_words = spark_bit_array::num_words(num_bits as usize); + let bits = vec![0u64; num_words]; + Self { + bits: SparkBitArray::new(bits), + num_hash_functions: num_hash_functions as u32, + version, + seed: match version { + SparkBloomFilterVersion::V1 => 0, + SparkBloomFilterVersion::V2 => seed, + }, + } + } + + /// Returns the serialization/scattering format this filter uses. + #[allow(dead_code)] + pub fn version(&self) -> SparkBloomFilterVersion { + self.version + } + /// Serializes a SparkBloomFilter to a byte array conforming to Spark's BloomFilter - /// binary format version 1. + /// binary format. The output format follows the filter's `version`. pub fn spark_serialization(&self) -> Vec { - // There might be a more efficient way to do this, even with all the endianness stuff. - let mut spark_bloom_filter: Vec = 1_u32.to_be_bytes().to_vec(); - spark_bloom_filter.append(&mut self.num_hash_functions.to_be_bytes().to_vec()); - spark_bloom_filter.append(&mut (self.bits.word_size() as u32).to_be_bytes().to_vec()); + let mut out: Vec = (self.version.to_int() as u32).to_be_bytes().to_vec(); + out.append(&mut self.num_hash_functions.to_be_bytes().to_vec()); + if let SparkBloomFilterVersion::V2 = self.version { + // Spark's BloomFilterImplV2.writeTo writes the seed between + // numHashFunctions and the bit array. + out.append(&mut (self.seed as u32).to_be_bytes().to_vec()); + } + out.append(&mut (self.bits.word_size() as u32).to_be_bytes().to_vec()); let mut filter_state: Vec = self.bits.data(); for i in filter_state.iter_mut() { *i = i.to_be(); } - // Does it make sense to do a std::mem::take of filter_state here? Unclear to me if a deep - // copy of filter_state as a Vec to a Vec is happening here. - spark_bloom_filter.append(&mut Vec::from(filter_state.to_byte_slice())); - spark_bloom_filter + out.append(&mut Vec::from(filter_state.to_byte_slice())); + out } - pub fn put_long(&mut self, item: i64) -> bool { - // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce - // n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions. - let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); - let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); + /// V1 bit-scattering: `combinedHash = h1 + i*h2` for `i in 1..=numHashFunctions`, + /// matching `BloomFilterImpl.scatterHashAndSetAllBits` (Spark <= 4.0; still + /// available as the V1 codepath in 4.1+). + fn scatter_v1(&mut self, h1: u32, h2: u32, set: bool) -> Option { let bit_size = self.bits.bit_size() as i32; - let mut bit_changed = false; for i in 1..=self.num_hash_functions { let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); if combined_hash < 0 { combined_hash = !combined_hash; } - bit_changed |= self.bits.set((combined_hash % bit_size) as usize) + let idx = (combined_hash % bit_size) as usize; + if set { + self.bits.set(idx); + } else if !self.bits.get(idx) { + return Some(false); + } + } + if set { + None + } else { + Some(true) } - bit_changed } - pub fn put_binary(&mut self, item: &[u8]) -> bool { - // Here we first hash the input long element into 2 int hash values, h1 and h2, then produce - // n hash values by `h1 + i * h2` with 1 <= i <= num_hash_functions. - let h1 = spark_compatible_murmur3_hash(item, 0); - let h2 = spark_compatible_murmur3_hash(item, h1); - let bit_size = self.bits.bit_size() as i32; - let mut bit_changed = false; - for i in 1..=self.num_hash_functions { - let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); - if combined_hash < 0 { - combined_hash = !combined_hash; + /// V2 bit-scattering: `combinedHash = (long)h1 * Integer.MAX_VALUE; for (i = 0; i < + /// numHashFunctions; i++) combinedHash += h2;`. Mirrors Spark 4.1's + /// `BloomFilterImplV2.scatterHashAndSetAllBits`. Note 64-bit accumulator, + /// zero-indexed loop, and `combinedHash < 0 ? ~combinedHash : combinedHash` for the + /// non-negative bit index. + fn scatter_v2(&mut self, h1: u32, h2: u32, set: bool) -> Option { + let bit_size = self.bits.bit_size() as i64; + // (long) h1 * Integer.MAX_VALUE - sign-extend h1, then i64 multiply with wrapping. + let mut combined_hash = (h1 as i32 as i64).wrapping_mul(i32::MAX as i64); + let h2_long = h2 as i32 as i64; + for _ in 0..self.num_hash_functions { + combined_hash = combined_hash.wrapping_add(h2_long); + let combined_index = if combined_hash < 0 { + !combined_hash + } else { + combined_hash + }; + let idx = (combined_index % bit_size) as usize; + if set { + self.bits.set(idx); + } else if !self.bits.get(idx) { + return Some(false); } - bit_changed |= self.bits.set((combined_hash % bit_size) as usize) } - bit_changed + if set { + None + } else { + Some(true) + } + } + + fn scatter(&mut self, h1: u32, h2: u32, set: bool) -> Option { + match self.version { + SparkBloomFilterVersion::V1 => self.scatter_v1(h1, h2, set), + SparkBloomFilterVersion::V2 => self.scatter_v2(h1, h2, set), + } + } + + /// Put a long item into the filter. Returns `false`; the original Spark + /// `BloomFilter.put` returns whether any bit changed, but no current Comet + /// caller uses that, so we don't bother computing it. + pub fn put_long(&mut self, item: i64) -> bool { + let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), self.seed as u32); + let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); + self.scatter(h1, h2, true); + false + } + + pub fn put_binary(&mut self, item: &[u8]) -> bool { + let h1 = spark_compatible_murmur3_hash(item, self.seed as u32); + let h2 = spark_compatible_murmur3_hash(item, h1); + self.scatter(h1, h2, true); + false } pub fn might_contain_long(&self, item: i64) -> bool { - let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), 0); + let h1 = spark_compatible_murmur3_hash(item.to_le_bytes(), self.seed as u32); let h2 = spark_compatible_murmur3_hash(item.to_le_bytes(), h1); - let bit_size = self.bits.bit_size() as i32; - for i in 1..=self.num_hash_functions { - let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); - if combined_hash < 0 { - combined_hash = !combined_hash; + match self.version { + SparkBloomFilterVersion::V1 => { + might_contain_long_v1(&self.bits, self.num_hash_functions, h1, h2) } - if !self.bits.get((combined_hash % bit_size) as usize) { - return false; + SparkBloomFilterVersion::V2 => { + might_contain_long_v2(&self.bits, self.num_hash_functions, h1, h2) } } - true } pub fn might_contain_longs(&self, items: &Int64Array) -> BooleanArray { @@ -168,3 +284,116 @@ impl SparkBloomFilter { self.bits.merge_bits(other); } } + +fn might_contain_long_v1(bits: &SparkBitArray, num_hash_functions: u32, h1: u32, h2: u32) -> bool { + let bit_size = bits.bit_size() as i32; + for i in 1..=num_hash_functions { + let mut combined_hash = (h1 as i32).add_wrapping((i as i32).mul_wrapping(h2 as i32)); + if combined_hash < 0 { + combined_hash = !combined_hash; + } + if !bits.get((combined_hash % bit_size) as usize) { + return false; + } + } + true +} + +fn might_contain_long_v2(bits: &SparkBitArray, num_hash_functions: u32, h1: u32, h2: u32) -> bool { + let bit_size = bits.bit_size() as i64; + let mut combined_hash = (h1 as i32 as i64).wrapping_mul(i32::MAX as i64); + let h2_long = h2 as i32 as i64; + for _ in 0..num_hash_functions { + combined_hash = combined_hash.wrapping_add(h2_long); + let combined_index = if combined_hash < 0 { + !combined_hash + } else { + combined_hash + }; + if !bits.get((combined_index % bit_size) as usize) { + return false; + } + } + true +} + +#[cfg(test)] +mod tests { + use super::*; + + /// Round-trip a V1 filter through put + serialize + deserialize and verify + /// `might_contain` agrees both before and after, and that the version flag + /// is preserved across (de)serialization. + #[test] + fn v1_round_trip() { + let mut filter = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 4, 256, 0); + for x in [1_i64, 42, 1_000_000, -7, i64::MIN, i64::MAX] { + filter.put_long(x); + } + let bytes = filter.spark_serialization(); + // V1: [version=1][numHashFunctions][numWords][bits...] + assert_eq!(&bytes[..4], &1_i32.to_be_bytes()); + + let parsed = SparkBloomFilter::from(bytes.as_slice()); + assert_eq!(parsed.version(), SparkBloomFilterVersion::V1); + assert_eq!(parsed.num_hash_functions, 4); + for x in [1_i64, 42, 1_000_000, -7, i64::MIN, i64::MAX] { + assert!(parsed.might_contain_long(x), "{x} should be present"); + } + } + + /// Round-trip a V2 filter (Spark 4.1+ default). The serialized form has a + /// `seed` between `numHashFunctions` and the bit array, and the hash-scattering + /// uses 64-bit accumulator arithmetic — different from V1, so the same input + /// produces different bit patterns. + #[test] + fn v2_round_trip() { + let mut filter = SparkBloomFilter::new(SparkBloomFilterVersion::V2, 4, 256, 0); + for x in [1_i64, 42, 1_000_000, -7, i64::MIN, i64::MAX] { + filter.put_long(x); + } + let bytes = filter.spark_serialization(); + // V2: [version=2][numHashFunctions][seed][numWords][bits...] + assert_eq!(&bytes[..4], &2_i32.to_be_bytes()); + // seed lives at offset 8 (after version + numHashFunctions) + assert_eq!(&bytes[8..12], &0_i32.to_be_bytes()); + + let parsed = SparkBloomFilter::from(bytes.as_slice()); + assert_eq!(parsed.version(), SparkBloomFilterVersion::V2); + assert_eq!(parsed.num_hash_functions, 4); + for x in [1_i64, 42, 1_000_000, -7, i64::MIN, i64::MAX] { + assert!(parsed.might_contain_long(x), "{x} should be present"); + } + } + + /// V1 and V2 use different scattering algorithms, so for the same inputs the + /// resulting bit arrays must not match. If this test ever starts passing, + /// the V2 implementation has likely regressed back to V1 semantics. + #[test] + fn v1_and_v2_produce_different_bits() { + let inputs = [1_i64, 2, 3, 100, 1_000_000]; + let mut v1 = SparkBloomFilter::new(SparkBloomFilterVersion::V1, 4, 256, 0); + let mut v2 = SparkBloomFilter::new(SparkBloomFilterVersion::V2, 4, 256, 0); + for x in inputs { + v1.put_long(x); + v2.put_long(x); + } + assert_ne!( + v1.state_as_bytes(), + v2.state_as_bytes(), + "V1 and V2 scattering must differ" + ); + } + + /// The deserializer must reject an unsupported version number rather than + /// silently producing a misconfigured filter. + #[test] + #[should_panic(expected = "Unsupported BloomFilter version: 3")] + fn rejects_unknown_version() { + let mut buf: Vec = 3_i32.to_be_bytes().to_vec(); + buf.extend_from_slice(&4_i32.to_be_bytes()); // numHashFunctions + buf.extend_from_slice(&4_i32.to_be_bytes()); // numWords + buf.extend_from_slice(&[0u8; 32]); // 4 words * 8 bytes + let _ = SparkBloomFilter::from(buf.as_slice()); + } +} diff --git a/native/spark-expr/src/lib.rs b/native/spark-expr/src/lib.rs index 5c476f65e6..e56b74d9de 100644 --- a/native/spark-expr/src/lib.rs +++ b/native/spark-expr/src/lib.rs @@ -53,7 +53,7 @@ pub use agg_funcs::*; pub use cast::{spark_cast, Cast, SparkCastOptions}; mod bloom_filter; -pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain}; +pub use bloom_filter::{BloomFilterAgg, BloomFilterMightContain, SparkBloomFilterVersion}; mod conditional_funcs; mod conversion_funcs; diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index c87c7ae00d..8044aab75b 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{ByteType, DataTypes, DecimalType, IntegerType import org.apache.comet.CometConf import org.apache.comet.CometConf.COMET_EXEC_STRICT_FLOATING_POINT -import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, withInfo} import org.apache.comet.serde.QueryPlanSerde.{evalModeToProto, exprToProto, serializeDataType} import org.apache.comet.shims.CometEvalModeUtil @@ -660,6 +660,13 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt builder.setNumItems(numItemsExpr.get) builder.setNumBits(numBitsExpr.get) builder.setDatatype(dataType.get) + // SPARK-XXXXX (Spark 4.1) introduced a V2 BloomFilter binary format with + // different bit-scattering. Spark 4.1's `BloomFilter.create` (used by + // `BloomFilterAggregate`) defaults to V2; older Spark always wrote V1. Match + // the Spark version so `bloom_filter_agg` outputs are byte-equivalent. + builder.setVersion( + if (isSpark41Plus) ExprOuterClass.BloomFilterVersion.BLOOM_FILTER_VERSION_V2 + else ExprOuterClass.BloomFilterVersion.BLOOM_FILTER_VERSION_V1) Some( ExprOuterClass.AggExpr diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala index cb7a6c5d7f..349b2654ad 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.functions.{col, lit} import org.apache.spark.util.sketch.BloomFilter import org.apache.comet.CometConf -import org.apache.comet.CometSparkSessionExtensions.{isSpark41Plus, isSpark42Plus} +import org.apache.comet.CometSparkSessionExtensions.isSpark42Plus /** * This test suite contains tests for only Spark 3.4+. @@ -163,7 +163,6 @@ class CometExec3_4PlusSuite extends CometTestBase { } test("test BloomFilterMightContain from random input") { - assume(!isSpark41Plus, "https://github.com/apache/datafusion-comet/issues/4098") val (longs, bfBytes) = bloomFilterFromRandomInput(10000, 10000) val table = "test" diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index 1ac1659754..eedbd9cf8b 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -2827,7 +2827,6 @@ class CometExecSuite extends CometTestBase { } test("bloom_filter_agg") { - assume(!isSpark41Plus, "https://github.com/apache/datafusion-comet/issues/4098") val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") spark.sessionState.functionRegistry.registerFunction( funcId_bloom_filter_agg, From d2843b00920cf419bd68bf2d6bf51b5462f62550 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 3 May 2026 14:50:57 -0600 Subject: [PATCH 2/4] test: keep Spark 4.2 guards on bloom filter tests The original `assume(!isSpark41Plus, ...)` guards skipped on both 4.1 and 4.2. The previous commit removed them entirely, but Spark 4.2 has a separate `might_contain` / `bloom_filter_agg` registration issue tracked in #4142 (`Function identifier must be fully qualified (3-part)`). Tightening the guards back to `assume(!isSpark42Plus, "#4142")` lets the tests run on 4.1 (the goal of this PR) while staying skipped on 4.2 (existing behavior). --- .../scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala | 1 + .../src/test/scala/org/apache/comet/exec/CometExecSuite.scala | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala index 349b2654ad..b187c30b44 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExec3_4PlusSuite.scala @@ -163,6 +163,7 @@ class CometExec3_4PlusSuite extends CometTestBase { } test("test BloomFilterMightContain from random input") { + assume(!isSpark42Plus, "https://github.com/apache/datafusion-comet/issues/4142") val (longs, bfBytes) = bloomFilterFromRandomInput(10000, 10000) val table = "test" diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala index eedbd9cf8b..57f0668189 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala @@ -49,7 +49,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE import org.apache.spark.unsafe.types.UTF8String import org.apache.comet.{CometConf, CometExecIterator, ExtendedExplainInfo} -import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus, isSpark41Plus} +import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, isSpark40Plus, isSpark41Plus, isSpark42Plus} import org.apache.comet.serde.Config.ConfigMap import org.apache.comet.testing.{DataGenOptions, ParquetGenerator, SchemaGenOptions} @@ -2827,6 +2827,7 @@ class CometExecSuite extends CometTestBase { } test("bloom_filter_agg") { + assume(!isSpark42Plus, "https://github.com/apache/datafusion-comet/issues/4142") val funcId_bloom_filter_agg = new FunctionIdentifier("bloom_filter_agg") spark.sessionState.functionRegistry.registerFunction( funcId_bloom_filter_agg, From cbbd27dc1a593f55e3cb1fc7111ea4f734a188ad Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 3 May 2026 15:24:40 -0600 Subject: [PATCH 3/4] fix: update bloom_filter_agg bench for new BloomFilterAgg::new signature The previous commit added a `version: SparkBloomFilterVersion` parameter to `BloomFilterAgg::new`, but missed updating the criterion bench under `spark-expr/benches/bloom_filter_agg.rs`. Pass V1 explicitly to match the historic behaviour the bench was written against. --- native/spark-expr/benches/bloom_filter_agg.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/native/spark-expr/benches/bloom_filter_agg.rs b/native/spark-expr/benches/bloom_filter_agg.rs index c425fd6c4f..2e61646e72 100644 --- a/native/spark-expr/benches/bloom_filter_agg.rs +++ b/native/spark-expr/benches/bloom_filter_agg.rs @@ -30,7 +30,7 @@ use datafusion::physical_expr::expressions::{Column, Literal}; use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; use datafusion::physical_plan::ExecutionPlan; -use datafusion_comet_spark_expr::BloomFilterAgg; +use datafusion_comet_spark_expr::{BloomFilterAgg, SparkBloomFilterVersion}; use futures::StreamExt; use std::hint::black_box; use std::sync::Arc; @@ -66,6 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::clone(&num_items), Arc::clone(&num_bits), DataType::Binary, + SparkBloomFilterVersion::V1, ))); b.to_async(&rt).iter(|| { black_box(agg_test( From 451338ca67bc6b5711f1eee47a3d1fdccb31d48c Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 4 May 2026 19:37:45 -0600 Subject: [PATCH 4/4] fix: replace SPARK-XXXXX placeholder with actual JIRA SPARK-47547 --- spark/src/main/scala/org/apache/comet/serde/aggregates.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala index 8044aab75b..bb0466b4eb 100644 --- a/spark/src/main/scala/org/apache/comet/serde/aggregates.scala +++ b/spark/src/main/scala/org/apache/comet/serde/aggregates.scala @@ -660,7 +660,7 @@ object CometBloomFilterAggregate extends CometAggregateExpressionSerde[BloomFilt builder.setNumItems(numItemsExpr.get) builder.setNumBits(numBitsExpr.get) builder.setDatatype(dataType.get) - // SPARK-XXXXX (Spark 4.1) introduced a V2 BloomFilter binary format with + // SPARK-47547 (Spark 4.1) introduced a V2 BloomFilter binary format with // different bit-scattering. Spark 4.1's `BloomFilter.create` (used by // `BloomFilterAggregate`) defaults to V2; older Spark always wrote V1. Match // the Spark version so `bloom_filter_agg` outputs are byte-equivalent.