Skip to content

Commit

Permalink
[SPARK-2781][SQL] Check resolution of LogicalPlans in Analyzer.
Browse files Browse the repository at this point in the history
  • Loading branch information
staple committed Sep 10, 2014
1 parent 79cdb9b commit 701dcd2
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
// TODO: pass this in as a parameter.
val fixedPoint = FixedPoint(100)

val batches: Seq[Batch] = Seq(
/**
* Override to provide additional rules for the "Resolution" batch.
*/
val extendedRules: List[Rule[LogicalPlan]] = Nil

lazy val batches: Seq[Batch] = Seq(
Batch("MultiInstanceRelations", Once,
NewRelationInstances),
Batch("CaseInsensitiveAttributeReferences", Once,
Expand All @@ -54,23 +59,31 @@ class Analyzer(catalog: Catalog, registry: FunctionRegistry, caseSensitive: Bool
StarExpansion ::
ResolveFunctions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
typeCoercionRules :_*),
UnresolvedHavingClauseAttributes ::
typeCoercionRules :::
extendedRules : _*),
Batch("Check Analysis", Once,
CheckResolution),
Batch("AnalysisOperators", fixedPoint,
EliminateAnalysisOperators)
)

/**
* Makes sure all attributes have been resolved.
* Makes sure all attributes and logical plans have been resolved.
*/
object CheckResolution extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
plan.transform {
case p if p.expressions.exists(!_.resolved) =>
throw new TreeNodeException(p,
s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}")
case p if !p.resolved && p.childrenResolved =>
throw new TreeNodeException(p, "Unresolved plan found")
} match {
// As a backstop, use the root node to check that the entire plan tree is resolved.
case p if !p.resolved =>
throw new TreeNodeException(p, "Unresolved plan in tree")
case p => p
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,10 @@ trait HiveTypeCoercion {
// If the data type is not boolean and is being cast boolean, turn it into a comparison
// with the numeric value, i.e. x != 0. This will coerce the type into numeric type.
case Cast(e, BooleanType) if e.dataType != BooleanType => Not(EqualTo(e, Literal(0)))
// Stringify boolean if casting to StringType.
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
case Cast(e, StringType) if e.dataType == BooleanType =>
If(e, Literal("true"), Literal("false"))
// Turn true into 1, and false into 0 if casting boolean into other types.
case Cast(e, dataType) if e.dataType == BooleanType =>
Cast(If(e, Literal(1), Literal(0)), dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] {

/**
* Returns true if this expression and all its children have been resolved to a specific schema
* and false if it is still contains any unresolved placeholders. Implementations of LogicalPlan
* and false if it still contains any unresolved placeholders. Implementations of LogicalPlan
* can override this (e.g.
* [[org.apache.spark.sql.catalyst.analysis.UnresolvedRelation UnresolvedRelation]]
* should return `false`).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,17 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
val e = intercept[TreeNodeException[_]] {
caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
}
assert(e.getMessage().toLowerCase.contains("unresolved"))
assert(e.getMessage().toLowerCase.contains("unresolved attribute"))
}

test("throw errors for unresolved plans during analysis") {
case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output = Nil
}
val e = intercept[TreeNodeException[_]] {
caseSensitiveAnalyze(UnresolvedTestPlan())
}
assert(e.getMessage().toLowerCase.contains("unresolved plan"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.analysis

import org.scalatest.FunSuite

import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
import org.apache.spark.sql.catalyst.types._

class HiveTypeCoercionSuite extends FunSuite {
Expand Down Expand Up @@ -84,4 +86,16 @@ class HiveTypeCoercionSuite extends FunSuite {
widenTest(StringType, MapType(IntegerType, StringType, true), None)
widenTest(ArrayType(IntegerType), StructType(Seq()), None)
}

test("boolean casts") {
def ruleTest(initial: Expression, transformed: Expression) {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
assert(BooleanCasts(Project(Seq(Alias(initial, "a")()), testRelation)) ==
Project(Seq(Alias(transformed, "a")()), testRelation))
}
// Remove superflous boolean -> boolean casts.
ruleTest(Cast(Literal(true), BooleanType), Literal(true))
// Stringify boolean when casting to string.
ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false")))
}
}
44 changes: 40 additions & 4 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql

import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.test._
import org.scalatest.BeforeAndAfterAll
Expand Down Expand Up @@ -477,18 +478,48 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
(3, null)))
}

test("EXCEPT") {
test("UNION") {
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM upperCaseData"),
(1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
(4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION SELECT * FROM lowerCaseData"),
(1, "a") :: (2, "b") :: (3, "c") :: (4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData UNION ALL SELECT * FROM lowerCaseData"),
(1, "a") :: (1, "a") :: (2, "b") :: (2, "b") :: (3, "c") :: (3, "c") ::
(4, "d") :: (4, "d") :: Nil)
}

test("UNION with column mismatches") {
// Column name mismatches are allowed.
checkAnswer(
sql("SELECT n,l FROM lowerCaseData UNION SELECT N as x1, L as x2 FROM upperCaseData"),
(1, "A") :: (1, "a") :: (2, "B") :: (2, "b") :: (3, "C") :: (3, "c") ::
(4, "D") :: (4, "d") :: (5, "E") :: (6, "F") :: Nil)
// Column type mismatches are not allowed, forcing a type coercion.
checkAnswer(
sql("SELECT n FROM lowerCaseData UNION SELECT L FROM upperCaseData"),
("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Tuple1(_)))
// Column type mismatches where a coercion is not possible, in this case between integer
// and array types, trigger a TreeNodeException.
intercept[TreeNodeException[_]] {
sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect()
}
}

test("EXCEPT") {
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData "),
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM upperCaseData"),
(1, "a") ::
(2, "b") ::
(3, "c") ::
(4, "d") :: Nil)
checkAnswer(
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData "), Nil)
sql("SELECT * FROM lowerCaseData EXCEPT SELECT * FROM lowerCaseData"), Nil)
checkAnswer(
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData "), Nil)
sql("SELECT * FROM upperCaseData EXCEPT SELECT * FROM upperCaseData"), Nil)
}

test("INTERSECT") {
Expand Down Expand Up @@ -635,5 +666,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Seq()
)

test("cast boolean to string") {
// TODO Ensure true/false string letter casing is consistent with Hive in all cases.
checkAnswer(
sql("SELECT CAST(TRUE AS STRING), CAST(FALSE AS STRING) FROM testData LIMIT 1"),
("true", "false") :: Nil)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
/* An analyzer that uses the Hive metastore. */
@transient
override protected[sql] lazy val analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = false)
new Analyzer(catalog, functionRegistry, caseSensitive = false) {
override val extendedRules = catalog.CreateTables :: catalog.PreInsertionCasts :: Nil
}

/**
* Runs the specified SQL query using Hive.
Expand Down Expand Up @@ -353,9 +355,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {

/** Extends QueryExecution with hive specific features. */
protected[sql] abstract class QueryExecution extends super.QueryExecution {
// TODO: Create mixin for the analyzer instead of overriding things here.
override lazy val optimizedPlan =
optimizer(ExtractPythonUdfs(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))))
// TODO: Utilize extendedRules in the analyzer instead of overriding things here.
override lazy val optimizedPlan = optimizer(ExtractPythonUdfs(analyzed))

override lazy val toRdd: RDD[Row] = executedPlan.execute().map(_.copy())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,15 +109,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
*/
object CreateTables extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

case InsertIntoCreatedTable(db, tableName, child) =>
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)

createTable(databaseName, tblName, child.output)

InsertIntoTable(
EliminateAnalysisOperators(
lookupRelation(Some(databaseName), tblName, None)),
lookupRelation(Some(databaseName), tblName, None),
Map.empty,
child,
overwrite = false)
Expand All @@ -130,15 +132,17 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
*/
object PreInsertionCasts extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
// Wait until children are resolved
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p

case p @ InsertIntoTable(table: MetastoreRelation, _, child, _) =>
case p @ InsertIntoTable(
LowerCaseSchema(table: MetastoreRelation), _, child, _) =>
castChildOutput(p, table, child)

case p @ logical.InsertIntoTable(
InMemoryRelation(_, _, _,
HiveTableScan(_, table, _)), _, child, _) =>
LowerCaseSchema(
InMemoryRelation(_, _, _,
HiveTableScan(_, table, _))), _, child, _) =>
castChildOutput(p, table, child)
}

Expand Down

0 comments on commit 701dcd2

Please sign in to comment.