Permalink
Browse files

[FLINK-5280] [table] Refactor TableSource interface.

This closes #3039.
  • Loading branch information...
1 parent d4d7cc3 commit 38ded2bb00aeb5c9581fa7ef313e5b9f803f5c26 @mushketyk mushketyk committed with fhueske Dec 22, 2016
Showing with 259 additions and 135 deletions.
  1. +3 −18 ...nector-kafka-base/src/main/java/org/apache/flink/streaming/connectors/kafka/KafkaTableSource.java
  2. +1 −1 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/BatchTableEnvironment.scala
  3. +1 −1 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/StreamTableEnvironment.scala
  4. +98 −37 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableEnvironment.scala
  5. +4 −18 .../flink-table/src/main/scala/org/apache/flink/table/functions/utils/UserDefinedFunctionUtils.scala
  6. +8 −3 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/FlinkRel.scala
  7. +5 −3 ...s/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/BatchTableSourceScan.scala
  8. +5 −3 ...ink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/StreamTableSourceScan.scala
  9. +3 −3 ...ink-table/src/main/scala/org/apache/flink/table/plan/rules/dataSet/BatchTableSourceScanRule.scala
  10. +2 −1 ...ain/scala/org/apache/flink/table/plan/rules/dataSet/PushProjectIntoBatchTableSourceScanRule.scala
  11. +2 −1 ...scala/org/apache/flink/table/plan/rules/datastream/PushProjectIntoStreamTableSourceScanRule.scala
  12. +3 −3 ...table/src/main/scala/org/apache/flink/table/plan/rules/datastream/StreamTableSourceScanRule.scala
  13. +3 −3 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/FlinkTable.scala
  14. +6 −7 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/schema/TableSourceTable.scala
  15. +1 −10 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/CsvTableSource.scala
  16. +35 −0 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/DefinedFieldNames.scala
  17. +11 −10 flink-libraries/flink-table/src/main/scala/org/apache/flink/table/sources/TableSource.scala
  18. +4 −7 ...-libraries/flink-table/src/test/java/org/apache/flink/table/api/java/batch/TableSourceITCase.java
  19. +5 −2 flink-libraries/flink-table/src/test/scala/org/apache/flink/table/TableEnvironmentTest.scala
  20. +19 −0 ...braries/flink-table/src/test/scala/org/apache/flink/table/api/scala/batch/TableSourceITCase.scala
  21. +40 −2 flink-libraries/flink-table/src/test/scala/org/apache/flink/table/utils/CommonTestData.scala
  22. +0 −2 flink-tests/src/test/scala/org/apache/flink/api/scala/io/CsvInputFormatTest.scala
@@ -19,12 +19,12 @@
package org.apache.flink.streaming.connectors.kafka;
import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.types.Row;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
-import org.apache.flink.table.sources.StreamTableSource;
import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.streaming.util.serialization.DeserializationSchema;
+import org.apache.flink.table.sources.StreamTableSource;
+import org.apache.flink.types.Row;
import org.apache.flink.util.Preconditions;
import java.util.Properties;
@@ -112,23 +112,8 @@
}
@Override
- public int getNumberOfFields() {
- return fieldNames.length;
- }
-
- @Override
- public String[] getFieldsNames() {
- return fieldNames;
- }
-
- @Override
- public TypeInformation<?>[] getFieldTypes() {
- return fieldTypes;
- }
-
- @Override
public TypeInformation<Row> getReturnType() {
- return new RowTypeInfo(fieldTypes);
+ return new RowTypeInfo(fieldTypes, fieldNames);
}
/**
@@ -298,7 +298,7 @@ abstract class BatchTableEnvironment(
* @return The [[DataSet]] that corresponds to the translated [[Table]].
*/
protected def translate[A](logicalPlan: RelNode)(implicit tpe: TypeInformation[A]): DataSet[A] = {
- validateType(tpe)
+ TableEnvironment.validateType(tpe)
logicalPlan match {
case node: DataSetRel =>
@@ -308,7 +308,7 @@ abstract class StreamTableEnvironment(
protected def translate[A]
(logicalPlan: RelNode)(implicit tpe: TypeInformation[A]): DataStream[A] = {
- validateType(tpe)
+ TableEnvironment.validateType(tpe)
logicalPlan match {
case node: DataStreamRel =>
@@ -18,8 +18,8 @@
package org.apache.flink.table.api
-import _root_.java.util.concurrent.atomic.AtomicInteger
import _root_.java.lang.reflect.Modifier
+import _root_.java.util.concurrent.atomic.AtomicInteger
import org.apache.calcite.config.Lex
import org.apache.calcite.jdbc.CalciteSchema
@@ -32,7 +32,8 @@ import org.apache.calcite.sql.parser.SqlParser
import org.apache.calcite.sql.util.ChainedSqlOperatorTable
import org.apache.calcite.tools.{FrameworkConfig, Frameworks, RuleSet, RuleSets}
import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
-import org.apache.flink.api.java.typeutils.{PojoTypeInfo, RowTypeInfo, TupleTypeInfo}
+import org.apache.flink.api.common.typeutils.CompositeType
+import org.apache.flink.api.java.typeutils.{PojoTypeInfo, TupleTypeInfo}
import org.apache.flink.api.java.{ExecutionEnvironment => JavaBatchExecEnv}
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.api.scala.{ExecutionEnvironment => ScalaBatchExecEnv}
@@ -48,6 +49,7 @@ import org.apache.flink.table.functions.{ScalarFunction, TableFunction}
import org.apache.flink.table.plan.cost.DataSetCostFactory
import org.apache.flink.table.plan.schema.RelTable
import org.apache.flink.table.sinks.TableSink
+import org.apache.flink.table.sources.{DefinedFieldNames, TableSource}
import org.apache.flink.table.validate.FunctionCatalog
import _root_.scala.collection.JavaConverters._
@@ -336,48 +338,16 @@ abstract class TableEnvironment(val config: TableConfig) {
frameworkConfig
}
- protected def validateType(typeInfo: TypeInformation[_]): Unit = {
- val clazz = typeInfo.getTypeClass
- if ((clazz.isMemberClass && !Modifier.isStatic(clazz.getModifiers)) ||
- !Modifier.isPublic(clazz.getModifiers) ||
- clazz.getCanonicalName == null) {
- throw TableException(s"Class '$clazz' described in type information '$typeInfo' must be " +
- s"static and globally accessible.")
- }
- }
-
/**
* Returns field names and field positions for a given [[TypeInformation]].
*
- * Field names are automatically extracted for
- * [[org.apache.flink.api.common.typeutils.CompositeType]].
- * The method fails if inputType is not a
- * [[org.apache.flink.api.common.typeutils.CompositeType]].
- *
* @param inputType The TypeInformation extract the field names and positions from.
* @tparam A The type of the TypeInformation.
* @return A tuple of two arrays holding the field names and corresponding field positions.
*/
protected[flink] def getFieldInfo[A](inputType: TypeInformation[A]):
- (Array[String], Array[Int]) =
- {
- validateType(inputType)
-
- val fieldNames: Array[String] = inputType match {
- case t: TupleTypeInfo[A] => t.getFieldNames
- case c: CaseClassTypeInfo[A] => c.getFieldNames
- case p: PojoTypeInfo[A] => p.getFieldNames
- case r: RowTypeInfo => r.getFieldNames
- case tpe =>
- throw new TableException(s"Type $tpe lacks explicit field naming")
- }
- val fieldIndexes = fieldNames.indices.toArray
-
- if (fieldNames.contains("*")) {
- throw new TableException("Field name can not be '*'.")
- }
-
- (fieldNames, fieldIndexes)
+ (Array[String], Array[Int]) = {
+ (TableEnvironment.getFieldNames(inputType), TableEnvironment.getFieldIndices(inputType))
}
/**
@@ -393,7 +363,7 @@ abstract class TableEnvironment(val config: TableConfig) {
inputType: TypeInformation[A],
exprs: Array[Expression]): (Array[String], Array[Int]) = {
- validateType(inputType)
+ TableEnvironment.validateType(inputType)
val indexedNames: Array[(Int, String)] = inputType match {
case a: AtomicType[A] =>
@@ -554,4 +524,95 @@ object TableEnvironment {
new ScalaStreamTableEnv(executionEnvironment, tableConfig)
}
+
+ /**
+ * Returns field names for a given [[TypeInformation]].
+ *
+ * @param inputType The TypeInformation extract the field names.
+ * @tparam A The type of the TypeInformation.
+ * @return An array holding the field names
+ */
+ def getFieldNames[A](inputType: TypeInformation[A]): Array[String] = {
+ validateType(inputType)
+
+ val fieldNames: Array[String] = inputType match {
+ case t: CompositeType[_] => t.getFieldNames
+ case a: AtomicType[_] => Array("f0")
+ case tpe =>
+ throw new TableException(s"Currently only CompositeType and AtomicType are supported. " +
+ s"Type $tpe lacks explicit field naming")
+ }
+
+ if (fieldNames.contains("*")) {
+ throw new TableException("Field name can not be '*'.")
+ }
+
+ fieldNames
+ }
+
+ /**
+ * Validate if class represented by the typeInfo is static and globally accessible
+ * @param typeInfo type to check
+ * @throws TableException if type does not meet these criteria
+ */
+ def validateType(typeInfo: TypeInformation[_]): Unit = {
+ val clazz = typeInfo.getTypeClass
+ if ((clazz.isMemberClass && !Modifier.isStatic(clazz.getModifiers)) ||
+ !Modifier.isPublic(clazz.getModifiers) ||
+ clazz.getCanonicalName == null) {
+ throw TableException(s"Class '$clazz' described in type information '$typeInfo' must be " +
+ s"static and globally accessible.")
+ }
+ }
+
+ /**
+ * Returns field indexes for a given [[TypeInformation]].
+ *
+ * @param inputType The TypeInformation extract the field positions from.
+ * @return An array holding the field positions
+ */
+ def getFieldIndices(inputType: TypeInformation[_]): Array[Int] = {
+ getFieldNames(inputType).indices.toArray
+ }
+
+ /**
+ * Returns field types for a given [[TypeInformation]].
+ *
+ * @param inputType The TypeInformation to extract field types from.
+ * @return An array holding the field types.
+ */
+ def getFieldTypes(inputType: TypeInformation[_]): Array[TypeInformation[_]] = {
+ validateType(inputType)
+
+ inputType match {
+ case t: CompositeType[_] => 0.until(t.getArity).map(t.getTypeAt(_)).toArray
+ case a: AtomicType[_] => Array(a.asInstanceOf[TypeInformation[_]])
+ case tpe =>
+ throw new TableException(s"Currently only CompositeType and AtomicType are supported.")
+ }
+ }
+
+ /**
+ * Returns field names for a given [[TableSource]].
+ *
+ * @param tableSource The TableSource to extract field names from.
+ * @tparam A The type of the TableSource.
+ * @return An array holding the field names.
+ */
+ def getFieldNames[A](tableSource: TableSource[A]): Array[String] = tableSource match {
+ case d: DefinedFieldNames => d.getFieldNames
+ case _ => TableEnvironment.getFieldNames(tableSource.getReturnType)
+ }
+
+ /**
+ * Returns field indices for a given [[TableSource]].
+ *
+ * @param tableSource The TableSource to extract field indices from.
+ * @tparam A The type of the TableSource.
+ * @return An array holding the field indices.
+ */
+ def getFieldIndices[A](tableSource: TableSource[A]): Array[Int] = tableSource match {
+ case d: DefinedFieldNames => d.getFieldIndices
+ case _ => TableEnvironment.getFieldIndices(tableSource.getReturnType)
+ }
}
@@ -29,7 +29,7 @@ import org.apache.flink.api.common.typeinfo.{AtomicType, TypeInformation}
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils.TypeExtractor
import org.apache.flink.table.calcite.FlinkTypeFactory
-import org.apache.flink.table.api.{TableException, ValidationException}
+import org.apache.flink.table.api.{TableEnvironment, TableException, ValidationException}
import org.apache.flink.table.functions.{ScalarFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.plan.schema.FlinkTableFunctionImpl
import org.apache.flink.util.InstantiationUtil
@@ -268,23 +268,9 @@ object UserDefinedFunctionUtils {
def getFieldInfo(inputType: TypeInformation[_])
: (Array[String], Array[Int], Array[TypeInformation[_]]) = {
- val fieldNames: Array[String] = inputType match {
- case t: CompositeType[_] => t.getFieldNames
- case a: AtomicType[_] => Array("f0")
- case tpe =>
- throw new TableException(s"Currently only CompositeType and AtomicType are supported. " +
- s"Type $tpe lacks explicit field naming")
- }
- val fieldIndexes = fieldNames.indices.toArray
- val fieldTypes: Array[TypeInformation[_]] = fieldNames.map { i =>
- inputType match {
- case t: CompositeType[_] => t.getTypeAt(i).asInstanceOf[TypeInformation[_]]
- case a: AtomicType[_] => a.asInstanceOf[TypeInformation[_]]
- case tpe =>
- throw new TableException(s"Currently only CompositeType and AtomicType are supported.")
- }
- }
- (fieldNames, fieldIndexes, fieldTypes)
+ (TableEnvironment.getFieldNames(inputType),
+ TableEnvironment.getFieldIndices(inputType),
+ TableEnvironment.getFieldTypes(inputType))
}
/**
@@ -18,7 +18,9 @@
package org.apache.flink.table.plan.nodes
-import org.apache.calcite.rel.`type`.RelDataType
+import java.util
+
+import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeField}
import org.apache.calcite.rex._
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.flink.api.common.functions.MapFunction
@@ -103,10 +105,12 @@ trait FlinkRel {
}
+
private[flink] def estimateRowSize(rowType: RelDataType): Double = {
+ val fieldList = rowType.getFieldList
- rowType.getFieldList.map(_.getType.getSqlTypeName).foldLeft(0) { (s, t) =>
- t match {
+ fieldList.map(_.getType.getSqlTypeName).zipWithIndex.foldLeft(0) { (s, t) =>
+ t._1 match {
case SqlTypeName.TINYINT => s + 1
case SqlTypeName.SMALLINT => s + 2
case SqlTypeName.INTEGER => s + 4
@@ -120,6 +124,7 @@ trait FlinkRel {
case typeName if SqlTypeName.YEAR_INTERVAL_TYPES.contains(typeName) => s + 8
case typeName if SqlTypeName.DAY_INTERVAL_TYPES.contains(typeName) => s + 4
case SqlTypeName.TIME | SqlTypeName.TIMESTAMP | SqlTypeName.DATE => s + 12
+ case SqlTypeName.ROW => s + estimateRowSize(fieldList.get(t._2).getType()).asInstanceOf[Int]
case _ => throw TableException(s"Unsupported data type encountered: $t")
}
}
@@ -23,7 +23,7 @@ import org.apache.calcite.rel.metadata.RelMetadataQuery
import org.apache.calcite.rel.{RelNode, RelWriter}
import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.DataSet
-import org.apache.flink.table.api.BatchTableEnvironment
+import org.apache.flink.table.api.{BatchTableEnvironment, TableEnvironment}
import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.schema.TableSourceTable
import org.apache.flink.table.sources.BatchTableSource
@@ -38,7 +38,9 @@ class BatchTableSourceScan(
override def deriveRowType() = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
- flinkTypeFactory.buildRowDataType(tableSource.getFieldsNames, tableSource.getFieldTypes)
+ flinkTypeFactory.buildRowDataType(
+ TableEnvironment.getFieldNames(tableSource),
+ TableEnvironment.getFieldTypes(tableSource.getReturnType))
}
override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
@@ -57,7 +59,7 @@ class BatchTableSourceScan(
override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
- .item("fields", tableSource.getFieldsNames.mkString(", "))
+ .item("fields", TableEnvironment.getFieldNames(tableSource).mkString(", "))
}
override def translateToPlan(
@@ -26,7 +26,7 @@ import org.apache.flink.table.calcite.FlinkTypeFactory
import org.apache.flink.table.plan.schema.TableSourceTable
import org.apache.flink.table.sources.StreamTableSource
import org.apache.flink.streaming.api.datastream.DataStream
-import org.apache.flink.table.api.StreamTableEnvironment
+import org.apache.flink.table.api.{StreamTableEnvironment, TableEnvironment}
/** Flink RelNode to read data from an external source defined by a [[StreamTableSource]]. */
class StreamTableSourceScan(
@@ -38,7 +38,9 @@ class StreamTableSourceScan(
override def deriveRowType() = {
val flinkTypeFactory = cluster.getTypeFactory.asInstanceOf[FlinkTypeFactory]
- flinkTypeFactory.buildRowDataType(tableSource.getFieldsNames, tableSource.getFieldTypes)
+ flinkTypeFactory.buildRowDataType(
+ TableEnvironment.getFieldNames(tableSource),
+ TableEnvironment.getFieldTypes(tableSource.getReturnType))
}
override def computeSelfCost (planner: RelOptPlanner, metadata: RelMetadataQuery): RelOptCost = {
@@ -57,7 +59,7 @@ class StreamTableSourceScan(
override def explainTerms(pw: RelWriter): RelWriter = {
super.explainTerms(pw)
- .item("fields", tableSource.getFieldsNames.mkString(", "))
+ .item("fields", TableEnvironment.getFieldNames(tableSource).mkString(", "))
}
override def translateToPlan(
@@ -39,9 +39,9 @@ class BatchTableSourceScanRule
/** Rule must only match if TableScan targets a [[BatchTableSource]] */
override def matches(call: RelOptRuleCall): Boolean = {
val scan: TableScan = call.rel(0).asInstanceOf[TableScan]
- val dataSetTable = scan.getTable.unwrap(classOf[TableSourceTable])
+ val dataSetTable = scan.getTable.unwrap(classOf[TableSourceTable[_]])
dataSetTable match {
- case tst: TableSourceTable =>
+ case tst: TableSourceTable[_] =>
tst.tableSource match {
case _: BatchTableSource[_] =>
true
@@ -57,7 +57,7 @@ class BatchTableSourceScanRule
val scan: TableScan = rel.asInstanceOf[TableScan]
val traitSet: RelTraitSet = rel.getTraitSet.replace(DataSetConvention.INSTANCE)
- val tableSource = scan.getTable.unwrap(classOf[TableSourceTable]).tableSource
+ val tableSource = scan.getTable.unwrap(classOf[TableSourceTable[_]]).tableSource
.asInstanceOf[BatchTableSource[_]]
new BatchTableSourceScan(
rel.getCluster,
Oops, something went wrong.

0 comments on commit 38ded2b

Please sign in to comment.