Skip to content

Commit

Permalink
[FLINK-5945] [core] Close function in OuterJoinOperatorBase#executeOn…
Browse files Browse the repository at this point in the history
…Collections

Conclude OuterJoinOperatorBase#executeOnCollections with a call to
FunctionUtils.closeFunction(function) in order to close rich user
functions.

This closes #3453
  • Loading branch information
greghogan committed Mar 2, 2017
1 parent ba5aa10 commit 01703e6
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ protected List<OUT> executeOnCollections(List<IN1> leftInput, List<IN2> rightInp
FunctionUtils.setFunctionRuntimeContext(function, runtimeContext);
FunctionUtils.openFunction(function, this.parameters);


List<OUT> result = new ArrayList<>();
Collector<OUT> collector = new CopyingListCollector<>(result, outInformation.createSerializer(executionConfig));

Expand All @@ -113,6 +112,8 @@ protected List<OUT> executeOnCollections(List<IN1> leftInput, List<IN2> rightInp
function.join(left == null ? null : leftSerializer.copy(left), right == null ? null : rightSerializer.copy(right), collector);
}

FunctionUtils.closeFunction(function);

return result;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,34 +18,63 @@

package org.apache.flink.api.common.operators.base;

import java.io.Serializable;
import java.util.Arrays;
import java.util.List;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.FlatJoinFunction;
import org.apache.flink.api.common.functions.RichFlatJoinFunction;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.operators.BinaryOperatorInformation;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.fs.Path;
import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
import org.apache.flink.util.Collector;
import org.junit.Before;
import org.junit.Test;

import java.io.Serializable;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

@SuppressWarnings("serial")
public class OuterJoinOperatorBaseTest implements Serializable {

private final FlatJoinFunction<String, String, String> joiner = new FlatJoinFunction<String, String, String>() {
@Override
public void join(String first, String second, Collector<String> out) throws Exception {
out.collect(String.valueOf(first) + ',' + String.valueOf(second));
}
};
private MockRichFlatJoinFunction joiner;

private OuterJoinOperatorBase<String, String, String, FlatJoinFunction<String, String, String>> baseOperator;

private ExecutionConfig executionConfig;

private RuntimeContext runtimeContext;

@SuppressWarnings({"rawtypes", "unchecked"})
private final OuterJoinOperatorBase<String, String, String, FlatJoinFunction<String, String, String>> baseOperator =
@Before
public void setup() {
joiner = new MockRichFlatJoinFunction();

baseOperator =
new OuterJoinOperatorBase(joiner,
new BinaryOperatorInformation(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.STRING_TYPE_INFO), new int[0], new int[0], "TestJoiner", null);
new BinaryOperatorInformation(BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO,
BasicTypeInfo.STRING_TYPE_INFO), new int[0], new int[0], "TestJoiner", null);

executionConfig = new ExecutionConfig();

String taskName = "Test rich outer join function";
TaskInfo taskInfo = new TaskInfo(taskName, 0, 1, 0);
HashMap<String, Accumulator<?, ?>> accumulatorMap = new HashMap<>();
HashMap<String, Future<Path>> cpTasks = new HashMap<>();

runtimeContext = new RuntimeUDFContext(taskInfo, null, executionConfig, cpTasks,
accumulatorMap, new UnregisteredMetricsGroup());
}

@Test
public void testFullOuterJoinWithoutMatchingPartners() throws Exception {
Expand Down Expand Up @@ -131,18 +160,41 @@ public void testThatExceptionIsThrownForOuterJoinTypeNull() throws Exception {
baseOperator.setOuterJoinType(null);
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
baseOperator.executeOnCollections(leftInput, rightInput, null, executionConfig);
baseOperator.executeOnCollections(leftInput, rightInput, runtimeContext, executionConfig);
}

private void testOuterJoin(List<String> leftInput, List<String> rightInput, List<String> expected) throws Exception {
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<String> resultSafe = baseOperator.executeOnCollections(leftInput, rightInput, null, executionConfig);
List<String> resultSafe = baseOperator.executeOnCollections(leftInput, rightInput, runtimeContext, executionConfig);
executionConfig.enableObjectReuse();
List<String> resultRegular = baseOperator.executeOnCollections(leftInput, rightInput, null, executionConfig);
List<String> resultRegular = baseOperator.executeOnCollections(leftInput, rightInput, runtimeContext, executionConfig);

assertEquals(expected, resultSafe);
assertEquals(expected, resultRegular);

assertTrue(joiner.opened.get());
assertTrue(joiner.closed.get());
}

}
private static class MockRichFlatJoinFunction extends RichFlatJoinFunction<String, String, String> {
final AtomicBoolean opened = new AtomicBoolean(false);
final AtomicBoolean closed = new AtomicBoolean(false);

@Override
public void open(Configuration parameters) throws Exception {
opened.compareAndSet(false, true);
assertEquals(0, getRuntimeContext().getIndexOfThisSubtask());
assertEquals(1, getRuntimeContext().getNumberOfParallelSubtasks());
}

@Override
public void close() throws Exception{
closed.compareAndSet(false, true);
}

@Override
public void join(String first, String second, Collector<String> out) throws Exception {
out.collect(String.valueOf(first) + ',' + String.valueOf(second));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void testWithCompleteGraph()
public void testWithRMatGraph()
throws Exception {
ChecksumHashCode checksum = DataSetUtils.checksumHashCode(directedRMatGraph
.run(new HITS<LongValue, NullValue, NullValue>(0.000001)));
.run(new HITS<LongValue, NullValue, NullValue>(1)));

assertEquals(902, checksum.getCount());
assertEquals(0x000001cbba6dbcd0L, checksum.getChecksum());
Expand Down

0 comments on commit 01703e6

Please sign in to comment.