diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index 04417c666c85..e6314c973f4f 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -285,44 +285,178 @@ where return None; } - let data: &[T::Native] = array.values(); + fn sum_impl_integer(array: &PrimitiveArray) -> Option + where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, + { + let data: &[T::Native] = array.values(); - match array.nulls() { - None => { - let sum = data.iter().fold(T::default_value(), |accumulator, value| { - accumulator.add_wrapping(*value) - }); + match array.nulls() { + None => { + let sum = data.iter().fold(T::default_value(), |accumulator, value| { + accumulator.add_wrapping(*value) + }); - Some(sum) + Some(sum) + } + Some(nulls) => { + let mut sum = T::default_value(); + let data_chunks = data.chunks_exact(64); + let remainder = data_chunks.remainder(); + + let bit_chunks = nulls.inner().bit_chunks(); + data_chunks + .zip(bit_chunks.iter()) + .for_each(|(chunk, mask)| { + // index_mask has value 1 << i in the loop + let mut index_mask = 1; + chunk.iter().for_each(|value| { + if (mask & index_mask) != 0 { + sum = sum.add_wrapping(*value); + } + index_mask <<= 1; + }); + }); + + let remainder_bits = bit_chunks.remainder_bits(); + + remainder.iter().enumerate().for_each(|(i, value)| { + if remainder_bits & (1 << i) != 0 { + sum = sum.add_wrapping(*value); + } + }); + + Some(sum) + } } - Some(nulls) => { - let mut sum = T::default_value(); - let data_chunks = data.chunks_exact(64); - let remainder = data_chunks.remainder(); - - let bit_chunks = nulls.inner().bit_chunks(); - data_chunks - .zip(bit_chunks.iter()) - .for_each(|(chunk, mask)| { - // index_mask has value 1 << i in the loop - let mut index_mask = 1; - chunk.iter().for_each(|value| { - if (mask & index_mask) != 0 { - sum = sum.add_wrapping(*value); + } + + fn sum_impl_floating( + array: &PrimitiveArray, + ) -> Option + where + T: ArrowNumericType, + T::Native: ArrowNativeTypeOp, + { + let data: &[T::Native] = array.values(); + let mut chunk_acc = [T::default_value(); LANES]; + let mut rem_acc = T::default_value(); + + match array.nulls() { + None => { + let data_chunks = data.chunks_exact(LANES); + let remainder = data_chunks.remainder(); + + data_chunks.for_each(|chunk| { + let chunk: [T::Native; LANES] = chunk.try_into().unwrap(); + + for i in 0..LANES { + chunk_acc[i] = chunk_acc[i].add_wrapping(chunk[i]); + } + }); + + remainder.iter().copied().for_each(|value| { + rem_acc = rem_acc.add_wrapping(value); + }); + + let mut reduced = T::default_value(); + for v in chunk_acc { + reduced = reduced.add_wrapping(v); + } + let sum = reduced.add_wrapping(rem_acc); + + Some(sum) + } + Some(nulls) => { + // process data in chunks of 64 elements since we also get 64 bits of validity information at a time + let data_chunks = data.chunks_exact(64); + let remainder = data_chunks.remainder(); + + let bit_chunks = nulls.inner().bit_chunks(); + let remainder_bits = bit_chunks.remainder_bits(); + + data_chunks.zip(bit_chunks).for_each(|(chunk, mut mask)| { + // split chunks further into slices corresponding to the vector length + // the compiler is able to unroll this inner loop and remove bounds checks + // since the outer chunk size (64) is always a multiple of the number of lanes + chunk.chunks_exact(LANES).for_each(|chunk| { + let mut chunk: [T::Native; LANES] = chunk.try_into().unwrap(); + + for i in 0..LANES { + if mask & (1 << i) == 0 { + chunk[i] = T::default_value(); + } + chunk_acc[i] = chunk_acc[i].add_wrapping(chunk[i]); } - index_mask <<= 1; - }); + + mask >>= LANES; + }) }); - let remainder_bits = bit_chunks.remainder_bits(); + remainder.iter().enumerate().for_each(|(i, value)| { + if remainder_bits & (1 << i) != 0 { + rem_acc = rem_acc.add_wrapping(*value); + } + }); - remainder.iter().enumerate().for_each(|(i, value)| { - if remainder_bits & (1 << i) != 0 { - sum = sum.add_wrapping(*value); + let mut reduced = T::default_value(); + for v in chunk_acc { + reduced = reduced.add_wrapping(v); } - }); + let sum = reduced.add_wrapping(rem_acc); - Some(sum) + Some(sum) + } + } + } + + match T::DATA_TYPE { + DataType::Timestamp(_, _) + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Date32 + | DataType::Date64 + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => sum_impl_integer(array), + DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => match T::lanes() { + 1 => sum_impl_floating::(array), + 2 => sum_impl_floating::(array), + 4 => sum_impl_floating::(array), + 8 => sum_impl_floating::(array), + 16 => sum_impl_floating::(array), + 32 => sum_impl_floating::(array), + 64 => sum_impl_floating::(array), + unhandled => unreachable!("Unhandled number of lanes: {unhandled}"), + }, + DataType::Null + | DataType::Boolean + | DataType::Binary + | DataType::FixedSizeBinary(_) + | DataType::LargeBinary + | DataType::Utf8 + | DataType::LargeUtf8 + | DataType::List(_) + | DataType::FixedSizeList(_, _) + | DataType::LargeList(_) + | DataType::Struct(_) + | DataType::Union(_, _) + | DataType::Dictionary(_, _) + | DataType::Map(_, _) + | DataType::RunEndEncoded(_, _) => { + unreachable!("Unsupported data type: {:?}", T::DATA_TYPE) } } } diff --git a/arrow-array/src/numeric.rs b/arrow-array/src/numeric.rs index afc0e2c33010..68ecb74bc86b 100644 --- a/arrow-array/src/numeric.rs +++ b/arrow-array/src/numeric.rs @@ -113,10 +113,13 @@ where /// A subtype of primitive type that represents numeric values. #[cfg(not(feature = "simd"))] -pub trait ArrowNumericType: ArrowPrimitiveType {} +pub trait ArrowNumericType: ArrowPrimitiveType { + /// The number of SIMD lanes available + fn lanes() -> usize; +} macro_rules! make_numeric_type { - ($impl_ty:ty, $native_ty:ty, $simd_ty:ident, $simd_mask_ty:ident) => { + ($impl_ty:ty, $native_ty:ty, $simd_ty:ident, $simd_mask_ty:ident, $lanes:expr) => { #[cfg(feature = "simd")] impl ArrowNumericType for $impl_ty { type Simd = $simd_ty; @@ -336,42 +339,52 @@ macro_rules! make_numeric_type { } #[cfg(not(feature = "simd"))] - impl ArrowNumericType for $impl_ty {} + impl ArrowNumericType for $impl_ty { + #[inline] + fn lanes() -> usize { + $lanes + } + } }; } -make_numeric_type!(Int8Type, i8, i8x64, m8x64); -make_numeric_type!(Int16Type, i16, i16x32, m16x32); -make_numeric_type!(Int32Type, i32, i32x16, m32x16); -make_numeric_type!(Int64Type, i64, i64x8, m64x8); -make_numeric_type!(UInt8Type, u8, u8x64, m8x64); -make_numeric_type!(UInt16Type, u16, u16x32, m16x32); -make_numeric_type!(UInt32Type, u32, u32x16, m32x16); -make_numeric_type!(UInt64Type, u64, u64x8, m64x8); -make_numeric_type!(Float32Type, f32, f32x16, m32x16); -make_numeric_type!(Float64Type, f64, f64x8, m64x8); - -make_numeric_type!(TimestampSecondType, i64, i64x8, m64x8); -make_numeric_type!(TimestampMillisecondType, i64, i64x8, m64x8); -make_numeric_type!(TimestampMicrosecondType, i64, i64x8, m64x8); -make_numeric_type!(TimestampNanosecondType, i64, i64x8, m64x8); -make_numeric_type!(Date32Type, i32, i32x16, m32x16); -make_numeric_type!(Date64Type, i64, i64x8, m64x8); -make_numeric_type!(Time32SecondType, i32, i32x16, m32x16); -make_numeric_type!(Time32MillisecondType, i32, i32x16, m32x16); -make_numeric_type!(Time64MicrosecondType, i64, i64x8, m64x8); -make_numeric_type!(Time64NanosecondType, i64, i64x8, m64x8); -make_numeric_type!(IntervalYearMonthType, i32, i32x16, m32x16); -make_numeric_type!(IntervalDayTimeType, i64, i64x8, m64x8); -make_numeric_type!(IntervalMonthDayNanoType, i128, i128x4, m128x4); -make_numeric_type!(DurationSecondType, i64, i64x8, m64x8); -make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8); -make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8); -make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8); -make_numeric_type!(Decimal128Type, i128, i128x4, m128x4); +make_numeric_type!(Int8Type, i8, i8x64, m8x64, 64); +make_numeric_type!(Int16Type, i16, i16x32, m16x32, 32); +make_numeric_type!(Int32Type, i32, i32x16, m32x16, 16); +make_numeric_type!(Int64Type, i64, i64x8, m64x8, 8); +make_numeric_type!(UInt8Type, u8, u8x64, m8x64, 64); +make_numeric_type!(UInt16Type, u16, u16x32, m16x32, 32); +make_numeric_type!(UInt32Type, u32, u32x16, m32x16, 16); +make_numeric_type!(UInt64Type, u64, u64x8, m64x8, 8); +make_numeric_type!(Float32Type, f32, f32x16, m32x16, 16); +make_numeric_type!(Float64Type, f64, f64x8, m64x8, 8); + +make_numeric_type!(TimestampSecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(TimestampMillisecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(TimestampMicrosecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(TimestampNanosecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(Date32Type, i32, i32x16, m32x16, 16); +make_numeric_type!(Date64Type, i64, i64x8, m64x8, 8); +make_numeric_type!(Time32SecondType, i32, i32x16, m32x16, 16); +make_numeric_type!(Time32MillisecondType, i32, i32x16, m32x16, 16); +make_numeric_type!(Time64MicrosecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(Time64NanosecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(IntervalYearMonthType, i32, i32x16, m32x16, 16); +make_numeric_type!(IntervalDayTimeType, i64, i64x8, m64x8, 8); +make_numeric_type!(IntervalMonthDayNanoType, i128, i128x4, m128x4, 4); +make_numeric_type!(DurationSecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8, 8); +make_numeric_type!(Decimal128Type, i128, i128x4, m128x4, 4); #[cfg(not(feature = "simd"))] -impl ArrowNumericType for Float16Type {} +impl ArrowNumericType for Float16Type { + #[inline] + fn lanes() -> usize { + Float32Type::lanes() + } +} #[cfg(feature = "simd")] impl ArrowNumericType for Float16Type { @@ -467,7 +480,12 @@ impl ArrowNumericType for Float16Type { } #[cfg(not(feature = "simd"))] -impl ArrowNumericType for Decimal256Type {} +impl ArrowNumericType for Decimal256Type { + #[inline] + fn lanes() -> usize { + 1 + } +} #[cfg(feature = "simd")] impl ArrowNumericType for Decimal256Type { diff --git a/arrow/benches/aggregate_kernels.rs b/arrow/benches/aggregate_kernels.rs index c7b09f70f70e..7536be2365e5 100644 --- a/arrow/benches/aggregate_kernels.rs +++ b/arrow/benches/aggregate_kernels.rs @@ -17,6 +17,10 @@ #[macro_use] extern crate criterion; +use arrow_array::types::{ + Float64Type, TimestampMillisecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::ArrowNumericType; use criterion::Criterion; extern crate arrow; @@ -24,16 +28,18 @@ extern crate arrow; use arrow::compute::kernels::aggregate::*; use arrow::util::bench_util::*; use arrow::{array::*, datatypes::Float32Type}; +use rand::distributions::Standard; +use rand::prelude::Distribution; -fn bench_sum(arr_a: &Float32Array) { +fn bench_sum(arr_a: &PrimitiveArray) { criterion::black_box(sum(arr_a).unwrap()); } -fn bench_min(arr_a: &Float32Array) { +fn bench_min(arr_a: &PrimitiveArray) { criterion::black_box(min(arr_a).unwrap()); } -fn bench_max(arr_a: &Float32Array) { +fn bench_max(arr_a: &PrimitiveArray) { criterion::black_box(max(arr_a).unwrap()); } @@ -41,18 +47,49 @@ fn bench_min_string(arr_a: &StringArray) { criterion::black_box(min_string(arr_a).unwrap()); } +fn sum_min_max_bench( + c: &mut Criterion, + size: usize, + null_density: f32, + description: &str, +) where + T: ArrowNumericType, + Standard: Distribution, +{ + let arr_a = create_primitive_array::(size, null_density); + + c.bench_function(&format!("sum {size} {description}"), |b| { + b.iter(|| bench_sum(&arr_a)) + }); + c.bench_function(&format!("min {size} {description}"), |b| { + b.iter(|| bench_min(&arr_a)) + }); + c.bench_function(&format!("max {size} {description}"), |b| { + b.iter(|| bench_max(&arr_a)) + }); +} + fn add_benchmark(c: &mut Criterion) { - let arr_a = create_primitive_array::(512, 0.0); + sum_min_max_bench::(c, 512, 0.0, "u8 no nulls"); + sum_min_max_bench::(c, 512, 0.5, "u8 50% nulls"); + + sum_min_max_bench::(c, 512, 0.0, "u16 no nulls"); + sum_min_max_bench::(c, 512, 0.5, "u16 50% nulls"); + + sum_min_max_bench::(c, 512, 0.0, "u32 no nulls"); + sum_min_max_bench::(c, 512, 0.5, "u32 50% nulls"); + + sum_min_max_bench::(c, 512, 0.0, "u64 no nulls"); + sum_min_max_bench::(c, 512, 0.5, "u64 50% nulls"); - c.bench_function("sum 512", |b| b.iter(|| bench_sum(&arr_a))); - c.bench_function("min 512", |b| b.iter(|| bench_min(&arr_a))); - c.bench_function("max 512", |b| b.iter(|| bench_max(&arr_a))); + sum_min_max_bench::(c, 512, 0.0, "ts_millis no nulls"); + sum_min_max_bench::(c, 512, 0.5, "ts_millis 50% nulls"); - let arr_a = create_primitive_array::(512, 0.5); + sum_min_max_bench::(c, 512, 0.0, "f32 no nulls"); + sum_min_max_bench::(c, 512, 0.5, "f32 50% nulls"); - c.bench_function("sum nulls 512", |b| b.iter(|| bench_sum(&arr_a))); - c.bench_function("min nulls 512", |b| b.iter(|| bench_min(&arr_a))); - c.bench_function("max nulls 512", |b| b.iter(|| bench_max(&arr_a))); + sum_min_max_bench::(c, 512, 0.0, "f64 no nulls"); + sum_min_max_bench::(c, 512, 0.5, "f64 50% nulls"); let arr_b = create_string_array::(512, 0.0); c.bench_function("min string 512", |b| b.iter(|| bench_min_string(&arr_b)));