Skip to content

Commit

Permalink
[SPARK-31450][SQL] Make ExpressionEncoder thread-safe
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
This PR moves the `ExpressionEncoder.toRow` and `ExpressionEncoder.fromRow` functions into their own function objects(`ExpressionEncoder.Serializer` & `ExpressionEncoder.Deserializer`). This effectively makes the `ExpressionEncoder` stateless, thread-safe and (more) reusable. The function objects are not thread safe, however they are documented as such and should be used in a more limited scope (making it easier to reason about thread safety).

### Why are the changes needed?
ExpressionEncoders are not thread-safe. We had various (nasty) bugs because of this.

### Does this PR introduce any user-facing change?
No.

### How was this patch tested?
Existing tests.

Closes #28223 from hvanhovell/SPARK-31450.

Authored-by: herman <herman@databricks.com>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
(cherry picked from commit fab4ca5)
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
  • Loading branch information
hvanhovell authored and dongjoon-hyun committed Apr 17, 2020
1 parent 9690d9f commit e7fef70
Show file tree
Hide file tree
Showing 39 changed files with 282 additions and 238 deletions.
Expand Up @@ -91,8 +91,8 @@ private[image] class ImageFileFormat extends FileFormat with DataSourceRegister
if (requiredSchema.isEmpty) {
filteredResult.map(_ => emptyUnsafeRow)
} else {
val converter = RowEncoder(requiredSchema)
filteredResult.map(row => converter.toRow(row))
val toRow = RowEncoder(requiredSchema).createSerializer()
filteredResult.map(row => toRow(row))
}
}
}
Expand Down
Expand Up @@ -166,7 +166,7 @@ private[libsvm] class LibSVMFileFormat
LabeledPoint(label, Vectors.sparse(numFeatures, indices, values))
}

val converter = RowEncoder(dataSchema)
val toRow = RowEncoder(dataSchema).createSerializer()
val fullOutput = dataSchema.map { f =>
AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()
}
Expand All @@ -178,7 +178,7 @@ private[libsvm] class LibSVMFileFormat

points.map { pt =>
val features = if (isSparse) pt.features.toSparse else pt.features.toDense
requiredColumns(converter.toRow(Row(pt.label, features)))
requiredColumns(toRow(Row(pt.label, features)))
}
}
}
Expand Down
Expand Up @@ -38,20 +38,22 @@ object UDTSerializationBenchmark extends BenchmarkBase {
val iters = 1e2.toInt
val numRows = 1e3.toInt

val encoder = ExpressionEncoder[Vector].resolveAndBind()
val encoder = ExpressionEncoder[Vector]().resolveAndBind()
val toRow = encoder.createSerializer()
val fromRow = encoder.createDeserializer()

val vectors = (1 to numRows).map { i =>
Vectors.dense(Array.fill(1e5.toInt)(1.0 * i))
}.toArray
val rows = vectors.map(encoder.toRow)
val rows = vectors.map(toRow)

val benchmark = new Benchmark("VectorUDT de/serialization", numRows, iters, output = output)

benchmark.addCase("serialize") { _ =>
var sum = 0
var i = 0
while (i < numRows) {
sum += encoder.toRow(vectors(i)).numFields
sum += toRow(vectors(i)).numFields
i += 1
}
}
Expand All @@ -60,7 +62,7 @@ object UDTSerializationBenchmark extends BenchmarkBase {
var sum = 0
var i = 0
while (i < numRows) {
sum += encoder.fromRow(rows(i)).numActives
sum += fromRow(rows(i)).numActives
i += 1
}
}
Expand Down
Expand Up @@ -58,8 +58,7 @@ import org.apache.spark.sql.types._
* }}}
*
* == Implementation ==
* - Encoders are not required to be thread-safe and thus they do not need to use locks to guard
* against concurrent access if they reuse internal buffers to improve performance.
* - Encoders should be thread-safe.
*
* @since 1.6.0
*/
Expand All @@ -76,10 +75,4 @@ trait Encoder[T] extends Serializable {
* A ClassTag that can be used to construct an Array to contain a collection of `T`.
*/
def clsTag: ClassTag[T]

/**
* Create a copied [[Encoder]]. The implementation may just copy internal reusable fields to speed
* up the [[Encoder]] creation.
*/
def makeCopy: Encoder[T]
}
Expand Up @@ -17,12 +17,15 @@

package org.apache.spark.sql.catalyst.encoders

import java.io.ObjectInputStream

import scala.reflect.ClassTag
import scala.reflect.runtime.universe.{typeTag, TypeTag}

import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.{Deserializer, Serializer}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, InitializeJavaBean, Invoke, NewInstance}
Expand Down Expand Up @@ -162,6 +165,56 @@ object ExpressionEncoder {
e4: ExpressionEncoder[T4],
e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]

private val anyObjectType = ObjectType(classOf[Any])

/**
* Function that deserializes an [[InternalRow]] into an object of type `T`. This class is not
* thread-safe.
*/
class Deserializer[T](private val expressions: Seq[Expression])
extends (InternalRow => T) with Serializable {
@transient
private[this] var constructProjection: Projection = _

override def apply(row: InternalRow): T = try {
if (constructProjection == null) {
constructProjection = SafeProjection.create(expressions)
}
constructProjection(row).get(0, anyObjectType).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n" +
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
}
}

/**
* Function that serializesa an object of type `T` to an [[InternalRow]]. This class is not
* thread-safe. Note that multiple calls to `apply(..)` return the same actual [[InternalRow]]
* object. Thus, the caller should copy the result before making another call if required.
*/
class Serializer[T](private val expressions: Seq[Expression])
extends (T => InternalRow) with Serializable {
@transient
private[this] var inputRow: GenericInternalRow = _

@transient
private[this] var extractProjection: UnsafeProjection = _

override def apply(t: T): InternalRow = try {
if (extractProjection == null) {
inputRow = new GenericInternalRow(1)
extractProjection = GenerateUnsafeProjection.generate(expressions)
}
inputRow(0) = t
extractProjection(inputRow)
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while encoding: $e\n" +
s"${expressions.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
}
}
}

/**
Expand Down Expand Up @@ -301,25 +354,22 @@ case class ExpressionEncoder[T](
}

@transient
private lazy val extractProjection = GenerateUnsafeProjection.generate({
private lazy val optimizedDeserializer: Seq[Expression] = {
// When using `ExpressionEncoder` directly, we will skip the normal query processing steps
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's
// important to codegen performance.
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
})

@transient
private lazy val inputRow = new GenericInternalRow(1)
}

@transient
private lazy val constructProjection = SafeProjection.create({
private lazy val optimizedSerializer = {
// When using `ExpressionEncoder` directly, we will skip the normal query processing steps
// (analyzer, optimizer, etc.). Here we apply the ReassignLambdaVariableID rule, as it's
// important to codegen performance.
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(Seq(deserializer)))
val optimizedPlan = ReassignLambdaVariableID.apply(DummyExpressionHolder(serializer))
optimizedPlan.asInstanceOf[DummyExpressionHolder].exprs
})
}

/**
* Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
Expand All @@ -331,31 +381,21 @@ case class ExpressionEncoder[T](
}

/**
* Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to
* toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should
* copy the result before making another call if required.
* Create a serializer that can convert an object of type `T` to a Spark SQL Row.
*
* Note that the returned [[Serializer]] is not thread safe. Multiple calls to
* `serializer.apply(..)` are allowed to return the same actual [[InternalRow]] object. Thus,
* the caller should copy the result before making another call if required.
*/
def toRow(t: T): InternalRow = try {
inputRow(0) = t
extractProjection(inputRow)
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while encoding: $e\n" +
s"${serializer.map(_.simpleString(SQLConf.get.maxToStringFields)).mkString("\n")}", e)
}
def createSerializer(): Serializer[T] = new Serializer[T](optimizedSerializer)

/**
* Returns an object of type `T`, extracting the required values from the provided row. Note that
* you must `resolveAndBind` an encoder to a specific schema before you can call this
* function.
* Create a deserializer that can convert a Spark SQL Row into an object of type `T`.
*
* Note that you must `resolveAndBind` an encoder to a specific schema before you can create a
* deserializer.
*/
def fromRow(row: InternalRow): T = try {
constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T]
} catch {
case e: Exception =>
throw new RuntimeException(s"Error while decoding: $e\n" +
s"${deserializer.simpleString(SQLConf.get.maxToStringFields)}", e)
}
def createDeserializer(): Deserializer[T] = new Deserializer[T](optimizedDeserializer)

/**
* The process of resolution to a given schema throws away information about where a given field
Expand All @@ -382,8 +422,6 @@ case class ExpressionEncoder[T](
.map { case(f, a) => s"${f.name}$a: ${f.dataType.simpleString}"}.mkString(", ")

override def toString: String = s"class[$schemaString]"

override def makeCopy: ExpressionEncoder[T] = copy()
}

// A dummy logical plan that can hold expressions and go through optimizer rules.
Expand Down
Expand Up @@ -110,8 +110,8 @@ case class ScalaUDF(
} else {
val encoder = inputEncoders(i)
if (encoder.isDefined && encoder.get.isSerializedAsStructForTopLevel) {
val enc = encoder.get.resolveAndBind()
row: Any => enc.fromRow(row.asInstanceOf[InternalRow])
val fromRow = encoder.get.resolveAndBind().createDeserializer()
row: Any => fromRow(row.asInstanceOf[InternalRow])
} else {
CatalystTypeConverters.createToScalaConverter(dataType)
}
Expand Down
Expand Up @@ -41,13 +41,13 @@ object HashBenchmark extends BenchmarkBase {
def test(name: String, schema: StructType, numRows: Int, iters: Int): Unit = {
runBenchmark(name) {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
val encoder = RowEncoder(schema)
val toRow = RowEncoder(schema).createSerializer()
val attrs = schema.toAttributes
val safeProjection = GenerateSafeProjection.generate(attrs, attrs)

val rows = (1 to numRows).map(_ =>
// The output of encoder is UnsafeRow, use safeProjection to turn in into safe format.
safeProjection(encoder.toRow(generator().asInstanceOf[Row])).copy()
safeProjection(toRow(generator().asInstanceOf[Row])).copy()
).toArray

val benchmark = new Benchmark("Hash For " + name, iters * numRows.toLong, output = output)
Expand Down
Expand Up @@ -37,8 +37,8 @@ object UnsafeProjectionBenchmark extends BenchmarkBase {

def generateRows(schema: StructType, numRows: Int): Array[InternalRow] = {
val generator = RandomDataGenerator.forType(schema, nullable = false).get
val encoder = RowEncoder(schema)
(1 to numRows).map(_ => encoder.toRow(generator().asInstanceOf[Row]).copy()).toArray
val toRow = RowEncoder(schema).createSerializer()
(1 to numRows).map(_ => toRow(generator().asInstanceOf[Row]).copy()).toArray
}

override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
Expand Down
Expand Up @@ -22,6 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
Expand All @@ -42,37 +43,44 @@ case class NestedArrayClass(nestedArr: Array[ArrayClass])
class EncoderResolutionSuite extends PlanTest {
private val str = UTF8String.fromString("hello")

def testFromRow[T](
encoder: ExpressionEncoder[T],
attributes: Seq[Attribute],
row: InternalRow): Unit = {
encoder.resolveAndBind(attributes).createDeserializer().apply(row)
}

test("real type doesn't match encoder schema but they are compatible: product") {
val encoder = ExpressionEncoder[StringLongClass]

// int type can be up cast to long type
val attrs1 = Seq('a.string, 'b.int)
encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1))
testFromRow(encoder, attrs1, InternalRow(str, 1))

// int type can be up cast to string type
val attrs2 = Seq('a.int, 'b.long)
encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L))
testFromRow(encoder, attrs2, InternalRow(1, 2L))
}

test("real type doesn't match encoder schema but they are compatible: nested product") {
val encoder = ExpressionEncoder[ComplexClass]
val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
testFromRow(encoder, attrs, InternalRow(1, InternalRow(2, 3L)))
}

test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
val encoder = ExpressionEncoder.tuple(
ExpressionEncoder[StringLongClass],
ExpressionEncoder[Long])
val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
testFromRow(encoder, attrs, InternalRow(InternalRow(str, 1.toByte), 2))
}

test("real type doesn't match encoder schema but they are compatible: primitive array") {
val encoder = ExpressionEncoder[PrimitiveArrayClass]
val attrs = Seq('arr.array(IntegerType))
val array = new GenericArrayData(Array(1, 2, 3))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
testFromRow(encoder, attrs, InternalRow(array))
}

test("the real type is not compatible with encoder schema: primitive array") {
Expand All @@ -93,7 +101,7 @@ class EncoderResolutionSuite extends PlanTest {
val encoder = ExpressionEncoder[ArrayClass]
val attrs = Seq('arr.array(new StructType().add("a", "int").add("b", "int").add("c", "int")))
val array = new GenericArrayData(Array(InternalRow(1, 2, 3)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(array))
testFromRow(encoder, attrs, InternalRow(array))
}

test("real type doesn't match encoder schema but they are compatible: nested array") {
Expand All @@ -103,7 +111,7 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = Seq('nestedArr.array(et))
val innerArr = new GenericArrayData(Array(InternalRow(1, 2, 3)))
val outerArr = new GenericArrayData(Array(InternalRow(innerArr)))
encoder.resolveAndBind(attrs).fromRow(InternalRow(outerArr))
testFromRow(encoder, attrs, InternalRow(outerArr))
}

test("the real type is not compatible with encoder schema: non-array field") {
Expand Down Expand Up @@ -142,14 +150,14 @@ class EncoderResolutionSuite extends PlanTest {
val attrs = 'a.array(IntegerType) :: Nil

// It should pass analysis
val bound = encoder.resolveAndBind(attrs)
val fromRow = encoder.resolveAndBind(attrs).createDeserializer()

// If no null values appear, it should work fine
bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
fromRow(InternalRow(new GenericArrayData(Array(1, 2))))

// If there is null value, it should throw runtime exception
val e = intercept[RuntimeException] {
bound.fromRow(InternalRow(new GenericArrayData(Array(1, null))))
fromRow(InternalRow(new GenericArrayData(Array(1, null))))
}
assert(e.getMessage.contains("Null value appeared in non-nullable field"))
}
Expand Down

0 comments on commit e7fef70

Please sign in to comment.