Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions native-engine/auron-planner/proto/auron.proto
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ enum WindowFunction {
ROW_NUMBER = 0;
RANK = 1;
DENSE_RANK = 2;
NTH_VALUE = 3;
NTH_VALUE_IGNORE_NULLS = 4;
}

enum AggFunction {
Expand Down
8 changes: 8 additions & 0 deletions native-engine/auron-planner/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,14 @@ impl PhysicalPlanner {
protobuf::WindowFunction::DenseRank => {
WindowFunction::RankLike(WindowRankType::DenseRank)
}
protobuf::WindowFunction::NthValue => {
WindowFunction::NthValue {
ignore_nulls: false,
}
}
protobuf::WindowFunction::NthValueIgnoreNulls => {
WindowFunction::NthValue { ignore_nulls: true }
}
},
protobuf::WindowFunctionType::Agg => match w.agg_func() {
protobuf::AggFunction::Min => WindowFunction::Agg(AggFunction::Min),
Expand Down
9 changes: 7 additions & 2 deletions native-engine/datafusion-ext-plans/src/window/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::{
agg::{AggFunction, agg::create_agg},
window::{
processors::{
agg_processor::AggProcessor, rank_processor::RankProcessor,
row_number_processor::RowNumberProcessor,
agg_processor::AggProcessor, nth_value_processor::NthValueProcessor,
rank_processor::RankProcessor, row_number_processor::RowNumberProcessor,
},
window_context::WindowContext,
},
Expand All @@ -36,6 +36,7 @@ pub mod window_context;
#[derive(Debug, Clone, Copy)]
pub enum WindowFunction {
RankLike(WindowRankType),
NthValue { ignore_nulls: bool },
Agg(AggFunction),
}

Expand Down Expand Up @@ -87,6 +88,10 @@ impl WindowExpr {
WindowFunction::RankLike(WindowRankType::DenseRank) => {
Ok(Box::new(RankProcessor::new(true)))
}
WindowFunction::NthValue { ignore_nulls } => Ok(Box::new(NthValueProcessor::try_new(
self.children.clone(),
ignore_nulls,
)?)),
WindowFunction::Agg(agg_func) => {
let agg = create_agg(
agg_func.clone(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
// limitations under the License.

pub mod agg_processor;
pub mod nth_value_processor;
pub mod rank_processor;
pub mod row_number_processor;
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use arrow::{array::ArrayRef, record_batch::RecordBatch};
use datafusion::{
common::{DataFusionError, Result, ScalarValue},
physical_expr::{PhysicalExprRef, expressions::Literal},
};
use datafusion_ext_commons::downcast_any;

use crate::window::{WindowFunctionProcessor, window_context::WindowContext};

pub struct NthValueProcessor {
input: PhysicalExprRef,
offset: usize,
ignore_nulls: bool,
cur_partition: Box<[u8]>,
observed_rows: usize,
nth_value: Option<ScalarValue>,
}

impl NthValueProcessor {
pub fn try_new(children: Vec<PhysicalExprRef>, ignore_nulls: bool) -> Result<Self> {
if children.len() != 2 {
return Err(DataFusionError::Execution(format!(
"nth_value expects input/offset children, got {}",
children.len(),
)));
}

let offset = match downcast_any!(children[1].as_ref(), Literal)?.value() {
ScalarValue::Int32(Some(value)) if *value > 0 => *value as usize,
ScalarValue::Int64(Some(value)) if *value > 0 => *value as usize,
ScalarValue::UInt32(Some(value)) if *value > 0 => *value as usize,
ScalarValue::UInt64(Some(value)) if *value > 0 => *value as usize,
other => {
return Err(DataFusionError::Execution(format!(
"nth_value offset must be a positive non-null foldable integer, got {other:?}",
)));
}
};

Ok(Self {
input: children[0].clone(),
offset,
ignore_nulls,
cur_partition: Box::default(),
observed_rows: 0,
nth_value: None,
})
}
}

impl WindowFunctionProcessor for NthValueProcessor {
fn process_batch(&mut self, context: &WindowContext, batch: &RecordBatch) -> Result<ArrayRef> {
let partition_rows = context.get_partition_rows(batch)?;
let input_values = self
.input
.evaluate(batch)
.and_then(|value| value.into_array(batch.num_rows()))?;
let null_value = ScalarValue::try_from(input_values.data_type().clone())?;
let mut output = Vec::with_capacity(batch.num_rows());

for row_idx in 0..batch.num_rows() {
let same_partition = !context.has_partition() || {
let partition_row = partition_rows.row(row_idx);
if partition_row.as_ref() != self.cur_partition.as_ref() {
self.cur_partition = partition_row.as_ref().into();
false
} else {
true
}
};

if !same_partition {
self.observed_rows = 0;
self.nth_value = None;
}

if self.nth_value.is_none() {
let value = ScalarValue::try_from_array(&input_values, row_idx)?;
let counts_for_offset = !self.ignore_nulls || !value.is_null();
if counts_for_offset {
self.observed_rows += 1;
if self.observed_rows == self.offset {
self.nth_value = Some(value);
}
}
}

output.push(self.nth_value.clone().unwrap_or_else(|| null_value.clone()));
}

ScalarValue::iter_to_array(output)
}
}
94 changes: 92 additions & 2 deletions native-engine/datafusion-ext-plans/src/window_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,11 @@ mod test {
use arrow::{array::*, datatypes::*, record_batch::RecordBatch};
use datafusion::{
assert_batches_eq,
common::Result,
physical_expr::{PhysicalSortExpr, expressions::Column},
common::{Result, ScalarValue},
physical_expr::{
PhysicalSortExpr,
expressions::{Column, Literal},
},
physical_plan::{ExecutionPlan, test::TestMemoryExec},
prelude::SessionContext,
};
Expand Down Expand Up @@ -320,6 +323,33 @@ mod test {
)?))
}

fn build_nullable_utf8_table(
a: (&str, &Vec<i32>),
b: (&str, &Vec<i32>),
c: (&str, &Vec<Option<&str>>),
) -> Result<Arc<dyn ExecutionPlan>> {
let schema = Schema::new(vec![
Field::new(a.0, DataType::Int32, false),
Field::new(b.0, DataType::Int32, false),
Field::new(c.0, DataType::Utf8, true),
]);

let batch = RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(Int32Array::from(a.1.clone())),
Arc::new(Int32Array::from(b.1.clone())),
Arc::new(StringArray::from(c.1.clone())),
],
)?;
let schema = batch.schema();
Ok(Arc::new(TestMemoryExec::try_new(
&[vec![batch]],
schema,
None,
)?))
}

#[tokio::test]
async fn test_window() -> Result<(), Box<dyn std::error::Error>> {
let session_ctx = SessionContext::new();
Expand Down Expand Up @@ -447,6 +477,66 @@ mod test {
Ok(())
}

#[tokio::test]
async fn test_nth_value_window() -> Result<(), Box<dyn std::error::Error>> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();

let input = build_nullable_utf8_table(
("grp", &vec![1, 1, 1, 2, 2]),
("id", &vec![1, 2, 3, 1, 2]),
("v", &vec![None, Some("b"), Some("c"), Some("x"), None]),
)?;
let window_exprs = vec![
WindowExpr::new(
WindowFunction::NthValue {
ignore_nulls: false,
},
vec![
Arc::new(Column::new("v", 2)),
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))),
],
Arc::new(Field::new("nth_value_all", DataType::Utf8, true)),
DataType::Utf8,
),
WindowExpr::new(
WindowFunction::NthValue { ignore_nulls: true },
vec![
Arc::new(Column::new("v", 2)),
Arc::new(Literal::new(ScalarValue::Int32(Some(2)))),
],
Arc::new(Field::new("nth_value_ignore_nulls", DataType::Utf8, true)),
DataType::Utf8,
),
];
let window = Arc::new(WindowExec::try_new(
input,
window_exprs,
vec![Arc::new(Column::new("grp", 0))],
vec![PhysicalSortExpr {
expr: Arc::new(Column::new("id", 1)),
options: Default::default(),
}],
None,
true,
)?);
let stream = window.execute(0, task_ctx)?;
let batches = datafusion::physical_plan::common::collect(stream).await?;
let expected = vec![
"+-----+----+---+---------------+------------------------+",
"| grp | id | v | nth_value_all | nth_value_ignore_nulls |",
"+-----+----+---+---------------+------------------------+",
"| 1 | 1 | | | |",
"| 1 | 2 | b | b | |",
"| 1 | 3 | c | b | c |",
"| 2 | 1 | x | | |",
"| 2 | 2 | | | |",
"+-----+----+---+---------------+------------------------+",
];
assert_batches_eq!(expected, &batches);
Ok(())
}

#[tokio::test]
async fn test_window_group_limit() -> Result<(), Box<dyn std::error::Error>> {
let session_ctx = SessionContext::new();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,58 @@ class AuronQuerySuite extends AuronQueryTest with BaseAuronSQLSuite with AuronSQ
}
}

test("nth_value window with row frame") {
if (AuronTestUtils.isSparkV31OrGreater) {
withTable("t_nth_value") {
sql("""
|create table t_nth_value using parquet as
|select * from values
| (1, 1, cast(null as string)),
| (1, 2, 'b'),
| (1, 3, 'c'),
| (2, 1, 'x'),
| (2, 2, cast(null as string))
|as t(grp, id, v)
|""".stripMargin)

if (AuronTestUtils.isSparkV32OrGreater) {
checkSparkAnswerAndOperator("""
|select
| grp,
| id,
| v,
| nth_value(v, 2) over (
| partition by grp
| order by id
| rows between unbounded preceding and current row
| ) as nth_value_all,
| nth_value(v, 2) ignore nulls over (
| partition by grp
| order by id
| rows between unbounded preceding and current row
| ) as nth_value_ignore_nulls
|from t_nth_value
|order by grp, id
|""".stripMargin)
} else {
checkSparkAnswerAndOperator("""
|select
| grp,
| id,
| v,
| nth_value(v, 2) over (
| partition by grp
| order by id
| rows between unbounded preceding and current row
| ) as nth_value_all
|from t_nth_value
|order by grp, id
|""".stripMargin)
}
}
}
}

test("standard LEFT ANTI JOIN includes NULL keys") {
// This test verifies that standard LEFT ANTI JOIN correctly includes NULL keys
// NULL keys should be in the result because NULL never matches anything
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,19 @@ abstract class NativeWindowBase(
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)

private def invokeNoArg[T](expr: Expression, methodName: String): T =
expr.getClass.getMethod(methodName).invoke(expr).asInstanceOf[T]

private def isNthValue(expr: Expression): Boolean = expr.getClass.getSimpleName == "NthValue"

private def nthValueInput(expr: Expression): Expression = invokeNoArg[Expression](expr, "input")

private def nthValueOffset(expr: Expression): Expression =
invokeNoArg[Expression](expr, "offset")

private def nthValueIgnoreNulls(expr: Expression): Boolean =
invokeNoArg[Boolean](expr, "ignoreNulls")

private def nativeWindowExprs = windowExpression.map { named =>
val field = NativeConverters.convertField(Util.getSchema(named :: Nil).fields(0))
val windowExprBuilder = pb.WindowExprNode.newBuilder().setField(field)
Expand Down Expand Up @@ -118,6 +131,22 @@ abstract class NativeWindowBase(
windowExprBuilder.setFuncType(pb.WindowFunctionType.Window)
windowExprBuilder.setWindowFunc(pb.WindowFunction.DENSE_RANK)

case e if isNthValue(e) =>
assert(
// Spark defaults ordered nth_value() to a RANGE frame. The current native executor
// only supports cumulative ROW frames, so keep the conversion scoped to the
// explicit ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW case.
spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounded, CurrentRow)
s"window frame not supported: ${spec.frameSpecification}")
windowExprBuilder.setFuncType(pb.WindowFunctionType.Window)
windowExprBuilder.setWindowFunc(if (nthValueIgnoreNulls(e)) {
pb.WindowFunction.NTH_VALUE_IGNORE_NULLS
} else {
pb.WindowFunction.NTH_VALUE
})
windowExprBuilder.addChildren(NativeConverters.convertExpr(nthValueInput(e)))
windowExprBuilder.addChildren(NativeConverters.convertExpr(nthValueOffset(e)))

case e: Sum =>
assert(
spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounde, CurrentRow)
Expand Down
Loading