Skip to content

Commit

Permalink
[FLINK-2458] [FLINK-2449] [runtime] Access distributed cache entries …
Browse files Browse the repository at this point in the history
…from Iteration contexts & use of distributed cache from Collection Environments

This closes #970
  • Loading branch information
sachingoel0101 authored and StephanEwen committed Aug 16, 2015
1 parent 0a7cc02 commit 358259d
Show file tree
Hide file tree
Showing 20 changed files with 185 additions and 70 deletions.
Expand Up @@ -57,17 +57,6 @@ public abstract class AbstractRuntimeUDFContext implements RuntimeContext {

private final DistributedCache distributedCache;


public AbstractRuntimeUDFContext(String name,
int numParallelSubtasks, int subtaskIndex,
ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig,
Map<String, Accumulator<?,?>> accumulators)
{
this(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig,
accumulators, Collections.<String, Future<Path>>emptyMap());
}

public AbstractRuntimeUDFContext(String name,
int numParallelSubtasks, int subtaskIndex,
ClassLoader userCodeClassLoader,
Expand All @@ -79,7 +68,7 @@ public AbstractRuntimeUDFContext(String name,
this.subtaskIndex = subtaskIndex;
this.userCodeClassLoader = userCodeClassLoader;
this.executionConfig = executionConfig;
this.distributedCache = new DistributedCache(cpTasks);
this.distributedCache = new DistributedCache(Preconditions.checkNotNull(cpTasks));
this.accumulators = Preconditions.checkNotNull(accumulators);
}

Expand Down
Expand Up @@ -37,18 +37,11 @@ public class RuntimeUDFContext extends AbstractRuntimeUDFContext {
private final HashMap<String, Object> initializedBroadcastVars = new HashMap<String, Object>();

private final HashMap<String, List<?>> uninitializedBroadcastVars = new HashMap<String, List<?>>();


public RuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, accumulators);
}


public RuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader userCodeClassLoader,
ExecutionConfig executionConfig, Map<String, Future<Path>> cpTasks, Map<String, Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, userCodeClassLoader, executionConfig, accumulators, cpTasks);
}


@Override
@SuppressWarnings("unchecked")
Expand Down
Expand Up @@ -27,6 +27,10 @@
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.InvalidProgramException;
Expand All @@ -37,6 +41,7 @@
import org.apache.flink.api.common.aggregators.Aggregator;
import org.apache.flink.api.common.aggregators.AggregatorWithName;
import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.common.functions.RichFunction;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
Expand All @@ -51,6 +56,8 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeutils.CompositeType;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.fs.local.LocalFileSystem;
import org.apache.flink.types.Value;
import org.apache.flink.util.Visitor;

Expand All @@ -64,6 +71,8 @@ public class CollectionExecutor {
private final Map<Operator<?>, List<?>> intermediateResults;

private final Map<String, Accumulator<?, ?>> accumulators;

private final Map<String, Future<Path>> cachedFiles;

private final Map<String, Value> previousAggregates;

Expand All @@ -84,7 +93,7 @@ public CollectionExecutor(ExecutionConfig executionConfig) {
this.accumulators = new HashMap<String, Accumulator<?,?>>();
this.previousAggregates = new HashMap<String, Value>();
this.aggregators = new HashMap<String, Aggregator<?>>();

this.cachedFiles = new HashMap<String, Future<Path>>();
this.classLoader = getClass().getClassLoader();
}

Expand All @@ -94,7 +103,7 @@ public CollectionExecutor(ExecutionConfig executionConfig) {

public JobExecutionResult execute(Plan program) throws Exception {
long startTime = System.currentTimeMillis();

initCache(program.getCachedFiles());
Collection<? extends GenericDataSinkBase<?>> sinks = program.getDataSinks();
for (Operator<?> sink : sinks) {
execute(sink);
Expand All @@ -104,7 +113,14 @@ public JobExecutionResult execute(Plan program) throws Exception {
Map<String, Object> accumulatorResults = AccumulatorHelper.toResultMap(accumulators);
return new JobExecutionResult(null, endTime - startTime, accumulatorResults);
}


private void initCache(Set<Map.Entry<String, DistributedCache.DistributedCacheEntry>> files){
for(Map.Entry<String, DistributedCache.DistributedCacheEntry> file: files){
Future<Path> doNothing = new CompletedFuture(new Path(file.getValue().filePath));
cachedFiles.put(file.getKey(), doNothing);
}
};

private List<?> execute(Operator<?> operator) throws Exception {
return execute(operator, 0);
}
Expand Down Expand Up @@ -165,8 +181,8 @@ private <IN> void executeDataSink(GenericDataSinkBase<?> sink, int superStep) th
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichOutputFormat.class.isAssignableFrom(typedSink.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(typedSink.getName(), 1, 0, getClass().getClassLoader(), executionConfig, accumulators) :
new IterationRuntimeUDFContext(typedSink.getName(), 1, 0, classLoader, executionConfig, accumulators);
ctx = superStep == 0 ? new RuntimeUDFContext(typedSink.getName(), 1, 0, getClass().getClassLoader(), executionConfig, cachedFiles, accumulators) :
new IterationRuntimeUDFContext(typedSink.getName(), 1, 0, classLoader, executionConfig, cachedFiles, accumulators);
} else {
ctx = null;
}
Expand All @@ -181,8 +197,8 @@ private <OUT> List<OUT> executeDataSource(GenericDataSourceBase<?, ?> source, in
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichInputFormat.class.isAssignableFrom(typedSource.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(source.getName(), 1, 0, getClass().getClassLoader(), executionConfig, accumulators) :
new IterationRuntimeUDFContext(source.getName(), 1, 0, classLoader, executionConfig, accumulators);
ctx = superStep == 0 ? new RuntimeUDFContext(source.getName(), 1, 0, getClass().getClassLoader(), executionConfig, cachedFiles, accumulators) :
new IterationRuntimeUDFContext(source.getName(), 1, 0, classLoader, executionConfig, cachedFiles, accumulators);
} else {
ctx = null;
}
Expand All @@ -204,8 +220,10 @@ private <IN, OUT> List<OUT> executeUnaryOperator(SingleInputOperator<?, ?, ?> op
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichFunction.class.isAssignableFrom(typedOp.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(operator.getName(), 1, 0, getClass().getClassLoader(), executionConfig, accumulators) :
new IterationRuntimeUDFContext(operator.getName(), 1, 0, classLoader, executionConfig, accumulators);
ctx = superStep == 0 ? new RuntimeUDFContext(operator.getName(), 1, 0, getClass()
.getClassLoader(), executionConfig, cachedFiles, accumulators) :
new IterationRuntimeUDFContext(operator.getName(), 1, 0, classLoader,
executionConfig, cachedFiles, accumulators);

for (Map.Entry<String, Operator<?>> bcInputs : operator.getBroadcastInputs().entrySet()) {
List<?> bcData = execute(bcInputs.getValue());
Expand Down Expand Up @@ -243,8 +261,10 @@ private <IN1, IN2, OUT> List<OUT> executeBinaryOperator(DualInputOperator<?, ?,
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichFunction.class.isAssignableFrom(typedOp.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(operator.getName(), 1, 0, classLoader, executionConfig, accumulators) :
new IterationRuntimeUDFContext(operator.getName(), 1, 0, classLoader, executionConfig, accumulators);
ctx = superStep == 0 ? new RuntimeUDFContext(operator.getName(), 1, 0, classLoader,
executionConfig, cachedFiles, accumulators) :
new IterationRuntimeUDFContext(operator.getName(), 1, 0, classLoader,
executionConfig, cachedFiles, accumulators);

for (Map.Entry<String, Operator<?>> bcInputs : operator.getBroadcastInputs().entrySet()) {
List<?> bcData = execute(bcInputs.getValue());
Expand Down Expand Up @@ -500,8 +520,9 @@ else if (op instanceof GenericDataSourceBase) {
private class IterationRuntimeUDFContext extends RuntimeUDFContext implements IterationRuntimeContext {

public IterationRuntimeUDFContext(String name, int numParallelSubtasks, int subtaskIndex, ClassLoader classloader,
ExecutionConfig executionConfig, Map<String, Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, classloader, executionConfig, accumulators);
ExecutionConfig executionConfig, Map<String, Future<Path>> cpTasks, Map<String,
Accumulator<?,?>> accumulators) {
super(name, numParallelSubtasks, subtaskIndex, classloader, executionConfig, cpTasks, accumulators);
}

@Override
Expand All @@ -521,4 +542,43 @@ public <T extends Value> T getPreviousIterationAggregate(String name) {
return (T) previousAggregates.get(name);
}
}

private static final class CompletedFuture implements Future<Path>{

private final Path result;

public CompletedFuture(Path entry) {
try{
LocalFileSystem fs = (LocalFileSystem) entry.getFileSystem();
result = entry.isAbsolute() ? new Path(entry.toUri().getPath()): new Path(fs.getWorkingDirectory(),entry);
} catch (Exception e){
throw new RuntimeException("DistributedCache supports only local files for Collection Environments");
}
}

@Override
public boolean cancel(boolean mayInterruptIfRunning) {
return false;
}

@Override
public boolean isCancelled() {
return false;
}

@Override
public boolean isDone() {
return true;
}

@Override
public Path get() throws InterruptedException, ExecutionException {
return result;
}

@Override
public Path get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException {
return get();
}
}
}
Expand Up @@ -24,10 +24,12 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.BroadcastVariableInitializer;
import org.apache.flink.core.fs.Path;
import org.junit.Test;


Expand All @@ -36,7 +38,7 @@ public class RuntimeUDFContextTest {
@Test
public void testBroadcastVariableNotFound() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(),new HashMap<String, Accumulator<?, ?>>());

try {
ctx.getBroadcastVariable("some name");
Expand Down Expand Up @@ -66,7 +68,7 @@ public void testBroadcastVariableNotFound() {
@Test
public void testBroadcastVariableSimple() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>());

ctx.setBroadcastVariable("name1", Arrays.asList(1, 2, 3, 4));
ctx.setBroadcastVariable("name2", Arrays.asList(1.0, 2.0, 3.0, 4.0));
Expand Down Expand Up @@ -100,7 +102,7 @@ public void testBroadcastVariableSimple() {
@Test
public void testBroadcastVariableWithInitializer() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>());

ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));

Expand All @@ -125,7 +127,7 @@ public void testBroadcastVariableWithInitializer() {
@Test
public void testResetBroadcastVariableWithInitializer() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>());

ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));

Expand All @@ -148,7 +150,7 @@ public void testResetBroadcastVariableWithInitializer() {
@Test
public void testBroadcastVariableWithInitializerAndMismatch() {
try {
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>());
RuntimeUDFContext ctx = new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>());

ctx.setBroadcastVariable("name", Arrays.asList(1, 2, 3, 4));

Expand Down
Expand Up @@ -20,10 +20,12 @@
package org.apache.flink.api.common.io;

import java.util.HashMap;
import java.util.concurrent.Future;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.Value;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -36,8 +38,7 @@ public class RichInputFormatTest {
@Test
public void testCheckRuntimeContextAccess() {
final SerializedInputFormat<Value> inputFormat = new SerializedInputFormat<Value>();
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1,
getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>()));
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()));

Assert.assertEquals(inputFormat.getRuntimeContext().getIndexOfThisSubtask(), 1);
Assert.assertEquals(inputFormat.getRuntimeContext().getNumberOfParallelSubtasks(),3);
Expand Down
Expand Up @@ -20,10 +20,12 @@
package org.apache.flink.api.common.io;

import java.util.HashMap;
import java.util.concurrent.Future;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.Value;
import org.junit.Assert;
import org.junit.Test;
Expand All @@ -36,8 +38,7 @@ public class RichOutputFormatTest {
@Test
public void testCheckRuntimeContextAccess() {
final SerializedOutputFormat<Value> inputFormat = new SerializedOutputFormat<Value>();
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1,
getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>()));
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1, getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Future<Path>>(), new HashMap<String, Accumulator<?, ?>>()));

Assert.assertEquals(inputFormat.getRuntimeContext().getIndexOfThisSubtask(), 1);
Assert.assertEquals(inputFormat.getRuntimeContext().getNumberOfParallelSubtasks(),3);
Expand Down
Expand Up @@ -26,10 +26,12 @@
import org.apache.flink.api.common.operators.util.TestNonRichInputFormat;
import org.apache.flink.api.common.operators.util.TestRichOutputFormat;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.core.fs.Path;
import org.apache.flink.types.Nothing;
import org.junit.Test;

import java.util.HashMap;
import java.util.concurrent.Future;

import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -87,15 +89,16 @@ public void testDataSourceWithRuntimeContext() {

ExecutionConfig executionConfig = new ExecutionConfig();
final HashMap<String, Accumulator<?, ?>> accumulatorMap = new HashMap<String, Accumulator<?, ?>>();
final HashMap<String, Future<Path>> cpTasks = new HashMap<>();
executionConfig.disableObjectReuse();
in.reset();
sink.executeOnCollections(asList(TestIOData.NAMES), new RuntimeUDFContext("test_sink", 1, 0, null, executionConfig, accumulatorMap), executionConfig);
sink.executeOnCollections(asList(TestIOData.NAMES), new RuntimeUDFContext("test_sink", 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
assertEquals(out.output, asList(TestIOData.RICH_NAMES));

executionConfig.enableObjectReuse();
out.clear();
in.reset();
sink.executeOnCollections(asList(TestIOData.NAMES), new RuntimeUDFContext("test_sink", 1, 0, null, executionConfig, accumulatorMap), executionConfig);
sink.executeOnCollections(asList(TestIOData.NAMES), new RuntimeUDFContext("test_sink", 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);
assertEquals(out.output, asList(TestIOData.RICH_NAMES));
} catch(Exception e){
e.printStackTrace();
Expand Down
Expand Up @@ -25,10 +25,12 @@
import org.apache.flink.api.common.operators.util.TestNonRichInputFormat;
import org.apache.flink.api.common.operators.util.TestRichInputFormat;
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.core.fs.Path;
import org.junit.Test;

import java.util.HashMap;
import java.util.List;
import java.util.concurrent.Future;

import static java.util.Arrays.asList;
import static org.junit.Assert.assertEquals;
Expand Down Expand Up @@ -73,13 +75,14 @@ public void testDataSourceWithRuntimeContext() {
in, new OperatorInformation<String>(BasicTypeInfo.STRING_TYPE_INFO), "testSource");

final HashMap<String, Accumulator<?, ?>> accumulatorMap = new HashMap<String, Accumulator<?, ?>>();
final HashMap<String, Future<Path>> cpTasks = new HashMap<>();
ExecutionConfig executionConfig = new ExecutionConfig();
executionConfig.disableObjectReuse();
List<String> resultMutableSafe = source.executeOnCollections(new RuntimeUDFContext("test_source", 1, 0, null, executionConfig, accumulatorMap), executionConfig);
List<String> resultMutableSafe = source.executeOnCollections(new RuntimeUDFContext("test_source", 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);

in.reset();
executionConfig.enableObjectReuse();
List<String> resultRegular = source.executeOnCollections(new RuntimeUDFContext("test_source", 1, 0, null, executionConfig, accumulatorMap), executionConfig);
List<String> resultRegular = source.executeOnCollections(new RuntimeUDFContext("test_source", 1, 0, null, executionConfig, cpTasks, accumulatorMap), executionConfig);

assertEquals(asList(TestIOData.RICH_NAMES), resultMutableSafe);
assertEquals(asList(TestIOData.RICH_NAMES), resultRegular);
Expand Down

0 comments on commit 358259d

Please sign in to comment.