Skip to content

Commit

Permalink
[FLINK-7] Prevent range partitioning inside iterations.
Browse files Browse the repository at this point in the history
  • Loading branch information
fhueske committed Dec 21, 2015
1 parent f5957ce commit a6a0528
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 16 deletions.
@@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.functions;

import org.apache.flink.api.common.functions.Partitioner;

public class IdPartitioner implements Partitioner<Integer> {

private static final long serialVersionUID = -1206233785103357568L;

@Override
public int partition(Integer key, int numPartitions) {
return key;
}

}
Expand Up @@ -17,8 +17,8 @@
*/ */
package org.apache.flink.optimizer.traversals; package org.apache.flink.optimizer.traversals;


import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.distributions.CommonRangeBoundaries; import org.apache.flink.api.common.distributions.CommonRangeBoundaries;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.operators.Order; import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering; import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.UnaryOperatorInformation; import org.apache.flink.api.common.operators.UnaryOperatorInformation;
Expand All @@ -29,9 +29,11 @@
import org.apache.flink.api.common.typeinfo.BasicTypeInfo; import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation; import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparatorFactory; import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
import org.apache.flink.api.java.functions.IdPartitioner;
import org.apache.flink.optimizer.costs.Costs; import org.apache.flink.optimizer.costs.Costs;
import org.apache.flink.optimizer.dataproperties.GlobalProperties; import org.apache.flink.optimizer.dataproperties.GlobalProperties;
import org.apache.flink.optimizer.dataproperties.LocalProperties; import org.apache.flink.optimizer.dataproperties.LocalProperties;
import org.apache.flink.optimizer.plan.IterationPlanNode;
import org.apache.flink.runtime.io.network.DataExchangeMode; import org.apache.flink.runtime.io.network.DataExchangeMode;
import org.apache.flink.runtime.operators.udf.AssignRangeIndex; import org.apache.flink.runtime.operators.udf.AssignRangeIndex;
import org.apache.flink.runtime.operators.udf.RemoveRangeIndex; import org.apache.flink.runtime.operators.udf.RemoveRangeIndex;
Expand All @@ -57,7 +59,9 @@
import org.apache.flink.util.Visitor; import org.apache.flink.util.Visitor;


import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set;


/** /**
* *
Expand All @@ -76,9 +80,11 @@ public class RangePartitionRewriter implements Visitor<PlanNode> {
final static IdPartitioner idPartitioner = new IdPartitioner(); final static IdPartitioner idPartitioner = new IdPartitioner();


final OptimizedPlan plan; final OptimizedPlan plan;
final Set<IterationPlanNode> visitedIterationNodes;


public RangePartitionRewriter(OptimizedPlan plan) { public RangePartitionRewriter(OptimizedPlan plan) {
this.plan = plan; this.plan = plan;
this.visitedIterationNodes = new HashSet<>();
} }


@Override @Override
Expand All @@ -87,12 +93,26 @@ public boolean preVisit(PlanNode visitable) {
} }


@Override @Override
public void postVisit(PlanNode visitable) { public void postVisit(PlanNode node) {
final Iterable<Channel> inputChannels = visitable.getInputs();
if(node instanceof IterationPlanNode) {
IterationPlanNode iNode = (IterationPlanNode)node;
if(!visitedIterationNodes.contains(iNode)) {
visitedIterationNodes.add(iNode);
iNode.acceptForStepFunction(this);
}
}

final Iterable<Channel> inputChannels = node.getInputs();
for (Channel channel : inputChannels) { for (Channel channel : inputChannels) {
ShipStrategyType shipStrategy = channel.getShipStrategy(); ShipStrategyType shipStrategy = channel.getShipStrategy();
// Make sure we only optimize the DAG for range partition, and do not optimize multi times. // Make sure we only optimize the DAG for range partition, and do not optimize multi times.
if (shipStrategy == ShipStrategyType.PARTITION_RANGE && isOptimized(visitable)) { if (shipStrategy == ShipStrategyType.PARTITION_RANGE) {

if(node.isOnDynamicPath()) {
throw new InvalidProgramException("Range Partitioning not supported within iterations.");
}

PlanNode channelSource = channel.getSource(); PlanNode channelSource = channel.getSource();
List<Channel> newSourceOutputChannels = rewriteRangePartitionChannel(channel); List<Channel> newSourceOutputChannels = rewriteRangePartitionChannel(channel);
channelSource.getOutgoingChannels().remove(channel); channelSource.getOutgoingChannels().remove(channel);
Expand Down Expand Up @@ -214,16 +234,4 @@ private List<Channel> rewriteRangePartitionChannel(Channel channel) {
return sourceNewOutputChannels; return sourceNewOutputChannels;
} }



private boolean isOptimized(PlanNode node) {
return node.getNodeName() != PR_NAME;
}

static class IDPartitioner implements Partitioner<Integer> {

@Override
public int partition(Integer key, int numPartitions) {
return key;
}
}
} }
Expand Up @@ -23,6 +23,7 @@
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;


import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.FilterFunction; import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction; import org.apache.flink.api.common.functions.MapPartitionFunction;
Expand All @@ -33,6 +34,7 @@
import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.AggregateOperator; import org.apache.flink.api.java.operators.AggregateOperator;
import org.apache.flink.api.java.operators.DataSource; import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.tuple.Tuple2; import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3; import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.test.javaApiOperators.util.CollectionDataSets; import org.apache.flink.test.javaApiOperators.util.CollectionDataSets;
Expand Down Expand Up @@ -451,6 +453,34 @@ public void testRangePartitionerOnSequenceData() throws Exception {
} }
} }


@Test(expected = InvalidProgramException.class)
public void testRangePartitionInIteration() throws Exception {

// does not apply for collection execution
if (super.mode == TestExecutionMode.COLLECTION) {
throw new InvalidProgramException("Does not apply for collection execution");
}

final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
DataSource<Long> source = env.generateSequence(0, 10000);

DataSet<Tuple2<Long, String>> tuples = source.map(new MapFunction<Long, Tuple2<Long, String>>() {
@Override
public Tuple2<Long, String> map(Long v) throws Exception {
return new Tuple2<>(v, Long.toString(v));
}
});

DeltaIteration<Tuple2<Long, String>, Tuple2<Long, String>> it = tuples.iterateDelta(tuples, 10, 0);
DataSet<Tuple2<Long, String>> body = it.getWorkset()
.partitionByRange(1) // Verify that range partition is not allowed in iteration
.join(it.getSolutionSet())
.where(0).equalTo(0).projectFirst(0).projectSecond(1);
DataSet<Tuple2<Long, String>> result = it.closeWith(body, body);

result.collect(); // should fail
}

private static class ObjectSelfKeySelector implements KeySelector<Long, Long> { private static class ObjectSelfKeySelector implements KeySelector<Long, Long> {
@Override @Override
public Long getKey(Long value) throws Exception { public Long getKey(Long value) throws Exception {
Expand Down

0 comments on commit a6a0528

Please sign in to comment.