From c9b01c5449e872a2686b7d6d40374322e9da2233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BE=99=E4=B8=89?= Date: Mon, 7 Jun 2021 18:29:40 +0800 Subject: [PATCH] [FLINK-22454] ignore casting on interoperable type when extracting lookup keys --- .../common/CommonPhysicalLookupJoin.scala | 7 ++- .../plan/stream/sql/join/LookupJoinTest.scala | 48 +++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalLookupJoin.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalLookupJoin.scala index bea30e113bcdc..e8f468fe95750 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalLookupJoin.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/nodes/physical/common/CommonPhysicalLookupJoin.scala @@ -28,6 +28,7 @@ import org.apache.flink.table.planner.plan.utils.LookupJoinUtil._ import org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall import org.apache.flink.table.planner.plan.utils.RelExplainUtil.preferExpressionFormat import org.apache.flink.table.planner.plan.utils.{JoinTypeUtil, LookupJoinUtil, RelExplainUtil} +import org.apache.flink.table.runtime.types.PlannerTypeUtils import org.apache.calcite.plan.{RelOptCluster, RelOptTable, RelTraitSet} import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField} @@ -295,7 +296,11 @@ abstract class CommonPhysicalLookupJoin( expr = call.getOperands.get(0) case call: RexCall if call.getOperator == SqlStdOperatorTable.CAST => // drill through identity function - expr = call.getOperands.get(0) + val outputType = call.getType + val inputType = call.getOperands.get(0).getType + val isCompatible = PlannerTypeUtils.isInteroperable( + FlinkTypeFactory.toLogicalType(outputType), FlinkTypeFactory.toLogicalType(inputType)) + expr = if (isCompatible) call.getOperands.get(0) else expr case _ => } expr match { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/LookupJoinTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/LookupJoinTest.scala index 4d417456b903a..b7a24ea384269 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/LookupJoinTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/stream/sql/join/LookupJoinTest.scala @@ -474,6 +474,54 @@ class LookupJoinTest(legacyTableSource: Boolean) extends TableTestBase with Seri util.verifyExecPlan(sql) } + @Test + def testJoinTemporalTableWithCastOnLookupTable(): Unit = { + util.addTable( + """ + |CREATE TABLE LookupTable2 ( + | `id` decimal(38, 18), + | `name` STRING, + | `age` INT + |) WITH ( + | 'connector' = 'values' + |) + |""".stripMargin) + val sql = + """ + |SELECT MyTable.b, LookupTable2.id + |FROM MyTable + |LEFT JOIN LookupTable2 FOR SYSTEM_TIME AS OF MyTable.`proctime` + |ON MyTable.a = CAST(LookupTable2.`id` as INT) + |""".stripMargin + thrown.expect(classOf[TableException]) + thrown.expectMessage("Temporal table join requires an equality condition on fields of " + + "table [default_catalog.default_database.LookupTable2]") + verifyTranslationSuccess(sql) + } + + @Test + def testJoinTemporalTableWithInteroperableCastOnLookupTable(): Unit = { + util.addTable( + """ + |CREATE TABLE LookupTable2 ( + | `id` INT, + | `name` char(10), + | `age` INT + |) WITH ( + | 'connector' = 'values' + |) + |""".stripMargin) + + val sql = + """ + |SELECT MyTable.b, LookupTable2.id + |FROM MyTable + |LEFT JOIN LookupTable2 FOR SYSTEM_TIME AS OF MyTable.`proctime` + |ON MyTable.b = CAST(LookupTable2.`name` as String) + |""".stripMargin + verifyTranslationSuccess(sql) + } + // ========================================================================================== private def createLookupTable(tableName: String, lookupFunction: UserDefinedFunction): Unit = {