diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 2db599047bcd..0b4c3af1d9c5 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,25 +15,28 @@ // specific language governing permissions and limitations // under the License. +use super::{Column, Literal}; +use crate::expressions::case::ResultState::{Complete, Empty, Partial}; use crate::expressions::try_cast; use crate::PhysicalExpr; -use std::borrow::Cow; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use arrow::array::*; use arrow::compute::kernels::zip::zip; -use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; -use arrow::datatypes::{DataType, Schema}; +use arrow::compute::{ + is_not_null, not, nullif, prep_null_mask_filter, FilterBuilder, FilterPredicate, +}; +use arrow::datatypes::{DataType, Schema, UInt32Type}; +use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, }; use datafusion_expr::ColumnarValue; - -use super::{Column, Literal}; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; +use std::borrow::Cow; +use std::fmt::{Debug, Formatter}; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; type WhenThen = (Arc, Arc); @@ -98,7 +101,7 @@ pub struct CaseExpr { } impl std::fmt::Display for CaseExpr { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { write!(f, "CASE ")?; if let Some(e) = &self.expr { write!(f, "{e} ")?; @@ -122,6 +125,419 @@ fn is_cheap_and_infallible(expr: &Arc) -> bool { expr.as_any().is::() } +/// Creates a [FilterPredicate] from a boolean array. +fn create_filter(predicate: &BooleanArray) -> FilterPredicate { + let mut filter_builder = FilterBuilder::new(predicate); + // Always optimize the filter since we use them multiple times. + filter_builder = filter_builder.optimize(); + filter_builder.build() +} + +// This should be removed when https://github.com/apache/arrow-rs/pull/8693 +// is merged and becomes available. +fn filter_record_batch( + record_batch: &RecordBatch, + filter: &FilterPredicate, +) -> std::result::Result { + let filtered_columns = record_batch + .columns() + .iter() + .map(|a| filter_array(a, filter)) + .collect::, _>>()?; + // SAFETY: since we start from a valid RecordBatch, there's no need to revalidate the schema + // since the set of columns has not changed. + // The input column arrays all had the same length (since they're coming from a valid RecordBatch) + // and the filtering them with the same filter will produces a new set of arrays with identical + // lengths. + unsafe { + Ok(RecordBatch::new_unchecked( + record_batch.schema(), + filtered_columns, + filter.count(), + )) + } +} + +// This function exists purely to be able to use the same call style +// for `filter_record_batch` and `filter_array` at the point of use. +// When https://github.com/apache/arrow-rs/pull/8693 is available, replace +// both with method calls on `FilterPredicate`. +#[inline(always)] +fn filter_array( + array: &dyn Array, + filter: &FilterPredicate, +) -> std::result::Result { + filter.filter(array) +} + +/// Merges elements by index from a list of [`ArrayData`], creating a new [`ColumnarValue`] from +/// those values. +/// +/// Each element in `indices` is the index of an array in `values`. The `indices` array is processed +/// sequentially. The first occurrence of index value `n` will be mapped to the first +/// value of the array at index `n`. The second occurrence to the second value, and so on. +/// An index value where `PartialResultIndex::is_none` is `true` is used to indicate null values. +/// +/// # Implementation notes +/// +/// This algorithm is similar in nature to both `zip` and `interleave`, but there are some important +/// differences. +/// +/// In contrast to `zip`, this function supports multiple input arrays. Instead of a boolean +/// selection vector, an index array is to take values from the input arrays, and a special marker +/// value is used to indicate null values. +/// +/// In contrast to `interleave`, this function does not use pairs of indices. The values in +/// `indices` serve the same purpose as the first value in the pairs passed to `interleave`. +/// The index in the array is implicit and is derived from the number of times a particular array +/// index occurs. +/// The more constrained indexing mechanism used by this algorithm makes it easier to copy values +/// in contiguous slices. In the example below, the two subsequent elements from array `2` can be +/// copied in a single operation from the source array instead of copying them one by one. +/// Long spans of null values are also especially cheap because they do not need to be represented +/// in an input array. +/// +/// # Safety +/// +/// This function does not check that the number of occurrences of any particular array index matches +/// the length of the corresponding input array. If an array contains more values than required, the +/// spurious values will be ignored. If an array contains fewer values than necessary, this function +/// will panic. +/// +/// # Example +/// +/// ```text +/// ┌───────────┐ ┌─────────┐ ┌─────────┐ +/// │┌─────────┐│ │ None │ │ NULL │ +/// ││ A ││ ├─────────┤ ├─────────┤ +/// │└─────────┘│ │ 1 │ │ B │ +/// │┌─────────┐│ ├─────────┤ ├─────────┤ +/// ││ B ││ │ 0 │ merge(values, indices) │ A │ +/// │└─────────┘│ ├─────────┤ ─────────────────────────▶ ├─────────┤ +/// │┌─────────┐│ │ None │ │ NULL │ +/// ││ C ││ ├─────────┤ ├─────────┤ +/// │├─────────┤│ │ 2 │ │ C │ +/// ││ D ││ ├─────────┤ ├─────────┤ +/// │└─────────┘│ │ 2 │ │ D │ +/// └───────────┘ └─────────┘ └─────────┘ +/// values indices result +/// +/// ``` +fn merge(values: &[ArrayData], indices: &[PartialResultIndex]) -> Result { + #[cfg(debug_assertions)] + for ix in indices { + if let Some(index) = ix.index() { + assert!( + index < values.len(), + "Index out of bounds: {} >= {}", + index, + values.len() + ); + } + } + + let data_refs = values.iter().collect(); + let mut mutable = MutableArrayData::new(data_refs, true, indices.len()); + + // This loop extends the mutable array by taking slices from the partial results. + // + // take_offsets keeps track of how many values have been taken from each array. + let mut take_offsets = vec![0; values.len() + 1]; + let mut start_row_ix = 0; + loop { + let array_ix = indices[start_row_ix]; + + // Determine the length of the slice to take. + let mut end_row_ix = start_row_ix + 1; + while end_row_ix < indices.len() && indices[end_row_ix] == array_ix { + end_row_ix += 1; + } + let slice_length = end_row_ix - start_row_ix; + + // Extend mutable with either nulls or with values from the array. + match array_ix.index() { + None => mutable.extend_nulls(slice_length), + Some(index) => { + let start_offset = take_offsets[index]; + let end_offset = start_offset + slice_length; + mutable.extend(index, start_offset, end_offset); + take_offsets[index] = end_offset; + } + } + + if end_row_ix == indices.len() { + break; + } else { + // Set the start_row_ix for the next slice. + start_row_ix = end_row_ix; + } + } + + Ok(make_array(mutable.freeze())) +} + +/// An index into the partial results array that's more compact than `usize`. +/// +/// `u32::MAX` is reserved as a special 'none' value. This is used instead of +/// `Option` to keep the array of indices as compact as possible. +#[derive(Copy, Clone, PartialEq, Eq)] +struct PartialResultIndex { + index: u32, +} + +const NONE_VALUE: u32 = u32::MAX; + +impl PartialResultIndex { + /// Returns the 'none' placeholder value. + fn none() -> Self { + Self { index: NONE_VALUE } + } + + fn zero() -> Self { + Self { index: 0 } + } + + /// Creates a new partial result index. + /// + /// If the provided value is greater than or equal to `u32::MAX` + /// an error will be returned. + fn try_new(index: usize) -> Result { + let Ok(index) = u32::try_from(index) else { + return internal_err!("Partial result index exceeds limit"); + }; + + if index == NONE_VALUE { + return internal_err!("Partial result index exceeds limit"); + } + + Ok(Self { index }) + } + + /// Determines if this index is the 'none' placeholder value or not. + fn is_none(&self) -> bool { + self.index == NONE_VALUE + } + + /// Returns `Some(index)` if this value is not the 'none' placeholder, `None` otherwise. + fn index(&self) -> Option { + if self.is_none() { + None + } else { + Some(self.index as usize) + } + } +} + +impl Debug for PartialResultIndex { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + if self.is_none() { + write!(f, "null") + } else { + write!(f, "{}", self.index) + } + } +} + +enum ResultState { + /// The final result is an array containing only null values. + Empty, + /// The final result needs to be computed by merging the data in `arrays`. + Partial { + // A `Vec` of partial results that should be merged. + // `partial_result_indices` contains indexes into this vec. + arrays: Vec, + // Indicates per result row from which array in `partial_results` a value should be taken. + indices: Vec, + }, + /// A single branch matched all input rows. When creating the final result, no further merging + /// of partial results is necessary. + Complete(ColumnarValue), +} + +/// A builder for constructing result arrays for CASE expressions. +/// +/// Rather than building a monolithic array containing all results, it maintains a set of +/// partial result arrays and a mapping that indicates for each row which partial array +/// contains the result value for that row. +/// +/// On finish(), the builder will merge all partial results into a single array if necessary. +/// If all rows evaluated to the same array, that array can be returned directly without +/// any merging overhead. +struct ResultBuilder { + data_type: DataType, + /// The number of rows in the final result. + row_count: usize, + state: ResultState, +} + +impl ResultBuilder { + /// Creates a new ResultBuilder that will produce arrays of the given data type. + /// + /// The `row_count` parameter indicates the number of rows in the final result. + fn new(data_type: &DataType, row_count: usize) -> Self { + Self { + data_type: data_type.clone(), + row_count, + state: Empty, + } + } + + /// Adds a result for one branch of the case expression. + /// + /// `row_indices` should be a [UInt32Array] containing [RecordBatch] relative row indices + /// for which `value` contains result values. + /// + /// If `value` is a scalar, the scalar value will be used as the value for each row in `row_indices`. + /// + /// If `value` is an array, the values from the array and the indices from `row_indices` will be + /// processed pairwise. The lengths of `value` and `row_indices` must match. + /// + /// The diagram below shows a situation where a when expression matched rows 1 and 4 of the + /// record batch. The then expression produced the value array `[A, D]`. + /// After adding this result, the result array will have been added to `partial arrays` and + /// `partial indices` will have been updated at indexes `1` and `4`. + /// + /// ```text + /// ┌─────────┐ ┌─────────┐┌───────────┐ ┌─────────┐┌───────────┐ + /// │ C │ │ 0: None ││┌ 0 ──────┐│ │ 0: None ││┌ 0 ──────┐│ + /// ├─────────┤ ├─────────┤││ A ││ ├─────────┤││ A ││ + /// │ D │ │ 1: None ││└─────────┘│ │ 1: 2 ││└─────────┘│ + /// └─────────┘ ├─────────┤│┌ 1 ──────┐│ add_branch_result( ├─────────┤│┌ 1 ──────┐│ + /// matching │ 2: 0 │││ B ││ row indices, │ 2: 0 │││ B ││ + /// 'then' values ├─────────┤│└─────────┘│ value ├─────────┤│└─────────┘│ + /// │ 3: None ││ │ ) │ 3: None ││┌ 2 ──────┐│ + /// ┌─────────┐ ├─────────┤│ │ ─────────────────────────▶ ├─────────┤││ C ││ + /// │ 1 │ │ 4: None ││ │ │ 4: 2 ││├─────────┤│ + /// ├─────────┤ ├─────────┤│ │ ├─────────┤││ D ││ + /// │ 4 │ │ 5: 1 ││ │ │ 5: 1 ││└─────────┘│ + /// └─────────┘ └─────────┘└───────────┘ └─────────┘└───────────┘ + /// row indices partial partial partial partial + /// indices arrays indices arrays + /// ``` + fn add_branch_result( + &mut self, + row_indices: &ArrayRef, + value: ColumnarValue, + ) -> Result<()> { + match value { + ColumnarValue::Array(a) => { + if a.len() != row_indices.len() { + internal_err!("Array length must match row indices length") + } else if row_indices.len() == self.row_count { + self.set_complete_result(ColumnarValue::Array(a)) + } else { + self.add_partial_result(row_indices, a.to_data()) + } + } + ColumnarValue::Scalar(s) => { + if row_indices.len() == self.row_count { + self.set_complete_result(ColumnarValue::Scalar(s)) + } else { + self.add_partial_result( + row_indices, + s.to_array_of_size(row_indices.len())?.to_data(), + ) + } + } + } + } + + /// Adds a partial result array. + /// + /// This method adds the given array data as a partial result and updates the index mapping + /// to indicate that the specified rows should take their values from this array. + /// The partial results will be merged into a single array when finish() is called. + fn add_partial_result( + &mut self, + row_indices: &ArrayRef, + row_values: ArrayData, + ) -> Result<()> { + if row_indices.null_count() != 0 { + return internal_err!("Row indices must not contain nulls"); + } + + match &mut self.state { + Empty => { + let array_index = PartialResultIndex::zero(); + let mut indices = vec![PartialResultIndex::none(); self.row_count]; + for row_ix in row_indices.as_primitive::().values().iter() { + indices[*row_ix as usize] = array_index; + } + + self.state = Partial { + arrays: vec![row_values], + indices, + }; + + Ok(()) + } + Partial { arrays, indices } => { + let array_index = PartialResultIndex::try_new(arrays.len())?; + + arrays.push(row_values); + + for row_ix in row_indices.as_primitive::().values().iter() { + // This is check is only active for debug config because the callers of this method, + // `case_when_with_expr` and `case_when_no_expr`, already ensure that + // they only calculate a value for each row at most once. + #[cfg(debug_assertions)] + if !indices[*row_ix as usize].is_none() { + return internal_err!("Duplicate value for row {}", *row_ix); + } + + indices[*row_ix as usize] = array_index; + } + Ok(()) + } + Complete(_) => internal_err!( + "Cannot add a partial result when complete result is already set" + ), + } + } + + /// Sets a result that applies to all rows. + /// + /// This is an optimization for cases where all rows evaluate to the same result. + /// When a complete result is set, the builder will return it directly from finish() + /// without any merging overhead. + fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> { + match &self.state { + Empty => { + self.state = Complete(value); + Ok(()) + } + Partial { .. } => { + internal_err!( + "Cannot set a complete result when there are already partial results" + ) + } + Complete(_) => internal_err!("Complete result already set"), + } + } + + /// Finishes building the result and returns the final array. + fn finish(self) -> Result { + match self.state { + Empty => { + // No complete result and no partial results. + // This can happen for case expressions with no else branch where no rows + // matched. + Ok(ColumnarValue::Scalar(ScalarValue::try_new_null( + &self.data_type, + )?)) + } + Partial { arrays, indices } => { + // Merge partial results into a single array. + Ok(ColumnarValue::Array(merge(&arrays, &indices)?)) + } + Complete(v) => { + // If we have a complete result, we can just return it. + Ok(v) + } + } + } +} + impl CaseExpr { /// Create a new CASE WHEN expression pub fn try_new( @@ -196,82 +612,146 @@ impl CaseExpr { /// END fn case_when_with_expr(&self, batch: &RecordBatch) -> Result { let return_type = self.data_type(&batch.schema())?; - let expr = self.expr.as_ref().unwrap(); - let base_value = expr.evaluate(batch)?; - let base_value = base_value.into_array(batch.num_rows())?; - let base_nulls = is_null(base_value.as_ref())?; - - // start with nulls as default output - let mut current_value = new_null_array(&return_type, batch.num_rows()); - // We only consider non-null values while comparing with whens - let mut remainder = not(&base_nulls)?; - let mut non_null_remainder_count = remainder.true_count(); - for i in 0..self.when_then_expr.len() { - // If there are no rows left to process, break out of the loop early - if non_null_remainder_count == 0 { - break; - } + let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows()); + + // `remainder_rows` contains the indices of the rows that need to be evaluated + let mut remainder_rows: ArrayRef = + Arc::new(UInt32Array::from_iter_values(0..batch.num_rows() as u32)); + // `remainder_batch` contains the rows themselves that need to be evaluated + let mut remainder_batch = Cow::Borrowed(batch); + + // evaluate the base expression + let mut base_values = self + .expr + .as_ref() + .unwrap() + .evaluate(batch)? + .into_array(batch.num_rows())?; - let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows())?; - // build boolean array representing which rows match the "when" value - let when_match = compare_with_eq( - &when_value, - &base_value, - // The types of case and when expressions will be coerced to match. - // We only need to check if the base_value is nested. - base_value.data_type().is_nested(), - )?; - // Treat nulls as false - let when_match = match when_match.null_count() { - 0 => Cow::Borrowed(&when_match), - _ => Cow::Owned(prep_null_mask_filter(&when_match)), - }; - // Make sure we only consider rows that have not been matched yet - let when_value = and(&when_match, &remainder)?; + // Fill in a result value already for rows where the base expression value is null + // Since each when expression is tested against the base expression using the equality + // operator, null base values can never match any when expression. `x = NULL` is falsy, + // for all possible values of `x`. + if base_values.null_count() > 0 { + // Use `is_not_null` since this is a cheap clone of the null buffer from 'base_value'. + // We already checked there are nulls, so we can be sure a new buffer will not be + // created. + let base_not_nulls = is_not_null(base_values.as_ref())?; + let base_all_null = base_values.null_count() == remainder_batch.num_rows(); + + // If there is an else expression, use that as the default value for the null rows + // Otherwise the default `null` value from the result builder will be used. + if let Some(e) = self.else_expr() { + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - // If the predicate did not match any rows, continue to the next branch immediately - let when_match_count = when_value.true_count(); - if when_match_count == 0 { - continue; + if base_all_null { + // All base values were null, so no need to filter + let nulls_value = expr.evaluate(&remainder_batch)?; + result_builder.add_branch_result(&remainder_rows, nulls_value)?; + } else { + // Filter out the null rows and evaluate the else expression for those + let nulls_filter = create_filter(¬(&base_not_nulls)?); + let nulls_batch = + filter_record_batch(&remainder_batch, &nulls_filter)?; + let nulls_rows = filter_array(&remainder_rows, &nulls_filter)?; + let nulls_value = expr.evaluate(&nulls_batch)?; + result_builder.add_branch_result(&nulls_rows, nulls_value)?; + } } - let then_expression = &self.when_then_expr[i].1; - let then_value = then_expression.evaluate_selection(batch, &when_value)?; + // All base values are null, so we can return early + if base_all_null { + return result_builder.finish(); + } - current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? + // Remove the null rows from the remainder batch + let not_null_filter = create_filter(&base_not_nulls); + remainder_batch = + Cow::Owned(filter_record_batch(&remainder_batch, ¬_null_filter)?); + remainder_rows = filter_array(&remainder_rows, ¬_null_filter)?; + base_values = filter_array(&base_values, ¬_null_filter)?; + } + + // The types of case and when expressions will be coerced to match. + // We only need to check if the base_value is nested. + let base_value_is_nested = base_values.data_type().is_nested(); + + for i in 0..self.when_then_expr.len() { + // Evaluate the 'when' predicate for the remainder batch + // This results in a boolean array with the same length as the remaining number of rows + let when_expr = &self.when_then_expr[i].0; + let when_value = match when_expr.evaluate(&remainder_batch)? { + ColumnarValue::Array(a) => { + compare_with_eq(&a, &base_values, base_value_is_nested) } - ColumnarValue::Array(then_value) => { - zip(&when_value, &then_value, ¤t_value)? + ColumnarValue::Scalar(s) => { + let scalar = Scalar::new(s.to_array()?); + compare_with_eq(&scalar, &base_values, base_value_is_nested) } - }; + }?; - remainder = and_not(&remainder, &when_value)?; - non_null_remainder_count -= when_match_count; - } + // `true_count` ignores `true` values where the validity bit is not set, so there's + // no need to call `prep_null_mask_filter`. + let when_true_count = when_value.true_count(); - if let Some(e) = self.else_expr() { - // null and unmatched tuples should be assigned else value - remainder = or(&base_nulls, &remainder)?; + // If the 'when' predicate did not match any rows, continue to the next branch immediately + if when_true_count == 0 { + continue; + } - if remainder.true_count() > 0 { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + // If the 'when' predicate matched all remaining rows, there is no need to filter + if when_true_count == remainder_batch.num_rows() { + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate(&remainder_batch)?; + result_builder.add_branch_result(&remainder_rows, then_value)?; + return result_builder.finish(); + } + + // Filter the remainder batch based on the 'when' value + // This results in a batch containing only the rows that need to be evaluated + // for the current branch + // Still no need to call `prep_null_mask_filter` since `create_filter` will already do + // this unconditionally. + let then_filter = create_filter(&when_value); + let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; + let then_rows = filter_array(&remainder_rows, &then_filter)?; - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate(&then_batch)?; + result_builder.add_branch_result(&then_rows, then_value)?; + + // If this is the last 'when' branch and there is no 'else' expression, there's no + // point in calculating the remaining rows. + if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 { + return result_builder.finish(); } + + // Prepare the next when branch (or the else branch) + let next_selection = match when_value.null_count() { + 0 => not(&when_value), + _ => { + // `prep_null_mask_filter` is required to ensure the not operation treats nulls + // as false + not(&prep_null_mask_filter(&when_value)) + } + }?; + let next_filter = create_filter(&next_selection); + remainder_batch = + Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); + remainder_rows = filter_array(&remainder_rows, &next_filter)?; + base_values = filter_array(&base_values, &next_filter)?; + } + + // If we reached this point, some rows were left unmatched. + // Check if those need to be evaluated using the 'else' expression. + if let Some(e) = self.else_expr() { + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_value = expr.evaluate(&remainder_batch)?; + result_builder.add_branch_result(&remainder_rows, else_value)?; } - Ok(ColumnarValue::Array(current_value)) + result_builder.finish() } /// This function evaluates the form of CASE where each WHEN expression is a boolean @@ -283,70 +763,86 @@ impl CaseExpr { /// END fn case_when_no_expr(&self, batch: &RecordBatch) -> Result { let return_type = self.data_type(&batch.schema())?; + let mut result_builder = ResultBuilder::new(&return_type, batch.num_rows()); - // start with nulls as default output - let mut current_value = new_null_array(&return_type, batch.num_rows()); - let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); - let mut remainder_count = batch.num_rows(); - for i in 0..self.when_then_expr.len() { - // If there are no rows left to process, break out of the loop early - if remainder_count == 0 { - break; - } + // `remainder_rows` contains the indices of the rows that need to be evaluated + let mut remainder_rows: ArrayRef = + Arc::new(UInt32Array::from_iter(0..batch.num_rows() as u32)); + // `remainder_batch` contains the rows themselves that need to be evaluated + let mut remainder_batch = Cow::Borrowed(batch); + for i in 0..self.when_then_expr.len() { + // Evaluate the 'when' predicate for the remainder batch + // This results in a boolean array with the same length as the remaining number of rows let when_predicate = &self.when_then_expr[i].0; - let when_value = when_predicate.evaluate_selection(batch, &remainder)?; - let when_value = when_value.into_array(batch.num_rows())?; + let when_value = when_predicate + .evaluate(&remainder_batch)? + .into_array(remainder_batch.num_rows())?; let when_value = as_boolean_array(&when_value).map_err(|_| { internal_datafusion_err!("WHEN expression did not return a BooleanArray") })?; - // Treat 'NULL' as false value - let when_value = match when_value.null_count() { - 0 => Cow::Borrowed(when_value), - _ => Cow::Owned(prep_null_mask_filter(when_value)), - }; - // Make sure we only consider rows that have not been matched yet - let when_value = and(&when_value, &remainder)?; - // If the predicate did not match any rows, continue to the next branch immediately - let when_match_count = when_value.true_count(); - if when_match_count == 0 { + // `true_count` ignores `true` values where the validity bit is not set, so there's + // no need to call `prep_null_mask_filter`. + let when_true_count = when_value.true_count(); + + // If the 'when' predicate did not match any rows, continue to the next branch immediately + if when_true_count == 0 { continue; } + // If the 'when' predicate matched all remaining rows, there is no need to filter + if when_true_count == remainder_batch.num_rows() { + let then_expression = &self.when_then_expr[i].1; + let then_value = then_expression.evaluate(&remainder_batch)?; + result_builder.add_branch_result(&remainder_rows, then_value)?; + return result_builder.finish(); + } + + // Filter the remainder batch based on the 'when' value + // This results in a batch containing only the rows that need to be evaluated + // for the current branch + // Still no need to call `prep_null_mask_filter` since `create_filter` will already do + // this unconditionally. + let then_filter = create_filter(when_value); + let then_batch = filter_record_batch(&remainder_batch, &then_filter)?; + let then_rows = filter_array(&remainder_rows, &then_filter)?; + let then_expression = &self.when_then_expr[i].1; - let then_value = then_expression.evaluate_selection(batch, &when_value)?; + let then_value = then_expression.evaluate(&then_batch)?; + result_builder.add_branch_result(&then_rows, then_value)?; - current_value = match then_value { - ColumnarValue::Scalar(ScalarValue::Null) => { - nullif(current_value.as_ref(), &when_value)? - } - ColumnarValue::Scalar(then_value) => { - zip(&when_value, &then_value.to_scalar()?, ¤t_value)? - } - ColumnarValue::Array(then_value) => { - zip(&when_value, &then_value, ¤t_value)? - } - }; + // If this is the last 'when' branch and there is no 'else' expression, there's no + // point in calculating the remaining rows. + if self.else_expr.is_none() && i == self.when_then_expr.len() - 1 { + return result_builder.finish(); + } - // Succeed tuples should be filtered out for short-circuit evaluation, - // null values for the current when expr should be kept - remainder = and_not(&remainder, &when_value)?; - remainder_count -= when_match_count; + // Prepare the next when branch (or the else branch) + let next_selection = match when_value.null_count() { + 0 => not(when_value), + _ => { + // `prep_null_mask_filter` is required to ensure the not operation treats nulls + // as false + not(&prep_null_mask_filter(when_value)) + } + }?; + let next_filter = create_filter(&next_selection); + remainder_batch = + Cow::Owned(filter_record_batch(&remainder_batch, &next_filter)?); + remainder_rows = filter_array(&remainder_rows, &next_filter)?; } + // If we reached this point, some rows were left unmatched. + // Check if those need to be evaluated using the 'else' expression. if let Some(e) = self.else_expr() { - if remainder_count > 0 { - // keep `else_expr`'s data type and return type consistent - let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; - let else_ = expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - current_value = zip(&remainder, &else_, ¤t_value)?; - } + // keep `else_expr`'s data type and return type consistent + let expr = try_cast(Arc::clone(e), &batch.schema(), return_type.clone())?; + let else_value = expr.evaluate(&remainder_batch)?; + result_builder.add_branch_result(&remainder_rows, else_value)?; } - Ok(ColumnarValue::Array(current_value)) + result_builder.finish() } /// This function evaluates the specialized case of: @@ -587,7 +1083,7 @@ impl PhysicalExpr for CaseExpr { } } - fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE ")?; if let Some(e) = &self.expr { e.fmt_sql(f)?; diff --git a/datafusion/sqllogictest/test_files/case.slt b/datafusion/sqllogictest/test_files/case.slt index 352300e753a7..4eaa87b0b516 100644 --- a/datafusion/sqllogictest/test_files/case.slt +++ b/datafusion/sqllogictest/test_files/case.slt @@ -595,3 +595,25 @@ SELECT CASE WHEN a = 'a' THEN 0 WHEN a = 'b' THEN 1 ELSE 2 END FROM (VALUES (NUL ---- 2 2 + +# The `WHEN 1/0` is not effectively reachable in this query and should never be executed +query T +SELECT CASE a WHEN 1 THEN 'a' WHEN 2 THEN 'b' WHEN 1 / 0 THEN 'c' ELSE 'd' END FROM (VALUES (1), (2)) t(a) +---- +a +b + +# The `WHEN 1/0` is not effectively reachable in this query and should never be executed +query T +SELECT CASE WHEN a = 1 THEN 'a' WHEN a = 2 THEN 'b' WHEN a = 1 / 0 THEN 'c' ELSE 'd' END FROM (VALUES (1), (2)) t(a) +---- +a +b + +# The `WHEN 1/0` is not effectively reachable in this query and should never be executed +query T +SELECT CASE WHEN a = 0 THEN 'a' WHEN 1 / a = 1 THEN 'b' ELSE 'c' END FROM (VALUES (0), (1), (2)) t(a) +---- +a +b +c