From 30b5bb22233c3b49956763b1c540a1ad16bdc82f Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Wed, 5 Apr 2017 12:24:30 +0300 Subject: [PATCH 01/12] [FLINK-5256] Extend DataSetSingleRowJoin to support Left and Right joins --- .../table/codegen/calls/ScalarOperators.scala | 10 +- .../flink/table/codegen/generated.scala | 6 +- .../nodes/dataset/DataSetSingleRowJoin.scala | 115 ++++-- .../dataSet/DataSetSingleRowJoinRule.scala | 11 +- .../table/runtime/MapJoinLeftRunner.scala | 15 +- .../table/runtime/MapJoinRightRunner.scala | 12 + .../batch/sql/DataSetSingleRowJoinTest.scala | 340 ++++++++++++++++++ .../api/scala/batch/sql/JoinITCase.scala | 260 +++++++++++++- 8 files changed, 732 insertions(+), 37 deletions(-) create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala index 0c5baa6343038..37266f4bba5eb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala @@ -951,6 +951,8 @@ object ScalarOperators { : GeneratedExpression = { val resultTerm = newName("result") val nullTerm = newName("isNull") + val leftNullTerm = newName("leftIsNull") + val rightNullTerm = newName("rightIsNull") val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) val defaultValue = primitiveDefaultValue(resultType) @@ -995,6 +997,8 @@ object ScalarOperators { s""" |${left.code} |${right.code} + |boolean $leftNullTerm = ${left.nullTerm}; + |boolean $rightNullTerm = ${right.nullTerm}; |boolean $nullTerm = ${left.nullTerm} || ${right.nullTerm}; |$resultTypeTerm $resultTerm; |if ($nullTerm) { @@ -1013,7 +1017,11 @@ object ScalarOperators { |""".stripMargin } - GeneratedExpression(resultTerm, nullTerm, resultCode, resultType) + val retval = GeneratedExpression(resultTerm, nullTerm, resultCode, resultType) + retval.leftNullTerm = leftNullTerm + retval.rightNullTerm = rightNullTerm + + return retval } private def internalExprCasting( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala index e26fb646ed0cd..c58f14d43e5e7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala @@ -35,7 +35,11 @@ case class GeneratedExpression( resultTerm: String, nullTerm: String, code: String, - resultType: TypeInformation[_]) + resultType: TypeInformation[_]) { + + var leftNullTerm: String = _ + var rightNullTerm: String = _ +} object GeneratedExpression { val ALWAYS_NULL = "true" diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index b7d1a4bfb60c6..91e1d5fe4cf06 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -20,6 +20,7 @@ package org.apache.flink.table.plan.nodes.dataset import org.apache.calcite.plan._ import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.rel.core.JoinRelType import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.calcite.rex.RexNode @@ -44,9 +45,11 @@ class DataSetSingleRowJoin( leftNode: RelNode, rightNode: RelNode, leftIsSingle: Boolean, + rightIsSingle: Boolean, rowRelDataType: RelDataType, joinCondition: RexNode, joinRowType: RelDataType, + joinType: JoinRelType, ruleDescription: String) extends BiRel(cluster, traitSet, leftNode, rightNode) with DataSetRel { @@ -60,9 +63,11 @@ class DataSetSingleRowJoin( inputs.get(0), inputs.get(1), leftIsSingle, + rightIsSingle, getRowType, joinCondition, joinRowType, + joinType, ruleDescription) } @@ -97,7 +102,6 @@ class DataSetSingleRowJoin( tableEnv.getConfig, leftDataSet.getType, rightDataSet.getType, - leftIsSingle, joinCondition, broadcastSetName) @@ -118,14 +122,18 @@ class DataSetSingleRowJoin( config: TableConfig, inputType1: TypeInformation[Row], inputType2: TypeInformation[Row], - firstIsSingle: Boolean, joinCondition: RexNode, broadcastInputSetName: String) : FlatMapFunction[Row, Row] = { + val nullCheck = joinType match { + case JoinRelType.LEFT | JoinRelType.RIGHT => true + case _ => false + } + val codeGenerator = new CodeGenerator( config, - false, + nullCheck, inputType1, Some(inputType2)) @@ -138,13 +146,62 @@ class DataSetSingleRowJoin( val condition = codeGenerator.generateExpression(joinCondition) val joinMethodBody = - s""" - |${condition.code} - |if (${condition.resultTerm}) { - | ${conversion.code} - | ${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); - |} - |""".stripMargin + if (joinType == JoinRelType.INNER) { + s""" + |${condition.code} + |if (${condition.resultTerm}) { + | ${conversion.code} + | ${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); + |} + |""".stripMargin + } else { + val singleNode = + if (rightIsSingle) { + rightNode + } + else { + leftNode + } + + val notSuitedToCondition = singleNode + .getRowType + .getFieldList + .map(field => getRowType.getFieldNames.indexOf(field.getName)) + .map(i => s"${conversion.resultTerm}.setField($i,null);") + + if (joinType == JoinRelType.LEFT && leftIsSingle) { + s""" + |${condition.code} + |${conversion.code} + |if(!${condition.resultTerm}){ + |${notSuitedToCondition.mkString("\n")} + |} + |if(!${condition.leftNullTerm}){ + |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); + |} + |""".stripMargin + } else if (joinType == JoinRelType.RIGHT && rightIsSingle){ + s""" + |${condition.code} + |${conversion.code} + |if(!${condition.resultTerm}){ + |${notSuitedToCondition.mkString("\n")} + |} + |if(!${condition.leftNullTerm} && ${condition.resultTerm}){ + |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); + |} + |""".stripMargin + } else { + s""" + |${condition.code} + |${conversion.code} + |if(!${condition.resultTerm}){ + |${notSuitedToCondition.mkString("\n")} + |} + |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); + |""".stripMargin + } + } val genFunction = codeGenerator.generateFunction( ruleDescription, @@ -152,18 +209,34 @@ class DataSetSingleRowJoin( joinMethodBody, returnType) - if (firstIsSingle) { - new MapJoinRightRunner[Row, Row, Row]( - genFunction.name, - genFunction.code, - genFunction.returnType, - broadcastInputSetName) + if (joinType == JoinRelType.RIGHT) { + if (leftIsSingle) { + new MapJoinRightRunner[Row, Row, Row]( + genFunction.name, + genFunction.code, + genFunction.returnType, + broadcastInputSetName) + } else { + new MapJoinLeftRunner[Row, Row, Row]( + genFunction.name, + genFunction.code, + genFunction.returnType, + broadcastInputSetName) + } } else { - new MapJoinLeftRunner[Row, Row, Row]( - genFunction.name, - genFunction.code, - genFunction.returnType, - broadcastInputSetName) + if (rightIsSingle) { + new MapJoinLeftRunner[Row, Row, Row]( + genFunction.name, + genFunction.code, + genFunction.returnType, + broadcastInputSetName) + } else { + new MapJoinRightRunner[Row, Row, Row]( + genFunction.name, + genFunction.code, + genFunction.returnType, + broadcastInputSetName) + } } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala index b61573c53965d..66e3b1b414fdb 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala @@ -37,10 +37,10 @@ class DataSetSingleRowJoinRule override def matches(call: RelOptRuleCall): Boolean = { val join = call.rel(0).asInstanceOf[FlinkLogicalJoin] - if (isInnerJoin(join)) { - isSingleRow(join.getRight) || isSingleRow(join.getLeft) - } else { - false + join.getJoinType match { + case JoinRelType.INNER | JoinRelType.LEFT | JoinRelType.RIGHT => + isSingleRow(join.getRight) || isSingleRow(join.getLeft) + case _ => false } } @@ -69,6 +69,7 @@ class DataSetSingleRowJoinRule val dataSetLeftNode = RelOptRule.convert(join.getLeft, FlinkConventions.DATASET) val dataSetRightNode = RelOptRule.convert(join.getRight, FlinkConventions.DATASET) val leftIsSingle = isSingleRow(join.getLeft) + val rightIsSingle = isSingleRow(join.getRight) new DataSetSingleRowJoin( rel.getCluster, @@ -76,9 +77,11 @@ class DataSetSingleRowJoinRule dataSetLeftNode, dataSetRightNode, leftIsSingle, + rightIsSingle, rel.getRowType, join.getCondition, join.getRowType, + join.getJoinType, description) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala index 5f3dbb4cbaf98..d7d23033ba6a8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.runtime import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.types.Row import org.apache.flink.util.Collector class MapJoinLeftRunner[IN1, IN2, OUT]( @@ -31,7 +32,19 @@ class MapJoinLeftRunner[IN1, IN2, OUT]( override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit = { broadcastSet match { case Some(singleInput) => function.join(multiInput, singleInput, out) - case None => + case None => { + if (isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { + val inputRow = multiInput.asInstanceOf[Row] + val countNullRecords = returnType.getTotalFields - inputRow.getArity + val nullRecords = new Row(countNullRecords) + function.join(multiInput, nullRecords.asInstanceOf[IN2], out) + } + } } } + + private def isRowClass(obj: Any) = obj match { + case r: Row => true + case _ => false + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala index e2d9331187623..1bc72d5983306 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.runtime import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.types.Row import org.apache.flink.util.Collector class MapJoinRightRunner[IN1, IN2, OUT]( @@ -32,6 +33,17 @@ class MapJoinRightRunner[IN1, IN2, OUT]( broadcastSet match { case Some(singleInput) => function.join(singleInput, multiInput, out) case None => + if (isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { + val inputRow = multiInput.asInstanceOf[Row] + val countNullRecords = returnType.getTotalFields - inputRow.getArity + val nullRecords= new Row(countNullRecords) + function.join(nullRecords.asInstanceOf[IN1], multiInput, out) + } } } + + private def isRowClass(obj: Any) = obj match { + case r: Row => true + case _ => false + } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala new file mode 100644 index 0000000000000..a32b0604f3125 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala @@ -0,0 +1,340 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.api.scala.batch.sql + +import org.apache.flink.api.scala._ +import org.apache.flink.table.api.scala._ +import org.apache.flink.table.utils.TableTestBase +import org.apache.flink.table.utils.TableTestUtil._ +import org.junit.Test + +class DataSetSingleRowJoinTest extends TableTestBase { + + @Test + def testSingleRowJoinWithCalcInput(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Int)]("A", 'a1, 'a2) + + val query = + "SELECT a1, asum " + + "FROM A, (SELECT sum(a1) + sum(a2) AS asum FROM A)" + + val expected = + binaryNode( + "DataSetSingleRowJoin", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a1") + ), + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + batchTableNode(0), + tuples(List(null, null)), + term("values", "a1", "a2") + ), + term("union","a1","a2") + ), + term("select", "SUM(a1) AS $f0", "SUM(a2) AS $f1") + ), + term("select", "+($f0, $f1) AS asum") + ), + term("where", "true"), + term("join", "a1", "asum"), + term("joinType", "NestedLoopJoin") + ) + + util.verifySql(query, expected) + } + + @Test + def testSingleRowEquiJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, String)]("A", 'a1, 'a2) + + val query = + "SELECT a1, a2 " + + "FROM A, (SELECT count(a1) AS cnt FROM A) " + + "WHERE a1 = cnt" + + val expected = + unaryNode( + "DataSetCalc", + binaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a1") + ), + tuples(List(null)), + term("values", "a1") + ), + term("union","a1") + ), + term("select", "COUNT(a1) AS cnt") + ), + term("where", "=(CAST(a1), cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopJoin") + ), + term("select", "a1", "a2") + ) + + util.verifySql(query, expected) + } + + @Test + def testSingleRowNotEquiJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, String)]("A", 'a1, 'a2) + + val query = + "SELECT a1, a2 " + + "FROM A, (SELECT count(a1) AS cnt FROM A) " + + "WHERE a1 < cnt" + + val expected = + unaryNode( + "DataSetCalc", + binaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a1") + ), + tuples(List(null)), + term("values", "a1") + ), + term("union", "a1") + ), + term("select", "COUNT(a1) AS cnt") + ), + term("where", "<(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopJoin") + ), + term("select", "a1", "a2") + ) + + util.verifySql(query, expected) + } + + @Test + def testSingleRowJoinWithComplexPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Long)]("A", 'a1, 'a2) + util.addTable[(Int, Long)]("B", 'b1, 'b2) + + val query = + "SELECT a1, a2, b1, b2 " + + "FROM A, (SELECT min(b1) AS b1, max(b2) AS b2 FROM B) " + + "WHERE a1 < b1 AND a2 = b2" + + val expected = binaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + batchTableNode(1), + tuples(List(null, null)), + term("values", "b1", "b2") + ), + term("union","b1","b2") + ), + term("select", "MIN(b1) AS b1", "MAX(b2) AS b2") + ), + term("where", "AND(<(a1, b1)", "=(a2, b2))"), + term("join", "a1", "a2", "b1", "b2"), + term("joinType", "NestedLoopJoin") + ) + + util.verifySql(query, expected) + } + + @Test + def testSingleRowJoinLeftOuterJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Int)]("A", 'a1, 'a2) + util.addTable[(Int, Int)]("B", 'b1, 'b2) + + val queryLeftJoin = + "SELECT a2 FROM A " + + "LEFT JOIN " + + "(SELECT COUNT(*) AS cnt FROM B) " + + "AS x " + + "ON a1 < cnt" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + term("where", "<(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopJoin") + ), + term("select", "a2") + ) + "\n" + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + + util.verifySql(queryLeftJoin, expected) + } + + @Test + def testSingleRowJoinRightOuterJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Int)]("A", 'a1, 'a2) + util.addTable[(Int, Int)]("B", 'b1, 'b2) + + val queryRightJoin = + "SELECT a2 FROM A " + + "RIGHT JOIN " + + "(SELECT COUNT(*) AS cnt FROM B) " + + "AS x " + + "ON a1 < cnt" + + //val queryRightJoin = + // "SELECT a2 FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON a1 < cnt" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + term("where", "<(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopJoin") + ), + term("select", "a2") + ) + "\n" + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + + util.verifySql(queryRightJoin, expected) + } + + @Test + def testSingleRowJoinInnerJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Int, Int)]("A", 'a1, 'a2) + val query = + "SELECT a2, sum(a1) " + + "FROM A " + + "GROUP BY a2 " + + "HAVING sum(a1) > (SELECT sum(a1) * 0.1 FROM A)" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + unaryNode( + "DataSetAggregate", + batchTableNode(0), + term("groupBy", "a2"), + term("select", "a2", "SUM(a1) AS EXPR$1") + ), + term("where", ">(EXPR$1, EXPR$0)"), + term("join", "a2", "EXPR$1", "EXPR$0"), + term("joinType", "NestedLoopJoin") + ), + term("select", "a2", "EXPR$1") + ) + "\n" + + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(0), + term("select", "a1") + ), + tuples(List(null)), term("values", "a1") + ), + term("union", "a1") + ), + term("select", "SUM(a1) AS $f0") + ), + term("select", "*($f0, 0.1) AS EXPR$0") + ) + + util.verifySql(query, expected) + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala index 8a8c0ce82fc56..0f09b3c19c1ac 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala @@ -40,10 +40,8 @@ class JoinITCase( @Test def testJoin(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) - val sqlQuery = "SELECT c, g FROM Table3, Table5 WHERE b = e" val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) @@ -280,8 +278,6 @@ class JoinITCase( tEnv.registerTable("Table3", ds1) tEnv.registerTable("Table5", ds2) - tEnv.sql(sqlQuery).toDataSet[Row].collect() - val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" + "null,Hallo Welt wie\n" + "null,Hallo Welt wie gehts?\n" + "null,ABC\n" + "null,BCD\n" + "null,CDE\n" + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "null,HIJ\n" + @@ -304,8 +300,6 @@ class JoinITCase( tEnv.registerTable("Table3", ds1) tEnv.registerTable("Table5", ds2) - tEnv.sql(sqlQuery).toDataSet[Row].collect() - val expected = "Hi,Hallo\n" + "Hello,Hallo Welt\n" + "Hello world,Hallo Welt\n" + "null,Hallo Welt wie\n" + "null,Hallo Welt wie gehts?\n" + "null,ABC\n" + "null,BCD\n" + "null,CDE\n" + "null,DEF\n" + "null,EFG\n" + "null,FGH\n" + "null,GHI\n" + "null,HIJ\n" + @@ -370,10 +364,258 @@ class JoinITCase( val table = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a1, 'a2, 'a3) tEnv.registerTable("A", table) - val sqlQuery1 = "SELECT * FROM A CROSS JOIN (SELECT count(*) FROM A HAVING count(*) < 0)" - val result = tEnv.sql(sqlQuery1).count() + val sqlQuery1 = "SELECT * FROM A CROSS JOIN " + + "(SELECT count(*) FROM A HAVING count(*) < 0)" + val result = tEnv.sql(sqlQuery1) + val expected =Seq( + "2,2,Hello,null", + "1,1,Hi,null", + "3,2,Hello world,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testLeftNullLeftJoin (): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val sqlQuery = + "SELECT a, cnt " + + "FROM" + + " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + + "LEFT JOIN A " + + "ON cnt > a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + val result = tEnv.sql(sqlQuery).collect() + val resultSize = result.size + + Assert.assertEquals( + s"Expected empty result, but actual size result = $resultSize;\n[${result.mkString(",")}]", + resultSize,0) + } + + @Test + def testLeftNullRightJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM" + + " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + + "RIGHT JOIN A " + + "ON a < cnt" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,null", + "2,null", "2,null", + "3,null", "3,null", "3,null", + "4,null", "4,null", "4,null", "4,null", + "5,null", "5,null", "5,null", "5,null", "5,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testLeftSingleLeftJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM" + + " (SELECT COUNT(*) AS cnt FROM A) " + + "LEFT JOIN B " + + "ON cnt > a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv)as('a, 'b, 'c) + tEnv.registerTable("A", ds2) + tEnv.registerTable("B", ds1) + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,3", "2,3", "2,3", "3,null", "3,null", + "3,null", "4,null", "4,null", "4,null", + "4,null", "5,null", "5,null", "5,null", + "5,null", "5,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } - Assert.assertEquals(0, result) + @Test + def testLeftSingleRightJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM" + + " (SELECT COUNT(*) AS cnt FROM B) " + + "RIGHT JOIN A " + + "ON cnt > a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,3", "2,3", "2,3", "3,null", "3,null", + "3,null", "4,null", "4,null", "4,null", + "4,null", "5,null", "5,null", "5,null", + "5,null", "5,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightNullLeftJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM" + + " A " + + "LEFT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + + "ON cnt > a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("A", ds2) + tEnv.registerTable("B", ds1) + + val result = tEnv.sql(sqlQuery) + + val expected = Seq( + "2,null", "3,null", "1,null").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightNullRightJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM A " + + "RIGHT JOIN" + + " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + + "ON cnt > a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) + tEnv.registerTable("A", ds2) + tEnv.registerTable("B", ds1) + + val result = tEnv.sql(sqlQuery).collect() + val resultSize = result.size + + Assert.assertEquals( + s"Expected empty result, but actual size result = $resultSize;\n[${result.mkString(",")}]", + resultSize,0) + } + + @Test + def testRightSingleLeftJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM" + + " A " + + "LEFT JOIN (SELECT COUNT(*) AS cnt FROM B) " + + "ON cnt < a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + val result = tEnv.sql(sqlQuery) + + val expected = Seq( + "1,null", "2,null", "2,null", "3,null", "3,null", + "3,null", "4,3", "4,3", "4,3", + "4,3", "5,3", "5,3", "5,3", + "5,3", "5,3").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightSingleRightJoin(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a, cnt " + + "FROM" + + " A " + + "RIGHT JOIN (SELECT COUNT(*) AS cnt FROM B) " + + "ON cnt > a" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,3", "2,3", "2,3").mkString("\n") + + val results = result.toDataSet[Row].collect() + + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + + @Test + def testRightSingleLeftJoinTwoFields(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + val sqlQuery = + "SELECT a,cnt, cnt2 " + + "FROM t1 " + + "LEFT JOIN (SELECT COUNT(*) AS cnt,COUNT(*) AS cnt2 FROM t2 ) AS x " + + "ON a > cnt" + + val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("t1", ds1) + tEnv.registerTable("t2", ds2) + + val result = tEnv.sql(sqlQuery) + val expected = Seq( + "1,null,null", + "2,null,null", "2,null,null", + "3,null,null", "3,null,null", "3,null,null", + "4,3,3", "4,3,3", "4,3,3", "4,3,3", + "5,3,3", "5,3,3", "5,3,3", "5,3,3", "5,3,3").mkString("\n") + + val results = result.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) } @Test From 49fe92646ab615366bc3ff55df3b9f152f2326d0 Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Wed, 5 Apr 2017 15:31:08 +0300 Subject: [PATCH 02/12] [FLINK-5256] Extend DataSetSingleRowJoin to support Left and Right joins --- .../scala/batch/sql/SingleRowJoinTest.scala | 195 ------------------ 1 file changed, 195 deletions(-) delete mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala deleted file mode 100644 index 27e385304528b..0000000000000 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/SingleRowJoinTest.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.api.scala.batch.sql - -import org.apache.flink.api.scala._ -import org.apache.flink.table.api.scala._ -import org.apache.flink.table.utils.TableTestBase -import org.apache.flink.table.utils.TableTestUtil._ -import org.junit.Test - -class SingleRowJoinTest extends TableTestBase { - - @Test - def testSingleRowJoinWithCalcInput(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, Int)]("A", 'a1, 'a2) - - val query = - "SELECT a1, asum " + - "FROM A, (SELECT sum(a1) + sum(a2) AS asum FROM A)" - - val expected = - binaryNode( - "DataSetSingleRowJoin", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - unaryNode( - "DataSetCalc", - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(0), - tuples(List(null, null)), - term("values", "a1", "a2") - ), - term("union","a1","a2") - ), - term("select", "SUM(a1) AS $f0", "SUM(a2) AS $f1") - ), - term("select", "+($f0, $f1) AS asum") - ), - term("where", "true"), - term("join", "a1", "asum"), - term("joinType", "NestedLoopJoin") - ) - - util.verifySql(query, expected) - } - - @Test - def testSingleRowEquiJoin(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, String)]("A", 'a1, 'a2) - - val query = - "SELECT a1, a2 " + - "FROM A, (SELECT count(a1) AS cnt FROM A) " + - "WHERE a1 = cnt" - - val expected = - unaryNode( - "DataSetCalc", - binaryNode( - "DataSetSingleRowJoin", - batchTableNode(0), - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - tuples(List(null)), - term("values", "a1") - ), - term("union","a1") - ), - term("select", "COUNT(a1) AS cnt") - ), - term("where", "=(CAST(a1), cnt)"), - term("join", "a1", "a2", "cnt"), - term("joinType", "NestedLoopJoin") - ), - term("select", "a1", "a2") - ) - - util.verifySql(query, expected) - } - - @Test - def testSingleRowNotEquiJoin(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, String)]("A", 'a1, 'a2) - - val query = - "SELECT a1, a2 " + - "FROM A, (SELECT count(a1) AS cnt FROM A) " + - "WHERE a1 < cnt" - - val expected = - unaryNode( - "DataSetCalc", - binaryNode( - "DataSetSingleRowJoin", - batchTableNode(0), - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(0), - term("select", "a1") - ), - tuples(List(null)), - term("values", "a1") - ), - term("union", "a1") - ), - term("select", "COUNT(a1) AS cnt") - ), - term("where", "<(a1, cnt)"), - term("join", "a1", "a2", "cnt"), - term("joinType", "NestedLoopJoin") - ), - term("select", "a1", "a2") - ) - - util.verifySql(query, expected) - } - - @Test - def testSingleRowJoinWithComplexPredicate(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, Long)]("A", 'a1, 'a2) - util.addTable[(Int, Long)]("B", 'b1, 'b2) - - val query = - "SELECT a1, a2, b1, b2 " + - "FROM A, (SELECT min(b1) AS b1, max(b2) AS b2 FROM B) " + - "WHERE a1 < b1 AND a2 = b2" - - val expected = binaryNode( - "DataSetSingleRowJoin", - batchTableNode(0), - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - batchTableNode(1), - tuples(List(null, null)), - term("values", "b1", "b2") - ), - term("union","b1","b2") - ), - term("select", "MIN(b1) AS b1", "MAX(b2) AS b2") - ), - term("where", "AND(<(a1, b1)", "=(a2, b2))"), - term("join", "a1", "a2", "b1", "b2"), - term("joinType", "NestedLoopJoin") - ) - - util.verifySql(query, expected) - } -} From 18ce31971c6e088941837d63df1c27110d63fc02 Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Wed, 19 Apr 2017 16:52:19 +0300 Subject: [PATCH 03/12] [FLINK-5256] Change tests accordingly FLINK-5520 --- .../batch/sql/DataSetSingleRowJoinTest.scala | 93 ------------------- .../api/scala/batch/sql/JoinITCase.scala | 42 ++++----- 2 files changed, 21 insertions(+), 114 deletions(-) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala index a32b0604f3125..7c9b0b05590b7 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala @@ -193,99 +193,6 @@ class DataSetSingleRowJoinTest extends TableTestBase { util.verifySql(query, expected) } - @Test - def testSingleRowJoinLeftOuterJoin(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, Int)]("A", 'a1, 'a2) - util.addTable[(Int, Int)]("B", 'b1, 'b2) - - val queryLeftJoin = - "SELECT a2 FROM A " + - "LEFT JOIN " + - "(SELECT COUNT(*) AS cnt FROM B) " + - "AS x " + - "ON a1 < cnt" - - val expected = - unaryNode( - "DataSetCalc", - unaryNode( - "DataSetSingleRowJoin", - batchTableNode(0), - term("where", "<(a1, cnt)"), - term("join", "a1", "a2", "cnt"), - term("joinType", "NestedLoopJoin") - ), - term("select", "a2") - ) + "\n" + - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(1), - term("select", "0 AS $f0")), - tuples(List(null)), term("values", "$f0") - ), - term("union", "$f0") - ), - term("select", "COUNT(*) AS cnt") - ) - - util.verifySql(queryLeftJoin, expected) - } - - @Test - def testSingleRowJoinRightOuterJoin(): Unit = { - val util = batchTestUtil() - util.addTable[(Int, Int)]("A", 'a1, 'a2) - util.addTable[(Int, Int)]("B", 'b1, 'b2) - - val queryRightJoin = - "SELECT a2 FROM A " + - "RIGHT JOIN " + - "(SELECT COUNT(*) AS cnt FROM B) " + - "AS x " + - "ON a1 < cnt" - - //val queryRightJoin = - // "SELECT a2 FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON a1 < cnt" - - val expected = - unaryNode( - "DataSetCalc", - unaryNode( - "DataSetSingleRowJoin", - batchTableNode(0), - term("where", "<(a1, cnt)"), - term("join", "a1", "a2", "cnt"), - term("joinType", "NestedLoopJoin") - ), - term("select", "a2") - ) + "\n" + - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(1), - term("select", "0 AS $f0")), - tuples(List(null)), term("values", "$f0") - ), - term("union", "$f0") - ), - term("select", "COUNT(*) AS cnt") - ) - - util.verifySql(queryRightJoin, expected) - } - @Test def testSingleRowJoinInnerJoin(): Unit = { val util = batchTestUtil() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala index 0f09b3c19c1ac..a22232df7ae03 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala @@ -386,7 +386,7 @@ class JoinITCase( "FROM" + " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + "LEFT JOIN A " + - "ON cnt > a" + "ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -410,7 +410,7 @@ class JoinITCase( "FROM" + " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + "RIGHT JOIN A " + - "ON a < cnt" + "ON a = cnt" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -440,7 +440,7 @@ class JoinITCase( "FROM" + " (SELECT COUNT(*) AS cnt FROM A) " + "LEFT JOIN B " + - "ON cnt > a" + "ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv)as('a, 'b, 'c) @@ -449,8 +449,8 @@ class JoinITCase( val result = tEnv.sql(sqlQuery) val expected = Seq( - "1,3", "2,3", "2,3", "3,null", "3,null", - "3,null", "4,null", "4,null", "4,null", + "1,null", "2,null", "2,null", "3,3", "3,3", + "3,3", "4,null", "4,null", "4,null", "4,null", "5,null", "5,null", "5,null", "5,null", "5,null").mkString("\n") @@ -468,7 +468,7 @@ class JoinITCase( "FROM" + " (SELECT COUNT(*) AS cnt FROM B) " + "RIGHT JOIN A " + - "ON cnt > a" + "ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -477,8 +477,8 @@ class JoinITCase( val result = tEnv.sql(sqlQuery) val expected = Seq( - "1,3", "2,3", "2,3", "3,null", "3,null", - "3,null", "4,null", "4,null", "4,null", + "1,null", "2,null", "2,null", "3,3", "3,3", + "3,3", "4,null", "4,null", "4,null", "4,null", "5,null", "5,null", "5,null", "5,null", "5,null").mkString("\n") @@ -496,7 +496,7 @@ class JoinITCase( "FROM" + " A " + "LEFT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + - "ON cnt > a" + "ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) @@ -522,7 +522,7 @@ class JoinITCase( "FROM A " + "RIGHT JOIN" + " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + - "ON cnt > a" + "ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) @@ -546,7 +546,7 @@ class JoinITCase( "FROM" + " A " + "LEFT JOIN (SELECT COUNT(*) AS cnt FROM B) " + - "ON cnt < a" + "ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -556,10 +556,10 @@ class JoinITCase( val result = tEnv.sql(sqlQuery) val expected = Seq( - "1,null", "2,null", "2,null", "3,null", "3,null", - "3,null", "4,3", "4,3", "4,3", - "4,3", "5,3", "5,3", "5,3", - "5,3", "5,3").mkString("\n") + "1,null", "2,null", "2,null", "3,3", "3,3", + "3,3", "4,null", "4,null", "4,null", + "4,null", "5,null", "5,null", "5,null", + "5,null", "5,null").mkString("\n") val results = result.toDataSet[Row].collect() @@ -575,7 +575,7 @@ class JoinITCase( "FROM" + " A " + "RIGHT JOIN (SELECT COUNT(*) AS cnt FROM B) " + - "ON cnt > a" + "ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -584,7 +584,7 @@ class JoinITCase( val result = tEnv.sql(sqlQuery) val expected = Seq( - "1,3", "2,3", "2,3").mkString("\n") + "3,3", "3,3", "3,3").mkString("\n") val results = result.toDataSet[Row].collect() @@ -599,7 +599,7 @@ class JoinITCase( "SELECT a,cnt, cnt2 " + "FROM t1 " + "LEFT JOIN (SELECT COUNT(*) AS cnt,COUNT(*) AS cnt2 FROM t2 ) AS x " + - "ON a > cnt" + "ON a = cnt" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -610,9 +610,9 @@ class JoinITCase( val expected = Seq( "1,null,null", "2,null,null", "2,null,null", - "3,null,null", "3,null,null", "3,null,null", - "4,3,3", "4,3,3", "4,3,3", "4,3,3", - "5,3,3", "5,3,3", "5,3,3", "5,3,3", "5,3,3").mkString("\n") + "3,3,3", "3,3,3", "3,3,3", + "4,null,null", "4,null,null", "4,null,null", "4,null,null", + "5,null,null", "5,null,null", "5,null,null", "5,null,null", "5,null,null").mkString("\n") val results = result.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) From d49f28314f12becf164502b363a537d5d8235b02 Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Thu, 4 May 2017 20:54:50 +0300 Subject: [PATCH 04/12] [FLINK-5256] Extend DataSetSingleRowJoin to support Left and Right joins --- .../table/codegen/calls/ScalarOperators.scala | 10 +-- .../nodes/dataset/DataSetSingleRowJoin.scala | 48 ++----------- .../dataSet/DataSetSingleRowJoinRule.scala | 7 +- .../table/runtime/MapJoinLeftRunner.scala | 5 +- .../table/runtime/MapJoinRightRunner.scala | 5 +- .../api/scala/batch/sql/JoinITCase.scala | 68 ++++--------------- 6 files changed, 28 insertions(+), 115 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala index 37266f4bba5eb..0c5baa6343038 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala @@ -951,8 +951,6 @@ object ScalarOperators { : GeneratedExpression = { val resultTerm = newName("result") val nullTerm = newName("isNull") - val leftNullTerm = newName("leftIsNull") - val rightNullTerm = newName("rightIsNull") val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType) val defaultValue = primitiveDefaultValue(resultType) @@ -997,8 +995,6 @@ object ScalarOperators { s""" |${left.code} |${right.code} - |boolean $leftNullTerm = ${left.nullTerm}; - |boolean $rightNullTerm = ${right.nullTerm}; |boolean $nullTerm = ${left.nullTerm} || ${right.nullTerm}; |$resultTypeTerm $resultTerm; |if ($nullTerm) { @@ -1017,11 +1013,7 @@ object ScalarOperators { |""".stripMargin } - val retval = GeneratedExpression(resultTerm, nullTerm, resultCode, resultType) - retval.leftNullTerm = leftNullTerm - retval.rightNullTerm = rightNullTerm - - return retval + GeneratedExpression(resultTerm, nullTerm, resultCode, resultType) } private def internalExprCasting( diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index 91e1d5fe4cf06..9491456be98d9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -45,7 +45,6 @@ class DataSetSingleRowJoin( leftNode: RelNode, rightNode: RelNode, leftIsSingle: Boolean, - rightIsSingle: Boolean, rowRelDataType: RelDataType, joinCondition: RexNode, joinRowType: RelDataType, @@ -63,7 +62,6 @@ class DataSetSingleRowJoin( inputs.get(0), inputs.get(1), leftIsSingle, - rightIsSingle, getRowType, joinCondition, joinRowType, @@ -156,7 +154,7 @@ class DataSetSingleRowJoin( |""".stripMargin } else { val singleNode = - if (rightIsSingle) { + if (!leftIsSingle) { rightNode } else { @@ -169,29 +167,6 @@ class DataSetSingleRowJoin( .map(field => getRowType.getFieldNames.indexOf(field.getName)) .map(i => s"${conversion.resultTerm}.setField($i,null);") - if (joinType == JoinRelType.LEFT && leftIsSingle) { - s""" - |${condition.code} - |${conversion.code} - |if(!${condition.resultTerm}){ - |${notSuitedToCondition.mkString("\n")} - |} - |if(!${condition.leftNullTerm}){ - |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); - |} - |""".stripMargin - } else if (joinType == JoinRelType.RIGHT && rightIsSingle){ - s""" - |${condition.code} - |${conversion.code} - |if(!${condition.resultTerm}){ - |${notSuitedToCondition.mkString("\n")} - |} - |if(!${condition.leftNullTerm} && ${condition.resultTerm}){ - |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); - |} - |""".stripMargin - } else { s""" |${condition.code} |${conversion.code} @@ -201,7 +176,6 @@ class DataSetSingleRowJoin( |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); |""".stripMargin } - } val genFunction = codeGenerator.generateFunction( ruleDescription, @@ -209,35 +183,21 @@ class DataSetSingleRowJoin( joinMethodBody, returnType) - if (joinType == JoinRelType.RIGHT) { - if (leftIsSingle) { - new MapJoinRightRunner[Row, Row, Row]( - genFunction.name, - genFunction.code, - genFunction.returnType, - broadcastInputSetName) - } else { - new MapJoinLeftRunner[Row, Row, Row]( - genFunction.name, - genFunction.code, - genFunction.returnType, - broadcastInputSetName) - } - } else { - if (rightIsSingle) { + if (!leftIsSingle) { new MapJoinLeftRunner[Row, Row, Row]( genFunction.name, genFunction.code, + nullCheck, genFunction.returnType, broadcastInputSetName) } else { new MapJoinRightRunner[Row, Row, Row]( genFunction.name, genFunction.code, + nullCheck, genFunction.returnType, broadcastInputSetName) } - } } private def getMapOperatorName: String = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala index 66e3b1b414fdb..f964a95f3d521 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/DataSetSingleRowJoinRule.scala @@ -38,8 +38,9 @@ class DataSetSingleRowJoinRule val join = call.rel(0).asInstanceOf[FlinkLogicalJoin] join.getJoinType match { - case JoinRelType.INNER | JoinRelType.LEFT | JoinRelType.RIGHT => - isSingleRow(join.getRight) || isSingleRow(join.getLeft) + case JoinRelType.INNER if isSingleRow(join.getLeft) || isSingleRow(join.getRight) => true + case JoinRelType.LEFT if isSingleRow(join.getRight) => true + case JoinRelType.RIGHT if isSingleRow(join.getLeft) => true case _ => false } } @@ -69,7 +70,6 @@ class DataSetSingleRowJoinRule val dataSetLeftNode = RelOptRule.convert(join.getLeft, FlinkConventions.DATASET) val dataSetRightNode = RelOptRule.convert(join.getRight, FlinkConventions.DATASET) val leftIsSingle = isSingleRow(join.getLeft) - val rightIsSingle = isSingleRow(join.getRight) new DataSetSingleRowJoin( rel.getCluster, @@ -77,7 +77,6 @@ class DataSetSingleRowJoinRule dataSetLeftNode, dataSetRightNode, leftIsSingle, - rightIsSingle, rel.getRowType, join.getCondition, join.getRowType, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala index d7d23033ba6a8..a902b9b569d4e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala @@ -25,7 +25,8 @@ import org.apache.flink.util.Collector class MapJoinLeftRunner[IN1, IN2, OUT]( name: String, code: String, - returnType: TypeInformation[OUT], + outerJoin: Boolean, + @transient returnType: TypeInformation[OUT], broadcastSetName: String) extends MapSideJoinRunner[IN1, IN2, IN2, IN1, OUT](name, code, returnType, broadcastSetName) { @@ -33,7 +34,7 @@ class MapJoinLeftRunner[IN1, IN2, OUT]( broadcastSet match { case Some(singleInput) => function.join(multiInput, singleInput, out) case None => { - if (isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { + if (outerJoin && isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { val inputRow = multiInput.asInstanceOf[Row] val countNullRecords = returnType.getTotalFields - inputRow.getArity val nullRecords = new Row(countNullRecords) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala index 1bc72d5983306..c2287fae6f1a3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala @@ -25,7 +25,8 @@ import org.apache.flink.util.Collector class MapJoinRightRunner[IN1, IN2, OUT]( name: String, code: String, - returnType: TypeInformation[OUT], + outerJoin: Boolean, + @transient returnType: TypeInformation[OUT], broadcastSetName: String) extends MapSideJoinRunner[IN1, IN2, IN1, IN2, OUT](name, code, returnType, broadcastSetName) { @@ -33,7 +34,7 @@ class MapJoinRightRunner[IN1, IN2, OUT]( broadcastSet match { case Some(singleInput) => function.join(singleInput, multiInput, out) case None => - if (isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { + if (outerJoin && isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { val inputRow = multiInput.asInstanceOf[Row] val countNullRecords = returnType.getTotalFields - inputRow.getArity val nullRecords= new Row(countNullRecords) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala index a22232df7ae03..9129fafbb4fe5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala @@ -364,16 +364,9 @@ class JoinITCase( val table = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a1, 'a2, 'a3) tEnv.registerTable("A", table) - val sqlQuery1 = "SELECT * FROM A CROSS JOIN " + - "(SELECT count(*) FROM A HAVING count(*) < 0)" - val result = tEnv.sql(sqlQuery1) - val expected =Seq( - "2,2,Hello,null", - "1,1,Hi,null", - "3,2,Hello world,null").mkString("\n") - - val results = result.toDataSet[Row].collect() - TestBaseUtils.compareResultAsText(results.asJava, expected) + val sqlQuery1 = "SELECT * FROM A CROSS JOIN (SELECT count(*) FROM A HAVING count(*) < 0)" + val result = tEnv.sql(sqlQuery1).count() + Assert.assertEquals(0, result) } @Test @@ -383,10 +376,7 @@ class JoinITCase( val sqlQuery = "SELECT a, cnt " + - "FROM" + - " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + - "LEFT JOIN A " + - "ON cnt = a" + "FROM (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) LEFT JOIN A ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -407,10 +397,7 @@ class JoinITCase( val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = "SELECT a, cnt " + - "FROM" + - " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + - "RIGHT JOIN A " + - "ON a = cnt" + "FROM (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) RIGHT JOIN A ON a = cnt" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -436,11 +423,7 @@ class JoinITCase( val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = - "SELECT a, cnt " + - "FROM" + - " (SELECT COUNT(*) AS cnt FROM A) " + - "LEFT JOIN B " + - "ON cnt = a" + "SELECT a, cnt FROM (SELECT COUNT(*) AS cnt FROM A) LEFT JOIN B ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv)as('a, 'b, 'c) @@ -449,10 +432,7 @@ class JoinITCase( val result = tEnv.sql(sqlQuery) val expected = Seq( - "1,null", "2,null", "2,null", "3,3", "3,3", - "3,3", "4,null", "4,null", "4,null", - "4,null", "5,null", "5,null", "5,null", - "5,null", "5,null").mkString("\n") + "3,3", "3,3", "3,3").mkString("\n") val results = result.toDataSet[Row].collect() @@ -464,11 +444,7 @@ class JoinITCase( val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = - "SELECT a, cnt " + - "FROM" + - " (SELECT COUNT(*) AS cnt FROM B) " + - "RIGHT JOIN A " + - "ON cnt = a" + "SELECT a, cnt FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -493,10 +469,7 @@ class JoinITCase( val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = "SELECT a, cnt " + - "FROM" + - " A " + - "LEFT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + - "ON cnt = a" + "FROM A LEFT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) @@ -519,10 +492,7 @@ class JoinITCase( val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = "SELECT a, cnt " + - "FROM A " + - "RIGHT JOIN" + - " (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) " + - "ON cnt = a" + "FROM A RIGHT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) @@ -542,11 +512,7 @@ class JoinITCase( val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = - "SELECT a, cnt " + - "FROM" + - " A " + - "LEFT JOIN (SELECT COUNT(*) AS cnt FROM B) " + - "ON cnt = a" + "SELECT a, cnt FROM A LEFT JOIN (SELECT COUNT(*) AS cnt FROM B) ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -571,11 +537,7 @@ class JoinITCase( val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = - "SELECT a, cnt " + - "FROM" + - " A " + - "RIGHT JOIN (SELECT COUNT(*) AS cnt FROM B) " + - "ON cnt = a" + "SELECT a, cnt FROM A RIGHT JOIN (SELECT COUNT(*) AS cnt FROM B) ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -596,10 +558,8 @@ class JoinITCase( val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = - "SELECT a,cnt, cnt2 " + - "FROM t1 " + - "LEFT JOIN (SELECT COUNT(*) AS cnt,COUNT(*) AS cnt2 FROM t2 ) AS x " + - "ON a = cnt" + "SELECT a, cnt, cnt2 " + + "FROM t1 LEFT JOIN (SELECT COUNT(*) AS cnt,COUNT(*) AS cnt2 FROM t2 ) AS x ON a = cnt" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) From 6ef3bb6028e883d628668efa63fda2e5d256a7cf Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Thu, 4 May 2017 21:35:36 +0300 Subject: [PATCH 05/12] [FLINK-5256] Extend DataSetSingleRowJoin to support Left and Right joins --- .../scala/org/apache/flink/table/codegen/generated.scala | 6 +----- .../table/plan/nodes/dataset/DataSetSingleRowJoin.scala | 8 ++++---- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala index c58f14d43e5e7..6673afd14d9a8 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala @@ -35,11 +35,7 @@ case class GeneratedExpression( resultTerm: String, nullTerm: String, code: String, - resultType: TypeInformation[_]) { - - var leftNullTerm: String = _ - var rightNullTerm: String = _ -} + resultType: TypeInformation[_]) object GeneratedExpression { val ALWAYS_NULL = "true" diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index 9491456be98d9..c408c3b69ff2a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -124,14 +124,14 @@ class DataSetSingleRowJoin( broadcastInputSetName: String) : FlatMapFunction[Row, Row] = { - val nullCheck = joinType match { + val isOuterJoin = joinType match { case JoinRelType.LEFT | JoinRelType.RIGHT => true case _ => false } val codeGenerator = new CodeGenerator( config, - nullCheck, + isOuterJoin, inputType1, Some(inputType2)) @@ -187,14 +187,14 @@ class DataSetSingleRowJoin( new MapJoinLeftRunner[Row, Row, Row]( genFunction.name, genFunction.code, - nullCheck, + isOuterJoin, genFunction.returnType, broadcastInputSetName) } else { new MapJoinRightRunner[Row, Row, Row]( genFunction.name, genFunction.code, - nullCheck, + isOuterJoin, genFunction.returnType, broadcastInputSetName) } From 397377ad23656f56a3e0d8b714e93ef6d5a46abd Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Thu, 4 May 2017 21:39:22 +0300 Subject: [PATCH 06/12] [FLINK-5256] Extend DataSetSingleRowJoin to support Left and Right joins --- .../main/scala/org/apache/flink/table/codegen/generated.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala index 6673afd14d9a8..e26fb646ed0cd 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/generated.scala @@ -35,7 +35,7 @@ case class GeneratedExpression( resultTerm: String, nullTerm: String, code: String, - resultType: TypeInformation[_]) + resultType: TypeInformation[_]) object GeneratedExpression { val ALWAYS_NULL = "true" From 4acaa72e9f1fab7e2ce9cdf9c67786042fbcb59c Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Fri, 5 May 2017 12:49:04 +0300 Subject: [PATCH 07/12] [FLINK-5256] Add joinType to explain method --- .../nodes/dataset/DataSetSingleRowJoin.scala | 2 +- .../batch/sql/DataSetSingleRowJoinTest.scala | 104 +++++++++++++++++- .../batch/sql/DistinctAggregateTest.scala | 6 +- 3 files changed, 103 insertions(+), 9 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index c408c3b69ff2a..5c678ac647884 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -214,7 +214,7 @@ class DataSetSingleRowJoin( } private def joinTypeToString: String = { - "NestedLoopJoin" + "NestedLoop"+joinType.toString.toLowerCase.capitalize+"Join" } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala index 7c9b0b05590b7..a5fe845a9db90 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala @@ -63,7 +63,7 @@ class DataSetSingleRowJoinTest extends TableTestBase { ), term("where", "true"), term("join", "a1", "asum"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ) util.verifySql(query, expected) @@ -105,7 +105,7 @@ class DataSetSingleRowJoinTest extends TableTestBase { ), term("where", "=(CAST(a1), cnt)"), term("join", "a1", "a2", "cnt"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ), term("select", "a1", "a2") ) @@ -149,7 +149,7 @@ class DataSetSingleRowJoinTest extends TableTestBase { ), term("where", "<(a1, cnt)"), term("join", "a1", "a2", "cnt"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ), term("select", "a1", "a2") ) @@ -187,12 +187,106 @@ class DataSetSingleRowJoinTest extends TableTestBase { ), term("where", "AND(<(a1, b1)", "=(a2, b2))"), term("join", "a1", "a2", "b1", "b2"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ) util.verifySql(query, expected) } + @Test + def testSingleRowJoinLeftOuterJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Int)]("A", 'a1, 'a2) + util.addTable[(Int, Int)]("B", 'b1, 'b2) + + val queryLeftJoin = + "SELECT a2 FROM A " + + "LEFT JOIN " + + "(SELECT COUNT(*) AS cnt FROM B) " + + "AS x " + + "ON a1 = cnt" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + term("where", "=(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopLeftJoin") + ), + term("select", "a2") + ) + "\n" + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + + util.verifySql(queryLeftJoin, expected) + } + + @Test + def testSingleRowJoinRightOuterJoin(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Int)]("A", 'a1, 'a2) + util.addTable[(Int, Int)]("B", 'b1, 'b2) + + val queryRightJoin = + "SELECT a2 FROM A " + + "RIGHT JOIN " + + "(SELECT COUNT(*) AS cnt FROM B) " + + "AS x " + + "ON a1 = cnt" + + //val queryRightJoin = + // "SELECT a2 FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON a1 < cnt" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetJoin", + batchTableNode(0), + term("where", "=(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "RightOuterJoin") + ), + term("select", "a2") + ) + "\n" + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + + util.verifySql(queryRightJoin, expected) + } + + @Test def testSingleRowJoinInnerJoin(): Unit = { val util = batchTestUtil() @@ -216,7 +310,7 @@ class DataSetSingleRowJoinTest extends TableTestBase { ), term("where", ">(EXPR$1, EXPR$0)"), term("join", "a2", "EXPR$1", "EXPR$0"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ), term("select", "a2", "EXPR$1") ) + "\n" + diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala index 54b4d24a767ae..85cfb18f521a5 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DistinctAggregateTest.scala @@ -211,7 +211,7 @@ class DistinctAggregateTest extends TableTestBase { ), term("where", "true"), term("join", "EXPR$0", "EXPR$1"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ) util.verifySql(sqlQuery, expected) @@ -268,7 +268,7 @@ class DistinctAggregateTest extends TableTestBase { ), term("where", "true"), term("join", "EXPR$2, EXPR$0"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ), unaryNode( "DataSetAggregate", @@ -294,7 +294,7 @@ class DistinctAggregateTest extends TableTestBase { ), term("where", "true"), term("join", "EXPR$2", "EXPR$0, EXPR$1"), - term("joinType", "NestedLoopJoin") + term("joinType", "NestedLoopInnerJoin") ), term("select", "EXPR$0, EXPR$1, EXPR$2") ) From 1e567048413ee22bee73ef7c01b9f2364fe97f5a Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Fri, 5 May 2017 13:34:44 +0300 Subject: [PATCH 08/12] [FLINK-5256] Fix style --- .../flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index 5c678ac647884..625da96f0b3b2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -214,7 +214,7 @@ class DataSetSingleRowJoin( } private def joinTypeToString: String = { - "NestedLoop"+joinType.toString.toLowerCase.capitalize+"Join" + "NestedLoop" + joinType.toString.toLowerCase.capitalize + "Join" } } From c53710e69cbafd582561745e1b1577b986ef70a1 Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Mon, 8 May 2017 12:35:00 +0300 Subject: [PATCH 09/12] [FLINK-5256] Fix conflict --- .../org/apache/flink/table/runtime/MapJoinLeftRunner.scala | 2 +- .../org/apache/flink/table/runtime/MapJoinRightRunner.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala index a902b9b569d4e..bfc818dc142b6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala @@ -26,7 +26,7 @@ class MapJoinLeftRunner[IN1, IN2, OUT]( name: String, code: String, outerJoin: Boolean, - @transient returnType: TypeInformation[OUT], + returnType: TypeInformation[OUT], broadcastSetName: String) extends MapSideJoinRunner[IN1, IN2, IN2, IN1, OUT](name, code, returnType, broadcastSetName) { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala index c2287fae6f1a3..e5afb0df46fc4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala @@ -26,7 +26,7 @@ class MapJoinRightRunner[IN1, IN2, OUT]( name: String, code: String, outerJoin: Boolean, - @transient returnType: TypeInformation[OUT], + returnType: TypeInformation[OUT], broadcastSetName: String) extends MapSideJoinRunner[IN1, IN2, IN1, IN2, OUT](name, code, returnType, broadcastSetName) { From 417d0a6e20c65afecfb3bff77d549d8920a7cfcf Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Wed, 10 May 2017 17:39:49 +0300 Subject: [PATCH 10/12] [FLINK-5256] Extend DataSetSingleRowJoin to support Left and Right joins --- .../nodes/dataset/DataSetSingleRowJoin.scala | 74 +++++++++---------- .../table/runtime/MapJoinLeftRunner.scala | 2 + .../table/runtime/MapJoinRightRunner.scala | 2 + .../batch/sql/DataSetSingleRowJoinTest.scala | 49 ------------ 4 files changed, 41 insertions(+), 86 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index 625da96f0b3b2..6c8fa90e867c3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -153,29 +153,29 @@ class DataSetSingleRowJoin( |} |""".stripMargin } else { - val singleNode = - if (!leftIsSingle) { - rightNode - } - else { - leftNode - } - - val notSuitedToCondition = singleNode - .getRowType - .getFieldList - .map(field => getRowType.getFieldNames.indexOf(field.getName)) - .map(i => s"${conversion.resultTerm}.setField($i,null);") - - s""" - |${condition.code} - |${conversion.code} - |if(!${condition.resultTerm}){ - |${notSuitedToCondition.mkString("\n")} - |} - |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); - |""".stripMargin + val singleNode = + if (!leftIsSingle) { + rightNode } + else { + leftNode + } + + val notSuitedToCondition = singleNode + .getRowType + .getFieldList + .map(field => getRowType.getFieldNames.indexOf(field.getName)) + .map(i => s"${conversion.resultTerm}.setField($i,null);") + + s""" + |${condition.code} + |${conversion.code} + |if(!${condition.resultTerm}){ + |${notSuitedToCondition.mkString("\n")} + |} + |${codeGenerator.collectorTerm}.collect(${conversion.resultTerm}); + |""".stripMargin + } val genFunction = codeGenerator.generateFunction( ruleDescription, @@ -183,21 +183,21 @@ class DataSetSingleRowJoin( joinMethodBody, returnType) - if (!leftIsSingle) { - new MapJoinLeftRunner[Row, Row, Row]( - genFunction.name, - genFunction.code, - isOuterJoin, - genFunction.returnType, - broadcastInputSetName) - } else { - new MapJoinRightRunner[Row, Row, Row]( - genFunction.name, - genFunction.code, - isOuterJoin, - genFunction.returnType, - broadcastInputSetName) - } + if (!leftIsSingle) { + new MapJoinLeftRunner[Row, Row, Row]( + genFunction.name, + genFunction.code, + isOuterJoin, + genFunction.returnType, + broadcastInputSetName) + } else { + new MapJoinRightRunner[Row, Row, Row]( + genFunction.name, + genFunction.code, + isOuterJoin, + genFunction.returnType, + broadcastInputSetName) + } } private def getMapOperatorName: String = { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala index bfc818dc142b6..170e15819669b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala @@ -33,6 +33,8 @@ class MapJoinLeftRunner[IN1, IN2, OUT]( override def flatMap(multiInput: IN1, out: Collector[OUT]): Unit = { broadcastSet match { case Some(singleInput) => function.join(multiInput, singleInput, out) + case None if outerJoin => function. + join(multiInput, null.asInstanceOf[IN2], out) case None => { if (outerJoin && isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { val inputRow = multiInput.asInstanceOf[Row] diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala index e5afb0df46fc4..6cc45e9bcd04e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala @@ -33,6 +33,8 @@ class MapJoinRightRunner[IN1, IN2, OUT]( override def flatMap(multiInput: IN2, out: Collector[OUT]): Unit = { broadcastSet match { case Some(singleInput) => function.join(singleInput, multiInput, out) + case None if outerJoin => function. + join(null.asInstanceOf[IN1], multiInput, out) case None => if (outerJoin && isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { val inputRow = multiInput.asInstanceOf[Row] diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala index a5fe845a9db90..5de3f9de00028 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala @@ -238,55 +238,6 @@ class DataSetSingleRowJoinTest extends TableTestBase { util.verifySql(queryLeftJoin, expected) } - @Test - def testSingleRowJoinRightOuterJoin(): Unit = { - val util = batchTestUtil() - util.addTable[(Long, Int)]("A", 'a1, 'a2) - util.addTable[(Int, Int)]("B", 'b1, 'b2) - - val queryRightJoin = - "SELECT a2 FROM A " + - "RIGHT JOIN " + - "(SELECT COUNT(*) AS cnt FROM B) " + - "AS x " + - "ON a1 = cnt" - - //val queryRightJoin = - // "SELECT a2 FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON a1 < cnt" - - val expected = - unaryNode( - "DataSetCalc", - unaryNode( - "DataSetJoin", - batchTableNode(0), - term("where", "=(a1, cnt)"), - term("join", "a1", "a2", "cnt"), - term("joinType", "RightOuterJoin") - ), - term("select", "a2") - ) + "\n" + - unaryNode( - "DataSetAggregate", - unaryNode( - "DataSetUnion", - unaryNode( - "DataSetValues", - unaryNode( - "DataSetCalc", - batchTableNode(1), - term("select", "0 AS $f0")), - tuples(List(null)), term("values", "$f0") - ), - term("union", "$f0") - ), - term("select", "COUNT(*) AS cnt") - ) - - util.verifySql(queryRightJoin, expected) - } - - @Test def testSingleRowJoinInnerJoin(): Unit = { val util = batchTestUtil() From ef2fc0e4574f6b04e1059c5ea2f1681dfd1b8557 Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Thu, 11 May 2017 15:43:55 +0300 Subject: [PATCH 11/12] [FLINK-5256] Review fix --- .../plan/nodes/logical/FlinkLogicalJoin.scala | 13 +- .../table/runtime/MapJoinLeftRunner.scala | 13 +- .../table/runtime/MapJoinRightRunner.scala | 10 -- .../batch/sql/DataSetSingleRowJoinTest.scala | 135 +++++++++++++++++- .../api/scala/batch/sql/JoinITCase.scala | 83 ++++------- 5 files changed, 171 insertions(+), 83 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala index 8df0b59c828c6..beff3ea423580 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala @@ -77,7 +77,9 @@ private class FlinkLogicalJoinConverter val join: LogicalJoin = call.rel(0).asInstanceOf[LogicalJoin] val joinInfo = join.analyzeCondition - hasEqualityPredicates(join, joinInfo) || isSingleRowInnerJoin(join) + (hasEqualityPredicates(join, joinInfo) + || isSingleRowInnerJoin(join) + || isOuterJoinWithSingleRowAtOuterSide(join, joinInfo)) } override def convert(rel: RelNode): RelNode = { @@ -101,6 +103,15 @@ private class FlinkLogicalJoinConverter !joinInfo.pairs().isEmpty && (joinInfo.isEqui || join.getJoinType == JoinRelType.INNER) } + + + private def isOuterJoinWithSingleRowAtOuterSide (join: LogicalJoin, joinInfo: JoinInfo): Boolean = { + val isLeflSingleOrEmpty = joinInfo.leftKeys.size() < 2 + val isRightSingleOrEmpty = joinInfo.rightKeys.size() < 2 + ((join.getJoinType == JoinRelType.RIGHT && isLeflSingleOrEmpty) + || (join.getJoinType == JoinRelType.LEFT && isRightSingleOrEmpty)) + } + private def isSingleRowInnerJoin(join: LogicalJoin): Boolean = { if (join.getJoinType == JoinRelType.INNER) { isSingleRow(join.getRight) || isSingleRow(join.getLeft) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala index 170e15819669b..f2812ae381367 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinLeftRunner.scala @@ -35,19 +35,8 @@ class MapJoinLeftRunner[IN1, IN2, OUT]( case Some(singleInput) => function.join(multiInput, singleInput, out) case None if outerJoin => function. join(multiInput, null.asInstanceOf[IN2], out) - case None => { - if (outerJoin && isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { - val inputRow = multiInput.asInstanceOf[Row] - val countNullRecords = returnType.getTotalFields - inputRow.getArity - val nullRecords = new Row(countNullRecords) - function.join(multiInput, nullRecords.asInstanceOf[IN2], out) - } - } + case None => } } - private def isRowClass(obj: Any) = obj match { - case r: Row => true - case _ => false - } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala index 6cc45e9bcd04e..2e31008b058f9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/MapJoinRightRunner.scala @@ -36,17 +36,7 @@ class MapJoinRightRunner[IN1, IN2, OUT]( case None if outerJoin => function. join(null.asInstanceOf[IN1], multiInput, out) case None => - if (outerJoin && isRowClass(multiInput) && returnType.getTypeClass.equals(classOf[Row])) { - val inputRow = multiInput.asInstanceOf[Row] - val countNullRecords = returnType.getTotalFields - inputRow.getArity - val nullRecords= new Row(countNullRecords) - function.join(nullRecords.asInstanceOf[IN1], multiInput, out) - } } } - private def isRowClass(obj: Any) = obj match { - case r: Row => true - case _ => false - } } diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala index 5de3f9de00028..e56fe01848b57 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala @@ -194,7 +194,7 @@ class DataSetSingleRowJoinTest extends TableTestBase { } @Test - def testSingleRowJoinLeftOuterJoin(): Unit = { + def testRightSingleLeftJoinEqualPredicate(): Unit = { val util = batchTestUtil() util.addTable[(Long, Int)]("A", 'a1, 'a2) util.addTable[(Int, Int)]("B", 'b1, 'b2) @@ -238,6 +238,139 @@ class DataSetSingleRowJoinTest extends TableTestBase { util.verifySql(queryLeftJoin, expected) } + @Test + def testRightSingleLeftJoinNotEqualPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Int)]("A", 'a1, 'a2) + util.addTable[(Int, Int)]("B", 'b1, 'b2) + + val queryLeftJoin = + "SELECT a2 FROM A " + + "LEFT JOIN " + + "(SELECT COUNT(*) AS cnt FROM B) " + + "AS x " + + "ON a1 > cnt" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + batchTableNode(0), + term("where", ">(a1, cnt)"), + term("join", "a1", "a2", "cnt"), + term("joinType", "NestedLoopLeftJoin") + ), + term("select", "a2") + ) + "\n" + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + ) + + util.verifySql(queryLeftJoin, expected) + } + + @Test + def testLeftSingleRightJoinNotEqualPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Long)]("A", 'a1, 'a2) + util.addTable[(Long, Long)]("B", 'b1, 'b2) + + val queryRightJoin = + "SELECT a1 FROM (SELECT COUNT(*) AS cnt FROM B) " + + "RIGHT JOIN A " + + "ON cnt < a2" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + "", + term("where", "<(cnt, a2)"), + term("join", "cnt", "a1", "a2"), + term("joinType", "NestedLoopRightJoin") + ), + term("select", "a1") + ) + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + )+ "\n" + + batchTableNode(0) + + util.verifySql(queryRightJoin, expected) + } + + @Test + def testLeftSingleRightJoinEqualPredicate(): Unit = { + val util = batchTestUtil() + util.addTable[(Long, Long)]("A", 'a1, 'a2) + util.addTable[(Long, Long)]("B", 'b1, 'b2) + + val queryRightJoin = + "SELECT a1 FROM (SELECT COUNT(*) AS cnt FROM B) " + + "RIGHT JOIN A " + + "ON cnt = a2" + + val expected = + unaryNode( + "DataSetCalc", + unaryNode( + "DataSetSingleRowJoin", + "", + term("where", "=(cnt, a2)"), + term("join", "cnt", "a1", "a2"), + term("joinType", "NestedLoopRightJoin") + ), + term("select", "a1") + ) + + unaryNode( + "DataSetAggregate", + unaryNode( + "DataSetUnion", + unaryNode( + "DataSetValues", + unaryNode( + "DataSetCalc", + batchTableNode(1), + term("select", "0 AS $f0")), + tuples(List(null)), term("values", "$f0") + ), + term("union", "$f0") + ), + term("select", "COUNT(*) AS cnt") + )+ "\n" + + batchTableNode(0) + + util.verifySql(queryRightJoin, expected) + } + @Test def testSingleRowJoinInnerJoin(): Unit = { val util = batchTestUtil() diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala index 9129fafbb4fe5..3ed44a32b54aa 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/JoinITCase.scala @@ -369,35 +369,13 @@ class JoinITCase( Assert.assertEquals(0, result) } - @Test - def testLeftNullLeftJoin (): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env, config) - - val sqlQuery = - "SELECT a, cnt " + - "FROM (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) LEFT JOIN A ON cnt = a" - - val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) - val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) - tEnv.registerTable("A", ds1) - tEnv.registerTable("B", ds2) - - val result = tEnv.sql(sqlQuery).collect() - val resultSize = result.size - - Assert.assertEquals( - s"Expected empty result, but actual size result = $resultSize;\n[${result.mkString(",")}]", - resultSize,0) - } - @Test def testLeftNullRightJoin(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = "SELECT a, cnt " + - "FROM (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) RIGHT JOIN A ON a = cnt" + "FROM (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) RIGHT JOIN A ON a < cnt" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -418,21 +396,25 @@ class JoinITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test - def testLeftSingleLeftJoin(): Unit = { + def testLeftSingleRightJoinEqualPredicate(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = - "SELECT a, cnt FROM (SELECT COUNT(*) AS cnt FROM A) LEFT JOIN B ON cnt = a" + "SELECT a, cnt FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON cnt = a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) - val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv)as('a, 'b, 'c) - tEnv.registerTable("A", ds2) - tEnv.registerTable("B", ds1) + val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) + tEnv.registerTable("A", ds1) + tEnv.registerTable("B", ds2) val result = tEnv.sql(sqlQuery) val expected = Seq( - "3,3", "3,3", "3,3").mkString("\n") + "1,null", "2,null", "2,null", "3,3", "3,3", + "3,3", "4,null", "4,null", "4,null", + "4,null", "5,null", "5,null", "5,null", + "5,null", "5,null").mkString("\n") val results = result.toDataSet[Row].collect() @@ -440,11 +422,11 @@ class JoinITCase( } @Test - def testLeftSingleRightJoin(): Unit = { + def testLeftSingleRightJoinNotEqualPredicate(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = - "SELECT a, cnt FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON cnt = a" + "SELECT a, cnt FROM (SELECT COUNT(*) AS cnt FROM B) RIGHT JOIN A ON cnt > a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -453,8 +435,8 @@ class JoinITCase( val result = tEnv.sql(sqlQuery) val expected = Seq( - "1,null", "2,null", "2,null", "3,3", "3,3", - "3,3", "4,null", "4,null", "4,null", + "1,3", "2,3", "2,3", "3,null", "3,null", + "3,null", "4,null", "4,null", "4,null", "4,null", "5,null", "5,null", "5,null", "5,null", "5,null").mkString("\n") @@ -469,7 +451,7 @@ class JoinITCase( val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = "SELECT a, cnt " + - "FROM A LEFT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) ON cnt = a" + "FROM A LEFT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) ON cnt > a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) @@ -487,28 +469,7 @@ class JoinITCase( } @Test - def testRightNullRightJoin(): Unit = { - val env = ExecutionEnvironment.getExecutionEnvironment - val tEnv = TableEnvironment.getTableEnvironment(env, config) - val sqlQuery = - "SELECT a, cnt " + - "FROM A RIGHT JOIN (SELECT cnt FROM (SELECT COUNT(*) AS cnt FROM B) WHERE cnt < 0) ON cnt = a" - - val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv) - val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c) - tEnv.registerTable("A", ds2) - tEnv.registerTable("B", ds1) - - val result = tEnv.sql(sqlQuery).collect() - val resultSize = result.size - - Assert.assertEquals( - s"Expected empty result, but actual size result = $resultSize;\n[${result.mkString(",")}]", - resultSize,0) - } - - @Test - def testRightSingleLeftJoin(): Unit = { + def testRightSingleLeftJoinEqualPredicate(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = @@ -533,11 +494,11 @@ class JoinITCase( } @Test - def testRightSingleRightJoin(): Unit = { + def testRightSingleLeftJoinNotEqualPredicate(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) val sqlQuery = - "SELECT a, cnt FROM A RIGHT JOIN (SELECT COUNT(*) AS cnt FROM B) ON cnt = a" + "SELECT a, cnt FROM A LEFT JOIN (SELECT COUNT(*) AS cnt FROM B) ON cnt < a" val ds1 = CollectionDataSets.get5TupleDataSet(env).toTable(tEnv).as('a, 'b, 'c, 'd, 'e) val ds2 = CollectionDataSets.getSmall3TupleDataSet(env).toTable(tEnv) @@ -545,8 +506,12 @@ class JoinITCase( tEnv.registerTable("B", ds2) val result = tEnv.sql(sqlQuery) + val expected = Seq( - "3,3", "3,3", "3,3").mkString("\n") + "1,null", "2,null", "2,null", "3,null", "3,null", + "3,null", "4,3", "4,3", "4,3", + "4,3", "5,3", "5,3", "5,3", + "5,3", "5,3").mkString("\n") val results = result.toDataSet[Row].collect() From 6f9b311bf9d20d824b955fc736b4fa86e5759a12 Mon Sep 17 00:00:00 2001 From: DmytroShkvyra Date: Thu, 11 May 2017 16:56:13 +0300 Subject: [PATCH 12/12] [FLINK-5256] Fix style --- .../table/plan/nodes/logical/FlinkLogicalJoin.scala | 5 ++++- .../scala/batch/sql/DataSetSingleRowJoinTest.scala | 11 +++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala index beff3ea423580..18621d0b296f2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/logical/FlinkLogicalJoin.scala @@ -105,7 +105,10 @@ private class FlinkLogicalJoinConverter - private def isOuterJoinWithSingleRowAtOuterSide (join: LogicalJoin, joinInfo: JoinInfo): Boolean = { + private def isOuterJoinWithSingleRowAtOuterSide( + join: LogicalJoin, + joinInfo: JoinInfo): Boolean = { + val isLeflSingleOrEmpty = joinInfo.leftKeys.size() < 2 val isRightSingleOrEmpty = joinInfo.rightKeys.size() < 2 ((join.getJoinType == JoinRelType.RIGHT && isLeflSingleOrEmpty) diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala index e56fe01848b57..37546975f2201 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/sql/DataSetSingleRowJoinTest.scala @@ -306,7 +306,7 @@ class DataSetSingleRowJoinTest extends TableTestBase { ), term("select", "a1") ) + - unaryNode( + unaryNode( "DataSetAggregate", unaryNode( "DataSetUnion", @@ -321,8 +321,8 @@ class DataSetSingleRowJoinTest extends TableTestBase { term("union", "$f0") ), term("select", "COUNT(*) AS cnt") - )+ "\n" + - batchTableNode(0) + ) + "\n" + + batchTableNode(0) util.verifySql(queryRightJoin, expected) } @@ -349,8 +349,7 @@ class DataSetSingleRowJoinTest extends TableTestBase { term("joinType", "NestedLoopRightJoin") ), term("select", "a1") - ) + - unaryNode( + ) + unaryNode( "DataSetAggregate", unaryNode( "DataSetUnion", @@ -365,7 +364,7 @@ class DataSetSingleRowJoinTest extends TableTestBase { term("union", "$f0") ), term("select", "COUNT(*) AS cnt") - )+ "\n" + + ) + "\n" + batchTableNode(0) util.verifySql(queryRightJoin, expected)