diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 4d98fdb1b207..a127076135f1 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -65,9 +65,9 @@ unicode-segmentation = { version = "^1.7.1", optional = true } regex = { version = "^1.4.3", optional = true } lazy_static = { version = "^1.4.0", optional = true } smallvec = { version = "1.6", features = ["union"] } +rand = "0.8" [dev-dependencies] -rand = "0.8" criterion = "0.3" tempfile = "3" doc-comment = "0.3" diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index c0c915f29a72..18becf2c8e4e 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -169,6 +169,8 @@ pub enum BuiltinScalarFunction { NullIf, /// octet_length OctetLength, + /// random + Random, /// regexp_replace RegexpReplace, /// repeat @@ -219,7 +221,10 @@ impl BuiltinScalarFunction { /// an allowlist of functions to take zero arguments, so that they will get special treatment /// while executing. fn supports_zero_argument(&self) -> bool { - matches!(self, BuiltinScalarFunction::Now) + matches!( + self, + BuiltinScalarFunction::Random | BuiltinScalarFunction::Now + ) } } @@ -275,6 +280,7 @@ impl FromStr for BuiltinScalarFunction { "md5" => BuiltinScalarFunction::MD5, "nullif" => BuiltinScalarFunction::NullIf, "octet_length" => BuiltinScalarFunction::OctetLength, + "random" => BuiltinScalarFunction::Random, "regexp_replace" => BuiltinScalarFunction::RegexpReplace, "repeat" => BuiltinScalarFunction::Repeat, "replace" => BuiltinScalarFunction::Replace, @@ -438,6 +444,7 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::Random => Ok(DataType::Float64), BuiltinScalarFunction::RegexpReplace => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -742,6 +749,7 @@ pub fn create_physical_expr( BuiltinScalarFunction::Ln => math_expressions::ln, BuiltinScalarFunction::Log10 => math_expressions::log10, BuiltinScalarFunction::Log2 => math_expressions::log2, + BuiltinScalarFunction::Random => math_expressions::random, BuiltinScalarFunction::Round => math_expressions::round, BuiltinScalarFunction::Signum => math_expressions::signum, BuiltinScalarFunction::Sin => math_expressions::sin, @@ -1307,6 +1315,7 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]), ]), + BuiltinScalarFunction::Random => Signature::Exact(vec![]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index 0e0bed2deac2..cfc239cde661 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -16,11 +16,12 @@ // under the License. //! Math expressions - use super::{ColumnarValue, ScalarValue}; use crate::error::{DataFusionError, Result}; use arrow::array::{Float32Array, Float64Array}; use arrow::datatypes::DataType; +use rand::{thread_rng, Rng}; +use std::iter; use std::sync::Arc; macro_rules! downcast_compute_op { @@ -100,3 +101,36 @@ math_unary_function!("exp", exp); math_unary_function!("ln", ln); math_unary_function!("log2", log2); math_unary_function!("log10", log10); + +/// random SQL function +pub fn random(args: &[ColumnarValue]) -> Result { + let len: usize = match &args[0] { + ColumnarValue::Array(array) => array.len(), + _ => { + return Err(DataFusionError::Internal( + "Expect random function to take no param".to_string(), + )) + } + }; + let mut rng = thread_rng(); + let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len); + let array = Float64Array::from_iter_values(values); + Ok(ColumnarValue::Array(Arc::new(array))) +} + +#[cfg(test)] +mod tests { + + use super::*; + use arrow::array::{Float64Array, NullArray}; + + #[test] + fn test_random_expression() { + let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; + let array = random(&args).expect("fail").into_array(1); + let floats = array.as_any().downcast_ref::().expect("fail"); + + assert_eq!(floats.len(), 1); + assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); + } +} diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index 98ae09cc381d..06d3739b53b2 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -75,7 +75,6 @@ pub fn data_types( if current_types.is_empty() { return Ok(vec![]); } - let valid_types = get_valid_types(signature, current_types)?; if valid_types diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 6edb75733490..eb50661b42e6 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -2910,6 +2910,17 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { Ok(()) } +#[tokio::test] +async fn test_random_expression() -> Result<()> { + let mut ctx = create_ctx()?; + let sql = "SELECT random() r1"; + let actual = execute(&mut ctx, sql).await; + let r1 = actual[0][0].parse::().unwrap(); + assert!(0.0 <= r1); + assert!(r1 < 1.0); + Ok(()) +} + #[tokio::test] async fn test_cast_expressions_error() -> Result<()> { // sin(utf8) should error