Skip to content

Commit

Permalink
write success and optional error message
Browse files Browse the repository at this point in the history
  • Loading branch information
dvogelbacher committed May 22, 2019
1 parent 64fcc12 commit 3b4fcbe
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 24 deletions.
20 changes: 13 additions & 7 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,19 @@ def load_stream(self, stream):
for batch in self.serializer.load_stream(stream):
yield batch

# load the batch order indices
num = read_int(stream)
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order
# check success
success = read_bool(stream)
if success:
# load the batch order indices
num = read_int(stream)
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order
else:
error_msg = UTF8Deserializer().loads(stream)
raise RuntimeError("An error occurred while collecting: {}".format(error_msg))

def __repr__(self):
return "ArrowCollectSerializer(%s)" % self.serializer
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def raise_exception():
exception_udf = udf(raise_exception, IntegerType())
df = df.withColumn("error", exception_udf())
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'My error'):
with self.assertRaisesRegexp(RuntimeError, 'My error'):
df.toPandas()

def _createDataFrame_toggle(self, pdf, schema=None):
Expand Down
46 changes: 30 additions & 16 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
package org.apache.spark.sql

import java.io.{ByteArrayOutputStream, CharArrayWriter, DataOutputStream}
import java.nio.charset.StandardCharsets

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.util.control.NonFatal

import org.apache.commons.lang3.StringUtils

import org.apache.spark.TaskContext
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
Expand All @@ -40,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.HiveTableRelation
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateSafeProjection
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JSONOptions}
import org.apache.spark.sql.catalyst.json.{JSONOptions, JacksonGenerator}
import org.apache.spark.sql.catalyst.optimizer.CombineUnions
import org.apache.spark.sql.catalyst.parser.{ParseException, ParserUtils}
import org.apache.spark.sql.catalyst.plans._
Expand Down Expand Up @@ -3313,20 +3312,35 @@ class Dataset[T] private[sql](
}
}

val arrowBatchRdd = toArrowBatchRdd(plan)
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(it: Iterator[Array[Byte]]) => it.toArray,
handlePartitionBatches)
var sparkException: Option[SparkException] = Option.empty
try {
val arrowBatchRdd = toArrowBatchRdd(plan)
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(it: Iterator[Array[Byte]]) => it.toArray,
handlePartitionBatches)
} catch {
case e: SparkException =>
sparkException = Option.apply(e)
}

// After processing all partitions, end the stream and write batch order indices
// After processing all partitions, end the batch stream
batchWriter.end()
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
sparkException match {
case Some(exception) =>
// Signal failure and write error message
out.writeBoolean(false)
PythonRDD.writeUTF(exception.getMessage, out)
case None =>
// Signal success and write batch order indices
out.writeBoolean(true)
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
}
}
}
}
Expand Down

0 comments on commit 3b4fcbe

Please sign in to comment.