From cf0076b9257d65605ed3153f0b59cd89cdb145fc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 5 Jan 2018 07:57:51 +0800 Subject: [PATCH 1/3] fix --- python/pyspark/sql/context.py | 17 ++++- python/pyspark/sql/tests.py | 6 ++ .../apache/spark/sql/UDFRegistration.scala | 73 ++++++++++++------- .../org/apache/spark/sql/JavaRandUDF.java | 30 ++++++++ .../org/apache/spark/sql/JavaUDFSuite.java | 8 ++ 5 files changed, 104 insertions(+), 30 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaRandUDF.java diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b8d86cc098e94..0cffc33e59c85 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -218,7 +218,7 @@ def registerFunction(self, name, f, returnType=StringType()): @ignore_unicode_prefix @since(2.1) - def registerJavaFunction(self, name, javaClassName, returnType=None): + def registerJavaFunction(self, name, javaClassName, returnType=None, deterministic=True): """Register a java UDF so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. @@ -226,6 +226,8 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): :param name: name of the UDF :param javaClassName: fully qualified name of java class :param returnType: a :class:`pyspark.sql.types.DataType` object + :param deterministic: a flag indicating if the UDF is deterministic. Deterministic UDF + returns same result each time it is invoked with a particular input. >>> sqlContext.registerJavaFunction("javaStringLength", ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) @@ -236,11 +238,18 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() [Row(UDF:javaStringLength2(test)=4)] + >>> from pyspark.sql.types import DoubleType + >>> sqlContext.registerJavaFunction("javaRand", + ... "test.org.apache.spark.sql.JavaRandUDF", DoubleType(), deterministic=False) + >>> sqlContext.sql("SELECT javaRand(3)").collect() # doctest: +SKIP + [Row(UDF:javaRand(3)=3.12345)] + """ jdt = None if returnType is not None: jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) - self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + self.sparkSession._jsparkSession.udf().registerJava( + name, javaClassName, jdt, deterministic) @ignore_unicode_prefix @since(2.3) @@ -578,8 +587,8 @@ def __init__(self, sqlContext): def register(self, name, f, returnType=StringType()): return self.sqlContext.registerFunction(name, f, returnType) - def registerJavaFunction(self, name, javaClassName, returnType=None): - self.sqlContext.registerJavaFunction(name, javaClassName, returnType) + def registerJavaFunction(self, name, javaClassName, returnType=None, deterministic=True): + self.sqlContext.registerJavaFunction(name, javaClassName, returnType, deterministic) def registerJavaUDAF(self, name, javaClassName): self.sqlContext.registerJavaUDAF(name, javaClassName) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6dc767f9ec46e..96d8af68df0e4 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -549,6 +549,12 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + def test_java_udf(self): + self.spark.udf.registerJavaFunction("javaRand", "test.org.apache.spark.sql.JavaRandUDF", + DoubleType(), deterministic=False) + row = self.spark.sql("SELECT javaRand(3)").collect() + self.assertTrue(row[0] >= 3.0) + def test_non_existed_udf(self): spark = self.spark self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index f94baef39dfad..64abc4604e4dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -601,7 +601,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Register a Java UDF class using reflection, for use from pyspark + * Register a deterministic Java UDF class using reflection, for use from pyspark * * @param name udf name * @param className fully qualified class name of udf @@ -609,7 +609,24 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * via reflection. */ private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + registerJava(name, className, returnDataType, deterministic = true) + } + /** + * Register a Java UDF class using reflection, for use from pyspark + * + * @param name udf name + * @param className fully qualified class name of udf + * @param returnDataType return type of udf. If it is null, spark would try to infer + * via reflection. + * @param deterministic True if the UDF is deterministic. Deterministic UDF returns same result + * each time it is invoked with a particular input. + */ + private[sql] def registerJava( + name: String, + className: String, + returnDataType: DataType, + deterministic: Boolean): Unit = { try { val clazz = Utils.classForName(className) val udfInterfaces = clazz.getGenericInterfaces @@ -622,40 +639,44 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class $className") } else { try { - val udf = clazz.newInstance() + val javaUDFClass = clazz.newInstance() val udfReturnType = udfInterfaces(0).getActualTypeArguments.last var returnType = returnDataType if (returnType == null) { returnType = JavaTypeInference.inferDataType(udfReturnType)._1 } - udfInterfaces(0).getActualTypeArguments.length match { - case 1 => register(name, udf.asInstanceOf[UDF0[_]], returnType) - case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) - case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) - case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) - case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType) - case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType) - case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType) - case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType) - case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType) - case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType) - case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType) - case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + import org.apache.spark.sql.functions.udf + + val javaUDF = udfInterfaces(0).getActualTypeArguments.length match { + case 1 => udf(javaUDFClass.asInstanceOf[UDF0[_]], returnType) + case 2 => udf(javaUDFClass.asInstanceOf[UDF1[_, _]], returnType) + case 3 => udf(javaUDFClass.asInstanceOf[UDF2[_, _, _]], returnType) + case 4 => udf(javaUDFClass.asInstanceOf[UDF3[_, _, _, _]], returnType) + case 5 => udf(javaUDFClass.asInstanceOf[UDF4[_, _, _, _, _]], returnType) + case 6 => udf(javaUDFClass.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType) + case 7 => udf(javaUDFClass.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType) + case 8 => udf(javaUDFClass.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType) + case 9 => udf(javaUDFClass.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType) + case 10 => udf(javaUDFClass.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType) + case 11 => udf(javaUDFClass.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType) + case 12 => udf(javaUDFClass.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 13 => udf(javaUDFClass.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 14 => udf(javaUDFClass.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 15 => udf(javaUDFClass.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 16 => udf(javaUDFClass.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 17 => udf(javaUDFClass.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 18 => udf(javaUDFClass.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 19 => udf(javaUDFClass.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 20 => udf(javaUDFClass.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 21 => udf(javaUDFClass.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 22 => udf(javaUDFClass.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 23 => udf(javaUDFClass.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case n => throw new AnalysisException(s"UDF class with $n type arguments is not supported.") } + val javaUDFWithDeterminism = if (deterministic) javaUDF else javaUDF.asNondeterministic() + register(name, javaUDFWithDeterminism.withName(name)) } catch { case e @ (_: InstantiationException | _: IllegalArgumentException) => throw new AnalysisException(s"Can not instantiate class $className, please make sure it has public non argument constructor") diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRandUDF.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRandUDF.java new file mode 100644 index 0000000000000..5bf4d9f19700a --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRandUDF.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.sql.api.java.UDF1; + +/** + * It is used for register Java UDF from PySpark + */ +public class JavaRandUDF implements UDF1 { + @Override + public Double call(Integer i) { + return i + Math.random(); + } +} \ No newline at end of file diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 5bf1888826186..a6fd4823efc77 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -121,4 +121,12 @@ public void udf6Test() { Row result = spark.sql("SELECT returnOne()").head(); Assert.assertEquals(1, result.getInt(0)); } + + @SuppressWarnings("unchecked") + @Test + public void udf7Test() { + spark.udf().registerJava("randUDF", JavaRandUDF.class.getName(), DataTypes.DoubleType, false); + Row result = spark.sql("SELECT randUDF(1)").head(); + Assert.assertTrue(result.getDouble(0) >= 0.0); + } } From d63e816672da36266b81989ba6c07bb398aaaf4e Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 5 Jan 2018 08:03:42 +0800 Subject: [PATCH 2/3] style --- .../src/main/scala/org/apache/spark/sql/UDFRegistration.scala | 1 - .../src/test/java/test/org/apache/spark/sql/JavaRandUDF.java | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 64abc4604e4dd..83d7d882d8071 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -647,7 +647,6 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } import org.apache.spark.sql.functions.udf - val javaUDF = udfInterfaces(0).getActualTypeArguments.length match { case 1 => udf(javaUDFClass.asInstanceOf[UDF0[_]], returnType) case 2 => udf(javaUDFClass.asInstanceOf[UDF1[_, _]], returnType) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRandUDF.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRandUDF.java index 5bf4d9f19700a..df806bc18288d 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaRandUDF.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRandUDF.java @@ -27,4 +27,4 @@ public class JavaRandUDF implements UDF1 { public Double call(Integer i) { return i + Math.random(); } -} \ No newline at end of file +} From 7e4f3c0f0c4082bf166cec0e72f9f86f5d23aac8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 5 Jan 2018 08:42:48 +0800 Subject: [PATCH 3/3] fix --- python/pyspark/sql/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 0cffc33e59c85..9ef157285913b 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -223,10 +223,11 @@ def registerJavaFunction(self, name, javaClassName, returnType=None, determinist In addition to a name and the function itself, the return type can be optionally specified. When the return type is not specified we would infer it via reflection. + :param name: name of the UDF :param javaClassName: fully qualified name of java class :param returnType: a :class:`pyspark.sql.types.DataType` object - :param deterministic: a flag indicating if the UDF is deterministic. Deterministic UDF + :param deterministic: a flag indicating if the UDF is deterministic. Deterministic UDF returns same result each time it is invoked with a particular input. >>> sqlContext.registerJavaFunction("javaStringLength",