-
Notifications
You must be signed in to change notification settings - Fork 28k
/
InjectRuntimeFilter.scala
320 lines (300 loc) · 14.3 KB
/
InjectRuntimeFilter.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.BloomFilterAggregate
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY, PYTHON_UDF, REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE, SCALA_UDF}
import org.apache.spark.sql.internal.SQLConf
/**
* Insert a runtime filter on one side of the join (we call this side the application side) if
* we can extract a runtime filter from the other side (creation side). A simple case is that
* the creation side is a table scan with a selective filter.
* The runtime filter is logically an IN subquery with the join keys. Currently it's always
* bloom filter but we may add other physical implementations in the future.
*/
object InjectRuntimeFilter extends Rule[LogicalPlan] with PredicateHelper with JoinSelectionHelper {
private def injectFilter(
filterApplicationSideKey: Expression,
filterApplicationSidePlan: LogicalPlan,
filterCreationSideKey: Expression,
filterCreationSidePlan: LogicalPlan): LogicalPlan = {
injectBloomFilter(
filterApplicationSideKey,
filterApplicationSidePlan,
filterCreationSideKey,
filterCreationSidePlan
)
}
private def injectBloomFilter(
filterApplicationSideKey: Expression,
filterApplicationSidePlan: LogicalPlan,
filterCreationSideKey: Expression,
filterCreationSidePlan: LogicalPlan): LogicalPlan = {
// Skip if the filter creation side is too big
if (filterCreationSidePlan.stats.sizeInBytes > conf.runtimeFilterCreationSideThreshold) {
return filterApplicationSidePlan
}
val rowCount = filterCreationSidePlan.stats.rowCount
val bloomFilterAgg =
if (rowCount.isDefined && rowCount.get.longValue > 0L) {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideKey)), rowCount.get.longValue)
} else {
new BloomFilterAggregate(new XxHash64(Seq(filterCreationSideKey)))
}
val alias = Alias(bloomFilterAgg.toAggregateExpression(), "bloomFilter")()
val aggregate =
ConstantFolding(ColumnPruning(Aggregate(Nil, Seq(alias), filterCreationSidePlan)))
val bloomFilterSubquery = ScalarSubquery(aggregate, Nil)
val filter = BloomFilterMightContain(bloomFilterSubquery,
new XxHash64(Seq(filterApplicationSideKey)))
Filter(filter, filterApplicationSidePlan)
}
/**
* Extracts a sub-plan which is a simple filter over scan from the input plan. The simple
* filter should be selective and the filter condition (including expressions in the child
* plan referenced by the filter condition) should be a simple expression, so that we do
* not add a subquery that might have an expensive computation. The extracted sub-plan should
* produce a superset of the entire creation side output data, so that it's still correct to
* use the sub-plan to build the runtime filter to prune the application side.
*/
private def extractSelectiveFilterOverScan(
plan: LogicalPlan,
filterCreationSideKey: Expression): Option[(Expression, LogicalPlan)] = {
def extract(
p: LogicalPlan,
predicateReference: AttributeSet,
hasHitFilter: Boolean,
hasHitSelectiveFilter: Boolean,
currentPlan: LogicalPlan,
targetKey: Expression): Option[(Expression, LogicalPlan)] = p match {
case Project(projectList, child) if hasHitFilter =>
// We need to make sure all expressions referenced by filter predicates are simple
// expressions.
val referencedExprs = projectList.filter(predicateReference.contains)
if (referencedExprs.forall(isSimpleExpression)) {
extract(
child,
referencedExprs.map(_.references).foldLeft(AttributeSet.empty)(_ ++ _),
hasHitFilter,
hasHitSelectiveFilter,
currentPlan,
targetKey)
} else {
None
}
case Project(_, child) =>
assert(predicateReference.isEmpty && !hasHitSelectiveFilter)
extract(child, predicateReference, hasHitFilter, hasHitSelectiveFilter, currentPlan,
targetKey)
case Filter(condition, child) if isSimpleExpression(condition) =>
extract(
child,
predicateReference ++ condition.references,
hasHitFilter = true,
hasHitSelectiveFilter = hasHitSelectiveFilter || isLikelySelective(condition),
currentPlan,
targetKey)
case ExtractEquiJoinKeys(_, lkeys, rkeys, _, _, left, right, _) =>
// Runtime filters use one side of the [[Join]] to build a set of join key values and prune
// the other side of the [[Join]]. It's also OK to use a superset of the join key values
// (ignore null values) to do the pruning.
// We assume other rules have already pushed predicates through join if possible.
// So the predicate references won't pass on anymore.
if (left.output.exists(_.semanticEquals(targetKey))) {
extract(left, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = left, targetKey = targetKey).orElse {
// We can also extract from the right side if the join keys are transitive.
lkeys.zip(rkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
.flatMap { newTargetKey =>
extract(right, AttributeSet.empty,
hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = right,
targetKey = newTargetKey)
}
}
} else if (right.output.exists(_.semanticEquals(targetKey))) {
extract(right, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = right, targetKey = targetKey).orElse {
// We can also extract from the left side if the join keys are transitive.
rkeys.zip(lkeys).find(_._1.semanticEquals(targetKey)).map(_._2)
.flatMap { newTargetKey =>
extract(left, AttributeSet.empty,
hasHitFilter = false, hasHitSelectiveFilter = false, currentPlan = left,
targetKey = newTargetKey)
}
}
} else {
None
}
case _: LeafNode if hasHitSelectiveFilter =>
Some((targetKey, currentPlan))
case _ => None
}
if (!plan.isStreaming) {
extract(plan, AttributeSet.empty, hasHitFilter = false, hasHitSelectiveFilter = false,
currentPlan = plan, targetKey = filterCreationSideKey)
} else {
None
}
}
private def isSimpleExpression(e: Expression): Boolean = {
!e.containsAnyPattern(PYTHON_UDF, SCALA_UDF, INVOKE, JSON_TO_STRUCT, LIKE_FAMLIY,
REGEXP_EXTRACT_FAMILY, REGEXP_REPLACE)
}
private def isProbablyShuffleJoin(left: LogicalPlan,
right: LogicalPlan, hint: JoinHint): Boolean = {
!hintToBroadcastLeft(hint) && !hintToBroadcastRight(hint) &&
!canBroadcastBySize(left, conf) && !canBroadcastBySize(right, conf)
}
private def probablyHasShuffle(plan: LogicalPlan): Boolean = {
plan.exists {
case Join(left, right, _, _, hint) => isProbablyShuffleJoin(left, right, hint)
case _: Aggregate => true
case _: Window => true
case _ => false
}
}
// Returns the max scan byte size in the subtree rooted at `filterApplicationSide`.
private def maxScanByteSize(filterApplicationSide: LogicalPlan): BigInt = {
val defaultSizeInBytes = conf.getConf(SQLConf.DEFAULT_SIZE_IN_BYTES)
filterApplicationSide.collect({
case leaf: LeafNode => leaf
}).map(scan => {
// DEFAULT_SIZE_IN_BYTES means there's no byte size information in stats. Since we avoid
// creating a Bloom filter when the filter application side is very small, so using 0
// as the byte size when the actual size is unknown can avoid regression by applying BF
// on a small table.
if (scan.stats.sizeInBytes == defaultSizeInBytes) BigInt(0) else scan.stats.sizeInBytes
}).max
}
// Returns true if `filterApplicationSide` satisfies the byte size requirement to apply a
// Bloom filter; false otherwise.
private def satisfyByteSizeRequirement(filterApplicationSide: LogicalPlan): Boolean = {
// In case `filterApplicationSide` is a union of many small tables, disseminating the Bloom
// filter to each small task might be more costly than scanning them itself. Thus, we use max
// rather than sum here.
val maxScanSize = maxScanByteSize(filterApplicationSide)
maxScanSize >=
conf.getConf(SQLConf.RUNTIME_BLOOM_FILTER_APPLICATION_SIDE_SCAN_SIZE_THRESHOLD)
}
/**
* Extracts the beneficial filter creation plan with check show below:
* - The filterApplicationSideKey can be pushed down through joins, aggregates and windows
* (ie the expression references originate from a single leaf node)
* - The filter creation side has a selective predicate
* - The max filterApplicationSide scan size is greater than a configurable threshold
*/
private def extractBeneficialFilterCreatePlan(
filterApplicationSide: LogicalPlan,
filterCreationSide: LogicalPlan,
filterApplicationSideKey: Expression,
filterCreationSideKey: Expression): Option[(Expression, LogicalPlan)] = {
if (findExpressionAndTrackLineageDown(
filterApplicationSideKey, filterApplicationSide).isDefined &&
satisfyByteSizeRequirement(filterApplicationSide)) {
extractSelectiveFilterOverScan(filterCreationSide, filterCreationSideKey)
} else {
None
}
}
// This checks if there is already a DPP filter, as this rule is called just after DPP.
@tailrec
private def hasDynamicPruningSubquery(
left: LogicalPlan,
right: LogicalPlan,
leftKey: Expression,
rightKey: Expression): Boolean = {
(left, right) match {
case (Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan), _) =>
pruningKey.fastEquals(leftKey) || hasDynamicPruningSubquery(plan, right, leftKey, rightKey)
case (_, Filter(DynamicPruningSubquery(pruningKey, _, _, _, _, _, _), plan)) =>
pruningKey.fastEquals(rightKey) ||
hasDynamicPruningSubquery(left, plan, leftKey, rightKey)
case _ => false
}
}
private def hasBloomFilter(plan: LogicalPlan, key: Expression): Boolean = {
plan.exists {
case Filter(condition, _) =>
splitConjunctivePredicates(condition).exists {
case BloomFilterMightContain(_, XxHash64(Seq(valueExpression), _))
if valueExpression.fastEquals(key) => true
case _ => false
}
case _ => false
}
}
private def tryInjectRuntimeFilter(plan: LogicalPlan): LogicalPlan = {
var filterCounter = 0
val numFilterThreshold = conf.getConf(SQLConf.RUNTIME_FILTER_NUMBER_THRESHOLD)
plan transformUp {
case join @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, left, right, hint) =>
var newLeft = left
var newRight = right
leftKeys.lazyZip(rightKeys).foreach((l, r) => {
// Check if:
// 1. There is already a DPP filter on the key
// 2. The keys are simple cheap expressions
if (filterCounter < numFilterThreshold &&
!hasDynamicPruningSubquery(left, right, l, r) &&
isSimpleExpression(l) && isSimpleExpression(r)) {
val oldLeft = newLeft
val oldRight = newRight
// Check if:
// 1. The current join type supports prune the left side with runtime filter
// 2. The current join is a shuffle join or a broadcast join that
// has a shuffle below it
// 3. There is no bloom filter on the left key yet
val hasShuffle = isProbablyShuffleJoin(left, right, hint)
if (canPruneLeft(joinType) && (hasShuffle || probablyHasShuffle(left)) &&
!hasBloomFilter(newLeft, l)) {
extractBeneficialFilterCreatePlan(left, right, l, r).foreach {
case (filterCreationSideKey, filterCreationSidePlan) =>
newLeft = injectFilter(l, newLeft, filterCreationSideKey, filterCreationSidePlan)
}
}
// Did we actually inject on the left? If not, try on the right
// Check if:
// 1. The current join type supports prune the right side with runtime filter
// 2. The current join is a shuffle join or a broadcast join that
// has a shuffle below it
// 3. There is no bloom filter on the right key yet
if (newLeft.fastEquals(oldLeft) && canPruneRight(joinType) &&
(hasShuffle || probablyHasShuffle(right)) && !hasBloomFilter(newRight, r)) {
extractBeneficialFilterCreatePlan(right, left, r, l).foreach {
case (filterCreationSideKey, filterCreationSidePlan) =>
newRight = injectFilter(
r, newRight, filterCreationSideKey, filterCreationSidePlan)
}
}
if (!newLeft.fastEquals(oldLeft) || !newRight.fastEquals(oldRight)) {
filterCounter = filterCounter + 1
}
}
})
join.withNewChildren(Seq(newLeft, newRight))
}
}
override def apply(plan: LogicalPlan): LogicalPlan = plan match {
case s: Subquery if s.correlated => plan
case _ if !conf.runtimeFilterBloomFilterEnabled => plan
case _ => tryInjectRuntimeFilter(plan)
}
}