Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FLINK-13393][FLINK-13391][table-planner-blink] Fix source conversion and source return type #9211

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -201,10 +201,12 @@ object CodeGenUtils {
* If it's internally compatible, don't need to DataStructure converter.
* clazz != classOf[Row] => Row can only infer GenericType[Row].
*/
def isInternalClass(clazz: Class[_], t: DataType): Boolean =
def isInternalClass(t: DataType): Boolean = {
val clazz = t.getConversionClass
clazz != classOf[Object] && clazz != classOf[Row] &&
(classOf[BaseRow].isAssignableFrom(clazz) ||
clazz == getInternalClassForType(fromDataTypeToLogicalType(t)))
(classOf[BaseRow].isAssignableFrom(clazz) ||
clazz == getInternalClassForType(fromDataTypeToLogicalType(t)))
}

def hashCodeForType(
ctx: CodeGeneratorContext, t: LogicalType, term: String): String = t.getTypeRoot match {
Expand Down Expand Up @@ -680,9 +682,8 @@ object CodeGenUtils {
def genToInternalIfNeeded(
ctx: CodeGeneratorContext,
t: DataType,
clazz: Class[_],
term: String): String = {
if (isInternalClass(clazz, t)) {
if (isInternalClass(t)) {
s"(${boxedTypeTermForType(fromDataTypeToLogicalType(t))}) $term"
} else {
genToInternal(ctx, t, term)
Expand All @@ -705,9 +706,8 @@ object CodeGenUtils {
def genToExternalIfNeeded(
ctx: CodeGeneratorContext,
t: DataType,
clazz: Class[_],
term: String): String = {
if (isInternalClass(clazz, t)) {
if (isInternalClass(t)) {
s"(${boxedTypeTermForType(fromDataTypeToLogicalType(t))}) $term"
} else {
genToExternal(ctx, t, term)
Expand Down
Expand Up @@ -18,7 +18,6 @@

package org.apache.flink.table.planner.codegen

import org.apache.flink.api.common.functions.InvalidTypesException
import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2}
Expand All @@ -31,7 +30,6 @@ import org.apache.flink.table.dataformat.util.BaseRowUtil
import org.apache.flink.table.dataformat.{BaseRow, GenericRow}
import org.apache.flink.table.planner.codegen.CodeGenUtils.genToExternal
import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.generateCollect
import org.apache.flink.table.planner.sinks.DataStreamTableSink
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
import org.apache.flink.table.runtime.types.TypeInfoLogicalTypeConverter.fromTypeInfoToLogicalType
import org.apache.flink.table.runtime.typeutils.BaseRowTypeInfo
Expand All @@ -42,20 +40,6 @@ import org.apache.flink.types.Row

object SinkCodeGenerator {

private[flink] def extractTableSinkTypeClass(sink: TableSink[_]): Class[_] = {
try {
sink match {
// DataStreamTableSink has no generic class, so we need get the type to get type class.
case sink: DataStreamTableSink[_] => sink.getConsumedDataType.getConversionClass
case _ => TypeExtractor.createTypeInfo(sink, classOf[TableSink[_]], sink.getClass, 0)
.getTypeClass.asInstanceOf[Class[_]]
}
} catch {
case _: InvalidTypesException =>
classOf[Object]
}
}

/** Code gen a operator to convert internal type rows to external type. **/
def generateRowConverterOperator[OUT](
ctx: CodeGeneratorContext,
Expand Down
Expand Up @@ -18,11 +18,13 @@

package org.apache.flink.table.planner.codegen.calls

import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.typeutils.GenericTypeInfo
import org.apache.flink.table.dataformat.DataFormatConverters
import org.apache.flink.table.dataformat.DataFormatConverters.getConverterForDataType
import org.apache.flink.table.functions.ScalarFunction
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.calls.ScalarFunctionCallGen.prepareFunctionArgs
import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GenerateUtils, GeneratedExpression}
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._
Expand Down Expand Up @@ -74,7 +76,7 @@ class ScalarFunctionCallGen(scalarFunction: ScalarFunction) extends CallGenerato
val javaTerm = newName("javaResult")
// it maybe a Internal class, so use resultClass is most safety.
val javaTypeTerm = resultClass.getCanonicalName
val internal = genToInternalIfNeeded(ctx, resultExternalType, resultClass, javaTerm)
val internal = genToInternalIfNeeded(ctx, resultExternalType, javaTerm)
s"""
|$javaTypeTerm $javaTerm = ($javaTypeTerm) $evalResult;
|$resultTerm = $javaTerm == null ? null : ($internal);
Expand Down Expand Up @@ -106,29 +108,39 @@ class ScalarFunctionCallGen(scalarFunction: ScalarFunction) extends CallGenerato
ctx: CodeGeneratorContext,
operands: Seq[GeneratedExpression],
func: ScalarFunction): Array[GeneratedExpression] = {

// get the expanded parameter types
var paramClasses = getEvalMethodSignature(func, operands.map(_.resultType).toArray)
prepareFunctionArgs(ctx, operands, paramClasses, func.getParameterTypes(paramClasses))
}

}

object ScalarFunctionCallGen {

val signatureTypes = func
.getParameterTypes(paramClasses)
.zipWithIndex
.map {
case (t, i) =>
// we don't trust GenericType.
if (t.isInstanceOf[GenericTypeInfo[_]]) {
fromLogicalTypeToDataType(operands(i).resultType)
} else {
fromLegacyInfoToDataType(t)
}
def prepareFunctionArgs(
ctx: CodeGeneratorContext,
operands: Seq[GeneratedExpression],
parameterClasses: Array[Class[_]],
parameterTypes: Array[TypeInformation[_]]): Array[GeneratedExpression] = {

val signatureTypes = parameterTypes.zipWithIndex.map {
case (t: GenericTypeInfo[_], i) =>
// we don't trust GenericType, like Row and BaseRow and LocalTime
val returnType = fromLogicalTypeToDataType(operands(i).resultType)
if (operands(i).resultType.supportsOutputConversion(t.getTypeClass)) {
returnType.bridgedTo(t.getTypeClass)
} else {
returnType
}
case (t, _) => fromLegacyInfoToDataType(t)
}

paramClasses.zipWithIndex.zip(operands).map { case ((paramClass, i), operandExpr) =>
parameterClasses.zipWithIndex.zip(operands).map { case ((paramClass, i), operandExpr) =>
if (paramClass.isPrimitive) {
operandExpr
} else {
val externalResultTerm = genToExternalIfNeeded(
ctx, signatureTypes(i), paramClass, operandExpr.resultTerm)
ctx, signatureTypes(i), operandExpr.resultTerm)
val exprOrNull = s"${operandExpr.nullTerm} ? null : ($externalResultTerm)"
operandExpr.copy(resultTerm = exprOrNull)
}
Expand Down
Expand Up @@ -18,15 +18,12 @@

package org.apache.flink.table.planner.codegen.calls

import org.apache.flink.api.java.typeutils.GenericTypeInfo
import org.apache.flink.table.functions.TableFunction
import org.apache.flink.table.planner.codegen.CodeGenUtils.genToExternalIfNeeded
import org.apache.flink.table.planner.codegen.GeneratedExpression.NEVER_NULL
import org.apache.flink.table.planner.codegen.calls.ScalarFunctionCallGen.prepareFunctionArgs
import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, GeneratedExpression}
import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.getEvalMethodSignature
import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType
import org.apache.flink.table.types.logical.LogicalType
import org.apache.flink.table.types.utils.TypeConversions

/**
* Generates a call to user-defined [[TableFunction]].
Expand Down Expand Up @@ -62,32 +59,8 @@ class TableFunctionCallGen(tableFunction: TableFunction[_]) extends CallGenerato
ctx: CodeGeneratorContext,
operands: Seq[GeneratedExpression],
func: TableFunction[_]): Array[GeneratedExpression] = {

// get the expanded parameter types
var paramClasses = getEvalMethodSignature(func, operands.map(_.resultType).toArray)

val signatureTypes = func
.getParameterTypes(paramClasses)
.zipWithIndex
.map {
case (t, i) =>
// we don't trust GenericType.
if (t.isInstanceOf[GenericTypeInfo[_]]) {
fromLogicalTypeToDataType(operands(i).resultType)
} else {
TypeConversions.fromLegacyInfoToDataType(t)
}
}

paramClasses.zipWithIndex.zip(operands).map { case ((paramClass, i), operandExpr) =>
if (paramClass.isPrimitive) {
operandExpr
} else {
val externalResultTerm = genToExternalIfNeeded(
ctx, signatureTypes(i), paramClass, operandExpr.resultTerm)
val exprOrNull = s"${operandExpr.nullTerm} ? null : ($externalResultTerm)"
operandExpr.copy(resultTerm = exprOrNull)
}
}
prepareFunctionArgs(ctx, operands, paramClasses, func.getParameterTypes(paramClasses))
}
}
Expand Up @@ -171,7 +171,12 @@ object UserDefinedFunctionUtils {
case (t: DataType, i) =>
// we don't trust GenericType.
if (fromDataTypeToLogicalType(t).getTypeRoot == LogicalTypeRoot.ANY) {
fromLogicalTypeToDataType(expectedTypes(i))
val returnType = fromLogicalTypeToDataType(expectedTypes(i))
if (expectedTypes(i).supportsOutputConversion(t.getConversionClass)) {
returnType.bridgedTo(t.getConversionClass)
} else {
returnType
}
} else {
t
}
Expand Down
Expand Up @@ -18,9 +18,15 @@

package org.apache.flink.table.planner.plan.nodes.physical

import org.apache.flink.api.common.io.InputFormat
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.dag.Transformation
import org.apache.flink.core.io.InputSplit
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.plan.schema.{FlinkRelOptTable, TableSourceTable}
import org.apache.flink.table.sources.TableSource
import org.apache.flink.table.sources.{InputFormatTableSource, StreamTableSource, TableSource}
import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo

import org.apache.calcite.plan.{RelOptCluster, RelTraitSet}
import org.apache.calcite.rel.RelWriter
Expand All @@ -38,6 +44,9 @@ abstract class PhysicalTableSourceScan(
relOptTable: FlinkRelOptTable)
extends TableScan(cluster, traitSet, relOptTable) {

// cache table source transformation.
protected var sourceTransform: Transformation[_] = _

protected val tableSourceTable: TableSourceTable[_] =
relOptTable.unwrap(classOf[TableSourceTable[_]])

Expand All @@ -52,4 +61,19 @@ abstract class PhysicalTableSourceScan(
super.explainTerms(pw).item("fields", getRowType.getFieldNames.asScala.mkString(", "))
}

def getSourceTransformation(
streamEnv: StreamExecutionEnvironment): Transformation[_] = {
if (sourceTransform == null) {
sourceTransform = tableSource match {
case source: InputFormatTableSource[_] =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the problem is InputFormatTableSource#getDataStream shouldn't call getReturnType, the following code should be moved to there.

val resultType = fromDataTypeToLegacyInfo(source.getProducedDataType)
.asInstanceOf[TypeInformation[Any]]
streamEnv.createInput(
source.getInputFormat.asInstanceOf[InputFormat[Any, _ <: InputSplit]],
resultType).getTransformation
case source: StreamTableSource[_] => source.getDataStream(streamEnv).getTransformation
}
}
sourceTransform
}
}
Expand Up @@ -110,9 +110,7 @@ class BatchExecBoundedStreamScan(

def needInternalConversion: Boolean = {
ScanUtil.hasTimeAttributeField(boundedStreamTable.fieldIndexes) ||
ScanUtil.needsConversion(
boundedStreamTable.dataType,
boundedStreamTable.dataStream.getType.getTypeClass)
ScanUtil.needsConversion(boundedStreamTable.dataType)
}

}
Expand Up @@ -142,8 +142,7 @@ class BatchExecSink[T](
// Sink's input must be BatchExecNode[BaseRow] now.
case node: BatchExecNode[BaseRow] =>
val plan = node.translateToPlan(planner)
val typeClass = extractTableSinkTypeClass(sink)
if (CodeGenUtils.isInternalClass(typeClass, resultDataType)) {
if (CodeGenUtils.isInternalClass(resultDataType)) {
plan.asInstanceOf[Transformation[T]]
} else {
val (converterOperator, outputTypeInfo) = generateRowConverterOperator[T](
Expand Down
Expand Up @@ -19,9 +19,7 @@
package org.apache.flink.table.planner.plan.nodes.physical.batch

import org.apache.flink.api.dag.Transformation
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.runtime.operators.DamBehavior
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment
import org.apache.flink.table.api.TableException
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.planner.codegen.CodeGeneratorContext
Expand Down Expand Up @@ -55,9 +53,6 @@ class BatchExecTableSourceScan(
with BatchPhysicalRel
with BatchExecNode[BaseRow]{

// cache table source transformation.
private var sourceTransform: Transformation[_] = _

override def copy(traitSet: RelTraitSet, inputs: util.List[RelNode]): RelNode = {
new BatchExecTableSourceScan(cluster, traitSet, relOptTable)
}
Expand Down Expand Up @@ -85,15 +80,6 @@ class BatchExecTableSourceScan(
replaceInput(ordinalInParent, newInputNode.asInstanceOf[RelNode])
}

def getSourceTransformation(
streamEnv: StreamExecutionEnvironment): Transformation[_] = {
if (sourceTransform == null) {
sourceTransform = tableSource.asInstanceOf[StreamTableSource[_]].
getDataStream(streamEnv).getTransformation
}
sourceTransform
}

override protected def translateToPlanInternal(
planner: BatchPlanner): Transformation[BaseRow] = {
val config = planner.getTableConfig
Expand Down Expand Up @@ -147,11 +133,7 @@ class BatchExecTableSourceScan(
isStreamTable = false,
tableSourceTable.selectedFields)
ScanUtil.hasTimeAttributeField(fieldIndexes) ||
ScanUtil.needsConversion(
tableSource.getProducedDataType,
TypeExtractor.createTypeInfo(
tableSource, classOf[StreamTableSource[_]], tableSource.getClass, 0)
.getTypeClass.asInstanceOf[Class[_]])
ScanUtil.needsConversion(tableSource.getProducedDataType)
}

def getEstimatedRowCount: lang.Double = {
Expand Down
Expand Up @@ -116,9 +116,7 @@ class StreamExecDataStreamScan(

// when there is row time extraction expression, we need internal conversion
// when the physical type of the input date stream is not BaseRow, we need internal conversion.
if (rowtimeExpr.isDefined || ScanUtil.needsConversion(
dataStreamTable.dataType,
dataStreamTable.dataStream.getType.getTypeClass)) {
if (rowtimeExpr.isDefined || ScanUtil.needsConversion(dataStreamTable.dataType)) {

// extract time if the index is -1 or -2.
val (extractElement, resetElement) =
Expand Down
Expand Up @@ -24,7 +24,7 @@ import org.apache.flink.streaming.api.transformations.OneInputTransformation
import org.apache.flink.table.api.{Table, TableException}
import org.apache.flink.table.dataformat.BaseRow
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.SinkCodeGenerator.{extractTableSinkTypeClass, generateRowConverterOperator}
import org.apache.flink.table.planner.codegen.SinkCodeGenerator.generateRowConverterOperator
import org.apache.flink.table.planner.codegen.{CodeGenUtils, CodeGeneratorContext}
import org.apache.flink.table.planner.delegation.StreamPlanner
import org.apache.flink.table.planner.plan.`trait`.{AccMode, AccModeTraitDef}
Expand Down Expand Up @@ -213,8 +213,7 @@ class StreamExecSink[T](
}
val resultDataType = sink.getConsumedDataType
val resultType = fromDataTypeToLegacyInfo(resultDataType)
val typeClass = extractTableSinkTypeClass(sink)
if (CodeGenUtils.isInternalClass(typeClass, resultDataType)) {
if (CodeGenUtils.isInternalClass(resultDataType)) {
parTransformation.asInstanceOf[Transformation[T]]
} else {
val (converterOperator, outputTypeInfo) = generateRowConverterOperator[T](
Expand Down