Skip to content

Commit

Permalink
[FLINK-14018][python] Add translateToPlan for DataStreamPythonCalc
Browse files Browse the repository at this point in the history
  • Loading branch information
dianfu committed Sep 27, 2019
1 parent 9aaebb0 commit e1b1ce7
Show file tree
Hide file tree
Showing 9 changed files with 283 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ public abstract class AbstractPythonScalarFunctionOperator<IN, OUT, UDFIN, UDFOU
protected final int[] udfInputOffsets;

/**
* The number of forwarded fields in the input element.
* The offset of the fields which should be forwarded.
*/
protected final int forwardedFieldCnt;
protected final int[] forwardedFields;

/**
* The udf input logical type.
Expand Down Expand Up @@ -117,12 +117,12 @@ public abstract class AbstractPythonScalarFunctionOperator<IN, OUT, UDFIN, UDFOU
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int forwardedFieldCnt) {
int[] forwardedFields) {
this.scalarFunctions = Preconditions.checkNotNull(scalarFunctions);
this.inputType = Preconditions.checkNotNull(inputType);
this.outputType = Preconditions.checkNotNull(outputType);
this.udfInputOffsets = Preconditions.checkNotNull(udfInputOffsets);
this.forwardedFieldCnt = forwardedFieldCnt;
this.forwardedFields = Preconditions.checkNotNull(forwardedFields);
}

@Override
Expand All @@ -133,7 +133,7 @@ public void open() throws Exception {
Arrays.stream(udfInputOffsets)
.mapToObj(i -> inputType.getFields().get(i))
.collect(Collectors.toList()));
udfOutputType = new RowType(outputType.getFields().subList(forwardedFieldCnt, outputType.getFieldCount()));
udfOutputType = new RowType(outputType.getFields().subList(forwardedFields.length, outputType.getFieldCount()));
super.open();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@

import org.apache.beam.sdk.fn.data.FnDataReceiver;

import java.util.Arrays;
import java.util.stream.Collectors;

/**
* The Python {@link ScalarFunction} operator for the blink planner.
*/
Expand Down Expand Up @@ -71,8 +74,8 @@ public BaseRowPythonScalarFunctionOperator(
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int forwardedFieldCnt) {
super(scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFieldCnt);
int[] forwardedFields) {
super(scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFields);
}

@Override
Expand Down Expand Up @@ -133,18 +136,16 @@ private Projection<BaseRow, BinaryRow> createUdfInputProjection() {
}

private Projection<BaseRow, BinaryRow> createForwardedFieldProjection() {
final int[] fields = new int[forwardedFieldCnt];
for (int i = 0; i < fields.length; i++) {
fields[i] = i;
}

final RowType forwardedFieldType = new RowType(inputType.getFields().subList(0, forwardedFieldCnt));
final RowType forwardedFieldType = new RowType(
Arrays.stream(forwardedFields)
.mapToObj(i -> inputType.getFields().get(i))
.collect(Collectors.toList()));
final GeneratedProjection generatedProjection = ProjectionCodeGenerator.generateProjection(
CodeGeneratorContext.apply(new TableConfig()),
"ForwardedFieldProjection",
inputType,
forwardedFieldType,
fields);
forwardedFields);
// noinspection unchecked
return generatedProjection.newInstance(Thread.currentThread().getContextClassLoader());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@

import org.apache.beam.sdk.fn.data.FnDataReceiver;

import java.util.Arrays;

/**
* The Python {@link ScalarFunction} operator for the legacy planner.
*/
Expand All @@ -59,8 +61,8 @@ public PythonScalarFunctionOperator(
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int forwardedFieldCnt) {
super(scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFieldCnt);
int[] forwardedFields) {
super(scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFields);
}

@Override
Expand All @@ -69,8 +71,8 @@ public void open() throws Exception {
this.cRowWrapper = new StreamRecordCRowWrappingCollector(output);

CRowTypeInfo forwardedInputTypeInfo = new CRowTypeInfo(new RowTypeInfo(
inputType.getFields().stream()
.limit(forwardedFieldCnt)
Arrays.stream(forwardedFields)
.mapToObj(i -> inputType.getFields().get(i))
.map(RowType.RowField::getType)
.map(TypeConversions::fromLogicalToDataType)
.map(TypeConversions::fromDataTypeToLegacyInfo)
Expand All @@ -80,7 +82,7 @@ public void open() throws Exception {

@Override
public void bufferInput(CRow input) {
CRow forwardedFieldsRow = new CRow(getForwardedRow(input.row()), input.change());
CRow forwardedFieldsRow = new CRow(Row.project(input.row(), forwardedFields), input.change());
if (getExecutionConfig().isObjectReuseEnabled()) {
forwardedFieldsRow = forwardedInputSerializer.copy(forwardedFieldsRow);
}
Expand Down Expand Up @@ -115,14 +117,6 @@ public PythonFunctionRunner<Row> createPythonFunctionRunner(FnDataReceiver<Row>
getContainingTask().getEnvironment().getTaskManagerInfo().getTmpDirectories());
}

private Row getForwardedRow(Row input) {
Row row = new Row(forwardedFieldCnt);
for (int i = 0; i < row.getArity(); i++) {
row.setField(i, input.getField(i));
}
return row;
}

/**
* The collector is used to convert a {@link Row} to a {@link CRow}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ public AbstractPythonScalarFunctionOperator<BaseRow, BaseRow, BaseRow, BaseRow>
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int forwardedFieldCnt) {
int[] forwardedFields) {
return new PassThroughPythonScalarFunctionOperator(
scalarFunctions,
inputType,
outputType,
udfInputOffsets,
forwardedFieldCnt
forwardedFields
);
}

Expand All @@ -83,8 +83,8 @@ private static class PassThroughPythonScalarFunctionOperator extends BaseRowPyth
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int forwardedFieldCnt) {
super(scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFieldCnt);
int[] forwardedFields) {
super(scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFields);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ public AbstractPythonScalarFunctionOperator<CRow, CRow, Row, Row> getTestOperato
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int forwardedFieldCnt) {
int[] forwardedFields) {
return new PassThroughPythonScalarFunctionOperator(
scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFieldCnt);
scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFields);
}

@Override
Expand All @@ -63,8 +63,8 @@ private static class PassThroughPythonScalarFunctionOperator extends PythonScala
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int forwardedFieldCnt) {
super(scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFieldCnt);
int[] forwardedFields) {
super(scalarFunctions, inputType, outputType, udfInputOffsets, forwardedFields);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ private OneInputStreamOperatorTestHarness<IN, OUT> getTestHarness() throws Excep
dataType,
dataType,
new int[]{2},
2
new int[]{0, 1}
);

return new OneInputStreamOperatorTestHarness<>(operator);
Expand All @@ -215,7 +215,7 @@ public abstract AbstractPythonScalarFunctionOperator<IN, OUT, UDFIN, UDFOUT> get
RowType inputType,
RowType outputType,
int[] udfInputOffsets,
int forwardedFieldCnt);
int[] forwardedFields);

public abstract IN newRow(boolean accumulateMsg, Object... fields);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions.python;

import org.apache.flink.annotation.Internal;
import org.apache.flink.util.Preconditions;

/**
* A simple implementation of {@link PythonFunction}.
*/
@Internal
public final class SimplePythonFunction implements PythonFunction {

private static final long serialVersionUID = 1L;

/**
* Serialized representation of the user-defined python function.
*/
private final byte[] serializedPythonFunction;

/**
* Python execution environment.
*/
private final PythonEnv pythonEnv;

public SimplePythonFunction(byte[] serializedPythonFunction, PythonEnv pythonEnv) {
this.serializedPythonFunction = Preconditions.checkNotNull(serializedPythonFunction);
this.pythonEnv = Preconditions.checkNotNull(pythonEnv);
}

@Override
public byte[] getSerializedPythonFunction() {
return serializedPythonFunction;
}

@Override
public PythonEnv getPythonEnv() {
return pythonEnv;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.plan.nodes

import org.apache.calcite.rex.{RexCall, RexInputRef, RexLiteral, RexNode}
import org.apache.flink.table.functions.FunctionLanguage
import org.apache.flink.table.functions.python.{PythonFunction, PythonFunctionInfo, SimplePythonFunction}
import org.apache.flink.table.functions.utils.ScalarSqlFunction

import scala.collection.JavaConversions._
import scala.collection.mutable

trait CommonPythonCalc {

private[flink] def extractPythonScalarFunctionInfos(
rexCalls: Array[RexCall]): (Array[Int], Array[PythonFunctionInfo]) = {
// using LinkedHashMap to keep the insert order
val inputNodes = new mutable.LinkedHashMap[RexNode, Integer]()
val pythonFunctionInfos = rexCalls.map(createPythonScalarFunctionInfo(_, inputNodes))

val udfInputOffsets = inputNodes.toArray.map(_._1).map {
case inputRef: RexInputRef => inputRef.getIndex
case _: RexLiteral => throw new Exception(
"Constants cannot be used as parameters of Python UDF for now. " +
"It will be supported in FLINK-14208")
}
(udfInputOffsets, pythonFunctionInfos)
}

private[flink] def createPythonScalarFunctionInfo(
rexCall: RexCall,
inputNodes: mutable.Map[RexNode, Integer]): PythonFunctionInfo = rexCall.getOperator match {
case sfc: ScalarSqlFunction if sfc.getScalarFunction.getLanguage == FunctionLanguage.PYTHON =>
val inputs = new mutable.ArrayBuffer[AnyRef]()
rexCall.getOperands.foreach {
case pythonRexCall: RexCall if pythonRexCall.getOperator.asInstanceOf[ScalarSqlFunction]
.getScalarFunction.getLanguage == FunctionLanguage.PYTHON =>
// Continuous Python UDFs can be chained together
val argPythonInfo = createPythonScalarFunctionInfo(pythonRexCall, inputNodes)
inputs.append(argPythonInfo)

case argNode: RexNode =>
// For input arguments of RexInputRef, it's replaced with an offset into the input row
inputNodes.get(argNode) match {
case Some(existing) => inputs.append(existing)
case None =>
val inputOffset = Integer.valueOf(inputNodes.size)
inputs.append(inputOffset)
inputNodes.put(argNode, inputOffset)
}
}

// Extracts the necessary information for Python function execution, such as
// the serialized Python function, the Python env, etc
val pythonFunction = new SimplePythonFunction(
sfc.getScalarFunction.asInstanceOf[PythonFunction].getSerializedPythonFunction,
sfc.getScalarFunction.asInstanceOf[PythonFunction].getPythonEnv)
new PythonFunctionInfo(pythonFunction, inputs.toArray)
}
}

0 comments on commit e1b1ce7

Please sign in to comment.