Skip to content

Commit

Permalink
Modify tests for unsupported types.
Browse files Browse the repository at this point in the history
  • Loading branch information
ueshin committed Dec 22, 2017
1 parent bcadac8 commit e29f833
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3194,10 +3194,11 @@ def create_pandas_data_frame(self):
return pd.DataFrame(data=data_dict)

def test_unsupported_datatype(self):
schema = StructType([StructField("decimal", DecimalType(), True)])
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
with QuietTest(self.sc):
self.assertRaises(Exception, lambda: df.toPandas())
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.toPandas()

def test_null_conversion(self):
df_null = self.spark.createDataFrame([tuple([None for _ in range(len(self.data[0]))])] +
Expand Down Expand Up @@ -3733,12 +3734,12 @@ def test_vectorized_udf_varargs(self):

def test_vectorized_udf_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col
schema = StructType([StructField("dt", DecimalType(), True)])
schema = StructType([StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(None,)], schema=schema)
f = pandas_udf(lambda x: x, DecimalType())
f = pandas_udf(lambda x: x, MapType(StringType(), IntegerType()))
with QuietTest(self.sc):
with self.assertRaisesRegexp(Exception, 'Unsupported data type'):
df.select(f(col('dt'))).collect()
df.select(f(col('map'))).collect()

def test_vectorized_udf_null_date(self):
from pyspark.sql.functions import pandas_udf, col
Expand Down Expand Up @@ -4032,7 +4033,8 @@ def test_wrong_args(self):
def test_unsupported_types(self):
from pyspark.sql.functions import pandas_udf, col, PandasUDFType
schema = StructType(
[StructField("id", LongType(), True), StructField("dt", DecimalType(), True)])
[StructField("id", LongType(), True),
StructField("map", MapType(StringType(), IntegerType()), True)])
df = self.spark.createDataFrame([(1, None,)], schema=schema)
f = pandas_udf(lambda x: x, df.schema, PandasUDFType.GROUP_MAP)
with QuietTest(self.sc):
Expand Down

0 comments on commit e29f833

Please sign in to comment.