Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12576][SQL] Enable expression parsing in CatalystQl #10649

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,45 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) {
}
}


/**
* Returns the AST for the given SQL string.
*/
protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf)

/** Creates LogicalPlan for a given HiveQL string. */
def createPlan(sql: String): LogicalPlan = {
protected def safeParse[T](sql: String, ast: ASTNode, toResult: ASTNode => T): T = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use multi parameter list?

protected  def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T

then you can call it in a slightly nicer way

safeParse(sql, ParseDriver.parseExpression(sql, conf)) { ast =>
  ...
}

try {
createPlan(sql, ParseDriver.parse(sql, conf))
toResult(ast)
} catch {
case e: MatchError => throw e
case e: AnalysisException => throw e
case e: Exception =>
throw new AnalysisException(e.getMessage)
case e: NotImplementedError =>
throw new AnalysisException(
s"""
|Unsupported language features in query: $sql
|${getAst(sql).treeString}
s"""== Unsupported language features in query ==
|== SQL ==
|$sql
|== AST ==
|${ast.treeString}
|== Error ==
|$e
|== Stacktrace ==
|${e.getStackTrace.head}
""".stripMargin)
}
}

protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree)

def parseDdl(ddl: String): Seq[Attribute] = {
val tree = getAst(ddl)
assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.")
val tableOps = tree.children
val colList = tableOps
.find(_.text == "TOK_TABCOLLIST")
.getOrElse(sys.error("No columnList!"))

colList.children.map(nodeToAttribute)
/** Creates LogicalPlan for a given SQL string. */
def createPlan(sql: String): LogicalPlan =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, I like the old name "parseXXX" better.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but no big deal.

safeParse(sql, ParseDriver.parsePlan(sql, conf), nodeToPlan)

/** Creates Expression for a given SQL string. */
def createExpression(sql: String): Expression =
safeParse(sql, ParseDriver.parseExpression(sql, conf), nodeToExpr)

def createDdl(sql: String): Seq[Attribute] = {
safeParse(sql, ParseDriver.parseExpression(sql, conf), ast => {
val Token("TOK_CREATETABLE", children) = ast
children
.find(_.text == "TOK_TABCOLLIST")
.getOrElse(sys.error("No columnList!"))
.flatMap(_.children.map(nodeToAttribute))
})
}

protected def getClauses(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,15 @@ import org.apache.spark.sql.AnalysisException
* This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver
*/
object ParseDriver extends Logging {
def parse(command: String, conf: ParserConf): ASTNode = {
def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some function doc for these 3 functions, and also cover what the differences are?

parser.statement().getTree
}

def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
parser.expression().getTree
}

def parse(command: String, conf: ParserConf)(toTree: SparkSqlParser => CommonTree): ASTNode = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be private?

logInfo(s"Parsing command: $command")

// Setup error collection.
Expand All @@ -44,7 +52,7 @@ object ParseDriver extends Logging {
parser.configure(conf, reporter)

try {
val result = parser.statement()
val result = toTree(parser)

// Check errors.
reporter.checkForErrors()
Expand All @@ -57,7 +65,7 @@ object ParseDriver extends Logging {
if (tree.token != null || tree.getChildCount == 0) tree
else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree])
}
val tree = nonNullToken(result.getTree)
val tree = nonNullToken(result)

// Make sure all boundaries are set.
tree.setUnknownTokenBoundaries()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,37 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions.{Subtract, Add, Literal}
import org.apache.spark.sql.catalyst.plans.PlanTest

class CatalystQlSuite extends PlanTest {
val parser = new CatalystQl()

test("parse expressions") {
compareExpressions(
parser.createExpression("prinln('hello', 'world')"),
UnresolvedFunction(
"prinln", Literal("hello") :: Literal("world") :: Nil, false))

compareExpressions(
parser.createExpression("1 + r.r"),
Add(Literal(1), UnresolvedAttribute("r.r")))

compareExpressions(
parser.createExpression("1 - f('o', o(bar))"),
Subtract(Literal(1),
UnresolvedFunction("f",
Literal("o") ::
UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) ::
Nil, false)))
}

test("parse union/except/intersect") {
val paresr = new CatalystQl()
paresr.createPlan("select * from t1 union all select * from t2")
paresr.createPlan("select * from t1 union distinct select * from t2")
paresr.createPlan("select * from t1 union select * from t2")
paresr.createPlan("select * from t1 except select * from t2")
paresr.createPlan("select * from t1 intersect select * from t2")
parser.createPlan("select * from t1 union all select * from t2")
parser.createPlan("select * from t1 union distinct select * from t2")
parser.createPlan("select * from t1 union select * from t2")
parser.createPlan("select * from t1 except select * from t2")
parser.createPlan("select * from t1 intersect select * from t2")
}
}
21 changes: 11 additions & 10 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,17 +229,18 @@ private[hive] object HiveQl extends SparkQl with Logging {
CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql)
}

protected override def createPlan(
sql: String,
node: ASTNode): LogicalPlan = {
if (nativeCommands.contains(node.text)) {
HiveNativeCommand(sql)
} else {
nodeToPlan(node) match {
case NativePlaceholder => HiveNativeCommand(sql)
case plan => plan
/** Creates LogicalPlan for a given SQL string. */
override def createPlan(sql: String): LogicalPlan = {
safeParse(sql, ParseDriver.parsePlan(sql, conf), ast => {
if (nativeCommands.contains(ast.text)) {
HiveNativeCommand(sql)
} else {
nodeToPlan(ast) match {
case NativePlaceholder => HiveNativeCommand(sql)
case plan => plan
}
}
}
})
}

protected override def isNoExplainCommand(command: String): Boolean =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
* @param token a unique token in the string that should be indicated by the exception
*/
def positionTest(name: String, query: String, token: String): Unit = {
def ast = ParseDriver.parse(query, hiveContext.conf)
def parseTree =
Try(quietly(ast.treeString)).getOrElse("<failed to parse>")
def ast = ParseDriver.parsePlan(query, hiveContext.conf)
def parseTree = Try(quietly(ast.treeString)).getOrElse("<failed to parse>")

test(name) {
val error = intercept[AnalysisException] {
Expand Down