Skip to content

Commit

Permalink
[SPARK-48545][SQL] Create to_avro and from_avro SQL functions to matc…
Browse files Browse the repository at this point in the history
…h DataFrame equivalents

### What changes were proposed in this pull request?

This PR creates two new SQL functions "to_avro" and "from_avro" to match existing DataFrame equivalents.

For example:

```
sql(
  """
    |create table t as
    |  select named_struct('u', named_struct('member0', member0, 'member1', member1)) as s
    |  from values (1, null), (null,  'a') tab(member0, member1)
    |""".stripMargin)

val jsonFormatSchema =
  """
    |{
    |  "type": "record",
    |  "name": "struct",
    |  "fields": [{
    |    "name": "u",
    |    "type": ["int","string"]
    |  }]
    |}
    |""".stripMargin

spark.sql(
  s"""
    |select from_avro(result, '$jsonFormatSchema', map()).u from (
    |  select to_avro(s, '$jsonFormatSchema') as result from t
    |)")
  .collect()

> {1, NULL}
  {NULL, "a"}
```

### Why are the changes needed?

This brings parity between SQL and DataFrame APIs in Apache Spark.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

This PR adds extra unit tests, and I also checked that the functions work with `spark-shell`.

### Was this patch authored or co-authored using generative AI tooling?

No GitHub copilot usage this time

Closes #46977 from dtenedor/from-avro.

Authored-by: Daniel Tenedorio <daniel.tenedorio@databricks.com>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
  • Loading branch information
dtenedor authored and gengliangwang committed Jun 21, 2024
1 parent 7e5a461 commit b1677a4
Show file tree
Hide file tree
Showing 6 changed files with 271 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecord
import org.apache.avro.io.EncoderFactory

import org.apache.spark.SparkException
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.execution.LocalTableScanExec
import org.apache.spark.sql.functions.{col, lit, struct}
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -286,4 +286,85 @@ class AvroFunctionsSuite extends QueryTest with SharedSparkSession {
assert(msg.contains("Invalid default for field id: null not a \"long\""))
}
}

test("SPARK-48545: from_avro and to_avro SQL functions") {
withTable("t") {
sql(
"""
|create table t as
| select named_struct('u', named_struct('member0', member0, 'member1', member1)) as s
| from values (1, null), (null, 'a') tab(member0, member1)
|""".stripMargin)
val jsonFormatSchema =
"""
|{
| "type": "record",
| "name": "struct",
| "fields": [{
| "name": "u",
| "type": ["int","string"]
| }]
|}
|""".stripMargin
val toAvroSql =
s"""
|select to_avro(s, '$jsonFormatSchema') as result from t
|""".stripMargin
val avroResult = spark.sql(toAvroSql).collect()
assert(avroResult != null)
checkAnswer(
spark.sql(s"select from_avro(result, '$jsonFormatSchema', map()).u from ($toAvroSql)"),
Seq(Row(Row(1, null)),
Row(Row(null, "a"))))

// Negative tests.
checkError(
exception = intercept[AnalysisException](sql(
s"""
|select to_avro(s, 42) as result from t
|""".stripMargin)),
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map("sqlExpr" -> "\"toavro(s, 42)\"",
"msg" -> ("The second argument of the TO_AVRO SQL function must be a constant string " +
"containing the JSON representation of the schema to use for converting the value to " +
"AVRO format"),
"hint" -> ""),
queryContext = Array(ExpectedContext(
fragment = "to_avro(s, 42)",
start = 8,
stop = 21)))
checkError(
exception = intercept[AnalysisException](sql(
s"""
|select from_avro(s, 42, '') as result from t
|""".stripMargin)),
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map("sqlExpr" -> "\"fromavro(s, 42, )\"",
"msg" -> ("The second argument of the FROM_AVRO SQL function must be a constant string " +
"containing the JSON representation of the schema to use for converting the value " +
"from AVRO format"),
"hint" -> ""),
queryContext = Array(ExpectedContext(
fragment = "from_avro(s, 42, '')",
start = 8,
stop = 27)))
checkError(
exception = intercept[AnalysisException](sql(
s"""
|select from_avro(s, '$jsonFormatSchema', 42) as result from t
|""".stripMargin)),
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = Map(
"sqlExpr" ->
s"\"fromavro(s, $jsonFormatSchema, 42)\"".stripMargin,
"msg" -> ("The third argument of the FROM_AVRO SQL function must be a constant map of " +
"strings to strings containing the options to use for converting the value " +
"from AVRO format"),
"hint" -> ""),
queryContext = Array(ExpectedContext(
fragment = s"from_avro(s, '$jsonFormatSchema', 42)",
start = 8,
stop = 138)))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,11 @@ object FunctionRegistry {
// Xml
expression[XmlToStructs]("from_xml"),
expression[SchemaOfXml]("schema_of_xml"),
expression[StructsToXml]("to_xml")
expression[StructsToXml]("to_xml"),

// Avro
expression[FromAvro]("from_avro"),
expression[ToAvro]("to_avro")
)

val builtin: SimpleFunctionRegistry = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{MapType, NullType, StringType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils

/**
* Converts a binary column of Avro format into its corresponding Catalyst value.
* This is a thin wrapper over the [[AvroDataToCatalyst]] class to create a SQL function.
*
* @param child the Catalyst binary input column.
* @param jsonFormatSchema the Avro schema in JSON string format.
* @param options the options to use when performing the conversion.
*
* @since 4.0.0
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(child, jsonFormatSchema, options) - Converts a binary Avro value into a Catalyst value.
""",
examples = """
Examples:
> SELECT _FUNC_(s, '{"type": "record", "name": "struct", "fields": [{ "name": "u", "type": ["int","string"] }]}', map()) IS NULL AS result FROM (SELECT NAMED_STRUCT('u', NAMED_STRUCT('member0', member0, 'member1', member1)) AS s FROM VALUES (1, NULL), (NULL, 'a') tab(member0, member1));
[false]
""",
note = """
The specified schema must match actual schema of the read data, otherwise the behavior
is undefined: it may fail or return arbitrary result.
To deserialize the data with a compatible and evolved schema, the expected Avro schema can be
set via the corresponding option.
""",
group = "misc_funcs",
since = "4.0.0"
)
// scalastyle:on line.size.limit
case class FromAvro(child: Expression, jsonFormatSchema: Expression, options: Expression)
extends TernaryExpression with RuntimeReplaceable {
override def first: Expression = child
override def second: Expression = jsonFormatSchema
override def third: Expression = options

override def withNewChildrenInternal(
newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = {
copy(child = newFirst, jsonFormatSchema = newSecond, options = newThird)
}

override def checkInputDataTypes(): TypeCheckResult = {
val schemaCheck = jsonFormatSchema.dataType match {
case _: StringType |
_: NullType
if jsonFormatSchema.foldable =>
None
case _ =>
Some(TypeCheckResult.TypeCheckFailure(
"The second argument of the FROM_AVRO SQL function must be a constant string " +
"containing the JSON representation of the schema to use for converting the value " +
"from AVRO format"))
}
val optionsCheck = options.dataType match {
case MapType(StringType, StringType, _) |
MapType(NullType, NullType, _) |
_: NullType
if options.foldable =>
None
case _ =>
Some(TypeCheckResult.TypeCheckFailure(
"The third argument of the FROM_AVRO SQL function must be a constant map of strings to " +
"strings containing the options to use for converting the value from AVRO format"))
}
schemaCheck.getOrElse(
optionsCheck.getOrElse(
TypeCheckResult.TypeCheckSuccess))
}

override def replacement: Expression = {
val schemaValue: String = jsonFormatSchema.eval() match {
case s: UTF8String =>
s.toString
case null =>
""
}
val optionsValue: Map[String, String] = options.eval() match {
case a: ArrayBasedMapData if a.keyArray.array.nonEmpty =>
val keys: Array[String] = a.keyArray.array.map(_.toString)
val values: Array[String] = a.valueArray.array.map(_.toString)
keys.zip(values).toMap
case _ =>
Map.empty
}
val constructor =
Utils.classForName("org.apache.spark.sql.avro.AvroDataToCatalyst").getConstructors().head
val expr = constructor.newInstance(child, schemaValue, optionsValue)
expr.asInstanceOf[Expression]
}
}

/**
* Converts a Catalyst binary input value into its corresponding AvroAvro format result.
* This is a thin wrapper over the [[CatalystDataToAvro]] class to create a SQL function.
*
* @param child the Catalyst binary input column.
* @param jsonFormatSchema the Avro schema in JSON string format.
*
* @since 4.0.0
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(child, jsonFormatSchema) - Converts a Catalyst binary input value into its corresponding
Avro format result.
""",
examples = """
Examples:
> SELECT _FUNC_(s, '{"type": "record", "name": "struct", "fields": [{ "name": "u", "type": ["int","string"] }]}', MAP()) IS NULL FROM (SELECT NULL AS s);
[true]
""",
group = "misc_funcs",
since = "4.0.0"
)
// scalastyle:on line.size.limit
case class ToAvro(child: Expression, jsonFormatSchema: Expression)
extends BinaryExpression with RuntimeReplaceable {
override def left: Expression = child

override def right: Expression = jsonFormatSchema

override def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Expression = {
copy(child = newLeft, jsonFormatSchema = newRight)
}

override def checkInputDataTypes(): TypeCheckResult = {
jsonFormatSchema.dataType match {
case _: StringType if jsonFormatSchema.foldable =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
"The second argument of the TO_AVRO SQL function must be a constant string " +
"containing the JSON representation of the schema to use for converting the value " +
"to AVRO format")
}
}

override def replacement: Expression = {
val schemaValue: Option[String] = jsonFormatSchema.eval() match {
case null =>
None
case s: UTF8String =>
Some(s.toString)
}
val constructor =
Utils.classForName("org.apache.spark.sql.avro.CatalystDataToAvro").getConstructors().head
val expr = constructor.newInstance(child, schemaValue)
expr.asInstanceOf[Expression]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ class ExpressionsSchemaSuite extends QueryTest with SharedSparkSession {
// Note: We need to filter out the commands that set the parameters, such as:
// SET spark.sql.parser.escapedStringLiterals=true
example.split(" > ").tail.filterNot(_.trim.startsWith("SET")).take(1).foreach {
case _ if funcName == "from_avro" || funcName == "to_avro" =>
// Skip running the example queries for the from_avro and to_avro functions because
// these functions dynamically load the AvroDataToCatalyst or CatalystDataToAvro classes
// which are not available in this test.
case exampleRe(sql, _) =>
val df = spark.sql(sql)
val escapedSql = sql.replaceAll("\\|", "&#124;")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,9 @@ class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
// Throws an error
"org.apache.spark.sql.catalyst.expressions.RaiseErrorExpressionBuilder",
"org.apache.spark.sql.catalyst.expressions.AssertTrue",
// Requires dynamic class loading not available in this test suite.
"org.apache.spark.sql.catalyst.expressions.FromAvro",
"org.apache.spark.sql.catalyst.expressions.ToAvro",
classOf[CurrentUser].getName,
// The encrypt expression includes a random initialization vector to its encrypted result
classOf[AesEncrypt].getName)
Expand Down
3 changes: 2 additions & 1 deletion sql/gen-sql-functions-docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,8 @@ def _make_pretty_examples(jspark, infos):

pretty_output = ""
for info in infos:
if info.examples.startswith("\n Examples:"):
if (info.examples.startswith("\n Examples:")
and info.name.lower() not in ("from_avro", "to_avro")):
output = []
output.append("-- %s" % info.name)
query_examples = filter(lambda x: x.startswith(" > "), info.examples.split("\n"))
Expand Down

0 comments on commit b1677a4

Please sign in to comment.