Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-35288][SQL] StaticInvoke should find the method without exact argument classes match #32413

Closed
wants to merge 10 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,34 @@ trait InvokeLike extends Expression with NonSQLExpression {
}
}
}

final def findMethod(cls: Class[_], functionName: String, argClasses: Seq[Class[_]]): Method = {
// Looking with function name + argument classes first.
try {
cls.getMethod(functionName, argClasses: _*)
} catch {
case _: NoSuchMethodException =>
// For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
// We look at function name + argument length
val m = cls.getMethods.filter { m =>
m.getName == functionName && m.getParameterCount == arguments.length
}
if (m.isEmpty) {
sys.error(s"Couldn't find $functionName on $cls")
} else if (m.length > 1) {
// More than one matched method signature. Exclude synthetic one, e.g. generic one.
val realMethods = m.filter(!_.isSynthetic)
if (realMethods.length > 1) {
// Ambiguous case, we don't know which method to choose, just fail it.
sys.error(s"Found ${realMethods.length} $functionName on $cls")
} else {
realMethods.head
}
} else {
m.head
}
}
}
}

/**
Expand Down Expand Up @@ -236,7 +264,7 @@ case class StaticInvoke(
override def children: Seq[Expression] = arguments

lazy val argClasses = ScalaReflection.expressionJavaClasses(arguments)
@transient lazy val method = cls.getDeclaredMethod(functionName, argClasses : _*)
@transient lazy val method = findMethod(cls, functionName, argClasses)

override def eval(input: InternalRow): Any = {
invoke(null, method, arguments, input, dataType)
Expand Down Expand Up @@ -326,31 +354,7 @@ case class Invoke(

@transient lazy val method = targetObject.dataType match {
case ObjectType(cls) =>
// Looking with function name + argument classes first.
try {
Some(cls.getMethod(encodedFunctionName, argClasses: _*))
} catch {
case _: NoSuchMethodException =>
// For some cases, e.g. arg class is Object, `getMethod` cannot find the method.
// We look at function name + argument length
val m = cls.getMethods.filter { m =>
m.getName == encodedFunctionName && m.getParameterCount == arguments.length
}
if (m.isEmpty) {
sys.error(s"Couldn't find $encodedFunctionName on $cls")
} else if (m.length > 1) {
// More than one matched method signature. Exclude synthetic one, e.g. generic one.
val realMethods = m.filter(!_.isSynthetic)
if (realMethods.length > 1) {
// Ambiguous case, we don't know which method to choose, just fail it.
sys.error(s"Found ${realMethods.length} $encodedFunctionName on $cls")
} else {
Some(realMethods.head)
}
} else {
Some(m.head)
}
}
Some(findMethod(cls, encodedFunctionName, argClasses))
case _ => None
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,8 +638,22 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val clsType = ObjectType(classOf[ConcreteClass])
val obj = new ConcreteClass

val input = (1, 2)
checkObjectExprEvaluation(
Invoke(Literal(obj, clsType), "testFunc", IntegerType, Seq(Literal(1))), 0)
Invoke(Literal(obj, clsType), "testFunc", IntegerType,
Seq(Literal(input, ObjectType(input.getClass)))), 2)
Comment on lines +641 to +644
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix this test from SPARK-35278. Original one doesn't produce synthetic method. I may miss it when I changed the code.

}

test("SPARK-35288: static invoke should find method without exact param type match") {
val input = (1, 2)

checkObjectExprEvaluation(
StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func",
Seq(Literal(input, ObjectType(input.getClass)))), 3)

checkObjectExprEvaluation(
StaticInvoke(TestStaticInvoke.getClass, IntegerType, "func",
Seq(Literal(1, IntegerType))), -1)
}
}

Expand All @@ -652,10 +666,22 @@ class TestBean extends Serializable {
assert(i != null, "this setter should not be called with null.")
}

object TestStaticInvoke {
def func(param: Any): Int = param match {
case pair: Tuple2[_, _] =>
pair.asInstanceOf[Tuple2[Int, Int]]._1 + pair.asInstanceOf[Tuple2[Int, Int]]._2
case _ => -1
}
}

abstract class BaseClass[T] {
def testFunc(param: T): T
def testFunc(param: T): Int
}

class ConcreteClass extends BaseClass[Int] with Serializable {
override def testFunc(param: Int): Int = param - 1
class ConcreteClass extends BaseClass[Product] with Serializable {
override def testFunc(param: Product): Int = param match {
case _: Tuple2[_, _] => 2
case _: Tuple3[_, _, _] => 3
case _ => 4
}
}