Skip to content

Commit

Permalink
feat: support first_value/last_value in range query (#3448)
Browse files Browse the repository at this point in the history
* feat: support `first_value/last_value` in range query

* chore: add sqlness test on `count`

* chore: add test
  • Loading branch information
Taylor-lagrange committed Mar 11, 2024
1 parent 21ff362 commit 8c37c3f
Show file tree
Hide file tree
Showing 3 changed files with 616 additions and 6 deletions.
278 changes: 272 additions & 6 deletions src/query/src/range_select/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::any::Any;
use std::cmp::Ordering;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::Display;
Expand All @@ -21,8 +23,8 @@ use std::task::{Context, Poll};
use std::time::Duration;

use ahash::RandomState;
use arrow::compute::{self, cast_with_options, CastOptions};
use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
use arrow::compute::{self, cast_with_options, CastOptions, SortColumn};
use arrow_schema::{DataType, Field, Schema, SchemaRef, SortOptions, TimeUnit};
use common_query::DfPhysicalPlan;
use common_recordbatch::DfSendableRecordBatchStream;
use datafusion::common::{Result as DataFusionResult, Statistics};
Expand All @@ -35,10 +37,14 @@ use datafusion::physical_plan::{
SendableRecordBatchStream,
};
use datafusion::physical_planner::create_physical_sort_expr;
use datafusion_common::utils::get_arrayref_at_indices;
use datafusion_common::utils::{get_arrayref_at_indices, get_row_at_idx};
use datafusion_common::{DFField, DFSchema, DFSchemaRef, DataFusionError, ScalarValue};
use datafusion_expr::utils::exprlist_to_fields;
use datafusion_expr::{Accumulator, Expr, ExprSchemable, LogicalPlan, UserDefinedLogicalNodeCore};
use datafusion_expr::utils::{exprlist_to_fields, COUNT_STAR_EXPANSION};
use datafusion_expr::{
lit, Accumulator, AggregateFunction, Expr, ExprSchemable, LogicalPlan,
UserDefinedLogicalNodeCore,
};
use datafusion_physical_expr::aggregate::utils::down_cast_any_ref;
use datafusion_physical_expr::expressions::create_aggregate_expr as create_aggr_expr;
use datafusion_physical_expr::hash_utils::create_hashes;
use datafusion_physical_expr::{
Expand All @@ -58,6 +64,140 @@ use crate::error::{DataFusionSnafu, RangeQuerySnafu, Result};

type Millisecond = <TimestampMillisecondType as ArrowPrimitiveType>::Native;

/// Implementation of `first_value`/`last_value`
/// aggregate function adapted to range query
#[derive(Debug)]
struct RangeFirstListValue {
/// calculate expr
expr: Arc<dyn PhysicalExpr>,
order_bys: Vec<PhysicalSortExpr>,
}

impl RangeFirstListValue {
pub fn new_aggregate_expr(
expr: Arc<dyn PhysicalExpr>,
order_bys: Vec<PhysicalSortExpr>,
) -> Arc<dyn AggregateExpr> {
Arc::new(Self { expr, order_bys })
}
}

impl PartialEq<dyn Any> for RangeFirstListValue {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| self.expr.eq(&x.expr) && self.order_bys.iter().eq(x.order_bys.iter()))
.unwrap_or(false)
}
}

impl AggregateExpr for RangeFirstListValue {
fn as_any(&self) -> &dyn std::any::Any {
self
}

fn create_accumulator(&self) -> DataFusionResult<Box<dyn Accumulator>> {
Ok(Box::new(RangeFirstListValueAcc::new(
self.order_bys.iter().map(|order| order.options).collect(),
)))
}

fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
let mut exprs: Vec<_> = self
.order_bys
.iter()
.map(|order| order.expr.clone())
.collect();
exprs.push(self.expr.clone());
exprs
}

fn field(&self) -> DataFusionResult<Field> {
unreachable!("AggregateExpr::field will not be used in range query")
}

fn state_fields(&self) -> DataFusionResult<Vec<Field>> {
unreachable!("AggregateExpr::state_fields will not be used in range query")
}
}

#[derive(Debug)]
pub struct RangeFirstListValueAcc {
pub sort_options: Vec<SortOptions>,
pub sort_columns: Vec<ScalarValue>,
pub data: Option<ScalarValue>,
}

impl RangeFirstListValueAcc {
pub fn new(sort_options: Vec<SortOptions>) -> Self {
Self {
sort_options,
sort_columns: vec![],
data: None,
}
}
}

impl Accumulator for RangeFirstListValueAcc {
fn update_batch(&mut self, values: &[ArrayRef]) -> DataFusionResult<()> {
let columns: Vec<_> = values
.iter()
.zip(self.sort_options.iter())
.map(|(v, s)| SortColumn {
values: v.clone(),
options: Some(*s),
})
.collect();
// finding the Top1 problem with complexity O(n)
let idx = compute::lexsort_to_indices(&columns, Some(1))?.value(0);
let vs = get_row_at_idx(values, idx as usize)?;
let need_update = self.data.is_none()
|| vs
.iter()
.zip(self.sort_columns.iter())
.zip(self.sort_options.iter())
.find_map(|((new_value, old_value), sort_option)| {
if new_value.is_null() && old_value.is_null() {
None
} else if sort_option.nulls_first
&& (new_value.is_null() || old_value.is_null())
{
Some(new_value.is_null())
} else {
new_value.partial_cmp(old_value).map(|x| {
(x == Ordering::Greater && sort_option.descending)
|| (x == Ordering::Less && !sort_option.descending)
})
}
})
.unwrap_or(false);
if need_update {
self.sort_columns = vs;
self.data = Some(ScalarValue::try_from_array(
&values[self.sort_options.len()],
idx as usize,
)?);
}
Ok(())
}

fn evaluate(&self) -> DataFusionResult<ScalarValue> {
Ok(self.data.clone().unwrap_or(ScalarValue::Null))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}

fn state(&self) -> DataFusionResult<Vec<ScalarValue>> {
unreachable!("Accumulator::state will not be used in range query")
}

fn merge_batch(&mut self, _states: &[ArrayRef]) -> DataFusionResult<()> {
unreachable!("Accumulator::merge_batch will not be used in range query")
}
}

#[derive(PartialEq, Eq, Debug, Hash, Clone)]
pub enum Fill {
Null,
Expand Down Expand Up @@ -271,6 +411,7 @@ pub struct RangeSelect {
pub align: Duration,
pub align_to: i64,
pub time_index: String,
pub time_expr: Expr,
pub by: Vec<Expr>,
pub schema: DFSchemaRef,
pub by_schema: DFSchemaRef,
Expand Down Expand Up @@ -382,6 +523,7 @@ impl RangeSelect {
align,
align_to,
time_index: time_index_name,
time_expr: time_index,
schema,
by_schema,
by,
Expand Down Expand Up @@ -440,6 +582,7 @@ impl UserDefinedLogicalNodeCore for RangeSelect {
range_expr: self.range_expr.clone(),
input: Arc::new(inputs[0].clone()),
time_index: self.time_index.clone(),
time_expr: self.time_expr.clone(),
schema: self.schema.clone(),
by: self.by.clone(),
by_schema: self.by_schema.clone(),
Expand All @@ -452,14 +595,28 @@ impl UserDefinedLogicalNodeCore for RangeSelect {
impl RangeSelect {
fn create_physical_expr_list(
&self,
is_count_aggr: bool,
exprs: &[Expr],
df_schema: &Arc<DFSchema>,
schema: &Schema,
session_state: &SessionState,
) -> DfResult<Vec<Arc<dyn PhysicalExpr>>> {
exprs
.iter()
.map(|by| create_physical_expr(by, df_schema, schema, session_state.execution_props()))
.map(|e| match e {
// `count(*)` will be rewritten by `CountWildcardRule` into `count(1)` when optimizing logical plan.
// The modification occurs after range plan rewrite.
// At this time, aggregate plan has been replaced by a custom range plan,
// so `CountWildcardRule` has not been applied.
// We manually modify it when creating the physical plan.
Expr::Wildcard if is_count_aggr => create_physical_expr(
&lit(COUNT_STAR_EXPANSION),
df_schema,
schema,
session_state.execution_props(),
),
_ => create_physical_expr(e, df_schema, schema, session_state.execution_props()),
})
.collect::<DfResult<Vec<_>>>()
}

Expand Down Expand Up @@ -488,6 +645,72 @@ impl RangeSelect {
.iter()
.map(|range_fn| {
let expr = match &range_fn.expr {
Expr::AggregateFunction(aggr)
if aggr.fun == AggregateFunction::FirstValue
|| aggr.fun == AggregateFunction::LastValue =>
{
// Because we only need to find the first_value/last_value,
// the complexity of sorting the entire batch is O(nlogn).
// We can sort the batch with limit 1.
// In this case, the algorithm degenerates into finding the Top1 problem with complexity O(n).
// We need reverse the sort order of last_value to correctly apply limit 1 when sorting.
let order_by = if let Some(exprs) = &aggr.order_by {
exprs
.iter()
.map(|x| {
create_physical_sort_expr(
x,
input_dfschema,
&input_schema,
session_state.execution_props(),
)
.map(|expr| {
// reverse the last_value sort
if aggr.fun == AggregateFunction::LastValue {
PhysicalSortExpr {
expr: expr.expr,
options: SortOptions {
descending: !expr.options.descending,
nulls_first: !expr.options.nulls_first,
},
}
} else {
expr
}
})
})
.collect::<DfResult<Vec<_>>>()?
} else {
// if user not assign order by, time index is needed as default ordering
let time_index = create_physical_expr(
&self.time_expr,
input_dfschema,
&input_schema,
session_state.execution_props(),
)?;
vec![PhysicalSortExpr {
expr: time_index,
options: SortOptions {
descending: aggr.fun == AggregateFunction::LastValue,
nulls_first: false,
},
}]
};
let arg = self.create_physical_expr_list(
false,
&aggr.args,
input_dfschema,
&input_schema,
session_state,
)?;
// first_value/last_value has only one param.
// The param have been checked by datafusion in logical plan stage.
// We can safely assume that there is only one element here.
Ok(RangeFirstListValue::new_aggregate_expr(
arg[0].clone(),
order_by,
))
}
Expr::AggregateFunction(aggr) => {
let order_by = if let Some(exprs) = &aggr.order_by {
exprs
Expand All @@ -508,6 +731,7 @@ impl RangeSelect {
&aggr.fun,
false,
&self.create_physical_expr_list(
aggr.fun == AggregateFunction::Count,
&aggr.args,
input_dfschema,
&input_schema,
Expand All @@ -523,6 +747,7 @@ impl RangeSelect {
let expr = create_aggr_udf_expr(
&aggr_udf.fun,
&self.create_physical_expr_list(
false,
&aggr_udf.args,
input_dfschema,
&input_schema,
Expand Down Expand Up @@ -564,6 +789,7 @@ impl RangeSelect {
align: self.align.as_millis() as Millisecond,
align_to: self.align_to,
by: self.create_physical_expr_list(
false,
&self.by,
input_dfschema,
&input_schema,
Expand Down Expand Up @@ -1447,4 +1673,44 @@ mod test {
Fill::Linear.apply_fill_strategy(&ts, &mut test1).unwrap();
assert_eq!(test, test1);
}

#[test]
fn test_fist_last_accumulator() {
let mut acc = RangeFirstListValueAcc::new(vec![
SortOptions {
descending: true,
nulls_first: false,
},
SortOptions {
descending: false,
nulls_first: true,
},
]);
let batch1: Vec<Arc<dyn Array>> = vec![
Arc::new(nullable_array!(Float64;
0.0, null, 0.0, null, 1.0
)),
Arc::new(nullable_array!(Float64;
5.0, null, 4.0, null, 3.0
)),
Arc::new(nullable_array!(Int64;
1, 2, 3, 4, 5
)),
];
let batch2: Vec<Arc<dyn Array>> = vec![
Arc::new(nullable_array!(Float64;
3.0, 3.0, 3.0, 3.0, 3.0
)),
Arc::new(nullable_array!(Float64;
null,3.0, 3.0, 3.0, 3.0
)),
Arc::new(nullable_array!(Int64;
6, 7, 8, 9, 10
)),
];
acc.update_batch(&batch1).unwrap();
assert_eq!(acc.evaluate().unwrap(), ScalarValue::Int64(Some(5)));
acc.update_batch(&batch2).unwrap();
assert_eq!(acc.evaluate().unwrap(), ScalarValue::Int64(Some(6)));
}
}

0 comments on commit 8c37c3f

Please sign in to comment.