Skip to content

Commit

Permalink
[SPARK-21307][SQL] Remove SQLConf parameters from the parser-related …
Browse files Browse the repository at this point in the history
…classes.

### What changes were proposed in this pull request?
This PR is to remove SQLConf parameters from the parser-related classes.

### How was this patch tested?
The existing test cases.

Author: gatorsmile <gatorsmile@gmail.com>

Closes #18531 from gatorsmile/rmSQLConfParser.
  • Loading branch information
gatorsmile committed Jul 5, 2017
1 parent 742da08 commit c8e7f44
Show file tree
Hide file tree
Showing 11 changed files with 121 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class SessionCatalog(
functionRegistry,
conf,
new Configuration(),
new CatalystSqlParser(conf),
CatalystSqlParser,
DummyFunctionResourceLoader)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ import org.apache.spark.util.random.RandomSampler
* The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or
* TableIdentifier.
*/
class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging {
class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging {
import ParserUtils._

def this() = this(new SQLConf())

protected def typedVisit[T](ctx: ParseTree): T = {
ctx.accept(this).asInstanceOf[T]
}
Expand Down Expand Up @@ -1457,7 +1455,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
* Special characters can be escaped by using Hive/C-style escaping.
*/
private def createString(ctx: StringLiteralContext): String = {
if (conf.escapedStringLiterals) {
if (SQLConf.get.escapedStringLiterals) {
ctx.STRING().asScala.map(stringWithoutUnescape).mkString
} else {
ctx.STRING().asScala.map(string).mkString
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}

/**
Expand Down Expand Up @@ -122,13 +121,8 @@ abstract class AbstractSqlParser extends ParserInterface with Logging {
/**
* Concrete SQL parser for Catalyst-only SQL statements.
*/
class CatalystSqlParser(conf: SQLConf) extends AbstractSqlParser {
val astBuilder = new AstBuilder(conf)
}

/** For test-only. */
object CatalystSqlParser extends AbstractSqlParser {
val astBuilder = new AstBuilder(new SQLConf())
val astBuilder = new AstBuilder
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,12 @@ class ExpressionParserSuite extends PlanTest {
}

test("like expressions with ESCAPED_STRING_LITERALS = true") {
val conf = new SQLConf()
conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, "true")
val parser = new CatalystSqlParser(conf)
assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser)
assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser)
assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser)
val parser = CatalystSqlParser
withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> "true") {
assertEqual("a rlike '^\\x20[\\x20-\\x23]+$'", 'a rlike "^\\x20[\\x20-\\x23]+$", parser)
assertEqual("a rlike 'pattern\\\\'", 'a rlike "pattern\\\\", parser)
assertEqual("a rlike 'pattern\\t\\n'", 'a rlike "pattern\\t\\n", parser)
}
}

test("is null expressions") {
Expand Down Expand Up @@ -435,86 +435,85 @@ class ExpressionParserSuite extends PlanTest {
}

test("strings") {
val parser = CatalystSqlParser
Seq(true, false).foreach { escape =>
val conf = new SQLConf()
conf.setConfString(SQLConf.ESCAPED_STRING_LITERALS.key, escape.toString)
val parser = new CatalystSqlParser(conf)

// tests that have same result whatever the conf is
// Single Strings.
assertEqual("\"hello\"", "hello", parser)
assertEqual("'hello'", "hello", parser)

// Multi-Strings.
assertEqual("\"hello\" 'world'", "helloworld", parser)
assertEqual("'hello' \" \" 'world'", "hello world", parser)

// 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
// regular '%'; to get the correct result you need to add another escaped '\'.
// TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
assertEqual("'pattern%'", "pattern%", parser)
assertEqual("'no-pattern\\%'", "no-pattern\\%", parser)

// tests that have different result regarding the conf
if (escape) {
// When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to
// Spark 1.6 behavior.

// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser)

// Escaped characters.
// Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work
// when ESCAPED_STRING_LITERALS is enabled.
// It is parsed literally.
assertEqual("'\\0'", "\\0", parser)

// Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is enabled.
val e = intercept[ParseException](parser.parseExpression("'\''"))
assert(e.message.contains("extraneous input '''"))

// The unescape special characters (e.g., "\\t") for 2.0+ don't work
// when ESCAPED_STRING_LITERALS is enabled. They are parsed literally.
assertEqual("'\\\"'", "\\\"", parser) // Double quote
assertEqual("'\\b'", "\\b", parser) // Backspace
assertEqual("'\\n'", "\\n", parser) // Newline
assertEqual("'\\r'", "\\r", parser) // Carriage return
assertEqual("'\\t'", "\\t", parser) // Tab character

// The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled.
// They are parsed literally.
assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser)
// The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled.
// They are parsed literally.
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'",
"\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser)
} else {
// Default behavior

// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser)

// Escaped characters.
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00')
assertEqual("'\\''", "\'", parser) // Single quote
assertEqual("'\\\"'", "\"", parser) // Double quote
assertEqual("'\\b'", "\b", parser) // Backspace
assertEqual("'\\n'", "\n", parser) // Newline
assertEqual("'\\r'", "\r", parser) // Carriage return
assertEqual("'\\t'", "\t", parser) // Tab character
assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows)

// Octals
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser)

// Unicode
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)",
parser)
withSQLConf(SQLConf.ESCAPED_STRING_LITERALS.key -> escape.toString) {
// tests that have same result whatever the conf is
// Single Strings.
assertEqual("\"hello\"", "hello", parser)
assertEqual("'hello'", "hello", parser)

// Multi-Strings.
assertEqual("\"hello\" 'world'", "helloworld", parser)
assertEqual("'hello' \" \" 'world'", "hello world", parser)

// 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a
// regular '%'; to get the correct result you need to add another escaped '\'.
// TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method?
assertEqual("'pattern%'", "pattern%", parser)
assertEqual("'no-pattern\\%'", "no-pattern\\%", parser)

// tests that have different result regarding the conf
if (escape) {
// When SQLConf.ESCAPED_STRING_LITERALS is enabled, string literal parsing fallbacks to
// Spark 1.6 behavior.

// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\\\%", parser)

// Escaped characters.
// Unescape string literal "'\\0'" for ASCII NUL (X'00') doesn't work
// when ESCAPED_STRING_LITERALS is enabled.
// It is parsed literally.
assertEqual("'\\0'", "\\0", parser)

// Note: Single quote follows 1.6 parsing behavior when ESCAPED_STRING_LITERALS is
// enabled.
val e = intercept[ParseException](parser.parseExpression("'\''"))
assert(e.message.contains("extraneous input '''"))

// The unescape special characters (e.g., "\\t") for 2.0+ don't work
// when ESCAPED_STRING_LITERALS is enabled. They are parsed literally.
assertEqual("'\\\"'", "\\\"", parser) // Double quote
assertEqual("'\\b'", "\\b", parser) // Backspace
assertEqual("'\\n'", "\\n", parser) // Newline
assertEqual("'\\r'", "\\r", parser) // Carriage return
assertEqual("'\\t'", "\\t", parser) // Tab character

// The unescape Octals for 2.0+ don't work when ESCAPED_STRING_LITERALS is enabled.
// They are parsed literally.
assertEqual("'\\110\\145\\154\\154\\157\\041'", "\\110\\145\\154\\154\\157\\041", parser)
// The unescape Unicode for 2.0+ doesn't work when ESCAPED_STRING_LITERALS is enabled.
// They are parsed literally.
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'",
"\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029", parser)
} else {
// Default behavior

// 'LIKE' string literals.
assertEqual("'pattern\\\\%'", "pattern\\%", parser)
assertEqual("'pattern\\\\\\%'", "pattern\\\\%", parser)

// Escaped characters.
// See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html
assertEqual("'\\0'", "\u0000", parser) // ASCII NUL (X'00')
assertEqual("'\\''", "\'", parser) // Single quote
assertEqual("'\\\"'", "\"", parser) // Double quote
assertEqual("'\\b'", "\b", parser) // Backspace
assertEqual("'\\n'", "\n", parser) // Newline
assertEqual("'\\r'", "\r", parser) // Carriage return
assertEqual("'\\t'", "\t", parser) // Tab character
assertEqual("'\\Z'", "\u001A", parser) // ASCII 26 - CTRL + Z (EOF on windows)

// Octals
assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!", parser)

// Unicode
assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)",
parser)
}
}

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ import org.apache.spark.sql.types.StructType
/**
* Concrete parser for Spark SQL statements.
*/
class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
val astBuilder = new SparkSqlAstBuilder(conf)
class SparkSqlParser extends AbstractSqlParser {

private val substitutor = new VariableSubstitution(conf)
val astBuilder = new SparkSqlAstBuilder

private val substitutor = new VariableSubstitution

protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = {
super.parse(substitutor.substitute(command))(toResult)
Expand All @@ -52,9 +53,11 @@ class SparkSqlParser(conf: SQLConf) extends AbstractSqlParser {
/**
* Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier.
*/
class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) {
class SparkSqlAstBuilder extends AstBuilder {
import org.apache.spark.sql.catalyst.parser.ParserUtils._

private def conf: SQLConf = SQLConf.get

/**
* Create a [[SetCommand]] logical plan.
*
Expand Down
3 changes: 1 addition & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint}
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -1276,7 +1275,7 @@ object functions {
*/
def expr(expr: String): Column = {
val parser = SparkSession.getActiveSession.map(_.sessionState.sqlParser).getOrElse {
new SparkSqlParser(new SQLConf)
new SparkSqlParser
}
Column(parser.parseExpression(expr))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ abstract class BaseSessionStateBuilder(
* Note: this depends on the `conf` field.
*/
protected lazy val sqlParser: ParserInterface = {
extensions.buildParser(session, new SparkSqlParser(conf))
extensions.buildParser(session, new SparkSqlParser)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ import org.apache.spark.internal.config._
*
* Variable substitution is controlled by `SQLConf.variableSubstituteEnabled`.
*/
class VariableSubstitution(conf: SQLConf) {
class VariableSubstitution {

private def conf = SQLConf.get

private val provider = new ConfigProvider {
override def get(key: String): Option[String] = Option(conf.getConfString(key, ""))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType
*/
class SparkSqlParserSuite extends AnalysisTest {

val newConf = new SQLConf
private lazy val parser = new SparkSqlParser(newConf)
private lazy val parser = new SparkSqlParser

/**
* Normalizes plans:
Expand Down Expand Up @@ -285,6 +284,7 @@ class SparkSqlParserSuite extends AnalysisTest {
}

test("query organization") {
val conf = SQLConf.get
// Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows
val baseSql = "select * from t"
val basePlan =
Expand All @@ -293,20 +293,20 @@ class SparkSqlParserSuite extends AnalysisTest {
assertEqual(s"$baseSql distribute by a, b",
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions))
numPartitions = conf.numShufflePartitions))
assertEqual(s"$baseSql distribute by a sort by b",
Sort(SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
numPartitions = conf.numShufflePartitions)))
assertEqual(s"$baseSql cluster by a, b",
Sort(SortOrder(UnresolvedAttribute("a"), Ascending) ::
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
global = false,
RepartitionByExpression(UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: Nil,
basePlan,
numPartitions = newConf.numShufflePartitions)))
numPartitions = conf.numShufflePartitions)))
}

test("pipeline concatenation") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.execution.SparkSqlParser
import org.apache.spark.sql.execution.datasources.CreateTable
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.internal.HiveSerDe
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}


// TODO: merge this with DDLSuite (SPARK-14441)
class DDLCommandSuite extends PlanTest {
private lazy val parser = new SparkSqlParser(new SQLConf)
private lazy val parser = new SparkSqlParser

private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = {
val e = intercept[ParseException] {
Expand Down

0 comments on commit c8e7f44

Please sign in to comment.