From 48bfb57f1254a95735929fc90baac8e08d4e77d9 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Tue, 17 Nov 2015 13:54:01 +0800 Subject: [PATCH 1/5] [SPARK-11775][PYSPARK][SQL] Allow PySpark to register Java UDF --- python/pyspark/sql/context.py | 5 ++ .../apache/spark/sql/UDFRegistration.scala | 64 +++++++++++++++++++ 2 files changed, 69 insertions(+) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 8264dcf8a97d2..e48cbc5d06896 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -202,6 +202,11 @@ def registerFunction(self, name, f, returnType=StringType()): """ self.sparkSession.catalog.registerFunction(name, f, returnType) + def registerJavaFunction(self, name, javaClassName, returnType): + jdt = self._ssql_ctx.parseDataType(returnType.json()) + self._ssql_ctx.udf().registerJava(name, javaClassName, jdt) + + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 617a14793697b..838e56168505c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,11 +17,17 @@ package org.apache.spark.sql + +import java.io.IOException +import java.util.{List => JList, Map => JMap} + import scala.reflect.runtime.universe.TypeTag import scala.util.Try import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging +import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl + import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.ScalaReflection @@ -30,6 +36,7 @@ import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} import org.apache.spark.sql.types.DataType +import org.apache.spark.util.Utils /** * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this. @@ -413,6 +420,63 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// + /** + * Register a Java UDF class + * @param name + * @param className + * @param returnType + */ + def registerJava(name: String, className: String, returnType: DataType): Unit = { + + try { + // scalastyle:off classforname + val clazz = Class.forName(className, false, Utils.getContextOrSparkClassLoader) + // scalastyle:on classforname + val udfInterfaces = clazz.getGenericInterfaces.filter(_.isInstanceOf[ParameterizedTypeImpl]).map(_.asInstanceOf[ParameterizedTypeImpl]) + .filter(_.getRawType.getName.startsWith("org.apache.spark.sql.api.java.UDF")) + if (udfInterfaces.length == 0) { + throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") + } else if (udfInterfaces.length > 1) { + throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + } else { + try { + val udf = clazz.newInstance() + udfInterfaces(0).getActualTypeArguments.length match { + case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) + case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) + case 4 => register(name, udf.asInstanceOf[UDF3[_, _, _, _]], returnType) + case 5 => register(name, udf.asInstanceOf[UDF4[_, _, _, _, _]], returnType) + case 6 => register(name, udf.asInstanceOf[UDF5[_, _, _, _, _, _]], returnType) + case 7 => register(name, udf.asInstanceOf[UDF6[_, _, _, _, _, _, _]], returnType) + case 8 => register(name, udf.asInstanceOf[UDF7[_, _, _, _, _, _, _, _]], returnType) + case 9 => register(name, udf.asInstanceOf[UDF8[_, _, _, _, _, _, _, _, _]], returnType) + case 10 => register(name, udf.asInstanceOf[UDF9[_, _, _, _, _, _, _, _, _, _]], returnType) + case 11 => register(name, udf.asInstanceOf[UDF10[_, _, _, _, _, _, _, _, _, _, _]], returnType) + case 12 => register(name, udf.asInstanceOf[UDF11[_, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 13 => register(name, udf.asInstanceOf[UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 14 => register(name, udf.asInstanceOf[UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 15 => register(name, udf.asInstanceOf[UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 16 => register(name, udf.asInstanceOf[UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 17 => register(name, udf.asInstanceOf[UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 18 => register(name, udf.asInstanceOf[UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 19 => register(name, udf.asInstanceOf[UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 20 => register(name, udf.asInstanceOf[UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) + case n => logError(s"UDF class with ${n} type arguments is not supported ") + } + } catch { + case e @ (_: InstantiationException | _: IllegalArgumentException) => + logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + } + } + } catch { + case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + } + + } + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 From f2c9bd8bb5a0ebfa43c728c47239d0aa8f8f4c88 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 12 May 2016 20:33:51 +0800 Subject: [PATCH 2/5] add unit test --- python/pyspark/sql/context.py | 1 - .../org/apache/spark/sql/UDFRegistration.scala | 1 + .../test/org/apache/spark/sql/JavaUDFSuite.java | 16 ++++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index e48cbc5d06896..3e1b57979df78 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -206,7 +206,6 @@ def registerJavaFunction(self, name, javaClassName, returnType): jdt = self._ssql_ctx.parseDataType(returnType.json()) self._ssql_ctx.udf().registerJava(name, javaClassName, jdt) - # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 838e56168505c..db412e51d3c26 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl +import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.ScalaReflection diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 2274912521a56..9c646a1fc929e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -87,4 +87,20 @@ public Integer call(String str1, String str2) { Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); } + + public static class StringLengthTest implements UDF2 { + @Override + public Integer call(String str1, String str2) throws Exception { + return new Integer(str1.length() + str2.length()); + } + } + + @SuppressWarnings("unchecked") + @Test + public void udf3Test() { + spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(), + DataTypes.IntegerType); + Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); + Assert.assertEquals(9, result.getInt(0)); + } } From e452050b047381a4ecd1a073c807d1f58548eb0b Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Sun, 9 Oct 2016 16:01:40 +0800 Subject: [PATCH 3/5] add more test and doc --- python/pyspark/sql/context.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 3e1b57979df78..0bcfce708ae93 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -28,7 +28,7 @@ from pyspark.sql.dataframe import DataFrame from pyspark.sql.readwriter import DataFrameReader from pyspark.sql.streaming import DataStreamReader -from pyspark.sql.types import Row, StringType +from pyspark.sql.types import IntegerType, Row, StringType from pyspark.sql.utils import install_exception_handler __all__ = ["SQLContext", "HiveContext", "UDFRegistration"] @@ -202,9 +202,25 @@ def registerFunction(self, name, f, returnType=StringType()): """ self.sparkSession.catalog.registerFunction(name, f, returnType) - def registerJavaFunction(self, name, javaClassName, returnType): - jdt = self._ssql_ctx.parseDataType(returnType.json()) - self._ssql_ctx.udf().registerJava(name, javaClassName, jdt) + @ignore_unicode_prefix + @since(2.1) + def registerJavaFunction(self, name, javaClassName, returnType=StringType()): + """Register a java UDF so it can be used in SQL statements. + + In addition to a name and the function itself, the return type can be optionally specified. + When the return type is not given it default to a string and conversion will automatically + be done. For any other return type, the produced object must match the specified type. + :param name: name of the UDF + :param javaClassName: fully qualified name of java class + :param returnType: a :class:`pyspark.sql.types.DataType` object + + >>> sqlContext.registerJavaFunction("stringLengthString", + ... "test.org.apache.spark.sql.StringLengthTest", IntegerType()) + >>> sqlContext.sql("SELECT stringLengthString('test')").collect() + [Row(stringLengthString(test)=u'4')] + """ + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): From d4818217dc6e29a72a4e470dbe08cda197933162 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 12 Oct 2016 11:21:27 +0800 Subject: [PATCH 4/5] address comments --- python/pyspark/sql/context.py | 22 ++++++--- .../spark/sql/test/JavaStringLength.java | 30 ++++++++++++ .../apache/spark/sql/UDFRegistration.scala | 49 +++++++++++++------ .../org/apache/spark/sql/JavaUDFSuite.java | 7 ++- 4 files changed, 83 insertions(+), 25 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/test/JavaStringLength.java diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 0bcfce708ae93..2232c697e6434 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -204,22 +204,28 @@ def registerFunction(self, name, f, returnType=StringType()): @ignore_unicode_prefix @since(2.1) - def registerJavaFunction(self, name, javaClassName, returnType=StringType()): + def registerJavaFunction(self, name, javaClassName, returnType=None): """Register a java UDF so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it default to a string and conversion will automatically - be done. For any other return type, the produced object must match the specified type. + When the return type is not given it would infer the returnType via reflection. :param name: name of the UDF :param javaClassName: fully qualified name of java class :param returnType: a :class:`pyspark.sql.types.DataType` object - >>> sqlContext.registerJavaFunction("stringLengthString", - ... "test.org.apache.spark.sql.StringLengthTest", IntegerType()) - >>> sqlContext.sql("SELECT stringLengthString('test')").collect() - [Row(stringLengthString(test)=u'4')] + >>> sqlContext.registerJavaFunction("javaStringLength", + ... "org.apache.spark.sql.test.JavaStringLength", IntegerType()) + >>> sqlContext.sql("SELECT javaStringLength('test')").collect() + [Row(UDF(test)=4)] + >>> sqlContext.registerJavaFunction("javaStringLength2", + ... "org.apache.spark.sql.test.JavaStringLength") + >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() + [Row(UDF(test)=4)] + """ - jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) + jdt = None + if returnType is not None: + jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) # TODO(andrew): delete this once we refactor things to take in SparkSession diff --git a/sql/core/src/main/java/org/apache/spark/sql/test/JavaStringLength.java b/sql/core/src/main/java/org/apache/spark/sql/test/JavaStringLength.java new file mode 100644 index 0000000000000..8938d7a1e4c55 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/test/JavaStringLength.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test; + +import org.apache.spark.sql.api.java.UDF1; + +/** + * It is used for register Java UDF from PySpark + */ +public class JavaStringLength implements UDF1 { + @Override + public Integer call(String str) throws Exception { + return new Integer(str.length()); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index db412e51d3c26..8031989892523 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,17 +17,13 @@ package org.apache.spark.sql - import java.io.IOException -import java.util.{List => JList, Map => JMap} +import java.lang.reflect.{ParameterizedType, Type} import scala.reflect.runtime.universe.TypeTag import scala.util.Try import org.apache.spark.annotation.InterfaceStability -import org.apache.spark.internal.Logging -import sun.reflect.generics.reflectiveObjects.ParameterizedTypeImpl - import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry @@ -36,7 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.expressions.{UserDefinedAggregateFunction, UserDefinedFunction} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{DataType, DataTypes} import org.apache.spark.util.Utils /** @@ -422,19 +418,21 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Register a Java UDF class - * @param name - * @param className - * @param returnType + * Register a Java UDF class using reflection, for use from pyspark + * + * @param name udf name + * @param className fully qualified class name of udf + * @param returnDataType return type of udf. If it is null, spark would try to infer + * via reflection. */ - def registerJava(name: String, className: String, returnType: DataType): Unit = { + def registerJava(name: String, className: String, returnDataType: DataType): Unit = { try { - // scalastyle:off classforname - val clazz = Class.forName(className, false, Utils.getContextOrSparkClassLoader) - // scalastyle:on classforname - val udfInterfaces = clazz.getGenericInterfaces.filter(_.isInstanceOf[ParameterizedTypeImpl]).map(_.asInstanceOf[ParameterizedTypeImpl]) - .filter(_.getRawType.getName.startsWith("org.apache.spark.sql.api.java.UDF")) + val clazz = Utils.classForName(className) + val udfInterfaces = clazz.getGenericInterfaces + .filter(_.isInstanceOf[ParameterizedType]) + .map(_.asInstanceOf[ParameterizedType]) + .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) if (udfInterfaces.length == 0) { throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") } else if (udfInterfaces.length > 1) { @@ -442,6 +440,25 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } else { try { val udf = clazz.newInstance() + val udfReturnType = udfInterfaces(0).getActualTypeArguments.last + var returnType = returnDataType + if (returnType == null) { + if (udfReturnType.isInstanceOf[Class[_]]) { + returnType = udfReturnType.asInstanceOf[Class[_]].getCanonicalName match { + case "java.lang.String" => DataTypes.BooleanType + case "java.lang.Double" => DataTypes.DoubleType + case "java.lang.Float" => DataTypes.FloatType + case "java.lang.Byte" => DataTypes.ByteType + case "java.lang.Integer" => DataTypes.IntegerType + case "java.lang.Long" => DataTypes.LongType + case "java.lang.Short" => DataTypes.ShortType + case t => throw new RuntimeException("Can not infer the return type: ${udfReturnType}, please declare returnType explicitly.") + } + } else { + throw new RuntimeException("The return type of UDF is not valid, returnType:" + udfReturnType) + } + } + udfInterfaces(0).getActualTypeArguments.length match { case 2 => register(name, udf.asInstanceOf[UDF1[_, _]], returnType) case 3 => register(name, udf.asInstanceOf[UDF2[_, _, _]], returnType) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java index 9c646a1fc929e..8bf3278c43880 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java @@ -99,8 +99,13 @@ public Integer call(String str1, String str2) throws Exception { @Test public void udf3Test() { spark.udf().registerJava("stringLengthTest", StringLengthTest.class.getName(), - DataTypes.IntegerType); + DataTypes.IntegerType); Row result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); Assert.assertEquals(9, result.getInt(0)); + + // returnType is not provided + spark.udf().registerJava("stringLengthTest2", StringLengthTest.class.getName(), null); + result = spark.sql("SELECT stringLengthTest('test', 'test2')").head(); + Assert.assertEquals(9, result.getInt(0)); } } From 8171b8515107ea66fa277c52823167d206b4756a Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 13 Oct 2016 11:41:15 +0800 Subject: [PATCH 5/5] address comments --- python/pyspark/sql/context.py | 6 +++--- .../sql/catalyst/JavaTypeInference.scala | 2 +- .../apache/spark/sql/UDFRegistration.scala | 21 +++++-------------- .../apache/spark/sql}/JavaStringLength.java | 2 +- 4 files changed, 10 insertions(+), 21 deletions(-) rename sql/core/src/{main/java/org/apache/spark/sql/test => test/java/test/org/apache/spark/sql}/JavaStringLength.java (96%) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 2232c697e6434..de4c335ad2752 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -208,17 +208,17 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): """Register a java UDF so it can be used in SQL statements. In addition to a name and the function itself, the return type can be optionally specified. - When the return type is not given it would infer the returnType via reflection. + When the return type is not specified we would infer it via reflection. :param name: name of the UDF :param javaClassName: fully qualified name of java class :param returnType: a :class:`pyspark.sql.types.DataType` object >>> sqlContext.registerJavaFunction("javaStringLength", - ... "org.apache.spark.sql.test.JavaStringLength", IntegerType()) + ... "test.org.apache.spark.sql.JavaStringLength", IntegerType()) >>> sqlContext.sql("SELECT javaStringLength('test')").collect() [Row(UDF(test)=4)] >>> sqlContext.registerJavaFunction("javaStringLength2", - ... "org.apache.spark.sql.test.JavaStringLength") + ... "test.org.apache.spark.sql.JavaStringLength") >>> sqlContext.sql("SELECT javaStringLength2('test')").collect() [Row(UDF(test)=4)] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index e6f61b00ebd70..04f0cfce883f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -59,7 +59,7 @@ object JavaTypeInference { * @param typeToken Java type * @return (SQL data type, nullable) */ - private def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { + private[sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = { typeToken.getRawType match { case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) => (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 8031989892523..0444ad10d34fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -23,11 +23,13 @@ import java.lang.reflect.{ParameterizedType, Type} import scala.reflect.runtime.universe.TypeTag import scala.util.Try +import com.google.common.reflect.TypeToken + import org.apache.spark.annotation.InterfaceStability import org.apache.spark.internal.Logging import org.apache.spark.sql.api.java._ +import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF} import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.execution.python.UserDefinedPythonFunction @@ -425,7 +427,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends * @param returnDataType return type of udf. If it is null, spark would try to infer * via reflection. */ - def registerJava(name: String, className: String, returnDataType: DataType): Unit = { + private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = { try { val clazz = Utils.classForName(className) @@ -443,20 +445,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val udfReturnType = udfInterfaces(0).getActualTypeArguments.last var returnType = returnDataType if (returnType == null) { - if (udfReturnType.isInstanceOf[Class[_]]) { - returnType = udfReturnType.asInstanceOf[Class[_]].getCanonicalName match { - case "java.lang.String" => DataTypes.BooleanType - case "java.lang.Double" => DataTypes.DoubleType - case "java.lang.Float" => DataTypes.FloatType - case "java.lang.Byte" => DataTypes.ByteType - case "java.lang.Integer" => DataTypes.IntegerType - case "java.lang.Long" => DataTypes.LongType - case "java.lang.Short" => DataTypes.ShortType - case t => throw new RuntimeException("Can not infer the return type: ${udfReturnType}, please declare returnType explicitly.") - } - } else { - throw new RuntimeException("The return type of UDF is not valid, returnType:" + udfReturnType) - } + returnType = JavaTypeInference.inferDataType(TypeToken.of(udfReturnType))._1 } udfInterfaces(0).getActualTypeArguments.length match { diff --git a/sql/core/src/main/java/org/apache/spark/sql/test/JavaStringLength.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/test/JavaStringLength.java rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java index 8938d7a1e4c55..b90224f2ae397 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/test/JavaStringLength.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaStringLength.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.test; +package test.org.apache.spark.sql; import org.apache.spark.sql.api.java.UDF1;