Skip to content

Commit

Permalink
[FLINK-3482] implement union translation
Browse files Browse the repository at this point in the history
- implement custom JoinUnionTransposeRules
because Calcite's only match with LogicalUnion

This closes #1715
  • Loading branch information
vasia committed Feb 26, 2016
1 parent 2b36401 commit 4a63e29
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,10 @@ class DataSetUnion(
override def translateToPlan(
config: TableConfig,
expectedType: Option[TypeInformation[Any]]): DataSet[Any] = {
???

val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(config)
val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(config)
leftDataSet.union(rightDataSet).asInstanceOf[DataSet[Any]]
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ object FlinkRuleSets {
// merge and push unions rules
// TODO: Add a rule to enforce binary unions
UnionEliminatorRule.INSTANCE,
JoinUnionTransposeRule.LEFT_UNION,
JoinUnionTransposeRule.RIGHT_UNION,
FlinkJoinUnionTransposeRule.LEFT_UNION,
FlinkJoinUnionTransposeRule.RIGHT_UNION,
// non-all Union to all-union + distinct
UnionToDistinctRule.INSTANCE,

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.table.plan.rules.logical

import org.apache.calcite.plan.RelOptRule.{any, operand, convert => convertTrait}
import org.apache.calcite.plan.RelOptRule
import org.apache.calcite.plan.RelOptRuleOperand
import org.apache.calcite.plan.RelOptRuleCall
import org.apache.calcite.rel.RelNode
import java.util.ArrayList
import scala.collection.JavaConversions._
import org.apache.calcite.rel.logical.LogicalJoin
import org.apache.calcite.rel.logical.LogicalUnion
import org.apache.calcite.rel.core.Join
import org.apache.calcite.rel.core.Union

/**
* This rule is a copy of Calcite's JoinUnionTransposeRule.
* Calcite's implementation checks whether one of the operands is a LogicalUnion,
* which fails in our case, when it matches with a FlinkUnion.
* This rule changes this check to match Union, instead of LogicalUnion only.
* The rest of the rule's logic has not been changed.
*/
class FlinkJoinUnionTransposeRule(
operand: RelOptRuleOperand,
description: String) extends RelOptRule(operand, description) {

override def onMatch(call: RelOptRuleCall): Unit = {
val join = call.rel(0).asInstanceOf[Join]
val (unionRel: Union, otherInput: RelNode, unionOnLeft: Boolean) = {
if (call.rel(1).isInstanceOf[Union]) {
(call.rel(1).asInstanceOf[Union], call.rel(2).asInstanceOf[RelNode], true)
}
else {
(call.rel(2).asInstanceOf[Union], call.rel(1).asInstanceOf[RelNode], false)
}
}

if (!unionRel.all) {
return
}
if (!join.getVariablesStopped.isEmpty) {
return
}
// The UNION ALL cannot be on the null generating side
// of an outer join (otherwise we might generate incorrect
// rows for the other side for join keys which lack a match
// in one or both branches of the union)
if (unionOnLeft) {
if (join.getJoinType.generatesNullsOnLeft) {
return
}
}
else {
if (join.getJoinType.generatesNullsOnRight) {
return
}
}
val newUnionInputs = new ArrayList[RelNode]
for (input <- unionRel.getInputs) {
val (joinLeft: RelNode, joinRight: RelNode) = {
if (unionOnLeft) {
(input, otherInput)
}
else {
(otherInput, input)
}
}

newUnionInputs.add(
join.copy(
join.getTraitSet,
join.getCondition,
joinLeft,
joinRight,
join.getJoinType,
join.isSemiJoinDone))
}
val newUnionRel = unionRel.copy(unionRel.getTraitSet, newUnionInputs, true)
call.transformTo(newUnionRel)
}
}

object FlinkJoinUnionTransposeRule {
val LEFT_UNION = new FlinkJoinUnionTransposeRule(
operand(classOf[LogicalJoin], operand(classOf[LogicalUnion], any),
operand(classOf[RelNode], any)),
"JoinUnionTransposeRule(Union-Other)")

val RIGHT_UNION = new FlinkJoinUnionTransposeRule(
operand(classOf[LogicalJoin], operand(classOf[RelNode], any),
operand(classOf[LogicalUnion], any)),
"JoinUnionTransposeRule(Other-Union)")
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import scala.NotImplementedError;

import java.util.List;

Expand All @@ -41,7 +40,7 @@ public UnionITCase(TestExecutionMode mode) {
super(mode);
}

@Test(expected = NotImplementedError.class)
@Test
public void testUnion() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
Expand All @@ -60,7 +59,7 @@ public void testUnion() throws Exception {
compareResultAsText(results, expected);
}

@Test(expected = NotImplementedError.class)
@Test
public void testUnionWithFilter() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
Expand Down Expand Up @@ -117,7 +116,7 @@ public void testUnionFieldsNameNotOverlap2() throws Exception {
compareResultAsText(results, expected);
}

@Test(expected = NotImplementedError.class)
@Test
public void testUnionWithAggregation() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();
Expand All @@ -136,4 +135,27 @@ public void testUnionWithAggregation() throws Exception {
compareResultAsText(results, expected);
}

}
@Test
public void testUnionWithJoin() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSet<Tuple3<Integer, Long, String>> ds1 = CollectionDataSets.getSmall3TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds2 = CollectionDataSets.get5TupleDataSet(env);
DataSet<Tuple5<Integer, Long, Integer, String, Long>> ds3 = CollectionDataSets.getSmall5TupleDataSet(env);

Table in1 = tableEnv.fromDataSet(ds1, "a, b, c");
Table in2 = tableEnv.fromDataSet(ds2, "a, b, d, c, e").select("a, b, c");
Table in3 = tableEnv.fromDataSet(ds3, "a2, b2, d2, c2, e2").select("a2, b2, c2");

Table joinDs = in1.unionAll(in2).join(in3).where("a === a2").select("c, c2");
DataSet<Row> ds = tableEnv.toDataSet(joinDs, Row.class);
List<Row> results = ds.collect();

String expected = "Hi,Hallo\n" + "Hallo,Hallo\n" +
"Hello,Hallo Welt\n" + "Hello,Hallo Welt wie\n" +
"Hallo Welt,Hallo Welt\n" + "Hallo Welt wie,Hallo Welt\n" +
"Hallo Welt,Hallo Welt wie\n" + "Hallo Welt wie,Hallo Welt wie\n";
compareResultAsText(results, expected);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,21 @@ import org.apache.flink.api.scala.table._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.api.table.{ExpressionException, Row}
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.apache.flink.test.util.{MultipleProgramsTestBase, TestBaseUtils}
import org.apache.flink.test.util.TestBaseUtils
import org.junit._
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

import scala.collection.JavaConverters._
import org.apache.flink.api.table.test.utils.TableProgramsTestBase
import TableProgramsTestBase.TableConfigMode

@RunWith(classOf[Parameterized])
class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) {
class UnionITCase(
mode: TestExecutionMode,
configMode: TableConfigMode)
extends TableProgramsTestBase(mode, configMode) {

@Test(expected = classOf[NotImplementedError])
@Test
def testUnion(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
Expand All @@ -46,7 +50,7 @@ class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test(expected = classOf[NotImplementedError])
@Test
def testUnionWithFilter(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
Expand Down Expand Up @@ -85,7 +89,7 @@ class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test(expected = classOf[NotImplementedError])
@Test
def testUnionWithAggregation(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
Expand All @@ -97,4 +101,23 @@ class UnionITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode
val expected = "18"
TestBaseUtils.compareResultAsText(results.asJava, expected)
}

@Test
def testUnionWithJoin(): Unit = {
val env: ExecutionEnvironment = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env).as('a, 'b, 'c)
val ds2 = CollectionDataSets.get5TupleDataSet(env).as('a, 'b, 'd, 'c, 'e)
val ds3 = CollectionDataSets.getSmall5TupleDataSet(env).as('a2, 'b2, 'd2, 'c2, 'e2)

val joinDs = ds1.unionAll(ds2.select('a, 'b, 'c))
.join(ds3.select('a2, 'b2, 'c2))
.where('a ==='a2).select('c, 'c2)

val results = joinDs.toDataSet[Row].collect()
val expected = "Hi,Hallo\n" + "Hallo,Hallo\n" +
"Hello,Hallo Welt\n" + "Hello,Hallo Welt wie\n" +
"Hallo Welt,Hallo Welt\n" + "Hallo Welt wie,Hallo Welt\n" +
"Hallo Welt,Hallo Welt wie\n" + "Hallo Welt wie,Hallo Welt wie\n"
TestBaseUtils.compareResultAsText(results.asJava, expected)
}
}

0 comments on commit 4a63e29

Please sign in to comment.