Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-3421][SQL] Allows arbitrary character in StructField.name #2291

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,8 +278,23 @@ def __init__(self, name, dataType, nullable):
self.nullable = nullable

def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())
"""
>>> 'StructField("f1",StringType,True)'
... == StructField("f1", StringType, True)
True
>>> 'StructField("f 1",StringType,True)'
... == StructField("f 1", StringType, True)
True
>>> 'StructField("f \\"1\\"",StringType,True)'
... == StructField('f "1"', StringType, True)
True
>>> 'StructField("f \\\\1",StringType,True)'
... == StructField('f \\1', StringType, True)
True
"""
escapedName = self.name.replace('\\', '\\\\').replace('"', '\\"')
return 'StructField("%s",%s,%s)' % (escapedName, self.dataType,
str(self.nullable).lower())


class StructType(DataType):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ import org.apache.spark.util.Utils
* Utility functions for working with DataTypes.
*/
object DataType extends RegexParsers {
override def skipWhitespace: Boolean = false

protected lazy val primitiveType: Parser[DataType] =
"StringType" ^^^ StringType |
"FloatType" ^^^ FloatType |
Expand All @@ -55,17 +57,22 @@ object DataType extends RegexParsers {
}

protected lazy val structField: Parser[StructField] =
("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
case name ~ tpe ~ nullable =>
StructField(name, tpe, nullable = nullable)
"StructField(" ~> quotedString ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ {
case name ~ tpe ~ nullable =>
StructField(name, tpe, nullable = nullable)
}

protected lazy val quotedString: Parser[String] =
"\"" ~> rep("[^\\\\\"]".r | ("\\" ~> "[\\\\\"]".r)) <~ "\"" ^^ {
case ch => ch.mkString
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this should be a typo, took the chance to fix it.


protected lazy val boolVal: Parser[Boolean] =
"true" ^^^ true |
"false" ^^^ false

protected lazy val structType: Parser[DataType] =
"StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ {
"StructType(List(" ~> repsep(structField, ",") <~ "))" ^^ {
case fields => new StructType(fields)
}

Expand Down Expand Up @@ -298,11 +305,15 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT
* @param nullable Indicates if values of this field can be `null` values.
*/
case class StructField(name: String, dataType: DataType, nullable: Boolean) {

private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = {
builder.append(s"${prefix}-- ${name}: ${dataType.simpleString} (nullable = ${nullable})\n")
DataType.buildFormattedString(dataType, s"$prefix |", builder)
}

override def toString = {
val escapedName = name.flatMap(ch => if ("\"\\".contains(ch)) s"\\$ch" else s"$ch")
s"""StructField("$escapedName",$dataType,$nullable)"""
}
}

object StructType {
Expand Down Expand Up @@ -364,6 +375,8 @@ case class StructType(fields: Seq[StructField]) extends DataType {
}

def simpleString: String = "struct"

override def toString = s"StructType(List(${fields.map(_.toString).mkString(",")}))"
}

object MapType {
Expand Down
42 changes: 42 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql

import scala.util.Random

import org.scalatest.FunSuite

class DataTypeSuite extends FunSuite {
Expand Down Expand Up @@ -55,4 +57,44 @@ class DataTypeSuite extends FunSuite {
struct(Set("b", "d", "e", "f"))
}
}

test("StructField.toString") {
def structFieldWithName(name: String) = StructField(name, StringType, nullable = true)

assertResult("""StructField("a",StringType,true)""") {
structFieldWithName("a").toString
}

assertResult("""StructField("(a)",StringType,true)""") {
structFieldWithName("(a)").toString
}

assertResult("""StructField("a\\b\"",StringType,true)""") {
structFieldWithName("""a\b"""").toString
}
}

test("parsing StructField strings") {
def randomFieldName = {
val mustHave = "\\\" ".toArray
Random.shuffle(Seq.fill(64)(Random.nextPrintableChar()) ++ mustHave).mkString
}

def escape(str: String) = str.flatMap {
case '\"' => "\\\""
case '\\' => "\\\\"
case ch => s"$ch"
}

(0 until 100).foreach { _ =>
val name = randomFieldName
val expected = StructType(Seq(StructField(name, StringType, true)))
val structFieldString = {
val quotedEscapedName = "\"" + escape(name) + "\""
s"StructField($quotedEscapedName,StringType,true)"
}
val structTypeString = s"StructType(List($structFieldString))"
assert(catalyst.types.DataType(structTypeString) === expected)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive.execution

import scala.util.Try

import org.apache.spark.sql.{SchemaRDD, Row}
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
Expand Down Expand Up @@ -313,7 +312,7 @@ class HiveQuerySuite extends HiveComparisonTest {
"SELECT srcalias.KEY, SRCALIAS.value FROM sRc SrCAlias WHERE SrCAlias.kEy < 15")

test("case sensitivity: registered table") {
val testData: SchemaRDD =
val testData =
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(2, "str2") :: Nil)
Expand All @@ -327,7 +326,7 @@ class HiveQuerySuite extends HiveComparisonTest {

def isExplanation(result: SchemaRDD) = {
val explanation = result.select('plan).collect().map { case Row(plan: String) => plan }
explanation.exists(_ == "== Physical Plan ==")
explanation.contains("== Physical Plan ==")
}

test("SPARK-1704: Explain commands as a SchemaRDD") {
Expand Down Expand Up @@ -467,7 +466,7 @@ class HiveQuerySuite extends HiveComparisonTest {
}

// Describe a registered temporary table.
val testData: SchemaRDD =
val testData =
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(1, "str2") :: Nil)
Expand Down