From e73c84ff75e66dee9a395c9913109a965d7d68f4 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Thu, 9 Feb 2017 15:04:36 -0500 Subject: [PATCH 1/4] got something working but not sure how good this is yet --- .../catalyst/expressions/TypedScalaUDF.scala | 112 ++++++++++++++++++ .../TypedUserDefinedFunction.scala | 32 +++++ .../org/apache/spark/sql/functions.scala | 24 +++- .../org/apache/spark/sql/TypedUDFSuite.scala | 65 ++++++++++ 4 files changed, 231 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala new file mode 100644 index 0000000000000..e4a9691505df3 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import scala.reflect.ClassTag + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +case class TypedScalaUDF[T1, R]( + function: (T1) => R, + t1Encoder: ExpressionEncoder[T1], + rEncoder: ExpressionEncoder[R], + child1: Expression, + inputTypes: Seq[DataType] = Nil + ) extends Expression with ImplicitCastInputTypes with NonSQLExpression { + + override def children: Seq[Expression] = Seq(child1) + + // something is wrong with encoders and Option so this is used to handle that + private val t1IsOption = t1Encoder.clsTag == implicitly[ClassTag[Option[_]]] + private val rIsOption = rEncoder.clsTag == implicitly[ClassTag[Option[_]]] + + override val dataType: DataType = + if (rEncoder.flat || rIsOption) + rEncoder.schema.head.dataType + else + rEncoder.schema + + override val nullable: Boolean = + if (rEncoder.flat) + rEncoder.schema.head.nullable + else + true + + val boundT1Encoder: ExpressionEncoder[T1] = t1Encoder.resolveAndBind() + + def internalRow(x: Any): InternalRow = InternalRow(x) + + override def eval(input: InternalRow): Any = { + val rowIn = if (t1Encoder.flat || t1IsOption) InternalRow(child1.eval(input)) else child1.eval(input).asInstanceOf[InternalRow] + val t1 = boundT1Encoder.fromRow(rowIn) + val rowOut = rEncoder.toRow(function(t1)) + if (rEncoder.flat || rIsOption) rowOut.get(0, dataType) else rowOut + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val typedScalaUDF = ctx.addReferenceObj("typedScalaUDF", this, classOf[TypedScalaUDF[_, _]].getName) + + // codegen for children expressions + val eval1 = child1.genCode(ctx) + val evalCode1 = eval1.code.mkString + + // codegen for this expression + val rowInTerm = ctx.freshName("rowIn") + val rowIn = if (t1Encoder.flat || t1IsOption) + s"${classOf[InternalRow].getName} $rowInTerm = ${eval1.isNull} ? ${typedScalaUDF}.internalRow(null) : ${typedScalaUDF}.internalRow(${eval1.value});" + else + s"${classOf[InternalRow].getName} $rowInTerm = (${classOf[InternalRow].getName}) ${eval1.value};" + val t1Term = ctx.freshName("t1") + val t1 = s"Object $t1Term = ${typedScalaUDF}.boundT1Encoder().fromRow($rowInTerm);" + val rTerm = ctx.freshName("r") + val r = s"Object $rTerm = ${typedScalaUDF}.function().apply($t1Term);" + val rowOutTerm = ctx.freshName("rowOut") + val rowOut = s"${classOf[InternalRow].getName} $rowOutTerm = ${typedScalaUDF}.rEncoder().toRow($rTerm);" + val resultTerm = ctx.freshName("result") + val result = if (rEncoder.flat || rIsOption) + s"Object $resultTerm = ${rowOutTerm}.get(0, ${typedScalaUDF}.dataType());" + else + s"Object $resultTerm = ${rowOutTerm};" + + // put it all in place + ev.copy(code = s""" + $evalCode1 + $rowIn + //System.out.println("rowIn is " + $rowInTerm); + $t1 + //System.out.println("t1 is " + $t1Term); + $r + //System.out.println("r is " + $rTerm); + $rowOut + //System.out.println("rowOut is " + $rowOutTerm); + $result + //System.out.println("result is " + $resultTerm); + + boolean ${ev.isNull} = $resultTerm == null; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.boxedType(dataType)}) $resultTerm; + } + """) + } + + override def toString: String = s"TypedUDF(${children.mkString(", ")})" +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala new file mode 100644 index 0000000000000..d94e013ea07b5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala @@ -0,0 +1,32 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.expressions + +import org.apache.spark.sql.catalyst.expressions.TypedScalaUDF +import org.apache.spark.sql.catalyst.encoders.encoderFor +import org.apache.spark.sql.{Column, Encoder} +import org.apache.spark.sql.functions.{lit, struct} + +case class TypedUserDefinedFunction[T1: Encoder, R: Encoder](f: T1 => R) { + def apply(exprs: Column*): Column = { + val t1Encoder = encoderFor[T1] + val rEncoder = encoderFor[R] + val expr1 = if (exprs.size == 0) lit(0) else if (exprs.size == 1) exprs.head else struct(exprs: _*) + Column(TypedScalaUDF(f, t1Encoder, rEncoder, expr1.expr)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 24ed906d33683..f6667d7a07daf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -25,12 +25,12 @@ import scala.util.Try import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.expressions.UserDefinedFunction +import org.apache.spark.sql.expressions.{UserDefinedFunction, TypedUserDefinedFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3271,4 +3271,24 @@ object functions { def callUDF(udfName: String, cols: Column*): Column = withExpr { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } + + def typedUdf[R: Encoder](f: Function0[R]): TypedUserDefinedFunction[Int, R] = { + implicit val intEncoder: Encoder[Int] = Encoders.scalaInt + TypedUserDefinedFunction{ (i: Int) => f() } + } + + def typedUdf[T1: Encoder, R: Encoder](f: Function1[T1, R]): TypedUserDefinedFunction[T1, R] = { + TypedUserDefinedFunction(f) + } + + def typedUdf[T1: Encoder, T2: Encoder, R: Encoder](f: Function2[T1, T2, R]): TypedUserDefinedFunction[(T1, T2), R] = { + implicit val t1t2Encoder: Encoder[(T1, T2)] = ExpressionEncoder.tuple(encoderFor[T1], encoderFor[T2]) + TypedUserDefinedFunction(f.tupled) + } + + def typedUdf[T1: Encoder, T2: Encoder, T3: Encoder, R: Encoder](f: Function3[T1, T2, T3, R]): TypedUserDefinedFunction[(T1, T2, T3), R] = { + implicit val t1t2t3Encoder: Encoder[(T1, T2, T3)] = ExpressionEncoder.tuple(encoderFor[T1], encoderFor[T2], encoderFor[T3]) + TypedUserDefinedFunction(f.tupled) + } } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala new file mode 100644 index 0000000000000..ca8631acd9533 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.test.SQLTestData._ +import org.apache.spark.sql.functions.{ col, udf, typedUdf } + +case class Blah(a: Int, b: String) + +class TypedUDFSuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + test("typedUdf") { + import java.lang.{ Integer => JInt } + + val df1 = Seq((1, "a", 1: JInt), (2, "b", 2: JInt), (3, "c", null)).toDF("x", "y", "z").select( + typedUdf{ () => false }.apply() as "f0", + typedUdf{ x: Int => x * 2 }.apply('x) as "f1", + typedUdf{ y: String => y + "!" }.apply('y) as "f2", + typedUdf{ blah: Blah => blah.copy(a = blah.a + 1) }.apply('x, 'y) as "f3", + typedUdf{ (x: Int, y: String) => (x + 1, y) }.apply('x, 'y) as "f4", + typedUdf{ z: Int => z + 1 }.apply('z) as "f5", + typedUdf{ z: Option[Int] => z.map(_ + 1) }.apply('z) as "f6", + typedUdf[JInt, JInt]{ z => if (z != null) z + 1 else null }.apply('z) as "f7", + typedUdf{ x: Int => x + 1 }.apply(col("x") * 2) as "f8" + ) + df1.printSchema + df1.explain + df1.show + + val df2 = testData + .withColumn("z", udf{ (x: JInt) => if (x > 10) x else null }.apply('value) as 'z) + .select( + typedUdf{ () => false }.apply() as "f0", + typedUdf{ x: Int => x * 2 }.apply('key) as "f1", + typedUdf{ y: String => y + "!" }.apply('value) as "f2", + typedUdf{ blah: Blah => blah.copy(a = blah.a + 1) }.apply('key, 'value) as "f3", + typedUdf{ (x: Int, y: String) => (x + 1, y) }.apply('key, 'value) as "f4", + typedUdf{ z: Int => z + 1 }.apply('z) as "f5", + typedUdf{ z: Option[Int] => z.map(_ + 1) }.apply('z) as "f6", + typedUdf[JInt, JInt]{ z => if (z != null) z + 1 else null }.apply('z) as "f7", + typedUdf{ x: Int => x + 1 }.apply(col("key") * 2) as "f8" + ) + df2.printSchema + df2.explain + df2.show + } + +} From bd111ae1f1a721ae2664d0e8a5810f018eb4c935 Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Thu, 9 Feb 2017 20:16:34 -0500 Subject: [PATCH 2/4] pattern match to create expr1 --- .../spark/sql/expressions/TypedUserDefinedFunction.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala index d94e013ea07b5..3f7b50c1151b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala @@ -26,7 +26,11 @@ case class TypedUserDefinedFunction[T1: Encoder, R: Encoder](f: T1 => R) { def apply(exprs: Column*): Column = { val t1Encoder = encoderFor[T1] val rEncoder = encoderFor[R] - val expr1 = if (exprs.size == 0) lit(0) else if (exprs.size == 1) exprs.head else struct(exprs: _*) + val expr1 = exprs match { + case Seq() => lit(0) + case Seq(expr1) => expr1 + case exprs => struct(exprs: _*) + } Column(TypedScalaUDF(f, t1Encoder, rEncoder, expr1.expr)) } } From 1a257c366aa1c0181dd7791fdd6eb05140c48d7e Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Thu, 9 Feb 2017 21:04:53 -0500 Subject: [PATCH 3/4] fix bug with internal row getting re-used and check results in unit tests --- .../sql/catalyst/expressions/TypedScalaUDF.scala | 16 +++++++++++++--- .../org/apache/spark/sql/TypedUDFSuite.scala | 15 ++++++++++++++- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala index e4a9691505df3..fd2e6d68fb950 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala @@ -55,10 +55,19 @@ case class TypedScalaUDF[T1, R]( def internalRow(x: Any): InternalRow = InternalRow(x) override def eval(input: InternalRow): Any = { - val rowIn = if (t1Encoder.flat || t1IsOption) InternalRow(child1.eval(input)) else child1.eval(input).asInstanceOf[InternalRow] + val eval1 = child1.eval(input) + //println("eval1 " + eval1) + val rowIn = if (t1Encoder.flat || t1IsOption) InternalRow(eval1) else eval1.asInstanceOf[InternalRow] + //println("rowIn" + rowIn) val t1 = boundT1Encoder.fromRow(rowIn) - val rowOut = rEncoder.toRow(function(t1)) - if (rEncoder.flat || rIsOption) rowOut.get(0, dataType) else rowOut + //println("t1 " + t1) + val r = function(t1) + //println("r " + r) + val rowOut = rEncoder.toRow(r).copy() // not entirely sure why i need the copy but i do + //println("rowOut " + rowOut) + val result = if (rEncoder.flat || rIsOption) rowOut.get(0, dataType) else rowOut + //println("result " + result) + result } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -89,6 +98,7 @@ case class TypedScalaUDF[T1, R]( // put it all in place ev.copy(code = s""" $evalCode1 + //System.out.println("eval1 is " + ${eval1.value}); $rowIn //System.out.println("rowIn is " + $rowInTerm); $t1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala index ca8631acd9533..3494b152b0e3f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala @@ -29,6 +29,7 @@ class TypedUDFSuite extends QueryTest with SharedSQLContext { test("typedUdf") { import java.lang.{ Integer => JInt } + // this seems to use TypedScalaUDF.eval, not sure why val df1 = Seq((1, "a", 1: JInt), (2, "b", 2: JInt), (3, "c", null)).toDF("x", "y", "z").select( typedUdf{ () => false }.apply() as "f0", typedUdf{ x: Int => x * 2 }.apply('x) as "f1", @@ -43,9 +44,16 @@ class TypedUDFSuite extends QueryTest with SharedSQLContext { df1.printSchema df1.explain df1.show + checkAnswer(df1, Seq( + Row(false, 2, "a!", Row(2, "a"), Row(2, "a"), 2, 2, 2, 3), + Row(false, 4, "b!", Row(3, "b"), Row(3, "b"), 3, 3, 3, 5), + Row(false, 6, "c!", Row(4, "c"), Row(4, "c"), 1, null, null, 7) + )) + // this seems to use TypedScalaUDF.doGenCode, not sure why val df2 = testData - .withColumn("z", udf{ (x: JInt) => if (x > 10) x else null }.apply('value) as 'z) + .filter("key < 4") + .withColumn("z", udf{ (x: JInt) => if (x < 3) x else null }.apply('value) as 'z) .select( typedUdf{ () => false }.apply() as "f0", typedUdf{ x: Int => x * 2 }.apply('key) as "f1", @@ -60,6 +68,11 @@ class TypedUDFSuite extends QueryTest with SharedSQLContext { df2.printSchema df2.explain df2.show + checkAnswer(df2, Seq( + Row(false, 2, "1!", Row(2, "1"), Row(2, "1"), 2, 2, 2, 3), + Row(false, 4, "2!", Row(3, "2"), Row(3, "2"), 3, 3, 3, 5), + Row(false, 6, "3!", Row(4, "3"), Row(4, "3"), 1, null, null, 7) + )) } } From e1c337ae9e747a3dc02b939a87c9bb5c8605b86c Mon Sep 17 00:00:00 2001 From: Koert Kuipers Date: Thu, 9 Feb 2017 22:26:40 -0500 Subject: [PATCH 4/4] deal with annoying style rules --- .../catalyst/expressions/TypedScalaUDF.scala | 76 +++++++++++-------- .../TypedUserDefinedFunction.scala | 5 +- .../org/apache/spark/sql/functions.scala | 8 +- .../org/apache/spark/sql/TypedUDFSuite.scala | 14 ++-- 4 files changed, 60 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala index fd2e6d68fb950..2a67700e68e05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TypedScalaUDF.scala @@ -35,20 +35,20 @@ case class TypedScalaUDF[T1, R]( override def children: Seq[Expression] = Seq(child1) // something is wrong with encoders and Option so this is used to handle that - private val t1IsOption = t1Encoder.clsTag == implicitly[ClassTag[Option[_]]] + private val t1IsOption = t1Encoder.clsTag == implicitly[ClassTag[Option[_]]] private val rIsOption = rEncoder.clsTag == implicitly[ClassTag[Option[_]]] - override val dataType: DataType = - if (rEncoder.flat || rIsOption) - rEncoder.schema.head.dataType - else - rEncoder.schema + override val dataType: DataType = if (rEncoder.flat || rIsOption) { + rEncoder.schema.head.dataType + } else { + rEncoder.schema + } - override val nullable: Boolean = - if (rEncoder.flat) - rEncoder.schema.head.nullable - else - true + override val nullable: Boolean = if (rEncoder.flat) { + rEncoder.schema.head.nullable + } else { + true + } val boundT1Encoder: ExpressionEncoder[T1] = t1Encoder.resolveAndBind() @@ -56,60 +56,70 @@ case class TypedScalaUDF[T1, R]( override def eval(input: InternalRow): Any = { val eval1 = child1.eval(input) - //println("eval1 " + eval1) - val rowIn = if (t1Encoder.flat || t1IsOption) InternalRow(eval1) else eval1.asInstanceOf[InternalRow] - //println("rowIn" + rowIn) + // println("eval1 " + eval1) + val rowIn = if (t1Encoder.flat || t1IsOption) { + InternalRow(eval1) + } else { + eval1.asInstanceOf[InternalRow] + } + // println("rowIn" + rowIn) val t1 = boundT1Encoder.fromRow(rowIn) - //println("t1 " + t1) + // println("t1 " + t1) val r = function(t1) - //println("r " + r) + // println("r " + r) val rowOut = rEncoder.toRow(r).copy() // not entirely sure why i need the copy but i do - //println("rowOut " + rowOut) + // println("rowOut " + rowOut) val result = if (rEncoder.flat || rIsOption) rowOut.get(0, dataType) else rowOut - //println("result " + result) + // println("result " + result) result } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val typedScalaUDF = ctx.addReferenceObj("typedScalaUDF", this, classOf[TypedScalaUDF[_, _]].getName) + val typedScalaUDF = ctx.addReferenceObj("typedScalaUDF", this, + classOf[TypedScalaUDF[_, _]].getName) // codegen for children expressions val eval1 = child1.genCode(ctx) val evalCode1 = eval1.code.mkString // codegen for this expression + val rowClass = classOf[InternalRow].getName val rowInTerm = ctx.freshName("rowIn") - val rowIn = if (t1Encoder.flat || t1IsOption) - s"${classOf[InternalRow].getName} $rowInTerm = ${eval1.isNull} ? ${typedScalaUDF}.internalRow(null) : ${typedScalaUDF}.internalRow(${eval1.value});" - else - s"${classOf[InternalRow].getName} $rowInTerm = (${classOf[InternalRow].getName}) ${eval1.value};" + val rowIn = if (t1Encoder.flat || t1IsOption) { + s"""$rowClass $rowInTerm = ${eval1.isNull} ? + ${typedScalaUDF}.internalRow(null) : + ${typedScalaUDF}.internalRow(${eval1.value});""" + } else { + s"""$rowClass $rowInTerm =($rowClass) ${eval1.value};""" + } val t1Term = ctx.freshName("t1") val t1 = s"Object $t1Term = ${typedScalaUDF}.boundT1Encoder().fromRow($rowInTerm);" val rTerm = ctx.freshName("r") val r = s"Object $rTerm = ${typedScalaUDF}.function().apply($t1Term);" val rowOutTerm = ctx.freshName("rowOut") - val rowOut = s"${classOf[InternalRow].getName} $rowOutTerm = ${typedScalaUDF}.rEncoder().toRow($rTerm);" + val rowOut = s"$rowClass $rowOutTerm = ${typedScalaUDF}.rEncoder().toRow($rTerm);" val resultTerm = ctx.freshName("result") - val result = if (rEncoder.flat || rIsOption) + val result = if (rEncoder.flat || rIsOption) { s"Object $resultTerm = ${rowOutTerm}.get(0, ${typedScalaUDF}.dataType());" - else + } else { s"Object $resultTerm = ${rowOutTerm};" + } // put it all in place ev.copy(code = s""" $evalCode1 - //System.out.println("eval1 is " + ${eval1.value}); + // System.out.println("eval1 is " + ${eval1.value}); $rowIn - //System.out.println("rowIn is " + $rowInTerm); + // System.out.println("rowIn is " + $rowInTerm); $t1 - //System.out.println("t1 is " + $t1Term); + // System.out.println("t1 is " + $t1Term); $r - //System.out.println("r is " + $rTerm); + // System.out.println("r is " + $rTerm); $rowOut - //System.out.println("rowOut is " + $rowOutTerm); + // System.out.println("rowOut is " + $rowOutTerm); $result - //System.out.println("result is " + $resultTerm); - + // System.out.println("result is " + $resultTerm); + boolean ${ev.isNull} = $resultTerm == null; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala index 3f7b50c1151b9..6e721928aaf60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/TypedUserDefinedFunction.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.TypedScalaUDF import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.{Column, Encoder} +import org.apache.spark.sql.catalyst.expressions.TypedScalaUDF +import org.apache.spark.sql.Column +import org.apache.spark.sql.Encoder import org.apache.spark.sql.functions.{lit, struct} case class TypedUserDefinedFunction[T1: Encoder, R: Encoder](f: T1 => R) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index f6667d7a07daf..34f3e1ec29001 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.expressions.{UserDefinedFunction, TypedUserDefinedFunction} +import org.apache.spark.sql.expressions.{TypedUserDefinedFunction, UserDefinedFunction} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3272,6 +3272,9 @@ object functions { UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } + // scalastyle:off parameter.number + // scalastyle:off line.size.limit + def typedUdf[R: Encoder](f: Function0[R]): TypedUserDefinedFunction[Int, R] = { implicit val intEncoder: Encoder[Int] = Encoders.scalaInt TypedUserDefinedFunction{ (i: Int) => f() } @@ -3290,5 +3293,8 @@ object functions { implicit val t1t2t3Encoder: Encoder[(T1, T2, T3)] = ExpressionEncoder.tuple(encoderFor[T1], encoderFor[T2], encoderFor[T3]) TypedUserDefinedFunction(f.tupled) } + + // scalastyle:on parameter.number + // scalastyle:on line.size.limit } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala index 3494b152b0e3f..7e56bf7a19fb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedUDFSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql +import org.apache.spark.sql.functions.{col, typedUdf, udf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.functions.{ col, udf, typedUdf } case class Blah(a: Int, b: String) @@ -41,9 +41,9 @@ class TypedUDFSuite extends QueryTest with SharedSQLContext { typedUdf[JInt, JInt]{ z => if (z != null) z + 1 else null }.apply('z) as "f7", typedUdf{ x: Int => x + 1 }.apply(col("x") * 2) as "f8" ) - df1.printSchema - df1.explain - df1.show + // df1.printSchema + // df1.explain + // df1.show checkAnswer(df1, Seq( Row(false, 2, "a!", Row(2, "a"), Row(2, "a"), 2, 2, 2, 3), Row(false, 4, "b!", Row(3, "b"), Row(3, "b"), 3, 3, 3, 5), @@ -65,9 +65,9 @@ class TypedUDFSuite extends QueryTest with SharedSQLContext { typedUdf[JInt, JInt]{ z => if (z != null) z + 1 else null }.apply('z) as "f7", typedUdf{ x: Int => x + 1 }.apply(col("key") * 2) as "f8" ) - df2.printSchema - df2.explain - df2.show + // df2.printSchema + // df2.explain + // df2.show checkAnswer(df2, Seq( Row(false, 2, "1!", Row(2, "1"), Row(2, "1"), 2, 2, 2, 3), Row(false, 4, "2!", Row(3, "2"), Row(3, "2"), 3, 3, 3, 5),