Skip to content

Commit

Permalink
[SPARK-22355][SQL] Dataset.collect is not threadsafe
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

It's possible that users create a `Dataset`, and call `collect` of this `Dataset` in many threads at the same time. Currently `Dataset#collect` just call `encoder.fromRow` to convert spark rows to objects of type T, and this encoder is per-dataset. This means `Dataset#collect` is not thread-safe, because the encoder uses a projection to output the object to a re-usable row.

This PR fixes this problem, by creating a new projection when calling `Dataset#collect`, so that we have the re-usable row for each method call, instead of each Dataset.

## How was this patch tested?

N/A

Author: Wenchen Fan <wenchen@databricks.com>

Closes #19577 from cloud-fan/encoder.
  • Loading branch information
cloud-fan authored and gatorsmile committed Oct 27, 2017
1 parent 9b262f6 commit 5c3a1f3
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
Expand Down Expand Up @@ -198,15 +199,10 @@ class Dataset[T] private[sql](
*/
private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)

/**
* Encoder is used mostly as a container of serde expressions in Dataset. We build logical
* plans by these serde expressions and execute it within the query framework. However, for
* performance reasons we may want to use encoder as a function to deserialize internal rows to
* custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its
* `fromRow` method later.
*/
private val boundEnc =
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
// The deserializer expression which can be used to build a projection and turn rows to objects
// of type T, after collecting rows to the driver side.
private val deserializer =
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer

private implicit def classTag = exprEnc.clsTag

Expand Down Expand Up @@ -2661,7 +2657,15 @@ class Dataset[T] private[sql](
*/
def toLocalIterator(): java.util.Iterator[T] = {
withAction("toLocalIterator", queryExecution) { plan =>
plan.executeToIterator().map(boundEnc.fromRow).asJava
// This projection writes output to a `InternalRow`, which means applying this projection is
// not thread-safe. Here we create the projection inside this method to make `Dataset`
// thread-safe.
val objProj = GenerateSafeProjection.generate(deserializer :: Nil)
plan.executeToIterator().map { row =>
// The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
// parameter of its `get` method, so it's safe to use null here.
objProj(row).get(0, null).asInstanceOf[T]
}.asJava
}
}

Expand Down Expand Up @@ -3102,7 +3106,14 @@ class Dataset[T] private[sql](
* Collect all elements from a spark plan.
*/
private def collectFromPlan(plan: SparkPlan): Array[T] = {
plan.executeCollect().map(boundEnc.fromRow)
// This projection writes output to a `InternalRow`, which means applying this projection is not
// thread-safe. Here we create the projection inside this method to make `Dataset` thread-safe.
val objProj = GenerateSafeProjection.generate(deserializer :: Nil)
plan.executeCollect().map { row =>
// The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
// parameter of its `get` method, so it's safe to use null here.
objProj(row).get(0, null).asInstanceOf[T]
}
}

private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
Expand Down

0 comments on commit 5c3a1f3

Please sign in to comment.