Skip to content

Commit

Permalink
[SPARK-23584][SQL] NewInstance should support interpreted execution
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pr supported interpreted mode for `NewInstance`.

## How was this patch tested?
Added tests in `ObjectExpressionsSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #20778 from maropu/SPARK-23584.
  • Loading branch information
maropu authored and hvanhovell committed Apr 19, 2018
1 parent 46bb2b5 commit 1b08c43
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

package org.apache.spark.sql.catalyst

import java.lang.reflect.Constructor

import org.apache.commons.lang3.reflect.ConstructorUtils

import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
Expand Down Expand Up @@ -781,6 +785,15 @@ object ScalaReflection extends ScalaReflection {
}
}

/**
* Finds an accessible constructor with compatible parameters. This is a more flexible search
* than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible
* matching constructor is returned. Otherwise, it returns `None`.
*/
def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = {
Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*))
}

/**
* Whether the fields of the given type is defined entirely by its constructor parameters.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,32 @@ case class NewInstance(
childrenResolved && !needOuterPointer
}

override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported.")
@transient private lazy val constructor: (Seq[AnyRef]) => Any = {
val paramTypes = ScalaReflection.expressionJavaClasses(arguments)
val getConstructor = (paramClazz: Seq[Class[_]]) => {
ScalaReflection.findConstructor(cls, paramClazz).getOrElse {
sys.error(s"Couldn't find a valid constructor on $cls")
}
}
outerPointer.map { p =>
val outerObj = p()
val d = outerObj.getClass +: paramTypes
val c = getConstructor(outerObj.getClass +: paramTypes)
(args: Seq[AnyRef]) => {
c.newInstance(outerObj +: args: _*)
}
}.getOrElse {
val c = getConstructor(paramTypes)
(args: Seq[AnyRef]) => {
c.newInstance(args: _*)
}
}
}

override def eval(input: InternalRow): Any = {
val argValues = arguments.map(_.eval(input))
constructor(argValues.map(_.asInstanceOf[AnyRef]))
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = CodeGenerator.javaType(dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ class InvokeTargetSubClass extends InvokeTargetClass {
override def binOp(e1: Int, e2: Double): Double = e1 - e2
}

// Tests for NewInstance
class Outer extends Serializable {
class Inner(val value: Int) {
override def hashCode(): Int = super.hashCode()
override def equals(other: Any): Boolean = {
if (other.isInstanceOf[Inner]) {
value == other.asInstanceOf[Inner].value
} else {
false
}
}
}
}

class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {

test("SPARK-16622: The returned value of the called method in Invoke can be null") {
Expand Down Expand Up @@ -383,6 +397,27 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("SPARK-23584 NewInstance should support interpreted execution") {
// Normal case test
val newInst1 = NewInstance(
cls = classOf[GenericArrayData],
arguments = Literal.fromObject(List(1, 2, 3)) :: Nil,
propagateNull = false,
dataType = ArrayType(IntegerType),
outerPointer = None)
checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3)))

// Inner class case test
val outerObj = new Outer()
val newInst2 = NewInstance(
cls = classOf[outerObj.Inner],
arguments = Literal(1) :: Nil,
propagateNull = false,
dataType = ObjectType(classOf[outerObj.Inner]),
outerPointer = Some(() => outerObj))
checkObjectExprEvaluation(newInst2, new outerObj.Inner(1))
}

test("LambdaVariable should support interpreted execution") {
def genSchema(dt: DataType): Seq[StructType] = {
Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil),
Expand Down Expand Up @@ -421,6 +456,7 @@ class TestBean extends Serializable {
private var x: Int = 0

def setX(i: Int): Unit = x = i

def setNonPrimitive(i: AnyRef): Unit =
assert(i != null, "this setter should not be called with null.")
}

0 comments on commit 1b08c43

Please sign in to comment.