Skip to content

Commit

Permalink
[SPARK-24782][SQL] Simplify conf retrieval in SQL expressions
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

The PR simplifies the retrieval of config in `size`, as we can access them from tasks too thanks to SPARK-24250.

## How was this patch tested?

existing UTs

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #21736 from mgaido91/SPARK-24605_followup.
  • Loading branch information
mgaido91 authored and gatorsmile committed Jul 12, 2018
1 parent ff7f6ef commit e008ad1
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 67 deletions.
Expand Up @@ -89,15 +89,9 @@ trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
> SELECT _FUNC_(NULL);
-1
""")
case class Size(
child: Expression,
legacySizeOfNull: Boolean)
extends UnaryExpression with ExpectsInputTypes {
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {

def this(child: Expression) =
this(
child,
legacySizeOfNull = SQLConf.get.getConf(SQLConf.LEGACY_SIZE_OF_NULL))
val legacySizeOfNull = SQLConf.get.legacySizeOfNull

override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
Expand Down
Expand Up @@ -514,10 +514,11 @@ case class JsonToStructs(
schema: DataType,
options: Map[String, String],
child: Expression,
timeZoneId: Option[String],
forceNullableSchema: Boolean)
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {

val forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA)

// The JSON input data might be missing certain fields. We force the nullability
// of the user-provided schema to avoid data corruptions. In particular, the parquet-mr encoder
// can generate incorrect files if values are missing in columns declared as non-nullable.
Expand All @@ -531,8 +532,7 @@ case class JsonToStructs(
schema = JsonExprUtils.evalSchemaExpr(schema),
options = options,
child = child,
timeZoneId = None,
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
timeZoneId = None)

def this(child: Expression, schema: Expression) = this(child, schema, Map.empty[String, String])

Expand All @@ -541,13 +541,7 @@ case class JsonToStructs(
schema = JsonExprUtils.evalSchemaExpr(schema),
options = JsonExprUtils.convertToMapData(options),
child = child,
timeZoneId = None,
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))

// Used in `org.apache.spark.sql.functions`
def this(schema: DataType, options: Map[String, String], child: Expression) =
this(schema, options, child, timeZoneId = None,
forceNullableSchema = SQLConf.get.getConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA))
timeZoneId = None)

override def checkInputDataTypes(): TypeCheckResult = nullableSchema match {
case _: StructType | ArrayType(_: StructType, _) | _: MapType =>
Expand Down
Expand Up @@ -27,8 +27,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT

/**
* The active config object within the current scope.
* Note that if you want to refer config values during execution, you have to capture them
* in Driver and use the captured values in Executors.
* See [[SQLConf.get]] for more information.
*/
def conf: SQLConf = SQLConf.get
Expand Down
Expand Up @@ -24,43 +24,48 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.DateTimeTestUtils
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.CalendarInterval

class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

def testSize(legacySizeOfNull: Boolean, sizeOfNull: Any): Unit = {
def testSize(sizeOfNull: Any): Unit = {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))

checkEvaluation(Size(a0, legacySizeOfNull), 3)
checkEvaluation(Size(a1, legacySizeOfNull), 0)
checkEvaluation(Size(a2, legacySizeOfNull), 2)
checkEvaluation(Size(a0), 3)
checkEvaluation(Size(a1), 0)
checkEvaluation(Size(a2), 2)

val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))

checkEvaluation(Size(m0, legacySizeOfNull), 2)
checkEvaluation(Size(m1, legacySizeOfNull), 0)
checkEvaluation(Size(m2, legacySizeOfNull), 1)
checkEvaluation(Size(m0), 2)
checkEvaluation(Size(m1), 0)
checkEvaluation(Size(m2), 1)

checkEvaluation(
Size(Literal.create(null, MapType(StringType, StringType)), legacySizeOfNull),
Size(Literal.create(null, MapType(StringType, StringType))),
expected = sizeOfNull)
checkEvaluation(
Size(Literal.create(null, ArrayType(StringType)), legacySizeOfNull),
Size(Literal.create(null, ArrayType(StringType))),
expected = sizeOfNull)
}

test("Array and Map Size - legacy") {
testSize(legacySizeOfNull = true, sizeOfNull = -1)
withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "true") {
testSize(sizeOfNull = -1)
}
}

test("Array and Map Size") {
testSize(legacySizeOfNull = false, sizeOfNull = null)
withSQLConf(SQLConf.LEGACY_SIZE_OF_NULL.key -> "false") {
testSize(sizeOfNull = null)
}
}

test("MapKeys/MapValues") {
Expand Down
Expand Up @@ -392,7 +392,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
val jsonData = """{"a": 1}"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true),
JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId),
InternalRow(1)
)
}
Expand All @@ -401,13 +401,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
val jsonData = """{"a" 1}"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId, true),
JsonToStructs(schema, Map.empty, Literal(jsonData), gmtId),
null
)

// Other modes should still return `null`.
checkEvaluation(
JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId, true),
JsonToStructs(schema, Map("mode" -> PermissiveMode.name), Literal(jsonData), gmtId),
null
)
}
Expand All @@ -416,70 +416,70 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
val input = """[{"a": 1}, {"a": 2}]"""
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(1) :: InternalRow(2) :: Nil
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=object, schema=array, output=array of single row") {
val input = """{"a": 1}"""
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(1) :: Nil
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty array, schema=array, output=empty array") {
val input = "[ ]"
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = Nil
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty object, schema=array, output=array of single row with null") {
val input = "{ }"
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(null) :: Nil
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=array of single object, schema=struct, output=single row") {
val input = """[{"a": 1}]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = InternalRow(1)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=array, schema=struct, output=null") {
val input = """[{"a": 1}, {"a": 2}]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = null
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty array, schema=struct, output=null") {
val input = """[]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = null
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty object, schema=struct, output=single row with null") {
val input = """{ }"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = InternalRow(null)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId, true), output)
checkEvaluation(JsonToStructs(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json null input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId, true),
JsonToStructs(schema, Map.empty, Literal.create(null, StringType), gmtId),
null
)
}

test("SPARK-20549: from_json bad UTF-8") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal(badJson), gmtId, true),
JsonToStructs(schema, Map.empty, Literal(badJson), gmtId),
null)
}

Expand All @@ -491,14 +491,14 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
c.set(2016, 0, 1, 0, 0, 0)
c.set(Calendar.MILLISECOND, 123)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId, true),
JsonToStructs(schema, Map.empty, Literal(jsonData1), gmtId),
InternalRow(c.getTimeInMillis * 1000L)
)
// The result doesn't change because the json string includes timezone string ("Z" here),
// which means the string represents the timestamp string in the timezone regardless of
// the timeZoneId parameter.
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST"), true),
JsonToStructs(schema, Map.empty, Literal(jsonData1), Option("PST")),
InternalRow(c.getTimeInMillis * 1000L)
)

Expand All @@ -512,8 +512,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
schema,
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss"),
Literal(jsonData2),
Option(tz.getID),
true),
Option(tz.getID)),
InternalRow(c.getTimeInMillis * 1000L)
)
checkEvaluation(
Expand All @@ -522,8 +521,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
Map("timestampFormat" -> "yyyy-MM-dd'T'HH:mm:ss",
DateTimeUtils.TIMEZONE_OPTION -> tz.getID),
Literal(jsonData2),
gmtId,
true),
gmtId),
InternalRow(c.getTimeInMillis * 1000L)
)
}
Expand All @@ -532,7 +530,7 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with
test("SPARK-19543: from_json empty input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId, true),
JsonToStructs(schema, Map.empty, Literal.create(" ", StringType), gmtId),
null
)
}
Expand Down Expand Up @@ -687,23 +685,24 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper with

test("from_json missing fields") {
for (forceJsonNullableSchema <- Seq(false, true)) {
val input =
"""{
withSQLConf(SQLConf.FROM_JSON_FORCE_NULLABLE_SCHEMA.key -> forceJsonNullableSchema.toString) {
val input =
"""{
| "a": 1,
| "c": "foo"
|}
|""".stripMargin
val jsonSchema = new StructType()
.add("a", LongType, nullable = false)
.add("b", StringType, nullable = false)
.add("c", StringType, nullable = false)
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
val expr = JsonToStructs(
jsonSchema, Map.empty, Literal.create(input, StringType), gmtId, forceJsonNullableSchema)
checkEvaluation(expr, output)
val schema = expr.dataType
val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
assert(schemaToCompare == schema)
val jsonSchema = new StructType()
.add("a", LongType, nullable = false)
.add("b", StringType, nullable = false)
.add("c", StringType, nullable = false)
val output = InternalRow(1L, null, UTF8String.fromString("foo"))
val expr = JsonToStructs(jsonSchema, Map.empty, Literal.create(input, StringType), gmtId)
checkEvaluation(expr, output)
val schema = expr.dataType
val schemaToCompare = if (forceJsonNullableSchema) jsonSchema.asNullable else jsonSchema
assert(schemaToCompare == schema)
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Expand Up @@ -3304,7 +3304,7 @@ object functions {
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr {
new JsonToStructs(schema, options, e.expr)
JsonToStructs(schema, options, e.expr)
}

/**
Expand Down Expand Up @@ -3495,7 +3495,7 @@ object functions {
* @group collection_funcs
* @since 1.5.0
*/
def size(e: Column): Column = withExpr { new Size(e.expr) }
def size(e: Column): Column = withExpr { Size(e.expr) }

/**
* Sorts the input array for the given column in ascending order,
Expand Down

0 comments on commit e008ad1

Please sign in to comment.