/
ResolveHints.scala
292 lines (255 loc) · 12.5 KB
/
ResolveHints.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis
import java.util.Locale
import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, IntegerLiteral, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.internal.SQLConf
/**
* Collection of rules related to hints. The only hint currently available is join strategy hint.
*
* Note that this is separately into two rules because in the future we might introduce new hint
* rules that have different ordering requirements from join strategies.
*/
object ResolveHints {
/**
* The list of allowed join strategy hints is defined in [[JoinStrategyHint.strategies]], and a
* sequence of relation aliases can be specified with a join strategy hint, e.g., "MERGE(a, c)",
* "BROADCAST(a)". A join strategy hint plan node will be inserted on top of any relation (that
* is not aliased differently), subquery, or common table expression that match the specified
* name.
*
* The hint resolution works by recursively traversing down the query plan to find a relation or
* subquery that matches one of the specified relation aliases. The traversal does not go past
* beyond any view reference, with clause or subquery alias.
*
* This rule must happen before common table expressions.
*/
class ResolveJoinStrategyHints(conf: SQLConf) extends Rule[LogicalPlan] {
private val STRATEGY_HINT_NAMES = JoinStrategyHint.strategies.flatMap(_.hintAliases)
private val hintErrorHandler = conf.hintErrorHandler
def resolver: Resolver = conf.resolver
private def createHintInfo(hintName: String): HintInfo = {
HintInfo(strategy =
JoinStrategyHint.strategies.find(
_.hintAliases.map(
_.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
}
// This method checks if given multi-part identifiers are matched with each other.
// The [[ResolveJoinStrategyHints]] rule is applied before the resolution batch
// in the analyzer and we cannot semantically compare them at this stage.
// Therefore, we follow a simple rule; they match if an identifier in a hint
// is a tail of an identifier in a relation. This process is independent of a session
// catalog (`currentDb` in [[SessionCatalog]]) and it just compares them literally.
//
// For example,
// * in a query `SELECT /*+ BROADCAST(t) */ * FROM db1.t JOIN t`,
// the broadcast hint will match both tables, `db1.t` and `t`,
// even when the current db is `db2`.
// * in a query `SELECT /*+ BROADCAST(default.t) */ * FROM default.t JOIN t`,
// the broadcast hint will match the left-side table only, `default.t`.
private def matchedIdentifier(identInHint: Seq[String], identInQuery: Seq[String]): Boolean = {
if (identInHint.length <= identInQuery.length) {
identInHint.zip(identInQuery.takeRight(identInHint.length))
.forall { case (i1, i2) => resolver(i1, i2) }
} else {
false
}
}
private def extractIdentifier(r: SubqueryAlias): Seq[String] = {
r.identifier.qualifier :+ r.identifier.name
}
private def applyJoinStrategyHint(
plan: LogicalPlan,
relationsInHint: Set[Seq[String]],
relationsInHintWithMatch: mutable.HashSet[Seq[String]],
hintName: String): LogicalPlan = {
// Whether to continue recursing down the tree
var recurse = true
def matchedIdentifierInHint(identInQuery: Seq[String]): Boolean = {
relationsInHint.find(matchedIdentifier(_, identInQuery))
.map(relationsInHintWithMatch.add).nonEmpty
}
val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case ResolvedHint(u @ UnresolvedRelation(ident, _), hint)
if matchedIdentifierInHint(ident) =>
ResolvedHint(u, createHintInfo(hintName).merge(hint, hintErrorHandler))
case ResolvedHint(r: SubqueryAlias, hint)
if matchedIdentifierInHint(extractIdentifier(r)) =>
ResolvedHint(r, createHintInfo(hintName).merge(hint, hintErrorHandler))
case UnresolvedRelation(ident, _) if matchedIdentifierInHint(ident) =>
ResolvedHint(plan, createHintInfo(hintName))
case r: SubqueryAlias if matchedIdentifierInHint(extractIdentifier(r)) =>
ResolvedHint(plan, createHintInfo(hintName))
case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
// Don't traverse down these nodes.
// For an existing strategy hint, there is no chance for a match from this point down.
// The rest (view, with, subquery) indicates different scopes that we shouldn't traverse
// down. Note that technically when this rule is executed, we haven't completed view
// resolution yet and as a result the view part should be deadcode. I'm leaving it here
// to be more future proof in case we change the view we do view resolution.
recurse = false
plan
case _ =>
plan
}
}
if ((plan fastEquals newNode) && recurse) {
newNode.mapChildren { child =>
applyJoinStrategyHint(child, relationsInHint, relationsInHintWithMatch, hintName)
}
} else {
newNode
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case h: UnresolvedHint if STRATEGY_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) =>
if (h.parameters.isEmpty) {
// If there is no table alias specified, apply the hint on the entire subtree.
ResolvedHint(h.child, createHintInfo(h.name))
} else {
// Otherwise, find within the subtree query plans to apply the hint.
val relationNamesInHint = h.parameters.map {
case tableName: String => UnresolvedAttribute.parseAttributeName(tableName)
case tableId: UnresolvedAttribute => tableId.nameParts
case unsupported => throw new AnalysisException("Join strategy hint parameter " +
s"should be an identifier or string but was $unsupported (${unsupported.getClass}")
}.toSet
val relationsInHintWithMatch = new mutable.HashSet[Seq[String]]
val applied = applyJoinStrategyHint(
h.child, relationNamesInHint, relationsInHintWithMatch, h.name)
// Filters unmatched relation identifiers in the hint
val unmatchedIdents = relationNamesInHint -- relationsInHintWithMatch
hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, unmatchedIdents)
applied
}
}
}
/**
* COALESCE Hint accepts names "COALESCE", "REPARTITION", and "REPARTITION_BY_RANGE".
*/
class ResolveCoalesceHints(conf: SQLConf) extends Rule[LogicalPlan] {
/**
* This function handles hints for "COALESCE" and "REPARTITION".
* The "COALESCE" hint only has a partition number as a parameter. The "REPARTITION" hint
* has a partition number, columns, or both of them as parameters.
*/
private def createRepartition(
shuffle: Boolean, hint: UnresolvedHint): LogicalPlan = {
val hintName = hint.name.toUpperCase(Locale.ROOT)
def createRepartitionByExpression(
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val sortOrders = partitionExprs.filter(_.isInstanceOf[SortOrder])
if (sortOrders.nonEmpty) throw new IllegalArgumentException(
s"""Invalid partitionExprs specified: $sortOrders
|For range partitioning use REPARTITION_BY_RANGE instead.
""".stripMargin)
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
s"${invalidParams.mkString(", ")} found")
}
RepartitionByExpression(
partitionExprs.map(_.asInstanceOf[Expression]), hint.child, numPartitions)
}
hint.parameters match {
case Seq(IntegerLiteral(numPartitions)) =>
Repartition(numPartitions, shuffle, hint.child)
case Seq(numPartitions: Int) =>
Repartition(numPartitions, shuffle, hint.child)
// The "COALESCE" hint (shuffle = false) must have a partition number only
case _ if !shuffle =>
throw new AnalysisException(s"$hintName Hint expects a partition number as a parameter")
case param @ Seq(IntegerLiteral(numPartitions), _*) if shuffle =>
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) if shuffle =>
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) if shuffle =>
createRepartitionByExpression(None, param)
}
}
/**
* This function handles hints for "REPARTITION_BY_RANGE".
* The "REPARTITION_BY_RANGE" hint must have column names and a partition number is optional.
*/
private def createRepartitionByRange(hint: UnresolvedHint): RepartitionByExpression = {
val hintName = hint.name.toUpperCase(Locale.ROOT)
def createRepartitionByExpression(
numPartitions: Option[Int], partitionExprs: Seq[Any]): RepartitionByExpression = {
val invalidParams = partitionExprs.filter(!_.isInstanceOf[UnresolvedAttribute])
if (invalidParams.nonEmpty) {
throw new AnalysisException(s"$hintName Hint parameter should include columns, but " +
s"${invalidParams.mkString(", ")} found")
}
val sortOrder = partitionExprs.map {
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
}
RepartitionByExpression(sortOrder, hint.child, numPartitions)
}
hint.parameters match {
case param @ Seq(IntegerLiteral(numPartitions), _*) =>
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(numPartitions: Int, _*) =>
createRepartitionByExpression(Some(numPartitions), param.tail)
case param @ Seq(_*) =>
createRepartitionByExpression(None, param)
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case hint @ UnresolvedHint(hintName, _, _) => hintName.toUpperCase(Locale.ROOT) match {
case "REPARTITION" =>
createRepartition(shuffle = true, hint)
case "COALESCE" =>
createRepartition(shuffle = false, hint)
case "REPARTITION_BY_RANGE" =>
createRepartitionByRange(hint)
case _ => hint
}
}
}
object ResolveCoalesceHints {
val COALESCE_HINT_NAMES: Set[String] = Set("COALESCE", "REPARTITION", "REPARTITION_BY_RANGE")
}
/**
* Removes all the hints, used to remove invalid hints provided by the user.
* This must be executed after all the other hint rules are executed.
*/
class RemoveAllHints(conf: SQLConf) extends Rule[LogicalPlan] {
private val hintErrorHandler = conf.hintErrorHandler
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsUp {
case h: UnresolvedHint =>
hintErrorHandler.hintNotRecognized(h.name, h.parameters)
h.child
}
}
/**
* Removes all the hints when `spark.sql.optimizer.disableHints` is set.
* This is executed at the very beginning of the Analyzer to disable
* the hint functionality.
*/
class DisableHints(conf: SQLConf) extends RemoveAllHints(conf: SQLConf) {
override def apply(plan: LogicalPlan): LogicalPlan = {
if (conf.getConf(SQLConf.DISABLE_HINTS)) super.apply(plan) else plan
}
}
}