From 2942ea77728a62e98207576c590e3929e9bca291 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 10 May 2021 17:47:01 +0800 Subject: [PATCH 1/3] fix 305 --- datafusion/Cargo.toml | 2 +- datafusion/src/physical_plan/functions.rs | 6 ++++ .../src/physical_plan/math_expressions.rs | 36 ++++++++++++++++++- datafusion/src/physical_plan/type_coercion.rs | 1 - datafusion/tests/sql.rs | 9 +++++ 5 files changed, 51 insertions(+), 3 deletions(-) diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 4d98fdb1b207..a127076135f1 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -65,9 +65,9 @@ unicode-segmentation = { version = "^1.7.1", optional = true } regex = { version = "^1.4.3", optional = true } lazy_static = { version = "^1.4.0", optional = true } smallvec = { version = "1.6", features = ["union"] } +rand = "0.8" [dev-dependencies] -rand = "0.8" criterion = "0.3" tempfile = "3" doc-comment = "0.3" diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index c0c915f29a72..69a2a60b9a0e 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -169,6 +169,8 @@ pub enum BuiltinScalarFunction { NullIf, /// octet_length OctetLength, + /// random + Random, /// regexp_replace RegexpReplace, /// repeat @@ -275,6 +277,7 @@ impl FromStr for BuiltinScalarFunction { "md5" => BuiltinScalarFunction::MD5, "nullif" => BuiltinScalarFunction::NullIf, "octet_length" => BuiltinScalarFunction::OctetLength, + "random" => BuiltinScalarFunction::Random, "regexp_replace" => BuiltinScalarFunction::RegexpReplace, "repeat" => BuiltinScalarFunction::Repeat, "replace" => BuiltinScalarFunction::Replace, @@ -438,6 +441,7 @@ pub fn return_type( )); } }), + BuiltinScalarFunction::Random => Ok(DataType::Float64), BuiltinScalarFunction::RegexpReplace => Ok(match arg_types[0] { DataType::LargeUtf8 => DataType::LargeUtf8, DataType::Utf8 => DataType::Utf8, @@ -742,6 +746,7 @@ pub fn create_physical_expr( BuiltinScalarFunction::Ln => math_expressions::ln, BuiltinScalarFunction::Log10 => math_expressions::log10, BuiltinScalarFunction::Log2 => math_expressions::log2, + BuiltinScalarFunction::Random => math_expressions::random, BuiltinScalarFunction::Round => math_expressions::round, BuiltinScalarFunction::Signum => math_expressions::signum, BuiltinScalarFunction::Sin => math_expressions::sin, @@ -1307,6 +1312,7 @@ fn signature(fun: &BuiltinScalarFunction) -> Signature { Signature::Exact(vec![DataType::Utf8, DataType::Utf8, DataType::Utf8]), Signature::Exact(vec![DataType::LargeUtf8, DataType::Utf8, DataType::Utf8]), ]), + BuiltinScalarFunction::Random => Signature::Exact(vec![]), // math expressions expect 1 argument of type f64 or f32 // priority is given to f64 because e.g. `sqrt(1i32)` is in IR (real numbers) and thus we // return the best approximation for it (in f64). diff --git a/datafusion/src/physical_plan/math_expressions.rs b/datafusion/src/physical_plan/math_expressions.rs index 0e0bed2deac2..cfc239cde661 100644 --- a/datafusion/src/physical_plan/math_expressions.rs +++ b/datafusion/src/physical_plan/math_expressions.rs @@ -16,11 +16,12 @@ // under the License. //! Math expressions - use super::{ColumnarValue, ScalarValue}; use crate::error::{DataFusionError, Result}; use arrow::array::{Float32Array, Float64Array}; use arrow::datatypes::DataType; +use rand::{thread_rng, Rng}; +use std::iter; use std::sync::Arc; macro_rules! downcast_compute_op { @@ -100,3 +101,36 @@ math_unary_function!("exp", exp); math_unary_function!("ln", ln); math_unary_function!("log2", log2); math_unary_function!("log10", log10); + +/// random SQL function +pub fn random(args: &[ColumnarValue]) -> Result { + let len: usize = match &args[0] { + ColumnarValue::Array(array) => array.len(), + _ => { + return Err(DataFusionError::Internal( + "Expect random function to take no param".to_string(), + )) + } + }; + let mut rng = thread_rng(); + let values = iter::repeat_with(|| rng.gen_range(0.0..1.0)).take(len); + let array = Float64Array::from_iter_values(values); + Ok(ColumnarValue::Array(Arc::new(array))) +} + +#[cfg(test)] +mod tests { + + use super::*; + use arrow::array::{Float64Array, NullArray}; + + #[test] + fn test_random_expression() { + let args = vec![ColumnarValue::Array(Arc::new(NullArray::new(1)))]; + let array = random(&args).expect("fail").into_array(1); + let floats = array.as_any().downcast_ref::().expect("fail"); + + assert_eq!(floats.len(), 1); + assert!(0.0 <= floats.value(0) && floats.value(0) < 1.0); + } +} diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index 98ae09cc381d..06d3739b53b2 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -75,7 +75,6 @@ pub fn data_types( if current_types.is_empty() { return Ok(vec![]); } - let valid_types = get_valid_types(signature, current_types)?; if valid_types diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 6edb75733490..85e71343c7ea 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -2906,7 +2906,16 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { let t2 = t2_naive.timestamp(); assert!(t1 <= t2 && t2 <= t3); assert_eq!(res2, res1); +} +#[tokio::test] +async fn test_random_expression() -> Result<()> { + let mut ctx = create_ctx()?; + let sql = format!("SELECT random() r1"); + let actual = execute(&mut ctx, sql.as_str()).await; + let r1 = actual[0][0].parse::().unwrap(); + assert!(0.0 <= r1); + assert!(r1 < 1.0); Ok(()) } From ec643c81a539d701b0d427176295a934973510ce Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 16 May 2021 10:08:06 +0800 Subject: [PATCH 2/3] add supports_zero_argument --- datafusion/src/physical_plan/functions.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 69a2a60b9a0e..18becf2c8e4e 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -221,7 +221,10 @@ impl BuiltinScalarFunction { /// an allowlist of functions to take zero arguments, so that they will get special treatment /// while executing. fn supports_zero_argument(&self) -> bool { - matches!(self, BuiltinScalarFunction::Now) + matches!( + self, + BuiltinScalarFunction::Random | BuiltinScalarFunction::Now + ) } } From fb74a591fe99eee6ca0729b3e2816272deea1580 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Sun, 16 May 2021 10:44:12 +0800 Subject: [PATCH 3/3] fix unit test --- datafusion/tests/sql.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 85e71343c7ea..eb50661b42e6 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -2906,13 +2906,15 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { let t2 = t2_naive.timestamp(); assert!(t1 <= t2 && t2 <= t3); assert_eq!(res2, res1); + + Ok(()) } #[tokio::test] async fn test_random_expression() -> Result<()> { let mut ctx = create_ctx()?; - let sql = format!("SELECT random() r1"); - let actual = execute(&mut ctx, sql.as_str()).await; + let sql = "SELECT random() r1"; + let actual = execute(&mut ctx, sql).await; let r1 = actual[0][0].parse::().unwrap(); assert!(0.0 <= r1); assert!(r1 < 1.0);