-
Notifications
You must be signed in to change notification settings - Fork 0
/
replace_cross_joins.go
executable file
·134 lines (120 loc) · 4.45 KB
/
replace_cross_joins.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
// Copyright 2020-2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package analyzer
import (
"github.com/Rock-liyi/p2pdb-store/sql"
"github.com/Rock-liyi/p2pdb-store/sql/expression"
"github.com/Rock-liyi/p2pdb-store/sql/plan"
)
// comparisonSatisfiesJoinCondition checks a) whether a comparison is a valid join predicate,
// and b) whether the Left/Right children of a comparison expression covers the dependency trees
// of a plan.CrossJoin's children.
func comparisonSatisfiesJoinCondition(expr expression.Comparer, j *plan.CrossJoin) bool {
lCols := j.Left().Schema()
rCols := j.Right().Schema()
var re, le *expression.GetField
switch e := expr.(type) {
case *expression.Equals, *expression.NullSafeEquals, *expression.GreaterThan,
*expression.GreaterThanOrEqual, *expression.NullSafeGreaterThanOrEqual,
*expression.NullSafeGreaterThan, *expression.LessThan, *expression.LessThanOrEqual,
*expression.NullSafeLessThanOrEqual, *expression.NullSafeLessThan:
ce, ok := e.(expression.Comparer)
if !ok {
return false
}
le, ok = ce.Left().(*expression.GetField)
if !ok {
return false
}
re, ok = ce.Right().(*expression.GetField)
if !ok {
return false
}
default:
return false
}
return lCols.Contains(le.Name(), le.Table()) && rCols.Contains(re.Name(), re.Table()) ||
rCols.Contains(le.Name(), le.Table()) && lCols.Contains(re.Name(), re.Table())
}
// expressionCoversJoin checks whether a subexpressions's comparison predicate
// satisfies the join condition. The input conjunctions have already been split,
// so we do not care which predicate satisfies the expression.
func expressionCoversJoin(c sql.Expression, j *plan.CrossJoin) (found bool) {
return expression.InspectUp(c, func(expr sql.Expression) bool {
switch e := expr.(type) {
case expression.Comparer:
return comparisonSatisfiesJoinCondition(e, j)
}
return false
})
}
// replaceCrossJoins recursively replaces filter nested cross joins with equivalent inner joins.
// There are 3 phases after we identify a Filter -> ... -> CrossJoin pattern.
// 1) Build a list of disjunct predicate expressions by top-down splitting conjunctions (AND).
// 2) For every CrossJoin, check whether a subset of predicates covers as join conditions,
// and create a new InnerJoin with the matching predicates.
// 3) Remove predicates from the parent Filter that have been pushed into InnerJoins.
func replaceCrossJoins(ctx *sql.Context, a *Analyzer, n sql.Node, scope *Scope) (sql.Node, error) {
if !n.Resolved() {
return n, nil
}
return plan.TransformUp(n, func(n sql.Node) (sql.Node, error) {
f, ok := n.(*plan.Filter)
if !ok {
return n, nil
}
predicates := splitConjunction(f.Expression)
movedPredicates := make(map[int]struct{})
newF, err := plan.TransformUp(f, func(n sql.Node) (sql.Node, error) {
cj, ok := n.(*plan.CrossJoin)
if !ok {
return n, nil
}
joinConjs := make([]int, 0, len(predicates))
for i, c := range predicates {
if expressionCoversJoin(c, cj) {
joinConjs = append(joinConjs, i)
}
}
if len(joinConjs) == 0 {
return n, nil
}
newExprs := make([]sql.Expression, len(joinConjs))
for i, v := range joinConjs {
movedPredicates[v] = struct{}{}
newExprs[i] = predicates[v]
}
return plan.NewInnerJoin(cj.Left(), cj.Right(), expression.JoinAnd(newExprs...)), nil
})
if err != nil {
return f, err
}
// only alter the Filter expression tree if we transferred predicates to an InnerJoin
if len(movedPredicates) == 0 {
return f, nil
}
// remove Filter if all expressions were transferred to joins
if len(predicates) == len(movedPredicates) {
return newF.(*plan.Filter).Child, nil
}
newFilterExprs := make([]sql.Expression, 0, len(predicates)-len(movedPredicates))
for i, e := range predicates {
if _, ok := movedPredicates[i]; ok {
continue
}
newFilterExprs = append(newFilterExprs, e)
}
return newF.(*plan.Filter).WithExpressions(expression.JoinAnd(newFilterExprs...))
})
}