From 14bb398fae974137c3e38162cefc088e12838258 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 5 Mar 2017 03:53:19 -0800 Subject: [PATCH] [SPARK-19254][SQL] Support Seq, Map, and Struct in functions.lit ## What changes were proposed in this pull request? This pr is to support Seq, Map, and Struct in functions.lit; it adds a new IF named `lit2` with `TypeTag` for avoiding type erasure. ## How was this patch tested? Added tests in `LiteralExpressionSuite` Author: Takeshi Yamamuro Author: Takeshi YAMAMURO Closes #16610 from maropu/SPARK-19254. --- .../sql/catalyst/expressions/literals.scala | 12 ++- .../expressions/LiteralExpressionSuite.scala | 90 ++++++++++++++++--- .../org/apache/spark/sql/functions.scala | 25 ++++-- .../spark/sql/ColumnExpressionSuite.scala | 14 +++ 4 files changed, 121 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e66fb893394eb..eaeaf08c37b4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -32,11 +32,13 @@ import java.util.Objects import javax.xml.bind.DatatypeConverter import scala.math.{BigDecimal, BigInt} +import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import org.json4s.JsonAST._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -153,6 +155,14 @@ object Literal { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } + def create[T : TypeTag](v: T): Literal = Try { + val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Literal(convert(v), dataType) + }.getOrElse { + Literal(v) + } + /** * Create a literal with default value for given DataType */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 15e8e6c057baf..a9e0eb0e377a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets +import scala.reflect.runtime.universe.{typeTag, TypeTag} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("boolean literals") { checkEvaluation(Literal(true), true) checkEvaluation(Literal(false), false) + + checkEvaluation(Literal.create(true), true) + checkEvaluation(Literal.create(false), false) } test("int literals") { @@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal(d.toLong), d.toLong) checkEvaluation(Literal(d.toShort), d.toShort) checkEvaluation(Literal(d.toByte), d.toByte) + + checkEvaluation(Literal.create(d), d) + checkEvaluation(Literal.create(d.toLong), d.toLong) + checkEvaluation(Literal.create(d.toShort), d.toShort) + checkEvaluation(Literal.create(d.toByte), d.toByte) } checkEvaluation(Literal(Long.MinValue), Long.MinValue) checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) + + checkEvaluation(Literal.create(Long.MinValue), Long.MinValue) + checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue) } test("double literals") { List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => checkEvaluation(Literal(d), d) checkEvaluation(Literal(d.toFloat), d.toFloat) + + checkEvaluation(Literal.create(d), d) + checkEvaluation(Literal.create(d.toFloat), d.toFloat) } checkEvaluation(Literal(Double.MinValue), Double.MinValue) checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) checkEvaluation(Literal(Float.MinValue), Float.MinValue) checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) + checkEvaluation(Literal.create(Double.MinValue), Double.MinValue) + checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue) + checkEvaluation(Literal.create(Float.MinValue), Float.MinValue) + checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue) + } test("string literals") { checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") checkEvaluation(Literal("\u0000"), "\u0000") + + checkEvaluation(Literal.create(""), "") + checkEvaluation(Literal.create("test"), "test") + checkEvaluation(Literal.create("\u0000"), "\u0000") } test("sum two literals") { checkEvaluation(Add(Literal(1), Literal(1)), 2) + checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2) } test("binary literals") { checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) + + checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0)) + checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2)) } test("decimal") { @@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { Decimal((d * 1000L).toLong, 10, 3)) checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d)) checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d)) + + checkEvaluation(Literal.create(Decimal(d)), Decimal(d)) + checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt)) + checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong)) + checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)), + Decimal((d * 1000L).toLong, 10, 3)) + checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d)) + checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d)) + } } + private def toCatalyst[T: TypeTag](value: T): Any = { + val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] + CatalystTypeConverters.createToCatalystConverter(dataType)(value) + } + test("array") { - def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = { - val toCatalyst = (a: Array[_], elementType: DataType) => { - CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a) - } - checkEvaluation(Literal(a), toCatalyst(a, elementType)) + def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = { + checkEvaluation(Literal(a), toCatalyst(a)) + checkEvaluation(Literal.create(a), toCatalyst(a)) + } + checkArrayLiteral(Array(1, 2, 3)) + checkArrayLiteral(Array("a", "b", "c")) + checkArrayLiteral(Array(1.0, 4.0)) + checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR)) + } + + test("seq") { + def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = { + checkEvaluation(Literal.create(a), toCatalyst(a)) } - checkArrayLiteral(Array(1, 2, 3), IntegerType) - checkArrayLiteral(Array("a", "b", "c"), StringType) - checkArrayLiteral(Array(1.0, 4.0), DoubleType) - checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), + checkSeqLiteral(Seq(1, 2, 3), IntegerType) + checkSeqLiteral(Seq("a", "b", "c"), StringType) + checkSeqLiteral(Seq(1.0, 4.0), DoubleType) + checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), CalendarIntervalType) } - test("unsupported types (map and struct) in literals") { + test("map") { + def checkMapLiteral[T: TypeTag](m: T): Unit = { + checkEvaluation(Literal.create(m), toCatalyst(m)) + } + checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3)) + checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0)) + } + + test("struct") { + def checkStructLiteral[T: TypeTag](s: T): Unit = { + checkEvaluation(Literal.create(s), toCatalyst(s)) + } + checkStructLiteral((1, 3.0, "abcde")) + checkStructLiteral(("de", 1, 2.0f)) + checkStructLiteral((1, ("fgh", 3.0))) + } + + test("unsupported types (map and struct) in Literal.apply") { def checkUnsupportedTypeInLiteral(v: Any): Unit = { val errMsgMap = intercept[RuntimeException] { Literal(v) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 24ed906d33683..2247010ac3f3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -91,15 +91,24 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = { - literal match { - case c: Column => return c - case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name) - case _ => // continue - } + def lit(literal: Any): Column = typedLit(literal) - val literalExpr = Literal(literal) - Column(literalExpr) + /** + * Creates a [[Column]] of literal value. + * + * The passed in object is returned directly if it is already a [[Column]]. + * If the object is a Scala Symbol, it is converted into a [[Column]] also. + * Otherwise, a new [[Column]] is created to represent the literal value. + * The difference between this function and [[lit]] is that this function + * can handle parameterized scala types e.g.: List, Seq and Map. + * + * @group normal_funcs + * @since 2.2.0 + */ + def typedLit[T : TypeTag](literal: T): Column = literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => Column(Literal.create(literal)) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index ee280a313cc04..b0f398dab7455 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)), testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39))) } + + test("typedLit") { + val df = Seq(Tuple1(0)).toDF("a") + // Only check the types `lit` cannot handle + checkAnswer( + df.select(typedLit(Seq(1, 2, 3))), + Row(Seq(1, 2, 3)) :: Nil) + checkAnswer( + df.select(typedLit(Map("a" -> 1, "b" -> 2))), + Row(Map("a" -> 1, "b" -> 2)) :: Nil) + checkAnswer( + df.select(typedLit(("a", 2, 1.0))), + Row(Row("a", 2, 1.0)) :: Nil) + } }