Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11011][SQL] Narrow type of UDT serialization #11379

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
))
}

override def serialize(obj: Any): InternalRow = {
override def serialize(obj: Matrix): InternalRow = {
val row = new GenericMutableRow(7)
obj match {
case sm: SparseMatrix =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class VectorUDT extends UserDefinedType[Vector] {
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
}

override def serialize(obj: Any): InternalRow = {
override def serialize(obj: Vector): InternalRow = {
obj match {
case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
Expand Down
2 changes: 2 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLConf$SQLConfEntry$")
) ++ Seq(
//SPARK-11011 UserDefinedType serialization should be strongly typed
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
// SPARK-12073: backpressure rate controller consumes events preferentially from lagging partitions
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.KafkaTestUtils.createTopic"),
ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.streaming.kafka.DirectKafkaInputDStream.maxMessagesPerPartition")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,16 +136,16 @@ object CatalystTypeConverters {
override def toScalaImpl(row: InternalRow, column: Int): Any = row.get(column, dataType)
}

private case class UDTConverter(
udt: UserDefinedType[_]) extends CatalystTypeConverter[Any, Any, Any] {
private case class UDTConverter[A >: Null](
udt: UserDefinedType[A]) extends CatalystTypeConverter[A, A, Any] {
// toCatalyst (it calls toCatalystImpl) will do null check.
override def toCatalystImpl(scalaValue: Any): Any = udt.serialize(scalaValue)
override def toCatalystImpl(scalaValue: A): Any = udt.serialize(scalaValue)

override def toScala(catalystValue: Any): Any = {
override def toScala(catalystValue: Any): A = {
if (catalystValue == null) null else udt.deserialize(catalystValue)
}

override def toScalaImpl(row: InternalRow, column: Int): Any =
override def toScalaImpl(row: InternalRow, column: Int): A =
toScala(row.get(column, udt.sqlType))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.annotation.DeveloperApi
* The conversion via `deserialize` occurs when reading from a `DataFrame`.
*/
@DeveloperApi
abstract class UserDefinedType[UserType] extends DataType with Serializable {
abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable {

/** Underlying storage type for this UDT */
def sqlType: DataType
Expand All @@ -50,11 +50,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {

/**
* Convert the user type to a SQL datum
*
* TODO: Can we make this take obj: UserType? The issue is in
* CatalystTypeConverters.convertToCatalyst, where we need to convert Any to UserType.
*/
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might also be a good idea to update the documentation. In this case, what exactly is a SQL datum?

def serialize(obj: Any): Any
def serialize(obj: UserType): Any
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, this is the main change. All other modifications were necessary to accommodate it.


/** Convert a SQL datum to the user type */
def deserialize(datum: Any): UserType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,7 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] {

override def sqlType: DataType = IntegerType

override def serialize(obj: Any): Int = {
obj match {
case groupableData: GroupableData => groupableData.data
}
}
override def serialize(groupableData: GroupableData): Int = groupableData.data

override def deserialize(datum: Any): GroupableData = {
datum match {
Expand All @@ -60,13 +56,10 @@ private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] {

override def sqlType: DataType = MapType(IntegerType, IntegerType)

override def serialize(obj: Any): MapData = {
obj match {
case groupableData: UngroupableData =>
val keyArray = new GenericArrayData(groupableData.data.keys.toSeq)
val valueArray = new GenericArrayData(groupableData.data.values.toSeq)
new ArrayBasedMapData(keyArray, valueArray)
}
override def serialize(ungroupableData: UngroupableData): MapData = {
val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq)
val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq)
new ArrayBasedMapData(keyArray, valueArray)
}

override def deserialize(datum: Any): UngroupableData = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@ class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

override def serialize(obj: Any): GenericArrayData = {
obj match {
case p: ExamplePoint =>
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}
override def serialize(p: ExamplePoint): GenericArrayData = {
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}

override def deserialize(datum: Any): ExamplePoint = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,11 @@ private[sql] class ExamplePointUDT extends UserDefinedType[ExamplePoint] {

override def pyUDT: String = "pyspark.sql.tests.ExamplePointUDT"

override def serialize(obj: Any): GenericArrayData = {
obj match {
case p: ExamplePoint =>
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}
override def serialize(p: ExamplePoint): GenericArrayData = {
val output = new Array[Any](2)
output(0) = p.x
output(1) = p.y
new GenericArrayData(output)
}

override def deserialize(datum: Any): ExamplePoint = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {

override def sqlType: DataType = ArrayType(DoubleType, containsNull = false)

override def serialize(obj: Any): ArrayData = {
obj match {
case features: MyDenseVector =>
new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}
override def serialize(features: MyDenseVector): ArrayData = {
new GenericArrayData(features.data.map(_.asInstanceOf[Any]))
}

override def deserialize(datum: Any): MyDenseVector = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -590,14 +590,11 @@ object TestingUDT {
.add("b", LongType, nullable = false)
.add("c", DoubleType, nullable = false)

override def serialize(obj: Any): Any = {
override def serialize(n: NestedStruct): Any = {
val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType))
obj match {
case n: NestedStruct =>
row.setInt(0, n.a)
row.setLong(1, n.b)
row.setDouble(2, n.c)
}
row.setInt(0, n.a)
row.setLong(1, n.b)
row.setDouble(2, n.c)
}

override def userClass: Class[NestedStruct] = classOf[NestedStruct]
Expand Down