Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.StateCheckpointer;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.util.PriorityQueue;

/**
* A RuntimeContext contains information about the context in which functions are executed. Each parallel instance
Expand Down Expand Up @@ -230,4 +232,16 @@ <S, C extends Serializable> OperatorState<S> getOperatorState(String name, S def
*/
<S extends Serializable> OperatorState<S> getOperatorState(String name, S defaultState,
boolean partitioned) throws IOException;

/**
* Return customized {@link PriorityQueue} with the given element type information and compare order.
*
* @param typeInformation Element type information.
* @param k Expect poll elements number.
* @param order Compare order.
* @param <T> Element type.
* @return A Flink PriorityQueue implementation.
* @throws Exception
*/
<T> PriorityQueue<T> getPriorityQueue(TypeInformation<T> typeInformation, int k, boolean order) throws Exception;
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.OperatorState;
import org.apache.flink.api.common.state.StateCheckpointer;
import org.apache.flink.api.common.typeinfo.AtomicType;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.core.fs.Path;

/**
Expand Down Expand Up @@ -176,4 +180,26 @@ public <S extends Serializable> OperatorState<S> getOperatorState(String name, S
boolean partitioned) throws IOException{
throw new UnsupportedOperationException("Operator state is only accessible for streaming operators.");
}

protected <T> TypeComparator<T> createComparator(TypeInformation<T> typeInfo, boolean sortOrder, ExecutionConfig executionConfig) {

TypeComparator<T> comparator;
if (typeInfo instanceof CompositeType) {
int totalFields = typeInfo.getTotalFields();
int[] keys = new int[totalFields];
boolean[] sortOrders = new boolean[totalFields];
for (int i = 0; i < totalFields; i++) {
keys[i] = i;
sortOrders[i] = sortOrder;
}
comparator = ((CompositeType<T>) typeInfo).createComparator(keys, sortOrders, 0, executionConfig);
} else if (typeInfo instanceof AtomicType) {
// handle grouping of atomic types
comparator = ((AtomicType<T>) typeInfo).createComparator(sortOrder, executionConfig);
} else {
throw new RuntimeException("Unrecognized type: " + typeInfo);
}

return comparator;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

package org.apache.flink.api.common.functions.util;

import java.io.IOException;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -27,7 +29,10 @@
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.core.fs.Path;
import org.apache.flink.util.PriorityQueue;

/**
* A standalone implementation of the {@link RuntimeContext}, created by runtime UDF operators.
Expand Down Expand Up @@ -91,6 +96,51 @@ public <T, C> C getBroadcastVariableWithInitializer(String name, BroadcastVariab
}
}
}

/**
*
* Wrap a java.util.PriorityQueue for standalone RuntimeContext.
*/
@Override
public <T> PriorityQueue<T> getPriorityQueue(final TypeInformation<T> typeInformation, final int k, final boolean order) throws Exception {
return new PriorityQueue<T>() {

private java.util.PriorityQueue<T> javaHeap = new java.util.PriorityQueue<T>(k, new Comparator<T>() {

TypeComparator<T> typeComparator = createComparator(typeInformation, order, null);

@Override
public int compare(T first, T second) {
return typeComparator.compare(first, second);
}
});

@Override
public void insert(T element) throws IOException {
this.javaHeap.add(element);
}

@Override
public T next() throws IOException {
return this.javaHeap.poll();
}

@Override
public T next(T reuse) throws IOException {
return this.javaHeap.poll();
}

@Override
public int size() {
return this.javaHeap.size();
}

@Override
public void close() throws IOException {
this.javaHeap.clear();
}
};
}

// --------------------------------------------------------------------------------------------

Expand Down
59 changes: 59 additions & 0 deletions flink-core/src/main/java/org/apache/flink/util/PriorityQueue.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.util;

import java.io.IOException;

/**
* Interface for Flink customized PriorityQueue.
* @param <T> Element type.
*/
public interface PriorityQueue<T> {
/**
* Insert element into queue.
* @param element
* @throws IOException
*/
void insert(T element) throws IOException;

/**
* Poll the next smallest element. Create new element instance each time.
* @return
* @throws IOException
*/
T next() throws IOException;

/**
* Poll the next smallest element. Reuse the input element instance.
* @param reuse
* @return
* @throws IOException
*/
T next(T reuse) throws IOException;

/**
* @return Return the priority queue size.
*/
int size();

/**
* Close the priority queue, and release all assigned resources.
* @throws IOException
*/
void close() throws IOException;
}
29 changes: 29 additions & 0 deletions flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.functions.SelectByMaxFunction;
import org.apache.flink.api.java.functions.SelectByMinFunction;
import org.apache.flink.api.java.functions.TopKMapPartition;
import org.apache.flink.api.java.functions.TopKReducer;
import org.apache.flink.api.java.io.CsvOutputFormat;
import org.apache.flink.api.java.io.DiscardingOutputFormat;
import org.apache.flink.api.java.io.PrintingOutputFormat;
Expand Down Expand Up @@ -570,6 +572,33 @@ public GroupReduceOperator<T, T> first(int n) {

return reduceGroup(new FirstReducer<T>(n));
}

/**
* Returns a new set containing the top k element in this {@link DataSet}.<br/>
*
* @param k Expected return element number.
* @return A {@link GroupReduceOperator} which represents the top K elements DataSet.
*/
public GroupReduceOperator<T, T> topK(int k) {
return topK(k, true);
}

/**
* Returns a new set containing the top k element in this {@link DataSet}.<br/>
*
* @param k Expected return element number.
* @param order True, poll largest elements. False, poll smallest elements.
* @return A {@link GroupReduceOperator} which represents the top K elements DataSet.
*/
public GroupReduceOperator<T, T> topK(int k, boolean order) {
if (k < 1) {
throw new InvalidProgramException("Parameter k of topK(k) must be at least 1.");
}

// No group operation, so there would be only 1 reduce task.
return mapPartition(new TopKMapPartition<T>(getType(), k, order)).
reduceGroup(new TopKReducer<T>(getType(), k, order));
}

// --------------------------------------------------------------------------------------------
// distinct
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.functions;

import org.apache.flink.api.common.functions.RichMapPartitionFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.util.Collector;
import org.apache.flink.util.PriorityQueue;

public class TopKMapPartition<T> extends RichMapPartitionFunction<T, T> {
private final TypeInformation<T> typeInformation;
private final int k;
private final boolean order;

public TopKMapPartition(TypeInformation<T> typeInformation, int k, boolean order) {
this.typeInformation = typeInformation;
this.k = k;
this.order = order;
}

@Override
public void mapPartition(Iterable<T> values, Collector<T> out) throws Exception {
// PriorityQueue poll smallest element every time, so we reverse the order here.
PriorityQueue<T> priorityQueue = getRuntimeContext().getPriorityQueue(this.typeInformation, k, !this.order);

for (T value : values) {
priorityQueue.insert(value);
}

int count = k;
T element = priorityQueue.next();
while (element != null && count > 0) {
out.collect(element);
element = priorityQueue.next();
count--;
}

priorityQueue.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.api.java.functions;

import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.util.Collector;
import org.apache.flink.util.PriorityQueue;

public class TopKReducer<T> extends RichGroupReduceFunction<T, T> {
private final TypeInformation<T> typeInformation;
private final int k;
private final boolean order;

public TopKReducer(TypeInformation<T> typeInformation, int k, boolean order) {
this.typeInformation = typeInformation;
this.k = k;
this.order = order;
}

@Override
public void reduce(Iterable<T> values, Collector<T> out) throws Exception {
// PriorityQueue poll smallest element every time, so we reverse the order here.
PriorityQueue<T> priorityQueue = getRuntimeContext().getPriorityQueue(this.typeInformation, k, !this.order);
for (T value : values) {
priorityQueue.insert(value);
}

int count = k;
T element = priorityQueue.next();
while (element != null && count > 0) {
out.collect(element);
element = priorityQueue.next();
count--;
}

priorityQueue.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.core.fs.Path;
import org.apache.flink.runtime.operators.PactTaskContext;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.flink.api.common.aggregators.Aggregator;
Expand Down Expand Up @@ -172,7 +173,7 @@ public DistributedRuntimeUDFContext createRuntimeContext(String taskName) {
Environment env = getEnvironment();
return new IterativeRuntimeUdfContext(taskName, env.getNumberOfSubtasks(),
env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(),
env.getDistributedCacheEntries(), this.accumulatorMap);
env.getDistributedCacheEntries(), this.accumulatorMap, this);
}

// --------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -363,8 +364,8 @@ private class IterativeRuntimeUdfContext extends DistributedRuntimeUDFContext im

public IterativeRuntimeUdfContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, Future<Path>> cpTasks,
Map<String, Accumulator<?,?>> accumulatorMap) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, cpTasks, accumulatorMap);
Map<String, Accumulator<?,?>> accumulatorMap, PactTaskContext pactTaskContext) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, cpTasks, accumulatorMap, pactTaskContext);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,7 @@ public DistributedRuntimeUDFContext createRuntimeContext(String taskName) {

return new DistributedRuntimeUDFContext(taskName, env.getNumberOfSubtasks(),
env.getIndexInSubtaskGroup(), getUserCodeClassLoader(), getExecutionConfig(),
env.getDistributedCacheEntries(), this.accumulatorMap);
env.getDistributedCacheEntries(), this.accumulatorMap, this);
}

// --------------------------------------------------------------------------------------------
Expand Down
Loading