diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index a3c181368d5f..1cf53d5baa9e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. +mod literal_lookup_table; + use super::{Column, Literal}; -use crate::expressions::case::ResultState::{Complete, Empty, Partial}; use crate::expressions::{lit, try_cast}; use crate::PhysicalExpr; use arrow::array::*; @@ -28,7 +29,6 @@ use arrow::compute::{ use arrow::datatypes::{DataType, Schema, UInt32Type, UnionMode}; use arrow::error::ArrowError; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_common::{ assert_or_internal_err, exec_err, internal_datafusion_err, internal_err, DataFusionError, HashMap, HashSet, Result, ScalarValue, @@ -38,11 +38,13 @@ use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; +use crate::expressions::case::literal_lookup_table::LiteralLookupTable; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; use datafusion_physical_expr_common::datum::compare_with_eq; use itertools::Itertools; use std::fmt::{Debug, Formatter}; -type WhenThen = (Arc, Arc); +pub(super) type WhenThen = (Arc, Arc); #[derive(Debug, Hash, PartialEq, Eq)] enum EvalMethod { @@ -73,8 +75,37 @@ enum EvalMethod { /// /// CASE WHEN condition THEN expression ELSE expression END ExpressionOrExpression(ProjectedCaseBody), + + /// This is a specialization for [`EvalMethod::WithExpression`] when the value and results are literals + /// + /// See [`LiteralLookupTable`] for more details + WithExprScalarLookupTable(LiteralLookupTable), +} + +/// Implementing hash so we can use `derive` on [`EvalMethod`]. +/// +/// not implementing actual [`Hash`] as it is not dyn compatible so we cannot implement it for +/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`]. +/// +/// So implementing empty hash is still valid as the data is derived from `PhysicalExpr` s which are already hashed +impl Hash for LiteralLookupTable { + fn hash(&self, _state: &mut H) {} } +/// Implementing Equal so we can use `derive` on [`EvalMethod`]. +/// +/// not implementing actual [`PartialEq`] as it is not dyn compatible so we cannot implement it for +/// `dyn` [`literal_lookup_table::WhenLiteralIndexMap`]. +/// +/// So we always return true as the data is derived from `PhysicalExpr` s which are already compared +impl PartialEq for LiteralLookupTable { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for LiteralLookupTable {} + /// The body of a CASE expression which consists of an optional base expression, the "when/then" /// branches and an optional "else" branch. #[derive(Debug, Hash, PartialEq, Eq)] @@ -591,7 +622,7 @@ impl ResultBuilder { Self { data_type: data_type.clone(), row_count, - state: Empty, + state: ResultState::Empty, } } @@ -671,21 +702,21 @@ impl ResultBuilder { ); match &mut self.state { - Empty => { + ResultState::Empty => { let array_index = PartialResultIndex::zero(); let mut indices = vec![PartialResultIndex::none(); self.row_count]; for row_ix in row_indices.as_primitive::().values().iter() { indices[*row_ix as usize] = array_index; } - self.state = Partial { + self.state = ResultState::Partial { arrays: vec![row_values], indices, }; Ok(()) } - Partial { arrays, indices } => { + ResultState::Partial { arrays, indices } => { let array_index = PartialResultIndex::try_new(arrays.len())?; arrays.push(row_values); @@ -705,7 +736,7 @@ impl ResultBuilder { } Ok(()) } - Complete(_) => internal_err!( + ResultState::Complete(_) => internal_err!( "Cannot add a partial result when complete result is already set" ), } @@ -718,23 +749,23 @@ impl ResultBuilder { /// without any merging overhead. fn set_complete_result(&mut self, value: ColumnarValue) -> Result<()> { match &self.state { - Empty => { - self.state = Complete(value); + ResultState::Empty => { + self.state = ResultState::Complete(value); Ok(()) } - Partial { .. } => { + ResultState::Partial { .. } => { internal_err!( "Cannot set a complete result when there are already partial results" ) } - Complete(_) => internal_err!("Complete result already set"), + ResultState::Complete(_) => internal_err!("Complete result already set"), } } /// Finishes building the result and returns the final array. fn finish(self) -> Result { match self.state { - Empty => { + ResultState::Empty => { // No complete result and no partial results. // This can happen for case expressions with no else branch where no rows // matched. @@ -742,11 +773,11 @@ impl ResultBuilder { &self.data_type, )?)) } - Partial { arrays, indices } => { + ResultState::Partial { arrays, indices } => { // Merge partial results into a single array. Ok(ColumnarValue::Array(merge_n(&arrays, &indices)?)) } - Complete(v) => { + ResultState::Complete(v) => { // If we have a complete result, we can just return it. Ok(v) } @@ -781,28 +812,40 @@ impl CaseExpr { else_expr, }; - let eval_method = if body.expr.is_some() { - EvalMethod::WithExpression(body.project()?) - } else if body.when_then_expr.len() == 1 - && is_cheap_and_infallible(&(body.when_then_expr[0].1)) - && body.else_expr.is_none() - { - EvalMethod::InfallibleExprOrNull - } else if body.when_then_expr.len() == 1 - && body.when_then_expr[0].1.as_any().is::() - && body.else_expr.is_some() - && body.else_expr.as_ref().unwrap().as_any().is::() - { - EvalMethod::ScalarOrScalar - } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() { - EvalMethod::ExpressionOrExpression(body.project()?) - } else { - EvalMethod::NoExpression(body.project()?) - }; + let eval_method = Self::find_best_eval_method(&body)?; Ok(Self { body, eval_method }) } + fn find_best_eval_method(body: &CaseBody) -> Result { + if body.expr.is_some() { + if let Some(mapping) = LiteralLookupTable::maybe_new(body) { + return Ok(EvalMethod::WithExprScalarLookupTable(mapping)); + } + + return Ok(EvalMethod::WithExpression(body.project()?)); + } + + Ok( + if body.when_then_expr.len() == 1 + && is_cheap_and_infallible(&(body.when_then_expr[0].1)) + && body.else_expr.is_none() + { + EvalMethod::InfallibleExprOrNull + } else if body.when_then_expr.len() == 1 + && body.when_then_expr[0].1.as_any().is::() + && body.else_expr.is_some() + && body.else_expr.as_ref().unwrap().as_any().is::() + { + EvalMethod::ScalarOrScalar + } else if body.when_then_expr.len() == 1 && body.else_expr.is_some() { + EvalMethod::ExpressionOrExpression(body.project()?) + } else { + EvalMethod::NoExpression(body.project()?) + }, + ) + } + /// Optional base expression that can be compared to literal values in the "when" expressions pub fn expr(&self) -> Option<&Arc> { self.body.expr.as_ref() @@ -1275,6 +1318,28 @@ impl CaseExpr { self.body.expr_or_expr(batch, when_value) } } + + fn with_lookup_table( + &self, + batch: &RecordBatch, + lookup_table: &LiteralLookupTable, + ) -> Result { + let expr = self.body.expr.as_ref().unwrap(); + let evaluated_expression = expr.evaluate(batch)?; + + let is_scalar = matches!(evaluated_expression, ColumnarValue::Scalar(_)); + let evaluated_expression = evaluated_expression.to_array(1)?; + + let values = lookup_table.map_keys_to_values(&evaluated_expression)?; + + let result = if is_scalar { + ColumnarValue::Scalar(ScalarValue::try_from_array(values.as_ref(), 0)?) + } else { + ColumnarValue::Array(values) + }; + + Ok(result) + } } impl PhysicalExpr for CaseExpr { @@ -1370,6 +1435,9 @@ impl PhysicalExpr for CaseExpr { } EvalMethod::ScalarOrScalar => self.scalar_or_scalar(batch), EvalMethod::ExpressionOrExpression(p) => self.expr_or_expr(batch, p), + EvalMethod::WithExprScalarLookupTable(lookup_table) => { + self.with_lookup_table(batch, lookup_table) + } } } @@ -1515,6 +1583,7 @@ mod tests { use datafusion_expr::type_coercion::binary::comparison_coercion; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr_common::physical_expr::fmt_sql; + use half::f16; #[test] fn case_with_expr() -> Result<()> { @@ -2646,4 +2715,501 @@ mod tests { assert_not_nullable(expr, schema); } } + + // Test Lookup evaluation + + fn test_case_when_literal_lookup( + values: ArrayRef, + lookup_map: &[(ScalarValue, ScalarValue)], + else_value: Option, + expected: ArrayRef, + ) { + // Create lookup + // CASE + // WHEN THEN + // WHEN THEN + // [ ELSE ] + + let schema = Schema::new(vec![Field::new( + "a", + values.data_type().clone(), + values.is_nullable(), + )]); + let schema = Arc::new(schema); + + let batch = RecordBatch::try_new(schema, vec![values]) + .expect("failed to create RecordBatch"); + + let schema = batch.schema_ref(); + let case = col("a", schema).expect("failed to create col"); + + let when_then = lookup_map + .iter() + .map(|(when, then)| { + ( + Arc::new(Literal::new(when.clone())) as _, + Arc::new(Literal::new(then.clone())) as _, + ) + }) + .collect::>(); + + let else_expr = else_value.map(|else_value| { + Arc::new(Literal::new(else_value)) as Arc + }); + let expr = CaseExpr::try_new(Some(case), when_then, else_expr) + .expect("failed to create case"); + + // Assert that we are testing what we intend to assert + assert!( + matches!( + expr.eval_method, + EvalMethod::WithExprScalarLookupTable { .. } + ), + "we should use the expected eval method" + ); + + let actual = expr + .evaluate(&batch) + .expect("failed to evaluate case") + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + + assert_eq!( + actual.data_type(), + expected.data_type(), + "Data type mismatch" + ); + + assert_eq!( + actual.as_ref(), + expected.as_ref(), + "actual (left) does not match expected (right)" + ); + } + + fn create_lookup( + when_then_pairs: impl IntoIterator, + ) -> Vec<(ScalarValue, ScalarValue)> + where + ScalarValue: From, + ScalarValue: From, + { + when_then_pairs + .into_iter() + .map(|(when, then)| (ScalarValue::from(when), ScalarValue::from(then))) + .collect() + } + + fn create_input_and_expected( + input_and_expected_pairs: impl IntoIterator, + ) -> (Input, Expected) + where + Input: Array + From>, + Expected: Array + From>, + { + let (input_items, expected_items): (Vec, Vec) = + input_and_expected_pairs.into_iter().unzip(); + + (Input::from(input_items), Expected::from(expected_items)) + } + + fn test_lookup_eval_with_and_without_else( + lookup_map: &[(ScalarValue, ScalarValue)], + input_values: ArrayRef, + expected: StringArray, + ) { + // Testing without ELSE should fallback to None + test_case_when_literal_lookup( + Arc::clone(&input_values), + lookup_map, + None, + Arc::new(expected.clone()), + ); + + // Testing with Else + let else_value = "___fallback___"; + + // Changing each expected None to be fallback + let expected_with_else = expected + .iter() + .map(|item| item.unwrap_or(else_value)) + .map(Some) + .collect::(); + + // Test case + test_case_when_literal_lookup( + input_values, + lookup_map, + Some(ScalarValue::Utf8(Some(else_value.to_string()))), + Arc::new(expected_with_else), + ); + } + + #[test] + fn test_case_when_literal_lookup_int32_to_string() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (Some(2), Some("two")), + (Some(3), Some("three")), + (Some(1), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_none_case_should_never_match() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (None, Some("none")), + (Some(2), Some("two")), + (Some(1), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (Some(1), Some("one")), + (Some(5), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some(2), Some("two")), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some(2), Some("two")), + (Some(5), None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_int32_to_string_with_duplicate_cases() { + let lookup_map = create_lookup([ + (Some(4), Some("four")), + (Some(4), Some("no 4")), + (Some(2), Some("two")), + (Some(2), Some("no 2")), + (Some(3), Some("three")), + (Some(3), Some("no 3")), + (Some(2), Some("no 2")), + (Some(4), Some("no 4")), + (Some(2), Some("no 2")), + (Some(3), Some("no 3")), + (Some(4), Some("no 4")), + (Some(2), Some("no 2")), + (Some(3), Some("no 3")), + (Some(3), Some("no 3")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, None), // No match in WHEN + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f32_to_string_with_special_values_and_duplicate_cases( + ) { + let lookup_map = create_lookup([ + (Some(4.0), Some("four point zero")), + (Some(f32::NAN), Some("NaN")), + (Some(3.2), Some("three point two")), + // Duplicate case to make sure it is not used + (Some(f32::NAN), Some("should not use this NaN branch")), + (Some(f32::INFINITY), Some("Infinity")), + (Some(0.0), Some("zero")), + // Duplicate case to make sure it is not used + ( + Some(f32::INFINITY), + Some("should not use this Infinity branch"), + ), + (Some(1.1), Some("one point one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1.1, Some("one point one")), + (f32::NAN, Some("NaN")), + (3.2, Some("three point two")), + (3.2, Some("three point two")), + (0.0, Some("zero")), + (f32::INFINITY, Some("Infinity")), + (3.2, Some("three point two")), + (f32::NEG_INFINITY, None), // No match in WHEN + (f32::NEG_INFINITY, None), // No match in WHEN + (3.2, Some("three point two")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f16_to_string_with_special_values() { + let lookup_map = create_lookup([ + ( + ScalarValue::Float16(Some(f16::from_f32(3.2))), + Some("3 dot 2"), + ), + (ScalarValue::Float16(Some(f16::NAN)), Some("NaN")), + ( + ScalarValue::Float16(Some(f16::from_f32(17.4))), + Some("17 dot 4"), + ), + (ScalarValue::Float16(Some(f16::INFINITY)), Some("Infinity")), + (ScalarValue::Float16(Some(f16::ZERO)), Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (f16::from_f32(3.2), Some("3 dot 2")), + (f16::NAN, Some("NaN")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::INFINITY, Some("Infinity")), + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::NEG_INFINITY, None), // No match in WHEN + (f16::NEG_INFINITY, None), // No match in WHEN + (f16::from_f32(17.4), Some("17 dot 4")), + (f16::NEG_ZERO, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f32_to_string_with_special_values() { + let lookup_map = create_lookup([ + (3.2, Some("3 dot 2")), + (f32::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (f32::INFINITY, Some("Infinity")), + (f32::ZERO, Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (3.2, Some("3 dot 2")), + (f32::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (17.4, Some("17 dot 4")), + (f32::INFINITY, Some("Infinity")), + (17.4, Some("17 dot 4")), + (f32::NEG_INFINITY, None), // No match in WHEN + (f32::NEG_INFINITY, None), // No match in WHEN + (17.4, Some("17 dot 4")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_case_when_literal_lookup_f64_to_string_with_special_values() { + let lookup_map = create_lookup([ + (3.2, Some("3 dot 2")), + (f64::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (f64::INFINITY, Some("Infinity")), + (f64::ZERO, Some("zero")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (3.2, Some("3 dot 2")), + (f64::NAN, Some("NaN")), + (17.4, Some("17 dot 4")), + (17.4, Some("17 dot 4")), + (f64::INFINITY, Some("Infinity")), + (17.4, Some("17 dot 4")), + (f64::NEG_INFINITY, None), // No match in WHEN + (f64::NEG_INFINITY, None), // No match in WHEN + (17.4, Some("17 dot 4")), + (-0.0, None), // No match in WHEN + ]); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + // Test that we don't lose the decimal precision and scale info + #[test] + fn test_decimal_with_non_default_precision_and_scale() { + let lookup_map = create_lookup([ + (ScalarValue::Decimal32(Some(4), 3, 2), Some("four")), + (ScalarValue::Decimal32(Some(2), 3, 2), Some("two")), + (ScalarValue::Decimal32(Some(3), 3, 2), Some("three")), + (ScalarValue::Decimal32(Some(1), 3, 2), Some("one")), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + let input_values = input_values + .with_precision_and_scale(3, 2) + .expect("must be able to set precision and scale"); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + // Test that we don't lose the timezone info + #[test] + fn test_timestamp_with_non_default_timezone() { + let timezone: Option> = Some("-10:00".into()); + let lookup_map = create_lookup([ + ( + ScalarValue::TimestampMillisecond(Some(4), timezone.clone()), + Some("four"), + ), + ( + ScalarValue::TimestampMillisecond(Some(2), timezone.clone()), + Some("two"), + ), + ( + ScalarValue::TimestampMillisecond(Some(3), timezone.clone()), + Some("three"), + ), + ( + ScalarValue::TimestampMillisecond(Some(1), timezone.clone()), + Some("one"), + ), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (1, Some("one")), + (2, Some("two")), + (3, Some("three")), + (3, Some("three")), + (2, Some("two")), + (3, Some("three")), + (5, None), // No match in WHEN + (5, None), // No match in WHEN + (3, Some("three")), + (5, None), // No match in WHEN + ]); + + let input_values = input_values.with_timezone_opt(timezone); + + test_lookup_eval_with_and_without_else( + &lookup_map, + Arc::new(input_values), + expected, + ); + } + + #[test] + fn test_with_strings_to_int32() { + let lookup_map = create_lookup([ + (Some("why"), Some(42)), + (Some("what"), Some(22)), + (Some("when"), Some(17)), + ]); + + let (input_values, expected) = + create_input_and_expected::([ + (Some("why"), Some(42)), + (Some("5"), None), // No match in WHEN + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some(22)), + (None, None), // None cases are never match in CASE WHEN syntax + (None, None), // None cases are never match in CASE WHEN syntax + (Some("what"), Some(22)), + (Some("5"), None), // No match in WHEN + ]); + + let input_values = Arc::new(input_values) as ArrayRef; + + // Testing without ELSE should fallback to None + test_case_when_literal_lookup( + Arc::clone(&input_values), + &lookup_map, + None, + Arc::new(expected.clone()), + ); + + // Testing with Else + let else_value = 101; + + // Changing each expected None to be fallback + let expected_with_else = expected + .iter() + .map(|item| item.unwrap_or(else_value)) + .map(Some) + .collect::(); + + // Test case + test_case_when_literal_lookup( + input_values, + &lookup_map, + Some(ScalarValue::Int32(Some(else_value))), + Arc::new(expected_with_else), + ); + } } diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs new file mode 100644 index 000000000000..3e09a19190b2 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/boolean_lookup_table.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ArrayRef, AsArray}; +use datafusion_common::{internal_err, ScalarValue}; + +#[derive(Clone, Debug)] +pub(super) struct BooleanIndexMap { + true_index: Option, + false_index: Option, +} + +impl BooleanIndexMap { + /// Try creating a new lookup table from the given literals and else index + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec, + ) -> datafusion_common::Result { + let mut true_index: Option = None; + let mut false_index: Option = None; + + for (index, literal) in unique_non_null_literals.into_iter().enumerate() { + match literal { + ScalarValue::Boolean(Some(true)) => { + if true_index.is_some() { + return internal_err!( + "Duplicate true literal found in literals for BooleanIndexMap" + ); + } + true_index = Some(index as u32); + } + ScalarValue::Boolean(Some(false)) => { + if false_index.is_some() { + return internal_err!( + "Duplicate false literal found in literals for BooleanIndexMap" + ); + } + false_index = Some(index as u32); + } + ScalarValue::Boolean(None) => { + return internal_err!( + "Null literal found in non-null literals for BooleanIndexMap" + ) + } + _ => { + return internal_err!( + "Non-boolean literal found in literals for BooleanIndexMap" + ) + } + } + } + + Ok(Self { + true_index, + false_index, + }) + } +} + +impl WhenLiteralIndexMap for BooleanIndexMap { + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result> { + let true_index = self.true_index.unwrap_or(else_index); + let false_index = self.false_index.unwrap_or(else_index); + + Ok(array + .as_boolean() + .into_iter() + .map(|value| match value { + Some(true) => true_index, + Some(false) => false_index, + None => else_index, + }) + .collect::>()) + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs new file mode 100644 index 000000000000..6f9c6c488886 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/bytes_like_lookup_table.rs @@ -0,0 +1,225 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{ + downcast_integer, Array, ArrayRef, AsArray, BinaryArray, BinaryViewArray, + DictionaryArray, FixedSizeBinaryArray, LargeBinaryArray, LargeStringArray, + StringArray, StringViewArray, +}; +use arrow::datatypes::{ + ArrowDictionaryKeyType, BinaryViewType, DataType, StringViewType, +}; +use datafusion_common::{internal_err, plan_datafusion_err, HashMap, ScalarValue}; +use std::fmt::Debug; + +/// Map from byte-like literal values to their first occurrence index +/// +/// This is a wrapper for handling different kinds of literal maps +#[derive(Clone, Debug)] +pub(super) struct BytesLikeIndexMap { + /// Map from non-null literal value the first occurrence index in the literals + map: HashMap, u32>, +} + +impl BytesLikeIndexMap { + /// Try creating a new lookup table from the given literals and else index + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec, + ) -> datafusion_common::Result { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.logical_null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let map: HashMap, u32> = try_get_bytes_iterator(&input)? + // Flattening Option<&[u8]> to &[u8] as literals cannot contain nulls + .flatten() + .enumerate() + .map(|(map_index, value)| (value.to_vec(), map_index as u32)) + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .collect(); + + Ok(Self { map }) + } +} + +impl WhenLiteralIndexMap for BytesLikeIndexMap { + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result> { + let indices = try_get_bytes_iterator(array)? + .map(|value| match value { + Some(value) => self.map.get(value).copied().unwrap_or(else_index), + None => else_index, + }) + .collect::>(); + + Ok(indices) + } +} + +fn try_get_bytes_iterator( + array: &ArrayRef, +) -> datafusion_common::Result> + '_>> { + Ok(match array.data_type() { + DataType::Utf8 => Box::new(array.as_string::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })), + + DataType::LargeUtf8 => { + Box::new(array.as_string::().into_iter().map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + })) + } + + DataType::Binary => Box::new(array.as_binary::().into_iter()), + + DataType::LargeBinary => Box::new(array.as_binary::().into_iter()), + + DataType::FixedSizeBinary(_) => Box::new(array.as_binary::().into_iter()), + + DataType::Utf8View => Box::new( + array + .as_byte_view::() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + DataType::BinaryView => { + Box::new(array.as_byte_view::().into_iter()) + } + + DataType::Dictionary(key, _) => { + macro_rules! downcast_dictionary_array_helper { + ($t:ty) => {{ + get_bytes_iterator_for_dictionary(array.as_dictionary::<$t>())? + }}; + } + + downcast_integer! { + key.as_ref() => (downcast_dictionary_array_helper), + k => unreachable!("unsupported dictionary key type: {}", k) + } + } + t => { + return Err(plan_datafusion_err!( + "Unsupported data type for bytes lookup table: {}", + t + )) + } + }) +} + +fn get_bytes_iterator_for_dictionary( + array: &DictionaryArray, +) -> datafusion_common::Result> + '_>> { + Ok(match array.values().data_type() { + DataType::Utf8 => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + + DataType::LargeUtf8 => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + + DataType::Binary => { + Box::new(array.downcast_dict::().unwrap().into_iter()) + } + + DataType::LargeBinary => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter(), + ), + + DataType::FixedSizeBinary(_) => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter(), + ), + + DataType::Utf8View => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter() + .map(|item| { + item.map(|v| { + let bytes: &[u8] = v.as_ref(); + + bytes + }) + }), + ), + DataType::BinaryView => Box::new( + array + .downcast_dict::() + .unwrap() + .into_iter(), + ), + + t => { + return Err(plan_datafusion_err!( + "Unsupported data type for lookup table dictionary value: {}", + t + )) + } + }) +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs new file mode 100644 index 000000000000..597a4882672c --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/mod.rs @@ -0,0 +1,328 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod boolean_lookup_table; +mod bytes_like_lookup_table; +mod primitive_lookup_table; + +use crate::expressions::case::literal_lookup_table::boolean_lookup_table::BooleanIndexMap; +use crate::expressions::case::literal_lookup_table::bytes_like_lookup_table::BytesLikeIndexMap; +use crate::expressions::case::literal_lookup_table::primitive_lookup_table::PrimitiveIndexMap; +use crate::expressions::case::CaseBody; +use crate::expressions::Literal; +use arrow::array::{downcast_primitive, Array, ArrayRef, UInt32Array}; +use arrow::datatypes::DataType; +use datafusion_common::DataFusionError; +use datafusion_common::{arrow_datafusion_err, plan_datafusion_err, ScalarValue}; +use indexmap::IndexMap; +use std::fmt::Debug; + +/// Optimization for CASE expressions with literal WHEN and THEN clauses +/// +/// for this form: +/// ```sql +/// CASE +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// ELSE +/// END +/// ``` +/// +/// # Improvement idea +/// TODO - we should think of unwrapping the `IN` expressions into multiple equality comparisons +/// so it will use this optimization as well, e.g. +/// ```sql +/// -- Before +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( in (, ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// +/// -- After +/// CASE +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// WHEN ( = ) THEN +/// ELSE +/// END +/// ``` +/// +#[derive(Debug)] +pub(in super::super) struct LiteralLookupTable { + /// The lookup table to use for evaluating the CASE expression + lookup: Box, + + else_index: u32, + + /// [`ArrayRef`] where `array[i] = then_literals[i]` + /// the last value in the array is the else_expr + /// + /// This will be used to take from based on the indices returned by the lookup table to build the final output + then_and_else_values: ArrayRef, +} + +impl LiteralLookupTable { + pub(in super::super) fn maybe_new(body: &CaseBody) -> Option { + // We can't use the optimization if we don't have any when then pairs + if body.when_then_expr.is_empty() { + return None; + } + + // If we only have 1 than this optimization is not useful + if body.when_then_expr.len() == 1 { + return None; + } + + // Try to downcast all the WHEN/THEN expressions to literals + let when_then_exprs_maybe_literals = body + .when_then_expr + .iter() + .map(|(when, then)| { + let when_maybe_literal = when.as_any().downcast_ref::(); + let then_maybe_literal = then.as_any().downcast_ref::(); + + when_maybe_literal.zip(then_maybe_literal) + }) + .collect::>(); + + // If not all the WHEN/THEN expressions are literals we cannot use this optimization + if when_then_exprs_maybe_literals.contains(&None) { + return None; + } + + let when_then_exprs_scalars = when_then_exprs_maybe_literals + .into_iter() + // Unwrap the options as we have already checked there is no None + .flatten() + .map(|(when_lit, then_lit)| { + (when_lit.value().clone(), then_lit.value().clone()) + }) + // Only keep non-null WHEN literals + // as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE + .filter(|(when_lit, _)| !when_lit.is_null()) + .collect::>(); + + if when_then_exprs_scalars.is_empty() { + // All WHEN literals were nulls, so cannot use optimization + // + // instead, another optimization would be to go straight to the ELSE clause + return None; + } + + // Keep only the first occurrence of each when literal (as the first match is used) + // and remove nulls (as they cannot be matched - case NULL WHEN NULL THEN ... ELSE ... END always goes to ELSE) + let (when, then): (Vec, Vec) = { + let mut map = IndexMap::with_capacity(body.when_then_expr.len()); + + for (when, then) in when_then_exprs_scalars.into_iter() { + // Don't overwrite existing entries as we want to keep the first occurrence + if !map.contains_key(&when) { + map.insert(when, then); + } + } + + map.into_iter().unzip() + }; + + let else_value: ScalarValue = if let Some(else_expr) = &body.else_expr { + let literal = else_expr.as_any().downcast_ref::()?; + + literal.value().clone() + } else { + let Ok(null_scalar) = ScalarValue::try_new_null(&then[0].data_type()) else { + return None; + }; + + null_scalar + }; + + { + let when_data_type = when[0].data_type(); + + // If not all the WHEN literals are the same data type we cannot use this optimization + if when.iter().any(|l| l.data_type() != when_data_type) { + return None; + } + } + + { + let data_type = then[0].data_type(); + + // If not all the then and the else literals are the same data type we cannot use this optimization + if then.iter().any(|l| l.data_type() != data_type) { + return None; + } + + if else_value.data_type() != data_type { + return None; + } + } + + let then_and_else_values = ScalarValue::iter_to_array( + then.iter() + // The else is in the end + .chain(std::iter::once(&else_value)) + .cloned(), + ) + .ok()?; + // The else expression is in the end + let else_index = then_and_else_values.len() as u32 - 1; + + let lookup = try_creating_lookup_table(when).ok()?; + + Some(Self { + lookup, + then_and_else_values, + else_index, + }) + } + + pub(in super::super) fn map_keys_to_values( + &self, + keys_array: &ArrayRef, + ) -> datafusion_common::Result { + let take_indices = self + .lookup + .map_to_when_indices(keys_array, self.else_index)?; + + // Zero-copy conversion + let take_indices = UInt32Array::from(take_indices); + + // An optimize version would depend on the type of the values_to_take_from + // For example, if the type is view we can just keep pointing to the same value (similar to dictionary) + // if the type is dictionary we can just use the indices as is (or cast them to the key type) and create a new dictionary array + let output = + arrow::compute::take(&self.then_and_else_values, &take_indices, None) + .map_err(|e| arrow_datafusion_err!(e))?; + + Ok(output) + } +} + +/// Map values that match the WHEN literal to the index of their corresponding WHEN clause +/// +/// For example, for this CASE expression: +/// +/// ```sql +/// CASE +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// WHEN THEN +/// ELSE +/// END +/// ``` +/// +/// this will map to 0, to 1, to 2, to 3 +pub(super) trait WhenLiteralIndexMap: Debug + Send + Sync { + /// Given an array of values, returns a vector of WHEN clause indices corresponding to each value in the provided array. + /// + /// For example, for this CASE expression: + /// + /// ```sql + /// CASE + /// WHEN THEN + /// WHEN THEN + /// WHEN THEN + /// WHEN THEN + /// ELSE + /// END + /// ``` + /// + /// the array will be the evaluated values of `` + /// and if that array is: + /// - `[, , , , ]` + /// + /// the returned vector will be: + /// - `[0, 2, else_index, 1, 0]` + /// + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result>; +} + +fn try_creating_lookup_table( + unique_non_null_literals: Vec, +) -> datafusion_common::Result> { + assert_ne!( + unique_non_null_literals.len(), + 0, + "Must have at least one literal" + ); + match unique_non_null_literals[0].data_type() { + DataType::Boolean => { + let lookup_table = BooleanIndexMap::try_new(unique_non_null_literals)?; + Ok(Box::new(lookup_table)) + } + + data_type if data_type.is_primitive() => { + macro_rules! create_matching_map { + ($t:ty) => {{ + let lookup_table = + PrimitiveIndexMap::<$t>::try_new(unique_non_null_literals)?; + Ok(Box::new(lookup_table)) + }}; + } + + downcast_primitive! { + data_type => (create_matching_map), + _ => Err(plan_datafusion_err!( + "Unsupported field type for primitive: {:?}", + data_type + )), + } + } + + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::FixedSizeBinary(_) + | DataType::Utf8View + | DataType::BinaryView => { + let lookup_table = BytesLikeIndexMap::try_new(unique_non_null_literals)?; + Ok(Box::new(lookup_table)) + } + + DataType::Dictionary(_key, value) + if matches!( + value.as_ref(), + DataType::Utf8 + | DataType::LargeUtf8 + | DataType::Binary + | DataType::LargeBinary + | DataType::FixedSizeBinary(_) + | DataType::Utf8View + | DataType::BinaryView + ) => + { + let lookup_table = BytesLikeIndexMap::try_new(unique_non_null_literals)?; + Ok(Box::new(lookup_table)) + } + + _ => Err(plan_datafusion_err!( + "Unsupported data type for lookup table: {}", + unique_non_null_literals[0].data_type() + )), + } +} diff --git a/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs new file mode 100644 index 000000000000..46f3fc3a86ac --- /dev/null +++ b/datafusion/physical-expr/src/expressions/case/literal_lookup_table/primitive_lookup_table.rs @@ -0,0 +1,193 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::expressions::case::literal_lookup_table::WhenLiteralIndexMap; +use arrow::array::{Array, ArrayRef, ArrowNativeTypeOp, ArrowPrimitiveType, AsArray}; +use arrow::datatypes::{i256, IntervalDayTime, IntervalMonthDayNano}; +use datafusion_common::{internal_err, HashMap, ScalarValue}; +use half::f16; +use std::fmt::Debug; +use std::hash::Hash; + +#[derive(Clone)] +pub(super) struct PrimitiveIndexMap +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + /// Literal value to map index + /// + /// If searching this map becomes a bottleneck consider using linear map implementations for small hashmaps + map: HashMap<::HashableKey, u32>, +} + +impl Debug for PrimitiveIndexMap +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PrimitiveIndexMap") + .field("map", &self.map) + .finish() + } +} + +impl PrimitiveIndexMap +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + /// Try creating a new lookup table from the given literals and else index. + /// The index of each literal in the vector is used as the mapped value in the lookup table. + /// + /// `literals` are guaranteed to be unique and non-nullable + pub(super) fn try_new( + unique_non_null_literals: Vec, + ) -> datafusion_common::Result { + let input = ScalarValue::iter_to_array(unique_non_null_literals)?; + + // Literals are guaranteed to not contain nulls + if input.null_count() > 0 { + return internal_err!("Literal values for WHEN clauses cannot contain nulls"); + } + + let map = input + .as_primitive::() + .values() + .iter() + .enumerate() + // Because literals are unique we can collect directly, and we can avoid only inserting the first occurrence + .map(|(map_index, value)| (value.into_hashable_key(), map_index as u32)) + .collect(); + + Ok(Self { map }) + } +} + +impl WhenLiteralIndexMap for PrimitiveIndexMap +where + T: ArrowPrimitiveType, + T::Native: ToHashableKey, +{ + fn map_to_when_indices( + &self, + array: &ArrayRef, + else_index: u32, + ) -> datafusion_common::Result> { + let indices = array + .as_primitive::() + .into_iter() + .map(|value| match value { + Some(value) => self + .map + .get(&value.into_hashable_key()) + .copied() + .unwrap_or(else_index), + + None => else_index, + }) + .collect::>(); + + Ok(indices) + } +} + +// TODO - We need to port it to arrow so that it can be reused in other places + +/// Trait that help convert a value to a key that is hashable and equatable +/// This is needed as some types like f16/f32/f64 do not implement Hash/Eq directly +pub(super) trait ToHashableKey: ArrowNativeTypeOp { + /// The type that is hashable and equatable + /// It must be an Arrow native type but it NOT GUARANTEED to be the same as Self + /// this is just a helper trait so you can reuse the same code for all arrow native types + type HashableKey: Hash + Eq + Debug + Clone + Copy + Send + Sync; + + /// Converts self to a hashable key + /// the result of this value can be used as the key in hash maps/sets + fn into_hashable_key(self) -> Self::HashableKey; +} + +macro_rules! impl_to_hashable_key { + (@single_already_hashable | $t:ty) => { + impl ToHashableKey for $t { + type HashableKey = $t; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self + } + } + }; + (@already_hashable | $($t:ty),+ $(,)?) => { + $( + impl_to_hashable_key!(@single_already_hashable | $t); + )+ + }; + (@float | $t:ty => $hashable:ty) => { + impl ToHashableKey for $t { + type HashableKey = $hashable; + + #[inline] + fn into_hashable_key(self) -> Self::HashableKey { + self.to_bits() + } + } + }; +} + +impl_to_hashable_key!(@already_hashable | i8, i16, i32, i64, i128, i256, u8, u16, u32, u64, IntervalDayTime, IntervalMonthDayNano); +impl_to_hashable_key!(@float | f16 => u16); +impl_to_hashable_key!(@float | f32 => u32); +impl_to_hashable_key!(@float | f64 => u64); + +#[cfg(test)] +mod tests { + use super::ToHashableKey; + use arrow::array::downcast_primitive; + + // This test ensure that all arrow primitive types implement ToHashableKey + // otherwise the code will not compile + #[test] + fn should_implement_to_hashable_key_for_all_primitives() { + #[derive(Debug, Default)] + struct ExampleSet + where + T: arrow::datatypes::ArrowPrimitiveType, + T::Native: ToHashableKey, + { + _map: std::collections::HashSet<::HashableKey>, + } + + macro_rules! create_matching_set { + ($t:ty) => {{ + let _lookup_table = ExampleSet::<$t> { + _map: Default::default(), + }; + + return; + }}; + } + + let data_type = arrow::datatypes::DataType::Float16; + + downcast_primitive! { + data_type => (create_matching_set), + _ => panic!("not implemented for {data_type}"), + } + } +}