Skip to content

Commit

Permalink
convert returned object from UDF into internal type
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 16, 2015
1 parent 4ea6480 commit 8f9c58b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 60 deletions.
16 changes: 3 additions & 13 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
from pyspark.sql.utils import install_exception_handler
from pyspark.sql.functions import UserDefinedFunction

try:
import pandas
Expand Down Expand Up @@ -191,19 +192,8 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> sqlContext.sql("SELECT stringLengthInt('test')").collect()
[Row(_c0=4)]
"""
func = lambda _, it: map(lambda x: f(*x), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
pickled_cmd, bvars, env, includes = _prepare_for_python_RDD(self._sc, command, self)
self._ssql_ctx.udf().registerPython(name,
bytearray(pickled_cmd),
env,
includes,
self._sc.pythonExec,
self._sc.pythonVer,
bvars,
self._sc._javaAccumulator,
returnType.json())
udf = UserDefinedFunction(f, returnType, name)
self._ssql_ctx.udf().registerPython(name, udf._judf)

def _inferSchemaFromList(self, data):
"""
Expand Down
17 changes: 10 additions & 7 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,23 +658,26 @@ class UserDefinedFunction(object):
.. versionadded:: 1.3
"""
def __init__(self, func, returnType):
def __init__(self, func, returnType, name=None):
self.func = func
self.returnType = returnType
self._broadcast = None
self._judf = self._create_judf()
self._judf = self._create_judf(name)

def _create_judf(self):
f = self.func # put it in closure `func`
func = lambda _, it: map(lambda x: f(*x), it)
def _create_judf(self, name):
f, returnType = self.func, self.returnType # put them in closure `func`
from pyspark.serializers import CloudPickleSerializer
CloudPickleSerializer().dumps(returnType)
func = lambda _, it: map(lambda x: returnType.toInternal(f(*x)), it)
ser = AutoBatchedSerializer(PickleSerializer())
command = (func, None, ser, ser)
sc = SparkContext._active_spark_context
pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(self.returnType.json())
fname = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.UserDefinedPythonFunction(fname, bytearray(pickled_command), env, includes,
if name is None:
name = f.__name__ if hasattr(f, '__name__') else f.__class__.__name__
judf = sc._jvm.UserDefinedPythonFunction(name, bytearray(pickled_command), env, includes,
sc.pythonExec, sc.pythonVer, broadcast_vars,
sc._javaAccumulator, jdt)
return judf
Expand Down
4 changes: 3 additions & 1 deletion python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,12 +413,14 @@ def test_apply_schema_with_udt(self):
self.assertEquals(point, ExamplePoint(1.0, 2.0))

def test_udf_with_udt(self):
from pyspark.sql.tests import ExamplePoint
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.sc.parallelize([row]).toDF()
self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])

def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
Expand Down
44 changes: 8 additions & 36 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,10 @@ import java.util.{List => JList, Map => JMap}
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try

import org.apache.spark.{Accumulator, Logging}
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.Logging
import org.apache.spark.sql.api.java._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType

/**
Expand All @@ -40,44 +37,19 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {

private val functionRegistry = sqlContext.functionRegistry

protected[sql] def registerPython(
name: String,
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
pythonExec: String,
pythonVer: String,
broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]],
stringDataType: String): Unit = {
protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
log.debug(
s"""
| Registering new PythonUDF:
| name: $name
| command: ${command.toSeq}
| envVars: $envVars
| pythonIncludes: $pythonIncludes
| pythonExec: $pythonExec
| dataType: $stringDataType
| command: ${udf.command.toSeq}
| envVars: ${udf.envVars}
| pythonIncludes: ${udf.pythonIncludes}
| pythonExec: ${udf.pythonExec}
| dataType: ${udf.dataType}
""".stripMargin)


val dataType = sqlContext.parseDataType(stringDataType)

def builder(e: Seq[Expression]): PythonUDF =
PythonUDF(
name,
command,
envVars,
pythonIncludes,
pythonExec,
pythonVer,
broadcastVars,
accumulator,
dataType,
e)

functionRegistry.registerFunction(name, builder)
functionRegistry.registerFunction(name, udf.builder)
}

// scalastyle:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.Accumulator
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.expressions.ScalaUDF
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalaUDF}
import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType

Expand Down Expand Up @@ -66,10 +66,14 @@ private[sql] case class UserDefinedPythonFunction(
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType) {

def builder(e: Seq[Expression]): PythonUDF = {
PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer, broadcastVars,
accumulator, dataType, e)
}

/** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
def apply(exprs: Column*): Column = {
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, pythonVer,
broadcastVars, accumulator, dataType, exprs.map(_.expr))
val udf = builder(exprs.map(_.expr))
Column(udf)
}
}

0 comments on commit 8f9c58b

Please sign in to comment.