Skip to content

Commit

Permalink
[SPARK-32059] Add tests for Filter
Browse files Browse the repository at this point in the history
  • Loading branch information
Frank Yin committed Jun 30, 2020
1 parent acce8c5 commit c95633e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 10 deletions.
Expand Up @@ -46,7 +46,7 @@ object NestedColumnAliasing {
* the optimizer can hit an infinite loop during the [[PushDownPredicates]] rule.
*/
case Project(projectList, Filter(condition, child))
if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) =>
if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) =>
val exprCandidatesToPrune = projectList ++ Seq(condition) ++ child.expressions
getAliasSubMap(exprCandidatesToPrune, child.producedAttributes.toSeq).map {
case (nestedFieldToAlias, attrToAliases) =>
Expand Down Expand Up @@ -166,10 +166,11 @@ object NestedColumnAliasing {
}

val exclusiveAttrSet = AttributeSet(exclusiveAttrs ++ otherRootReferences)
val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]]
val groupByReferenceList = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]]
.filter(!_.references.subsetOf(exclusiveAttrSet))
.groupBy(_.references.head)
.toList
val exprIdToAliases = groupByReferenceList
.flatMap { case (attr, nestedFields: Seq[ExtractValue]) =>
// Remove redundant `ExtractValue`s if they share the same parent nest field.
// For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`.
Expand Down Expand Up @@ -200,6 +201,7 @@ object NestedColumnAliasing {
None
}
}
val aliasSub = exprIdToAliases
.groupBy(_._1) // To fix same ExprId mapped to different attribute instance
.map {
case (exprId: ExprId, expressions: List[(ExprId, Seq[(ExtractValue, Alias)])]) =>
Expand Down
Expand Up @@ -525,19 +525,88 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
test("Nested field pruning for Filter") {
val spec = windowSpec($"address" :: Nil, $"id".asc :: Nil, UnspecifiedFrame)
val winExpr = windowExpr(RowNumber().toAggregateExpression(), spec)
val query = contact.select($"name.first", winExpr.as('window))
val query1 = contact.select($"name.first", winExpr.as('window))
.where($"window" === 1 && $"name.first" === "a")
.analyze
val optimized = Optimize.execute(query)
val aliases = collectGeneratedAliases(optimized)
val expected = contact
.select($"name.first", $"address", $"id", $"name.first".as(aliases(1)))
val optimized1 = Optimize.execute(query1)
val aliases1 = collectGeneratedAliases(optimized1)
val expected1 = contact
.select($"name.first", $"address", $"id", $"name.first".as(aliases1(1)))
.window(Seq(winExpr.as("window")), Seq($"address"), Seq($"id".asc))
.select($"first", $"${aliases(1)}".as(aliases(0)), $"window")
.where($"window" === 1 && $"${aliases(0)}" === "a")
.select($"first", $"${aliases1(1)}".as(aliases1(0)), $"window")
.where($"window" === 1 && $"${aliases1(0)}" === "a")
.select($"first", $"window")
.analyze
comparePlans(optimized, expected)
comparePlans(optimized1, expected1)

val query2 = contact.sortBy($"name.first".asc)
.where($"name.first" === "a")
.select($"name.first")
.analyze
val optimized2 = Optimize.execute(query2)
val aliases2 = collectGeneratedAliases(optimized2)
val expected2 = contact
.select($"name.first".as(aliases2(1)))
.sortBy($"${aliases2(1)}".asc)
.select($"${aliases2(1)}".as(aliases2(0)))
.where($"${aliases2(0)}" === "a")
.select($"${aliases2(0)}".as("first"))
.analyze
comparePlans(optimized2, expected2)

val query3 = contact.distribute($"name.first")(100)
.where($"name.first" === "a")
.select($"name.first")
.analyze
val optimized3 = Optimize.execute(query3)
val aliases3 = collectGeneratedAliases(optimized3)
val expected3 = contact
.select($"name.first".as(aliases3(1)))
.distribute($"${aliases3(1)}")(100)
.select($"${aliases3(1)}".as(aliases3(0)))
.where($"${aliases3(0)}" === "a")
.select($"${aliases3(0)}".as("first"))
.analyze
comparePlans(optimized3, expected3)

val department = LocalRelation(
'depID.int,
'personID.string)
val query4 = contact.join(department, condition = Some($"id" === $"depID"))
.where($"name.first" === "a")
.select($"name.first")
.analyze
val optimized4 = Optimize.execute(query4)
val aliases4 = collectGeneratedAliases(optimized4)
val expected4 = contact
.select($"id", $"name.first".as(aliases4(1)))
.join(department.select('depID), condition = Some($"id" === $"depID"))
.select($"${aliases4(1)}".as(aliases4(0)))
.where($"${aliases4(0)}" === "a")
.select($"${aliases4(0)}".as("first"))
.analyze
comparePlans(optimized4, expected4)

def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = {
val query = basePlan(contact)
.where($"name.first" === "a")
.select($"name.first")
.analyze
val optimized = Optimize.execute(query)
val aliases = collectGeneratedAliases(optimized)
val expected = basePlan(contact
.select($"name.first".as(aliases(0))))
.where($"${aliases(0)}" === "a")
.select($"${aliases(0)}".as("first"))
.analyze
comparePlans(optimized, expected)
}
Seq(
(plan: LogicalPlan) => plan.limit(100),
(plan: LogicalPlan) => plan.repartition(100),
(plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base =>
runTest(base)
}
}

test("Nested field pruning for Sort") {
Expand Down

0 comments on commit c95633e

Please sign in to comment.