Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-12576][SQL] Enable expression parsing in CatalystQl #10649

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ selectItem
:
(tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns)
|
namedExpression
;

namedExpression
@init { gParent.pushMsg("select named expression", state); }
@after { gParent.popMsg(state); }
:
( expression
((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))?
) -> ^(TOK_SELEXPR expression identifier*)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.random.RandomSampler

private[sql] object CatalystQl {
val parser = new CatalystQl
def parseExpression(sql: String): Expression = parser.parseExpression(sql)
def parseTableIdentifier(sql: String): TableIdentifier = parser.parseTableIdentifier(sql)
}

/**
* This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]].
*/
Expand All @@ -41,43 +47,53 @@ private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) {
}
}


/**
* Returns the AST for the given SQL string.
* The safeParse method allows a user to focus on the parsing/AST transformation logic. This
* method will take care of possible errors during the parsing process.
*/
protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf)

/** Creates LogicalPlan for a given HiveQL string. */
def createPlan(sql: String): LogicalPlan = {
protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = {
try {
createPlan(sql, ParseDriver.parse(sql, conf))
toResult(ast)
} catch {
case e: MatchError => throw e
case e: AnalysisException => throw e
case e: Exception =>
throw new AnalysisException(e.getMessage)
case e: NotImplementedError =>
throw new AnalysisException(
s"""
|Unsupported language features in query: $sql
|${getAst(sql).treeString}
s"""Unsupported language features in query
|== SQL ==
|$sql
|== AST ==
|${ast.treeString}
|== Error ==
|$e
|== Stacktrace ==
|${e.getStackTrace.head}
""".stripMargin)
}
}

protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree)

def parseDdl(ddl: String): Seq[Attribute] = {
val tree = getAst(ddl)
assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.")
val tableOps = tree.children
val colList = tableOps
.find(_.text == "TOK_TABCOLLIST")
.getOrElse(sys.error("No columnList!"))

colList.children.map(nodeToAttribute)
/** Creates LogicalPlan for a given SQL string. */
def parsePlan(sql: String): LogicalPlan =
safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan)

/** Creates Expression for a given SQL string. */
def parseExpression(sql: String): Expression =
safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get)

/** Creates TableIdentifier for a given SQL string. */
def parseTableIdentifier(sql: String): TableIdentifier =
safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent)

def parseDdl(sql: String): Seq[Attribute] = {
safeParse(sql, ParseDriver.parseExpression(sql, conf)) { ast =>
val Token("TOK_CREATETABLE", children) = ast
children
.find(_.text == "TOK_TABCOLLIST")
.getOrElse(sys.error("No columnList!"))
.flatMap(_.children.map(nodeToAttribute))
}
}

protected def getClauses(
Expand Down Expand Up @@ -187,7 +203,6 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val keyMap = keyASTs.zipWithIndex.toMap

val bitmasks: Seq[Int] = setASTs.map {
case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0
case Token("TOK_GROUPING_SETS_EXPRESSION", columns) =>
columns.foldLeft(0)((bitmap, col) => {
val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,25 @@ import org.apache.spark.sql.AnalysisException
* This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver
*/
object ParseDriver extends Logging {
def parse(command: String, conf: ParserConf): ASTNode = {
/** Create an LogicalPlan ASTNode from a SQL command. */
def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add some function doc for these 3 functions, and also cover what the differences are?

parser.statement().getTree
}

/** Create an Expression ASTNode from a SQL command. */
def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
parser.namedExpression().getTree
}

/** Create an TableIdentifier ASTNode from a SQL command. */
def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser =>
parser.tableName().getTree
}

private def parse(
command: String,
conf: ParserConf)(
toTree: SparkSqlParser => CommonTree): ASTNode = {
logInfo(s"Parsing command: $command")

// Setup error collection.
Expand All @@ -44,7 +62,7 @@ object ParseDriver extends Logging {
parser.configure(conf, reporter)

try {
val result = parser.statement()
val result = toTree(parser)

// Check errors.
reporter.checkForErrors()
Expand All @@ -57,7 +75,7 @@ object ParseDriver extends Logging {
if (tree.token != null || tree.getChildCount == 0) tree
else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree])
}
val tree = nonNullToken(result.getTree)
val tree = nonNullToken(result)

// Make sure all boundaries are set.
tree.setUnknownTokenBoundaries()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,36 +17,157 @@

package org.apache.spark.sql.catalyst

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.unsafe.types.CalendarInterval

class CatalystQlSuite extends PlanTest {
val parser = new CatalystQl()

test("test case insensitive") {
val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation)
assert(result === parser.parsePlan("seLect 1"))
assert(result === parser.parsePlan("select 1"))
assert(result === parser.parsePlan("SELECT 1"))
}

test("test NOT operator with comparison operations") {
val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE")
val expected = Project(
UnresolvedAlias(
Not(
GreaterThan(Literal(true), Literal(true)))
) :: Nil,
OneRowRelation)
comparePlans(parsed, expected)
}

test("support hive interval literal") {
def checkInterval(sql: String, result: CalendarInterval): Unit = {
val parsed = parser.parsePlan(sql)
val expected = Project(
UnresolvedAlias(
Literal(result)
) :: Nil,
OneRowRelation)
comparePlans(parsed, expected)
}

def checkYearMonth(lit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' YEAR TO MONTH",
CalendarInterval.fromYearMonthString(lit))
}

def checkDayTime(lit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' DAY TO SECOND",
CalendarInterval.fromDayTimeString(lit))
}

def checkSingleUnit(lit: String, unit: String): Unit = {
checkInterval(
s"SELECT INTERVAL '$lit' $unit",
CalendarInterval.fromSingleUnitString(unit, lit))
}

checkYearMonth("123-10")
checkYearMonth("496-0")
checkYearMonth("-2-3")
checkYearMonth("-123-0")

checkDayTime("99 11:22:33.123456789")
checkDayTime("-99 11:22:33.123456789")
checkDayTime("10 9:8:7.123456789")
checkDayTime("1 0:0:0")
checkDayTime("-1 0:0:0")
checkDayTime("1 0:0:1")

for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) {
checkSingleUnit("7", unit)
checkSingleUnit("-7", unit)
checkSingleUnit("0", unit)
}

checkSingleUnit("13.123456789", "second")
checkSingleUnit("-13.123456789", "second")
}

test("support scientific notation") {
def assertRight(input: String, output: Double): Unit = {
val parsed = parser.parsePlan("SELECT " + input)
val expected = Project(
UnresolvedAlias(
Literal(output)
) :: Nil,
OneRowRelation)
comparePlans(parsed, expected)
}

assertRight("9.0e1", 90)
assertRight("0.9e+2", 90)
assertRight("900e-1", 90)
assertRight("900.0E-1", 90)
assertRight("9.e+1", 90)

intercept[AnalysisException](parser.parsePlan("SELECT .e3"))
}

test("parse expressions") {
compareExpressions(
parser.parseExpression("prinln('hello', 'world')"),
UnresolvedFunction(
"prinln", Literal("hello") :: Literal("world") :: Nil, false))

compareExpressions(
parser.parseExpression("1 + r.r As q"),
Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")())

compareExpressions(
parser.parseExpression("1 - f('o', o(bar))"),
Subtract(Literal(1),
UnresolvedFunction("f",
Literal("o") ::
UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) ::
Nil, false)))
}

test("table identifier") {
assert(TableIdentifier("q") === parser.parseTableIdentifier("q"))
assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q"))
intercept[AnalysisException](parser.parseTableIdentifier(""))
// TODO parser swallows third identifier.
// intercept[AnalysisException](parser.parseTableIdentifier("d.q.g"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we going to support this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should support this. Are there use cases for this? I'll create a fix, that'll throw an AnalysisException when we encounter this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, throw exception seems reasonable to me

}

test("parse union/except/intersect") {
parser.createPlan("select * from t1 union all select * from t2")
parser.createPlan("select * from t1 union distinct select * from t2")
parser.createPlan("select * from t1 union select * from t2")
parser.createPlan("select * from t1 except select * from t2")
parser.createPlan("select * from t1 intersect select * from t2")
parser.createPlan("(select * from t1) union all (select * from t2)")
parser.createPlan("(select * from t1) union distinct (select * from t2)")
parser.createPlan("(select * from t1) union (select * from t2)")
parser.createPlan("select * from ((select * from t1) union (select * from t2)) t")
parser.parsePlan("select * from t1 union all select * from t2")
parser.parsePlan("select * from t1 union distinct select * from t2")
parser.parsePlan("select * from t1 union select * from t2")
parser.parsePlan("select * from t1 except select * from t2")
parser.parsePlan("select * from t1 intersect select * from t2")
parser.parsePlan("(select * from t1) union all (select * from t2)")
parser.parsePlan("(select * from t1) union distinct (select * from t2)")
parser.parsePlan("(select * from t1) union (select * from t2)")
parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t")
}

test("window function: better support of parentheses") {
parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " +
"order by 2) from windowData")
parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " +
"order by 2) from windowData")
parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " +
"order by 2) from windowData")

parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " +
"from windowData")
parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " +
"from windowData")
parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " +
"from windowData")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ private[hive] class ExtendedHiveQlParser extends AbstractSparkSQLParser {

protected lazy val hiveQl: Parser[LogicalPlan] =
restInput ^^ {
case statement => HiveQl.createPlan(statement.trim)
case statement => HiveQl.parsePlan(statement.trim)
}

protected lazy val dfs: Parser[LogicalPlan] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive
alias match {
// because hive use things like `_c0` to build the expanded text
// currently we cannot support view from "create view v1(c1) as ..."
case None => Subquery(table.name, HiveQl.createPlan(viewText))
case Some(aliasText) => Subquery(aliasText, HiveQl.createPlan(viewText))
case None => Subquery(table.name, HiveQl.parsePlan(viewText))
case Some(aliasText) => Subquery(aliasText, HiveQl.parsePlan(viewText))
}
} else {
MetastoreRelation(qualifiedTableName.database, qualifiedTableName.name, alias)(table)(hive)
Expand Down
19 changes: 10 additions & 9 deletions sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,16 @@ private[hive] object HiveQl extends SparkQl with Logging {
CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql)
}

protected override def createPlan(
sql: String,
node: ASTNode): LogicalPlan = {
if (nativeCommands.contains(node.text)) {
HiveNativeCommand(sql)
} else {
nodeToPlan(node) match {
case NativePlaceholder => HiveNativeCommand(sql)
case plan => plan
/** Creates LogicalPlan for a given SQL string. */
override def parsePlan(sql: String): LogicalPlan = {
safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast =>
if (nativeCommands.contains(ast.text)) {
HiveNativeCommand(sql)
} else {
nodeToPlan(ast) match {
case NativePlaceholder => HiveNativeCommand(sql)
case plan => plan
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd
* @param token a unique token in the string that should be indicated by the exception
*/
def positionTest(name: String, query: String, token: String): Unit = {
def ast = ParseDriver.parse(query, hiveContext.conf)
def parseTree =
Try(quietly(ast.treeString)).getOrElse("<failed to parse>")
def ast = ParseDriver.parsePlan(query, hiveContext.conf)
def parseTree = Try(quietly(ast.treeString)).getOrElse("<failed to parse>")

test(name) {
val error = intercept[AnalysisException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, M

class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll {
private def extractTableDesc(sql: String): (HiveTable, Boolean) = {
HiveQl.createPlan(sql).collect {
HiveQl.parsePlan(sql).collect {
case CreateTableAsSelect(desc, child, allowExisting) => (desc, allowExisting)
}.head
}
Expand Down