From 9f9e0aed983ebca607044f6b8b41d8488ddbc963 Mon Sep 17 00:00:00 2001 From: Xpray Date: Tue, 26 Dec 2017 17:40:14 +0800 Subject: [PATCH] [FLINK-8312][TableAPI && SQL] Fix ScalarFunction varargs length exceed 254 --- .../functions/utils/ScalarSqlFunction.scala | 27 ++++++++++++------- .../table/runtime/stream/sql/SqlITCase.scala | 27 +++++++++++++++++++ 2 files changed, 44 insertions(+), 10 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala index 27e093dc3fbaa..cbe2ac77748a5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/utils/ScalarSqlFunction.scala @@ -140,7 +140,7 @@ object ScalarSqlFunction { scalarFunction: ScalarFunction) : SqlOperandTypeChecker = { - val signatures = getMethodSignatures(scalarFunction, "eval") + val methods = checkAndExtractMethods(scalarFunction, "eval") /** * Operand type checker based on [[ScalarFunction]] given information. @@ -151,17 +151,24 @@ object ScalarSqlFunction { } override def getOperandCountRange: SqlOperandCountRange = { - var min = 255 + var min = 254 // according to JVM spec 4.3.3 var max = -1 - signatures.foreach( sig => { - var len = sig.length - if (len > 0 && sig(sig.length - 1).isArray) { - max = 254 // according to JVM spec 4.3.3 - len = sig.length - 1 + var isVarargs = false + methods.foreach( + m => { + var len = m.getParameterTypes.length + if (len > 0 && m.isVarArgs && m.getParameterTypes()(len - 1).isArray) { + isVarargs = true + len = len - 1 + } + max = Math.max(len, max) + min = Math.min(len, min) } - max = Math.max(len, max) - min = Math.min(len, min) - }) + ) + if (isVarargs) { + // if eval method is varargs, set max to -1 to skip length check in Calcite + max = -1 + } SqlOperandCountRanges.between(min, max) } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala index 18b45a36ea26b..76d01262ac3fb 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/stream/sql/SqlITCase.scala @@ -26,10 +26,12 @@ import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment import org.apache.flink.table.api.{TableEnvironment, Types} import org.apache.flink.table.api.scala._ import org.apache.flink.table.expressions.utils.SplitUDF +import org.apache.flink.table.expressions.utils.Func15 import org.apache.flink.table.runtime.utils.TimeTestUtil.EventTimeSourceFunction import org.apache.flink.table.runtime.utils.{StreamITCase, StreamTestData, StreamingWithStateTestBase} import org.apache.flink.types.Row import org.apache.flink.table.utils.MemoryTableSinkUtil + import org.junit.Assert._ import org.junit._ @@ -516,4 +518,29 @@ class SqlITCase extends StreamingWithStateTestBase { val expected = List("a,a,d,d,e,e", "x,x,z,z,z,z") assertEquals(expected.sorted, StreamITCase.testResults.sorted) } + + @Test + def testUDFWithLongVarargs(): Unit = { + val env = StreamExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env) + StreamITCase.clear + + tEnv.registerFunction("func15", Func15) + + val parameters = "c," + (0 until 255).map(_ => "a").mkString(",") + val sqlQuery = s"SELECT func15($parameters) FROM T1" + + val t1 = StreamTestData.getSmall3TupleDataStream(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("T1", t1) + + val result = tEnv.sqlQuery(sqlQuery).toAppendStream[Row] + result.addSink(new StreamITCase.StringSink[Row]) + env.execute() + + val expected = List( + "Hi255", + "Hello255", + "Hello world255") + assertEquals(expected.sorted, StreamITCase.testResults.sorted) + } }