From fe6c86d7e10f528bfd55675bdb496f160744b070 Mon Sep 17 00:00:00 2001 From: nikolay_vasilishin Date: Fri, 23 Dec 2016 13:50:46 +0300 Subject: [PATCH 1/4] [FLINK-3850] Add forward field annotations to DataSet operators generated by the Table API - Added field forwarding at most of DataSetRel implementations. - String with forwarded fields allowed to be empty at SemanticPropUtil.java - Wrapper for indices based on types moved to object class FieldForwardingUtils - In most cases forwarding done only for conversion BatchScan: forwarding at conversion DataSetAggregate: forwarding at conversion DataSetCalc: forwarding based on unmodified at RexCalls operands DataSetCorrelate: forwarding based on unmodified at RexCalls operands DataSetIntersect: forwarding at conversion DataSetJoin: forwarding based on fields which are not keys DataSetMinus: forwarding at conversion DataSetSingleRowJoin: forwarded all fields from multi row dataset, single row used via broadcast DataSetSort: all fields forwarded + conversion Conflicts: flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala --- .../api/java/functions/SemanticPropUtil.java | 2 +- .../table/plan/nodes/CommonCorrelate.scala | 3 +- .../table/plan/nodes/dataset/BatchScan.scala | 11 +- .../plan/nodes/dataset/DataSetAggregate.scala | 11 ++ .../plan/nodes/dataset/DataSetCalc.scala | 36 +++++- .../plan/nodes/dataset/DataSetCorrelate.scala | 46 ++++++- .../plan/nodes/dataset/DataSetIntersect.scala | 1 + .../plan/nodes/dataset/DataSetJoin.scala | 12 ++ .../plan/nodes/dataset/DataSetMinus.scala | 1 + .../nodes/dataset/DataSetSingleRowJoin.scala | 27 +++- .../plan/nodes/dataset/DataSetSort.scala | 4 + .../forwarding/FieldForwardingUtils.scala | 115 ++++++++++++++++++ .../datastream/DataStreamCorrelate.scala | 3 + 13 files changed, 260 insertions(+), 12 deletions(-) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala diff --git a/flink-java/src/main/java/org/apache/flink/api/java/functions/SemanticPropUtil.java b/flink-java/src/main/java/org/apache/flink/api/java/functions/SemanticPropUtil.java index aedba150a81bc..bb11a4de56655 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/functions/SemanticPropUtil.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/functions/SemanticPropUtil.java @@ -420,7 +420,7 @@ private static void parseForwardedFields(SemanticProperties sp, String[] forward } for (String s : forwardedStr) { - if (s == null) { + if (s == null || s.trim().equals("")) { continue; } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala index 6c4066b7af9d4..d3745c5684f94 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala @@ -46,6 +46,7 @@ trait CommonCorrelate { config: TableConfig, inputTypeInfo: TypeInformation[Row], udtfTypeInfo: TypeInformation[Any], + returnType: TypeInformation[Row], rowType: RelDataType, joinType: SemiJoinType, rexCall: RexCall, @@ -54,8 +55,6 @@ trait CommonCorrelate { ruleDescription: String) : CorrelateFlatMapRunner[Row, Row] = { - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType) - val flatMap = generateFunction( config, inputTypeInfo, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala index b39b8ed8896be..2f0ecc97f1f17 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala @@ -22,6 +22,7 @@ import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.TableConfig import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.nodes.CommonScan +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedFields import org.apache.flink.table.plan.schema.FlinkTable import org.apache.flink.types.Row @@ -53,7 +54,15 @@ trait BatchScan extends CommonScan with DataSetRel { val opName = s"from: (${getRowType.getFieldNames.asScala.toList.mkString(", ")})" - input.map(mapFunc).name(opName) + //Forward all fields at conversion + val indices = flinkTable.fieldIndexes.zipWithIndex + val fields: String = getForwardedFields(inputType, internalType, indices) + + input + .map(mapFunc) + .withForwardedFields(fields) + .name(opName) + .asInstanceOf[DataSet[Row]] } // no conversion necessary, forward else { diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index 5a4aa5916bdca..11656d9f8ea12 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -28,6 +28,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.table.api.BatchTableEnvironment +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.nodes.CommonAggregate import org.apache.flink.table.runtime.aggregate.{AggregateUtil, DataSetPreAggFunction} @@ -117,6 +118,8 @@ class DataSetAggregate( .groupBy(grouping: _*) .combineGroup(preAgg.get) .returns(preAggType.get) + // forward fields at conversion + .withForwardedFields(forwardFields(rowTypeInfo)) .name(aggOpName) // final aggregation .groupBy(grouping.indices: _*) @@ -153,4 +156,12 @@ class DataSetAggregate( } } } + + private def forwardFields(rowTypeInfo: RowTypeInfo) = { + val indices = 0 to rowTypeInfo.getTotalFields + getForwardedInput( + FlinkTypeFactory.toInternalRowTypeInfo(inputType), + rowTypeInfo, + indices) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala index e05b5a83e69cb..9818faf4ed557 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala @@ -23,15 +23,17 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.core.Calc import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter} -import org.apache.calcite.rex._ import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.java.DataSet +import org.apache.calcite.rex._ +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedFields import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.plan.nodes.CommonCalc import org.apache.flink.types.Row +import scala.collection.JavaConversions._ /** * Flink RelNode which matches along with LogicalCalc. * @@ -100,7 +102,35 @@ class DataSetCalc( body, returnType) + def getForwardIndices = { + //get indices of all modified operands + val modifiedOperands = calcProgram. + getExprList + .filter(_.isInstanceOf[RexCall]) + .flatMap(_.asInstanceOf[RexCall].operands) + .map(_.asInstanceOf[RexLocalRef].getIndex) + .toSet + + // get input/output indices of operands, filter modified operands and specify forwarding + val tuples = calcProgram.getProjectList + .map(ref => (ref.getName, ref.getIndex)) + .zipWithIndex + .map { case ((name, inputIndex), projectIndex) => (name, inputIndex, projectIndex) } + //consider only input fields + .filter(_._2 < calcProgram.getExprList.filter(_.isInstanceOf[RexInputRef]).map(_.asInstanceOf[RexInputRef]).size) + .filterNot(ref => modifiedOperands.contains(ref._2)) + .map {case (name, in, out) => (in, out)} + tuples + } + val mapFunc = calcMapFunction(genFunction) - inputDS.flatMap(mapFunc).name(calcOpName(calcProgram, getExpressionString)) - } + val tuples = getForwardIndices + + val fields: String = getForwardedFields(inputDS.getType, + returnType, + tuples, + calcProgram = Some(calcProgram)) + + inputDS.flatMap(mapFunc).withForwardedFields(fields).name(calcOpName(calcProgram, getExpressionString)) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala index c18a829587d5f..c1a30189356a7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -22,14 +22,19 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.logical.LogicalTableFunctionScan import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.calcite.rex.{RexCall, RexNode} +import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode} import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment +import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.CommonCorrelate import org.apache.flink.types.Row +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput + +import scala.collection.JavaConversions._ /** * Flink RelNode which matches along with join a user defined table function. @@ -97,11 +102,13 @@ class DataSetCorrelate( val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType) val mapFunc = correlateMapFunction( config, inputDS.getType, udtfTypeInfo, + returnType, getRowType, joinType, rexCall, @@ -109,6 +116,41 @@ class DataSetCorrelate( Some(pojoFieldMapping), ruleDescription) - inputDS.flatMap(mapFunc).name(correlateOpName(rexCall, sqlFunction, relRowType)) + def getIndices = { + //recursively get all operands from RexCalls + def extractOperands(rex: RexNode): Seq[Int] = { + rex match { + case r: RexInputRef => Seq(r.getIndex) + case call: RexCall => call.operands.flatMap(extractOperands) + case _ => Seq() + } + } + //get indices of all modified operands + val modifiedOperandsInRel = funcRel.getCall.asInstanceOf[RexCall].operands + .flatMap(extractOperands) + .toSet + val joinCondition = if (condition.isDefined) { + condition.get.asInstanceOf[RexCall].operands + .flatMap(extractOperands) + .toSet + } else { + Set() + } + val modifiedOperands = modifiedOperandsInRel ++ joinCondition + + // get input/output indices of operands, filter modified operands and specify forwarding + val tuples = inputDS.getType.asInstanceOf[CompositeType[_]].getFieldNames + .zipWithIndex + .map(_._2) + .filterNot(modifiedOperands.contains) + .toSeq + + tuples + } + + val fields: String = getForwardedInput(inputDS.getType, returnType, getIndices) + inputDS.flatMap(mapFunc) + .withForwardedFields(fields) + .name(correlateOpName(rexCall, sqlFunction, relRowType)) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala index 4497df33d8b4f..5f9949972f83d 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala @@ -24,6 +24,7 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput import org.apache.flink.table.runtime.IntersectCoGroupFunction import org.apache.flink.types.Row diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala index e6f8ca4bb82e2..de7829be91f7a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetJoin.scala @@ -31,6 +31,7 @@ import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.{BatchTableEnvironment, TableException} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.{getForwardedFields, getForwardedInput} import org.apache.flink.table.runtime.FlatJoinRunner import org.apache.flink.types.Row @@ -199,10 +200,21 @@ class DataSetJoin( val joinOpName = s"where: ($joinConditionToString), join: ($joinSelectionToString)" + //consider all fields not which are not keys are forwarded + val leftIndices = (0 until left.getRowType.getFieldCount).diff(leftKeys) + val fieldsLeft = getForwardedInput(leftDataSet.getType, returnType, leftIndices) + + val rightIndices = (0 until right.getRowType.getFieldCount) + .diff(rightKeys) + .map(in => (in, in + left.getRowType.getFieldCount)) + val fieldsRight = getForwardedFields(rightDataSet.getType, returnType, rightIndices) + joinOperator .where(leftKeys.toArray: _*) .equalTo(rightKeys.toArray: _*) .`with`(joinFun) + .withForwardedFieldsFirst(fieldsLeft) + .withForwardedFieldsSecond(fieldsRight) .name(joinOpName) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala index 9ba65bfdddb8c..dba23fab4bb3c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala @@ -24,6 +24,7 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput import org.apache.flink.table.runtime.MinusCoGroupFunction import org.apache.flink.types.Row diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index b7d1a4bfb60c6..417d84d814b94 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -26,10 +26,12 @@ import org.apache.calcite.rex.RexNode import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet -import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig} import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.api.java.typeutils.GenericTypeInfo import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.runtime.{MapJoinLeftRunner, MapJoinRightRunner} +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.{compositeTypeField, getForwardedFields} +import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig} import org.apache.flink.types.Row import scala.collection.JavaConversions._ @@ -92,11 +94,13 @@ class DataSetSingleRowJoin( val leftDataSet = left.asInstanceOf[DataSetRel].translateToPlan(tableEnv) val rightDataSet = right.asInstanceOf[DataSetRel].translateToPlan(tableEnv) + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) val broadcastSetName = "joinSet" val mapSideJoin = generateMapFunction( tableEnv.getConfig, leftDataSet.getType, rightDataSet.getType, + returnType, leftIsSingle, joinCondition, broadcastSetName) @@ -108,9 +112,27 @@ class DataSetSingleRowJoin( (leftDataSet, rightDataSet) } + def customWrapper(typeInformation: TypeInformation[_]) = { + typeInformation match { + case r: GenericTypeInfo[_] => if (r.getTypeClass == classOf[Row]) { + val leftCount: Int = leftDataSet.getType.getTotalFields + val rightCount: Int = rightDataSet.getType.getTotalFields + compositeTypeField((0 until (leftCount + rightCount)).map("f" + _)) + } else { + ??? + } + } + } + val offset: Int = if (leftIsSingle) 1 else 0 + val indices = (0 until multiRowDataSet.getType.getTotalFields) + .map { inputIndex => (inputIndex, inputIndex + offset) } + + val fields = getForwardedFields(multiRowDataSet.getType, returnType, indices, customWrapper) + multiRowDataSet .flatMap(mapSideJoin) .withBroadcastSet(singleRowDataSet, broadcastSetName) + .withForwardedFields(fields) .name(getMapOperatorName) } @@ -118,6 +140,7 @@ class DataSetSingleRowJoin( config: TableConfig, inputType1: TypeInformation[Row], inputType2: TypeInformation[Row], + returnType: TypeInformation[Row], firstIsSingle: Boolean, joinCondition: RexNode, broadcastInputSetName: String) @@ -129,8 +152,6 @@ class DataSetSingleRowJoin( inputType1, Some(inputType2)) - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) - val conversion = codeGenerator.generateConverterResultExpression( returnType, joinRowType.getFieldNames) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala index 192237ac14ee7..d23da667ea9de 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala @@ -29,6 +29,7 @@ import org.apache.calcite.rex.{RexLiteral, RexNode} import org.apache.flink.api.common.operators.Order import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.{BatchTableEnvironment, TableException} +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput import org.apache.flink.table.runtime.{CountPartitionFunction, LimitFilterFunction} import org.apache.flink.types.Row @@ -128,9 +129,12 @@ class DataSetSort( broadcastName) val limitName = s"offset: $offsetToString, fetch: $fetchToString" + // TODO Do we need this here? + val allFields: String = "*" partitionedDs .filter(limitFunction) + .withForwardedFields(allFields) .name(limitName) .withBroadcastSet(partitionCount, broadcastName) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala new file mode 100644 index 0000000000000..3533daa54737f --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.dataset.forwarding + +import org.apache.calcite.rex.RexProgram +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.table.api.TableException + +import scala.collection.JavaConversions._ + +object FieldForwardingUtils { + + def compositeTypeField = (fields: Seq[String]) => (v: Int) => fields(v) + + private def throwMissedWrapperException(wrapperCustomCase: TypeInformation[_]) = { + throw new TableException(s"Implementation for $wrapperCustomCase index wrapper is missing.") + } + + /** + * Wrapper for {@link getForwardedFields} + * @param inputType + * @param outputType + * @param forwardIndices + * @param wrapperCustomCase + * @param calcProgram + * @return + */ + def getForwardedInput( + inputType: TypeInformation[_], + outputType: TypeInformation[_], + forwardIndices: Seq[Int], + wrapperCustomCase: TypeInformation[_] => (Int) => String = throwMissedWrapperException, + calcProgram: Option[RexProgram] = None) = { + + getForwardedFields(inputType, + outputType, + forwardIndices.zip(forwardIndices), + wrapperCustomCase, + calcProgram) + } + + /** + * Wraps provided indices with proper names, e.g. _1 for tuple, f0 for row, fieldName for POJO. + * @param inputType + * @param outputType + * @param forwardIndices - tuple of input-output indices of a forwarded field + * @param wrapperCustomCase - used for figuring out proper type in specific cases, + * e.g. {@see DataSetSingleRowJoin} + * @param calcProgram - used for figuring out proper type in specific cases, + * e.g. {@see DataSetCalc} + * @return - string with forwarded fields mapped from input to output + */ + def getForwardedFields( + inputType: TypeInformation[_], + outputType: TypeInformation[_], + forwardIndices: Seq[(Int, Int)], + wrapperCustomCase: TypeInformation[_] => (Int) => String = throwMissedWrapperException, + calcProgram: Option[RexProgram] = None): String = { + + def chooseWrapper(typeInformation: TypeInformation[_]): (Int) => String = { + typeInformation match { + case composite: CompositeType[_] => + // POJOs' fields are sorted, so we can not access them by their positional index. + // So we collect field names from + // outputRowType. For all other types we get field names from inputDS. + val typeFieldList = composite.getFieldNames + var fieldWrapper: (Int) => String = compositeTypeField(typeFieldList) + if (calcProgram.isDefined) { + val projectFieldList = calcProgram.get.getOutputRowType.getFieldNames + if (typeFieldList.toSet == projectFieldList.toSet) { + fieldWrapper = compositeTypeField(projectFieldList) + } + } + fieldWrapper + case basic: BasicTypeInfo[_] => (v: Int) => s"*" + case _ => wrapperCustomCase(typeInformation) + } + } + + val wrapInput = chooseWrapper(inputType) + val wrapOutput = chooseWrapper(outputType) + + implicit def string2ForwardFields(left: String) = new AnyRef { + def ->(right: String): String = left + "->" + right + def simplify(): String = if (left.split("->").head == left.split("->").last) { + left.split("->").head + } else { + left + } + } + + def wrapIndices(inputIndex: Int, outputIndex: Int): String = { + wrapInput(inputIndex) -> wrapOutput(outputIndex) simplify() + } + + forwardIndices map { case (in, out) => wrapIndices(in, out) } mkString ";" + } +} diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala index dd799e6946f72..f2d16c3825237 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -26,6 +26,7 @@ import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.streaming.api.datastream.DataStream import org.apache.flink.table.api.StreamTableEnvironment +import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.CommonCorrelate import org.apache.flink.types.Row @@ -91,11 +92,13 @@ class DataStreamCorrelate( val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType) val mapFunc = correlateMapFunction( config, inputDS.getType, udtfTypeInfo, + returnType, getRowType, joinType, rexCall, From f43ff172e5a4cc641cfffa3c325282ded2cca46e Mon Sep 17 00:00:00 2001 From: tonycox Date: Mon, 23 Jan 2017 15:50:49 +0400 Subject: [PATCH 2/4] [FLINK-3850] Add test and code refactoring --- .../org/apache/flink/table/api/table.scala | 2 +- .../plan/nodes/dataset/DataSetCalc.scala | 36 ++++----- .../plan/nodes/dataset/DataSetCorrelate.scala | 41 ++++------ .../forwarding/FieldForwardingUtils.scala | 78 ++++++++----------- .../{ => util}/ProjectionTranslator.scala | 2 +- .../table/plan/util/RexFieldExtractor.scala | 36 +++++++++ .../api/scala/batch/table/CalcITCase.scala | 17 +++- .../forwarding/FieldForwardingUtilsTest.scala | 71 +++++++++++++++++ 8 files changed, 186 insertions(+), 97 deletions(-) rename flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/{ => util}/ProjectionTranslator.scala (99%) create mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexFieldExtractor.scala create mode 100644 flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala index 769e940cd72c8..df8a9167b8ed4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/table.scala @@ -23,7 +23,7 @@ import org.apache.flink.api.java.operators.join.JoinType import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.logical.Minus import org.apache.flink.table.expressions.{Alias, Asc, Call, Expression, ExpressionParser, Ordering, TableFunctionCall, UnresolvedAlias} -import org.apache.flink.table.plan.ProjectionTranslator._ +import org.apache.flink.table.plan.util.ProjectionTranslator._ import org.apache.flink.table.plan.logical._ import org.apache.flink.table.sinks.TableSink diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala index 9818faf4ed557..ac5680a3a311f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala @@ -32,6 +32,7 @@ import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.plan.nodes.CommonCalc import org.apache.flink.types.Row +import org.apache.flink.table.plan.util.RexProgramProjectExtractor._ import scala.collection.JavaConversions._ /** @@ -103,34 +104,25 @@ class DataSetCalc( returnType) def getForwardIndices = { - //get indices of all modified operands - val modifiedOperands = calcProgram. - getExprList - .filter(_.isInstanceOf[RexCall]) - .flatMap(_.asInstanceOf[RexCall].operands) - .map(_.asInstanceOf[RexLocalRef].getIndex) - .toSet - - // get input/output indices of operands, filter modified operands and specify forwarding - val tuples = calcProgram.getProjectList - .map(ref => (ref.getName, ref.getIndex)) + // get (input, output) indices of operands, + // filter modified operands and specify forwarding + val inputFields = extractRefInputFields(calcProgram) + calcProgram.getProjectList + .map(_.getIndex) .zipWithIndex - .map { case ((name, inputIndex), projectIndex) => (name, inputIndex, projectIndex) } - //consider only input fields - .filter(_._2 < calcProgram.getExprList.filter(_.isInstanceOf[RexInputRef]).map(_.asInstanceOf[RexInputRef]).size) - .filterNot(ref => modifiedOperands.contains(ref._2)) - .map {case (name, in, out) => (in, out)} - tuples + .filter(tup => inputFields.contains(tup._1)) } val mapFunc = calcMapFunction(genFunction) - val tuples = getForwardIndices - val fields: String = getForwardedFields(inputDS.getType, + val fields = getForwardedFields( + inputDS.getType, returnType, - tuples, - calcProgram = Some(calcProgram)) + getForwardIndices) - inputDS.flatMap(mapFunc).withForwardedFields(fields).name(calcOpName(calcProgram, getExpressionString)) + inputDS + .flatMap(mapFunc) + .withForwardedFields(fields) + .name(calcOpName(calcProgram, getExpressionString)) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala index c1a30189356a7..f1d115e841278 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -33,6 +33,7 @@ import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.CommonCorrelate import org.apache.flink.types.Row import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput +import org.apache.flink.table.plan.util.RexFieldExtractor._ import scala.collection.JavaConversions._ @@ -117,39 +118,25 @@ class DataSetCorrelate( ruleDescription) def getIndices = { - //recursively get all operands from RexCalls - def extractOperands(rex: RexNode): Seq[Int] = { - rex match { - case r: RexInputRef => Seq(r.getIndex) - case call: RexCall => call.operands.flatMap(extractOperands) - case _ => Seq() - } - } - //get indices of all modified operands - val modifiedOperandsInRel = funcRel.getCall.asInstanceOf[RexCall].operands - .flatMap(extractOperands) - .toSet + //get indices of all input operands + val inputOperandsInRel = extractRefInputFields(rexCall) val joinCondition = if (condition.isDefined) { - condition.get.asInstanceOf[RexCall].operands - .flatMap(extractOperands) - .toSet + extractRefInputFields(condition.get) } else { - Set() + Array() } - val modifiedOperands = modifiedOperandsInRel ++ joinCondition - - // get input/output indices of operands, filter modified operands and specify forwarding - val tuples = inputDS.getType.asInstanceOf[CompositeType[_]].getFieldNames - .zipWithIndex - .map(_._2) - .filterNot(modifiedOperands.contains) - .toSeq + val inputOperands = inputOperandsInRel ++ joinCondition - tuples + inputDS.getType.asInstanceOf[CompositeType[_]] + .getFieldNames + .indices + .filter(inputOperands.contains) } - val fields: String = getForwardedInput(inputDS.getType, returnType, getIndices) - inputDS.flatMap(mapFunc) + val fields = getForwardedInput(inputDS.getType, mapFunc.getProducedType, getIndices) + + inputDS + .flatMap(mapFunc) .withForwardedFields(fields) .name(correlateOpName(rexCall, sqlFunction, relRowType)) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala index 3533daa54737f..37d18193eb82a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala @@ -18,16 +18,13 @@ package org.apache.flink.table.plan.nodes.dataset.forwarding -import org.apache.calcite.rex.RexProgram import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.table.api.TableException -import scala.collection.JavaConversions._ - object FieldForwardingUtils { - def compositeTypeField = (fields: Seq[String]) => (v: Int) => fields(v) + def compositeTypeField = (fields: Seq[String]) => fields private def throwMissedWrapperException(wrapperCustomCase: TypeInformation[_]) = { throw new TableException(s"Implementation for $wrapperCustomCase index wrapper is missing.") @@ -35,79 +32,70 @@ object FieldForwardingUtils { /** * Wrapper for {@link getForwardedFields} - * @param inputType - * @param outputType - * @param forwardIndices - * @param wrapperCustomCase - * @param calcProgram - * @return */ def getForwardedInput( inputType: TypeInformation[_], outputType: TypeInformation[_], forwardIndices: Seq[Int], - wrapperCustomCase: TypeInformation[_] => (Int) => String = throwMissedWrapperException, - calcProgram: Option[RexProgram] = None) = { + wrapperCustomCase: TypeInformation[_] => + (Int) => String = throwMissedWrapperException): String = { getForwardedFields(inputType, outputType, forwardIndices.zip(forwardIndices), - wrapperCustomCase, - calcProgram) + wrapperCustomCase) } /** - * Wraps provided indices with proper names, e.g. _1 for tuple, f0 for row, fieldName for POJO. - * @param inputType - * @param outputType - * @param forwardIndices - tuple of input-output indices of a forwarded field - * @param wrapperCustomCase - used for figuring out proper type in specific cases, - * e.g. {@see DataSetSingleRowJoin} - * @param calcProgram - used for figuring out proper type in specific cases, - * e.g. {@see DataSetCalc} - * @return - string with forwarded fields mapped from input to output + * Wraps provided indices with proper names. + * e.g. _1 for Tuple, f0 for Row, fieldName for POJO and named Row + * + * @param inputType information of input data + * @param outputType information of output data + * @param forwardIndices tuple of (input, output) indices of a forwarded field + * @param customWrapper used for figuring out proper type in specific cases, + * e.g. {@see DataSetSingleRowJoin} + * @return string with forwarded fields mapped from input to output */ def getForwardedFields( inputType: TypeInformation[_], outputType: TypeInformation[_], forwardIndices: Seq[(Int, Int)], - wrapperCustomCase: TypeInformation[_] => (Int) => String = throwMissedWrapperException, - calcProgram: Option[RexProgram] = None): String = { + customWrapper: TypeInformation[_] => + (Int) => String = throwMissedWrapperException): String = { def chooseWrapper(typeInformation: TypeInformation[_]): (Int) => String = { typeInformation match { case composite: CompositeType[_] => - // POJOs' fields are sorted, so we can not access them by their positional index. - // So we collect field names from - // outputRowType. For all other types we get field names from inputDS. - val typeFieldList = composite.getFieldNames - var fieldWrapper: (Int) => String = compositeTypeField(typeFieldList) - if (calcProgram.isDefined) { - val projectFieldList = calcProgram.get.getOutputRowType.getFieldNames - if (typeFieldList.toSet == projectFieldList.toSet) { - fieldWrapper = compositeTypeField(projectFieldList) - } - } - fieldWrapper - case basic: BasicTypeInfo[_] => (v: Int) => s"*" - case _ => wrapperCustomCase(typeInformation) + compositeTypeField(composite.getFieldNames) + case basic: BasicTypeInfo[_] => + (_: Int) => s"*" + case _ => + customWrapper(typeInformation) } } val wrapInput = chooseWrapper(inputType) val wrapOutput = chooseWrapper(outputType) - implicit def string2ForwardFields(left: String) = new AnyRef { - def ->(right: String): String = left + "->" + right - def simplify(): String = if (left.split("->").head == left.split("->").last) { - left.split("->").head - } else { + forwardFields(forwardIndices, wrapInput, wrapOutput) + } + + private def forwardFields( + forwardIndices: Seq[(Int, Int)], + wrapInput: (Int) => String, + wrapOutput: (Int) => String): String = { + + implicit class String2ForwardFields(left: String) { + def ->(right: String): String = if (left == right) { left + } else { + s"$left->$right" } } def wrapIndices(inputIndex: Int, outputIndex: Int): String = { - wrapInput(inputIndex) -> wrapOutput(outputIndex) simplify() + wrapInput(inputIndex) -> wrapOutput(outputIndex) } forwardIndices map { case (in, out) => wrapIndices(in, out) } mkString ";" diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/ProjectionTranslator.scala similarity index 99% rename from flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala rename to flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/ProjectionTranslator.scala index 94a0aa1cf8142..00cd692ebb01e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/ProjectionTranslator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/ProjectionTranslator.scala @@ -16,7 +16,7 @@ * limitations under the License. */ -package org.apache.flink.table.plan +package org.apache.flink.table.plan.util import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.table.api.TableEnvironment diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexFieldExtractor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexFieldExtractor.scala new file mode 100644 index 0000000000000..351ff20fc1ee3 --- /dev/null +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexFieldExtractor.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.util + +import org.apache.calcite.rex.RexNode + +object RexFieldExtractor { + + /** + * Extracts the indexes of input fields accessed by the RexNode. + * + * @param rex RexNode to analyze + * @return The indexes of accessed input fields + */ + def extractRefInputFields(rex: RexNode): Array[Int] = { + val visitor = new RefFieldsVisitor + rex.accept(visitor) + visitor.getFields + } +} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala index b78dd91fa6b0a..f6d41ec614840 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/table/CalcITCase.scala @@ -260,7 +260,7 @@ class CalcITCase( } @Test - def testSimpleCalc(): Unit = { + def testSimpleCalcWithRow(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) @@ -275,6 +275,21 @@ class CalcITCase( TestBaseUtils.compareResultAsText(results.asJava, expected) } + @Test + def testSimpleCalcWithPojo(): Unit = { + val env = ExecutionEnvironment.getExecutionEnvironment + val tEnv = TableEnvironment.getTableEnvironment(env, config) + + val t = CollectionDataSets.getSmallPojoDataSet(env).toTable(tEnv) + .select('number, 'str, 'nestedPojo) + .where('number < 7) + .select('nestedPojo.get("longNumber"), 'str) + + val expected = "10000,First\n20000,Second\n30000,Third\n" + val results = t.toDataSet[Row].collect() + TestBaseUtils.compareResultAsText(results.asJava, expected) + } + @Test def testCalcWithTwoFilters(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala new file mode 100644 index 0000000000000..75ba4567723f5 --- /dev/null +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.plan.nodes.dataset.forwarding + +import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, TypeInformation} +import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo, TypeExtractor} +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils._ +import org.junit.Assert._ +import org.junit.Test + +class FieldForwardingUtilsTest { + + @Test + def testForwarding() = { + val intType = BasicTypeInfo.INT_TYPE_INFO + val strType = BasicTypeInfo.STRING_TYPE_INFO + val doubleArrType = BasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO + + val tuple = new TupleTypeInfo(intType, strType) + val pojo = TypeExtractor.getForClass(classOf[TestPojo]) + + val row = new RowTypeInfo(strType, intType, doubleArrType) + val namedRow = new RowTypeInfo( + Array[TypeInformation[_]](intType, doubleArrType, strType), + Array("ints", "doubleArr", "strings") + ) + + assertEquals("f0->f1;f1->f0", getForwardedFields(tuple, row, Seq((0, 1), (1, 0)))) + assertEquals("f0->ints;f1->strings", getForwardedFields(tuple, namedRow, Seq((0, 0), (1, 2)))) + assertEquals("f0->someInt;f1->aString", getForwardedFields(tuple, pojo, Seq((0, 2), (1, 0)))) + + assertEquals("*", getForwardedInput(intType, intType, Seq(0))) + + val customTypeWrapper = (info: TypeInformation[_]) => + info match { + case array: BasicArrayTypeInfo[_, _] => + (_: Int) => s"*" + } + assertEquals("*", getForwardedInput(doubleArrType, doubleArrType, Seq(0), customTypeWrapper)) + } +} + +//TODO can't test it in this package +case class TestCaseClass(aString: String, someInt: Int) + +final class TestPojo { + private var aString: String = _ + var doubleArray: Array[Double] = _ + var someInt: Int = 0 + + def setaString(aString: String) = + this.aString = aString + + def getaString: String = aString +} From 41a60c17148adeaac4deb49a6a713dcbdccbb3ea Mon Sep 17 00:00:00 2001 From: tonycox Date: Thu, 16 Feb 2017 13:36:36 +0400 Subject: [PATCH 3/4] [FLINK-3850] rebase and simplify --- .../table/plan/nodes/dataset/BatchScan.scala | 2 +- .../plan/nodes/dataset/DataSetAggregate.scala | 7 ++ .../plan/nodes/dataset/DataSetCorrelate.scala | 2 +- .../plan/nodes/dataset/DataSetIntersect.scala | 12 ++- .../plan/nodes/dataset/DataSetMinus.scala | 12 ++- .../nodes/dataset/DataSetSingleRowJoin.scala | 16 +--- .../plan/nodes/dataset/DataSetSort.scala | 2 - .../forwarding/FieldForwardingUtils.scala | 92 +++++++++++++------ .../datastream/DataStreamCorrelate.scala | 2 +- .../forwarding/FieldForwardingUtilsTest.scala | 40 ++++---- 10 files changed, 119 insertions(+), 68 deletions(-) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala index 2f0ecc97f1f17..1332c5ee83a1a 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchScan.scala @@ -56,7 +56,7 @@ trait BatchScan extends CommonScan with DataSetRel { //Forward all fields at conversion val indices = flinkTable.fieldIndexes.zipWithIndex - val fields: String = getForwardedFields(inputType, internalType, indices) + val fields = getForwardedFields(inputType, internalType, indices) input .map(mapFunc) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index 11656d9f8ea12..81aeb6412461c 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -158,6 +158,13 @@ class DataSetAggregate( } private def forwardFields(rowTypeInfo: RowTypeInfo) = { + //Forward all fields at conversion + val inputInfo = mappedInput.getType + val indices = if (rowTypeInfo.getTotalFields < inputInfo.getTotalFields) { + 0 until rowTypeInfo.getTotalFields + } else { + 0 until inputInfo.getTotalFields + } val indices = 0 to rowTypeInfo.getTotalFields getForwardedInput( FlinkTypeFactory.toInternalRowTypeInfo(inputType), diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala index f1d115e841278..161e9757bb745 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -103,7 +103,7 @@ class DataSetCorrelate( val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType) + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) val mapFunc = correlateMapFunction( config, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala index 5f9949972f83d..4ccfe32ca39a0 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala @@ -24,7 +24,8 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment -import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getDummyForwardedFields import org.apache.flink.table.runtime.IntersectCoGroupFunction import org.apache.flink.types.Row @@ -85,10 +86,19 @@ class DataSetIntersect( val coGroupOpName = s"intersect: ($intersectSelectionToString)" val coGroupFunction = new IntersectCoGroupFunction[Row](all) + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) + + val (leftFields, rightFields) = getDummyForwardedFields( + leftDataSet, + rightDataSet, + returnType) + coGroupedDs .where("*") .equalTo("*") .`with`(coGroupFunction) + .withForwardedFieldsFirst(leftFields) + .withForwardedFieldsSecond(rightFields) .name(coGroupOpName) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala index dba23fab4bb3c..b6fae13d18f9e 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala @@ -24,7 +24,8 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment -import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput +import org.apache.flink.table.calcite.FlinkTypeFactory +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getDummyForwardedFields import org.apache.flink.table.runtime.MinusCoGroupFunction import org.apache.flink.types.Row @@ -96,10 +97,19 @@ class DataSetMinus( val coGroupOpName = s"minus: ($minusSelectionToString)" val coGroupFunction = new MinusCoGroupFunction[Row](all) + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) + + val (leftFields, rightFields) = getDummyForwardedFields( + leftDataSet, + rightDataSet, + returnType) + coGroupedDs .where("*") .equalTo("*") .`with`(coGroupFunction) + .withForwardedFieldsFirst(leftFields) + .withForwardedFieldsSecond(rightFields) .name(coGroupOpName) } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index 417d84d814b94..5c112a2ec1b10 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -27,10 +27,9 @@ import org.apache.flink.api.common.functions.{FlatJoinFunction, FlatMapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.api.java.typeutils.GenericTypeInfo import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.runtime.{MapJoinLeftRunner, MapJoinRightRunner} -import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.{compositeTypeField, getForwardedFields} +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedFields import org.apache.flink.table.api.{BatchTableEnvironment, TableConfig} import org.apache.flink.types.Row @@ -112,22 +111,11 @@ class DataSetSingleRowJoin( (leftDataSet, rightDataSet) } - def customWrapper(typeInformation: TypeInformation[_]) = { - typeInformation match { - case r: GenericTypeInfo[_] => if (r.getTypeClass == classOf[Row]) { - val leftCount: Int = leftDataSet.getType.getTotalFields - val rightCount: Int = rightDataSet.getType.getTotalFields - compositeTypeField((0 until (leftCount + rightCount)).map("f" + _)) - } else { - ??? - } - } - } val offset: Int = if (leftIsSingle) 1 else 0 val indices = (0 until multiRowDataSet.getType.getTotalFields) .map { inputIndex => (inputIndex, inputIndex + offset) } - val fields = getForwardedFields(multiRowDataSet.getType, returnType, indices, customWrapper) + val fields = getForwardedFields(multiRowDataSet.getType, returnType, indices) multiRowDataSet .flatMap(mapSideJoin) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala index d23da667ea9de..0e4760436a05b 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSort.scala @@ -29,7 +29,6 @@ import org.apache.calcite.rex.{RexLiteral, RexNode} import org.apache.flink.api.common.operators.Order import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.{BatchTableEnvironment, TableException} -import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput import org.apache.flink.table.runtime.{CountPartitionFunction, LimitFilterFunction} import org.apache.flink.types.Row @@ -129,7 +128,6 @@ class DataSetSort( broadcastName) val limitName = s"offset: $offsetToString, fetch: $fetchToString" - // TODO Do we need this here? val allFields: String = "*" partitionedDs diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala index 37d18193eb82a..07b5ed9d1fe17 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala @@ -18,32 +18,34 @@ package org.apache.flink.table.plan.nodes.dataset.forwarding -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, TypeInformation => TypeInfo} import org.apache.flink.api.common.typeutils.CompositeType +import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.TableException +import org.apache.flink.types.Row object FieldForwardingUtils { - def compositeTypeField = (fields: Seq[String]) => fields + def compositeTypeField = (fields: Seq[(String, TypeInfo[_])]) => fields - private def throwMissedWrapperException(wrapperCustomCase: TypeInformation[_]) = { - throw new TableException(s"Implementation for $wrapperCustomCase index wrapper is missing.") + private def throwMissedWrapperException(customWrapper: TypeInfo[_]) = { + throw new TableException(s"Implementation for $customWrapper wrapper is missing.") } /** * Wrapper for {@link getForwardedFields} */ def getForwardedInput( - inputType: TypeInformation[_], - outputType: TypeInformation[_], + inputType: TypeInfo[_], + outputType: TypeInfo[_], forwardIndices: Seq[Int], - wrapperCustomCase: TypeInformation[_] => - (Int) => String = throwMissedWrapperException): String = { + customWrapper: TypeInfo[_] => + Seq[(String, TypeInfo[_])] = throwMissedWrapperException): String = { getForwardedFields(inputType, outputType, forwardIndices.zip(forwardIndices), - wrapperCustomCase) + customWrapper) } /** @@ -58,18 +60,23 @@ object FieldForwardingUtils { * @return string with forwarded fields mapped from input to output */ def getForwardedFields( - inputType: TypeInformation[_], - outputType: TypeInformation[_], + inputType: TypeInfo[_], + outputType: TypeInfo[_], forwardIndices: Seq[(Int, Int)], - customWrapper: TypeInformation[_] => - (Int) => String = throwMissedWrapperException): String = { + customWrapper: TypeInfo[_] => + Seq[(String, TypeInfo[_])] = throwMissedWrapperException): String = { + + def chooseWrapper( + typeInformation: TypeInfo[_]): Seq[(String, TypeInfo[_])] = { - def chooseWrapper(typeInformation: TypeInformation[_]): (Int) => String = { typeInformation match { case composite: CompositeType[_] => - compositeTypeField(composite.getFieldNames) + val fields = extractFields(composite) + compositeTypeField(fields) case basic: BasicTypeInfo[_] => - (_: Int) => s"*" + Seq((s"*", basic)) + case array: BasicArrayTypeInfo[_, _] => + Seq((s"*", array)) case _ => customWrapper(typeInformation) } @@ -81,23 +88,56 @@ object FieldForwardingUtils { forwardFields(forwardIndices, wrapInput, wrapOutput) } + private def extractFields( + composite: CompositeType[_]): Seq[(String, TypeInfo[_])] = { + + val types = for { + i <- 0 until composite.getArity + } yield { composite.getTypeAt(i) } + + composite.getFieldNames.zip(types) + } + private def forwardFields( forwardIndices: Seq[(Int, Int)], - wrapInput: (Int) => String, - wrapOutput: (Int) => String): String = { + wrappedInput: Int => (String, TypeInfo[_]), + wrappedOutput: Int => (String, TypeInfo[_])): String = { - implicit class String2ForwardFields(left: String) { - def ->(right: String): String = if (left == right) { - left + implicit class Field2ForwardField(left: (String, TypeInfo[_])) { + def ->(right: (String, TypeInfo[_])): String = if (left.equals(right)) { + s"${left._1}" } else { - s"$left->$right" + if (left._2.equals(right._2)) { + s"${left._1}->${right._1}" + } else { + null + } } } - def wrapIndices(inputIndex: Int, outputIndex: Int): String = { - wrapInput(inputIndex) -> wrapOutput(outputIndex) - } + forwardIndices map { + case (in, out) => + wrappedInput(in) -> wrappedOutput(out) + } filterNot(_ == null) mkString ";" + } + + def getDummyForwardedFields( + leftDataSet: DataSet[Row], + rightDataSet: DataSet[Row], + returnType: TypeInfo[Row]): (String, String) = { + + val leftFields = getDummyForwardedFields(leftDataSet, returnType) + val rightFields = getDummyForwardedFields(rightDataSet, returnType) + (leftFields, rightFields) + } + + def getDummyForwardedFields( + dataSet: DataSet[Row], + returnType: TypeInfo[Row]): String ={ - forwardIndices map { case (in, out) => wrapIndices(in, out) } mkString ";" + val `type` = dataSet.getType + val indices = 0 until `type`.getTotalFields + getForwardedInput(`type`, returnType, indices) } + } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala index f2d16c3825237..20bc3fdd35720 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -92,7 +92,7 @@ class DataStreamCorrelate( val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType) + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) val mapFunc = correlateMapFunction( config, diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala index 75ba4567723f5..589c37b73c463 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala @@ -19,46 +19,44 @@ package org.apache.flink.table.plan.nodes.dataset.forwarding import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, TypeInformation} -import org.apache.flink.api.java.typeutils.{RowTypeInfo, TupleTypeInfo, TypeExtractor} +import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo, TupleTypeInfo, TypeExtractor} import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils._ import org.junit.Assert._ import org.junit.Test class FieldForwardingUtilsTest { + val intType = BasicTypeInfo.INT_TYPE_INFO + val strType = BasicTypeInfo.STRING_TYPE_INFO + val doubleArrType = BasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO + val tuple = new TupleTypeInfo(intType, strType) + val pojo = TypeExtractor.getForClass(classOf[TestPojo]) + val row = new RowTypeInfo(strType, intType, doubleArrType) + val namedRow = new RowTypeInfo( + Array[TypeInformation[_]](intType, doubleArrType, strType), + Array("ints", "doubleArr", "strings") + ) + @Test def testForwarding() = { - val intType = BasicTypeInfo.INT_TYPE_INFO - val strType = BasicTypeInfo.STRING_TYPE_INFO - val doubleArrType = BasicArrayTypeInfo.DOUBLE_ARRAY_TYPE_INFO - - val tuple = new TupleTypeInfo(intType, strType) - val pojo = TypeExtractor.getForClass(classOf[TestPojo]) - - val row = new RowTypeInfo(strType, intType, doubleArrType) - val namedRow = new RowTypeInfo( - Array[TypeInformation[_]](intType, doubleArrType, strType), - Array("ints", "doubleArr", "strings") - ) - assertEquals("f0->f1;f1->f0", getForwardedFields(tuple, row, Seq((0, 1), (1, 0)))) assertEquals("f0->ints;f1->strings", getForwardedFields(tuple, namedRow, Seq((0, 0), (1, 2)))) assertEquals("f0->someInt;f1->aString", getForwardedFields(tuple, pojo, Seq((0, 2), (1, 0)))) - assertEquals("*", getForwardedInput(intType, intType, Seq(0))) + } + @Test + def testForwardingWithCustomType() = { + val customType = new GenericTypeInfo(classOf[Int]) val customTypeWrapper = (info: TypeInformation[_]) => info match { - case array: BasicArrayTypeInfo[_, _] => - (_: Int) => s"*" + case gen: GenericTypeInfo[Int] => + Seq((s"*", intType)) } - assertEquals("*", getForwardedInput(doubleArrType, doubleArrType, Seq(0), customTypeWrapper)) + assertEquals("*", getForwardedInput(customType, intType, Seq(0), customTypeWrapper)) } } -//TODO can't test it in this package -case class TestCaseClass(aString: String, someInt: Int) - final class TestPojo { private var aString: String = _ var doubleArray: Array[Double] = _ From 973f548cfe508189467a10d5a7b03b2635a04764 Mon Sep 17 00:00:00 2001 From: tonycox Date: Fri, 10 Mar 2017 21:41:06 +0400 Subject: [PATCH 4/4] [FLINK-3850] address commented fix --- .../table/plan/nodes/CommonCorrelate.scala | 3 +- .../plan/nodes/dataset/DataSetAggregate.scala | 52 ++++++--- .../plan/nodes/dataset/DataSetCalc.scala | 34 +++--- .../plan/nodes/dataset/DataSetCorrelate.scala | 27 +---- .../plan/nodes/dataset/DataSetIntersect.scala | 8 +- .../plan/nodes/dataset/DataSetMinus.scala | 8 +- .../nodes/dataset/DataSetSingleRowJoin.scala | 17 ++- .../forwarding/FieldForwardingUtils.scala | 104 ++++++++++-------- .../datastream/DataStreamCorrelate.scala | 2 - .../table/plan/util/RexFieldExtractor.scala | 36 ------ .../forwarding/FieldForwardingUtilsTest.scala | 11 -- 11 files changed, 140 insertions(+), 162 deletions(-) delete mode 100644 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexFieldExtractor.scala diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala index d3745c5684f94..6c4066b7af9d4 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/CommonCorrelate.scala @@ -46,7 +46,6 @@ trait CommonCorrelate { config: TableConfig, inputTypeInfo: TypeInformation[Row], udtfTypeInfo: TypeInformation[Any], - returnType: TypeInformation[Row], rowType: RelDataType, joinType: SemiJoinType, rexCall: RexCall, @@ -55,6 +54,8 @@ trait CommonCorrelate { ruleDescription: String) : CorrelateFlatMapRunner[Row, Row] = { + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(rowType) + val flatMap = generateFunction( config, inputTypeInfo, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index 81aeb6412461c..fe4212218b014 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -28,7 +28,7 @@ import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.api.java.DataSet import org.apache.flink.api.java.typeutils.RowTypeInfo import org.apache.flink.table.api.BatchTableEnvironment -import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedFields import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.plan.nodes.CommonAggregate import org.apache.flink.table.runtime.aggregate.{AggregateUtil, DataSetPreAggFunction} @@ -112,25 +112,34 @@ class DataSetAggregate( val aggOpName = s"groupBy: (${groupingToString(inputType, grouping)}), " + s"select: ($aggString)" + val inputTypeInfo = FlinkTypeFactory.toInternalRowTypeInfo(inputType) + if (preAgg.isDefined) { + + val preAggFields = forwardFields(inputTypeInfo, preAggType.get, grouping) + val fields = continueForwardFields(preAggType.get, rowTypeInfo, grouping) inputDS // pre-aggregation .groupBy(grouping: _*) .combineGroup(preAgg.get) .returns(preAggType.get) - // forward fields at conversion - .withForwardedFields(forwardFields(rowTypeInfo)) + // forward fields at pre-aggregation + .withForwardedFields(preAggFields) .name(aggOpName) // final aggregation .groupBy(grouping.indices: _*) .reduceGroup(finalAgg) .returns(rowTypeInfo) + // forward fields at final conversion + .withForwardedFields(fields) .name(aggOpName) } else { + val fields = forwardFields(inputTypeInfo, rowTypeInfo, grouping) inputDS .groupBy(grouping: _*) .reduceGroup(finalAgg) .returns(rowTypeInfo) + .withForwardedFields(fields) .name(aggOpName) } } @@ -157,18 +166,29 @@ class DataSetAggregate( } } - private def forwardFields(rowTypeInfo: RowTypeInfo) = { - //Forward all fields at conversion - val inputInfo = mappedInput.getType - val indices = if (rowTypeInfo.getTotalFields < inputInfo.getTotalFields) { - 0 until rowTypeInfo.getTotalFields - } else { - 0 until inputInfo.getTotalFields - } - val indices = 0 to rowTypeInfo.getTotalFields - getForwardedInput( - FlinkTypeFactory.toInternalRowTypeInfo(inputType), - rowTypeInfo, - indices) + private def continueForwardFields( + inputRowType: TypeInformation[Row], + resultRowType: TypeInformation[Row], + aggIndices: Array[Int]) = { + + val names = inputType.getFieldNames + val aggKeys = aggIndices.map(names.get) + val outIndices = aggKeys.map(getRowType.getField(_, false, false).getIndex) + + getForwardedFields( + inputRowType, + resultRowType, + aggIndices.indices.zip(outIndices)) + } + + private def forwardFields( + inputRowType: TypeInformation[Row], + resultRowType: TypeInformation[Row], + aggIndices: Array[Int]) = { + + getForwardedFields( + inputRowType, + resultRowType, + grouping.zipWithIndex) } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala index ac5680a3a311f..52aa63ca39e23 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCalc.scala @@ -26,13 +26,13 @@ import org.apache.calcite.rel.{RelNode, RelWriter} import org.apache.flink.api.common.functions.FlatMapFunction import org.apache.flink.api.java.DataSet import org.apache.calcite.rex._ +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedFields import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.CodeGenerator import org.apache.flink.table.plan.nodes.CommonCalc import org.apache.flink.types.Row -import org.apache.flink.table.plan.util.RexProgramProjectExtractor._ import scala.collection.JavaConversions._ /** @@ -103,26 +103,30 @@ class DataSetCalc( body, returnType) - def getForwardIndices = { - // get (input, output) indices of operands, - // filter modified operands and specify forwarding - val inputFields = extractRefInputFields(calcProgram) - calcProgram.getProjectList - .map(_.getIndex) - .zipWithIndex - .filter(tup => inputFields.contains(tup._1)) - } - val mapFunc = calcMapFunction(genFunction) - val fields = getForwardedFields( - inputDS.getType, - returnType, - getForwardIndices) + val fields = forwardFields(inputDS.getType, returnType) inputDS .flatMap(mapFunc) .withForwardedFields(fields) .name(calcOpName(calcProgram, getExpressionString)) } + + private def forwardFields( + inputDS: TypeInformation[Row], + returnType: TypeInformation[Row]) = { + + def getForwardIndices = { + calcProgram.getProjectList.zipWithIndex.flatMap { case (p, out) => + val expr = calcProgram.getExprList.get(p.getIndex) + expr match { + case i: RexInputRef => Some((i.getIndex, out)) + case _ => None + } + } + } + + getForwardedFields(inputDS, returnType, getForwardIndices) + } } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala index 161e9757bb745..165ffe2c3742f 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetCorrelate.scala @@ -22,10 +22,9 @@ import org.apache.calcite.rel.`type`.RelDataType import org.apache.calcite.rel.logical.LogicalTableFunctionScan import org.apache.calcite.rel.metadata.RelMetadataQuery import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} -import org.apache.calcite.rex.{RexCall, RexInputRef, RexNode} +import org.apache.calcite.rex.{RexCall, RexNode} import org.apache.calcite.sql.SemiJoinType import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.common.typeutils.CompositeType import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory @@ -33,9 +32,6 @@ import org.apache.flink.table.functions.utils.TableSqlFunction import org.apache.flink.table.plan.nodes.CommonCorrelate import org.apache.flink.types.Row import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput -import org.apache.flink.table.plan.util.RexFieldExtractor._ - -import scala.collection.JavaConversions._ /** * Flink RelNode which matches along with join a user defined table function. @@ -103,13 +99,11 @@ class DataSetCorrelate( val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) val mapFunc = correlateMapFunction( config, inputDS.getType, udtfTypeInfo, - returnType, getRowType, joinType, rexCall, @@ -117,23 +111,8 @@ class DataSetCorrelate( Some(pojoFieldMapping), ruleDescription) - def getIndices = { - //get indices of all input operands - val inputOperandsInRel = extractRefInputFields(rexCall) - val joinCondition = if (condition.isDefined) { - extractRefInputFields(condition.get) - } else { - Array() - } - val inputOperands = inputOperandsInRel ++ joinCondition - - inputDS.getType.asInstanceOf[CompositeType[_]] - .getFieldNames - .indices - .filter(inputOperands.contains) - } - - val fields = getForwardedInput(inputDS.getType, mapFunc.getProducedType, getIndices) + val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) + val fields = getForwardedInput(inputDS.getType, returnType) inputDS .flatMap(mapFunc) diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala index 4ccfe32ca39a0..06a8678d147d2 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetIntersect.scala @@ -25,7 +25,7 @@ import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getDummyForwardedFields +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput import org.apache.flink.table.runtime.IntersectCoGroupFunction import org.apache.flink.types.Row @@ -88,9 +88,9 @@ class DataSetIntersect( val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) - val (leftFields, rightFields) = getDummyForwardedFields( - leftDataSet, - rightDataSet, + val (leftFields, rightFields) = getForwardedInput( + leftDataSet.getType, + rightDataSet.getType, returnType) coGroupedDs diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala index b6fae13d18f9e..dc5491e391399 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetMinus.scala @@ -25,7 +25,7 @@ import org.apache.calcite.rel.{BiRel, RelNode, RelWriter} import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.BatchTableEnvironment import org.apache.flink.table.calcite.FlinkTypeFactory -import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getDummyForwardedFields +import org.apache.flink.table.plan.nodes.dataset.forwarding.FieldForwardingUtils.getForwardedInput import org.apache.flink.table.runtime.MinusCoGroupFunction import org.apache.flink.types.Row @@ -99,9 +99,9 @@ class DataSetMinus( val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) - val (leftFields, rightFields) = getDummyForwardedFields( - leftDataSet, - rightDataSet, + val (leftFields, rightFields) = getForwardedInput( + leftDataSet.getType, + rightDataSet.getType, returnType) coGroupedDs diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala index 5c112a2ec1b10..3eaa178f2b384 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetSingleRowJoin.scala @@ -111,11 +111,7 @@ class DataSetSingleRowJoin( (leftDataSet, rightDataSet) } - val offset: Int = if (leftIsSingle) 1 else 0 - val indices = (0 until multiRowDataSet.getType.getTotalFields) - .map { inputIndex => (inputIndex, inputIndex + offset) } - - val fields = getForwardedFields(multiRowDataSet.getType, returnType, indices) + val fields = forwardFields(returnType, multiRowDataSet) multiRowDataSet .flatMap(mapSideJoin) @@ -124,6 +120,17 @@ class DataSetSingleRowJoin( .name(getMapOperatorName) } + private def forwardFields( + returnType: TypeInformation[Row], + multiRowDataSet: DataSet[Row]): String = { + + val offset: Int = if (leftIsSingle) 1 else 0 + val indices = (0 until multiRowDataSet.getType.getTotalFields) + .map { inputIndex => (inputIndex, inputIndex + offset) } + + getForwardedFields(multiRowDataSet.getType, returnType, indices) + } + private def generateMapFunction( config: TableConfig, inputType1: TypeInformation[Row], diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala index 07b5ed9d1fe17..58e96f0835b34 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtils.scala @@ -18,34 +18,73 @@ package org.apache.flink.table.plan.nodes.dataset.forwarding -import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, TypeInformation => TypeInfo} +import org.apache.flink.api.common.typeinfo.AtomicType +import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo +import org.apache.flink.api.common.typeinfo.BasicTypeInfo +import org.apache.flink.api.common.typeinfo.TypeInformation +import org.apache.flink.api.common.typeinfo.{TypeInformation => TypeInfo} import org.apache.flink.api.common.typeutils.CompositeType -import org.apache.flink.api.java.DataSet import org.apache.flink.table.api.TableException import org.apache.flink.types.Row object FieldForwardingUtils { - def compositeTypeField = (fields: Seq[(String, TypeInfo[_])]) => fields + private def throwMissedTypeInfoException(info: TypeInfo[_]) = { + throw new TableException(s"Implementation for $info wrapper is missing.") + } - private def throwMissedWrapperException(customWrapper: TypeInfo[_]) = { - throw new TableException(s"Implementation for $customWrapper wrapper is missing.") + /** + * Wrapper for {@link getForwardedInput} + * + * @param leftType left input [[TypeInformation]] + * @param rightType right input [[TypeInformation]] + * @param returnType information of output data + * @return string with forwarded fields mapped from input to output + */ + def getForwardedInput( + leftType: TypeInfo[Row], + rightType: TypeInfo[Row], + returnType: TypeInfo[Row]): (String, String) = { + + val leftFields = getForwardedInput(leftType, returnType) + val rightFields = getForwardedInput(rightType, returnType) + (leftFields, rightFields) + } + + /** + * Wrapper for {@link getForwardedInput} + * Generates default indices by inputTypeInfo + * + * @param inputType input [[TypeInformation]] + * @param returnType information of output data + * @return string with forwarded fields mapped from input to output + */ + def getForwardedInput( + inputType: TypeInfo[Row], + returnType: TypeInfo[Row]): String ={ + + val indices = 0 until inputType.getTotalFields + getForwardedInput(inputType, returnType, indices) } /** * Wrapper for {@link getForwardedFields} + * Generates default indices by zipping forwardIndices with itself + * + * @param inputType information of input data + * @param outputType information of output data + * @param forwardIndices direct mapping of fields + * @return string with forwarded fields mapped from input to output */ def getForwardedInput( inputType: TypeInfo[_], outputType: TypeInfo[_], - forwardIndices: Seq[Int], - customWrapper: TypeInfo[_] => - Seq[(String, TypeInfo[_])] = throwMissedWrapperException): String = { + forwardIndices: Seq[Int]): String = { - getForwardedFields(inputType, + getForwardedFields( + inputType, outputType, - forwardIndices.zip(forwardIndices), - customWrapper) + forwardIndices.zip(forwardIndices)) } /** @@ -55,30 +94,27 @@ object FieldForwardingUtils { * @param inputType information of input data * @param outputType information of output data * @param forwardIndices tuple of (input, output) indices of a forwarded field - * @param customWrapper used for figuring out proper type in specific cases, - * e.g. {@see DataSetSingleRowJoin} * @return string with forwarded fields mapped from input to output */ def getForwardedFields( inputType: TypeInfo[_], outputType: TypeInfo[_], - forwardIndices: Seq[(Int, Int)], - customWrapper: TypeInfo[_] => - Seq[(String, TypeInfo[_])] = throwMissedWrapperException): String = { + forwardIndices: Seq[(Int, Int)]): String = { def chooseWrapper( typeInformation: TypeInfo[_]): Seq[(String, TypeInfo[_])] = { typeInformation match { case composite: CompositeType[_] => - val fields = extractFields(composite) - compositeTypeField(fields) + extractFields(composite) case basic: BasicTypeInfo[_] => Seq((s"*", basic)) case array: BasicArrayTypeInfo[_, _] => Seq((s"*", array)) + case atomic: AtomicType[_] => + Seq((s"*", atomic)) case _ => - customWrapper(typeInformation) + throwMissedTypeInfoException(typeInformation) } } @@ -110,34 +146,14 @@ object FieldForwardingUtils { if (left._2.equals(right._2)) { s"${left._1}->${right._1}" } else { - null + throw new TableException("The logic of identifying " + + "the mapping of forwarded fields is broken") } } } - forwardIndices map { - case (in, out) => - wrappedInput(in) -> wrappedOutput(out) - } filterNot(_ == null) mkString ";" - } - - def getDummyForwardedFields( - leftDataSet: DataSet[Row], - rightDataSet: DataSet[Row], - returnType: TypeInfo[Row]): (String, String) = { - - val leftFields = getDummyForwardedFields(leftDataSet, returnType) - val rightFields = getDummyForwardedFields(rightDataSet, returnType) - (leftFields, rightFields) - } - - def getDummyForwardedFields( - dataSet: DataSet[Row], - returnType: TypeInfo[Row]): String ={ - - val `type` = dataSet.getType - val indices = 0 until `type`.getTotalFields - getForwardedInput(`type`, returnType, indices) + forwardIndices map { case (in, out) => + wrappedInput(in) -> wrappedOutput(out) + } mkString ";" } - } diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala index 20bc3fdd35720..0745fe0dce0cc 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamCorrelate.scala @@ -92,13 +92,11 @@ class DataStreamCorrelate( val sqlFunction = rexCall.getOperator.asInstanceOf[TableSqlFunction] val pojoFieldMapping = sqlFunction.getPojoFieldMapping val udtfTypeInfo = sqlFunction.getRowTypeInfo.asInstanceOf[TypeInformation[Any]] - val returnType = FlinkTypeFactory.toInternalRowTypeInfo(getRowType) val mapFunc = correlateMapFunction( config, inputDS.getType, udtfTypeInfo, - returnType, getRowType, joinType, rexCall, diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexFieldExtractor.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexFieldExtractor.scala deleted file mode 100644 index 351ff20fc1ee3..0000000000000 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/util/RexFieldExtractor.scala +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.flink.table.plan.util - -import org.apache.calcite.rex.RexNode - -object RexFieldExtractor { - - /** - * Extracts the indexes of input fields accessed by the RexNode. - * - * @param rex RexNode to analyze - * @return The indexes of accessed input fields - */ - def extractRefInputFields(rex: RexNode): Array[Int] = { - val visitor = new RefFieldsVisitor - rex.accept(visitor) - visitor.getFields - } -} diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala index 589c37b73c463..1fcde79264a32 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/plan/nodes/dataset/forwarding/FieldForwardingUtilsTest.scala @@ -44,17 +44,6 @@ class FieldForwardingUtilsTest { assertEquals("f0->someInt;f1->aString", getForwardedFields(tuple, pojo, Seq((0, 2), (1, 0)))) assertEquals("*", getForwardedInput(intType, intType, Seq(0))) } - - @Test - def testForwardingWithCustomType() = { - val customType = new GenericTypeInfo(classOf[Int]) - val customTypeWrapper = (info: TypeInformation[_]) => - info match { - case gen: GenericTypeInfo[Int] => - Seq((s"*", intType)) - } - assertEquals("*", getForwardedInput(customType, intType, Seq(0), customTypeWrapper)) - } } final class TestPojo {