diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7f05d48ade2b3..a9c706057af27 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -679,6 +679,12 @@ def test_udf(self): [row] = self.sqlCtx.sql("SELECT twoArgs('test', 1)").collect() self.assertEqual(row[0], 5) + def test_udf2(self): + self.sqlCtx.registerFunction("strlen", lambda string: len(string)) + self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test") + [res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() + self.assertEqual(4, res[0]) + def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar)