Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1773,11 +1773,11 @@ def json_tuple(col, *fields):
@since(2.1)
def from_json(col, schema, options={}):
"""
Parses a column containing a JSON string into a [[StructType]] with the
specified schema. Returns `null`, in the case of an unparseable string.
Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]]
with the specified schema. Returns `null`, in the case of an unparseable string.

:param col: string column in json format
:param schema: a StructType to use when parsing the json column
:param schema: a StructType or ArrayType to use when parsing the json column
:param options: options to control parsing. accepts the same options as the json datasource

>>> from pyspark.sql.types import *
Expand All @@ -1786,6 +1786,11 @@ def from_json(col, schema, options={}):
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_json(df.value, schema).alias("json")).collect()
[Row(json=Row(a=1))]
>>> data = [(1, '''[{"a": 1}]''')]
>>> schema = ArrayType(StructType([StructField("a", IntegerType())]))
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(from_json(df.value, schema).alias("json")).collect()
[Row(json=[Row(a=1)])]
"""

sc = SparkContext._active_spark_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json._
import org.apache.spark.sql.catalyst.util.ParseModes
import org.apache.spark.sql.catalyst.util.{GenericArrayData, ParseModes}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -480,23 +480,45 @@ case class JsonTuple(children: Seq[Expression])
}

/**
* Converts an json input string to a [[StructType]] with the specified schema.
* Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema.
*/
case class JsonToStruct(
schema: StructType,
schema: DataType,
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes {
override def nullable: Boolean = true

def this(schema: StructType, options: Map[String, String], child: Expression) =
def this(schema: DataType, options: Map[String, String], child: Expression) =
this(schema, options, child, None)

override def checkInputDataTypes(): TypeCheckResult = schema match {
Copy link
Contributor

@brkyvz brkyvz Feb 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not just override:

override def inputTypes = new TypeCollection(ArrayType, StructType) :: Nil

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh.. I thought schema is not the child of the expression. Let me check again!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried several combinations with TypeCollection but it seems not working.

case _: StructType | ArrayType(_: StructType, _) =>
super.checkInputDataTypes()
case _ => TypeCheckResult.TypeCheckFailure(
s"Input schema ${schema.simpleString} must be a struct or an array of structs.")
}

@transient
lazy val rowSchema = schema match {
case st: StructType => st
case ArrayType(st: StructType, _) => st
}

// This converts parsed rows to the desired output by the given schema.
@transient
lazy val converter = schema match {
case _: StructType =>
(rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null
Copy link
Contributor

@brkyvz brkyvz Mar 3, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this breaks previous behavior. I would still return the first element if rows.length > 1. Feel free to push back. Also wonder what @marmbrus thinks

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay breaking previous behavior because I'd call truncating an array a bug.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should list this in the release notes though (i.e. go tag the JIRA).

case ArrayType(_: StructType, _) =>
(rows: Seq[InternalRow]) => new GenericArrayData(rows)
}

@transient
lazy val parser =
new JacksonParser(
schema,
rowSchema,
new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get))

override def dataType: DataType = schema
Expand All @@ -505,11 +527,32 @@ case class JsonToStruct(
copy(timeZoneId = Option(timeZoneId))

override def nullSafeEval(json: Any): Any = {
// When input is,
// - `null`: `null`.
// - invalid json: `null`.
// - empty string: `null`.
//
// When the schema is array,
// - json array: `Array(Row(...), ...)`
// - json object: `Array(Row(...))`
// - empty json array: `Array()`.
// - empty json object: `Array(Row(null))`.
//
// When the schema is a struct,
// - json object/array with single element: `Row(...)`
// - json array with multiple elements: `null`
// - empty json array: `null`.
// - empty json object: `Row(null)`.

// We need `null` if the input string is an empty string. `JacksonParser` can
// deal with this but produces `Nil`.
if (json.toString.trim.isEmpty) return null

try {
parser.parse(
converter(parser.parse(
json.asInstanceOf[UTF8String],
CreateJacksonParser.utf8String,
identity[UTF8String]).headOption.orNull
identity[UTF8String]))
} catch {
case _: SparkSQLJsonProcessingException => null
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import java.util.Calendar
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -372,6 +372,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
)
}

test("from_json - input=array, schema=array, output=array") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are great! thanks!

val input = """[{"a": 1}, {"a": 2}]"""
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(1) :: InternalRow(2) :: Nil
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=object, schema=array, output=array of single row") {
val input = """{"a": 1}"""
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(1) :: Nil
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty array, schema=array, output=empty array") {
val input = "[ ]"
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = Nil
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty object, schema=array, output=array of single row with null") {
val input = "{ }"
val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil))
val output = InternalRow(null) :: Nil
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=array of single object, schema=struct, output=single row") {
val input = """[{"a": 1}]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = InternalRow(1)
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=array, schema=struct, output=null") {
val input = """[{"a": 1}, {"a": 2}]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = null
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty array, schema=struct, output=null") {
val input = """[]"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = null
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json - input=empty object, schema=struct, output=single row with null") {
val input = """{ }"""
val schema = StructType(StructField("a", IntegerType) :: Nil)
val output = InternalRow(null)
checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output)
}

test("from_json null input column") {
val schema = StructType(StructField("a", IntegerType) :: Nil)
checkEvaluation(
Expand Down
52 changes: 47 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2964,7 +2964,22 @@ object functions {
* @group collection_funcs
* @since 2.1.0
*/
def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr {
def from_json(e: Column, schema: StructType, options: Map[String, String]): Column =
from_json(e, schema.asInstanceOf[DataType], options)

/**
* (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
* @param options options to control how the json is parsed. accepts the same options and the
* json data source.
*
* @group collection_funcs
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr {
JsonToStruct(schema, options, e.expr)
}

Expand All @@ -2983,6 +2998,21 @@ object functions {
def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column =
from_json(e, schema, options.asScala.toMap)

/**
* (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType`
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
* @param options options to control how the json is parsed. accepts the same options and the
* json data source.
*
* @group collection_funcs
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column =
from_json(e, schema, options.asScala.toMap)

/**
* Parses a column containing a JSON string into a `StructType` with the specified schema.
* Returns `null`, in the case of an unparseable string.
Expand All @@ -2997,8 +3027,21 @@ object functions {
from_json(e, schema, Map.empty[String, String])

/**
* Parses a column containing a JSON string into a `StructType` with the specified schema.
* Returns `null`, in the case of an unparseable string.
* Parses a column containing a JSON string into a `StructType` or `ArrayType`
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string
*
* @group collection_funcs
* @since 2.2.0
*/
def from_json(e: Column, schema: DataType): Column =
from_json(e, schema, Map.empty[String, String])

/**
* Parses a column containing a JSON string into a `StructType` or `ArrayType`
* with the specified schema. Returns `null`, in the case of an unparseable string.
*
* @param e a string column containing JSON data.
* @param schema the schema to use when parsing the json string as a json string
Expand All @@ -3007,8 +3050,7 @@ object functions {
* @since 2.1.0
*/
def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column =
from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options)

from_json(e, DataType.fromJson(schema), options)

/**
* (Scala-specific) Converts a column containing a `StructType` into a JSON string with the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql

import org.apache.spark.sql.functions.{from_json, struct, to_json}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType, TimestampType}
import org.apache.spark.sql.types._

class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
import testImplicits._
Expand Down Expand Up @@ -133,6 +133,29 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
Row(null) :: Nil)
}

test("from_json invalid schema") {
val df = Seq("""{"a" 1}""").toDS()
val schema = ArrayType(StringType)
val message = intercept[AnalysisException] {
df.select(from_json($"value", schema))
}.getMessage

assert(message.contains(
"Input schema array<string> must be a struct or an array of structs."))
}

test("from_json array support") {
val df = Seq("""[{"a": 1, "b": "a"}, {"a": 2}, { }]""").toDS()
val schema = ArrayType(
StructType(
StructField("a", IntegerType) ::
StructField("b", StringType) :: Nil))

checkAnswer(
df.select(from_json($"value", schema)),
Row(Seq(Row(1, "a"), Row(2, null), Row(null, null))))
}

test("to_json") {
val df = Seq(Tuple1(Tuple1(1))).toDF("a")

Expand Down