diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index e4274aaa9727e..818cc2fb1e8a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.catalyst +import java.lang.reflect.Constructor + +import org.apache.commons.lang3.reflect.ConstructorUtils + import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ @@ -781,6 +785,15 @@ object ScalaReflection extends ScalaReflection { } } + /** + * Finds an accessible constructor with compatible parameters. This is a more flexible search + * than the exact matching algorithm in `Class.getConstructor`. The first assignment-compatible + * matching constructor is returned. Otherwise, it returns `None`. + */ + def findConstructor(cls: Class[_], paramTypes: Seq[Class[_]]): Option[Constructor[_]] = { + Option(ConstructorUtils.getMatchingAccessibleConstructor(cls, paramTypes: _*)) + } + /** * Whether the fields of the given type is defined entirely by its constructor parameters. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 72b202b3a5020..1645bd7d57b1d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -449,8 +449,32 @@ case class NewInstance( childrenResolved && !needOuterPointer } - override def eval(input: InternalRow): Any = - throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + @transient private lazy val constructor: (Seq[AnyRef]) => Any = { + val paramTypes = ScalaReflection.expressionJavaClasses(arguments) + val getConstructor = (paramClazz: Seq[Class[_]]) => { + ScalaReflection.findConstructor(cls, paramClazz).getOrElse { + sys.error(s"Couldn't find a valid constructor on $cls") + } + } + outerPointer.map { p => + val outerObj = p() + val d = outerObj.getClass +: paramTypes + val c = getConstructor(outerObj.getClass +: paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(outerObj +: args: _*) + } + }.getOrElse { + val c = getConstructor(paramTypes) + (args: Seq[AnyRef]) => { + c.newInstance(args: _*) + } + } + } + + override def eval(input: InternalRow): Any = { + val argValues = arguments.map(_.eval(input)) + constructor(argValues.map(_.asInstanceOf[AnyRef])) + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = CodeGenerator.javaType(dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index b0188b0098def..bf805f4f29ac5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -47,6 +47,20 @@ class InvokeTargetSubClass extends InvokeTargetClass { override def binOp(e1: Int, e2: Double): Double = e1 - e2 } +// Tests for NewInstance +class Outer extends Serializable { + class Inner(val value: Int) { + override def hashCode(): Int = super.hashCode() + override def equals(other: Any): Boolean = { + if (other.isInstanceOf[Inner]) { + value == other.asInstanceOf[Inner].value + } else { + false + } + } + } +} + class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-16622: The returned value of the called method in Invoke can be null") { @@ -383,6 +397,27 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("SPARK-23584 NewInstance should support interpreted execution") { + // Normal case test + val newInst1 = NewInstance( + cls = classOf[GenericArrayData], + arguments = Literal.fromObject(List(1, 2, 3)) :: Nil, + propagateNull = false, + dataType = ArrayType(IntegerType), + outerPointer = None) + checkObjectExprEvaluation(newInst1, new GenericArrayData(List(1, 2, 3))) + + // Inner class case test + val outerObj = new Outer() + val newInst2 = NewInstance( + cls = classOf[outerObj.Inner], + arguments = Literal(1) :: Nil, + propagateNull = false, + dataType = ObjectType(classOf[outerObj.Inner]), + outerPointer = Some(() => outerObj)) + checkObjectExprEvaluation(newInst2, new outerObj.Inner(1)) + } + test("LambdaVariable should support interpreted execution") { def genSchema(dt: DataType): Seq[StructType] = { Seq(StructType(StructField("col_1", dt, nullable = false) :: Nil), @@ -421,6 +456,7 @@ class TestBean extends Serializable { private var x: Int = 0 def setX(i: Int): Unit = x = i + def setNonPrimitive(i: AnyRef): Unit = assert(i != null, "this setter should not be called with null.") }