Skip to content

Commit

Permalink
[SPARK-21567][SQL] Dataset should work with type alias
Browse files Browse the repository at this point in the history
If we create a type alias for a type workable with Dataset, the type alias doesn't work with Dataset.

A reproducible case looks like:

    object C {
      type TwoInt = (Int, Int)
      def tupleTypeAlias: TwoInt = (1, 1)
    }

    Seq(1).toDS().map(_ => ("", C.tupleTypeAlias))

It throws an exception like:

    type T1 is not a class
    scala.ScalaReflectionException: type T1 is not a class
      at scala.reflect.api.Symbols$SymbolApi$class.asClass(Symbols.scala:275)
      ...

This patch accesses the dealias of type in many places in `ScalaReflection` to fix it.

Added test case.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes apache#18813 from viirya/SPARK-21567.

(cherry picked from commit ee13041)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
viirya authored and MatthewRBruce committed Jul 31, 2018
1 parent f075068 commit e74b939
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
Expand Up @@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection {
def dataTypeFor[T : TypeTag]: DataType = dataTypeFor(localTypeOf[T])

private def dataTypeFor(tpe: `Type`): DataType = ScalaReflectionLock.synchronized {
tpe match {
tpe.dealias match {
case t if t <:< definitions.IntTpe => IntegerType
case t if t <:< definitions.LongTpe => LongType
case t if t <:< definitions.DoubleTpe => DoubleType
Expand Down Expand Up @@ -93,7 +93,7 @@ object ScalaReflection extends ScalaReflection {
* JVM form instead of the Scala Array that handles auto boxing.
*/
private def arrayClassFor(tpe: `Type`): ObjectType = ScalaReflectionLock.synchronized {
val cls = tpe match {
val cls = tpe.dealias match {
case t if t <:< definitions.IntTpe => classOf[Array[Int]]
case t if t <:< definitions.LongTpe => classOf[Array[Long]]
case t if t <:< definitions.DoubleTpe => classOf[Array[Double]]
Expand Down Expand Up @@ -192,7 +192,7 @@ object ScalaReflection extends ScalaReflection {
case _ => UpCast(expr, expected, walkedTypePath)
}

tpe match {
tpe.dealias match {
case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath

case t if t <:< localTypeOf[Option[_]] =>
Expand Down Expand Up @@ -479,7 +479,7 @@ object ScalaReflection extends ScalaReflection {
}
}

tpe match {
tpe.dealias match {
case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject

case t if t <:< localTypeOf[Option[_]] =>
Expand Down Expand Up @@ -633,7 +633,7 @@ object ScalaReflection extends ScalaReflection {
* we also treat [[DefinedByConstructorParams]] as product type.
*/
def optionOfProductType(tpe: `Type`): Boolean = ScalaReflectionLock.synchronized {
tpe match {
tpe.dealias match {
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
definedByConstructorParams(optType)
Expand Down Expand Up @@ -680,7 +680,7 @@ object ScalaReflection extends ScalaReflection {
/*
* Retrieves the runtime class corresponding to the provided type.
*/
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.typeSymbol.asClass)
def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.dealias.typeSymbol.asClass)

case class Schema(dataType: DataType, nullable: Boolean)

Expand All @@ -695,7 +695,7 @@ object ScalaReflection extends ScalaReflection {

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
tpe match {
tpe.dealias match {
case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
Schema(udt, nullable = true)
Expand Down Expand Up @@ -761,7 +761,7 @@ object ScalaReflection extends ScalaReflection {
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
def definedByConstructorParams(tpe: Type): Boolean = {
tpe <:< localTypeOf[Product] || tpe <:< localTypeOf[DefinedByConstructorParams]
tpe.dealias <:< localTypeOf[Product] || tpe.dealias <:< localTypeOf[DefinedByConstructorParams]
}

private val javaKeywords = Set("abstract", "assert", "boolean", "break", "byte", "case", "catch",
Expand Down Expand Up @@ -816,7 +816,7 @@ trait ScalaReflection {
* synthetic classes, emulating behaviour in Java bytecode.
*/
def getClassNameFromType(tpe: `Type`): String = {
tpe.erasure.typeSymbol.asClass.fullName
tpe.dealias.erasure.typeSymbol.asClass.fullName
}

/**
Expand All @@ -835,9 +835,10 @@ trait ScalaReflection {
* support inner class.
*/
def getConstructorParameters(tpe: Type): Seq[(String, Type)] = {
val formalTypeArgs = tpe.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = tpe
val params = constructParams(tpe)
val dealiasedTpe = tpe.dealias
val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
val params = constructParams(dealiasedTpe)
// if there are type variables to fill in, do the substitution (SomeClass[T] -> SomeClass[Int])
if (actualTypeArgs.nonEmpty) {
params.map { p =>
Expand All @@ -851,7 +852,7 @@ trait ScalaReflection {
}

protected def constructParams(tpe: Type): Seq[Symbol] = {
val constructorSymbol = tpe.member(nme.CONSTRUCTOR)
val constructorSymbol = tpe.dealias.member(nme.CONSTRUCTOR)
val params = if (constructorSymbol.isMethod) {
constructorSymbol.asMethod.paramss
} else {
Expand Down
24 changes: 24 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Expand Up @@ -34,6 +34,16 @@ import org.apache.spark.sql.types._
case class TestDataPoint(x: Int, y: Double, s: String, t: TestDataPoint2)
case class TestDataPoint2(x: Int, s: String)

object TestForTypeAlias {
type TwoInt = (Int, Int)
type ThreeInt = (TwoInt, Int)
type SeqOfTwoInt = Seq[TwoInt]

def tupleTypeAlias: TwoInt = (1, 1)
def nestedTupleTypeAlias: ThreeInt = ((1, 1), 2)
def seqOfTupleTypeAlias: SeqOfTwoInt = Seq((1, 1), (2, 2))
}

class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._

Expand Down Expand Up @@ -1210,6 +1220,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(df.orderBy($"id"), expected)
checkAnswer(df.orderBy('id), expected)
}

test("SPARK-21567: Dataset should work with type alias") {
checkDataset(
Seq(1).toDS().map(_ => ("", TestForTypeAlias.tupleTypeAlias)),
("", (1, 1)))

checkDataset(
Seq(1).toDS().map(_ => ("", TestForTypeAlias.nestedTupleTypeAlias)),
("", ((1, 1), 2)))

checkDataset(
Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)),
("", Seq((1, 1), (2, 2))))
}
}

case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String])
Expand Down

0 comments on commit e74b939

Please sign in to comment.