Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

package org.apache.flink.api.java.operators;

import java.util.ArrayList;
import java.util.List;

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.Public;
import org.apache.flink.api.common.InvalidProgramException;
Expand All @@ -42,106 +39,104 @@
import org.apache.flink.util.Collector;
import org.apache.flink.util.Preconditions;

import java.util.ArrayList;
import java.util.List;

/**
* This operator represents the application of a "aggregate" operation on a data set, and the
* result data set produced by the function.
*
*
* @param <IN> The type of the data set aggregated by the operator.
*/
@Public
public class AggregateOperator<IN> extends SingleInputOperator<IN, IN, AggregateOperator<IN>> {

private final List<AggregationFunction<?>> aggregationFunctions = new ArrayList<>(4);

private final List<Integer> fields = new ArrayList<>(4);

private final Grouping<IN> grouping;

private final String aggregateLocationName;

/**
* <p>
* Non grouped aggregation
* Non grouped aggregation.
*/
public AggregateOperator(DataSet<IN> input, Aggregations function, int field, String aggregateLocationName) {
super(Preconditions.checkNotNull(input), input.getType());
Preconditions.checkNotNull(function);

this.aggregateLocationName = aggregateLocationName;

if (!input.getType().isTupleType()) {
throw new InvalidProgramException("Aggregating on field positions is only possible on tuple data types.");
}

TupleTypeInfoBase<?> inType = (TupleTypeInfoBase<?>) input.getType();

if (field < 0 || field >= inType.getArity()) {
throw new IllegalArgumentException("Aggregation field position is out of range.");
}

AggregationFunctionFactory factory = function.getFactory();
AggregationFunction<?> aggFunct = factory.createAggregationFunction(inType.getTypeAt(field).getTypeClass());

// this is the first aggregation operator after a regular data set (non grouped aggregation)
this.aggregationFunctions.add(aggFunct);
this.fields.add(field);
this.grouping = null;
}

/**
*
* Grouped aggregation
*
* Grouped aggregation.
*
* @param input
* @param function
* @param field
*/
public AggregateOperator(Grouping<IN> input, Aggregations function, int field, String aggregateLocationName) {
super(Preconditions.checkNotNull(input).getInputDataSet(), input.getInputDataSet().getType());
Preconditions.checkNotNull(function);

this.aggregateLocationName = aggregateLocationName;

if (!input.getInputDataSet().getType().isTupleType()) {
throw new InvalidProgramException("Aggregating on field positions is only possible on tuple data types.");
}

TupleTypeInfoBase<?> inType = (TupleTypeInfoBase<?>) input.getInputDataSet().getType();

if (field < 0 || field >= inType.getArity()) {
throw new IllegalArgumentException("Aggregation field position is out of range.");
}

AggregationFunctionFactory factory = function.getFactory();
AggregationFunction<?> aggFunct = factory.createAggregationFunction(inType.getTypeAt(field).getTypeClass());

// set the aggregation fields
this.aggregationFunctions.add(aggFunct);
this.fields.add(field);
this.grouping = input;
}



public AggregateOperator<IN> and(Aggregations function, int field) {
Preconditions.checkNotNull(function);

TupleTypeInfoBase<?> inType = (TupleTypeInfoBase<?>) getType();

if (field < 0 || field >= inType.getArity()) {
throw new IllegalArgumentException("Aggregation field position is out of range.");
}



AggregationFunctionFactory factory = function.getFactory();
AggregationFunction<?> aggFunct = factory.createAggregationFunction(inType.getTypeAt(field).getTypeClass());

this.aggregationFunctions.add(aggFunct);
this.fields.add(field);

return this;
}


public AggregateOperator<IN> andSum (int field) {
return this.and(Aggregations.SUM, field);
}
Expand All @@ -154,68 +149,65 @@ public AggregateOperator<IN> andMax (int field) {
return this.and(Aggregations.MAX, field);
}


@SuppressWarnings("unchecked")
@Override
@Internal
protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> translateToDataFlow(Operator<IN> input) {

// sanity check
if (this.aggregationFunctions.isEmpty() || this.aggregationFunctions.size() != this.fields.size()) {
throw new IllegalStateException();
}

// construct the aggregation function
AggregationFunction<Object>[] aggFunctions = new AggregationFunction[this.aggregationFunctions.size()];
int[] fields = new int[this.fields.size()];
StringBuilder genName = new StringBuilder();

for (int i = 0; i < fields.length; i++) {
aggFunctions[i] = (AggregationFunction<Object>) this.aggregationFunctions.get(i);
fields[i] = this.fields.get(i);

genName.append(aggFunctions[i].toString()).append('(').append(fields[i]).append(')').append(',');
}
genName.append(" at ").append(aggregateLocationName);
genName.setLength(genName.length()-1);


genName.setLength(genName.length() - 1);

@SuppressWarnings("rawtypes")
RichGroupReduceFunction<IN, IN> function = new AggregatingUdf(aggFunctions, fields);



String name = getName() != null ? getName() : genName.toString();

// distinguish between grouped reduce and non-grouped reduce
if (this.grouping == null) {
// non grouped aggregation
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> po =
new GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>>(function, operatorInfo, new int[0], name);

po.setCombinable(true);

// set input
po.setInput(input);
// set parallelism
po.setParallelism(this.getParallelism());

return po;
}

if (this.grouping.getKeys() instanceof Keys.ExpressionKeys) {
// grouped aggregation
int[] logicalKeyPositions = this.grouping.getKeys().computeLogicalKeyPositions();
UnaryOperatorInformation<IN, IN> operatorInfo = new UnaryOperatorInformation<>(getInputType(), getResultType());
GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>> po =
new GroupReduceOperatorBase<IN, IN, GroupReduceFunction<IN, IN>>(function, operatorInfo, logicalKeyPositions, name);

po.setCombinable(true);

po.setInput(input);
po.setParallelism(this.getParallelism());
po.setCustomPartitioner(grouping.getCustomPartitioner());

SingleInputSemanticProperties props = new SingleInputSemanticProperties();

for (int keyField : logicalKeyPositions) {
Expand All @@ -232,9 +224,9 @@ protected org.apache.flink.api.common.operators.base.GroupReduceOperatorBase<IN,
props.addForwardedField(keyField, keyField);
}
}

po.setSemanticProperties(props);

return po;
}
else if (this.grouping.getKeys() instanceof Keys.SelectorFunctionKeys) {
Expand All @@ -243,47 +235,45 @@ else if (this.grouping.getKeys() instanceof Keys.SelectorFunctionKeys) {
else {
throw new UnsupportedOperationException("Unrecognized key type.");
}

}

// --------------------------------------------------------------------------------------------

@Internal
public static final class AggregatingUdf<T extends Tuple>
private static final class AggregatingUdf<T extends Tuple>
extends RichGroupReduceFunction<T, T>
implements GroupCombineFunction<T, T> {

private static final long serialVersionUID = 1L;

private final int[] fieldPositions;

private final AggregationFunction<Object>[] aggFunctions;



public AggregatingUdf(AggregationFunction<Object>[] aggFunctions, int[] fieldPositions) {
Preconditions.checkNotNull(aggFunctions);
Preconditions.checkNotNull(aggFunctions);
Preconditions.checkArgument(aggFunctions.length == fieldPositions.length);

this.aggFunctions = aggFunctions;
this.fieldPositions = fieldPositions;
}


@Override
public void open(Configuration parameters) throws Exception {
for (AggregationFunction<Object> aggFunction : aggFunctions) {
aggFunction.initializeAggregate();
}
}

@Override
public void reduce(Iterable<T> records, Collector<T> out) {
final AggregationFunction<Object>[] aggFunctions = this.aggFunctions;
final int[] fieldPositions = this.fieldPositions;

// aggregators are initialized from before

T outT = null;
for (T record : records) {
outT = record;
Expand All @@ -293,20 +283,20 @@ public void reduce(Iterable<T> records, Collector<T> out) {
aggFunctions[i].aggregate(val);
}
}

for (int i = 0; i < fieldPositions.length; i++) {
Object aggVal = aggFunctions[i].getAggregate();
outT.setField(aggVal, fieldPositions[i]);
aggFunctions[i].initializeAggregate();
}

out.collect(outT);
}

@Override
public void combine(Iterable<T> records, Collector<T> out) {
reduce(records, out);
}

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;

/**
* Resulting {@link DataSet} of bulk iterations.
* @param <T>
*/
@Internal
public class BulkIterationResultSet<T> extends DataSet<T> {

Expand All @@ -41,8 +45,7 @@ public class BulkIterationResultSet<T> extends DataSet<T> {

BulkIterationResultSet(ExecutionEnvironment context,
TypeInformation<T> type, IterativeDataSet<T> iterationHead,
DataSet<T> nextPartialSolution, DataSet<?> terminationCriterion)
{
DataSet<T> nextPartialSolution, DataSet<?> terminationCriterion) {
super(context, type);
this.iterationHead = iterationHead;
this.nextPartialSolution = nextPartialSolution;
Expand Down
Loading