Skip to content

Commit

Permalink
Dataset.collect is not threadsafe
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Oct 25, 2017
1 parent 1051ebe commit cecea8c
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
Expand Down Expand Up @@ -198,15 +199,10 @@ class Dataset[T] private[sql](
*/
private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)

/**
* Encoder is used mostly as a container of serde expressions in Dataset. We build logical
* plans by these serde expressions and execute it within the query framework. However, for
* performance reasons we may want to use encoder as a function to deserialize internal rows to
* custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its
* `fromRow` method later.
*/
private val boundEnc =
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
// The deserializer expression which can be used to build a projection and turn rows to objects
// of type T, after collecting rows to the driver side.
private val deserializer =
exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer).deserializer

private implicit def classTag = exprEnc.clsTag

Expand Down Expand Up @@ -2661,7 +2657,12 @@ class Dataset[T] private[sql](
*/
def toLocalIterator(): java.util.Iterator[T] = {
withAction("toLocalIterator", queryExecution) { plan =>
plan.executeToIterator().map(boundEnc.fromRow).asJava
val objProj = GenerateSafeProjection.generate(deserializer :: Nil)
plan.executeToIterator().map { row =>
// The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
// parameter of its `get` method, so it's safe to use null here.
objProj(row).get(0, null).asInstanceOf[T]
}.asJava
}
}

Expand Down Expand Up @@ -3102,7 +3103,12 @@ class Dataset[T] private[sql](
* Collect all elements from a spark plan.
*/
private def collectFromPlan(plan: SparkPlan): Array[T] = {
plan.executeCollect().map(boundEnc.fromRow)
val objProj = GenerateSafeProjection.generate(deserializer :: Nil)
plan.executeCollect().map { row =>
// The row returned by SafeProjection is `SpecificInternalRow`, which ignore the data type
// parameter of its `get` method, so it's safe to use null here.
objProj(row).get(0, null).asInstanceOf[T]
}
}

private def sortInternal(global: Boolean, sortExprs: Seq[Column]): Dataset[T] = {
Expand Down

0 comments on commit cecea8c

Please sign in to comment.