diff --git a/datafusion/functions-nested/src/array_math.rs b/datafusion/functions-nested/src/array_math.rs new file mode 100644 index 0000000000000..1a10ba5bb8b30 --- /dev/null +++ b/datafusion/functions-nested/src/array_math.rs @@ -0,0 +1,790 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [ScalarUDFImpl] definitions for array_add, array_subtract, and array_scale functions. + +use crate::utils::make_scalar_function; +use crate::vector_math::convert_to_f64_array; +use arrow::array::{ + Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{ + DataType, Field, + DataType::{FixedSizeList, LargeList, List, Null}, +}; +use datafusion_common::cast::{as_float64_array, as_generic_list_array}; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_functions::downcast_arg; +use datafusion_macros::user_doc; +use itertools::Itertools; +use std::sync::Arc; + +// ============================================================================ +// array_add +// ============================================================================ + +make_udf_expr_and_func!( + ArrayAdd, + array_add, + array1 array2, + "returns element-wise addition of two numeric arrays.", + array_add_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns a new array with element-wise addition of two input arrays of equal length.", + syntax_example = "array_add(array1, array2)", + sql_example = r#"```sql +> select array_add([1.0, 2.0], [3.0, 4.0]); ++------------------------------------------+ +| array_add(List([1.0,2.0]), List([3.0,4.0])) | ++------------------------------------------+ +| [4.0, 6.0] | ++------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayAdd { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayAdd { + fn default() -> Self { + Self::new() + } +} + +impl ArrayAdd { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_add".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayAdd { + fn name(&self) -> &str { + "array_add" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + List(_) | FixedSizeList(..) | Null => Ok(List(Arc::new( + Field::new_list_field(DataType::Float64, true), + ))), + LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field( + DataType::Float64, + true, + )))), + _ => exec_err!("array_add does not support type {}", arg_types[0]), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) + } + }); + + arg_types.try_collect() + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_add_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_add_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("array_add", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_array_binary_op::(args, "array_add", |a, b| a + b), + (LargeList(_), LargeList(_)) => { + general_array_binary_op::(args, "array_add", |a, b| a + b) + } + (arg_type1, arg_type2) => { + exec_err!("array_add does not support types {arg_type1} and {arg_type2}") + } + } +} + +// ============================================================================ +// array_subtract +// ============================================================================ + +make_udf_expr_and_func!( + ArraySubtract, + array_subtract, + array1 array2, + "returns element-wise subtraction of two numeric arrays.", + array_subtract_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns a new array with element-wise subtraction of two input arrays of equal length.", + syntax_example = "array_subtract(array1, array2)", + sql_example = r#"```sql +> select array_subtract([5.0, 3.0], [1.0, 2.0]); ++-----------------------------------------------+ +| array_subtract(List([5.0,3.0]), List([1.0,2.0])) | ++-----------------------------------------------+ +| [4.0, 1.0] | ++-----------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArraySubtract { + signature: Signature, + aliases: Vec, +} + +impl Default for ArraySubtract { + fn default() -> Self { + Self::new() + } +} + +impl ArraySubtract { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_subtract".to_string()], + } + } +} + +impl ScalarUDFImpl for ArraySubtract { + fn name(&self) -> &str { + "array_subtract" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + List(_) | FixedSizeList(..) | Null => Ok(List(Arc::new( + Field::new_list_field(DataType::Float64, true), + ))), + LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field( + DataType::Float64, + true, + )))), + _ => exec_err!( + "array_subtract does not support type {}", + arg_types[0] + ), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) + } + }); + + arg_types.try_collect() + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_subtract_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_subtract_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("array_subtract", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => { + general_array_binary_op::(args, "array_subtract", |a, b| a - b) + } + (LargeList(_), LargeList(_)) => { + general_array_binary_op::(args, "array_subtract", |a, b| a - b) + } + (arg_type1, arg_type2) => { + exec_err!( + "array_subtract does not support types {arg_type1} and {arg_type2}" + ) + } + } +} + +// ============================================================================ +// array_scale +// ============================================================================ + +make_udf_expr_and_func!( + ArrayScale, + array_scale, + array scalar, + "returns array with each element multiplied by a scalar.", + array_scale_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns a new array with each element multiplied by the given scalar value.", + syntax_example = "array_scale(array, scalar)", + sql_example = r#"```sql +> select array_scale([1.0, 2.0, 3.0], 2.0); ++-------------------------------------------+ +| array_scale(List([1.0,2.0,3.0]), 2.0) | ++-------------------------------------------+ +| [2.0, 4.0, 6.0] | ++-------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "scalar", + description = "Float64 scalar value to multiply each element by." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayScale { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayScale { + fn default() -> Self { + Self::new() + } +} + +impl ArrayScale { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_scale".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayScale { + fn name(&self) -> &str { + "array_scale" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + match &arg_types[0] { + List(_) | FixedSizeList(..) | Null => Ok(List(Arc::new( + Field::new_list_field(DataType::Float64, true), + ))), + LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field( + DataType::Float64, + true, + )))), + _ => exec_err!("array_scale does not support type {}", arg_types[0]), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + + let coercion = Some(&ListCoercion::FixedSizedListToList); + let first = &arg_types[0]; + let coerced_first = if matches!( + first, + Null | List(_) | LargeList(_) | FixedSizeList(..) + ) { + coerced_type_with_base_type_only(first, &DataType::Float64, coercion) + } else { + return plan_err!("{} does not support type {first}", self.name()); + }; + + // Second argument is scalar Float64 + Ok(vec![coerced_first, DataType::Float64]) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_scale_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_scale_inner(args: &[ArrayRef]) -> Result { + let [array1, _scalar] = take_function_args("array_scale", args)?; + match array1.data_type() { + List(_) => general_array_scale::(args), + LargeList(_) => general_array_scale::(args), + arg_type => { + exec_err!("array_scale does not support type {arg_type}") + } + } +} + +fn general_array_scale(arrays: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&arrays[0])?; + let scalar_array = as_float64_array(&arrays[1])?; + + let mut result_values: Vec> = Vec::new(); + let mut offsets: Vec = vec![O::from_usize(0).unwrap()]; + let mut nulls: Vec = Vec::new(); + + for (i, arr) in list_array.iter().enumerate() { + let scalar_val = if scalar_array.is_null(i) { + None + } else { + Some(scalar_array.value(i)) + }; + + match compute_scale(arr, scalar_val)? { + Some(scaled) => { + for j in 0..scaled.len() { + if scaled.is_null(j) { + result_values.push(None); + } else { + result_values.push(Some(scaled.value(j))); + } + } + offsets.push(O::from_usize(result_values.len()).unwrap()); + nulls.push(true); + } + None => { + offsets.push(O::from_usize(result_values.len()).unwrap()); + nulls.push(false); + } + } + } + + let values_array = Arc::new(Float64Array::from(result_values)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from(nulls); + + Ok(Arc::new(arrow::array::GenericListArray::::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) as ArrayRef) +} + +fn compute_scale( + arr: Option, + scalar: Option, +) -> Result> { + let value = match arr { + Some(arr) => arr, + None => return Ok(None), + }; + let scalar = match scalar { + Some(s) => s, + None => return Ok(None), + }; + + let mut value = value; + + loop { + match value.data_type() { + List(_) => { + if downcast_arg!(value, ListArray).null_count() > 0 { + return Ok(None); + } + value = downcast_arg!(value, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value, LargeListArray).null_count() > 0 { + return Ok(None); + } + value = downcast_arg!(value, LargeListArray).value(0); + } + _ => break, + } + } + + if value.null_count() != 0 { + return Ok(None); + } + + let values = convert_to_f64_array(&value)?; + let scaled: Float64Array = values + .iter() + .map(|v| v.map(|val| val * scalar)) + .collect(); + + Ok(Some(scaled)) +} + +// ============================================================================ +// Shared binary operation helper for array_add / array_subtract +// ============================================================================ + +fn general_array_binary_op( + arrays: &[ArrayRef], + fn_name: &str, + op: fn(f64, f64) -> f64, +) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let mut result_values: Vec> = Vec::new(); + let mut offsets: Vec = vec![O::from_usize(0).unwrap()]; + let mut nulls: Vec = Vec::new(); + + for (arr1, arr2) in list_array1.iter().zip(list_array2.iter()) { + match compute_binary_op(arr1, arr2, fn_name, op)? { + Some(result) => { + for i in 0..result.len() { + if result.is_null(i) { + result_values.push(None); + } else { + result_values.push(Some(result.value(i))); + } + } + offsets.push(O::from_usize(result_values.len()).unwrap()); + nulls.push(true); + } + None => { + offsets.push(O::from_usize(result_values.len()).unwrap()); + nulls.push(false); + } + } + } + + let values_array = Arc::new(Float64Array::from(result_values)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from(nulls); + + Ok(Arc::new(arrow::array::GenericListArray::::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) as ArrayRef) +} + +fn compute_binary_op( + arr1: Option, + arr2: Option, + _fn_name: &str, + op: fn(f64, f64) -> f64, +) -> Result> { + let value1 = match arr1 { + Some(arr) => arr, + None => return Ok(None), + }; + let value2 = match arr2 { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut value1 = value1; + let mut value2 = value2; + + loop { + match value1.data_type() { + List(_) => { + if downcast_arg!(value1, ListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value1, LargeListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, LargeListArray).value(0); + } + _ => break, + } + + match value2.data_type() { + List(_) => { + if downcast_arg!(value2, ListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value2, LargeListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, LargeListArray).value(0); + } + _ => break, + } + } + + if value1.null_count() != 0 || value2.null_count() != 0 { + return Ok(None); + } + + let values1 = convert_to_f64_array(&value1)?; + let values2 = convert_to_f64_array(&value2)?; + + if values1.len() != values2.len() { + return exec_err!("Both arrays must have the same length"); + } + + let result: Float64Array = values1 + .iter() + .zip(values2.iter()) + .map(|(v1, v2)| match (v1, v2) { + (Some(a), Some(b)) => Some(op(a, b)), + _ => None, + }) + .collect(); + + Ok(Some(result)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, ListArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + fn make_f64_list_array(values: Vec>>>) -> ArrayRef { + let mut flat: Vec> = Vec::new(); + let mut offsets: Vec = vec![0]; + for v in &values { + match v { + Some(inner) => { + flat.extend(inner); + offsets.push(flat.len() as i32); + } + None => { + offsets.push(flat.len() as i32); + } + } + } + let values_array = Arc::new(Float64Array::from(flat)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from( + values.iter().map(|v| v.is_some()).collect::>(), + ); + Arc::new(ListArray::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) + } + + fn make_f64_scalar_array(values: Vec>) -> ArrayRef { + Arc::new(Float64Array::from(values)) as ArrayRef + } + + // === array_add tests === + + #[test] + fn test_array_add_basic() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(3.0), Some(4.0)])]); + let result = array_add_inner(&[arr1, arr2]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + let inner = list.value(0); + let values = inner.as_any().downcast_ref::().unwrap(); + assert!((values.value(0) - 4.0).abs() < 1e-10); + assert!((values.value(1) - 6.0).abs() < 1e-10); + } + + #[test] + fn test_array_add_null() { + let arr1 = make_f64_list_array(vec![None]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); + let result = array_add_inner(&[arr1, arr2]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + assert!(list.is_null(0)); + } + + #[test] + fn test_array_add_mismatched_lengths() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); + let result = array_add_inner(&[arr1, arr2]); + assert!(result.is_err()); + } + + #[test] + fn test_array_add_empty() { + let arr1 = make_f64_list_array(vec![Some(vec![])]); + let arr2 = make_f64_list_array(vec![Some(vec![])]); + let result = array_add_inner(&[arr1, arr2]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + assert!(!list.is_null(0)); + assert_eq!(list.value(0).len(), 0); + } + + // === array_subtract tests === + + #[test] + fn test_array_subtract_basic() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(5.0), Some(3.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let result = array_subtract_inner(&[arr1, arr2]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + let inner = list.value(0); + let values = inner.as_any().downcast_ref::().unwrap(); + assert!((values.value(0) - 4.0).abs() < 1e-10); + assert!((values.value(1) - 1.0).abs() < 1e-10); + } + + #[test] + fn test_array_subtract_null() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); + let arr2 = make_f64_list_array(vec![None]); + let result = array_subtract_inner(&[arr1, arr2]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + assert!(list.is_null(0)); + } + + #[test] + fn test_array_subtract_mismatched_lengths() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let result = array_subtract_inner(&[arr1, arr2]); + assert!(result.is_err()); + } + + // === array_scale tests === + + #[test] + fn test_array_scale_basic() { + let arr = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); + let scalar = make_f64_scalar_array(vec![Some(2.0)]); + let result = array_scale_inner(&[arr, scalar]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + let inner = list.value(0); + let values = inner.as_any().downcast_ref::().unwrap(); + assert!((values.value(0) - 2.0).abs() < 1e-10); + assert!((values.value(1) - 4.0).abs() < 1e-10); + assert!((values.value(2) - 6.0).abs() < 1e-10); + } + + #[test] + fn test_array_scale_null_array() { + let arr = make_f64_list_array(vec![None]); + let scalar = make_f64_scalar_array(vec![Some(2.0)]); + let result = array_scale_inner(&[arr, scalar]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + assert!(list.is_null(0)); + } + + #[test] + fn test_array_scale_null_scalar() { + let arr = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let scalar = make_f64_scalar_array(vec![None]); + let result = array_scale_inner(&[arr, scalar]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + assert!(list.is_null(0)); + } + + #[test] + fn test_array_scale_zero() { + let arr = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let scalar = make_f64_scalar_array(vec![Some(0.0)]); + let result = array_scale_inner(&[arr, scalar]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + let inner = list.value(0); + let values = inner.as_any().downcast_ref::().unwrap(); + assert!((values.value(0) - 0.0).abs() < 1e-10); + assert!((values.value(1) - 0.0).abs() < 1e-10); + } + + #[test] + fn test_array_scale_empty() { + let arr = make_f64_list_array(vec![Some(vec![])]); + let scalar = make_f64_scalar_array(vec![Some(5.0)]); + let result = array_scale_inner(&[arr, scalar]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + assert!(!list.is_null(0)); + assert_eq!(list.value(0).len(), 0); + } +} diff --git a/datafusion/functions-nested/src/array_normalize.rs b/datafusion/functions-nested/src/array_normalize.rs new file mode 100644 index 0000000000000..7a5a8c0d758c2 --- /dev/null +++ b/datafusion/functions-nested/src/array_normalize.rs @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [ScalarUDFImpl] definitions for array_normalize function. + +use crate::utils::make_scalar_function; +use crate::vector_math::{convert_to_f64_array, magnitude_f64}; +use arrow::array::{ + Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, +}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::{ + DataType, Field, + DataType::{FixedSizeList, LargeList, List, Null}, +}; +use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_functions::downcast_arg; +use datafusion_macros::user_doc; +use itertools::Itertools; +use std::sync::Arc; + +make_udf_expr_and_func!( + ArrayNormalize, + array_normalize, + array, + "returns the array normalized to unit length (L2 norm).", + array_normalize_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the input array normalized to unit length (each element divided by the L2 norm). Returns NULL if the array has zero magnitude.", + syntax_example = "array_normalize(array)", + sql_example = r#"```sql +> select array_normalize([3.0, 4.0]); ++-----------------------------------+ +| array_normalize(List([3.0,4.0])) | ++-----------------------------------+ +| [0.6, 0.8] | ++-----------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayNormalize { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayNormalize { + fn default() -> Self { + Self::new() + } +} + +impl ArrayNormalize { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_normalize".to_string()], + } + } +} + +impl ScalarUDFImpl for ArrayNormalize { + fn name(&self) -> &str { + "array_normalize" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + // Return same list type but with Float64 elements + match &arg_types[0] { + List(_) | FixedSizeList(..) | Null => Ok(List(Arc::new(Field::new_list_field( + DataType::Float64, + true, + )))), + LargeList(_) => Ok(LargeList(Arc::new(Field::new_list_field( + DataType::Float64, + true, + )))), + _ => exec_err!( + "array_normalize does not support type {}", + arg_types[0] + ), + } + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) + } + }); + + arg_types.try_collect() + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(array_normalize_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn array_normalize_inner(args: &[ArrayRef]) -> Result { + let [array1] = take_function_args("array_normalize", args)?; + match array1.data_type() { + List(_) => general_array_normalize::(args), + LargeList(_) => general_array_normalize::(args), + arg_type => { + exec_err!("array_normalize does not support type {arg_type}") + } + } +} + +fn general_array_normalize(arrays: &[ArrayRef]) -> Result { + let list_array = as_generic_list_array::(&arrays[0])?; + + let mut result_values: Vec> = Vec::new(); + let mut offsets: Vec = vec![O::from_usize(0).unwrap()]; + let mut nulls: Vec = Vec::new(); + + for arr in list_array.iter() { + match compute_normalize(arr)? { + Some(normalized) => { + for i in 0..normalized.len() { + if normalized.is_null(i) { + result_values.push(None); + } else { + result_values.push(Some(normalized.value(i))); + } + } + offsets.push(O::from_usize(result_values.len()).unwrap()); + nulls.push(true); + } + None => { + offsets.push(O::from_usize(result_values.len()).unwrap()); + nulls.push(false); + } + } + } + + let values_array = Arc::new(Float64Array::from(result_values)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from(nulls); + + Ok(Arc::new(arrow::array::GenericListArray::::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) as ArrayRef) +} + +/// Normalizes an array to unit length by dividing each element by the L2 norm. +fn compute_normalize(arr: Option) -> Result> { + let value = match arr { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut value = value; + + loop { + match value.data_type() { + List(_) => { + if downcast_arg!(value, ListArray).null_count() > 0 { + return Ok(None); + } + value = downcast_arg!(value, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value, LargeListArray).null_count() > 0 { + return Ok(None); + } + value = downcast_arg!(value, LargeListArray).value(0); + } + _ => break, + } + } + + if value.null_count() != 0 { + return Ok(None); + } + + let values = convert_to_f64_array(&value)?; + let mag = magnitude_f64(&values); + + if mag == 0.0 { + return Ok(None); + } + + let normalized: Float64Array = values + .iter() + .map(|v| v.map(|val| val / mag)) + .collect(); + + Ok(Some(normalized)) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, ListArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + fn make_f64_list_array(values: Vec>>>) -> ArrayRef { + let mut flat: Vec> = Vec::new(); + let mut offsets: Vec = vec![0]; + for v in &values { + match v { + Some(inner) => { + flat.extend(inner); + offsets.push(flat.len() as i32); + } + None => { + offsets.push(flat.len() as i32); + } + } + } + let values_array = Arc::new(Float64Array::from(flat)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from( + values.iter().map(|v| v.is_some()).collect::>(), + ); + Arc::new(ListArray::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) + } + + #[test] + fn test_normalize_basic() { + let arr = make_f64_list_array(vec![Some(vec![Some(3.0), Some(4.0)])]); + let result = array_normalize_inner(&[arr]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + let inner = list.value(0); + let values = inner.as_any().downcast_ref::().unwrap(); + assert!((values.value(0) - 0.6).abs() < 1e-10); + assert!((values.value(1) - 0.8).abs() < 1e-10); + } + + #[test] + fn test_normalize_null_array() { + let arr = make_f64_list_array(vec![None]); + let result = array_normalize_inner(&[arr]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + assert!(list.is_null(0)); + } + + #[test] + fn test_normalize_zero_magnitude() { + let arr = make_f64_list_array(vec![Some(vec![Some(0.0), Some(0.0)])]); + let result = array_normalize_inner(&[arr]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + assert!(list.is_null(0)); + } + + #[test] + fn test_normalize_unit_vector() { + let arr = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let result = array_normalize_inner(&[arr]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + let inner = list.value(0); + let values = inner.as_any().downcast_ref::().unwrap(); + assert!((values.value(0) - 1.0).abs() < 1e-10); + assert!((values.value(1) - 0.0).abs() < 1e-10); + } + + #[test] + fn test_normalize_empty_array() { + let arr = make_f64_list_array(vec![Some(vec![])]); + let result = array_normalize_inner(&[arr]).unwrap(); + let list = result.as_any().downcast_ref::().unwrap(); + // Zero magnitude -> null + assert!(list.is_null(0)); + } +} diff --git a/datafusion/functions-nested/src/cosine_distance.rs b/datafusion/functions-nested/src/cosine_distance.rs new file mode 100644 index 0000000000000..b31d9c74b597f --- /dev/null +++ b/datafusion/functions-nested/src/cosine_distance.rs @@ -0,0 +1,336 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [ScalarUDFImpl] definitions for cosine_distance function. + +use crate::utils::make_scalar_function; +use crate::vector_math::{convert_to_f64_array, dot_product_f64, magnitude_f64}; +use arrow::array::{ + Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, +}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, +}; +use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_functions::downcast_arg; +use datafusion_macros::user_doc; +use itertools::Itertools; +use std::sync::Arc; + +make_udf_expr_and_func!( + CosineDistance, + cosine_distance, + array1 array2, + "returns the cosine distance between two numeric arrays.", + cosine_distance_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the cosine distance between two input arrays of equal length. The cosine distance is defined as 1 - cosine_similarity, i.e. `1 - dot(a,b) / (||a|| * ||b||)`.", + syntax_example = "cosine_distance(array1, array2)", + sql_example = r#"```sql +> select cosine_distance([1.0, 0.0], [0.0, 1.0]); ++-----------------------------------------------+ +| cosine_distance(List([1.0,0.0]), List([0.0,1.0])) | ++-----------------------------------------------+ +| 1.0 | ++-----------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CosineDistance { + signature: Signature, + aliases: Vec, +} + +impl Default for CosineDistance { + fn default() -> Self { + Self::new() + } +} + +impl CosineDistance { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec!["list_cosine_distance".to_string()], + } + } +} + +impl ScalarUDFImpl for CosineDistance { + fn name(&self) -> &str { + "cosine_distance" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) + } + }); + + arg_types.try_collect() + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(cosine_distance_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn cosine_distance_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("cosine_distance", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_cosine_distance::(args), + (LargeList(_), LargeList(_)) => general_cosine_distance::(args), + (arg_type1, arg_type2) => { + exec_err!( + "cosine_distance does not support types {arg_type1} and {arg_type2}" + ) + } + } +} + +fn general_cosine_distance(arrays: &[ArrayRef]) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let result = list_array1 + .iter() + .zip(list_array2.iter()) + .map(|(arr1, arr2)| compute_cosine_distance(arr1, arr2)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Computes the cosine distance between two arrays: 1 - dot(a,b) / (||a|| * ||b||) +fn compute_cosine_distance( + arr1: Option, + arr2: Option, +) -> Result> { + let value1 = match arr1 { + Some(arr) => arr, + None => return Ok(None), + }; + let value2 = match arr2 { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut value1 = value1; + let mut value2 = value2; + + loop { + match value1.data_type() { + List(_) => { + if downcast_arg!(value1, ListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value1, LargeListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, LargeListArray).value(0); + } + _ => break, + } + + match value2.data_type() { + List(_) => { + if downcast_arg!(value2, ListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value2, LargeListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, LargeListArray).value(0); + } + _ => break, + } + } + + // Check for NULL values inside the arrays + if value1.null_count() != 0 || value2.null_count() != 0 { + return Ok(None); + } + + let values1 = convert_to_f64_array(&value1)?; + let values2 = convert_to_f64_array(&value2)?; + + if values1.len() != values2.len() { + return exec_err!("Both arrays must have the same length"); + } + + let dot = dot_product_f64(&values1, &values2); + let mag1 = magnitude_f64(&values1); + let mag2 = magnitude_f64(&values2); + + if mag1 == 0.0 || mag2 == 0.0 { + return Ok(None); + } + + Ok(Some(1.0 - dot / (mag1 * mag2))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, ListArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + fn make_f64_list_array(values: Vec>>>) -> ArrayRef { + let mut flat: Vec> = Vec::new(); + let mut offsets: Vec = vec![0]; + for v in &values { + match v { + Some(inner) => { + flat.extend(inner); + offsets.push(flat.len() as i32); + } + None => { + offsets.push(flat.len() as i32); + } + } + } + let values_array = Arc::new(Float64Array::from(flat)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from( + values.iter().map(|v| v.is_some()).collect::>(), + ); + Arc::new(ListArray::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) + } + + #[test] + fn test_cosine_distance_orthogonal() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(0.0), Some(1.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 1.0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_identical() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.value(0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_opposite() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(-1.0), Some(0.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 2.0).abs() < 1e-10); + } + + #[test] + fn test_cosine_distance_null_array() { + let arr1 = make_f64_list_array(vec![None]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + } + + #[test] + fn test_cosine_distance_mismatched_lengths() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]); + assert!(result.is_err()); + } + + #[test] + fn test_cosine_distance_empty_arrays() { + let arr1 = make_f64_list_array(vec![Some(vec![])]); + let arr2 = make_f64_list_array(vec![Some(vec![])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + // Zero magnitude -> null + assert!(result.is_null(0)); + } + + #[test] + fn test_cosine_distance_float32_coerced() { + // Float32 gets coerced to Float64 by coerce_types, so the inner function + // always receives Float64. This test confirms the pattern works. + let arr1 = make_f64_list_array(vec![Some(vec![Some(3.0), Some(4.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(4.0), Some(3.0)])]); + let result = cosine_distance_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + // cos_sim = (12+12) / (5*5) = 24/25 = 0.96, distance = 0.04 + assert!((result.value(0) - 0.04).abs() < 1e-10); + } +} diff --git a/datafusion/functions-nested/src/distance.rs b/datafusion/functions-nested/src/distance.rs index edf1806b66c2d..2d83651427f3d 100644 --- a/datafusion/functions-nested/src/distance.rs +++ b/datafusion/functions-nested/src/distance.rs @@ -18,6 +18,7 @@ //! [ScalarUDFImpl] definitions for array_distance function. use crate::utils::make_scalar_function; +use crate::vector_math::convert_to_f64_array; use arrow::array::{ Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, }; @@ -25,10 +26,7 @@ use arrow::datatypes::{ DataType, DataType::{FixedSizeList, LargeList, List, Null}, }; -use datafusion_common::cast::{ - as_float32_array, as_float64_array, as_generic_list_array, as_int32_array, - as_int64_array, -}; +use datafusion_common::cast::as_generic_list_array; use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; use datafusion_expr::{ @@ -233,28 +231,4 @@ fn compute_array_distance( Ok(Some(sum_squares.sqrt())) } -/// Converts an array of any numeric type to a Float64Array. -fn convert_to_f64_array(array: &ArrayRef) -> Result { - match array.data_type() { - DataType::Float64 => Ok(as_float64_array(array)?.clone()), - DataType::Float32 => { - let array = as_float32_array(array)?; - let converted: Float64Array = - array.iter().map(|v| v.map(|v| v as f64)).collect(); - Ok(converted) - } - DataType::Int64 => { - let array = as_int64_array(array)?; - let converted: Float64Array = - array.iter().map(|v| v.map(|v| v as f64)).collect(); - Ok(converted) - } - DataType::Int32 => { - let array = as_int32_array(array)?; - let converted: Float64Array = - array.iter().map(|v| v.map(|v| v as f64)).collect(); - Ok(converted) - } - _ => exec_err!("Unsupported array type for conversion to Float64Array"), - } -} +// convert_to_f64_array is now shared via crate::vector_math diff --git a/datafusion/functions-nested/src/inner_product.rs b/datafusion/functions-nested/src/inner_product.rs new file mode 100644 index 0000000000000..1cba68583bf4a --- /dev/null +++ b/datafusion/functions-nested/src/inner_product.rs @@ -0,0 +1,311 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [ScalarUDFImpl] definitions for inner_product function. + +use crate::utils::make_scalar_function; +use crate::vector_math::{convert_to_f64_array, dot_product_f64}; +use arrow::array::{ + Array, ArrayRef, Float64Array, LargeListArray, ListArray, OffsetSizeTrait, +}; +use arrow::datatypes::{ + DataType, + DataType::{FixedSizeList, LargeList, List, Null}, +}; +use datafusion_common::cast::as_generic_list_array; +use datafusion_common::utils::{ListCoercion, coerced_type_with_base_type_only}; +use datafusion_common::{Result, exec_err, plan_err, utils::take_function_args}; +use datafusion_expr::{ + ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, + Volatility, +}; +use datafusion_functions::downcast_arg; +use datafusion_macros::user_doc; +use itertools::Itertools; +use std::sync::Arc; + +make_udf_expr_and_func!( + InnerProduct, + inner_product, + array1 array2, + "returns the inner (dot) product of two numeric arrays.", + inner_product_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "Returns the inner product (dot product) of two input arrays of equal length: `sum(a[i] * b[i])`.", + syntax_example = "inner_product(array1, array2)", + sql_example = r#"```sql +> select inner_product([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]); ++------------------------------------------------------------+ +| inner_product(List([1.0,2.0,3.0]), List([4.0,5.0,6.0])) | ++------------------------------------------------------------+ +| 32.0 | ++------------------------------------------------------------+ +```"#, + argument( + name = "array1", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument( + name = "array2", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ) +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct InnerProduct { + signature: Signature, + aliases: Vec, +} + +impl Default for InnerProduct { + fn default() -> Self { + Self::new() + } +} + +impl InnerProduct { + pub fn new() -> Self { + Self { + signature: Signature::user_defined(Volatility::Immutable), + aliases: vec![ + "list_inner_product".to_string(), + "dot_product".to_string(), + ], + } + } +} + +impl ScalarUDFImpl for InnerProduct { + fn name(&self) -> &str { + "inner_product" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + let [_, _] = take_function_args(self.name(), arg_types)?; + let coercion = Some(&ListCoercion::FixedSizedListToList); + let arg_types = arg_types.iter().map(|arg_type| { + if matches!(arg_type, Null | List(_) | LargeList(_) | FixedSizeList(..)) { + Ok(coerced_type_with_base_type_only( + arg_type, + &DataType::Float64, + coercion, + )) + } else { + plan_err!("{} does not support type {arg_type}", self.name()) + } + }); + + arg_types.try_collect() + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + make_scalar_function(inner_product_inner)(&args.args) + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} + +fn inner_product_inner(args: &[ArrayRef]) -> Result { + let [array1, array2] = take_function_args("inner_product", args)?; + match (array1.data_type(), array2.data_type()) { + (List(_), List(_)) => general_inner_product::(args), + (LargeList(_), LargeList(_)) => general_inner_product::(args), + (arg_type1, arg_type2) => { + exec_err!( + "inner_product does not support types {arg_type1} and {arg_type2}" + ) + } + } +} + +fn general_inner_product(arrays: &[ArrayRef]) -> Result { + let list_array1 = as_generic_list_array::(&arrays[0])?; + let list_array2 = as_generic_list_array::(&arrays[1])?; + + let result = list_array1 + .iter() + .zip(list_array2.iter()) + .map(|(arr1, arr2)| compute_inner_product(arr1, arr2)) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Computes the inner product of two arrays: sum(a[i] * b[i]) +fn compute_inner_product( + arr1: Option, + arr2: Option, +) -> Result> { + let value1 = match arr1 { + Some(arr) => arr, + None => return Ok(None), + }; + let value2 = match arr2 { + Some(arr) => arr, + None => return Ok(None), + }; + + let mut value1 = value1; + let mut value2 = value2; + + loop { + match value1.data_type() { + List(_) => { + if downcast_arg!(value1, ListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value1, LargeListArray).null_count() > 0 { + return Ok(None); + } + value1 = downcast_arg!(value1, LargeListArray).value(0); + } + _ => break, + } + + match value2.data_type() { + List(_) => { + if downcast_arg!(value2, ListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, ListArray).value(0); + } + LargeList(_) => { + if downcast_arg!(value2, LargeListArray).null_count() > 0 { + return Ok(None); + } + value2 = downcast_arg!(value2, LargeListArray).value(0); + } + _ => break, + } + } + + // Check for NULL values inside the arrays + if value1.null_count() != 0 || value2.null_count() != 0 { + return Ok(None); + } + + let values1 = convert_to_f64_array(&value1)?; + let values2 = convert_to_f64_array(&value2)?; + + if values1.len() != values2.len() { + return exec_err!("Both arrays must have the same length"); + } + + Ok(Some(dot_product_f64(&values1, &values2))) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Float64Array, ListArray}; + use arrow::buffer::OffsetBuffer; + use arrow::datatypes::{DataType, Field}; + use std::sync::Arc; + + fn make_f64_list_array(values: Vec>>>) -> ArrayRef { + let mut flat: Vec> = Vec::new(); + let mut offsets: Vec = vec![0]; + for v in &values { + match v { + Some(inner) => { + flat.extend(inner); + offsets.push(flat.len() as i32); + } + None => { + offsets.push(flat.len() as i32); + } + } + } + let values_array = Arc::new(Float64Array::from(flat)) as ArrayRef; + let field = Arc::new(Field::new_list_field(DataType::Float64, true)); + let offset_buffer = OffsetBuffer::new(offsets.into()); + let null_buffer = arrow::buffer::NullBuffer::from( + values.iter().map(|v| v.is_some()).collect::>(), + ); + Arc::new(ListArray::new( + field, + offset_buffer, + values_array, + Some(null_buffer), + )) + } + + #[test] + fn test_inner_product_basic() { + let arr1 = + make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0), Some(3.0)])]); + let arr2 = + make_f64_list_array(vec![Some(vec![Some(4.0), Some(5.0), Some(6.0)])]); + let result = inner_product_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 32.0).abs() < 1e-10); + } + + #[test] + fn test_inner_product_null_array() { + let arr1 = make_f64_list_array(vec![None]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let result = inner_product_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!(result.is_null(0)); + } + + #[test] + fn test_inner_product_mismatched_lengths() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(2.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(1.0)])]); + let result = inner_product_inner(&[arr1, arr2]); + assert!(result.is_err()); + } + + #[test] + fn test_inner_product_empty_arrays() { + let arr1 = make_f64_list_array(vec![Some(vec![])]); + let arr2 = make_f64_list_array(vec![Some(vec![])]); + let result = inner_product_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 0.0).abs() < 1e-10); + } + + #[test] + fn test_inner_product_orthogonal() { + let arr1 = make_f64_list_array(vec![Some(vec![Some(1.0), Some(0.0)])]); + let arr2 = make_f64_list_array(vec![Some(vec![Some(0.0), Some(1.0)])]); + let result = inner_product_inner(&[arr1, arr2]).unwrap(); + let result = result.as_any().downcast_ref::().unwrap(); + assert!((result.value(0) - 0.0).abs() < 1e-10); + } +} diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 99b25ec96454b..b07f661588c7c 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -38,9 +38,12 @@ pub mod macros; pub mod array_has; +pub mod array_math; +pub mod array_normalize; pub mod arrays_zip; pub mod cardinality; pub mod concat; +pub mod cosine_distance; pub mod dimension; pub mod distance; pub mod empty; @@ -48,6 +51,7 @@ pub mod except; pub mod expr_ext; pub mod extract; pub mod flatten; +pub mod inner_product; pub mod length; pub mod make_array; pub mod map; @@ -68,6 +72,7 @@ pub mod set_ops; pub mod sort; pub mod string; pub mod utils; +pub mod vector_math; use datafusion_common::Result; use datafusion_execution::FunctionRegistry; @@ -80,11 +85,16 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::array_math::array_add; + pub use super::array_math::array_scale; + pub use super::array_math::array_subtract; + pub use super::array_normalize::array_normalize; pub use super::arrays_zip::arrays_zip; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; pub use super::concat::array_prepend; + pub use super::cosine_distance::cosine_distance; pub use super::dimension::array_dims; pub use super::dimension::array_ndims; pub use super::distance::array_distance; @@ -96,6 +106,7 @@ pub mod expr_fn { pub use super::extract::array_pop_front; pub use super::extract::array_slice; pub use super::flatten::flatten; + pub use super::inner_product::inner_product; pub use super::length::array_length; pub use super::make_array::make_array; pub use super::map_entries::map_entries; @@ -148,10 +159,16 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), + array_math::array_add_udf(), + array_math::array_subtract_udf(), + array_math::array_scale_udf(), + array_normalize::array_normalize_udf(), empty::array_empty_udf(), length::array_length_udf(), + cosine_distance::cosine_distance_udf(), distance::array_distance_udf(), flatten::flatten_udf(), + inner_product::inner_product_udf(), min_max::array_max_udf(), min_max::array_min_udf(), sort::array_sort_udf(), diff --git a/datafusion/functions-nested/src/vector_math.rs b/datafusion/functions-nested/src/vector_math.rs new file mode 100644 index 0000000000000..ff04f7b2fdf4a --- /dev/null +++ b/datafusion/functions-nested/src/vector_math.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Shared vector math primitives used by cosine_distance, inner_product, +//! array_normalize, and related functions. + +use arrow::array::{ArrayRef, Float64Array}; +use datafusion_common::cast::{ + as_float32_array, as_float64_array, as_int32_array, as_int64_array, +}; +use datafusion_common::{Result, exec_err}; + +/// Converts an array of any numeric type to a Float64Array. +pub fn convert_to_f64_array(array: &ArrayRef) -> Result { + match array.data_type() { + arrow::datatypes::DataType::Float64 => Ok(as_float64_array(array)?.clone()), + arrow::datatypes::DataType::Float32 => { + let array = as_float32_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + arrow::datatypes::DataType::Int64 => { + let array = as_int64_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + arrow::datatypes::DataType::Int32 => { + let array = as_int32_array(array)?; + let converted: Float64Array = + array.iter().map(|v| v.map(|v| v as f64)).collect(); + Ok(converted) + } + _ => exec_err!("Unsupported array type for conversion to Float64Array"), + } +} + +/// Computes dot product: sum(a[i] * b[i]) +pub fn dot_product_f64(a: &Float64Array, b: &Float64Array) -> f64 { + a.iter() + .zip(b.iter()) + .map(|(v1, v2)| v1.unwrap_or(0.0) * v2.unwrap_or(0.0)) + .sum() +} + +/// Computes sum of squares: sum(a[i]^2) +pub fn sum_of_squares_f64(a: &Float64Array) -> f64 { + a.iter() + .map(|v| { + let val = v.unwrap_or(0.0); + val * val + }) + .sum() +} + +/// Computes magnitude (L2 norm): sqrt(sum(a[i]^2)) +pub fn magnitude_f64(a: &Float64Array) -> f64 { + sum_of_squares_f64(a).sqrt() +} diff --git a/datafusion/sqllogictest/test_files/vector_functions.slt b/datafusion/sqllogictest/test_files/vector_functions.slt new file mode 100644 index 0000000000000..a049db47015c2 --- /dev/null +++ b/datafusion/sqllogictest/test_files/vector_functions.slt @@ -0,0 +1,302 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Vector Distance and Array Math Functions Tests +############# + +# ============================================================================ +# cosine_distance +# ============================================================================ + +# orthogonal vectors -> distance 1.0 +query R +select cosine_distance([1.0, 0.0], [0.0, 1.0]); +---- +1 + +# identical vectors -> distance 0.0 +query R +select cosine_distance([1.0, 2.0, 3.0], [1.0, 2.0, 3.0]); +---- +0 + +# opposite vectors -> distance 2.0 +query R +select cosine_distance([1.0, 0.0], [-1.0, 0.0]); +---- +2 + +# alias: list_cosine_distance +query R +select list_cosine_distance([1.0, 0.0], [0.0, 1.0]); +---- +1 + +# null elements inside array +query R +select cosine_distance([1.0, NULL], [1.0, 2.0]); +---- +NULL + +# mismatched lengths +query error +select cosine_distance([1.0, 2.0], [1.0]); + +# ============================================================================ +# inner_product / dot_product +# ============================================================================ + +# basic dot product: 1*4 + 2*5 + 3*6 = 32 +query R +select inner_product([1.0, 2.0, 3.0], [4.0, 5.0, 6.0]); +---- +32 + +# orthogonal vectors -> 0 +query R +select inner_product([1.0, 0.0], [0.0, 1.0]); +---- +0 + +# alias: list_inner_product +query R +select list_inner_product([1.0, 2.0], [3.0, 4.0]); +---- +11 + +# alias: dot_product +query R +select dot_product([1.0, 2.0], [3.0, 4.0]); +---- +11 + +# null elements inside array +query R +select inner_product([1.0, NULL], [1.0, 2.0]); +---- +NULL + +# mismatched lengths +query error +select inner_product([1.0, 2.0], [1.0]); + +# empty arrays +query R +select inner_product(ARRAY[]::DOUBLE[], ARRAY[]::DOUBLE[]); +---- +0 + +# ============================================================================ +# array_normalize +# ============================================================================ + +# 3-4-5 triangle: [3/5, 4/5] = [0.6, 0.8] +query ? +select array_normalize([3.0, 4.0]); +---- +[0.6, 0.8] + +# unit vector stays unit +query ? +select array_normalize([1.0, 0.0]); +---- +[1.0, 0.0] + +# alias: list_normalize +query ? +select list_normalize([3.0, 4.0]); +---- +[0.6, 0.8] + +# null elements +query ? +select array_normalize([1.0, NULL, 2.0]); +---- +NULL + +# zero vector -> null (undefined direction) +query ? +select array_normalize([0.0, 0.0]); +---- +NULL + +# ============================================================================ +# array_add +# ============================================================================ + +query ? +select array_add([1.0, 2.0], [3.0, 4.0]); +---- +[4.0, 6.0] + +# alias: list_add +query ? +select list_add([1.0, 2.0], [3.0, 4.0]); +---- +[4.0, 6.0] + +# mismatched lengths +query error +select array_add([1.0, 2.0], [1.0]); + +# empty arrays +query ? +select array_add(ARRAY[]::DOUBLE[], ARRAY[]::DOUBLE[]); +---- +[] + +# ============================================================================ +# array_subtract +# ============================================================================ + +query ? +select array_subtract([5.0, 3.0], [1.0, 2.0]); +---- +[4.0, 1.0] + +# alias: list_subtract +query ? +select list_subtract([5.0, 3.0], [1.0, 2.0]); +---- +[4.0, 1.0] + +# mismatched lengths +query error +select array_subtract([1.0], [1.0, 2.0]); + +# ============================================================================ +# array_scale +# ============================================================================ + +query ? +select array_scale([1.0, 2.0, 3.0], 2.0); +---- +[2.0, 4.0, 6.0] + +# alias: list_scale +query ? +select list_scale([1.0, 2.0], 3.0); +---- +[3.0, 6.0] + +# scale by zero +query ? +select array_scale([1.0, 2.0], 0.0); +---- +[0.0, 0.0] + +# negative scale +query ? +select array_scale([1.0, 2.0], -1.0); +---- +[-1.0, -2.0] + +# ============================================================================ +# Column-based null tests (NULL via column data, not bare SQL NULL literal) +# ============================================================================ + +statement ok +CREATE TABLE null_test (a DOUBLE[], b DOUBLE[]) AS VALUES + ([1.0, 2.0], [3.0, 4.0]), + (NULL, [1.0, 2.0]), + ([1.0, 2.0], NULL); + +query R +select cosine_distance(a, b) from null_test; +---- +0.0161300899 +NULL +NULL + +query R +select inner_product(a, b) from null_test; +---- +11 +NULL +NULL + +query ? +select array_normalize(a) from null_test; +---- +[0.4472135954999579, 0.8944271909999159] +NULL +[0.4472135954999579, 0.8944271909999159] + +query ? +select array_add(a, b) from null_test; +---- +[4.0, 6.0] +NULL +NULL + +query ? +select array_subtract(a, b) from null_test; +---- +[-2.0, -2.0] +NULL +NULL + +statement ok +DROP TABLE null_test; + +# ============================================================================ +# Vector search pattern: cosine distance ranking +# ============================================================================ + +statement ok +CREATE TABLE embeddings (id INT, emb DOUBLE[]) AS VALUES + (1, [1.0, 0.0, 0.0]), + (2, [0.0, 1.0, 0.0]), + (3, [0.707, 0.707, 0.0]), + (4, [0.0, 0.0, 1.0]), + (5, [0.577, 0.577, 0.577]); + +# Find nearest neighbors by cosine distance +query IR +select id, round(cosine_distance(emb, [1.0, 0.0, 0.0]), 6) as dist from embeddings order by dist limit 3; +---- +1 0 +3 0.292893 +5 0.42265 + +# Rank by inner product (higher = more similar) +query IR +select id, round(inner_product(emb, [1.0, 0.0, 0.0]), 6) as score from embeddings order by score desc limit 3; +---- +1 1 +3 0.707 +5 0.577 + +# Normalize then compute cosine distance (should give same results) +query IR +select id, round(cosine_distance(array_normalize(emb), [1.0, 0.0, 0.0]), 6) as dist from embeddings order by dist limit 3; +---- +1 0 +3 0.292893 +5 0.42265 + +# Combine array math: subtract query vector, get L2 distance of difference +query IR +select id, round(array_distance(array_subtract(emb, [1.0, 0.0, 0.0]), [0.0, 0.0, 0.0]), 6) as l2_diff from embeddings order by l2_diff limit 3; +---- +1 0 +3 0.765309 +5 0.919123 + +statement ok +DROP TABLE embeddings;