Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure sum for better auto-vectorization for floats #4560

Closed
wants to merge 10 commits into from
192 changes: 163 additions & 29 deletions arrow-arith/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,44 +285,178 @@ where
return None;
}

let data: &[T::Native] = array.values();
fn sum_impl_integer<T>(array: &PrimitiveArray<T>) -> Option<T::Native>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW if you changed the signature to

fn sum_impl_integer<T: ArrowNativeType>(values: &[T], nulls: Option<&NullBuffer>) -> Option<T>

It would potentially save on codegen, as it would be instantiated per native type not per primitive type

where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
let data: &[T::Native] = array.values();

match array.nulls() {
None => {
let sum = data.iter().fold(T::default_value(), |accumulator, value| {
accumulator.add_wrapping(*value)
});
match array.nulls() {
None => {
let sum = data.iter().fold(T::default_value(), |accumulator, value| {
accumulator.add_wrapping(*value)
});

Some(sum)
Some(sum)
}
Some(nulls) => {
let mut sum = T::default_value();
let data_chunks = data.chunks_exact(64);
let remainder = data_chunks.remainder();

let bit_chunks = nulls.inner().bit_chunks();
data_chunks
.zip(bit_chunks.iter())
.for_each(|(chunk, mask)| {
// index_mask has value 1 << i in the loop
let mut index_mask = 1;
chunk.iter().for_each(|value| {
if (mask & index_mask) != 0 {
sum = sum.add_wrapping(*value);
}
index_mask <<= 1;
});
});

let remainder_bits = bit_chunks.remainder_bits();

remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
sum = sum.add_wrapping(*value);
}
});

Some(sum)
}
}
Some(nulls) => {
let mut sum = T::default_value();
let data_chunks = data.chunks_exact(64);
let remainder = data_chunks.remainder();

let bit_chunks = nulls.inner().bit_chunks();
data_chunks
.zip(bit_chunks.iter())
.for_each(|(chunk, mask)| {
// index_mask has value 1 << i in the loop
let mut index_mask = 1;
chunk.iter().for_each(|value| {
if (mask & index_mask) != 0 {
sum = sum.add_wrapping(*value);
}

fn sum_impl_floating<T, const LANES: usize>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above

array: &PrimitiveArray<T>,
) -> Option<T::Native>
where
T: ArrowNumericType,
T::Native: ArrowNativeTypeOp,
{
let data: &[T::Native] = array.values();
let mut chunk_acc = [T::default_value(); LANES];
let mut rem_acc = T::default_value();

match array.nulls() {
None => {
let data_chunks = data.chunks_exact(LANES);
let remainder = data_chunks.remainder();

data_chunks.for_each(|chunk| {
let chunk: [T::Native; LANES] = chunk.try_into().unwrap();

for i in 0..LANES {
chunk_acc[i] = chunk_acc[i].add_wrapping(chunk[i]);
}
});

remainder.iter().copied().for_each(|value| {
rem_acc = rem_acc.add_wrapping(value);
});

let mut reduced = T::default_value();
for v in chunk_acc {
reduced = reduced.add_wrapping(v);
}
let sum = reduced.add_wrapping(rem_acc);

Some(sum)
}
Some(nulls) => {
// process data in chunks of 64 elements since we also get 64 bits of validity information at a time
let data_chunks = data.chunks_exact(64);
let remainder = data_chunks.remainder();

let bit_chunks = nulls.inner().bit_chunks();
let remainder_bits = bit_chunks.remainder_bits();

data_chunks.zip(bit_chunks).for_each(|(chunk, mut mask)| {
// split chunks further into slices corresponding to the vector length
// the compiler is able to unroll this inner loop and remove bounds checks
// since the outer chunk size (64) is always a multiple of the number of lanes
chunk.chunks_exact(LANES).for_each(|chunk| {
let mut chunk: [T::Native; LANES] = chunk.try_into().unwrap();

for i in 0..LANES {
if mask & (1 << i) == 0 {
chunk[i] = T::default_value();
}
chunk_acc[i] = chunk_acc[i].add_wrapping(chunk[i]);
}
index_mask <<= 1;
});

mask >>= LANES;
})
});

let remainder_bits = bit_chunks.remainder_bits();
remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
rem_acc = rem_acc.add_wrapping(*value);
}
});

remainder.iter().enumerate().for_each(|(i, value)| {
if remainder_bits & (1 << i) != 0 {
sum = sum.add_wrapping(*value);
let mut reduced = T::default_value();
for v in chunk_acc {
reduced = reduced.add_wrapping(v);
}
});
let sum = reduced.add_wrapping(rem_acc);

Some(sum)
Some(sum)
}
}
}

match T::DATA_TYPE {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This match block is kind of grim, but I don't have a better solution off the top of my head... Perhaps some sort of trait 🤔

DataType::Timestamp(_, _)
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Date32
| DataType::Date64
| DataType::Duration(_)
| DataType::Interval(_)
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64 => sum_impl_integer(array),
DataType::Float16
| DataType::Float32
| DataType::Float64
| DataType::Decimal128(_, _)
| DataType::Decimal256(_, _) => match T::lanes() {
Comment on lines +433 to +434
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is decimal here?

1 => sum_impl_floating::<T, 1>(array),
2 => sum_impl_floating::<T, 2>(array),
4 => sum_impl_floating::<T, 4>(array),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs to me that we have 3 floating point types, we could just dispatch to sum_impl_floating with the appropriate constant specified, without needing ArrowNumericType?

8 => sum_impl_floating::<T, 8>(array),
16 => sum_impl_floating::<T, 16>(array),
32 => sum_impl_floating::<T, 32>(array),
64 => sum_impl_floating::<T, 64>(array),
unhandled => unreachable!("Unhandled number of lanes: {unhandled}"),
},
DataType::Null
| DataType::Boolean
| DataType::Binary
| DataType::FixedSizeBinary(_)
| DataType::LargeBinary
| DataType::Utf8
| DataType::LargeUtf8
| DataType::List(_)
| DataType::FixedSizeList(_, _)
| DataType::LargeList(_)
| DataType::Struct(_)
| DataType::Union(_, _)
| DataType::Dictionary(_, _)
| DataType::Map(_, _)
| DataType::RunEndEncoded(_, _) => {
unreachable!("Unsupported data type: {:?}", T::DATA_TYPE)
}
}
}
Expand Down
86 changes: 52 additions & 34 deletions arrow-array/src/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,13 @@ where

/// A subtype of primitive type that represents numeric values.
#[cfg(not(feature = "simd"))]
pub trait ArrowNumericType: ArrowPrimitiveType {}
pub trait ArrowNumericType: ArrowPrimitiveType {
/// The number of SIMD lanes available
fn lanes() -> usize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels a little off to define this for all the types, but then only use it for a special case of floats 🤔

}

macro_rules! make_numeric_type {
($impl_ty:ty, $native_ty:ty, $simd_ty:ident, $simd_mask_ty:ident) => {
($impl_ty:ty, $native_ty:ty, $simd_ty:ident, $simd_mask_ty:ident, $lanes:expr) => {
#[cfg(feature = "simd")]
impl ArrowNumericType for $impl_ty {
type Simd = $simd_ty;
Expand Down Expand Up @@ -336,42 +339,52 @@ macro_rules! make_numeric_type {
}

#[cfg(not(feature = "simd"))]
impl ArrowNumericType for $impl_ty {}
impl ArrowNumericType for $impl_ty {
#[inline]
fn lanes() -> usize {
$lanes
}
}
};
}

make_numeric_type!(Int8Type, i8, i8x64, m8x64);
make_numeric_type!(Int16Type, i16, i16x32, m16x32);
make_numeric_type!(Int32Type, i32, i32x16, m32x16);
make_numeric_type!(Int64Type, i64, i64x8, m64x8);
make_numeric_type!(UInt8Type, u8, u8x64, m8x64);
make_numeric_type!(UInt16Type, u16, u16x32, m16x32);
make_numeric_type!(UInt32Type, u32, u32x16, m32x16);
make_numeric_type!(UInt64Type, u64, u64x8, m64x8);
make_numeric_type!(Float32Type, f32, f32x16, m32x16);
make_numeric_type!(Float64Type, f64, f64x8, m64x8);

make_numeric_type!(TimestampSecondType, i64, i64x8, m64x8);
make_numeric_type!(TimestampMillisecondType, i64, i64x8, m64x8);
make_numeric_type!(TimestampMicrosecondType, i64, i64x8, m64x8);
make_numeric_type!(TimestampNanosecondType, i64, i64x8, m64x8);
make_numeric_type!(Date32Type, i32, i32x16, m32x16);
make_numeric_type!(Date64Type, i64, i64x8, m64x8);
make_numeric_type!(Time32SecondType, i32, i32x16, m32x16);
make_numeric_type!(Time32MillisecondType, i32, i32x16, m32x16);
make_numeric_type!(Time64MicrosecondType, i64, i64x8, m64x8);
make_numeric_type!(Time64NanosecondType, i64, i64x8, m64x8);
make_numeric_type!(IntervalYearMonthType, i32, i32x16, m32x16);
make_numeric_type!(IntervalDayTimeType, i64, i64x8, m64x8);
make_numeric_type!(IntervalMonthDayNanoType, i128, i128x4, m128x4);
make_numeric_type!(DurationSecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8);
make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8);
make_numeric_type!(Decimal128Type, i128, i128x4, m128x4);
make_numeric_type!(Int8Type, i8, i8x64, m8x64, 64);
make_numeric_type!(Int16Type, i16, i16x32, m16x32, 32);
make_numeric_type!(Int32Type, i32, i32x16, m32x16, 16);
make_numeric_type!(Int64Type, i64, i64x8, m64x8, 8);
make_numeric_type!(UInt8Type, u8, u8x64, m8x64, 64);
make_numeric_type!(UInt16Type, u16, u16x32, m16x32, 32);
make_numeric_type!(UInt32Type, u32, u32x16, m32x16, 16);
make_numeric_type!(UInt64Type, u64, u64x8, m64x8, 8);
make_numeric_type!(Float32Type, f32, f32x16, m32x16, 16);
make_numeric_type!(Float64Type, f64, f64x8, m64x8, 8);

make_numeric_type!(TimestampSecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(TimestampMillisecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(TimestampMicrosecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(TimestampNanosecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(Date32Type, i32, i32x16, m32x16, 16);
make_numeric_type!(Date64Type, i64, i64x8, m64x8, 8);
make_numeric_type!(Time32SecondType, i32, i32x16, m32x16, 16);
make_numeric_type!(Time32MillisecondType, i32, i32x16, m32x16, 16);
make_numeric_type!(Time64MicrosecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(Time64NanosecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(IntervalYearMonthType, i32, i32x16, m32x16, 16);
make_numeric_type!(IntervalDayTimeType, i64, i64x8, m64x8, 8);
make_numeric_type!(IntervalMonthDayNanoType, i128, i128x4, m128x4, 4);
make_numeric_type!(DurationSecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(DurationMillisecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(DurationMicrosecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(DurationNanosecondType, i64, i64x8, m64x8, 8);
make_numeric_type!(Decimal128Type, i128, i128x4, m128x4, 4);

#[cfg(not(feature = "simd"))]
impl ArrowNumericType for Float16Type {}
impl ArrowNumericType for Float16Type {
#[inline]
fn lanes() -> usize {
Float32Type::lanes()
}
}

#[cfg(feature = "simd")]
impl ArrowNumericType for Float16Type {
Expand Down Expand Up @@ -467,7 +480,12 @@ impl ArrowNumericType for Float16Type {
}

#[cfg(not(feature = "simd"))]
impl ArrowNumericType for Decimal256Type {}
impl ArrowNumericType for Decimal256Type {
#[inline]
fn lanes() -> usize {
1
}
}

#[cfg(feature = "simd")]
impl ArrowNumericType for Decimal256Type {
Expand Down