Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object ScalaReflection extends ScalaReflection {
case t if t <:< definitions.BooleanTpe => BooleanType
case t if t <:< localTypeOf[Array[Byte]] => BinaryType
case _ =>
val className: String = tpe.erasure.typeSymbol.asClass.fullName
val className = getClassNameFromType(tpe)
className match {
case "scala.Array" =>
val TypeRef(_, _, Seq(elementType)) = tpe
Expand Down Expand Up @@ -320,9 +320,23 @@ object ScalaReflection extends ScalaReflection {
}
}

/** Returns expressions for extracting all the fields from the given type. */
/**
* Returns expressions for extracting all the fields from the given type.
*
* If the given type is not supported, i.e. there is no encoder can be built for this type,
* an [[UnsupportedOperationException]] will be thrown with detailed error message to explain
* the type path walked so far and which class we are not supporting.
* There are 4 kinds of type path:
* * the root type: `root class: "abc.xyz.MyClass"`
* * the value type of [[Option]]: `option value class: "abc.xyz.MyClass"`
* * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"`
* * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")`
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any suggestions for the error message format? cc @marmbrus @rxin

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually think this looks pretty good.

*/
def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
extractorFor(inputObject, localTypeOf[T]) match {
val tpe = localTypeOf[T]
val clsName = getClassNameFromType(tpe)
val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
extractorFor(inputObject, tpe, walkedTypePath) match {
case s: CreateNamedStruct => s
case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
}
Expand All @@ -331,7 +345,28 @@ object ScalaReflection extends ScalaReflection {
/** Helper for extracting internal fields from a case class. */
private def extractorFor(
inputObject: Expression,
tpe: `Type`): Expression = ScalaReflectionLock.synchronized {
tpe: `Type`,
walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized {

def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = silentSchemaFor(elementType)
if (isNativeType(catalystType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(catalystType, nullable))
} else {
val clsName = getClassNameFromType(elementType)
val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath
// `MapObjects` will run `extractorFor` lazily, we need to eagerly call `extractorFor` here
// to trigger the type check.
extractorFor(inputObject, elementType, newPath)

MapObjects(extractorFor(_, elementType, newPath), input, externalDataType)
}
}

if (!inputObject.dataType.isInstanceOf[ObjectType]) {
inputObject
} else {
Expand Down Expand Up @@ -378,15 +413,16 @@ object ScalaReflection extends ScalaReflection {

// For non-primitives, we can just extract the object from the Option and then recurse.
case other =>
val className: String = optType.erasure.typeSymbol.asClass.fullName
val className = getClassNameFromType(optType)
val classObj = Utils.classForName(className)
val optionObjectType = ObjectType(classObj)
val newPath = s"""- option value class: "$className"""" +: walkedTypePath

val unwrapped = UnwrapOption(optionObjectType, inputObject)
expressions.If(
IsNull(unwrapped),
expressions.Literal.create(null, schemaFor(optType).dataType),
extractorFor(unwrapped, optType))
expressions.Literal.create(null, silentSchemaFor(optType).dataType),
extractorFor(unwrapped, optType, newPath))
}

case t if t <:< localTypeOf[Product] =>
Expand All @@ -412,7 +448,10 @@ object ScalaReflection extends ScalaReflection {
val fieldName = p.name.toString
val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil
val clsName = getClassNameFromType(fieldType)
val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath

expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil
})

case t if t <:< localTypeOf[Array[_]] =>
Expand Down Expand Up @@ -500,23 +539,11 @@ object ScalaReflection extends ScalaReflection {
Invoke(inputObject, "booleanValue", BooleanType)

case other =>
throw new UnsupportedOperationException(s"Extractor for type $other is not supported")
throw new UnsupportedOperationException(
s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
}
}
}

private def toCatalystArray(input: Expression, elementType: `Type`): Expression = {
val externalDataType = dataTypeFor(elementType)
val Schema(catalystType, nullable) = schemaFor(elementType)
if (isNativeType(catalystType)) {
NewInstance(
classOf[GenericArrayData],
input :: Nil,
dataType = ArrayType(catalystType, nullable))
} else {
MapObjects(extractorFor(_, elementType), input, externalDataType)
}
}
}

/**
Expand Down Expand Up @@ -561,7 +588,7 @@ trait ScalaReflection {

/** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */
def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized {
val className: String = tpe.erasure.typeSymbol.asClass.fullName
val className = getClassNameFromType(tpe)
tpe match {
case t if Utils.classIsLoadable(className) &&
Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) =>
Expand Down Expand Up @@ -637,6 +664,23 @@ trait ScalaReflection {
}
}

/**
* Returns a catalyst DataType and its nullability for the given Scala Type using reflection.
*
* Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return
* `NullType` silently instead.
*/
private def silentSchemaFor(tpe: `Type`): Schema = try {
schemaFor(tpe)
} catch {
case _: UnsupportedOperationException => Schema(NullType, nullable = true)
}

/** Returns the full class name for a type. */
private def getClassNameFromType(tpe: `Type`): String = {
tpe.erasure.typeSymbol.asClass.fullName
}

/**
* Returns classes of input parameters of scala function object.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,22 @@

package org.apache.spark.sql.catalyst.encoders

import scala.reflect.ClassTag

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders

class NonEncodable(i: Int)

case class ComplexNonEncodable1(name1: NonEncodable)

case class ComplexNonEncodable2(name2: ComplexNonEncodable1)

case class ComplexNonEncodable3(name3: Option[NonEncodable])

case class ComplexNonEncodable4(name4: Array[NonEncodable])

case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]])

class EncoderErrorMessageSuite extends SparkFunSuite {

Expand All @@ -37,4 +50,53 @@ class EncoderErrorMessageSuite extends SparkFunSuite {
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Long] }
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] }
}

test("nice error message for missing encoder") {
val errorMsg1 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage
assert(errorMsg1.contains(
s"""root class: "${clsName[ComplexNonEncodable1]}""""))
assert(errorMsg1.contains(
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))

val errorMsg2 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage
assert(errorMsg2.contains(
s"""root class: "${clsName[ComplexNonEncodable2]}""""))
assert(errorMsg2.contains(
s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")"""))
assert(errorMsg1.contains(
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))

val errorMsg3 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage
assert(errorMsg3.contains(
s"""root class: "${clsName[ComplexNonEncodable3]}""""))
assert(errorMsg3.contains(
s"""field (class: "scala.Option", name: "name3")"""))
assert(errorMsg3.contains(
s"""option value class: "${clsName[NonEncodable]}""""))

val errorMsg4 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage
assert(errorMsg4.contains(
s"""root class: "${clsName[ComplexNonEncodable4]}""""))
assert(errorMsg4.contains(
s"""field (class: "scala.Array", name: "name4")"""))
assert(errorMsg4.contains(
s"""array element class: "${clsName[NonEncodable]}""""))

val errorMsg5 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage
assert(errorMsg5.contains(
s"""root class: "${clsName[ComplexNonEncodable5]}""""))
assert(errorMsg5.contains(
s"""field (class: "scala.Option", name: "name5")"""))
assert(errorMsg5.contains(
s"""option value class: "scala.Array""""))
assert(errorMsg5.contains(
s"""array element class: "${clsName[NonEncodable]}""""))
}

private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName
}