Skip to content

Commit

Permalink
Remove redundant aliases in NestedColumnAliasing.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed May 6, 2019
1 parent d9bcacf commit a009d3e
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,17 @@ object NestedColumnAliasing {
.filter(!_.references.subsetOf(AttributeSet(otherRootReferences)))
.groupBy(_.references.head)
.flatMap { case (attr, nestedFields: Seq[GetStructField]) =>
// Remove redundant `GetStructField`s if they share the same parent nest field.
// For example, when `a.b` and `a.b.c` are in project list, we only need to alias `a.b`.
val dedupNestedFields = nestedFields.filter {
case n @ GetStructField(child, _, _) =>
nestedFields.forall(f => f == n || child.find(_.semanticEquals(f)).isEmpty)
case _ => true
}

// Each expression can contain multiple nested fields.
// Note that we keep the original names to deliver to parquet in a case-sensitive way.
val nestedFieldToAlias = nestedFields.distinct.map { f =>
val nestedFieldToAlias = dedupNestedFields.distinct.map { f =>
val exprId = NamedExpression.newExprId
(f, Alias(f, s"_gen_alias_${exprId.id}")(exprId, Seq.empty, None))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,42 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
comparePlans(optimized, expected)
}

test("Do not generate redundant aliases if parent nested field is aliased too") {
val nestedRelation = LocalRelation('a.struct('b.struct('c.int,
'd.struct('f.int, 'g.int), 'e.int)))

val first = GetStructField('a, 0, Some("b"))
val second = Add(
GetStructField(
GetStructField('a, 0, Some("b")), 0, Some("c")),
Literal(1))
val last = GetStructField(
GetStructField(
GetStructField('a, 0, Some("b")), 1, Some("d")), 0, Some("f"))

val query = nestedRelation
.limit(5)
.select(first, second, last)
.analyze

val optimized = Optimize.execute(query)

val aliases = collectGeneratedAliases(optimized)

val expected = nestedRelation
.select(first.as(aliases(0)))
.limit(5)
.select($"${aliases(0)}".as("a.b"),
Add(
GetStructField($"${aliases(0)}", 0, Some("c")),
Literal(1)).as("(a.b.c + 1)"),
GetStructField(
GetStructField($"${aliases(0)}", 1, Some("d")), 0, Some("f")).as("a.b.d.f"))
.analyze

comparePlans(optimized, expected)
}

private def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = {
val aliases = ArrayBuffer[String]()
query.transformAllExpressions {
Expand Down

0 comments on commit a009d3e

Please sign in to comment.