Skip to content

Commit

Permalink
removed wrapping function and use option conditional on py version in…
Browse files Browse the repository at this point in the history
…stead
  • Loading branch information
BryanCutler committed Apr 9, 2019
1 parent 26eaa1a commit 8dd7202
Showing 1 changed file with 27 additions and 28 deletions.
55 changes: 27 additions & 28 deletions python/pyspark/sql/tests/test_pandas_udf_grouped_map.py
Expand Up @@ -17,6 +17,7 @@

import datetime
import unittest
import sys

from collections import OrderedDict
from decimal import Decimal
Expand All @@ -31,22 +32,20 @@

if have_pandas:
import pandas as pd
from pandas.util.testing import assert_frame_equal as pd_assert_frame_equal
from pandas.util.testing import assert_frame_equal

if have_pyarrow:
import pyarrow as pa


def assert_frame_equal(left, right):
"""
Wrap Pandas function because pd.DataFrame.assign will infer mixed types (unicode/str)
w/ Python 2, so need to set check_column_type=False
"""
import sys
if sys.version < '3':
pd_assert_frame_equal(left, right, check_column_type=False)
else:
pd_assert_frame_equal(left, right)
"""
Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names
from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check
"""
if sys.version < '3':
_check_column_type = False
else:
_check_column_type = True


@unittest.skipIf(
Expand Down Expand Up @@ -145,9 +144,9 @@ def test_supported_types(self):
result3 = df.groupby('id').apply(udf3).sort('id').toPandas()
expected3 = expected1

assert_frame_equal(expected1, result1)
assert_frame_equal(expected2, result2)
assert_frame_equal(expected3, result3)
assert_frame_equal(expected1, result1, check_column_type=_check_column_type)
assert_frame_equal(expected2, result2, check_column_type=_check_column_type)
assert_frame_equal(expected3, result3, check_column_type=_check_column_type)

def test_array_type_correct(self):
df = self.data.withColumn("arr", array(col("id"))).repartition(1, "id")
Expand All @@ -165,7 +164,7 @@ def test_array_type_correct(self):

result = df.groupby('id').apply(udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(udf.func).reset_index(drop=True)
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

def test_register_grouped_map_udf(self):
foo_udf = pandas_udf(lambda x: x, "id long", PandasUDFType.GROUPED_MAP)
Expand All @@ -187,7 +186,7 @@ def foo(pdf):

result = df.groupby('id').apply(foo).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

def test_coerce(self):
df = self.data
Expand All @@ -201,7 +200,7 @@ def test_coerce(self):
result = df.groupby('id').apply(foo).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True)
expected = expected.assign(v=expected.v.astype('float64'))
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

def test_complex_groupby(self):
df = self.data
Expand All @@ -219,7 +218,7 @@ def normalize(pdf):
expected = pdf.groupby(pdf['id'] % 2 == 0, as_index=False).apply(normalize.func)
expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype('float64'))
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

def test_empty_groupby(self):
df = self.data
Expand All @@ -237,7 +236,7 @@ def normalize(pdf):
expected = normalize.func(pdf)
expected = expected.sort_values(['id', 'v']).reset_index(drop=True)
expected = expected.assign(norm=expected.norm.astype('float64'))
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

def test_datatype_string(self):
df = self.data
Expand All @@ -250,7 +249,7 @@ def test_datatype_string(self):

result = df.groupby('id').apply(foo_udf).sort('id').toPandas()
expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True)
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

def test_wrong_return_type(self):
with QuietTest(self.sc):
Expand Down Expand Up @@ -311,7 +310,7 @@ def test_timestamp_dst(self):
df = self.spark.createDataFrame(dt, 'timestamp').toDF('time')
foo_udf = pandas_udf(lambda pdf: pdf, 'time timestamp', PandasUDFType.GROUPED_MAP)
result = df.groupby('time').apply(foo_udf).sort('time')
assert_frame_equal(df.toPandas(), result.toPandas())
assert_frame_equal(df.toPandas(), result.toPandas(), check_column_type=_check_column_type)

def test_udf_with_key(self):
import numpy as np
Expand Down Expand Up @@ -365,26 +364,26 @@ def foo3(key, pdf):
expected1 = pdf.groupby('id', as_index=False)\
.apply(lambda x: udf1.func((x.id.iloc[0],), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected1, result1)
assert_frame_equal(expected1, result1, check_column_type=_check_column_type)

# Test groupby expression
result2 = df.groupby(df.id % 2).apply(udf1).sort('id', 'v').toPandas()
expected2 = pdf.groupby(pdf.id % 2, as_index=False)\
.apply(lambda x: udf1.func((x.id.iloc[0] % 2,), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected2, result2)
assert_frame_equal(expected2, result2, check_column_type=_check_column_type)

# Test complex groupby
result3 = df.groupby(df.id, df.v % 2).apply(udf2).sort('id', 'v').toPandas()
expected3 = pdf.groupby([pdf.id, pdf.v % 2], as_index=False)\
.apply(lambda x: udf2.func((x.id.iloc[0], (x.v % 2).iloc[0],), x))\
.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected3, result3)
assert_frame_equal(expected3, result3, check_column_type=_check_column_type)

# Test empty groupby
result4 = df.groupby().apply(udf3).sort('id', 'v').toPandas()
expected4 = udf3.func((), pdf)
assert_frame_equal(expected4, result4)
assert_frame_equal(expected4, result4, check_column_type=_check_column_type)

def test_column_order(self):

Expand Down Expand Up @@ -417,7 +416,7 @@ def change_col_order(pdf):
.select('id', 'u', 'v').toPandas()
pd_result = grouped_pdf.apply(change_col_order)
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

# Function returns a pdf with positional columns, indexed by range
def range_col_order(pdf):
Expand All @@ -436,7 +435,7 @@ def range_col_order(pdf):
pd_result = grouped_pdf.apply(range_col_order)
rename_pdf(pd_result, ['id', 'u', 'v'])
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

# Function returns a pdf with columns indexed with integers
def int_index(pdf):
Expand All @@ -454,7 +453,7 @@ def int_index(pdf):
pd_result = grouped_pdf.apply(int_index)
rename_pdf(pd_result, ['id', 'u', 'v'])
expected = pd_result.sort_values(['id', 'v']).reset_index(drop=True)
assert_frame_equal(expected, result)
assert_frame_equal(expected, result, check_column_type=_check_column_type)

@pandas_udf('id long, v int', PandasUDFType.GROUPED_MAP)
def column_name_typo(pdf):
Expand Down

0 comments on commit 8dd7202

Please sign in to comment.