Skip to content

Commit

Permalink
[SPARK-27805][PYTHON] Propagate SparkExceptions during toPandas with …
Browse files Browse the repository at this point in the history
…arrow enabled

## What changes were proposed in this pull request?
Similar to #24070, we now propagate SparkExceptions that are encountered during the collect in the java process to the python process.

Fixes https://jira.apache.org/jira/browse/SPARK-27805

## How was this patch tested?
Added a new unit test

Closes #24677 from dvogelbacher/dv/betterErrorMsgWhenUsingArrow.

Authored-by: David Vogelbacher <dvogelbacher@palantir.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
  • Loading branch information
dvogelbacher authored and BryanCutler committed Jun 4, 2019
1 parent adf72e2 commit f9ca8ab
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 14 deletions.
6 changes: 5 additions & 1 deletion python/pyspark/serializers.py
Expand Up @@ -206,8 +206,12 @@ def load_stream(self, stream):
for batch in self.serializer.load_stream(stream):
yield batch

# load the batch order indices
# load the batch order indices or propagate any error that occurred in the JVM
num = read_int(stream)
if num == -1:
error_msg = UTF8Deserializer().loads(stream)
raise RuntimeError("An error occurred while calling "
"ArrowCollectSerializer.load_stream: {}".format(error_msg))
batch_order = []
for i in xrange(num):
index = read_int(stream)
Expand Down
12 changes: 12 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Expand Up @@ -23,6 +23,7 @@
import warnings

from pyspark.sql import Row
from pyspark.sql.functions import udf
from pyspark.sql.types import *
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
Expand Down Expand Up @@ -205,6 +206,17 @@ def test_no_partition_frame(self):
self.assertEqual(pdf.columns[0], "field1")
self.assertTrue(pdf.empty)

def test_propagates_spark_exception(self):
df = self.spark.range(3).toDF("i")

def raise_exception():
raise Exception("My error")
exception_udf = udf(raise_exception, IntegerType())
df = df.withColumn("error", exception_udf())
with QuietTest(self.sc):
with self.assertRaisesRegexp(RuntimeError, 'My error'):
df.toPandas()

def _createDataFrame_toggle(self, pdf, schema=None):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
Expand Down
40 changes: 27 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Expand Up @@ -26,7 +26,7 @@ import scala.util.control.NonFatal

import org.apache.commons.lang3.StringUtils

import org.apache.spark.TaskContext
import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
Expand Down Expand Up @@ -3321,20 +3321,34 @@ class Dataset[T] private[sql](
}
}

val arrowBatchRdd = toArrowBatchRdd(plan)
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(it: Iterator[Array[Byte]]) => it.toArray,
handlePartitionBatches)
var sparkException: Option[SparkException] = None
try {
val arrowBatchRdd = toArrowBatchRdd(plan)
sparkSession.sparkContext.runJob(
arrowBatchRdd,
(it: Iterator[Array[Byte]]) => it.toArray,
handlePartitionBatches)
} catch {
case e: SparkException =>
sparkException = Some(e)
}

// After processing all partitions, end the stream and write batch order indices
// After processing all partitions, end the batch stream
batchWriter.end()
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
sparkException match {
case Some(exception) =>
// Signal failure and write error message
out.writeInt(-1)
PythonRDD.writeUTF(exception.getMessage, out)
case None =>
// Write batch order indices
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) =>
out.writeInt(overallBatchIndex)
}
}
}
}
Expand Down

0 comments on commit f9ca8ab

Please sign in to comment.