-
Notifications
You must be signed in to change notification settings - Fork 992
/
push_down_filter.rs
2973 lines (2704 loc) · 109 KB
/
push_down_filter.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! [`PushDownFilter`] applies filters as early as possible
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use itertools::Itertools;
use datafusion_common::tree_node::{
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{
internal_err, plan_datafusion_err, qualified_name, Column, DFSchema, DFSchemaRef,
JoinConstraint, Result,
};
use datafusion_expr::expr::Alias;
use datafusion_expr::expr_rewriter::replace_col;
use datafusion_expr::logical_plan::tree_node::unwrap_arc;
use datafusion_expr::logical_plan::{
CrossJoin, Join, JoinType, LogicalPlan, TableScan, Union,
};
use datafusion_expr::utils::{conjunction, split_conjunction, split_conjunction_owned};
use datafusion_expr::{
and, build_join_schema, or, BinaryExpr, Expr, Filter, LogicalPlanBuilder, Operator,
TableProviderFilterPushDown,
};
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
/// Optimizer rule for pushing (moving) filter expressions down in a plan so
/// they are applied as early as possible.
///
/// # Introduction
///
/// The goal of this rule is to improve query performance by eliminating
/// redundant work.
///
/// For example, given a plan that sorts all values where `a > 10`:
///
/// ```text
/// Filter (a > 10)
/// Sort (a, b)
/// ```
///
/// A better plan is to filter the data *before* the Sort, which sorts fewer
/// rows and therefore does less work overall:
///
/// ```text
/// Sort (a, b)
/// Filter (a > 10) <-- Filter is moved before the sort
/// ```
///
/// However it is not always possible to push filters down. For example, given a
/// plan that finds the top 3 values and then keeps only those that are greater
/// than 10, if the filter is pushed below the limit it would produce a
/// different result.
///
/// ```text
/// Filter (a > 10) <-- can not move this Filter before the limit
/// Limit (fetch=3)
/// Sort (a, b)
/// ```
///
///
/// More formally, a filter-commutative operation is an operation `op` that
/// satisfies `filter(op(data)) = op(filter(data))`.
///
/// The filter-commutative property is plan and column-specific. A filter on `a`
/// can be pushed through a `Aggregate(group_by = [a], agg=[SUM(b))`. However, a
/// filter on `SUM(b)` can not be pushed through the same aggregate.
///
/// # Handling Conjunctions
///
/// It is possible to only push down **part** of a filter expression if is
/// connected with `AND`s (more formally if it is a "conjunction").
///
/// For example, given the following plan:
///
/// ```text
/// Filter(a > 10 AND SUM(b) < 5)
/// Aggregate(group_by = [a], agg = [SUM(b))
/// ```
///
/// The `a > 10` is commutative with the `Aggregate` but `SUM(b) < 5` is not.
/// Therefore it is possible to only push part of the expression, resulting in:
///
/// ```text
/// Filter(SUM(b) < 5)
/// Aggregate(group_by = [a], agg = [SUM(b))
/// Filter(a > 10)
/// ```
///
/// # Handling Column Aliases
///
/// This optimizer must sometimes handle re-writing filter expressions when they
/// pushed, for example if there is a projection that aliases `a+1` to `"b"`:
///
/// ```text
/// Filter (b > 10)
/// Projection: [a+1 AS "b"] <-- changes the name of `a+1` to `b`
/// ```
///
/// To apply the filter prior to the `Projection`, all references to `b` must be
/// rewritten to `a+1`:
///
/// ```text
/// Projection: a AS "b"
/// Filter: (a + 1 > 10) <--- changed from b to a + 1
/// ```
/// # Implementation Notes
///
/// This implementation performs a single pass through the plan, "pushing" down
/// filters. When it passes through a filter, it stores that filter, and when it
/// reaches a plan node that does not commute with that filter, it adds the
/// filter to that place. When it passes through a projection, it re-writes the
/// filter's expression taking into account that projection.
#[derive(Default)]
pub struct PushDownFilter {}
// For a given JOIN logical plan, determine whether each side of the join is preserved.
// We say a join side is preserved if the join returns all or a subset of the rows from
// the relevant side, such that each row of the output table directly maps to a row of
// the preserved input table. If a table is not preserved, it can provide extra null rows.
// That is, there may be rows in the output table that don't directly map to a row in the
// input table.
//
// For example:
// - In an inner join, both sides are preserved, because each row of the output
// maps directly to a row from each side.
// - In a left join, the left side is preserved and the right is not, because
// there may be rows in the output that don't directly map to a row in the
// right input (due to nulls filling where there is no match on the right).
//
// This is important because we can always push down post-join filters to a preserved
// side of the join, assuming the filter only references columns from that side. For the
// non-preserved side it can be more tricky.
//
// Returns a tuple of booleans - (left_preserved, right_preserved).
fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((true, false)),
JoinType::Right => Ok((false, true)),
JoinType::Full => Ok((false, false)),
// No columns from the right side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
// No columns from the left side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
},
LogicalPlan::CrossJoin(_) => Ok((true, true)),
_ => internal_err!("lr_is_preserved only valid for JOIN nodes"),
}
}
// For a given JOIN logical plan, determine whether each side of the join is preserved
// in terms on join filtering.
// Predicates from join filter can only be pushed to preserved join side.
fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
match plan {
LogicalPlan::Join(Join { join_type, .. }) => match join_type {
JoinType::Inner => Ok((true, true)),
JoinType::Left => Ok((false, true)),
JoinType::Right => Ok((true, false)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::RightSemi => Ok((true, true)),
JoinType::LeftAnti => Ok((false, true)),
JoinType::RightAnti => Ok((true, false)),
},
LogicalPlan::CrossJoin(_) => {
internal_err!("on_lr_is_preserved cannot be applied to CROSSJOIN nodes")
}
_ => internal_err!("on_lr_is_preserved only valid for JOIN nodes"),
}
}
// Determine which predicates in state can be pushed down to a given side of a join.
// To determine this, we need to know the schema of the relevant join side and whether
// or not the side's rows are preserved when joining. If the side is not preserved, we
// do not push down anything. Otherwise we can push down predicates where all of the
// relevant columns are contained on the relevant join side's schema.
fn can_pushdown_join_predicate(predicate: &Expr, schema: &DFSchema) -> Result<bool> {
let schema_columns = schema
.iter()
.flat_map(|(qualifier, field)| {
[
Column::new(qualifier.cloned(), field.name()),
// we need to push down filter using unqualified column as well
Column::new_unqualified(field.name()),
]
})
.collect::<HashSet<_>>();
let columns = predicate.to_columns()?;
Ok(schema_columns
.intersection(&columns)
.collect::<HashSet<_>>()
.len()
== columns.len())
}
// Determine whether the predicate can evaluate as the join conditions
fn can_evaluate_as_join_condition(predicate: &Expr) -> Result<bool> {
let mut is_evaluate = true;
predicate.apply(|expr| match expr {
Expr::Column(_)
| Expr::Literal(_)
| Expr::Placeholder(_)
| Expr::ScalarVariable(_, _) => Ok(TreeNodeRecursion::Jump),
Expr::Exists { .. }
| Expr::InSubquery(_)
| Expr::ScalarSubquery(_)
| Expr::OuterReferenceColumn(_, _)
| Expr::Unnest(_)
| Expr::ScalarFunction(_) => {
is_evaluate = false;
Ok(TreeNodeRecursion::Stop)
}
Expr::Alias(_)
| Expr::BinaryExpr(_)
| Expr::Like(_)
| Expr::SimilarTo(_)
| Expr::Not(_)
| Expr::IsNotNull(_)
| Expr::IsNull(_)
| Expr::IsTrue(_)
| Expr::IsFalse(_)
| Expr::IsUnknown(_)
| Expr::IsNotTrue(_)
| Expr::IsNotFalse(_)
| Expr::IsNotUnknown(_)
| Expr::Negative(_)
| Expr::GetIndexedField(_)
| Expr::Between(_)
| Expr::Case(_)
| Expr::Cast(_)
| Expr::TryCast(_)
| Expr::InList { .. } => Ok(TreeNodeRecursion::Continue),
Expr::Sort(_)
| Expr::AggregateFunction(_)
| Expr::WindowFunction(_)
| Expr::Wildcard { .. }
| Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"),
})?;
Ok(is_evaluate)
}
// examine OR clause to see if any useful clauses can be extracted and push down.
// extract at least one qual from each sub clauses of OR clause, then form the quals
// to new OR clause as predicate.
//
// Filter: (a = c and a < 20) or (b = d and b > 10)
// join/crossjoin:
// TableScan: projection=[a, b]
// TableScan: projection=[c, d]
//
// is optimized to
//
// Filter: (a = c and a < 20) or (b = d and b > 10)
// join/crossjoin:
// Filter: (a < 20) or (b > 10)
// TableScan: projection=[a, b]
// TableScan: projection=[c, d]
//
// In general, predicates of this form:
//
// (A AND B) OR (C AND D)
//
// will be transformed to
//
// ((A AND B) OR (C AND D)) AND (A OR C)
//
// OR
//
// ((A AND B) OR (C AND D)) AND ((A AND B) OR C)
//
// OR
//
// do nothing.
//
fn extract_or_clauses_for_join<'a>(
filters: &'a [Expr],
schema: &'a DFSchema,
) -> impl Iterator<Item = Expr> + 'a {
let schema_columns = schema
.iter()
.flat_map(|(qualifier, field)| {
[
Column::new(qualifier.cloned(), field.name()),
// we need to push down filter using unqualified column as well
Column::new_unqualified(field.name()),
]
})
.collect::<HashSet<_>>();
// new formed OR clauses and their column references
filters.iter().filter_map(move |expr| {
if let Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Or,
right,
}) = expr
{
let left_expr = extract_or_clause(left.as_ref(), &schema_columns);
let right_expr = extract_or_clause(right.as_ref(), &schema_columns);
// If nothing can be extracted from any sub clauses, do nothing for this OR clause.
if let (Some(left_expr), Some(right_expr)) = (left_expr, right_expr) {
return Some(or(left_expr, right_expr));
}
}
None
})
}
// extract qual from OR sub-clause.
//
// A qual is extracted if it only contains set of column references in schema_columns.
//
// For AND clause, we extract from both sub-clauses, then make new AND clause by extracted
// clauses if both extracted; Otherwise, use the extracted clause from any sub-clauses or None.
//
// For OR clause, we extract from both sub-clauses, then make new OR clause by extracted clauses if both extracted;
// Otherwise, return None.
//
// For other clause, apply the rule above to extract clause.
fn extract_or_clause(expr: &Expr, schema_columns: &HashSet<Column>) -> Option<Expr> {
let mut predicate = None;
match expr {
Expr::BinaryExpr(BinaryExpr {
left: l_expr,
op: Operator::Or,
right: r_expr,
}) => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);
if let (Some(l_expr), Some(r_expr)) = (l_expr, r_expr) {
predicate = Some(or(l_expr, r_expr));
}
}
Expr::BinaryExpr(BinaryExpr {
left: l_expr,
op: Operator::And,
right: r_expr,
}) => {
let l_expr = extract_or_clause(l_expr, schema_columns);
let r_expr = extract_or_clause(r_expr, schema_columns);
match (l_expr, r_expr) {
(Some(l_expr), Some(r_expr)) => {
predicate = Some(and(l_expr, r_expr));
}
(Some(l_expr), None) => {
predicate = Some(l_expr);
}
(None, Some(r_expr)) => {
predicate = Some(r_expr);
}
(None, None) => {
predicate = None;
}
}
}
_ => {
let columns = expr.to_columns().ok().unwrap();
if schema_columns
.intersection(&columns)
.collect::<HashSet<_>>()
.len()
== columns.len()
{
predicate = Some(expr.clone());
}
}
}
predicate
}
// push down join/cross-join
fn push_down_all_join(
predicates: Vec<Expr>,
infer_predicates: Vec<Expr>,
join_plan: &LogicalPlan,
left: &LogicalPlan,
right: &LogicalPlan,
on_filter: Vec<Expr>,
is_inner_join: bool,
) -> Result<Transformed<LogicalPlan>> {
let on_filter_empty = on_filter.is_empty();
// Get pushable predicates from current optimizer state
let (left_preserved, right_preserved) = lr_is_preserved(join_plan)?;
// The predicates can be divided to three categories:
// 1) can push through join to its children(left or right)
// 2) can be converted to join conditions if the join type is Inner
// 3) should be kept as filter conditions
let left_schema = left.schema();
let right_schema = right.schema();
let mut left_push = vec![];
let mut right_push = vec![];
let mut keep_predicates = vec![];
let mut join_conditions = vec![];
for predicate in predicates {
if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? {
left_push.push(predicate);
} else if right_preserved
&& can_pushdown_join_predicate(&predicate, right_schema)?
{
right_push.push(predicate);
} else if is_inner_join && can_evaluate_as_join_condition(&predicate)? {
// Here we do not differ it is eq or non-eq predicate, ExtractEquijoinPredicate will extract the eq predicate
// and convert to the join on condition
join_conditions.push(predicate);
} else {
keep_predicates.push(predicate);
}
}
// For infer predicates, if they can not push through join, just drop them
for predicate in infer_predicates {
if left_preserved && can_pushdown_join_predicate(&predicate, left_schema)? {
left_push.push(predicate);
} else if right_preserved
&& can_pushdown_join_predicate(&predicate, right_schema)?
{
right_push.push(predicate);
}
}
if !on_filter.is_empty() {
let (on_left_preserved, on_right_preserved) = on_lr_is_preserved(join_plan)?;
for on in on_filter {
if on_left_preserved && can_pushdown_join_predicate(&on, left_schema)? {
left_push.push(on)
} else if on_right_preserved
&& can_pushdown_join_predicate(&on, right_schema)?
{
right_push.push(on)
} else {
join_conditions.push(on)
}
}
}
// Extract from OR clause, generate new predicates for both side of join if possible.
// We only track the unpushable predicates above.
if left_preserved {
left_push.extend(extract_or_clauses_for_join(&keep_predicates, left_schema));
left_push.extend(extract_or_clauses_for_join(&join_conditions, left_schema));
}
if right_preserved {
right_push.extend(extract_or_clauses_for_join(&keep_predicates, right_schema));
right_push.extend(extract_or_clauses_for_join(&join_conditions, right_schema));
}
let left = match conjunction(left_push) {
Some(predicate) => {
LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(left.clone()))?)
}
None => left.clone(),
};
let right = match conjunction(right_push) {
Some(predicate) => {
LogicalPlan::Filter(Filter::try_new(predicate, Arc::new(right.clone()))?)
}
None => right.clone(),
};
// Create a new Join with the new `left` and `right`
//
// expressions() output for Join is a vector consisting of
// 1. join keys - columns mentioned in ON clause
// 2. optional predicate - in case join filter is not empty,
// it always will be the last element, otherwise result
// vector will contain only join keys (without additional
// element representing filter).
let mut exprs = join_plan.expressions();
if !on_filter_empty {
exprs.pop();
}
exprs.extend(join_conditions.into_iter().reduce(Expr::and));
let plan = join_plan.with_new_exprs(exprs, vec![left, right])?;
// wrap the join on the filter whose predicates must be kept
match conjunction(keep_predicates) {
Some(predicate) => {
let new_filter_plan = Filter::try_new(predicate, Arc::new(plan))?;
Ok(Transformed::yes(LogicalPlan::Filter(new_filter_plan)))
}
None => Ok(Transformed::no(plan)),
}
}
fn push_down_join(
plan: &LogicalPlan,
join: &Join,
parent_predicate: Option<&Expr>,
) -> Result<Transformed<LogicalPlan>> {
// Split the parent predicate into individual conjunctive parts.
let predicates = parent_predicate
.map_or_else(Vec::new, |pred| split_conjunction_owned(pred.clone()));
// Extract conjunctions from the JOIN's ON filter, if present.
let on_filters = join
.filter
.as_ref()
.map_or_else(Vec::new, |filter| split_conjunction_owned(filter.clone()));
let mut is_inner_join = false;
let infer_predicates = if join.join_type == JoinType::Inner {
is_inner_join = true;
// Only allow both side key is column.
let join_col_keys = join
.on
.iter()
.filter_map(|(l, r)| {
let left_col = l.try_into_col().ok()?;
let right_col = r.try_into_col().ok()?;
Some((left_col, right_col))
})
.collect::<Vec<_>>();
// TODO refine the logic, introduce EquivalenceProperties to logical plan and infer additional filters to push down
// For inner joins, duplicate filters for joined columns so filters can be pushed down
// to both sides. Take the following query as an example:
//
// ```sql
// SELECT * FROM t1 JOIN t2 on t1.id = t2.uid WHERE t1.id > 1
// ```
//
// `t1.id > 1` predicate needs to be pushed down to t1 table scan, while
// `t2.uid > 1` predicate needs to be pushed down to t2 table scan.
//
// Join clauses with `Using` constraints also take advantage of this logic to make sure
// predicates reference the shared join columns are pushed to both sides.
// This logic should also been applied to conditions in JOIN ON clause
predicates
.iter()
.chain(on_filters.iter())
.filter_map(|predicate| {
let mut join_cols_to_replace = HashMap::new();
let columns = match predicate.to_columns() {
Ok(columns) => columns,
Err(e) => return Some(Err(e)),
};
for col in columns.iter() {
for (l, r) in join_col_keys.iter() {
if col == l {
join_cols_to_replace.insert(col, r);
break;
} else if col == r {
join_cols_to_replace.insert(col, l);
break;
}
}
}
if join_cols_to_replace.is_empty() {
return None;
}
let join_side_predicate =
match replace_col(predicate.clone(), &join_cols_to_replace) {
Ok(p) => p,
Err(e) => {
return Some(Err(e));
}
};
Some(Ok(join_side_predicate))
})
.collect::<Result<Vec<_>>>()?
} else {
vec![]
};
if on_filters.is_empty() && predicates.is_empty() && infer_predicates.is_empty() {
return Ok(Transformed::no(plan.clone()));
}
match push_down_all_join(
predicates,
infer_predicates,
plan,
&join.left,
&join.right,
on_filters,
is_inner_join,
) {
Ok(plan) => Ok(Transformed::yes(plan.data)),
Err(e) => Err(e),
}
}
impl OptimizerRule for PushDownFilter {
fn try_optimize(
&self,
_plan: &LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Option<LogicalPlan>> {
internal_err!("Should have called PushDownFilter::rewrite")
}
fn name(&self) -> &str {
"push_down_filter"
}
fn apply_order(&self) -> Option<ApplyOrder> {
Some(ApplyOrder::TopDown)
}
fn supports_rewrite(&self) -> bool {
true
}
fn rewrite(
&self,
plan: LogicalPlan,
_config: &dyn OptimizerConfig,
) -> Result<Transformed<LogicalPlan>> {
let filter = match plan {
LogicalPlan::Filter(ref filter) => filter,
LogicalPlan::Join(ref join) => return push_down_join(&plan, join, None),
_ => return Ok(Transformed::no(plan)),
};
let child_plan = filter.input.as_ref();
let new_plan = match child_plan {
LogicalPlan::Filter(ref child_filter) => {
let parents_predicates = split_conjunction(&filter.predicate);
let set: HashSet<&&Expr> = parents_predicates.iter().collect();
let new_predicates = parents_predicates
.iter()
.chain(
split_conjunction(&child_filter.predicate)
.iter()
.filter(|e| !set.contains(e)),
)
.map(|e| (*e).clone())
.collect::<Vec<_>>();
let new_predicate = conjunction(new_predicates).ok_or_else(|| {
plan_datafusion_err!("at least one expression exists")
})?;
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
child_filter.input.clone(),
)?);
self.rewrite(new_filter, _config)?.data
}
LogicalPlan::Repartition(_)
| LogicalPlan::Distinct(_)
| LogicalPlan::Sort(_) => {
let new_filter = plan.with_new_exprs(
plan.expressions(),
vec![child_plan.inputs()[0].clone()],
)?;
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
}
LogicalPlan::SubqueryAlias(ref subquery_alias) => {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in
subquery_alias.input.schema().iter().enumerate()
{
let (sub_qualifier, sub_field) =
subquery_alias.schema.qualified_field(i);
replace_map.insert(
qualified_name(sub_qualifier, sub_field.name()),
Expr::Column(Column::new(qualifier.cloned(), field.name())),
);
}
let new_predicate =
replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
let new_filter = LogicalPlan::Filter(Filter::try_new(
new_predicate,
subquery_alias.input.clone(),
)?);
child_plan.with_new_exprs(child_plan.expressions(), vec![new_filter])?
}
LogicalPlan::Projection(ref projection) => {
// A projection is filter-commutable if it do not contain volatile predicates or contain volatile
// predicates that are not used in the filter. However, we should re-writes all predicate expressions.
// collect projection.
let (volatile_map, non_volatile_map): (HashMap<_, _>, HashMap<_, _>) =
projection
.schema
.iter()
.enumerate()
.map(|(i, (qualifier, field))| {
// strip alias, as they should not be part of filters
let expr = match &projection.expr[i] {
Expr::Alias(Alias { expr, .. }) => expr.as_ref().clone(),
expr => expr.clone(),
};
(qualified_name(qualifier, field.name()), expr)
})
.partition(|(_, value)| value.is_volatile().unwrap_or(true));
let mut push_predicates = vec![];
let mut keep_predicates = vec![];
for expr in split_conjunction_owned(filter.predicate.clone()).into_iter()
{
if contain(&expr, &volatile_map) {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}
match conjunction(push_predicates) {
Some(expr) => {
// re-write all filters based on this projection
// E.g. in `Filter: b\n Projection: a > 1 as b`, we can swap them, but the filter must be "a > 1"
let new_filter = LogicalPlan::Filter(Filter::try_new(
replace_cols_by_name(expr, &non_volatile_map)?,
projection.input.clone(),
)?);
match conjunction(keep_predicates) {
None => child_plan.with_new_exprs(
child_plan.expressions(),
vec![new_filter],
)?,
Some(keep_predicate) => {
let child_plan = child_plan.with_new_exprs(
child_plan.expressions(),
vec![new_filter],
)?;
LogicalPlan::Filter(Filter::try_new(
keep_predicate,
Arc::new(child_plan),
)?)
}
}
}
None => return Ok(Transformed::no(plan)),
}
}
LogicalPlan::Union(ref union) => {
let mut inputs = Vec::with_capacity(union.inputs.len());
for input in &union.inputs {
let mut replace_map = HashMap::new();
for (i, (qualifier, field)) in input.schema().iter().enumerate() {
let (union_qualifier, union_field) =
union.schema.qualified_field(i);
replace_map.insert(
qualified_name(union_qualifier, union_field.name()),
Expr::Column(Column::new(qualifier.cloned(), field.name())),
);
}
let push_predicate =
replace_cols_by_name(filter.predicate.clone(), &replace_map)?;
inputs.push(Arc::new(LogicalPlan::Filter(Filter::try_new(
push_predicate,
input.clone(),
)?)))
}
LogicalPlan::Union(Union {
inputs,
schema: plan.schema().clone(),
})
}
LogicalPlan::Aggregate(ref agg) => {
// We can push down Predicate which in groupby_expr.
let group_expr_columns = agg
.group_expr
.iter()
.map(|e| Ok(Column::from_qualified_name(e.display_name()?)))
.collect::<Result<HashSet<_>>>()?;
let predicates = split_conjunction_owned(filter.predicate.clone());
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.to_columns()?;
if cols.iter().all(|c| group_expr_columns.contains(c)) {
push_predicates.push(expr);
} else {
keep_predicates.push(expr);
}
}
// As for plan Filter: Column(a+b) > 0 -- Agg: groupby:[Column(a)+Column(b)]
// After push, we need to replace `a+b` with Column(a)+Column(b)
// So we need create a replace_map, add {`a+b` --> Expr(Column(a)+Column(b))}
let mut replace_map = HashMap::new();
for expr in &agg.group_expr {
replace_map.insert(expr.display_name()?, expr.clone());
}
let replaced_push_predicates = push_predicates
.iter()
.map(|expr| replace_cols_by_name(expr.clone(), &replace_map))
.collect::<Result<Vec<_>>>()?;
let child = match conjunction(replaced_push_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
agg.input.clone(),
)?),
None => (*agg.input).clone(),
};
let new_agg = filter
.input
.with_new_exprs(filter.input.expressions(), vec![child])?;
match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_agg),
)?),
None => new_agg,
}
}
LogicalPlan::Join(ref join) => {
push_down_join(
&unwrap_arc(filter.clone().input),
join,
Some(&filter.predicate),
)?
.data
}
LogicalPlan::CrossJoin(ref cross_join) => {
let predicates = split_conjunction_owned(filter.predicate.clone());
let join = convert_cross_join_to_inner_join(cross_join.clone())?;
let join_plan = LogicalPlan::Join(join);
let inputs = join_plan.inputs();
let left = inputs[0];
let right = inputs[1];
let plan = push_down_all_join(
predicates,
vec![],
&join_plan,
left,
right,
vec![],
true,
)?;
convert_to_cross_join_if_beneficial(plan.data)?
}
LogicalPlan::TableScan(ref scan) => {
let filter_predicates = split_conjunction(&filter.predicate);
let results = scan
.source
.supports_filters_pushdown(filter_predicates.as_slice())?;
if filter_predicates.len() != results.len() {
return internal_err!(
"Vec returned length: {} from supports_filters_pushdown is not the same size as the filters passed, which length is: {}",
results.len(),
filter_predicates.len());
}
let zip = filter_predicates.iter().zip(results);
let new_scan_filters = zip
.clone()
.filter(|(_, res)| res != &TableProviderFilterPushDown::Unsupported)
.map(|(pred, _)| *pred);
let new_scan_filters: Vec<Expr> = scan
.filters
.iter()
.chain(new_scan_filters)
.unique()
.cloned()
.collect();
let new_predicate: Vec<Expr> = zip
.filter(|(_, res)| res != &TableProviderFilterPushDown::Exact)
.map(|(pred, _)| (*pred).clone())
.collect();
let new_scan = LogicalPlan::TableScan(TableScan {
source: scan.source.clone(),
projection: scan.projection.clone(),
projected_schema: scan.projected_schema.clone(),
table_name: scan.table_name.clone(),
filters: new_scan_filters,
fetch: scan.fetch,
});
match conjunction(new_predicate) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_scan),
)?),
None => new_scan,
}
}
LogicalPlan::Extension(ref extension_plan) => {
let prevent_cols =
extension_plan.node.prevent_predicate_push_down_columns();
let predicates = split_conjunction_owned(filter.predicate.clone());
let mut keep_predicates = vec![];
let mut push_predicates = vec![];
for expr in predicates {
let cols = expr.to_columns()?;
if cols.iter().any(|c| prevent_cols.contains(&c.name)) {
keep_predicates.push(expr);
} else {
push_predicates.push(expr);
}
}
let new_children = match conjunction(push_predicates) {
Some(predicate) => extension_plan
.node
.inputs()
.into_iter()
.map(|child| {
Ok(LogicalPlan::Filter(Filter::try_new(
predicate.clone(),
Arc::new(child.clone()),
)?))
})
.collect::<Result<Vec<_>>>()?,
None => extension_plan.node.inputs().into_iter().cloned().collect(),
};
// extension with new inputs.
let new_extension =
child_plan.with_new_exprs(child_plan.expressions(), new_children)?;
match conjunction(keep_predicates) {
Some(predicate) => LogicalPlan::Filter(Filter::try_new(
predicate,
Arc::new(new_extension),
)?),
None => new_extension,
}
}
_ => return Ok(Transformed::no(plan)),
};
Ok(Transformed::yes(new_plan))
}
}
impl PushDownFilter {
#[allow(missing_docs)]
pub fn new() -> Self {
Self {}
}
}
/// Converts the given cross join to an inner join with an empty equality
/// predicate and an empty filter condition.
fn convert_cross_join_to_inner_join(cross_join: CrossJoin) -> Result<Join> {
let CrossJoin { left, right, .. } = cross_join;
let join_schema = build_join_schema(left.schema(), right.schema(), &JoinType::Inner)?;
Ok(Join {
left,
right,
join_type: JoinType::Inner,
join_constraint: JoinConstraint::On,
on: vec![],
filter: None,
schema: DFSchemaRef::new(join_schema),
null_equals_null: true,
})
}
/// Converts the given inner join with an empty equality predicate and an
/// empty filter condition to a cross join.
fn convert_to_cross_join_if_beneficial(plan: LogicalPlan) -> Result<LogicalPlan> {
if let LogicalPlan::Join(join) = &plan {
// Can be converted back to cross join
if join.on.is_empty() && join.filter.is_none() {
return LogicalPlanBuilder::from(join.left.as_ref().clone())
.cross_join(join.right.as_ref().clone())?
.build();
}
} else if let LogicalPlan::Filter(filter) = &plan {
let new_input =
convert_to_cross_join_if_beneficial(filter.input.as_ref().clone())?;
return Filter::try_new(filter.predicate.clone(), Arc::new(new_input))
.map(LogicalPlan::Filter);
}