diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 5182af9d20e58..bc48ed4f08ce0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -53,7 +53,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case e => e } - val generatedSQL = toSQL(canonicalizedPlan) + val generatedSQL = toSQL(canonicalizedPlan, true) logDebug( s"""Built SQL query string successfully from given logical plan: | @@ -78,6 +78,27 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi } } + private def toSQL(node: LogicalPlan, topNode: Boolean): String = { + if (topNode) { + node match { + case d: Distinct => toSQL(node) + case p: Project => toSQL(node) + case a: Aggregate => toSQL(node) + case s: Sort => toSQL(node) + case r: RepartitionByExpression => toSQL(node) + case _ => + build( + "SELECT", + node.output.map(_.sql).mkString(", "), + "FROM", + toSQL(node) + ) + } + } else { + toSQL(node) + } + } + private def toSQL(node: LogicalPlan): String = node match { case Distinct(p: Project) => projectToSQL(p, isDistinct = true) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index f9a5cf4d4781e..d708fcf8dd4d9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive import scala.util.control.NonFatal +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils @@ -54,6 +56,33 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { sql("DROP TABLE IF EXISTS t0") } + private def checkPlan(plan: LogicalPlan, sqlContext: SQLContext, expected: String): Unit = { + val convertedSQL = try new SQLBuilder(plan, sqlContext).toSQL catch { + case NonFatal(e) => + fail( + s"""Cannot convert the following logical query plan back to SQL query string: + | + |# Original logical query plan: + |${plan.treeString} + """.stripMargin, e) + } + + try { + checkAnswer(sql(convertedSQL), DataFrame(sqlContext, plan)) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$convertedSQL + | + |# Original logical query plan: + |${plan.treeString} + """.stripMargin, + cause) + } + } + private def checkHiveQl(hiveQl: String): Unit = { val df = sql(hiveQl) @@ -157,6 +186,18 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { "SELECT x.key, COUNT(*) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key group by x.key") } + test("join plan") { + val expectedSql = "SELECT x.key FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key" + + val df1 = sqlContext.table("parquet_t1").as("x") + val df2 = sqlContext.table("parquet_t1").as("y") + val joinPlan = df1.join(df2).queryExecution.analyzed + + // Make sure we have a plain Join operator without Project on top of it. + assert(joinPlan.isInstanceOf[Join]) + checkPlan(joinPlan, sqlContext, expectedSql) + } + test("case") { checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM parquet_t0") }