Skip to content

Commit

Permalink
cr
Browse files Browse the repository at this point in the history
  • Loading branch information
dvogelbacher committed May 29, 2019
1 parent 4f57b7d commit d9936d5
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 18 deletions.
22 changes: 10 additions & 12 deletions python/pyspark/serializers.py
Expand Up @@ -206,19 +206,17 @@ def load_stream(self, stream):
for batch in self.serializer.load_stream(stream):
yield batch

# check success
success = read_bool(stream)
if success:
# load the batch order indices
num = read_int(stream)
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order
else:
# load the batch order indices
num = read_int(stream)
if num == -1:
error_msg = UTF8Deserializer().loads(stream)
raise RuntimeError("An error occurred while collecting: {}".format(error_msg))
raise RuntimeError("An error occurred while calling "
"ArrowCollectSerializer.load_stream: {}".format(error_msg))
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order

def __repr__(self):
return "ArrowCollectSerializer(%s)" % self.serializer
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/tests/test_arrow.py
Expand Up @@ -23,6 +23,7 @@
import warnings

from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
Expand Down Expand Up @@ -193,7 +194,6 @@ def test_no_partition_frame(self):

def test_propagates_spark_exception(self):
df = self.spark.range(3).toDF("i")
from pyspark.sql.functions import udf

def raise_exception():
raise Exception("My error")
Expand Down
9 changes: 4 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Expand Up @@ -3313,7 +3313,7 @@ class Dataset[T] private[sql](
}
}

var sparkException: Option[SparkException] = Option.empty
var sparkException: Option[SparkException] = None
try {
val arrowBatchRdd = toArrowBatchRdd(plan)
sparkSession.sparkContext.runJob(
Expand All @@ -3322,19 +3322,18 @@ class Dataset[T] private[sql](
handlePartitionBatches)
} catch {
case e: SparkException =>
sparkException = Option.apply(e)
sparkException = Some(e)
}

// After processing all partitions, end the batch stream
batchWriter.end()
sparkException match {
case Some(exception) =>
// Signal failure and write error message
out.writeBoolean(false)
out.writeInt(-1)
PythonRDD.writeUTF(exception.getMessage, out)
case None =>
// Signal success and write batch order indices
out.writeBoolean(true)
// Write batch order indices
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
Expand Down

0 comments on commit d9936d5

Please sign in to comment.