Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5274][SQL] Reconcile Java and Scala UDFRegistration. #4056

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,14 +1281,14 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc._gateway._gateway_client)
includes = ListConverter().convert(self._sc._python_includes,
self._sc._gateway._gateway_client)
self._ssql_ctx.registerPython(name,
bytearray(pickled_command),
env,
includes,
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
returnType.json())
self._ssql_ctx.udf().registerPython(name,
bytearray(pickled_command),
env,
includes,
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
returnType.json())

def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
Expand Down
29 changes: 28 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with CacheManager
with ExpressionConversions
with UDFRegistration
with Serializable {

self =>
Expand Down Expand Up @@ -338,6 +337,34 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
val experimental: ExperimentalMethods = new ExperimentalMethods(this)

/**
* A collection of methods for registering user-defined functions (UDF).
*
* The following example registers a Scala closure as UDF:
* {{{
* sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1)
* }}}
*
* The following example registers a UDF in Java:
* {{{
* sqlContext.udf().register("myUDF",
* new UDF2<Integer, String, String>() {
* @Override
* public String call(Integer arg1, String arg2) {
* return arg2 + arg1;
* }
* }, DataTypes.StringType);
* }}}
*
* Or, to use Java 8 lambda syntax:
* {{{
* sqlContext.udf().register("myUDF",
* (Integer arg1, String arg2) -> arg2 + arg1),
* DataTypes.StringType);
* }}}
*/
val udf: UDFRegistration = new UDFRegistration(this)

protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext

Expand Down
692 changes: 631 additions & 61 deletions sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
}

test("SPARK-3371 Renaming a function expression with group by gives error") {
registerFunction("len", (s: String) => s.length)
udf.register("len", (s: String) => s.length)
checkAnswer(
sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), 1)
}
Expand Down
9 changes: 4 additions & 5 deletions sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,22 @@ case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {

test("Simple UDF") {
registerFunction("strLenScala", (_: String).length)
udf.register("strLenScala", (_: String).length)
assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4)
}

test("ZeroArgument UDF") {
registerFunction("random0", () => { Math.random()})
udf.register("random0", () => { Math.random()})
assert(sql("SELECT random0()").first().getDouble(0) >= 0.0)
}

test("TwoArgument UDF") {
registerFunction("strLenScala", (_: String).length + (_:Int))
udf.register("strLenScala", (_: String).length + (_:Int))
assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
}


test("struct UDF") {
registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))

val result=
sql("SELECT returnStruct('test', 'test2') as ret")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class UserDefinedTypeSuite extends QueryTest {
}

test("UDTs and UDFs") {
registerFunction("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
pointsRDD.registerTempTable("points")
checkAnswer(
sql("SELECT testType(features) from points"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class HiveUdfSuite extends QueryTest {
import TestHive._

test("spark sql udf test that returns a struct") {
registerFunction("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5))
udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5))
assert(sql(
"""
|SELECT getStruct(1).f1,
Expand Down