Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ use datafusion::{
use datafusion_comet_spark_expr::{
create_comet_physical_fun, create_comet_physical_fun_with_eval_mode, BinaryOutputStyle,
BloomFilterAgg, BloomFilterMightContain, CsvWriteOptions, EvalMode, SparkArraysZipFunc,
SumInteger, ToCsv,
SparkBloomFilterVersion, SumInteger, ToCsv,
};
use datafusion_spark::function::aggregate::collect::SparkCollectSet;
use iceberg::expr::Bind;
Expand Down Expand Up @@ -2287,10 +2287,17 @@ impl PhysicalPlanner {
let num_bits =
self.create_expr(expr.num_bits.as_ref().unwrap(), Arc::clone(&schema))?;
let datatype = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let version = match expr.version() {
spark_expression::BloomFilterVersion::V2 => SparkBloomFilterVersion::V2,
// Default (Unspecified or V1) preserves the pre-Spark-4.1 format that
// Comet has always emitted, keeping older Spark versions byte-equivalent.
_ => SparkBloomFilterVersion::V1,
};
let func = AggregateUDF::new_from_impl(BloomFilterAgg::new(
Arc::clone(&num_items),
Arc::clone(&num_bits),
datatype,
version,
));
Self::create_aggr_func_expr("bloom_filter_agg", schema, vec![child], func)
}
Expand Down
11 changes: 11 additions & 0 deletions native/proto/src/proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,17 @@ message BloomFilterAgg {
Expr numItems = 2;
Expr numBits = 3;
DataType datatype = 4;
// Output serialization version. Spark 4.0 and earlier always wrote V1; Spark
// 4.1+ defaults to V2 (different bit-scattering algorithm and a `seed` field
// in the binary format). The JVM serde sets this to the matching version so
// Comet's aggregate output is byte-equivalent with Spark's.
BloomFilterVersion version = 5;
}

enum BloomFilterVersion {
BLOOM_FILTER_VERSION_UNSPECIFIED = 0;
BLOOM_FILTER_VERSION_V1 = 1;
BLOOM_FILTER_VERSION_V2 = 2;
}

message CollectSet {
Expand Down
3 changes: 2 additions & 1 deletion native/spark-expr/benches/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::physical_expr::expressions::{Column, Literal};
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy};
use datafusion::physical_plan::ExecutionPlan;
use datafusion_comet_spark_expr::BloomFilterAgg;
use datafusion_comet_spark_expr::{BloomFilterAgg, SparkBloomFilterVersion};
use futures::StreamExt;
use std::hint::black_box;
use std::sync::Arc;
Expand Down Expand Up @@ -66,6 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) {
Arc::clone(&num_items),
Arc::clone(&num_bits),
DataType::Binary,
SparkBloomFilterVersion::V1,
)));
b.to_async(&rt).iter(|| {
black_box(agg_test(
Expand Down
15 changes: 12 additions & 3 deletions native/spark-expr/src/bloom_filter/bloom_filter_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use datafusion::{arrow::datatypes::DataType, logical_expr::Volatility};
use std::{any::Any, sync::Arc};

use crate::bloom_filter::spark_bloom_filter;
use crate::bloom_filter::spark_bloom_filter::SparkBloomFilter;
use crate::bloom_filter::spark_bloom_filter::{SparkBloomFilter, SparkBloomFilterVersion};

use arrow::array::ArrayRef;
use arrow::array::BinaryArray;
Expand All @@ -37,6 +37,10 @@ pub struct BloomFilterAgg {
signature: Signature,
num_items: i32,
num_bits: i32,
/// Output serialization version. Spark <= 4.0 only knows V1; Spark 4.1+'s
/// `BloomFilter.create` defaults to V2, so the JVM serde sets this to V2 on
/// 4.1+ to keep `bloom_filter_agg` byte-equivalent with Spark's aggregator.
version: SparkBloomFilterVersion,
}

#[inline]
Expand All @@ -54,6 +58,7 @@ impl BloomFilterAgg {
num_items: Arc<dyn PhysicalExpr>,
num_bits: Arc<dyn PhysicalExpr>,
data_type: DataType,
version: SparkBloomFilterVersion,
) -> Self {
assert!(matches!(data_type, DataType::Binary));
Self {
Expand All @@ -70,6 +75,7 @@ impl BloomFilterAgg {
),
num_items: extract_i32_from_literal(num_items),
num_bits: extract_i32_from_literal(num_bits),
version,
}
}
}
Expand All @@ -92,10 +98,13 @@ impl AggregateUDFImpl for BloomFilterAgg {
}

fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(SparkBloomFilter::from((
Ok(Box::new(SparkBloomFilter::new(
self.version,
spark_bloom_filter::optimal_num_hash_functions(self.num_items, self.num_bits),
self.num_bits,
))))
// Spark's BloomFilterAggregate always uses BloomFilterImplV2.DEFAULT_SEED (= 0).
0,
)))
}

fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
Expand Down
1 change: 1 addition & 0 deletions native/spark-expr/src/bloom_filter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ mod bit;

mod spark_bit_array;
mod spark_bloom_filter;
pub use spark_bloom_filter::SparkBloomFilterVersion;

pub mod bloom_filter_agg;
pub use bloom_filter_might_contain::BloomFilterMightContain;
Expand Down
Loading
Loading