Skip to content

Commit

Permalink
[SPARK-19254][SQL] Support Seq, Map, and Struct in functions.lit
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pr is to support Seq, Map, and Struct in functions.lit; it adds a new IF named `lit2` with `TypeTag` for avoiding type erasure.

## How was this patch tested?
Added tests in `LiteralExpressionSuite`

Author: Takeshi Yamamuro <yamamuro@apache.org>
Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>

Closes #16610 from maropu/SPARK-19254.
  • Loading branch information
maropu authored and hvanhovell committed Mar 5, 2017
1 parent f48461a commit 14bb398
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ import java.util.Objects
import javax.xml.bind.DatatypeConverter

import scala.math.{BigDecimal, BigInt}
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import org.json4s.JsonAST._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -153,6 +155,14 @@ object Literal {
Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
}

def create[T : TypeTag](v: T): Literal = Try {
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
Literal(convert(v), dataType)
}.getOrElse {
Literal(v)
}

/**
* Create a literal with default value for given DataType
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions

import java.nio.charset.StandardCharsets

import scala.reflect.runtime.universe.{typeTag, TypeTag}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
test("boolean literals") {
checkEvaluation(Literal(true), true)
checkEvaluation(Literal(false), false)

checkEvaluation(Literal.create(true), true)
checkEvaluation(Literal.create(false), false)
}

test("int literals") {
Expand All @@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal(d.toLong), d.toLong)
checkEvaluation(Literal(d.toShort), d.toShort)
checkEvaluation(Literal(d.toByte), d.toByte)

checkEvaluation(Literal.create(d), d)
checkEvaluation(Literal.create(d.toLong), d.toLong)
checkEvaluation(Literal.create(d.toShort), d.toShort)
checkEvaluation(Literal.create(d.toByte), d.toByte)
}
checkEvaluation(Literal(Long.MinValue), Long.MinValue)
checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)

checkEvaluation(Literal.create(Long.MinValue), Long.MinValue)
checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue)
}

test("double literals") {
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
checkEvaluation(Literal(d), d)
checkEvaluation(Literal(d.toFloat), d.toFloat)

checkEvaluation(Literal.create(d), d)
checkEvaluation(Literal.create(d.toFloat), d.toFloat)
}
checkEvaluation(Literal(Double.MinValue), Double.MinValue)
checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
checkEvaluation(Literal(Float.MinValue), Float.MinValue)
checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)

checkEvaluation(Literal.create(Double.MinValue), Double.MinValue)
checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue)
checkEvaluation(Literal.create(Float.MinValue), Float.MinValue)
checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue)

}

test("string literals") {
checkEvaluation(Literal(""), "")
checkEvaluation(Literal("test"), "test")
checkEvaluation(Literal("\u0000"), "\u0000")

checkEvaluation(Literal.create(""), "")
checkEvaluation(Literal.create("test"), "test")
checkEvaluation(Literal.create("\u0000"), "\u0000")
}

test("sum two literals") {
checkEvaluation(Add(Literal(1), Literal(1)), 2)
checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2)
}

test("binary literals") {
checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))

checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0))
checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2))
}

test("decimal") {
Expand All @@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
Decimal((d * 1000L).toLong, 10, 3))
checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d))
checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d))

checkEvaluation(Literal.create(Decimal(d)), Decimal(d))
checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt))
checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong))
checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)),
Decimal((d * 1000L).toLong, 10, 3))
checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d))
checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d))

}
}

private def toCatalyst[T: TypeTag](value: T): Any = {
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
CatalystTypeConverters.createToCatalystConverter(dataType)(value)
}

test("array") {
def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = {
val toCatalyst = (a: Array[_], elementType: DataType) => {
CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a)
}
checkEvaluation(Literal(a), toCatalyst(a, elementType))
def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = {
checkEvaluation(Literal(a), toCatalyst(a))
checkEvaluation(Literal.create(a), toCatalyst(a))
}
checkArrayLiteral(Array(1, 2, 3))
checkArrayLiteral(Array("a", "b", "c"))
checkArrayLiteral(Array(1.0, 4.0))
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR))
}

test("seq") {
def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = {
checkEvaluation(Literal.create(a), toCatalyst(a))
}
checkArrayLiteral(Array(1, 2, 3), IntegerType)
checkArrayLiteral(Array("a", "b", "c"), StringType)
checkArrayLiteral(Array(1.0, 4.0), DoubleType)
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
checkSeqLiteral(Seq(1, 2, 3), IntegerType)
checkSeqLiteral(Seq("a", "b", "c"), StringType)
checkSeqLiteral(Seq(1.0, 4.0), DoubleType)
checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
CalendarIntervalType)
}

test("unsupported types (map and struct) in literals") {
test("map") {
def checkMapLiteral[T: TypeTag](m: T): Unit = {
checkEvaluation(Literal.create(m), toCatalyst(m))
}
checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3))
checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0))
}

test("struct") {
def checkStructLiteral[T: TypeTag](s: T): Unit = {
checkEvaluation(Literal.create(s), toCatalyst(s))
}
checkStructLiteral((1, 3.0, "abcde"))
checkStructLiteral(("de", 1, 2.0f))
checkStructLiteral((1, ("fgh", 3.0)))
}

test("unsupported types (map and struct) in Literal.apply") {
def checkUnsupportedTypeInLiteral(v: Any): Unit = {
val errMsgMap = intercept[RuntimeException] {
Literal(v)
Expand Down
25 changes: 17 additions & 8 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,24 @@ object functions {
* @group normal_funcs
* @since 1.3.0
*/
def lit(literal: Any): Column = {
literal match {
case c: Column => return c
case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name)
case _ => // continue
}
def lit(literal: Any): Column = typedLit(literal)

val literalExpr = Literal(literal)
Column(literalExpr)
/**
* Creates a [[Column]] of literal value.
*
* The passed in object is returned directly if it is already a [[Column]].
* If the object is a Scala Symbol, it is converted into a [[Column]] also.
* Otherwise, a new [[Column]] is created to represent the literal value.
* The difference between this function and [[lit]] is that this function
* can handle parameterized scala types e.g.: List, Seq and Map.
*
* @group normal_funcs
* @since 2.2.0
*/
def typedLit[T : TypeTag](literal: T): Column = literal match {
case c: Column => c
case s: Symbol => new ColumnName(s.name)
case _ => Column(Literal.create(literal))
}

//////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
}

test("typedLit") {
val df = Seq(Tuple1(0)).toDF("a")
// Only check the types `lit` cannot handle
checkAnswer(
df.select(typedLit(Seq(1, 2, 3))),
Row(Seq(1, 2, 3)) :: Nil)
checkAnswer(
df.select(typedLit(Map("a" -> 1, "b" -> 2))),
Row(Map("a" -> 1, "b" -> 2)) :: Nil)
checkAnswer(
df.select(typedLit(("a", 2, 1.0))),
Row(Row("a", 2, 1.0)) :: Nil)
}
}

0 comments on commit 14bb398

Please sign in to comment.