Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13534][PySpark] Using Apache Arrow to increase performance of DataFrame.toPandas #15821

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
f681d52
Inital attempt to integrate Arrow for use in dataframe.toPandas. Con…
BryanCutler Dec 14, 2016
afd5739
Test suite prototyping for collectAsArrow
icexelloss Dec 12, 2016
a4b958e
Test compiling against the newest arrow; Fix validity map; Add benchm…
icexelloss Jan 5, 2017
be508a5
Fix conversion for String type; refactor related functions to Arrow.s…
icexelloss Jan 12, 2017
5dbad22
Moved test data files to a sub-dir for arrow, merged dataType matchin…
BryanCutler Jan 12, 2017
5837b38
added some python unit tests
BryanCutler Jan 14, 2017
bdba357
Implement Arrow column writers
icexelloss Jan 17, 2017
d20437f
added bool type converstion test
BryanCutler Jan 20, 2017
2e81a93
changed scope of some functions and minor cleanup
BryanCutler Jan 24, 2017
1ce4f2d
Add support for date/timestamp/binary; Add more numbers to benchmark.…
icexelloss Jan 23, 2017
ed1f0fa
Cleanup of changes before updating the PR for review
BryanCutler Jan 24, 2017
202650e
Changed RootAllocator param to Option in collectAsArrow
BryanCutler Jan 25, 2017
fbe3b7c
renamed to ArrowConverters
BryanCutler Jan 27, 2017
f44e6d7
Adjust to cleaned up pyarrow FileReader API, support multiple record …
wesm Jan 30, 2017
e0bf11b
changed conversion to use Iterator[InternalRow] instead of Array
BryanCutler Feb 3, 2017
3090a3e
Changed tests to use generated JSON data instead of files
BryanCutler Feb 22, 2017
54884ed
updated Arrow artifacts to 0.2.0 release
BryanCutler Feb 23, 2017
42af1d5
fixed python style checks
BryanCutler Feb 23, 2017
9c8ea63
updated dependency manifest
BryanCutler Feb 23, 2017
b7c28ad
test format fix for python 2.6
BryanCutler Feb 24, 2017
2851cd6
fixed docstrings and added list of pyarrow supported types
BryanCutler Feb 28, 2017
f8f24ab
fixed memory leak of ArrowRecordBatch iterator getting consumed and b…
BryanCutler Mar 3, 2017
b6c752b
changed _collectAsArrow to private method
BryanCutler Mar 3, 2017
cbab294
added netty to exclusion list for arrow dependency
BryanCutler Mar 3, 2017
44ca3ff
dict comprehensions not supported in python 2.6
BryanCutler Mar 7, 2017
33b75b9
ensure payload batches are closed if any exception is thrown, some mi…
BryanCutler Mar 10, 2017
97742b8
changed comment for readable seekable byte channel class
BryanCutler Mar 13, 2017
b821077
Remove Date and Timestamp from supported types
icexelloss Mar 28, 2017
3d786a2
Added scaladocs to methods that did not have it
BryanCutler Apr 3, 2017
cb4c510
added check for pyarrow import error
BryanCutler Apr 3, 2017
074f66c
Merge remote-tracking branch 'upstream/master' into wip-toPandas_with…
BryanCutler Apr 13, 2017
7260217
changed pyspark script to accept all args when testing
BryanCutler Apr 13, 2017
a0483b8
added pyarrow tests to be launched during run-pip-tests when using conda
BryanCutler Apr 13, 2017
470f33d
Merge remote-tracking branch 'upstream/master' into wip-toPandas_with…
BryanCutler Apr 25, 2017
c144667
pre-update for using Arrow 0.3, cleanup of converter functions, times…
BryanCutler Apr 25, 2017
250b581
added DateType to tests
BryanCutler Apr 26, 2017
f667a7a
removed support for DateType and TimestampType for now
BryanCutler Apr 26, 2017
76f7ddb
moved ArrowConverters to o.a.s.sql.execution.arrow
BryanCutler Apr 26, 2017
89dd0f4
changed useArrow flag to SQLConf spark.sql.execution.arrow.enable
BryanCutler Apr 27, 2017
d7cb4ab
separated numeric tests, moved data to test scope
BryanCutler Apr 27, 2017
b6fe733
removed timestamp and date test until fully supported
BryanCutler Apr 27, 2017
36f8127
added exception handling in byteArrayToBatch conversion, changed Arro…
BryanCutler May 1, 2017
088f79e
added binary conversion test
BryanCutler May 1, 2017
e0449eb
fixed up unsupported test for timestamp
BryanCutler May 1, 2017
b6bfcd7
Updated Arrow version to 0.3.0
BryanCutler May 2, 2017
2c1af59
added ArrowPayload method toByteArray
BryanCutler May 9, 2017
1d471ac
removed unused imports, arrow.vector.DateUnit and TimeUnit
BryanCutler May 9, 2017
a4d6057
Added conf spark.sql.execution.arrow.maxRecordsPerBatch to limit num …
BryanCutler May 9, 2017
934c147
update dependency manifests for Arrow 0.3.0
BryanCutler May 9, 2017
2e4747b
changed tests to close resources properly
BryanCutler May 10, 2017
b4eebc2
Made JSON test data local string for each test, removed JSON generation
BryanCutler May 17, 2017
d49a14d
upgrade to use Arrow 0.4
BryanCutler May 19, 2017
a630bf0
forgot to update arrow version in dependency manifests
BryanCutler May 26, 2017
748e6fb
Changed UTF8StringColumnWriter to use VarCharVector
BryanCutler Jun 15, 2017
b361bdc
Added check for DataFrame that is filtered out completely and convert…
BryanCutler Jun 19, 2017
8bff966
Moved all work out of ArrowPayload construction to companion object
BryanCutler Jun 19, 2017
f96f555
Renamed variable to schemaCaptured
BryanCutler Jun 20, 2017
b53e09f
Merge remote-tracking branch 'upstream/master' into wip-toPandas_with…
BryanCutler Jun 22, 2017
44d7a2a
cleanup up test now that toPandas without Arrow will have correct dtypes
BryanCutler Jun 22, 2017
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bin/pyspark
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ if [[ -n "$SPARK_TESTING" ]]; then
unset YARN_CONF_DIR
unset HADOOP_CONF_DIR
export PYTHONHASHSEED=0
exec "$PYSPARK_DRIVER_PYTHON" -m "$1"
exec "$PYSPARK_DRIVER_PYTHON" -m "$@"
exit
fi

Expand Down
5 changes: 5 additions & 0 deletions dev/deps/spark-deps-hadoop-2.6
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar
api-asn1-api-1.0.0-M20.jar
api-util-1.0.0-M20.jar
arpack_combined_all-0.1.jar
arrow-format-0.4.0.jar
arrow-memory-0.4.0.jar
arrow-vector-0.4.0.jar
avro-1.7.7.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
Expand Down Expand Up @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar
datanucleus-rdbms-3.2.9.jar
derby-10.12.1.1.jar
eigenbase-properties-1.1.5.jar
flatbuffers-1.2.0-3f79e055.jar
gson-2.2.4.jar
guava-14.0.1.jar
guice-3.0.jar
Expand All @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.6.5.jar
hk2-api-2.4.0-b34.jar
hk2-locator-2.4.0-b34.jar
hk2-utils-2.4.0-b34.jar
hppc-0.7.1.jar
htrace-core-3.0.4.jar
httpclient-4.5.2.jar
httpcore-4.4.4.jar
Expand Down
5 changes: 5 additions & 0 deletions dev/deps/spark-deps-hadoop-2.7
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ apacheds-kerberos-codec-2.0.0-M15.jar
api-asn1-api-1.0.0-M20.jar
api-util-1.0.0-M20.jar
arpack_combined_all-0.1.jar
arrow-format-0.4.0.jar
arrow-memory-0.4.0.jar
arrow-vector-0.4.0.jar
avro-1.7.7.jar
avro-ipc-1.7.7.jar
avro-mapred-1.7.7-hadoop2.jar
Expand Down Expand Up @@ -55,6 +58,7 @@ datanucleus-core-3.2.10.jar
datanucleus-rdbms-3.2.9.jar
derby-10.12.1.1.jar
eigenbase-properties-1.1.5.jar
flatbuffers-1.2.0-3f79e055.jar
gson-2.2.4.jar
guava-14.0.1.jar
guice-3.0.jar
Expand All @@ -77,6 +81,7 @@ hadoop-yarn-server-web-proxy-2.7.3.jar
hk2-api-2.4.0-b34.jar
hk2-locator-2.4.0-b34.jar
hk2-utils-2.4.0-b34.jar
hppc-0.7.1.jar
htrace-core-3.1.0-incubating.jar
httpclient-4.5.2.jar
httpcore-4.4.4.jar
Expand Down
6 changes: 6 additions & 0 deletions dev/run-pip-tests
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ for python in "${PYTHON_EXECS[@]}"; do
if [ -n "$USE_CONDA" ]; then
conda create -y -p "$VIRTUALENV_PATH" python=$python numpy pandas pip setuptools
source activate "$VIRTUALENV_PATH"
conda install -y -c conda-forge pyarrow=0.4.0
TEST_PYARROW=1
else
mkdir -p "$VIRTUALENV_PATH"
virtualenv --python=$python "$VIRTUALENV_PATH"
Expand Down Expand Up @@ -120,6 +122,10 @@ for python in "${PYTHON_EXECS[@]}"; do
python "$FWDIR"/dev/pip-sanity-check.py
echo "Run the tests for context.py"
python "$FWDIR"/python/pyspark/context.py
if [ -n "$TEST_PYARROW" ]; then
echo "Run tests for pyarrow"
SPARK_TESTING=1 "$FWDIR"/bin/pyspark pyspark.sql.tests ArrowTests
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a temporary addition to make sure pyarrow tests run because they required a conda env, for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will we remove this at some time?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyarrow has pip packages now, does that help? @BryanCutler I can't read your comment and understand why they required a conda env but maybe that's it?

Copy link
Member Author

@BryanCutler BryanCutler Jun 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this will definitely be removed at some point. I was working with @holdenk to set this up as a temporary way to get the Python Arrow tests to run. I'll look into using the pip packages and see if that can be used instead of this. Thanks @leifwalsh!

fi

cd "$FWDIR"

Expand Down
20 changes: 20 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@
<paranamer.version>2.6</paranamer.version>
<maven-antrun.version>1.8</maven-antrun.version>
<commons-crypto.version>1.0.0</commons-crypto.version>
<arrow.version>0.4.0</arrow.version>

<test.java.home>${java.home}</test.java.home>
<test.exclude.tags></test.exclude.tags>
Expand Down Expand Up @@ -1878,6 +1879,25 @@
<artifactId>paranamer</artifactId>
<version>${paranamer.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we add arrow dependency at root instead of only spark sql?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is just standard pom ..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just the dependency management section in the main pom. The only actual dependency is in spark-sql.

<artifactId>arrow-vector</artifactId>
<version>${arrow.version}</version>
<exclusions>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider explore excluding netty here, since we exclude it in most of the other related projects (like parquet), since it seems to have added some unnecessary jars to the deps list.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added netty to exclusions and does not seem to cause any issues

<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
</exclusion>
<exclusion>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
</exclusion>
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty-handler</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down
17 changes: 17 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,23 @@ def loads(self, obj):
raise NotImplementedError


class ArrowSerializer(FramedSerializer):
"""
Serializes an Arrow stream.
"""

def dumps(self, obj):
raise NotImplementedError

def loads(self, obj):
import pyarrow as pa
reader = pa.RecordBatchFileReader(pa.BufferReader(obj))
return reader.read_all()
Copy link
Contributor

@cloud-fan cloud-fan Jun 14, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since we are sending multiple batches from JVM, does reader.read_all() wait for all the batches?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will read all batches in a framed byte array from a stream and return. The stream can have multiple framed byte arrays, so it repeats until end of stream.

How many batches this reads depends on how it serialized. When calling toPandas(), it collects all batches and then serializes each one to Python as an iterator. So in this case, reader.read_all() will read 1 batch at a time.


def __repr__(self):
return "ArrowSerializer"


class BatchedSerializer(Serializer):

"""
Expand Down
48 changes: 37 additions & 11 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

from pyspark import copy_func, since
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.serializers import ArrowSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
Expand Down Expand Up @@ -1708,7 +1709,8 @@ def toDF(self, *cols):

@since(1.3)
def toPandas(self):
"""Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.
"""
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.

This is only available if Pandas is installed and available.

Expand All @@ -1721,18 +1723,42 @@ def toPandas(self):
1 5 Bob
"""
import pandas as pd
if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true":
try:
import pyarrow
tables = self._collectAsArrow()
if tables:
table = pyarrow.concat_tables(tables)
return table.to_pandas()
else:
return pd.DataFrame.from_records([], columns=self.columns)
except ImportError as e:
msg = "note: pyarrow must be installed and available on calling Python process " \
"if using spark.sql.execution.arrow.enable=true"
raise ImportError("%s\n%s" % (e.message, msg))
else:
dtype = {}
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
if pandas_type is not None:
dtype[field.name] = pandas_type

dtype = {}
for field in self.schema:
pandas_type = _to_corrected_pandas_type(field.dataType)
if pandas_type is not None:
dtype[field.name] = pandas_type
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)

pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
return pdf

for f, t in dtype.items():
pdf[f] = pdf[f].astype(t, copy=False)
return pdf
def _collectAsArrow(self):
"""
Returns all records as list of deserialized ArrowPayloads, pyarrow must be installed
and available.

.. note:: Experimental.
"""
with SCCallSiteSync(self._sc) as css:
port = self._jdf.collectAsArrowToPython()
return list(_load_from_socket(port, ArrowSerializer()))

##########################################################################################
# Pandas compatibility
Expand Down
79 changes: 78 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,21 @@
from pyspark.sql import SparkSession, SQLContext, HiveContext, Column, Row
from pyspark.sql.types import *
from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase, SparkSubmitTests
from pyspark.tests import QuietTest, ReusedPySparkTestCase, SparkSubmitTests
from pyspark.sql.functions import UserDefinedFunction, sha2, lit
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException


_have_arrow = False
try:
import pyarrow
_have_arrow = True
Copy link
Member

@viirya viirya Mar 29, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should do similar thing above when using Arrow required feature, e.g., ArrowSerializer.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean to automatically enable the Arrow functionality if pyarrow installed? Right now it is enabled manually with a flag useArrow in the public API toPandas. If enabled and pyarrow is not installed, it will give an import error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we should throw an exception when useArrow is used but no pyspark is installed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe give the param doc string as exception message?

I.e., To make use of Apache Arrow for conversion, pyarrow must be installed and available on the calling Python process (Experimental).

except:
# No Arrow, but that's okay, we'll skip those tests
pass


class UTCOffsetTimezone(datetime.tzinfo):
"""
Specifies timezone in UTC offset
Expand Down Expand Up @@ -2620,6 +2629,74 @@ def range_frame_match():

importlib.reload(window)


@unittest.skipIf(not _have_arrow, "Arrow not installed")
class ArrowTests(ReusedPySparkTestCase):

@classmethod
def setUpClass(cls):
ReusedPySparkTestCase.setUpClass()
cls.spark = SparkSession(cls.sc)
cls.spark.conf.set("spark.sql.execution.arrow.enable", "true")
cls.schema = StructType([
StructField("1_str_t", StringType(), True),
StructField("2_int_t", IntegerType(), True),
StructField("3_long_t", LongType(), True),
StructField("4_float_t", FloatType(), True),
StructField("5_double_t", DoubleType(), True)])
cls.data = [("a", 1, 10, 0.2, 2.0),
("b", 2, 20, 0.4, 4.0),
("c", 3, 30, 0.8, 6.0)]

def assertFramesEqual(self, df_with_arrow, df_without):
msg = ("DataFrame from Arrow is not equal" +
("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) +
("\n\nWithout:\n%s\n%s" % (df_without, df_without.dtypes)))
self.assertTrue(df_without.equals(df_with_arrow), msg=msg)

def test_unsupported_datatype(self):
schema = StructType([StructField("array", ArrayType(IntegerType(), False), True)])
df = self.spark.createDataFrame([([1, 2, 3],)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())

def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
self.data)
pdf = df_null.toPandas()
null_counts = pdf.isnull().sum().tolist()
self.assertTrue(all([c == 1 for c in null_counts]))

def test_toPandas_arrow_toggle(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
self.spark.conf.set("spark.sql.execution.arrow.enable", "false")
pdf = df.toPandas()
self.spark.conf.set("spark.sql.execution.arrow.enable", "true")
pdf_arrow = df.toPandas()
self.assertFramesEqual(pdf_arrow, pdf)

def test_pandas_round_trip(self):
import pandas as pd
import numpy as np
data_dict = {}
for j, name in enumerate(self.schema.names):
data_dict[name] = [self.data[i][j] for i in range(len(self.data))]
# need to convert these to numpy types first
data_dict["2_int_t"] = np.int32(data_dict["2_int_t"])
data_dict["4_float_t"] = np.float32(data_dict["4_float_t"])
pdf = pd.DataFrame(data=data_dict)
df = self.spark.createDataFrame(self.data, schema=self.schema)
pdf_arrow = df.toPandas()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

call self.spark.conf.set("spark.sql.execution.arrow.enable", "true") before this

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's done in the setUpClass so it's already enabled for all tests. There is a test where it's toggled, off then back on though to test that behavior.

self.assertFramesEqual(pdf_arrow, pdf)

def test_filtered_frame(self):
df = self.spark.range(3).toDF("i")
pdf = df.filter("i < 0").toPandas()
self.assertEqual(len(pdf.columns), 1)
self.assertEqual(pdf.columns[0], "i")
self.assertTrue(pdf.empty)


if __name__ == "__main__":
from pyspark.sql.tests import *
if xmlrunner:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,24 @@ object SQLConf {
.intConf
.createWithDefault(UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD.toInt)

val ARROW_EXECUTION_ENABLE =
buildConf("spark.sql.execution.arrow.enable")
.internal()
.doc("Make use of Apache Arrow for columnar data transfers. Currently available " +
"for use with pyspark.sql.DataFrame.toPandas with the following data types: " +
"StringType, BinaryType, BooleanType, DoubleType, FloatType, ByteType, IntegerType, " +
"LongType, ShortType")
.booleanConf
.createWithDefault(false)

val ARROW_EXECUTION_MAX_RECORDS_PER_BATCH =
buildConf("spark.sql.execution.arrow.maxRecordsPerBatch")
.internal()
.doc("When using Apache Arrow, limit the maximum number of records that can be written " +
"to a single ArrowRecordBatch in memory. If set to zero or negative there is no limit.")
.intConf
.createWithDefault(10000)

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -1104,6 +1122,10 @@ class SQLConf extends Serializable with Logging {

def starSchemaFTRatio: Double = getConf(STARSCHEMA_FACT_TABLE_RATIO)

def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE)

def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH)

/** ********************** SQLConf functionality methods ************ */

/** Set Spark SQL configuration properties. */
Expand Down
4 changes: 4 additions & 0 deletions sql/core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@
<artifactId>jackson-databind</artifactId>
<version>${fasterxml.jackson.version}</version>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-vector</artifactId>
</dependency>
<dependency>
<groupId>org.apache.xbean</groupId>
<artifactId>xbean-asm5-shaded</artifactId>
Expand Down
20 changes: 20 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, PartitioningCollection}
import org.apache.spark.sql.catalyst.util.{usePrettyExpression, DateTimeUtils}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload}
import org.apache.spark.sql.execution.command._
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.python.EvaluatePython
Expand Down Expand Up @@ -2922,6 +2923,16 @@ class Dataset[T] private[sql](
}
}

/**
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
*/
private[sql] def collectAsArrowToPython(): Int = {
withNewExecutionId {
val iter = toArrowPayload.collect().iterator.map(_.asPythonSerializable)
PythonRDD.serveIterator(iter, "serve-Arrow")
}
}

private[sql] def toPythonIterator(): Int = {
withNewExecutionId {
PythonRDD.toLocalIteratorAndServe(javaToPython.rdd)
Expand Down Expand Up @@ -3003,4 +3014,13 @@ class Dataset[T] private[sql](
Dataset(sparkSession, logicalPlan)
}
}

/** Convert to an RDD of ArrowPayload byte arrays */
private[sql] def toArrowPayload: RDD[ArrowPayload] = {
val schemaCaptured = this.schema
val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch
queryExecution.toRdd.mapPartitionsInternal { iter =>
ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch)
}
}
}
Loading