Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/pyspark/sql/pandas/_typing/protocols/frame.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ Axis = Any
Level = Any

class DataFrameLike(Protocol):
columns: Axes
dtypes: List[Any]
def __init__(
self,
data: Any = ...,
Expand Down Expand Up @@ -422,7 +424,10 @@ class DataFrameLike(Protocol):
self, freq: Any = ..., axis: Any = ..., copy: Any = ...
) -> DataFrameLike: ...
def isin(self, values: Any) -> DataFrameLike: ...
def copy(self) -> DataFrameLike: ...
plot: Any = ...
hist: Any = ...
boxplot: Any = ...
sparse: Any = ...
loc: Any = ...
iloc: Any = ...
120 changes: 88 additions & 32 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,31 @@
import sys
import warnings
from collections import Counter
from typing import List, Optional, Type, Union, no_type_check, overload, TYPE_CHECKING

from pyspark.rdd import _load_from_socket
from pyspark.rdd import _load_from_socket # type: ignore[attr-defined]
from pyspark.sql.pandas.serializers import ArrowCollectSerializer
from pyspark.sql.types import IntegralType
from pyspark.sql.types import ByteType, ShortType, IntegerType, LongType, FloatType, \
DoubleType, BooleanType, MapType, TimestampType, TimestampNTZType, StructType, DataType
from pyspark.sql.utils import is_timestamp_ntz_preferred
from pyspark.traceback_utils import SCCallSiteSync

if TYPE_CHECKING:
import numpy as np
import pyarrow as pa

from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
from pyspark.sql import DataFrame


class PandasConversionMixin(object):
"""
Min-in for the conversion from Spark to pandas. Currently, only :class:`DataFrame`
can use this class.
"""

def toPandas(self):
def toPandas(self) -> "PandasDataFrameLike":
"""
Returns the contents of this :class:`DataFrame` as Pandas ``pandas.DataFrame``.

Expand Down Expand Up @@ -65,9 +73,9 @@ def toPandas(self):
import numpy as np
import pandas as pd

timezone = self.sql_ctx._conf.sessionLocalTimeZone()
timezone = self.sql_ctx._conf.sessionLocalTimeZone() # type: ignore[attr-defined]

if self.sql_ctx._conf.arrowPySparkEnabled():
if self.sql_ctx._conf.arrowPySparkEnabled(): # type: ignore[attr-defined]
use_arrow = True
try:
from pyspark.sql.pandas.types import to_arrow_schema
Expand All @@ -77,7 +85,7 @@ def toPandas(self):
to_arrow_schema(self.schema)
except Exception as e:

if self.sql_ctx._conf.arrowPySparkFallbackEnabled():
if self.sql_ctx._conf.arrowPySparkFallbackEnabled(): # type: ignore[attr-defined]
msg = (
"toPandas attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
Expand Down Expand Up @@ -106,7 +114,10 @@ def toPandas(self):
import pyarrow
# Rename columns to avoid duplicated column names.
tmp_column_names = ['col_{}'.format(i) for i in range(len(self.columns))]
self_destruct = self.sql_ctx._conf.arrowPySparkSelfDestructEnabled()
self_destruct = (
self.sql_ctx._conf # type: ignore[attr-defined]
.arrowPySparkSelfDestructEnabled()
)
batches = self.toDF(*tmp_column_names)._collect_as_arrow(
split_batches=self_destruct)
if len(batches) > 0:
Expand Down Expand Up @@ -158,7 +169,7 @@ def toPandas(self):
pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns)
column_counter = Counter(self.columns)

dtype = [None] * len(self.schema)
dtype = [None] * len(self.schema) # type: List[Optional[Type]]
for fieldIdx, field in enumerate(self.schema):
# For duplicate column name, we use `iloc` to access it.
if column_counter[field.name] > 1:
Expand All @@ -179,7 +190,7 @@ def toPandas(self):
if isinstance(field.dataType, IntegralType) and pandas_col.isnull().any():
dtype[fieldIdx] = np.float64
if isinstance(field.dataType, BooleanType) and pandas_col.isnull().any():
dtype[fieldIdx] = np.object
dtype[fieldIdx] = np.object # type: ignore[attr-defined]

df = pd.DataFrame()
for index, t in enumerate(dtype):
Expand Down Expand Up @@ -216,7 +227,7 @@ def toPandas(self):
return pdf

@staticmethod
def _to_corrected_pandas_type(dt):
def _to_corrected_pandas_type(dt: DataType) -> Optional[Type]:
"""
When converting Spark SQL records to Pandas :class:`DataFrame`, the inferred data type
may be wrong. This method gets the corrected data type for Pandas if that type may be
Expand All @@ -236,15 +247,15 @@ def _to_corrected_pandas_type(dt):
elif type(dt) == DoubleType:
return np.float64
elif type(dt) == BooleanType:
return np.bool
return np.bool # type: ignore[attr-defined]
elif type(dt) == TimestampType:
return np.datetime64
elif type(dt) == TimestampNTZType:
return np.datetime64
else:
return None

def _collect_as_arrow(self, split_batches=False):
def _collect_as_arrow(self, split_batches: bool = False) -> List["pa.RecordBatch"]:
"""
Returns all records as a list of ArrowRecordBatches, pyarrow must be installed
and available on driver and worker Python environments.
Expand All @@ -260,7 +271,9 @@ def _collect_as_arrow(self, split_batches=False):
assert isinstance(self, DataFrame)

with SCCallSiteSync(self._sc):
port, auth_secret, jsocket_auth_server = self._jdf.collectAsArrowToPython()
port, auth_secret, jsocket_auth_server = (
self._jdf.collectAsArrowToPython() # type: ignore[operator]
)

# Collect list of un-ordered batches where last element is a list of correct order indices
try:
Expand Down Expand Up @@ -301,27 +314,50 @@ class SparkConversionMixin(object):
Min-in for the conversion from pandas to Spark. Currently, only :class:`SparkSession`
can use this class.
"""
def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=True):

@overload
def createDataFrame(
self, data: "PandasDataFrameLike", samplingRatio: Optional[float] = ...
) -> "DataFrame":
...

@overload
def createDataFrame(
self,
data: "PandasDataFrameLike",
schema: Union[StructType, str],
verifySchema: bool = ...,
) -> "DataFrame":
...

def createDataFrame( # type: ignore[misc]
self,
data: "PandasDataFrameLike",
schema: Optional[Union[StructType, List[str]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True
) -> "DataFrame":
from pyspark.sql import SparkSession

assert isinstance(self, SparkSession)

from pyspark.sql.pandas.utils import require_minimum_pandas_version
require_minimum_pandas_version()

timezone = self._wrapped._conf.sessionLocalTimeZone()
timezone = self._wrapped._conf.sessionLocalTimeZone() # type: ignore[attr-defined]

# If no schema supplied by user then get the names of columns only
if schema is None:
schema = [str(x) if not isinstance(x, str) else
(x.encode('utf-8') if not isinstance(x, str) else x)
for x in data.columns]
schema = [str(x) if not isinstance(x, str) else x for x in data.columns]

if self._wrapped._conf.arrowPySparkEnabled() and len(data) > 0:
if (
self._wrapped._conf.arrowPySparkEnabled() # type: ignore[attr-defined]
and len(data) > 0
):
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
except Exception as e:
if self._wrapped._conf.arrowPySparkFallbackEnabled():
if self._wrapped._conf.arrowPySparkFallbackEnabled(): # type: ignore[attr-defined]
msg = (
"createDataFrame attempted Arrow optimization because "
"'spark.sql.execution.arrow.pyspark.enabled' is set to true; however, "
Expand All @@ -339,10 +375,17 @@ def createDataFrame(self, data, schema=None, samplingRatio=None, verifySchema=Tr
"has been set to false.\n %s" % str(e))
warnings.warn(msg)
raise
data = self._convert_from_pandas(data, schema, timezone)
return self._create_dataframe(data, schema, samplingRatio, verifySchema)

def _convert_from_pandas(self, pdf, schema, timezone):
converted_data = self._convert_from_pandas(data, schema, timezone)
return self._create_dataframe( # type: ignore[attr-defined]
converted_data, schema, samplingRatio, verifySchema
)

def _convert_from_pandas(
self,
pdf: "PandasDataFrameLike",
schema: Union[StructType, str, List[str]],
timezone: str
) -> List:
"""
Convert a pandas.DataFrame to list of records that can be used to make a DataFrame

Expand Down Expand Up @@ -398,7 +441,7 @@ def _convert_from_pandas(self, pdf, schema, timezone):
# Convert list of numpy records to python lists
return [r.tolist() for r in np_records]

def _get_numpy_record_dtype(self, rec):
def _get_numpy_record_dtype(self, rec: "np.recarray") -> Optional["np.dtype"]:
"""
Used when converting a pandas.DataFrame to Spark using to_records(), this will correct
the dtypes of fields in a record so they can be properly loaded into Spark.
Expand Down Expand Up @@ -429,7 +472,12 @@ def _get_numpy_record_dtype(self, rec):
record_type_list.append((str(col_names[i]), curr_type))
return np.dtype(record_type_list) if has_rec_fix else None

def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
def _create_from_pandas_with_arrow(
self,
pdf: "PandasDataFrameLike",
schema: Union[StructType, List[str]],
timezone: str
) -> "DataFrame":
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
to Arrow data, then sending to the JVM to parallelize. If a schema is passed in, the
Expand Down Expand Up @@ -483,27 +531,35 @@ def _create_from_pandas_with_arrow(self, pdf, schema, timezone):
arrow_data = [[(c, t) for (_, c), t in zip(pdf_slice.iteritems(), arrow_types)]
for pdf_slice in pdf_slices]

jsqlContext = self._wrapped._jsqlContext
jsqlContext = self._wrapped._jsqlContext # type: ignore[attr-defined]

safecheck = self._wrapped._conf.arrowSafeTypeConversion()
safecheck = self._wrapped._conf.arrowSafeTypeConversion() # type: ignore[attr-defined]
col_by_name = True # col by name only applies to StructType columns, can't happen here
ser = ArrowStreamPandasSerializer(timezone, safecheck, col_by_name)

@no_type_check
def reader_func(temp_filename):
return self._jvm.PythonSQLUtils.readArrowStreamFromFile(jsqlContext, temp_filename)

@no_type_check
def create_RDD_server():
return self._jvm.ArrowRDDServer(jsqlContext)

# Create Spark DataFrame from Arrow stream file, using one batch per partition
jrdd = self._sc._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
jdf = self._jvm.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
df = DataFrame(jdf, self._wrapped)
df._schema = schema
jrdd = (
self._sc # type: ignore[attr-defined]
._serialize_to_jvm(arrow_data, ser, reader_func, create_RDD_server)
)
jdf = (
self._jvm # type: ignore[attr-defined]
.PythonSQLUtils.toDataFrame(jrdd, schema.json(), jsqlContext)
)
df = DataFrame(jdf, self._wrapped) # type: ignore[attr-defined]
df._schema = schema # type: ignore[attr-defined]
return df


def _test():
def _test() -> None:
import doctest
from pyspark.sql import SparkSession
import pyspark.sql.pandas.conversion
Expand Down
59 changes: 0 additions & 59 deletions python/pyspark/sql/pandas/conversion.pyi

This file was deleted.

Loading