Skip to content

Commit

Permalink
[FLINK-2662] [optimizer] Fix computation of global properties of unio…
Browse files Browse the repository at this point in the history
…n operator.

- Fixes invalid shipping strategy between consecutive unions.

This closes apache#2848.
  • Loading branch information
fhueske authored and alpinegizmo committed Nov 28, 2016
1 parent e0e33f1 commit d8f2f75
Show file tree
Hide file tree
Showing 2 changed files with 258 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,35 @@ public GlobalProperties computeGlobalProperties(GlobalProperties in1, GlobalProp

if (in1.getPartitioning() == PartitioningProperty.HASH_PARTITIONED &&
in2.getPartitioning() == PartitioningProperty.HASH_PARTITIONED &&
in1.getPartitioningFields().equals(in2.getPartitioningFields()))
{
in1.getPartitioningFields().equals(in2.getPartitioningFields())) {
newProps.setHashPartitioned(in1.getPartitioningFields());
}

else if (in1.getPartitioning() == PartitioningProperty.RANGE_PARTITIONED &&
in2.getPartitioning() == PartitioningProperty.RANGE_PARTITIONED &&
in1.getPartitioningOrdering().equals(in2.getPartitioningOrdering()) &&
(
in1.getDataDistribution() == null && in2.getDataDistribution() == null ||
in1.getDataDistribution() != null && in1.getDataDistribution().equals(in2.getDataDistribution())
)
) {
if (in1.getDataDistribution() == null) {
newProps.setRangePartitioned(in1.getPartitioningOrdering());
}
else {
newProps.setRangePartitioned(in1.getPartitioningOrdering(), in1.getDataDistribution());
}
}
else if (in1.getPartitioning() == PartitioningProperty.CUSTOM_PARTITIONING &&
in2.getPartitioning() == PartitioningProperty.CUSTOM_PARTITIONING &&
in1.getPartitioningFields().equals(in2.getPartitioningFields()) &&
in1.getCustomPartitioner().equals(in2.getCustomPartitioner())) {
newProps.setCustomPartitioned(in1.getPartitioningFields(), in1.getCustomPartitioner());
}
else if (in1.getPartitioning() == PartitioningProperty.FORCED_REBALANCED &&
in2.getPartitioning() == PartitioningProperty.FORCED_REBALANCED) {
newProps.setForcedRebalanced();
}

return newProps;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,25 @@
package org.apache.flink.optimizer;

import junit.framework.Assert;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.util.FieldList;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.common.Plan;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.optimizer.dataproperties.PartitioningProperty;
import org.apache.flink.optimizer.plan.Channel;
import org.apache.flink.optimizer.plan.DualInputPlanNode;
import org.apache.flink.optimizer.plan.NAryUnionPlanNode;
import org.apache.flink.optimizer.plan.OptimizedPlan;
import org.apache.flink.optimizer.plan.SingleInputPlanNode;
import org.apache.flink.optimizer.plan.SourcePlanNode;
import org.apache.flink.optimizer.plantranslate.JobGraphGenerator;
import org.apache.flink.optimizer.util.CompilerTestBase;
import org.apache.flink.runtime.operators.Driver;
import org.apache.flink.runtime.operators.DriverStrategy;
import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
import org.junit.Test;

Expand Down Expand Up @@ -87,7 +94,7 @@ public void testUnionReplacement() {
*
*/
@Test
public void testUnionWithTwoOutputsTest() throws Exception {
public void testUnionWithTwoOutputs() throws Exception {

// -----------------------------------------------------------------------------------------
// Build test program
Expand Down Expand Up @@ -120,38 +127,253 @@ public void testUnionWithTwoOutputsTest() throws Exception {
SingleInputPlanNode groupRed2 = resolver.getNode("2");

// check partitioning is correct
Assert.assertTrue("Reduce input should be partitioned on 0.",
assertTrue("Reduce input should be partitioned on 0.",
groupRed1.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new FieldList(0)));
Assert.assertTrue("Reduce input should be partitioned on 1.",
assertTrue("Reduce input should be partitioned on 1.",
groupRed2.getInput().getGlobalProperties().getPartitioningFields().isExactMatch(new FieldList(1)));

// check group reduce inputs are n-ary unions with three inputs
Assert.assertTrue("Reduce input should be n-ary union with three inputs.",
assertTrue("Reduce input should be n-ary union with three inputs.",
groupRed1.getInput().getSource() instanceof NAryUnionPlanNode &&
((NAryUnionPlanNode) groupRed1.getInput().getSource()).getListOfInputs().size() == 3);
Assert.assertTrue("Reduce input should be n-ary union with three inputs.",
assertTrue("Reduce input should be n-ary union with three inputs.",
groupRed2.getInput().getSource() instanceof NAryUnionPlanNode &&
((NAryUnionPlanNode) groupRed2.getInput().getSource()).getListOfInputs().size() == 3);

// check channel from union to group reduce is forwarding
Assert.assertTrue("Channel between union and group reduce should be forwarding",
assertTrue("Channel between union and group reduce should be forwarding",
groupRed1.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD));
Assert.assertTrue("Channel between union and group reduce should be forwarding",
assertTrue("Channel between union and group reduce should be forwarding",
groupRed2.getInput().getShipStrategy().equals(ShipStrategyType.FORWARD));

// check that all inputs of unions are hash partitioned
List<Channel> union123In = ((NAryUnionPlanNode) groupRed1.getInput().getSource()).getListOfInputs();
for(Channel i : union123In) {
Assert.assertTrue("Union input channel should hash partition on 0",
assertTrue("Union input channel should hash partition on 0",
i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) &&
i.getShipStrategyKeys().isExactMatch(new FieldList(0)));
}
List<Channel> union234In = ((NAryUnionPlanNode) groupRed2.getInput().getSource()).getListOfInputs();
for(Channel i : union234In) {
Assert.assertTrue("Union input channel should hash partition on 0",
assertTrue("Union input channel should hash partition on 0",
i.getShipStrategy().equals(ShipStrategyType.PARTITION_HASH) &&
i.getShipStrategyKeys().isExactMatch(new FieldList(1)));
}

}

/**
*
* Checks that a plan with consecutive UNIONs followed by PartitionByHash is correctly translated.
*
* The program can be illustrated as follows:
*
* Src1 -\
* >-> Union12--<
* Src2 -/ \
* >-> Union123 -> PartitionByHash -> Output
* Src3 ----------------/
*
* In the resulting plan, the hash partitioning (ShippingStrategy.PARTITION_HASH) must be
* pushed to the inputs of the unions (Src1, Src2, Src3).
*
*/
@Test
public void testConsecutiveUnionsWithHashPartitioning() throws Exception {

// -----------------------------------------------------------------------------------------
// Build test program
// -----------------------------------------------------------------------------------------

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(DEFAULT_PARALLELISM);

DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new Tuple2<>(0L, 0L));

DataSet<Tuple2<Long, Long>> union12 = src1.union(src2);
DataSet<Tuple2<Long, Long>> union123 = union12.union(src3);

union123.partitionByHash(1).output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("out");

// -----------------------------------------------------------------------------------------
// Verify optimized plan
// -----------------------------------------------------------------------------------------

OptimizedPlan optimizedPlan = compileNoStats(env.createProgramPlan());

OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(optimizedPlan);

SingleInputPlanNode sink = resolver.getNode("out");

// check partitioning is correct
assertEquals("Sink input should be hash partitioned.",
PartitioningProperty.HASH_PARTITIONED, sink.getInput().getGlobalProperties().getPartitioning());
assertEquals("Sink input should be hash partitioned on 1.",
new FieldList(1), sink.getInput().getGlobalProperties().getPartitioningFields());

SingleInputPlanNode partitioner = (SingleInputPlanNode)sink.getInput().getSource();
assertTrue(partitioner.getDriverStrategy() == DriverStrategy.UNARY_NO_OP);
assertEquals("Partitioner input should be hash partitioned.",
PartitioningProperty.HASH_PARTITIONED, partitioner.getInput().getGlobalProperties().getPartitioning());
assertEquals("Partitioner input should be hash partitioned on 1.",
new FieldList(1), partitioner.getInput().getGlobalProperties().getPartitioningFields());
assertEquals("Partitioner input channel should be forwarding",
ShipStrategyType.FORWARD, partitioner.getInput().getShipStrategy());

NAryUnionPlanNode union = (NAryUnionPlanNode)partitioner.getInput().getSource();
// all union inputs should be hash partitioned
for (Channel c : union.getInputs()) {
assertEquals("Union input should be hash partitioned",
PartitioningProperty.HASH_PARTITIONED, c.getGlobalProperties().getPartitioning());
assertEquals("Union input channel should be hash partitioning",
ShipStrategyType.PARTITION_HASH, c.getShipStrategy());
assertTrue("Union input should be data source",
c.getSource() instanceof SourcePlanNode);
}
}

/**
*
* Checks that a plan with consecutive UNIONs followed by REBALANCE is correctly translated.
*
* The program can be illustrated as follows:
*
* Src1 -\
* >-> Union12--<
* Src2 -/ \
* >-> Union123 -> Rebalance -> Output
* Src3 ----------------/
*
* In the resulting plan, the Rebalance (ShippingStrategy.PARTITION_FORCED_REBALANCE) must be
* pushed to the inputs of the unions (Src1, Src2, Src3).
*
*/
@Test
public void testConsecutiveUnionsWithRebalance() throws Exception {

// -----------------------------------------------------------------------------------------
// Build test program
// -----------------------------------------------------------------------------------------

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(DEFAULT_PARALLELISM);

DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new Tuple2<>(0L, 0L));

DataSet<Tuple2<Long, Long>> union12 = src1.union(src2);
DataSet<Tuple2<Long, Long>> union123 = union12.union(src3);

union123.rebalance().output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("out");

// -----------------------------------------------------------------------------------------
// Verify optimized plan
// -----------------------------------------------------------------------------------------

OptimizedPlan optimizedPlan = compileNoStats(env.createProgramPlan());

OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(optimizedPlan);

SingleInputPlanNode sink = resolver.getNode("out");

// check partitioning is correct
assertEquals("Sink input should be force rebalanced.",
PartitioningProperty.FORCED_REBALANCED, sink.getInput().getGlobalProperties().getPartitioning());

SingleInputPlanNode partitioner = (SingleInputPlanNode)sink.getInput().getSource();
assertTrue(partitioner.getDriverStrategy() == DriverStrategy.UNARY_NO_OP);
assertEquals("Partitioner input should be force rebalanced.",
PartitioningProperty.FORCED_REBALANCED, partitioner.getInput().getGlobalProperties().getPartitioning());
assertEquals("Partitioner input channel should be forwarding",
ShipStrategyType.FORWARD, partitioner.getInput().getShipStrategy());

NAryUnionPlanNode union = (NAryUnionPlanNode)partitioner.getInput().getSource();
// all union inputs should be force rebalanced
for (Channel c : union.getInputs()) {
assertEquals("Union input should be force rebalanced",
PartitioningProperty.FORCED_REBALANCED, c.getGlobalProperties().getPartitioning());
assertEquals("Union input channel should be rebalancing",
ShipStrategyType.PARTITION_FORCED_REBALANCE, c.getShipStrategy());
assertTrue("Union input should be data source",
c.getSource() instanceof SourcePlanNode);
}
}

/**
*
* Checks that a plan with consecutive UNIONs followed by PARTITION_RANGE is correctly translated.
*
* The program can be illustrated as follows:
*
* Src1 -\
* >-> Union12--<
* Src2 -/ \
* >-> Union123 -> PartitionByRange -> Output
* Src3 ----------------/
*
* In the resulting plan, the range partitioning must be
* pushed to the inputs of the unions (Src1, Src2, Src3).
*
*/
@Test
public void testConsecutiveUnionsWithRangePartitioning() throws Exception {

// -----------------------------------------------------------------------------------------
// Build test program
// -----------------------------------------------------------------------------------------

ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
env.setParallelism(DEFAULT_PARALLELISM);

DataSet<Tuple2<Long, Long>> src1 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src2 = env.fromElements(new Tuple2<>(0L, 0L));
DataSet<Tuple2<Long, Long>> src3 = env.fromElements(new Tuple2<>(0L, 0L));

DataSet<Tuple2<Long, Long>> union12 = src1.union(src2);
DataSet<Tuple2<Long, Long>> union123 = union12.union(src3);

union123.partitionByRange(1).output(new DiscardingOutputFormat<Tuple2<Long, Long>>()).name("out");

// -----------------------------------------------------------------------------------------
// Verify optimized plan
// -----------------------------------------------------------------------------------------

OptimizedPlan optimizedPlan = compileNoStats(env.createProgramPlan());

OptimizerPlanNodeResolver resolver = getOptimizerPlanNodeResolver(optimizedPlan);

SingleInputPlanNode sink = resolver.getNode("out");

// check partitioning is correct
assertEquals("Sink input should be range partitioned.",
PartitioningProperty.RANGE_PARTITIONED, sink.getInput().getGlobalProperties().getPartitioning());
assertEquals("Sink input should be range partitioned on 1",
new Ordering(1, null, Order.ASCENDING), sink.getInput().getGlobalProperties().getPartitioningOrdering());

SingleInputPlanNode partitioner = (SingleInputPlanNode)sink.getInput().getSource();
assertTrue(partitioner.getDriverStrategy() == DriverStrategy.UNARY_NO_OP);
assertEquals("Partitioner input should be range partitioned.",
PartitioningProperty.RANGE_PARTITIONED, partitioner.getInput().getGlobalProperties().getPartitioning());
assertEquals("Partitioner input should be range partitioned on 1",
new Ordering(1, null, Order.ASCENDING), partitioner.getInput().getGlobalProperties().getPartitioningOrdering());
assertEquals("Partitioner input channel should be forwarding",
ShipStrategyType.FORWARD, partitioner.getInput().getShipStrategy());

NAryUnionPlanNode union = (NAryUnionPlanNode)partitioner.getInput().getSource();
// all union inputs should be force rebalanced
for (Channel c : union.getInputs()) {
assertEquals("Union input should be force rebalanced",
PartitioningProperty.RANGE_PARTITIONED, c.getGlobalProperties().getPartitioning());
assertEquals("Union input channel should be rebalancing",
ShipStrategyType.FORWARD, c.getShipStrategy());
// range partitioning is executed as custom partitioning with prior sampling
SingleInputPlanNode partitionMap = (SingleInputPlanNode)c.getSource();
assertEquals(DriverStrategy.MAP, partitionMap.getDriverStrategy());
assertEquals(ShipStrategyType.PARTITION_CUSTOM, partitionMap.getInput().getShipStrategy());
}
}

}

0 comments on commit d8f2f75

Please sign in to comment.