Skip to content

Commit

Permalink
[SPARK-25351][SQL][PYTHON] Handle Pandas category type when convertin…
Browse files Browse the repository at this point in the history
…g from Python with Arrow

Handle Pandas category type while converting from python with Arrow enabled. The category column will be converted to whatever type the category elements are as is the case with Arrow disabled.

### Does this PR introduce any user-facing change?
No

### How was this patch tested?
New unit tests were added for `createDataFrame` and scalar `pandas_udf`

Closes #26585 from jalpan-randeri/feature-pyarrow-dictionary-type.

Authored-by: Jalpan Randeri <randerij@amazon.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
  • Loading branch information
Jalpan Randeri authored and BryanCutler committed May 28, 2020
1 parent efe7fd2 commit 339b0ec
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/pyspark/sql/pandas/serializers.py
Expand Up @@ -154,6 +154,9 @@ def create_array(s, t):
# Ensure timestamp series are in expected form for Spark internal representation
if t is not None and pa.types.is_timestamp(t):
s = _check_series_convert_timestamps_internal(s, self._timezone)
elif type(s.dtype) == pd.CategoricalDtype:
# Note: This can be removed once minimum pyarrow version is >= 0.16.1
s = s.astype(s.dtypes.categories.dtype)
try:
array = pa.Array.from_pandas(s, mask=mask, type=t, safe=self._safecheck)
except pa.ArrowException as e:
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/types.py
Expand Up @@ -114,6 +114,8 @@ def from_arrow_type(at):
return StructType(
[StructField(field.name, from_arrow_type(field.type), nullable=field.nullable)
for field in at])
elif types.is_dictionary(at):
spark_type = from_arrow_type(at.value_type)
else:
raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
return spark_type
Expand Down
26 changes: 26 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Expand Up @@ -415,6 +415,32 @@ def run_test(num_records, num_parts, max_records, use_delay=False):
for case in cases:
run_test(*case)

def test_createDateFrame_with_category_type(self):
pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
pdf["B"] = pdf["A"].astype('category')
category_first_element = dict(enumerate(pdf['B'].cat.categories))[0]

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": True}):
arrow_df = self.spark.createDataFrame(pdf)
arrow_type = arrow_df.dtypes[1][1]
result_arrow = arrow_df.toPandas()
arrow_first_category_element = result_arrow["B"][0]

with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df = self.spark.createDataFrame(pdf)
spark_type = df.dtypes[1][1]
result_spark = df.toPandas()
spark_first_category_element = result_spark["B"][0]

assert_frame_equal(result_spark, result_arrow)

# ensure original category elements are string
assert isinstance(category_first_element, str)
# spark data frame and arrow execution mode enabled data frame type must match pandas
assert spark_type == arrow_type == 'string'
assert isinstance(arrow_first_category_element, str)
assert isinstance(spark_first_category_element, str)


@unittest.skipIf(
not have_pandas or not have_pyarrow,
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/test_pandas_udf_scalar.py
Expand Up @@ -897,6 +897,27 @@ def test_timestamp_dst(self):
result = df.withColumn('time', foo_udf(df.time))
self.assertEquals(df.collect(), result.collect())

def test_udf_category_type(self):

@pandas_udf('string')
def to_category_func(x):
return x.astype('category')

pdf = pd.DataFrame({"A": [u"a", u"b", u"c", u"a"]})
df = self.spark.createDataFrame(pdf)
df = df.withColumn("B", to_category_func(df['A']))
result_spark = df.toPandas()

spark_type = df.dtypes[1][1]
# spark data frame and arrow execution mode enabled data frame type must match pandas
assert spark_type == 'string'

# Check result value of column 'B' must be equal to column 'A'
for i in range(0, len(result_spark["A"])):
assert result_spark["A"][i] == result_spark["B"][i]
assert isinstance(result_spark["A"][i], str)
assert isinstance(result_spark["B"][i], str)

@unittest.skipIf(sys.version_info[:2] < (3, 5), "Type hints are supported from Python 3.5.")
def test_type_annotation(self):
# Regression test to check if type hints can be used. See SPARK-23569.
Expand Down

0 comments on commit 339b0ec

Please sign in to comment.