diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index b7d1a4bfb60c6..6c8fa90e867c3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.plan.nodes.dataset import org.apache.calcite.plan._ import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.calcite.rex.RexNode @@ -47,6 +48,7 @@ class DataSetSingleRowJoin( rowRelDataType: RelDataType, joinCondition: RexNode, joinRowType: RelDataType, + joinType: JoinRelType, ruleDescription: String) extends BiRel(cluster, traitSet, leftNode, rightNode) with DataSetRel { @@ -63,6 +65,7 @@ class DataSetSingleRowJoin( getRowType, joinCondition, joinRowType, + joinType, ruleDescription) } @@ -97,7 +100,6 @@ class DataSetSingleRowJoin( tableEnv.getConfig, leftDataSet.getType, rightDataSet.getType, - leftIsSingle, joinCondition, broadcastSetName) @@ -118,14 +120,18 @@ class DataSetSingleRowJoin( config: TableConfig, inputType1: TypeInformation[Row], inputType2: TypeInformation[Row], - firstIsSingle: Boolean, joinCondition: RexNode, broadcastInputSetName: String) : FlatMapFunction[Row, Row] = { + val isOuterJoin = joinType match { + case JoinRelType.LEFT | JoinRelType.RIGHT => true + case _ => false + } + val codeGenerator = new CodeGenerator( config, - false, + isOuterJoin, inputType1, Some(inputType2)) @@ -138,13 +144,38 @@ class DataSetSingleRowJoin( val condition = codeGenerator.generateExpression(joinCondition) val joinMethodBody = - s""" - |${condition.code} - |if (${condition.resultTerm}) { - | ${conversion.code} - | ${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); - |} - |""".stripMargin + if (joinType == JoinRelType.INNER) { + s""" + |${condition.code} + |if (${condition.resultTerm}) { + | ${conversion.code} + | ${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); + |} + |""".stripMargin + } else { + val singleNode = + if (!leftIsSingle) { + rightNode + } + else { + leftNode + } + + val notSuitedToCondition = singleNode + .getRowType + .getFieldList + .map(field => getRowType.getFieldNames.indexOf(field.getName)) + .map(i => s"${conversion.resultTerm}.setField($i,null);") + + s""" + |${condition.code} + |${conversion.code} + |if(!${condition.resultTerm}){ + |${notSuitedToCondition.mkString("\n")} + |} + |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); + |""".stripMargin + } val genFunction = codeGenerator.generateFunction( ruleDescription, @@ -152,16 +183,18 @@ class DataSetSingleRowJoin( joinMethodBody, returnType) - if (firstIsSingle) { - new MapJoinRightRunner[Row, Row, Row]( + if (!leftIsSingle) { + new MapJoinLeftRunner[Row, Row, Row]( genFunction.name, genFunction.code, + isOuterJoin, genFunction.returnType, broadcastInputSetName) } else { - new MapJoinLeftRunner[Row, Row, Row]( + new MapJoinRightRunner[Row, Row, Row]( genFunction.name, genFunction.code, + isOuterJoin, genFunction.returnType, broadcastInputSetName) } @@ -181,7 +214,7 @@ class DataSetSingleRowJoin( } private def joinTypeToString: String = { - "NestedLoopJoin" + "NestedLoop" + joinType.toString.toLowerCase.capitalize + "Join" } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala index 8df0b59c828c6..18621d0b296f2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala @@ -77,7 +77,9 @@ private class FlinkLogicalJoinConverter val join: LogicalJoin = call.rel(0).asInstanceOf[LogicalJoin] val joinInfo = join.analyzeCondition - hasEqualityPredicates(join, joinInfo) || isSingleRowInnerJoin(join) + (hasEqualityPredicates(join, joinInfo) + || isSingleRowInnerJoin(join) + || isOuterJoinWithSingleRowAtOuterSide(join, joinInfo)) } override def convert(rel: RelNode): RelNode = { @@ -101,6 +103,18 @@ private class FlinkLogicalJoinConverter !joinInfo.pairs().isEmpty && (joinInfo.isEqui || join.getJoinType == JoinRelType.INNER) } + + + private def isOuterJoinWithSingleRowAtOuterSide( + join: LogicalJoin, + joinInfo: JoinInfo): Boolean = { + + val isLeflSingleOrEmpty = joinInfo.leftKeys.size() < 2 + val isRightSingleOrEmpty = joinInfo.rightKeys.size() < 2 + ((join.getJoinType == JoinRelType.RIGHT && isLeflSingleOrEmpty) + || (join.getJoinType == JoinRelType.LEFT && isRightSingleOrEmpty)) + } + private def isSingleRowInnerJoin(join: LogicalJoin): Boolean = { if (join.getJoinType == JoinRelType.INNER) { isSingleRow(join.getRight) || isSingleRow(join.getLeft) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala index b61573c53965d..f964a95f3d521 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala @@ -37,10 +37,11 @@ class DataSetSingleRowJoinRule override def matches(call: RelOptRuleCall): Boolean = { val join = call.rel(0).asInstanceOf[FlinkLogicalJoin] - if (isInnerJoin(join)) { - isSingleRow(join.getRight) || isSingleRow(join.getLeft) - } else { - false + join.getJoinType match { + case JoinRelType.INNER if isSingleRow(join.getLeft) || isSingleRow(join.getRight) => true + case JoinRelType.LEFT if isSingleRow(join.getRight) => true + case JoinRelType.RIGHT if isSingleRow(join.getLeft) => true + case _ => false } } @@ -79,6 +80,7 @@ class DataSetSingleRowJoinRule rel.getRowType, join.getCondition, join.getRowType, + join.getJoinType, description) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala index 5f3dbb4cbaf98..f2812ae381367 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala @@ -19,11 +19,13 @@ package org.apache.flink.table.runtime import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.types.Row import org.apache.flink.util.Collector class MapJoinLeftRunner[IN1, IN2, OUT]( name: String, code: String, + outerJoin: Boolean, returnType: TypeInformation[OUT], broadcastSetName: String) extends MapSideJoinRunner[IN1, IN2, IN2, IN1, OUT](name, code, returnType, broadcastSetName) { @@ -31,7 +33,10 @@ class MapJoinLeftRunner[IN1, IN2, OUT]( override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit = { broadcastSet match { case Some(singleInput) => function.join(multiInput, singleInput, out) + case None if outerJoin => function. + join(multiInput, null.asInstanceOf[IN2], out) case None => } } + } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala index e2d9331187623..2e31008b058f9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala @@ -19,11 +19,13 @@ package org.apache.flink.table.runtime import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.types.Row import org.apache.flink.util.Collector class MapJoinRightRunner[IN1, IN2, OUT]( name: String, code: String, + outerJoin: Boolean, returnType: TypeInformation[OUT], broadcastSetName: String) extends MapSideJoinRunner[IN1, IN2, IN1, IN2, OUT](name, code, returnType, broadcastSetName) { @@ -31,7 +33,10 @@ class MapJoinRightRunner[IN1, IN2, OUT]( override def flatMap(multiInput: IN2, out: Collector[OUT]): Unit = { broadcastSet match { case Some(singleInput) => function.join(singleInput, multiInput, out) + case None if outerJoin => function. + join(null.asInstanceOf[IN1], multiInput, out) case None => } } + } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala new file mode 100644 index 0000000000000..37546975f2201 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala @@ -0,0 +1,424 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.table.utils.TableTestUtil._ +import org.junit.Test + +class DataSetSingleRowJoinTest extends TableTestBase { + + @Test + def testSingleRowJoinWithCalcInput(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Int)]("A", 'a1, 'a2) + + val query = + "SELECT a1, asum " + + "FROM A, (SELECT sum(a1) + sum(a2) AS asum FROM A)" + + val expected = + binaryNode( + "DataSetSingleRowJoin", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a1") + ), + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + batchTableNode(0), + tuples(List(null, null)), + term("values", "a1", "a2") + ), + term("union","a1","a2") + ), + term("select", "SUM(a1) AS $f0", "SUM(a2) AS $f1") + ), + term("select", "+($f0, $f1) AS asum") + ), + term("where", "true"), + term("join", "a1", "asum"), + term("joinType", "NestedLoopInnerJoin") + ) + + util.verifySql(query, expected) + } + + @Test + def testSingleRowEquiJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, String)]("A", 'a1, 'a2) + + val query = + "SELECT a1, a2 " + + "FROM A, (SELECT count(a1) AS cnt FROM A) " + + "WHERE a1 = cnt" + + val expected = + unaryNode( + "DataSetCalc", + binaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a1") + ), + tuples(List(null)), + term("values", "a1") + ), + term("union","a1") + ), + term("select", "COUNT(a1) AS cnt") + ), + term("where", "=(CAST(a1), cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopInnerJoin") + ), + term("select", "a1", "a2") + ) + + util.verifySql(query, expected) + } + + @Test + def testSingleRowNotEquiJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, String)]("A", 'a1, 'a2) + + val query = + "SELECT a1, a2 " + + "FROM A, (SELECT count(a1) AS cnt FROM A) " + + "WHERE a1 < cnt" + + val expected = + unaryNode( + "DataSetCalc", + binaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a1") + ), + tuples(List(null)), + term("values", "a1") + ), + term("union", "a1") + ), + term("select", "COUNT(a1) AS cnt") + ), + term("where", "<(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopInnerJoin") + ), + term("select", "a1", "a2") + ) + + util.verifySql(query, expected) + } + + @Test + def testSingleRowJoinWithComplexPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Long)]("A", 'a1, 'a2) + util.addTable[(Int, Long)]("B", 'b1, 'b2) + + val query = + "SELECT a1, a2, b1, b2 " + + "FROM A, (SELECT min(b1) AS b1, max(b2) AS b2 FROM B) " + + "WHERE a1 < b1 AND a2 = b2" + + val expected = binaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + batchTableNode(1), + tuples(List(null, null)), + term("values", "b1", "b2") + ), + term("union","b1","b2") + ), + term("select", "MIN(b1) AS b1", "MAX(b2) AS b2") + ), + term("where", "AND(<(a1, b1)", "=(a2, b2))"), + term("join", "a1", "a2", "b1", "b2"), + term("joinType", "NestedLoopInnerJoin") + ) + + util.verifySql(query, expected) + } + + @Test + def testRightSingleLeftJoinEqualPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Int)]("A", 'a1, 'a2) + util.addTable[(Int, Int)]("B", 'b1, 'b2) + + val queryLeftJoin = + "SELECT a2 FROM A " + + "LEFT JOIN " + + "(SELECT COUNT(*) AS cnt FROM B) " + + "AS x " + + "ON a1 = cnt" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + term("where", "=(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopLeftJoin") + ), + term("select", "a2") + ) + "\n" + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + + util.verifySql(queryLeftJoin, expected) + } + + @Test + def testRightSingleLeftJoinNotEqualPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Int)]("A", 'a1, 'a2) + util.addTable[(Int, Int)]("B", 'b1, 'b2) + + val queryLeftJoin = + "SELECT a2 FROM A " + + "LEFT JOIN " + + "(SELECT COUNT(*) AS cnt FROM B) " + + "AS x " + + "ON a1 > cnt" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + term("where", ">(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopLeftJoin") + ), + term("select", "a2") + ) + "\n" + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + + util.verifySql(queryLeftJoin, expected) + } + + @Test + def testLeftSingleRightJoinNotEqualPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Long)]("A", 'a1, 'a2) + util.addTable[(Long, Long)]("B", 'b1, 'b2) + + val queryRightJoin = + "SELECT a1 FROM (SELECT COUNT(*) AS cnt FROM B) " + + "RIGHT JOIN A " + + "ON cnt < a2" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + "", + term("where", "<(cnt, a2)"), + term("join", "cnt", "a1", "a2"), + term("joinType", "NestedLoopRightJoin") + ), + term("select", "a1") + ) + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + "\n" + + batchTableNode(0) + + util.verifySql(queryRightJoin, expected) + } + + @Test + def testLeftSingleRightJoinEqualPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Long)]("A", 'a1, 'a2) + util.addTable[(Long, Long)]("B", 'b1, 'b2) + + val queryRightJoin = + "SELECT a1 FROM (SELECT COUNT(*) AS cnt FROM B) " + + "RIGHT JOIN A " + + "ON cnt = a2" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + "", + term("where", "=(cnt, a2)"), + term("join", "cnt", "a1", "a2"), + term("joinType", "NestedLoopRightJoin") + ), + term("select", "a1") + ) + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + "\n" + + batchTableNode(0) + + util.verifySql(queryRightJoin, expected) + } + + @Test + def testSingleRowJoinInnerJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Int)]("A", 'a1, 'a2) + val query = + "SELECT a2, sum(a1) " + + "FROM A " + + "GROUP BY a2 " + + "HAVING sum(a1) > (SELECT sum(a1) * 0.1 FROM A)" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + unaryNode( + "DataSetAggregate", + batchTableNode(0), + term("groupBy", "a2"), + term("select", "a2", "SUM(a1) AS EXPR$1") + ), + term("where", ">(EXPR$1, EXPR$0)"), + term("join", "a2", "EXPR$1", "EXPR$0"), + term("joinType", "NestedLoopInnerJoin") + ), + term("select", "a2", "EXPR$1") + ) + "\n" + + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a1") + ), + tuples(List(null)), term("values", "a1") + ), + term("union", "a1") + ), + term("select", "SUM(a1) AS $f0") + ), + term("select", "*($f0, 0.1) AS EXPR$0") + ) + + util.verifySql(query, expected) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala index 54b4d24a767ae..85cfb18f521a5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala @@ -211,7 +211,7 @@ class DistinctAggregateTest extends TableTestBase { ), term("where", "true"), term("join", "EXPR$0", "EXPR$1"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ) util.verifySql(sqlQuery, expected) @@ -268,7 +268,7 @@ class DistinctAggregateTest extends TableTestBase { ), term("where", "true"), term("join", "EXPR$2, EXPR$0"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ), unaryNode( "DataSetAggregate", @@ -294,7 +294,7 @@ class DistinctAggregateTest extends TableTestBase { ), term("where", "true"), term("join", "EXPR$2", "EXPR$0, EXPR$1"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ), term("select", "EXPR$0, EXPR$1, EXPR$2") ) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala index 8a8c0ce82fc56..3ed44a32b54aa 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala @@ -40,10 +40,8 @@ class JoinITCase( @Test def testJoin(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) - val sqlQuery = "SELECT c, g FROM Table3, Table5 WHERE b = e" val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) @@ -280,8 +278,6 @@ class JoinITCase( tEnv.registerTable("Table3", ds1) tEnv.registerTable("Table5", ds2) - tEnv.sql(sqlQuery).toDataSet[Row].collect() - val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" + "null,Hallo Welt wie\n" + "null,Hallo Welt wie gehts?\n" + "null,ABC\n" + "null,BCD\n" + "null,CDE\n" + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "null,HIJ\n" + @@ -304,8 +300,6 @@ class JoinITCase( tEnv.registerTable("Table3", ds1) tEnv.registerTable("Table5", ds2) - tEnv.sql(sqlQuery).toDataSet[Row].collect() - val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" + "null,Hallo Welt wie\n" + "null,Hallo Welt wie gehts?\n" + "null,ABC\n" + "null,BCD\n" + "null,CDE\n" + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "null,HIJ\n" + @@ -372,10 +366,183 @@ class JoinITCase( val sqlQuery1 = "SELECT * FROM A CROSS JOIN (SELECT count(*) FROM A HAVING count(*) < 0)" val result = tEnv.sql(sqlQuery1).count() - Assert.assertEquals(0, result) } + @Test + def testLeftNullRightJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) RIGHT JOIN A ON a < cnt" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,null", + "2,null", "2,null", + "3,null", "3,null", "3,null", + "4,null", "4,null", "4,null", "4,null", + "5,null", "5,null", "5,null", "5,null", "5,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + + @Test + def testLeftSingleRightJoinEqualPredicate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON cnt = a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,null", "2,null", "2,null", "3,3", "3,3", + "3,3", "4,null", "4,null", "4,null", + "4,null", "5,null", "5,null", "5,null", + "5,null", "5,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testLeftSingleRightJoinNotEqualPredicate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON cnt > a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,3", "2,3", "2,3", "3,null", "3,null", + "3,null", "4,null", "4,null", "4,null", + "4,null", "5,null", "5,null", "5,null", + "5,null", "5,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightNullLeftJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM A LEFT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) ON cnt > a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("A", ds2) + tEnv.registerTable("B", ds1) + + val result = tEnv.sql(sqlQuery) + + val expected = Seq( + "2,null", "3,null", "1,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightSingleLeftJoinEqualPredicate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt FROM A LEFT JOIN (SELECT COUNT(*) AS cnt FROM B) ON cnt = a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + val result = tEnv.sql(sqlQuery) + + val expected = Seq( + "1,null", "2,null", "2,null", "3,3", "3,3", + "3,3", "4,null", "4,null", "4,null", + "4,null", "5,null", "5,null", "5,null", + "5,null", "5,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightSingleLeftJoinNotEqualPredicate(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt FROM A LEFT JOIN (SELECT COUNT(*) AS cnt FROM B) ON cnt < a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + val result = tEnv.sql(sqlQuery) + + val expected = Seq( + "1,null", "2,null", "2,null", "3,null", "3,null", + "3,null", "4,3", "4,3", "4,3", + "4,3", "5,3", "5,3", "5,3", + "5,3", "5,3").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightSingleLeftJoinTwoFields(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt, cnt2 " + + "FROM t1 LEFT JOIN (SELECT COUNT(*) AS cnt,COUNT(*) AS cnt2 FROM t2 ) AS x ON a = cnt" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("t1", ds1) + tEnv.registerTable("t2", ds2) + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,null,null", + "2,null,null", "2,null,null", + "3,3,3", "3,3,3", "3,3,3", + "4,null,null", "4,null,null", "4,null,null", "4,null,null", + "5,null,null", "5,null,null", "5,null,null", "5,null,null", "5,null,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + @Test def testCrossWithUnnest(): Unit = { diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala deleted file mode 100644 index 27e385304528b..0000000000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.api.scala.batch.sql - -import org.apache.flink.api.scala._ -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.utils.TableTestBase -import org.apache.flink.table.utils.TableTestUtil._ -import org.junit.Test - -class SingleRowJoinTest extends TableTestBase { - - @Test - def testSingleRowJoinWithCalcInput(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, Int)]("A", 'a1, 'a2) - - val query = - "SELECT a1, asum " + - "FROM A, (SELECT sum(a1) + sum(a2) AS asum FROM A)" - - val expected = - binaryNode( - "DataSetSingleRowJoin", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - unaryNode( - "DataSetCalc", - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(0), - tuples(List(null, null)), - term("values", "a1", "a2") - ), - term("union","a1","a2") - ), - term("select", "SUM(a1) AS $f0", "SUM(a2) AS $f1") - ), - term("select", "+($f0, $f1) AS asum") - ), - term("where", "true"), - term("join", "a1", "asum"), - term("joinType", "NestedLoopJoin") - ) - - util.verifySql(query, expected) - } - - @Test - def testSingleRowEquiJoin(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, String)]("A", 'a1, 'a2) - - val query = - "SELECT a1, a2 " + - "FROM A, (SELECT count(a1) AS cnt FROM A) " + - "WHERE a1 = cnt" - - val expected = - unaryNode( - "DataSetCalc", - binaryNode( - "DataSetSingleRowJoin", - batchTableNode(0), - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - tuples(List(null)), - term("values", "a1") - ), - term("union","a1") - ), - term("select", "COUNT(a1) AS cnt") - ), - term("where", "=(CAST(a1), cnt)"), - term("join", "a1", "a2", "cnt"), - term("joinType", "NestedLoopJoin") - ), - term("select", "a1", "a2") - ) - - util.verifySql(query, expected) - } - - @Test - def testSingleRowNotEquiJoin(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, String)]("A", 'a1, 'a2) - - val query = - "SELECT a1, a2 " + - "FROM A, (SELECT count(a1) AS cnt FROM A) " + - "WHERE a1 < cnt" - - val expected = - unaryNode( - "DataSetCalc", - binaryNode( - "DataSetSingleRowJoin", - batchTableNode(0), - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - tuples(List(null)), - term("values", "a1") - ), - term("union", "a1") - ), - term("select", "COUNT(a1) AS cnt") - ), - term("where", "<(a1, cnt)"), - term("join", "a1", "a2", "cnt"), - term("joinType", "NestedLoopJoin") - ), - term("select", "a1", "a2") - ) - - util.verifySql(query, expected) - } - - @Test - def testSingleRowJoinWithComplexPredicate(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, Long)]("A", 'a1, 'a2) - util.addTable[(Int, Long)]("B", 'b1, 'b2) - - val query = - "SELECT a1, a2, b1, b2 " + - "FROM A, (SELECT min(b1) AS b1, max(b2) AS b2 FROM B) " + - "WHERE a1 < b1 AND a2 = b2" - - val expected = binaryNode( - "DataSetSingleRowJoin", - batchTableNode(0), - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(1), - tuples(List(null, null)), - term("values", "b1", "b2") - ), - term("union","b1","b2") - ), - term("select", "MIN(b1) AS b1", "MAX(b2) AS b2") - ), - term("where", "AND(<(a1, b1)", "=(a2, b2))"), - term("join", "a1", "a2", "b1", "b2"), - term("joinType", "NestedLoopJoin") - ) - - util.verifySql(query, expected) - } -}