Skip to content

Commit

Permalink
[SPARK-9114] [SQL] [PySpark] convert returned object from UDF into in…
Browse files Browse the repository at this point in the history
…ternal type

This PR also remove the duplicated code between registerFunction and UserDefinedFunction.

cc JoshRosen

Author: Davies Liu <davies@databricks.com>

Closes #7450 from davies/fix_return_type and squashes the following commits:

e80bf9f [Davies Liu] remove debugging code
f94b1f6 [Davies Liu] fix mima
8f9c58b [Davies Liu] convert returned object from UDF into internal type
  • Loading branch information
Davies Liu authored and davies committed Jul 20, 2015
1 parent 02181fb commit 9f913c4
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 61 deletions.
4 changes: 3 additions & 1 deletion project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.json.JsonRDD"),
// local function inside a method
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1")
"org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1"),
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.org$apache$spark$sql$UDFRegistration$$builder$24")
) ++ Seq(
// SPARK-8479 Add numNonzeros and numActives to Matrix.
ProblemFilters.exclude[MissingMethodProblem](
Expand Down
16 changes: 3 additions & 13 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.utils import install_exception_handler
from pyspark.sql.functions import UserDefinedFunction

try:
import pandas
Expand Down Expand Up @@ -191,19 +192,8 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
self._ssql_ctx.udf().registerPython(name,
bytearray(pickled_cmd),
env,
includes,
self._sc.pythonExec,
self._sc.pythonVer,
bvars,
self._sc._javaAccumulator,
returnType.json())
udf = UserDefinedFunction(f, returnType, name)
self._ssql_ctx.udf().registerPython(name, udf._judf)

def _inferSchemaFromList(self, data):
"""
Expand Down
15 changes: 8 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,23 +801,24 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
def __init__(self, func, returnType):
def __init__(self, func, returnType, name=None):
self.func = func
self.returnType = returnType
self._broadcast = None
self._judf = self._create_judf()
self._judf = self._create_judf(name)

def _create_judf(self):
f = self.func # put it in closure `func`
func = lambda _, it: map(lambda x: f(*x), it)
def _create_judf(self, name):
f, returnType = self.func, self.returnType # put them in closure `func`
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
if name is None:
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
sc.pythonExec, sc.pythonVer, broadcast_vars,
sc._javaAccumulator, jdt)
return judf
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,14 @@ def test_apply_schema_with_udt(self):
self.assertEquals(point, ExamplePoint(1.0, 2.0))

def test_udf_with_udt(self):
from pyspark.sql.tests import ExamplePoint
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.sc.parallelize([row]).toDF()
self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
Expand Down
44 changes: 8 additions & 36 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@ import java.util.{List => JList, Map => JMap}
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import org.apache.spark.{Accumulator, Logging}
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -40,44 +37,19 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {

private val functionRegistry = sqlContext.functionRegistry

protected[sql] def registerPython(
name: String,
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
stringDataType: String): Unit = {
protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
log.debug(
s"""
| Registering new PythonUDF:
| name: $name
| command: ${command.toSeq}
| envVars: $envVars
| pythonIncludes: $pythonIncludes
| pythonExec: $pythonExec
| dataType: $stringDataType
| command: ${udf.command.toSeq}
| envVars: ${udf.envVars}
| pythonIncludes: ${udf.pythonIncludes}
| pythonExec: ${udf.pythonExec}
| dataType: ${udf.dataType}
""".stripMargin)


val dataType = sqlContext.parseDataType(stringDataType)

def builder(e: Seq[Expression]): PythonUDF =
PythonUDF(
name,
command,
envVars,
pythonIncludes,
pythonExec,
pythonVer,
broadcastVars,
accumulator,
dataType,
e)

functionRegistry.registerFunction(name, builder)
functionRegistry.registerFunction(name, udf.builder)
}

// scalastyle:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -66,10 +66,14 @@ private[sql] case class UserDefinedPythonFunction(
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType) {

def builder(e: Seq[Expression]): PythonUDF = {
PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
accumulator, dataType, e)
}

/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
def apply(exprs: Column*): Column = {
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
broadcastVars, accumulator, dataType, exprs.map(_.expr))
val udf = builder(exprs.map(_.expr))
Column(udf)
}
}

0 comments on commit 9f913c4

Please sign in to comment.