Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-32846][SQL][PYTHON] Support createDataFrame from an RDD of pd.DataFrames #29719

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
133 changes: 108 additions & 25 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,11 @@ class SparkConversionMixin(object):
"""
Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession`
can use this class.
pandasRDD=True creates a DataFrame from an RDD of pandas dataframes
(currently only supported using arrow)
Comment on lines +300 to +301
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to do type checking here instead of having a flag?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we somehow define/get the type of the RDD[py-object] without evaluating the first element of it?
If not, then the RDD might contain any type of object, so the pandasRDD option is used as a way to differentiate between initialization from an RDD and an RDD of pd.DataFrames.

Thank you for reviewing! please let me know if there's anything else i can do to get this merged.

Copy link
Contributor

@holdenk holdenk Jul 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. If we look in session.py we can see _createFromRDD does it magic there. Personally I would put this logic instead inside of _inferSchema and toInternal respectively but I'm coming at this from more a core-spark dev perspective maybe @HyukjinKwon has a different view.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if we don't refactor this back into session.py, I'd encourage you to look at session.py and consider structuring this in a similar way so that we don't have to have this flag here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that this seems to fit well into _inferSchema & _createFromRDD, although we still would need some way to discern between an rdd of DataFrames and other types when the user provides a schema (and we don't want to peek into the first item).

Do you think it would be better to move the pandas flag into _createFromRDD?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So _inferSchema does effectively peek into the first element. I think we could just put the logic down inside of the map and then the user doesn't have to specify this flag.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's in case the user wants to infer the schema (so we have to peek into the rdd), but in case the use does specify the schema, there's no need to peek, and we're left with no other option to tell which code path we need

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So let's say the user specifies a schema, in that case inside of _createFromRDD we can just look at the type of each element that were processing and see if it's a DataFrame or a Row or a Dictionary and dispatch the logic there. What do you think? Or is there a reason I'm missing why we couldn't do the dispatch inside of _createFromRDD based on type?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well in case the user specifies a schema, the entire process is lazy, so there's no need to evaluate any of the rdd elements...

if we keep everything lazy and map each element to either a row or RecordBatch, we would still need to know which path to take, e.g. for RecordBatches we need to call:

        from pyspark.sql.dataframe import DataFrame
        jrdd = rb_rdd._to_java_object_rdd()
        jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), self._wrapped._jsqlContext)
        df = DataFrame(jdf, self._wrapped)
        df._schema = schema
        return df

and for Rows we need to call:

        jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
        jdf = self._jsparkSession.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
        df = DataFrame(jdf, self._wrapped)
        df._schema = schema
        return df

"""
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True,
pandasRDD=False):
from pyspark.sql import SparkSession

assert isinstance(self, SparkSession)
Expand All @@ -308,6 +311,14 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr

timezone = self._wrapped._conf.sessionLocalTimeZone()

if self._wrapped._conf.arrowPySparkEnabled() and pandasRDD:
from pyspark.rdd import RDD
if not isinstance(data, RDD):
raise ValueError('pandasRDD is set but data is of type %s, expected RDD type.'
% type(data))
# TODO: Support non-arrow conversion? might be *very* slow
return self._create_from_pandas_rdd_with_arrow(data, schema, timezone)

# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) if not isinstance(x, str) else
Expand Down Expand Up @@ -353,30 +364,8 @@ def _convert_from_pandas(self, pdf, schema, timezone):
assert isinstance(self, SparkSession)

if timezone is not None:
from pyspark.sql.pandas.types import _check_series_convert_timestamps_tz_local
copied = False
if isinstance(schema, StructType):
for field in schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
if s is not pdf[field.name]:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[field.name] = s
else:
for column, series in pdf.iteritems():
s = _check_series_convert_timestamps_tz_local(series, timezone)
if s is not series:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[column] = s
from pyspark.sql.pandas.types import _check_dataframe_covert_timestamps_tz_local
pdf = _check_dataframe_covert_timestamps_tz_local(pdf, timezone, schema)

# Convert pandas.DataFrame to list of numpy records
np_records = pdf.to_records(index=False)
Expand Down Expand Up @@ -421,6 +410,39 @@ def _get_numpy_record_dtype(self, rec):
record_type_list.append((str(col_names[i]), curr_type))
return np.dtype(record_type_list) if has_rec_fix else None

def _create_from_pandas_rdd_with_arrow(self, prdd, schema, timezone):
"""
Create a DataFrame from an RDD of pandas.DataFrames by converting each DF to one or more
Arrow RecordBatches which are then sent to the JVM.
If a schema is passed in, the data types will be used to coerce the data in
Pandas to Arrow conversion.
"""
import pandas as pd
import pyarrow as pa

safecheck = self._wrapped._conf.arrowSafeTypeConversion()

# In case no schema is passed, extract inferred schema from the first record batch
from pyspark.sql.pandas.types import from_arrow_schema
if schema is None:
schema = from_arrow_schema(pa.Schema.from_pandas(prdd.first()))

# Convert to an RDD of arrow record batches
rb_rdd = (prdd.
filter(lambda x: isinstance(x, pd.DataFrame)).
map(lambda x: _dataframe_to_arrow_record_batch(x,
timezone=timezone,
schema=schema,
safecheck=safecheck)))

# Create Spark DataFrame from Arrow record batches RDD
from pyspark.sql.dataframe import DataFrame
jrdd = rb_rdd._to_java_object_rdd()
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), self._wrapped._jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
return df

def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
Expand Down Expand Up @@ -491,6 +513,67 @@ def create_RDD_server():
return df


def _sanitize_arrow_schema(schema):
import pyarrow as pa
import re
sanitized_fields = []

# Convert pyarrow schema to a spark compatible one
_SPARK_DISALLOWED_CHARS = re.compile('[ ,;{}()\n\t=]')

def _sanitized_spark_field_name(name):
return _SPARK_DISALLOWED_CHARS.sub('_', name)

for field in schema:
name = field.name
sanitized_name = _sanitized_spark_field_name(name)

if sanitized_name != name:
sanitized_field = pa.field(sanitized_name, field.type,
field.nullable, field.metadata)
sanitized_fields.append(sanitized_field)
else:
sanitized_fields.append(field)

new_schema = pa.schema(sanitized_fields, metadata=schema.metadata)
return new_schema


def _dataframe_to_arrow_record_batch(pdf, schema=None, timezone=None, safecheck=False):
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
data types will be used to coerce the data in Pandas to Arrow conversion.
"""
import pyarrow as pa
from pyspark.sql.pandas.types import to_arrow_schema, from_arrow_schema
from pyspark.sql.pandas.utils import require_minimum_pandas_version, \
require_minimum_pyarrow_version

require_minimum_pandas_version()
require_minimum_pyarrow_version()

# Determine arrow types to coerce data when creating batches
if schema is not None:
arrow_schema = to_arrow_schema(schema)
else:
# Any timestamps must be coerced to be compatible with Spark
arrow_schema = to_arrow_schema(from_arrow_schema(pa.Schema.from_pandas(pdf)))

# Sanitize arrow schema for spark compatibility
arrow_schema = _sanitize_arrow_schema(arrow_schema)

# Create an Arrow record batch, one batch per DF
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
arrow_data = [(pdf[col_name], arrow_type) for col_name, arrow_type
in zip(arrow_schema.names, arrow_schema.types)]

col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)

return bytearray(ser._create_batch(arrow_data).serialize())


def _test():
import doctest
from pyspark.sql import SparkSession
Expand Down
36 changes: 36 additions & 0 deletions python/pyspark/sql/pandas/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,42 @@ def _check_series_convert_timestamps_tz_local(s, timezone):
return _check_series_convert_timestamps_localize(s, timezone, None)


def _check_dataframe_covert_timestamps_tz_local(pdf, timezone, schema=None):
"""
Convert timestamp to timezone-naive in the specified timezone or local timezone

:param pdf: a pandas.DataFrame
:param timezone: the timezone to convert from. if None then use local timezone
:param schema: an optional spark schema that defines which timestamp columns to inspect
:return pandas.DataFrame where if it is a timestamp, has been converted to tz-naive
"""
copied = False
if isinstance(schema, StructType):
for field in schema:
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
if isinstance(field.dataType, TimestampType):
s = _check_series_convert_timestamps_tz_local(pdf[field.name], timezone)
if s is not pdf[field.name]:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[field.name] = s
else:
for column, series in pdf.iteritems():
s = _check_series_convert_timestamps_tz_local(series, timezone)
if s is not series:
if not copied:
# Copy once if the series is modified to prevent the original
# Pandas DataFrame from being updated
pdf = pdf.copy()
copied = True
pdf[column] = s

return pdf


def _convert_map_items_to_dict(s):
"""
Convert a series with items as list of (key, value), as made from an Arrow column of map type,
Expand Down
27 changes: 20 additions & 7 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,11 @@ def _create_shell_session():

return SparkSession.builder.getOrCreate()

def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True,
pandasRDD=False):
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list or a :class:`pandas.DataFrame`.
Creates a :class:`DataFrame` from an :class:`RDD`, an :class:`RDD[pandas.DataFrame]`,
a list or a :class:`pandas.DataFrame`.

When ``schema`` is a list of column names, the type of each column
will be inferred from ``data``.
Expand All @@ -580,9 +582,9 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
Parameters
----------
data : :class:`RDD` or iterable
an RDD of any kind of SQL data representation (:class:`Row`,
:class:`tuple`, ``int``, ``boolean``, etc.), or :class:`list`, or
:class:`pandas.DataFrame`.
an RDD of any kind of SQL data representation(e.g. :class:`Row`,
:class:`tuple`, ``int``, ``boolean``, :class:`pandas.DataFrame`, etc.),
or :class:`list`, or :class:`pandas.DataFrame`.
schema : :class:`pyspark.sql.types.DataType`, str or list, optional
a :class:`pyspark.sql.types.DataType` or a datatype string or a list of
column names, default is None. The data type string format equals to
Expand All @@ -594,6 +596,8 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
the sample ratio of rows used for inferring
verifySchema : bool, optional
verify data types of every row against schema. Enabled by default.
pandasRDD: bool, optional
indicates that the input RDD contains pandas.DataFrame.

Returns
-------
Expand Down Expand Up @@ -637,6 +641,15 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
>>> df3.collect()
[Row(name='Alice', age=1)]

>>> # doctest: +SKIP
... prdd = sc.range(0, 10).map(lambda x: pandas.DataFrame([[x,]*4], columns=list('ABCD')))
... df4 = spark.createDataFrame(prdd, schema=None, pandasRDD=True)
... df4.collect()
[Row(A=0, B=0, C=0, D=0),
Row(A=1, B=1, C=1, D=1),
Row(A=2, B=2, C=2, D=2),
Row(A=3, B=3, C=3, D=3)]

>>> spark.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
[Row(name='Alice', age=1)]
>>> spark.createDataFrame(pandas.DataFrame([[1, 2]])).collect() # doctest: +SKIP
Expand Down Expand Up @@ -668,10 +681,10 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
has_pandas = True
except Exception:
has_pandas = False
if has_pandas and isinstance(data, pandas.DataFrame):
if has_pandas and (isinstance(data, pandas.DataFrame) or pandasRDD):
# Create a DataFrame from pandas DataFrame.
return super(SparkSession, self).createDataFrame(
data, schema, samplingRatio, verifySchema)
data, schema, samplingRatio, verifySchema, pandasRDD)
return self._create_dataframe(data, schema, samplingRatio, verifySchema)

def _create_dataframe(self, data, schema, samplingRatio, verifySchema):
Expand Down
11 changes: 11 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,17 @@ def test_schema_conversion_roundtrip(self):
schema_rt = from_arrow_schema(arrow_schema)
self.assertEqual(self.schema, schema_rt)

def test_createDataFrame_from_pandas_rdd(self):
pdfs = [self.create_pandas_data_frame() for _ in range(4)]
prdd = self.sc.parallelize(pdfs)

df_from_rdd = self.spark.createDataFrame(prdd, schema=self.schema, pandasRDD=True)
df_from_pdf = self.spark.createDataFrame(pd.concat(pdfs), schema=self.schema)

result_prdd = df_from_rdd.toPandas()
result_single_pdf = df_from_pdf.toPandas()
assert_frame_equal(result_prdd, result_single_pdf)

def test_createDataFrame_with_array_type(self):
pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
df, df_arrow = self._createDataFrame_toggle(pdf)
Expand Down