From 3416f3ae86e1af186a0ccd477f2fa904d15a7cf3 Mon Sep 17 00:00:00 2001 From: vasia Date: Thu, 25 Feb 2016 20:52:19 +0100 Subject: [PATCH] [FLINK-3482] implement union translation - implement custom JoinUnionTransposeRules because Calcite's only match with LogicalUnion --- .../plan/nodes/dataset/DataSetUnion.scala | 5 +- .../api/table/plan/rules/FlinkRuleSets.scala | 4 +- .../logical/FlinkJoinUnionTransposeRule.scala | 110 ++++++++++++++++++ .../api/java/table/test/UnionITCase.java | 32 ++++- .../api/scala/table/test/UnionITCase.scala | 35 +++++- 5 files changed, 172 insertions(+), 14 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/logical/FlinkJoinUnionTransposeRule.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetUnion.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetUnion.scala index ebfd48a5ff87c..462c4a5b2b2a3 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetUnion.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/nodes/dataset/DataSetUnion.scala @@ -59,7 +59,10 @@ class DataSetUnion( override def translateToPlan( config: TableConfig, expectedType: Option[TypeInformation[Any]]): DataSet[Any] = { - ??? + + val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(config) + val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(config) + leftDataSet.union(rightDataSet).asInstanceOf[DataSet[Any]] } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala index 32d9f0dcf3f6e..b5c3800efd20c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/FlinkRuleSets.scala @@ -51,8 +51,8 @@ object FlinkRuleSets { // merge and push unions rules // TODO: Add a rule to enforce binary unions UnionEliminatorRule.INSTANCE, - JoinUnionTransposeRule.LEFT_UNION, - JoinUnionTransposeRule.RIGHT_UNION, + FlinkJoinUnionTransposeRule.LEFT_UNION, + FlinkJoinUnionTransposeRule.RIGHT_UNION, // non-all Union to all-union + distinct UnionToDistinctRule.INSTANCE, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/logical/FlinkJoinUnionTransposeRule.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/logical/FlinkJoinUnionTransposeRule.scala new file mode 100644 index 0000000000000..af54f37518e71 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/plan/rules/logical/FlinkJoinUnionTransposeRule.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.api.table.plan.rules.logical + +import org.apache.calcite.plan.RelOptRule.{any, operand, convert => convertTrait} +import org.apache.calcite.plan.RelOptRule +import org.apache.calcite.plan.RelOptRuleOperand +import org.apache.calcite.plan.RelOptRuleCall +import org.apache.calcite.rel.RelNode +import java.util.ArrayList +import scala.collection.JavaConversions._ +import org.apache.calcite.rel.logical.LogicalJoin +import org.apache.calcite.rel.logical.LogicalUnion +import org.apache.calcite.rel.core.Join +import org.apache.calcite.rel.core.Union + +/** + * This rule is a copy of Calcite's JoinUnionTransposeRule. + * Calcite's implementation checks whether one of the operands is a LogicalUnion, + * which fails in our case, when it matches with a FlinkUnion. + * This rule changes this check to match Union, instead of LogicalUnion only. + * The rest of the rule's logic has not been changed. + */ +class FlinkJoinUnionTransposeRule( + operand: RelOptRuleOperand, + description: String) extends RelOptRule(operand, description) { + + override def onMatch(call: RelOptRuleCall): Unit = { + val join = call.rel(0).asInstanceOf[Join] + val (unionRel: Union, otherInput: RelNode, unionOnLeft: Boolean) = { + if (call.rel(1).isInstanceOf[Union]) { + (call.rel(1).asInstanceOf[Union], call.rel(2).asInstanceOf[RelNode], true) + } + else { + (call.rel(2).asInstanceOf[Union], call.rel(1).asInstanceOf[RelNode], false) + } + } + + if (!unionRel.all) { + return + } + if (!join.getVariablesStopped.isEmpty) { + return + } + // The UNION ALL cannot be on the null generating side + // of an outer join (otherwise we might generate incorrect + // rows for the other side for join keys which lack a match + // in one or both branches of the union) + if (unionOnLeft) { + if (join.getJoinType.generatesNullsOnLeft) { + return + } + } + else { + if (join.getJoinType.generatesNullsOnRight) { + return + } + } + val newUnionInputs = new ArrayList[RelNode] + for (input <- unionRel.getInputs) { + val (joinLeft: RelNode, joinRight: RelNode) = { + if (unionOnLeft) { + (input, otherInput) + } + else { + (otherInput, input) + } + } + + newUnionInputs.add( + join.copy( + join.getTraitSet, + join.getCondition, + joinLeft, + joinRight, + join.getJoinType, + join.isSemiJoinDone)) + } + val newUnionRel = unionRel.copy(unionRel.getTraitSet, newUnionInputs, true) + call.transformTo(newUnionRel) + } +} + +object FlinkJoinUnionTransposeRule { + val LEFT_UNION = new FlinkJoinUnionTransposeRule( + operand(classOf[LogicalJoin], operand(classOf[LogicalUnion], any), + operand(classOf[RelNode], any)), + "JoinUnionTransposeRule(Union-Other)") + + val RIGHT_UNION = new FlinkJoinUnionTransposeRule( + operand(classOf[LogicalJoin], operand(classOf[RelNode], any), + operand(classOf[LogicalUnion], any)), + "JoinUnionTransposeRule(Other-Union)") +} diff --git a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/UnionITCase.java b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/UnionITCase.java index 8876dc8cf14f8..75429c24d018c 100644 --- a/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/UnionITCase.java +++ b/flink-libraries/flink-table/src/test/java/org/apache/flink/api/java/table/test/UnionITCase.java @@ -30,7 +30,6 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import scala.NotImplementedError; import java.util.List; @@ -41,7 +40,7 @@ public UnionITCase(TestExecutionMode mode) { super(mode); } - @Test(expected = NotImplementedError.class) + @Test public void testUnion() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); TableEnvironment tableEnv = new TableEnvironment(); @@ -60,7 +59,7 @@ public void testUnion() throws Exception { compareResultAsText(results, expected); } - @Test(expected = NotImplementedError.class) + @Test public void testUnionWithFilter() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); TableEnvironment tableEnv = new TableEnvironment(); @@ -117,7 +116,7 @@ public void testUnionFieldsNameNotOverlap2() throws Exception { compareResultAsText(results, expected); } - @Test(expected = NotImplementedError.class) + @Test public void testUnionWithAggregation() throws Exception { ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); TableEnvironment tableEnv = new TableEnvironment(); @@ -136,4 +135,27 @@ public void testUnionWithAggregation() throws Exception { compareResultAsText(results, expected); } -} + @Test + public void testUnionWithJoin() throws Exception { + ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment(); + TableEnvironment tableEnv = new TableEnvironment(); + + DataSet> ds1 = CollectionDataSets.getSmall3TupleDataSet(env); + DataSet> ds2 = CollectionDataSets.get5TupleDataSet(env); + DataSet> ds3 = CollectionDataSets.getSmall5TupleDataSet(env); + + Table in1 = tableEnv.fromDataSet(ds1, "a, b, c"); + Table in2 = tableEnv.fromDataSet(ds2, "a, b, d, c, e").select("a, b, c"); + Table in3 = tableEnv.fromDataSet(ds3, "a2, b2, d2, c2, e2").select("a2, b2, c2"); + + Table joinDs = in1.unionAll(in2).join(in3).where("a === a2").select("c, c2"); + DataSet ds = tableEnv.toDataSet(joinDs, Row.class); + List results = ds.collect(); + + String expected = "Hi,Hallo\n" + "Hallo,Hallo\n" + + "Hello,Hallo Welt\n" + "Hello,Hallo Welt wie\n" + + "Hallo Welt,Hallo Welt\n" + "Hallo Welt wie,Hallo Welt\n" + + "Hallo Welt,Hallo Welt wie\n" + "Hallo Welt wie,Hallo Welt wie\n"; + compareResultAsText(results, expected); + } +} \ No newline at end of file diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala index 3d03f2305805e..3708107f1b402 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/api/scala/table/test/UnionITCase.scala @@ -23,17 +23,21 @@ import org.apache.flink.api.scala.table._ import org.apache.flink.api.scala.util.CollectionDataSets import org.apache.flink.api.table.{ExpressionException, Row} import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode -import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils} +import org.apache.flink.test.util.TestBaseUtils import org.junit._ import org.junit.runner.RunWith import org.junit.runners.Parameterized - import scala.collection.JavaConverters._ +import org.apache.flink.api.table.test.utils.TableProgramsTestBase +import TableProgramsTestBase.TableConfigMode @RunWith(classOf[Parameterized]) -class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) { +class UnionITCase( + mode: TestExecutionMode, + configMode: TableConfigMode) + extends TableProgramsTestBase(mode, configMode) { - @Test(expected = classOf[NotImplementedError]) + @Test def testUnion(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c) @@ -46,7 +50,7 @@ class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode TestBaseUtils.compareResultAsText(results.asJava, expected) } - @Test(expected = classOf[NotImplementedError]) + @Test def testUnionWithFilter(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c) @@ -85,7 +89,7 @@ class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode TestBaseUtils.compareResultAsText(results.asJava, expected) } - @Test(expected = classOf[NotImplementedError]) + @Test def testUnionWithAggregation(): Unit = { val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c) @@ -97,4 +101,23 @@ class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode val expected = "18" TestBaseUtils.compareResultAsText(results.asJava, expected) } + + @Test + def testUnionWithJoin(): Unit = { + val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment + val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c) + val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'd, 'c, 'e) + val ds3 = CollectionDataSets.getSmall5TupleDataSet(env).as('a2, 'b2, 'd2, 'c2, 'e2) + + val joinDs = ds1.unionAll(ds2.select('a, 'b, 'c)) + .join(ds3.select('a2, 'b2, 'c2)) + .where('a ==='a2).select('c, 'c2) + + val results = joinDs.toDataSet[Row].collect() + val expected = "Hi,Hallo\n" + "Hallo,Hallo\n" + + "Hello,Hallo Welt\n" + "Hello,Hallo Welt wie\n" + + "Hallo Welt,Hallo Welt\n" + "Hallo Welt wie,Hallo Welt\n" + + "Hallo Welt,Hallo Welt wie\n" + "Hallo Welt wie,Hallo Welt wie\n" + TestBaseUtils.compareResultAsText(results.asJava, expected) + } }