Skip to content

Commit

Permalink
fix test failed
Browse files Browse the repository at this point in the history
  • Loading branch information
windpiger committed Feb 15, 2017
1 parent cbe91bf commit adf31b2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -737,20 +737,20 @@ object ScalaReflection extends ScalaReflection {
Schema(udt, nullable = true)
case t if t <:< localTypeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
Schema(schemaForDefaultBinaryType(optType).dataType, nullable = true)
case t if t <:< localTypeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
case t if t <:< localTypeOf[Array[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val Schema(dataType, nullable) = schemaForDefaultBinaryType(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[Seq[_]] =>
val TypeRef(_, _, Seq(elementType)) = t
val Schema(dataType, nullable) = schemaFor(elementType)
val Schema(dataType, nullable) = schemaForDefaultBinaryType(elementType)
Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
case t if t <:< localTypeOf[Map[_, _]] =>
val TypeRef(_, _, Seq(keyType, valueType)) = t
val Schema(valueDataType, valueNullable) = schemaFor(valueType)
Schema(MapType(schemaFor(keyType).dataType,
val Schema(valueDataType, valueNullable) = schemaForDefaultBinaryType(valueType)
Schema(MapType(schemaForDefaultBinaryType(keyType).dataType,
valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
Expand Down Expand Up @@ -781,7 +781,7 @@ object ScalaReflection extends ScalaReflection {
val params = getConstructorParameters(t)
Schema(StructType(
params.map { case (fieldName, fieldType) =>
val Schema(dataType, nullable) = schemaFor(fieldType)
val Schema(dataType, nullable) = schemaForDefaultBinaryType(fieldType)
StructField(fieldName, dataType, nullable)
}), nullable = true)
case other =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,6 @@ import scala.reflect.ClassTag
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Encoders

class NonEncodable(i: Int)

case class ComplexNonEncodable1(name1: NonEncodable)

case class ComplexNonEncodable2(name2: ComplexNonEncodable1)

case class ComplexNonEncodable3(name3: Option[NonEncodable])

case class ComplexNonEncodable4(name4: Array[NonEncodable])

case class ComplexNonEncodable5(name5: Option[Array[NonEncodable]])

class EncoderErrorMessageSuite extends SparkFunSuite {

// Note: we also test error messages for encoders for private classes in JavaDatasetSuite.
Expand All @@ -51,52 +39,5 @@ class EncoderErrorMessageSuite extends SparkFunSuite {
intercept[UnsupportedOperationException] { Encoders.javaSerialization[Char] }
}

test("nice error message for missing encoder") {
val errorMsg1 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable1]).getMessage
assert(errorMsg1.contains(
s"""root class: "${clsName[ComplexNonEncodable1]}""""))
assert(errorMsg1.contains(
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))

val errorMsg2 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable2]).getMessage
assert(errorMsg2.contains(
s"""root class: "${clsName[ComplexNonEncodable2]}""""))
assert(errorMsg2.contains(
s"""field (class: "${clsName[ComplexNonEncodable1]}", name: "name2")"""))
assert(errorMsg1.contains(
s"""field (class: "${clsName[NonEncodable]}", name: "name1")"""))

val errorMsg3 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable3]).getMessage
assert(errorMsg3.contains(
s"""root class: "${clsName[ComplexNonEncodable3]}""""))
assert(errorMsg3.contains(
s"""field (class: "scala.Option", name: "name3")"""))
assert(errorMsg3.contains(
s"""option value class: "${clsName[NonEncodable]}""""))

val errorMsg4 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable4]).getMessage
assert(errorMsg4.contains(
s"""root class: "${clsName[ComplexNonEncodable4]}""""))
assert(errorMsg4.contains(
s"""field (class: "scala.Array", name: "name4")"""))
assert(errorMsg4.contains(
s"""array element class: "${clsName[NonEncodable]}""""))

val errorMsg5 =
intercept[UnsupportedOperationException](ExpressionEncoder[ComplexNonEncodable5]).getMessage
assert(errorMsg5.contains(
s"""root class: "${clsName[ComplexNonEncodable5]}""""))
assert(errorMsg5.contains(
s"""field (class: "scala.Option", name: "name5")"""))
assert(errorMsg5.contains(
s"""option value class: "scala.Array""""))
assert(errorMsg5.contains(
s"""array element class: "${clsName[NonEncodable]}""""))
}

private def clsName[T : ClassTag]: String = implicitly[ClassTag[T]].runtimeClass.getName
}

0 comments on commit adf31b2

Please sign in to comment.