Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12828][SQL]add natural join support #10762

Closed
wants to merge 17 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,20 @@ joinToken
@init { gParent.pushMsg("join type specifier", state); }
@after { gParent.popMsg(state); }
:
KW_JOIN -> TOK_JOIN
| KW_INNER KW_JOIN -> TOK_JOIN
| COMMA -> TOK_JOIN
| KW_CROSS KW_JOIN -> TOK_CROSSJOIN
| KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN
| KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN
| KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN
| KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN
| KW_ANTI KW_JOIN -> TOK_ANTIJOIN
KW_JOIN -> TOK_JOIN
| KW_INNER KW_JOIN -> TOK_JOIN
| KW_NATURAL KW_JOIN -> TOK_NATURALJOIN
| KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN
| COMMA -> TOK_JOIN
| KW_CROSS KW_JOIN -> TOK_CROSSJOIN
| KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN
| KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN
| KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN
| KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN
| KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN
| KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN
| KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN
| KW_ANTI KW_JOIN -> TOK_ANTIJOIN
;

lateralView
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,8 @@ KW_WEEK: 'WEEK'|'WEEKS';
KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS';
KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS';

KW_NATURAL: 'NATURAL';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the NATURAL keyword reserved? If it is not, please add it to nonReserved rule in the Expression parser.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be reserved.
See:
MySQL: https://dev.mysql.com/doc/refman/5.5/en/keywords.html
PostgreSQL: http://www.postgresql.org/docs/7.3/static/sql-keywords-appendix.html

and it's also reserved in SQL92 and SQL99(as declared in the doc of PostgreSQL).


// Operators
// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ TOK_RIGHTOUTERJOIN;
TOK_FULLOUTERJOIN;
TOK_UNIQUEJOIN;
TOK_CROSSJOIN;
TOK_NATURALJOIN;
TOK_NATURALLEFTOUTERJOIN;
TOK_NATURALRIGHTOUTERJOIN;
TOK_NATURALFULLOUTERJOIN;
TOK_LOAD;
TOK_EXPORT;
TOK_IMPORT;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case "TOK_LEFTSEMIJOIN" => LeftSemi
case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node)
case "TOK_ANTIJOIN" => noParseRule("Anti Join", node)
case "TOK_NATURALJOIN" => NaturalJoin(Inner)
case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter)
case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter)
case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter)
}
Join(nodeToRelation(relation1),
nodeToRelation(relation2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
Expand Down Expand Up @@ -80,6 +81,7 @@ class Analyzer(
ResolveAliases ::
ResolveWindowOrder ::
ResolveWindowFrame ::
ResolveNaturalJoin ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
Expand Down Expand Up @@ -1159,6 +1161,47 @@ class Analyzer(
}
}
}

/**
* Removes natural joins by calculating output columns based on output from two sides,
* Then apply a Project on a normal Join to eliminate natural join.
*/
object ResolveNaturalJoin extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
// Should not skip unresolved nodes because natural join is always unresolved.
case j @ Join(left, right, NaturalJoin(joinType), condition) if j.resolvedExceptNatural =>
// find common column names from both sides, should be treated like usingColumns
val joinNames = left.output.map(_.name).intersect(right.output.map(_.name))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you need to add more inline comment for this function to explain what's going on

val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
val joinPairs = leftKeys.zip(rightKeys)
// Add joinPairs to joinConditions
val newCondition = (condition ++ joinPairs.map {
case (l, r) => EqualTo(l, r)
}).reduceLeftOption(And)
// columns not in joinPairs
val lUniqueOutput = left.output.filterNot(att => leftKeys.contains(att))
val rUniqueOutput = right.output.filterNot(att => rightKeys.contains(att))
// we should only keep unique columns(depends on joinType) for joinCols
val projectList = joinType match {
case LeftOuter =>
leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we switching the ordering of output columns?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nvm i figured it out.

case RightOuter =>
rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput
case FullOuter =>
// in full outer join, joinCols should be non-null if there is.
val joinedCols = joinPairs.map {
case (l, r) => Alias(Coalesce(Seq(l, r)), l.name)()
}
joinedCols ++ lUniqueOutput.map(_.withNullability(true)) ++
rUniqueOutput.map(_.withNullability(true))
case _ =>
rightKeys ++ lUniqueOutput ++ rUniqueOutput
}
// use Project to trim unnecessary fields
Project(projectList, Join(left, right, joinType, newCondition))
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral}
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions}
import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, LeftSemi, RightOuter}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -919,6 +919,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
(rightFilterConditions ++ commonFilterCondition).
reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
case FullOuter => f // DO Nothing for Full Outer Join
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to catch it? I think we can guarantee there is no NaturalJoin after CheckAnalysis

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but joinType is a sealed abstract class, so we have to put something here.

}

// push down the join filter into sub query scanning if applicable
Expand Down Expand Up @@ -953,6 +954,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {

Join(newLeft, newRight, LeftOuter, newJoinCond)
case FullOuter => f
case NaturalJoin(_) => sys.error("Untransformed NaturalJoin node")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ case object FullOuter extends JoinType {
case object LeftSemi extends JoinType {
override def sql: String = "LEFT SEMI"
}

case class NaturalJoin(tpe: JoinType) extends JoinType {
override def sql: String = "NATURAL " + tpe.sql
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,20 @@ case class Join(
def selfJoinResolved: Boolean = left.outputSet.intersect(right.outputSet).isEmpty

// Joins are only resolved if they don't introduce ambiguous expression ids.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should make resolved false if the type is natural join

override lazy val resolved: Boolean = {
// NaturalJoin should be ready for resolution only if everything else is resolved here
lazy val resolvedExceptNatural: Boolean = {
childrenResolved &&
expressions.forall(_.resolved) &&
selfJoinResolved &&
condition.forall(_.dataType == BooleanType)
}

// if not a natural join, use `resolvedExceptNatural`. if it is a natural join, we still need
// to eliminate natural before we mark it resolved.
override lazy val resolved: Boolean = joinType match {
case NaturalJoin(_) => false
case _ => resolvedExceptNatural
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation

class ResolveNaturalJoinSuite extends AnalysisTest {
lazy val a = 'a.string
lazy val b = 'b.string
lazy val c = 'c.string
lazy val aNotNull = a.notNull
lazy val bNotNull = b.notNull
lazy val cNotNull = c.notNull
lazy val r1 = LocalRelation(a, b)
lazy val r2 = LocalRelation(a, c)
lazy val r3 = LocalRelation(aNotNull, bNotNull)
lazy val r4 = LocalRelation(bNotNull, cNotNull)

test("natural inner join") {
val plan = r1.join(r2, NaturalJoin(Inner), None)
val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
}

test("natural left join") {
val plan = r1.join(r2, NaturalJoin(LeftOuter), None)
val expected = r1.join(r2, LeftOuter, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
}

test("natural right join") {
val plan = r1.join(r2, NaturalJoin(RightOuter), None)
val expected = r1.join(r2, RightOuter, Some(EqualTo(a, a))).select(a, b, c)
checkAnalysis(plan, expected)
}

test("natural full outer join") {
val plan = r1.join(r2, NaturalJoin(FullOuter), None)
val expected = r1.join(r2, FullOuter, Some(EqualTo(a, a))).select(
Alias(Coalesce(Seq(a, a)), "a")(), b, c)
checkAnalysis(plan, expected)
}

test("natural inner join with no nullability") {
val plan = r3.join(r4, NaturalJoin(Inner), None)
val expected = r3.join(r4, Inner, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, aNotNull, cNotNull)
checkAnalysis(plan, expected)
}

test("natural left join with no nullability") {
val plan = r3.join(r4, NaturalJoin(LeftOuter), None)
val expected = r3.join(r4, LeftOuter, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, aNotNull, c)
checkAnalysis(plan, expected)
}

test("natural right join with no nullability") {
val plan = r3.join(r4, NaturalJoin(RightOuter), None)
val expected = r3.join(r4, RightOuter, Some(EqualTo(bNotNull, bNotNull))).select(
bNotNull, a, cNotNull)
checkAnalysis(plan, expected)
}

test("natural full outer join with no nullability") {
val plan = r3.join(r4, NaturalJoin(FullOuter), None)
val expected = r3.join(r4, FullOuter, Some(EqualTo(bNotNull, bNotNull))).select(
Alias(Coalesce(Seq(bNotNull, bNotNull)), "b")(), a, c)
checkAnalysis(plan, expected)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ class DataFrame private[sql](
val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true)
Alias(Coalesce(Seq(leftCol, rightCol)), col)()
}
case NaturalJoin(_) => sys.error("NaturalJoin with using clause is not supported.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we going to support natural join in DataFrame? If so, I think we should also change JoinType.apply

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no i don't think we need to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then this case is unreachable as JoinType.apply won't produce natural join.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup - although we should still throw some exception here just in case we refactor code in the future so this is reachable.

}
// The nullability of output of joined could be different than original column,
// so we can only compare them by exprId
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2056,4 +2056,28 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
}

test("natural join") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should create a unit test for the analyzer also

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

val df1 = Seq(("one", 1), ("two", 2), ("three", 3)).toDF("k", "v1")
val df2 = Seq(("one", 1), ("two", 22), ("one", 5)).toDF("k", "v2")
withTempTable("nt1", "nt2") {
df1.registerTempTable("nt1")
df2.registerTempTable("nt2")
checkAnswer(
sql("SELECT * FROM nt1 natural join nt2 where k = \"one\""),
Row("one", 1, 1) :: Row("one", 1, 5) :: Nil)

checkAnswer(
sql("SELECT * FROM nt1 natural left join nt2 order by v1, v2"),
Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Row("three", 3, null) :: Nil)

checkAnswer(
sql("SELECT * FROM nt1 natural right join nt2 order by v1, v2"),
Row("one", 1, 1) :: Row("one", 1, 5) :: Row("two", 2, 22) :: Nil)

checkAnswer(
sql("SELECT count(*) FROM nt1 natural full outer join nt2"),
Row(4) :: Nil)
}
}
}