Skip to content

Commit

Permalink
[FLINK-1819][core] allow access to RuntimeContext from Input and Outp…
Browse files Browse the repository at this point in the history
…utFormats

1. Allow access to Runtime Context from I/O formats.
2. Make all existing I/O formats context aware.

This closes #966.
  • Loading branch information
sachingoel0101 authored and mxm committed Aug 13, 2015
1 parent 7cc85c7 commit 26c6447
Show file tree
Hide file tree
Showing 34 changed files with 1,001 additions and 49 deletions.
Expand Up @@ -46,15 +46,15 @@
import org.apache.flink.core.fs.Path;

/**
* The base class for {@link InputFormat}s that read from files. For specific input types the
* The base class for {@link RichInputFormat}s that read from files. For specific input types the
* {@link #nextRecord(Object)} and {@link #reachedEnd()} methods need to be implemented.
* Additionally, one may override {@link #open(FileInputSplit)} and {@link #close()} to
* change the life cycle behavior.
*
* <p>After the {@link #open(FileInputSplit)} method completed, the file input data is available
* from the {@link #stream} field.</p>
*/
public abstract class FileInputFormat<OT> implements InputFormat<OT, FileInputSplit> {
public abstract class FileInputFormat<OT> extends RichInputFormat<OT, FileInputSplit> {

// -------------------------------------- Constants -------------------------------------------

Expand Down
Expand Up @@ -33,10 +33,11 @@
import org.apache.flink.core.fs.FileSystem.WriteMode;

/**
* The abstract base class for all output formats that are file based. Contains the logic to open/close the target
* The abstract base class for all Rich output formats that are file based. Contains the logic to
* open/close the target
* file streams.
*/
public abstract class FileOutputFormat<IT> implements OutputFormat<IT>, InitializeOnMaster, CleanupWhenUnsuccessful {
public abstract class FileOutputFormat<IT> extends RichOutputFormat<IT> implements InitializeOnMaster, CleanupWhenUnsuccessful {

private static final long serialVersionUID = 1L;

Expand Down
Expand Up @@ -26,9 +26,9 @@
import org.apache.flink.core.io.GenericInputSplit;

/**
* Generic base class for all inputs that are not based on files.
* Generic base class for all Rich inputs that are not based on files.
*/
public abstract class GenericInputFormat<OT> implements InputFormat<OT, GenericInputSplit> {
public abstract class GenericInputFormat<OT> extends RichInputFormat<OT, GenericInputSplit> {

private static final long serialVersionUID = 1L;

Expand Down
Expand Up @@ -19,6 +19,7 @@
package org.apache.flink.api.common.io;


import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.io.statistics.BaseStatistics;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.io.InputSplit;
Expand Down Expand Up @@ -52,14 +53,15 @@
* @param <S> The InputSplit type of the wrapped InputFormat.
*
* @see org.apache.flink.api.common.io.InputFormat
* @see org.apache.flink.api.common.io.RichInputFormat
* @see org.apache.flink.api.common.operators.base.JoinOperatorBase
* @see org.apache.flink.api.common.operators.base.CrossOperatorBase
* @see org.apache.flink.api.common.operators.base.MapOperatorBase
* @see org.apache.flink.api.common.operators.base.FlatMapOperatorBase
* @see org.apache.flink.api.common.operators.base.FilterOperatorBase
* @see org.apache.flink.api.common.operators.base.MapPartitionOperatorBase
*/
public final class ReplicatingInputFormat<OT, S extends InputSplit> implements InputFormat<OT, S> {
public final class ReplicatingInputFormat<OT, S extends InputSplit> extends RichInputFormat<OT, S> {

private static final long serialVersionUID = 1L;

Expand Down Expand Up @@ -112,4 +114,20 @@ public OT nextRecord(OT reuse) throws IOException {
public void close() throws IOException {
this.replicatedIF.close();
}

@Override
public void setRuntimeContext(RuntimeContext context){
if(this.replicatedIF instanceof RichInputFormat){
((RichInputFormat)this.replicatedIF).setRuntimeContext(context);
}
}

@Override
public RuntimeContext getRuntimeContext(){
if(this.replicatedIF instanceof RichInputFormat){
return ((RichInputFormat)this.replicatedIF).getRuntimeContext();
} else{
throw new RuntimeException("The underlying input format to this ReplicatingInputFormat isn't context aware");
}
}
}
@@ -0,0 +1,50 @@
/*
c * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.common.io;

import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.core.io.InputSplit;

/**
* An abstract stub implementation for Rich input formats.
* Rich formats have access to their runtime execution context via {@link #getRuntimeContext()}.
*/
public abstract class RichInputFormat<OT, T extends InputSplit> implements InputFormat<OT, T> {

private static final long serialVersionUID = 1L;

// --------------------------------------------------------------------------------------------
// Runtime context access
// --------------------------------------------------------------------------------------------

private transient RuntimeContext runtimeContext;

public void setRuntimeContext(RuntimeContext t) {
this.runtimeContext = t;
}

public RuntimeContext getRuntimeContext() {
if (this.runtimeContext != null) {
return this.runtimeContext;
} else {
throw new IllegalStateException("The runtime context has not been initialized yet. Try accessing " +
"it in one of the other life cycle methods.");
}
}
}
@@ -0,0 +1,49 @@
/*
c * Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.common.io;

import org.apache.flink.api.common.functions.RuntimeContext;

/**
* An abstract stub implementation for Rich output formats.
* Rich formats have access to their runtime execution context via {@link #getRuntimeContext()}.
*/
public abstract class RichOutputFormat<IT> implements OutputFormat<IT> {

private static final long serialVersionUID = 1L;

// --------------------------------------------------------------------------------------------
// Runtime context access
// --------------------------------------------------------------------------------------------

private transient RuntimeContext runtimeContext;

public void setRuntimeContext(RuntimeContext t) {
this.runtimeContext = t;
}

public RuntimeContext getRuntimeContext() {
if (this.runtimeContext != null) {
return this.runtimeContext;
} else {
throw new IllegalStateException("The runtime context has not been initialized yet. Try accessing " +
"it in one of the other life cycle methods.");
}
}
}
Expand Up @@ -28,8 +28,7 @@
*
* @see SerializedInputFormat
*/
public class SerializedOutputFormat<T extends IOReadableWritable> extends
BinaryOutputFormat<T> {
public class SerializedOutputFormat<T extends IOReadableWritable> extends BinaryOutputFormat<T> {

private static final long serialVersionUID = 1L;

Expand Down
Expand Up @@ -40,6 +40,8 @@
import org.apache.flink.api.common.functions.IterationRuntimeContext;
import org.apache.flink.api.common.functions.RichFunction;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.api.common.io.RichInputFormat;
import org.apache.flink.api.common.io.RichOutputFormat;
import org.apache.flink.api.common.operators.base.BulkIterationBase;
import org.apache.flink.api.common.operators.base.BulkIterationBase.PartialSolutionPlaceHolder;
import org.apache.flink.api.common.operators.base.DeltaIterationBase;
Expand Down Expand Up @@ -128,10 +130,10 @@ else if (operator instanceof DualInputOperator) {
result = executeBinaryOperator((DualInputOperator<?, ?, ?, ?>) operator, superStep);
}
else if (operator instanceof GenericDataSourceBase) {
result = executeDataSource((GenericDataSourceBase<?, ?>) operator);
result = executeDataSource((GenericDataSourceBase<?, ?>) operator, superStep);
}
else if (operator instanceof GenericDataSinkBase) {
executeDataSink((GenericDataSinkBase<?>) operator);
executeDataSink((GenericDataSinkBase<?>) operator, superStep);
result = Collections.emptyList();
}
else {
Expand All @@ -148,7 +150,7 @@ else if (operator instanceof GenericDataSinkBase) {
// Operator class specific execution methods
// --------------------------------------------------------------------------------------------

private <IN> void executeDataSink(GenericDataSinkBase<?> sink) throws Exception {
private <IN> void executeDataSink(GenericDataSinkBase<?> sink, int superStep) throws Exception {
Operator<?> inputOp = sink.getInput();
if (inputOp == null) {
throw new InvalidProgramException("The data sink " + sink.getName() + " has no input.");
Expand All @@ -160,13 +162,31 @@ private <IN> void executeDataSink(GenericDataSinkBase<?> sink) throws Exception
@SuppressWarnings("unchecked")
GenericDataSinkBase<IN> typedSink = (GenericDataSinkBase<IN>) sink;

typedSink.executeOnCollections(input, executionConfig);
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichOutputFormat.class.isAssignableFrom(typedSink.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(typedSink.getName(), 1, 0, getClass().getClassLoader(), executionConfig, accumulators) :
new IterationRuntimeUDFContext(typedSink.getName(), 1, 0, classLoader, executionConfig, accumulators);
} else {
ctx = null;
}

typedSink.executeOnCollections(input, ctx, executionConfig);
}

private <OUT> List<OUT> executeDataSource(GenericDataSourceBase<?, ?> source) throws Exception {
private <OUT> List<OUT> executeDataSource(GenericDataSourceBase<?, ?> source, int superStep)
throws Exception {
@SuppressWarnings("unchecked")
GenericDataSourceBase<OUT, ?> typedSource = (GenericDataSourceBase<OUT, ?>) source;
return typedSource.executeOnCollections(executionConfig);
// build the runtime context and compute broadcast variables, if necessary
RuntimeUDFContext ctx;
if (RichInputFormat.class.isAssignableFrom(typedSource.getUserCodeWrapper().getUserCodeClass())) {
ctx = superStep == 0 ? new RuntimeUDFContext(source.getName(), 1, 0, getClass().getClassLoader(), executionConfig, accumulators) :
new IterationRuntimeUDFContext(source.getName(), 1, 0, classLoader, executionConfig, accumulators);
} else {
ctx = null;
}
return typedSource.executeOnCollections(ctx, executionConfig);
}

private <IN, OUT> List<OUT> executeUnaryOperator(SingleInputOperator<?, ?, ?> operator, int superStep) throws Exception {
Expand Down
Expand Up @@ -19,15 +19,17 @@

package org.apache.flink.api.common.operators;

import java.util.List;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.distributions.DataDistribution;
import org.apache.flink.api.common.io.FinalizeOnMaster;
import org.apache.flink.api.common.io.InitializeOnMaster;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.io.OutputFormat;
import org.apache.flink.api.common.io.RichOutputFormat;
import org.apache.flink.api.common.io.InitializeOnMaster;
import org.apache.flink.api.common.io.FinalizeOnMaster;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
import org.apache.flink.api.common.typeinfo.AtomicType;
Expand Down Expand Up @@ -296,11 +298,11 @@ public void accept(Visitor<Operator<?>> visitor) {
visitor.postVisit(this);
}
}

// --------------------------------------------------------------------------------------------

@SuppressWarnings("unchecked")
protected void executeOnCollections(List<IN> inputData, ExecutionConfig executionConfig) throws Exception {
protected void executeOnCollections(List<IN> inputData, RuntimeContext ctx, ExecutionConfig executionConfig) throws Exception {
OutputFormat<IN> format = this.formatWrapper.getUserCodeObject();
TypeInformation<IN> inputType = getInput().getOperatorInfo().getOutputType();

Expand Down Expand Up @@ -328,9 +330,11 @@ public int compare(IN o1, IN o2) {
if(format instanceof InitializeOnMaster) {
((InitializeOnMaster)format).initializeGlobal(1);
}

format.configure(this.parameters);


if(format instanceof RichOutputFormat){
((RichOutputFormat) format).setRuntimeContext(ctx);
}
format.open(0, 1);
for (IN element : inputData) {
format.writeRecord(element);
Expand Down
Expand Up @@ -24,7 +24,9 @@

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.Partitioner;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.io.InputFormat;
import org.apache.flink.api.common.io.RichInputFormat;
import org.apache.flink.api.common.operators.util.UserCodeClassWrapper;
import org.apache.flink.api.common.operators.util.UserCodeObjectWrapper;
import org.apache.flink.api.common.operators.util.UserCodeWrapper;
Expand Down Expand Up @@ -200,14 +202,18 @@ public void accept(Visitor<Operator<?>> visitor) {
visitor.postVisit(this);
}
}

// --------------------------------------------------------------------------------------------

protected List<OUT> executeOnCollections(ExecutionConfig executionConfig) throws Exception {
protected List<OUT> executeOnCollections(RuntimeContext ctx, ExecutionConfig executionConfig) throws Exception {
@SuppressWarnings("unchecked")
InputFormat<OUT, InputSplit> inputFormat = (InputFormat<OUT, InputSplit>) this.formatWrapper.getUserCodeObject();
inputFormat.configure(this.parameters);


if(inputFormat instanceof RichInputFormat){
((RichInputFormat) inputFormat).setRuntimeContext(ctx);
}

List<OUT> result = new ArrayList<OUT>();

// splits
Expand Down
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/


package org.apache.flink.api.common.io;

import java.util.HashMap;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.functions.util.RuntimeUDFContext;
import org.apache.flink.types.Value;
import org.junit.Assert;
import org.junit.Test;

/**
* Tests runtime context access from inside an RichInputFormat class
*/
public class RichInputFormatTest {

@Test
public void testCheckRuntimeContextAccess() {
final SerializedInputFormat<Value> inputFormat = new SerializedInputFormat<Value>();
inputFormat.setRuntimeContext(new RuntimeUDFContext("test name", 3, 1,
getClass().getClassLoader(), new ExecutionConfig(), new HashMap<String, Accumulator<?, ?>>()));

Assert.assertEquals(inputFormat.getRuntimeContext().getIndexOfThisSubtask(), 1);
Assert.assertEquals(inputFormat.getRuntimeContext().getNumberOfParallelSubtasks(),3);
}
}

0 comments on commit 26c6447

Please sign in to comment.